SparkSQL自定义 UDF 函数median求中位数

原文:SparkSQL自定义 UDF 函数median求中位数

前言


我的场景:提供一个聚合组件操作Spark的DataFrame,然后支持先分组在聚合的功能,这里聚合要求支持最大值个数、求和、去重后求和、均值、中位数、最大值、最小值、方差、标准差、唯一值个数、唯一值、归一化等。

实现下来发现除中位数和归一化外其他聚合均有内置函数,实现起来也就很容易了。
但是在分组后计算中位数这里卡了很长时间,最后的解决办法是:自定义一个UDF函数实现分组后中位数的计算

自定义中位数函数:CustomMedian.scala

/**
    * 自定义计算中位数聚合函数
    * qi.wang[email protected]
    */
  object CustomMedian extends UserDefinedAggregateFunction {

    override def inputSchema: StructType = StructType(StructField("input", StringType) :: Nil)
    override def bufferSchema: StructType = StructType(StructField("sum", StringType) :: StructField("count", StringType) :: Nil)
    override def dataType: DataType = DoubleType
    override def deterministic: Boolean = true // 聚合函数是否是幂等的,即相同输入是否总是能得到相同输出

    override def initialize(buffer: MutableAggregationBuffer): Unit = {
      buffer(0) = ""
    }

    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
      if (!input.isNullAt(0)) {
        buffer(0) = buffer.get(0) + "," + input.get(0)
      }
    }

    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
      buffer1.update(0, buffer1.get(0) + "," + buffer2.get(0))
    }

    override def evaluate(buffer: Row): Any = {
      val list = new util.ArrayList[Integer]
      val stringList:Array[String] = buffer.getString(0).split(",")
      for (s <- stringList) {
        if (StringUtils.isNotBlank(s))
          list.add(s.toInt)
      }
      Collections.sort(list)
      val size = list.size
      var num:Double = 0L
      if (size % 2 == 1) num = list.get(((size+1) / 2) - 1).toDouble
      if (size % 2 == 0) num = (list.get(size / 2 - 1) + list.get(size / 2)) / 2.00
      num
    }
  }

函数测试

  1. 造一个数据文件:/tmp/data.csv, 内容如下
id|name|mobile|idnumber
10|aa|11111111111|111111111111111111
12|bb|12321321321|213123123213333333
13|aa|21312332322|333333333333333334
15|dd|23114567888|872837482374932794
17|bb|44444444444|827183787373733333
18|bb|55555555555|823048320999399999
  1. 测试代码
package www.relaxheart.cn

import www.relaxheart.cn.CustomMedian
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types._
import scala.util.Random


/**
  * @author 王琦[email protected]
  * @date 19/8/13 下午20:33
  * @description
  */
object MedianUDFTest extends App {

  val spark = SparkSession.builder().master("local[*]").appName("MedianUDFTest").config("spark.sql.crossJoin.enabled", "true").getOrCreate()

// 读取data.csv得到RDD
  val rdd = spark.sparkContext.textFile("/tmp/data.csv")

  // 从第一行数据中获取最后转成的DataFrame应该有多少列 并给每一列命名
  val colNames = rdd.first.split("\\|")

  // 设置DataFrame的结构
  val schema = StructType(colNames.map(fieldName => StructField(fieldName, StringType)))

  // 对每一行的数据进行处理
  val rowRDD = rdd.filter(_.split("\\|")(0) != "id").map(_.split("\\|")).map(p => Row(p: _*))

  // 创建DataFrame
  val data = spark.createDataFrame(rowRDD, schema)

  // 创建临时表
  val tmpTable = "_table"+System.currentTimeMillis()+Random.nextInt(10000000)
  data.createOrReplaceTempView(tmpTable)

 // 这步很关键,注册我们的自定义中位数函数
  spark.udf.register("median",  CustomMedian)

  // 利用SparkSQL + 自定义中位数函数实现分组后求中位数
  // 这里对测试数据按name进行分组,然后组内id的中位数
  val medianGroupDF = spark.sql(s"select name , median(id) as median from $tmpTable group by name")

  // 打印分组中位数聚合结果
  medianGroupDF.show()
}

结果验证

image.png

看打印结果是符合我们预期的。

个人博客网站:王琦的个人兴趣分享网站 | RelaxHeart网 | Tec博客

你可能感兴趣的:(SparkSQL自定义 UDF 函数median求中位数)