首先用Scala寫一個UDAF函數
import scala.collection.mutable.{ArrayBuffer, WrappedArray}
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
class UDAFMedian extends UserDefinedAggregateFunction {
// 聚合函數的輸入資料結構
def inputSchema: StructType =
StructType(StructField("value", DoubleType) :: Nil)
// 緩存區資料結構
def bufferSchema: StructType = StructType(
StructField("data_list", ArrayType(DoubleType, false)) :: Nil
)
// 聚合函數傳回值資料類型
def dataType: DataType = DoubleType
// 聚合函數是否是幂等的,即相同輸入是否總是能得到相同輸出
def deterministic: Boolean = true
// 初始化緩沖區
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = new ArrayBuffer[Double]()
}
// 給聚合函數傳入一條新資料時的處理邏輯
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
var bufferVal = buffer.getAs[WrappedArray[Double]](0).toBuffer
bufferVal += input.getAs[Double](0)
buffer(0) = bufferVal
}
// 合并聚合函數緩沖區
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getAs[WrappedArray[Double]](0) ++ buffer2.getAs[WrappedArray[Double]](0)
}
// 計算最終結果
def evaluate(buffer: org.apache.spark.sql.Row): Any = {
val sortedWindow = buffer.getAs[WrappedArray[Double]](0).sorted.toBuffer
val windowSize = sortedWindow.size
if (windowSize % 2 == 0) {
val index = windowSize / 2
(sortedWindow(index) + sortedWindow(index - 1)) / 2
} else {
sortedWindow((windowSize + 1) / 2 - 1)
}
}
}
其次,注冊該UDAF并使用
import org.apache.spark.sql.SparkSession
import scala.collection.mutable.ListBuffer
object TestMedian {
def main(args: Array[String]): Unit = {
val ss: SparkSession = SparkSession.builder().master("local").enableHiveSupport().getOrCreate()
// 注冊自定義的UFAF函數,并命名為median
ss.sqlContext.udf.register("median", new UDAFMedian())
// 在sql中使用median函數,求中位數
val sql = "select class, median(score) from scores group by class"
val rdd = ss.sql(sql).rdd.collect()
// 将sql結果存入ListBuffer
val result:ListBuffer[String] = new ListBuffer[String]()
for (i <- 0 to rdd.length - 1) {
val line: StringBuffer = new StringBuffer()
for (j <- 0 to rdd(i).length - 1) {
val value = rdd(i)(j)
if (Option(value) == None) {
line.append("")
} else {
line.append(value.toString)
}
if (j < rdd(i).length - 1) {
line.append(",")
}
}
result.append(line.toString)
}
}
}
官方UDAF示例參考: https://docs.databricks.com/spark/latest/spark-sql/udaf-scala.html