《SparkSQL内核剖析》【Aggregation篇】

这篇文章讨论Spark SQL中聚合相关的内容。

聚合表达式

我们先来看一条简单的聚合语句

df.groupBy("xxx").count("yyy")

这样一个聚合查询,在Spark SQL中是怎么表示的呢?核心的部分就是分组表达式GroupExpression聚合表达式AggregationExpression。我们把上面的语句拆解开,groupBy(“xxx”)就是分组表达式,而count(“yyy”)就是聚合表达式。
聚合表达式的最主要组成部分是聚合函数。

聚合函数

聚合函数可以分成下面几类:

  • 声明式聚合函数
    声明式聚合函数是指可以由Catalyst中的表达式直接构建的聚合函数,也是比较简单的聚合函数类型,最常见的count, sum,avg等都是声明式聚合函数。
  • 命令式聚合函数
    命令式聚合函数是指一类需要显式实现几个方法来操作聚合缓冲区AggBuffer中的数据的聚合函数。命令式聚合函数不那么常见,能找到的命令式聚合函数包括基数统计hyperLogLogPlus、透视转换pivotFirst等。
  • 带类型的命令式聚合函数
    带类型的命令式聚合函数是最灵活的一种聚合函数类型,它允许使用用户自定义对象作为聚合缓冲区。涉及用户自定义类型的聚合都是这种类型,例如collect_list、collect_set、percentile等。

常见的聚合函数

聚合函数 含义 类型
sum 求和 声明式
count 计数 声明式
avg 平均数 声明式
covariance 方差 声明式
first 第一个元素 声明式
last 最后一个元素 声明式
pivotFirst 透视转换 命令式
hyperLogLogPlus 基数估计 命令式
collect_list 生成取值列表 带类型的命令式
collect_set 生成取值集合 带类型的命令式
percentile 百分位数 带类型的命令式

不管是哪种类型的聚合函数,都需要实现4个方法来完成聚合运算:

  • 初始值 initialValues
  • 更新表达式 updateExpressions
  • 合并表达式 mergeExpressions
  • 生成结果表达式 evaluateExpression

这四个函数更新和合并的对象,就是聚合缓冲区。

聚合缓冲区和聚合模式

聚合查询在计算聚合值的过程中,通常需要保存相关的中间结果,比如计算max,需要保留当前的最大值;比如计算avg,需要保存当前的sum和count。这些临时的中间结果,就是保存在聚合缓冲区AggBuffer

根据聚合过程中操作缓冲区的方式, 可以将聚合分成以下几种聚合模式AggregateMode

  • Partial模式
  • PartialMerge模式
  • Final模式
  • Complete模式

Partial模式的意思就是先把局部数据进行聚合,比如先计算每个分区的sum。Final模式的意思就是把聚合缓冲区中的聚合结果再进行聚合,比如计算分区sum的sum。一般Partial模式和Final模式配合出现,Partial类似map过程,而Final类似reduce过程。

Complete模式的特点在于没有中间聚合过程,每个分组的全体值都需要在一次聚合过程中参与计算。

PartialMerge模式出现在多种类型的聚合函数同时聚合的情况,比如同时聚集sum和countDistinct。这时候缓冲区聚合之后的结果,仍然是中间结果。

聚合模式的定义,见sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala

/** The mode of an [[AggregateFunction]]. */
sealed trait AggregateMode

/**
 * An [[AggregateFunction]] with [[Partial]] mode is used for partial aggregation.
 * This function updates the given aggregation buffer with the original input of this
 * function. When it has processed all input rows, the aggregation buffer is returned.
 */
case object Partial extends AggregateMode

/**
 * An [[AggregateFunction]] with [[PartialMerge]] mode is used to merge aggregation buffers
 * containing intermediate results for this function.
 * This function updates the given aggregation buffer by merging multiple aggregation buffers.
 * When it has processed all input rows, the aggregation buffer is returned.
 */
case object PartialMerge extends AggregateMode

/**
 * An [[AggregateFunction]] with [[Final]] mode is used to merge aggregation buffers
 * containing intermediate results for this function and then generate final result.
 * This function updates the given aggregation buffer by merging multiple aggregation buffers.
 * When it has processed all input rows, the final result of this function is returned.
 */
case object Final extends AggregateMode

/**
 * An [[AggregateFunction]] with [[Complete]] mode is used to evaluate this function directly
 * from original input rows without any partial aggregation.
 * This function updates the given aggregation buffer with the original input of this
 * function. When it has processed all input rows, the final result of this function is returned.
 */
case object Complete extends AggregateMode

sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala

abstract class Aggregator[-IN, BUF, OUT] extends Serializable {

  def zero: BUF
  def reduce(b: BUF, a: IN): BUF
  def merge(b1: BUF, b2: BUF): BUF
  def finish(reduction: BUF): OUT
  
  def bufferEncoder: Encoder[BUF]
  def outputEncoder: Encoder[OUT]
  
  def toColumn: TypedColumn[IN, OUT] = {
    implicit val bEncoder = bufferEncoder
    implicit val cEncoder = outputEncoder

    val expr =
      AggregateExpression(
        TypedAggregateExpression(this),
        Complete,
        isDistinct = false)

    new TypedColumn[IN, OUT](expr, encoderFor[OUT])
  }
}

用户自定义聚合函数

一个UDAF的例子

import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._

class GeometricMean extends UserDefinedAggregateFunction {
  // This is the input fields for your aggregate function.
  override def inputSchema: org.apache.spark.sql.types.StructType =
    StructType(StructField("value", DoubleType) :: Nil)

  // This is the internal fields you keep for computing your aggregate.
  override def bufferSchema: StructType = StructType(
    StructField("count", LongType) ::
    StructField("product", DoubleType) :: Nil
  )

  // This is the output type of your aggregatation function.
  override def dataType: DataType = DoubleType

  override def deterministic: Boolean = true

  // This is the initial value for your buffer schema.
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0L
    buffer(1) = 1.0
  }

  // This is how to update your buffer schema given an input.
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = buffer.getAs[Long](0) + 1
    buffer(1) = buffer.getAs[Double](1) * input.getAs[Double](0)
  }

  // This is how to merge two objects with the bufferSchema type.
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getAs[Long](0) + buffer2.getAs[Long](0)
    buffer1(1) = buffer1.getAs[Double](1) * buffer2.getAs[Double](1)
  }

  // This is where you output the final value, given the final value of your bufferSchema.
  override def evaluate(buffer: Row): Any = {
    math.pow(buffer.getDouble(1), 1.toDouble / buffer.getLong(0))
  }
}

有哪些用途

实际项目中,往往存在业务自定义的一些复杂要求,只用max、min、sum、first、last等默认聚集函数无法满足需求。比如有一项指标叫做“用户经常出现的城市”,选择过去一段时间内,出现次数超过n次,且在用户出现过的多个城市中出现次数最多的;这时候自定义的聚集函数就派上了用场。

你可能感兴趣的:(Spark)