spark-mllib中許多算法用到了treeAggregate這個方法,使用該方法而不是aggregate方法能夠提升算法的性能。比如mllib中的GaussianMixture模型可以提升20%的性能,見treeAggregate
此前對這種聚合方式不是很了解,是以這裡記錄一下。
1. 一個例子
def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder
.appName(s"agg")
.master("local")
.getOrCreate()
val sc = spark.sparkContext
def seqOp(s1:Int, s2:Int):Int = {
println("seq: "+s1+":"+s2)
s1 + s2
}
def combOp(c1: Int, c2: Int): Int = {
println("comb: "+c1+":"+c2)
c1 + c2
}
val rdd = sc.parallelize( to ).repartition()
val res1 = rdd.aggregate()(seqOp, combOp)
// val res2 = rdd.treeAggregate(0)(seqOp, combOp)
println(res1)
// println(res2)
}
aggregate:
seq: 0:6
seq: 6:12
comb: 0:18
seq: 0:1
seq: 1:7
comb: 18:8
seq: 0:2
seq: 2:8
comb: 26:10
seq: 0:3
seq: 3:9
comb: 36:12
seq: 0:4
seq: 4:10
comb: 48:14
seq: 0:5
seq: 5:11
comb: 62:16
78
treeAggregate:
seq: 0:6
seq: 6:12
seq: 0:1
seq: 1:7
seq: 0:2
seq: 2:8
seq: 0:3
seq: 3:9
seq: 0:4
seq: 4:10
seq: 0:5
seq: 5:11
[Stage 2:> (0 + 0) / 2]
comb: 18:10
comb: 28:14
comb: 8:12
comb: 20:16
comb: 42:36
78
2. Aggregate
treeAggregate是aggregate的一種特殊形式,是以了解treeAggregate首先需要了解aggregate的如何對資料做聚合操作。方法定義如下:
def aggregate[U: ClassTag](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U
從aggregate方法的定義中,可以看到它需要傳入三個參數:
- 聚合的初始值:zeroValue: U
- 對序列操作的函數:seqOp
- 聚合函數:combOp
aggregate函數将每個分區進行seqOp,且從zeroValue開始周遊分區裡的所有元素。然後用combOp。從zeroValue開始周遊所有分區的結果。
注:每個partition的seqOp隻應用一次zeroValue,最後的combOp也應用一次zeroValue。
用一張圖來說明上面的計算過程:
3. treeAggregate
def treeAggregate[U: ClassTag](zeroValue: U)(
seqOp: (U, T) => U,
combOp: (U, U) => U,
depth: Int = ): U
與aggregate不同的是treeAggregate多了depth的參數,其他參數含義相同。aggregate在執行完SeqOp後會将計算結果拿到driver端使用CombOp周遊一次SeqOp計算的結果,最終得到聚合結果。而treeAggregate不會一次就Comb得到最終結果,SeqOp得到的結果也許很大,直接拉到driver可能會OutOfMemory,是以它會先把分區的結果做局部聚合(reduceByKey),如果分區數過多時會做分區合并,之後再把結果拿到driver端做reduce。
注:與aggregate不同的地方是:在每個分區,會做兩次或者多次combOp,避免将所有局部的值傳給driver端。另外,初始值zeroValue不會參與combOp。
具體可以參見源碼:
/**
* Aggregates the elements of this RDD in a multi-level tree pattern.
*
* @param depth suggested depth of the tree (default: 2)
* @see [[org.apache.spark.rdd.RDD#aggregate]]
*/
def treeAggregate[U: ClassTag](zeroValue: U)(
seqOp: (U, T) => U,
combOp: (U, U) => U,
depth: Int = ): U = withScope {
require(depth >= , s"Depth must be greater than or equal to 1 but got $depth.")
if (partitions.length == ) {
Utils.clone(zeroValue, context.env.closureSerializer.newInstance())
} else {
val cleanSeqOp = context.clean(seqOp)
val cleanCombOp = context.clean(combOp)
val aggregatePartition =
(it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
var partiallyAggregated = mapPartitions(it => Iterator(aggregatePartition(it)))
var numPartitions = partiallyAggregated.partitions.length
val scale = math.max(math.ceil(math.pow(numPartitions, / depth)).toInt, )
// If creating an extra level doesn't help reduce
// the wall-clock time, we stop tree aggregation.
// Don't trigger TreeAggregation when it doesn't save wall-clock time
while (numPartitions > scale + math.ceil(numPartitions.toDouble / scale)) {
numPartitions /= scale
val curNumPartitions = numPartitions
partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex {
(i, iter) => iter.map((i % curNumPartitions, _))
}.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values
}
partiallyAggregated.reduce(cleanCombOp)
}
}
還是用一張圖來說明:
參考:
https://www.cnblogs.com/drawwindows/p/5762392.html
http://blog.csdn.net/lookqlp/article/details/52121057
https://www.jianshu.com/p/27222830d21a