这篇文章讨论Spark SQL中聚合相关的内容。
我们先来看一条简单的聚合语句
df.groupBy("xxx").count("yyy")
这样一个聚合查询,在Spark SQL中是怎么表示的呢?核心的部分就是分组表达式GroupExpression和聚合表达式AggregationExpression。我们把上面的语句拆解开,groupBy(“xxx”)就是分组表达式,而count(“yyy”)就是聚合表达式。
聚合表达式的最主要组成部分是聚合函数。
聚合函数可以分成下面几类:
常见的聚合函数
聚合函数 | 含义 | 类型 |
---|---|---|
sum | 求和 | 声明式 |
count | 计数 | 声明式 |
avg | 平均数 | 声明式 |
covariance | 方差 | 声明式 |
first | 第一个元素 | 声明式 |
last | 最后一个元素 | 声明式 |
pivotFirst | 透视转换 | 命令式 |
hyperLogLogPlus | 基数估计 | 命令式 |
collect_list | 生成取值列表 | 带类型的命令式 |
collect_set | 生成取值集合 | 带类型的命令式 |
percentile | 百分位数 | 带类型的命令式 |
不管是哪种类型的聚合函数,都需要实现4个方法来完成聚合运算:
这四个函数更新和合并的对象,就是聚合缓冲区。
聚合查询在计算聚合值的过程中,通常需要保存相关的中间结果,比如计算max,需要保留当前的最大值;比如计算avg,需要保存当前的sum和count。这些临时的中间结果,就是保存在聚合缓冲区AggBuffer。
根据聚合过程中操作缓冲区的方式, 可以将聚合分成以下几种聚合模式AggregateMode:
Partial模式的意思就是先把局部数据进行聚合,比如先计算每个分区的sum。Final模式的意思就是把聚合缓冲区中的聚合结果再进行聚合,比如计算分区sum的sum。一般Partial模式和Final模式配合出现,Partial类似map过程,而Final类似reduce过程。
Complete模式的特点在于没有中间聚合过程,每个分组的全体值都需要在一次聚合过程中参与计算。
PartialMerge模式出现在多种类型的聚合函数同时聚合的情况,比如同时聚集sum和countDistinct。这时候缓冲区聚合之后的结果,仍然是中间结果。
聚合模式的定义,见sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
/** The mode of an [[AggregateFunction]]. */
sealed trait AggregateMode
/**
* An [[AggregateFunction]] with [[Partial]] mode is used for partial aggregation.
* This function updates the given aggregation buffer with the original input of this
* function. When it has processed all input rows, the aggregation buffer is returned.
*/
case object Partial extends AggregateMode
/**
* An [[AggregateFunction]] with [[PartialMerge]] mode is used to merge aggregation buffers
* containing intermediate results for this function.
* This function updates the given aggregation buffer by merging multiple aggregation buffers.
* When it has processed all input rows, the aggregation buffer is returned.
*/
case object PartialMerge extends AggregateMode
/**
* An [[AggregateFunction]] with [[Final]] mode is used to merge aggregation buffers
* containing intermediate results for this function and then generate final result.
* This function updates the given aggregation buffer by merging multiple aggregation buffers.
* When it has processed all input rows, the final result of this function is returned.
*/
case object Final extends AggregateMode
/**
* An [[AggregateFunction]] with [[Complete]] mode is used to evaluate this function directly
* from original input rows without any partial aggregation.
* This function updates the given aggregation buffer with the original input of this
* function. When it has processed all input rows, the final result of this function is returned.
*/
case object Complete extends AggregateMode
见sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
abstract class Aggregator[-IN, BUF, OUT] extends Serializable {
def zero: BUF
def reduce(b: BUF, a: IN): BUF
def merge(b1: BUF, b2: BUF): BUF
def finish(reduction: BUF): OUT
def bufferEncoder: Encoder[BUF]
def outputEncoder: Encoder[OUT]
def toColumn: TypedColumn[IN, OUT] = {
implicit val bEncoder = bufferEncoder
implicit val cEncoder = outputEncoder
val expr =
AggregateExpression(
TypedAggregateExpression(this),
Complete,
isDistinct = false)
new TypedColumn[IN, OUT](expr, encoderFor[OUT])
}
}
一个UDAF的例子
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
class GeometricMean extends UserDefinedAggregateFunction {
// This is the input fields for your aggregate function.
override def inputSchema: org.apache.spark.sql.types.StructType =
StructType(StructField("value", DoubleType) :: Nil)
// This is the internal fields you keep for computing your aggregate.
override def bufferSchema: StructType = StructType(
StructField("count", LongType) ::
StructField("product", DoubleType) :: Nil
)
// This is the output type of your aggregatation function.
override def dataType: DataType = DoubleType
override def deterministic: Boolean = true
// This is the initial value for your buffer schema.
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L
buffer(1) = 1.0
}
// This is how to update your buffer schema given an input.
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getAs[Long](0) + 1
buffer(1) = buffer.getAs[Double](1) * input.getAs[Double](0)
}
// This is how to merge two objects with the bufferSchema type.
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getAs[Long](0) + buffer2.getAs[Long](0)
buffer1(1) = buffer1.getAs[Double](1) * buffer2.getAs[Double](1)
}
// This is where you output the final value, given the final value of your bufferSchema.
override def evaluate(buffer: Row): Any = {
math.pow(buffer.getDouble(1), 1.toDouble / buffer.getLong(0))
}
}
实际项目中,往往存在业务自定义的一些复杂要求,只用max、min、sum、first、last等默认聚集函数无法满足需求。比如有一项指标叫做“用户经常出现的城市”,选择过去一段时间内,出现次数超过n次,且在用户出现过的多个城市中出现次数最多的;这时候自定义的聚集函数就派上了用场。