天天看點

Scala UDAF + Spark Sql實作求中位數

首先用Scala寫一個UDAF函數

import scala.collection.mutable.{ArrayBuffer, WrappedArray}
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._


class UDAFMedian extends UserDefinedAggregateFunction {
  
  // 聚合函數的輸入資料結構
  def inputSchema: StructType =
    StructType(StructField("value", DoubleType) :: Nil)
  
  // 緩存區資料結構
  def bufferSchema: StructType = StructType(
    StructField("data_list", ArrayType(DoubleType, false)) :: Nil
  )

  // 聚合函數傳回值資料類型
  def dataType: DataType = DoubleType

  // 聚合函數是否是幂等的,即相同輸入是否總是能得到相同輸出
  def deterministic: Boolean = true

  // 初始化緩沖區
  def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = new ArrayBuffer[Double]()
  }

  // 給聚合函數傳入一條新資料時的處理邏輯
  def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    var bufferVal = buffer.getAs[WrappedArray[Double]](0).toBuffer
    bufferVal += input.getAs[Double](0)
    buffer(0) = bufferVal
  }

  // 合并聚合函數緩沖區
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getAs[WrappedArray[Double]](0) ++ buffer2.getAs[WrappedArray[Double]](0)
  }

  // 計算最終結果
  def evaluate(buffer: org.apache.spark.sql.Row): Any = {
    val sortedWindow = buffer.getAs[WrappedArray[Double]](0).sorted.toBuffer
    val windowSize = sortedWindow.size
    if (windowSize % 2 == 0) {
      val index = windowSize / 2
      (sortedWindow(index) + sortedWindow(index - 1)) / 2
    } else {
      sortedWindow((windowSize + 1) / 2 - 1)
    }
  }

}
           

其次,注冊該UDAF并使用

import org.apache.spark.sql.SparkSession
import scala.collection.mutable.ListBuffer

object TestMedian {
  def main(args: Array[String]): Unit = {
 
    val ss: SparkSession = SparkSession.builder().master("local").enableHiveSupport().getOrCreate()
    
    // 注冊自定義的UFAF函數,并命名為median
    ss.sqlContext.udf.register("median", new UDAFMedian())

    // 在sql中使用median函數,求中位數
    val sql = "select class, median(score) from scores group by class"
    val rdd = ss.sql(sql).rdd.collect()
    
    // 将sql結果存入ListBuffer
    val result:ListBuffer[String] = new ListBuffer[String]()
    for (i <- 0 to rdd.length - 1) {
      val line: StringBuffer = new StringBuffer()
      for (j <- 0 to rdd(i).length - 1) {
        val value = rdd(i)(j)
        if (Option(value) == None) {
          line.append("")
        } else {
          line.append(value.toString)
        }
        if (j < rdd(i).length - 1) {
          line.append(",")
        }
      }
      result.append(line.toString)
    }
  }
}
           

官方UDAF示例參考: https://docs.databricks.com/spark/latest/spark-sql/udaf-scala.html