天天看點

spark-aggregate與treeAggregate的了解

spark-mllib中許多算法用到了treeAggregate這個方法,使用該方法而不是aggregate方法能夠提升算法的性能。比如mllib中的GaussianMixture模型可以提升20%的性能,見treeAggregate

此前對這種聚合方式不是很了解,是以這裡記錄一下。

1. 一個例子

def main(args: Array[String]): Unit = {
  val spark = SparkSession
    .builder
    .appName(s"agg")
    .master("local")
    .getOrCreate()
  val sc = spark.sparkContext

  def seqOp(s1:Int, s2:Int):Int = {
    println("seq: "+s1+":"+s2)
    s1 + s2
  }

  def combOp(c1: Int, c2: Int): Int = {
    println("comb: "+c1+":"+c2)
    c1 + c2
  }

  val rdd = sc.parallelize( to ).repartition()
  val res1 = rdd.aggregate()(seqOp, combOp)
// val res2 = rdd.treeAggregate(0)(seqOp, combOp)
  println(res1)
// println(res2)
}
           

aggregate:

seq: 0:6

seq: 6:12

comb: 0:18

seq: 0:1

seq: 1:7

comb: 18:8

seq: 0:2

seq: 2:8

comb: 26:10

seq: 0:3

seq: 3:9

comb: 36:12

seq: 0:4

seq: 4:10

comb: 48:14

seq: 0:5

seq: 5:11

comb: 62:16

78

treeAggregate:

seq: 0:6

seq: 6:12

seq: 0:1

seq: 1:7

seq: 0:2

seq: 2:8

seq: 0:3

seq: 3:9

seq: 0:4

seq: 4:10

seq: 0:5

seq: 5:11

[Stage 2:> (0 + 0) / 2]

comb: 18:10

comb: 28:14

comb: 8:12

comb: 20:16

comb: 42:36

78

2. Aggregate

treeAggregate是aggregate的一種特殊形式,是以了解treeAggregate首先需要了解aggregate的如何對資料做聚合操作。方法定義如下:

def aggregate[U: ClassTag](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U
           

從aggregate方法的定義中,可以看到它需要傳入三個參數:

  1. 聚合的初始值:zeroValue: U
  2. 對序列操作的函數:seqOp
  3. 聚合函數:combOp

aggregate函數将每個分區進行seqOp,且從zeroValue開始周遊分區裡的所有元素。然後用combOp。從zeroValue開始周遊所有分區的結果。

注:每個partition的seqOp隻應用一次zeroValue,最後的combOp也應用一次zeroValue。

用一張圖來說明上面的計算過程:

spark-aggregate與treeAggregate的了解

3. treeAggregate

def treeAggregate[U: ClassTag](zeroValue: U)(
  seqOp: (U, T) => U,
  combOp: (U, U) => U,
  depth: Int = ): U
           

​ 與aggregate不同的是treeAggregate多了depth的參數,其他參數含義相同。aggregate在執行完SeqOp後會将計算結果拿到driver端使用CombOp周遊一次SeqOp計算的結果,最終得到聚合結果。而treeAggregate不會一次就Comb得到最終結果,SeqOp得到的結果也許很大,直接拉到driver可能會OutOfMemory,是以它會先把分區的結果做局部聚合(reduceByKey),如果分區數過多時會做分區合并,之後再把結果拿到driver端做reduce。

注:與aggregate不同的地方是:在每個分區,會做兩次或者多次combOp,避免将所有局部的值傳給driver端。另外,初始值zeroValue不會參與combOp。

具體可以參見源碼:

/**
   * Aggregates the elements of this RDD in a multi-level tree pattern.
   *
   * @param depth suggested depth of the tree (default: 2)
   * @see [[org.apache.spark.rdd.RDD#aggregate]]
   */
  def treeAggregate[U: ClassTag](zeroValue: U)(
      seqOp: (U, T) => U,
      combOp: (U, U) => U,
      depth: Int = ): U = withScope {
    require(depth >= , s"Depth must be greater than or equal to 1 but got $depth.")
    if (partitions.length == ) {
      Utils.clone(zeroValue, context.env.closureSerializer.newInstance())
    } else {
      val cleanSeqOp = context.clean(seqOp)
      val cleanCombOp = context.clean(combOp)
      val aggregatePartition =
        (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
      var partiallyAggregated = mapPartitions(it => Iterator(aggregatePartition(it)))
      var numPartitions = partiallyAggregated.partitions.length
      val scale = math.max(math.ceil(math.pow(numPartitions,  / depth)).toInt, )
      // If creating an extra level doesn't help reduce
      // the wall-clock time, we stop tree aggregation.

      // Don't trigger TreeAggregation when it doesn't save wall-clock time
      while (numPartitions > scale + math.ceil(numPartitions.toDouble / scale)) {
        numPartitions /= scale
        val curNumPartitions = numPartitions
        partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex {
          (i, iter) => iter.map((i % curNumPartitions, _))
        }.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values
      }
      partiallyAggregated.reduce(cleanCombOp)
    }
  }
           

還是用一張圖來說明:

spark-aggregate與treeAggregate的了解

參考:

https://www.cnblogs.com/drawwindows/p/5762392.html

http://blog.csdn.net/lookqlp/article/details/52121057

https://www.jianshu.com/p/27222830d21a

繼續閱讀