作者:周克勇,花名一錘,阿裡巴巴計算平台事業部EMR團隊技術專家,大資料領域技術愛好者,對Spark有濃厚興趣和一定的了解,目前主要專注于EMR産品中開源計算引擎的優化工作。
背景介紹
SparkSQL的優越性能背後有兩大技術支柱:Optimizer和Runtime。前者緻力于尋找最優的執行計劃,後者則緻力于把既定的執行計劃盡可能快地執行出來。Runtime的多種優化可概括為兩個層面:
1. 全局優化。從提升全局資源使用率、消除資料傾斜、降低IO等角度做優化,包括自适應執行(Adaptive Execution), Shuffle Removal等。
2. 局部優化。優化具體的Task的執行效率,主要依賴Codegen技術,具體包括Expression級别和WholeStage級别的Codegen。
本文介紹Spark Codegen的技術原理。
Case Study
本節通過兩個具體case介紹Codegen的做法。
Expression級别
考慮下面的表達式計算:x + (1 + 2),用scala代碼表達如下:
Add(Attribute(x), Add(Literal(1), Literal(2)))
文法樹如下:

遞歸求值這棵文法樹的正常代碼如下:
tree.transformUp {
case Attribute(idx) => Literal(row.getValue(idx))
case Add(Literal(c1),Literal(c2)) => Literal(c1+c2)
case Literal(c) => Literal(c)
}
執行上述代碼需要做很多類型比對、虛函數調用、對象建立等額外邏輯,這些overhead遠超對表達式求值本身。
為了消除這些overhead,Spark Codegen直接拼成求值表達式的java代碼并進行即時編譯。具體分為三個步驟:
1. 代碼生成。根據文法樹生成java代碼,封裝在wrapper類中:
... // class wrapper
row.getValue(idx) + (1 + 2)
... // class wrapper
2. 即時編譯。使用Janino架構把生成代碼編譯成class檔案。
3. 加載執行。最後加載并執行。
優化前後性能有數量級的提升。
WholeStage級别
考慮如下的sql語句:
select count(*) from store_sales
where ss_item_sk=1000;
生成的實體執行計劃如下:
執行該計劃的正常做法是使用火山模型(vocano model),每個Operator都繼承了Iterator接口,其next()方法首先驅動上遊執行拿到輸入,然後執行自己的邏輯。代碼示例如下:
class Agg extends Iterator[Row] {
def doAgg() {
while (child.hasNext()) {
val row = child.next();
// do aggregation
...
}
}
def next(): Row {
if (!doneAgg) {
doAgg();
}
return aggIter.next();
}
}
class Filter extends Iterator[Row] {
def next(): Row {
var current = child.next()
while (current != null && !predicate(current)) {
current = child.next()
}
return current;
}
}
從上述代碼可知,火山模型會有大量類型轉換和虛函數調用。虛函數調用會導緻CPU分支預測失敗,進而導緻嚴重的性能回退。
為了消除這些overhead,Spark WholestageCodegen會為該實體計劃生成類型确定的java代碼,然後類似Expression的做法即時編譯和加載執行。本例生成的java代碼示例如下(非真實代碼,真實代碼片段見後文):
var count = 0
for (ss_item_sk in store_sales) {
if (ss_item_sk == 1000) {
count += 1
}
}
優化前後性能提升資料如下:
Spark Codegen架構
Spark Codegen架構有三個核心組成部分
1. 核心接口/類
2. CodegenContext
3. Produce-Consume Pattern
接下來詳細介紹。
接口/類
四個核心接口:
1. CodegenSupport(接口)
實作該接口的Operator可以将自己的邏輯拼成java代碼。重要方法:
produce() // 輸出本節點産出Row的java代碼
consume() // 輸出本節點消費上遊節點輸入的Row的java代碼
實作類包括但不限于: ProjectExec, FilterExec, HashAggregateExec, SortMergeJoinExec。
2. WholeStageCodegenExec(類)
CodegenSupport的實作類之一,Stage内部所有相鄰的實作CodegenSupport接口的Operator的融合,産出的代碼把所有被融合的Operator的執行邏輯封裝到一個Wrapper類中,該Wrapper類作為Janino即時compile的入參。
3. InputAdapter(類)
CodegenSupport的實作類之一,膠水類,用來連接配接WholeStageCodegenExec節點和未實作CodegenSupport的上遊節點。
4. BufferedRowIterator(接口)
WholeStageCodegenExec生成的java代碼的父類,重要方法:
public InternalRow next() // 傳回下一條Row
public void append(InternalRow row) // append一條Row
CodegenContext
管理生成代碼的核心類。主要涵蓋以下功能:
1.命名管理。保證同一Scope内無變量名沖突。
2.變量管理。維護類變量,判斷變量類型(應該聲明為獨立變量還是壓縮到類型數組中),維護變量初始化邏輯等。
3.方法管理。維護類方法。
4.内部類管理。維護内部類。
5.相同表達式管理。維護相同子表達式,避免重複計算。
6.size管理。避免方法、類size過大,避免類變量數過多,進行比較拆分。如把表達式塊拆分成多個函數;把函數、變量定義拆分到多個内部類。
7.依賴管理。維護該類依賴的外部對象,如Broadcast對象、工具對象、度量對象等。
8.通用模闆管理。提供通用代碼模闆,如genComp, nullSafeExec等。
Produce-Consume Pattern
相鄰Operator通過Produce-Consume模式生成代碼。
Produce生成整體處理的架構代碼,例如aggregation生成的代碼架構如下:
if (!initialized) {
# create a hash map, then build the aggregation hash map
# call child.produce()
initialized = true;
}
while (hashmap.hasNext()) {
row = hashmap.next();
# build the aggregation results
# create variables for results
# call consume(), which will call parent.doConsume()
if (shouldStop()) return;
}
Consume生成目前節點處理上遊輸入的Row的邏輯。如Filter生成代碼如下:
# code to evaluate the predicate expression, result is isNull1 and value2
if (!isNull1 && value2) {
# call consume(), which will call parent.doConsume()
}
下圖比較清晰地展示了WholestageCodegen生成java代碼的call graph:
Case Study的示例,生成的真實代碼如下:
== Subtree 1 / 2 ==
*(2) HashAggregate(keys=[], functions=[count(1)], output=[count(1)#326L])
+- Exchange SinglePartition
+- *(1) HashAggregate(keys=[], functions=[partial_count(1)], output=[count#329L])
+- *(1) Project
+- *(1) Filter (isnotnull(ss_item_sk#13L) && (ss_item_sk#13L = 1000))
+- *(1) FileScan parquet [ss_item_sk#13L] Batched: true, Format: Parquet, Location: InMemoryFileIndex[file:/home/admin/zhoukeyong/workspace/tpc/tpcds/data/parquet/10/store_sales/par..., PartitionFilters: [], PushedFilters: [IsNotNull(ss_item_sk), EqualTo(ss_item_sk,1000)], ReadSchema: struct<ss_item_sk:bigint>
Generated code:
/* 001 */ public Object generate(Object[] references) {
/* 002 */ return new GeneratedIteratorForCodegenStage2(references);
/* 003 */ }
/* 004 */
/* 005 */ // codegenStageId=2
/* 006 */ final class GeneratedIteratorForCodegenStage2 extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 007 */ private Object[] references;
/* 008 */ private scala.collection.Iterator[] inputs;
/* 009 */ private boolean agg_initAgg_0;
/* 010 */ private boolean agg_bufIsNull_0;
/* 011 */ private long agg_bufValue_0;
/* 012 */ private scala.collection.Iterator inputadapter_input_0;
/* 013 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] agg_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[1];
/* 014 */
/* 015 */ public GeneratedIteratorForCodegenStage2(Object[] references) {
/* 016 */ this.references = references;
/* 017 */ }
/* 018 */
/* 019 */ public void init(int index, scala.collection.Iterator[] inputs) {
/* 020 */ partitionIndex = index;
/* 021 */ this.inputs = inputs;
/* 022 */
/* 023 */ inputadapter_input_0 = inputs[0];
/* 024 */ agg_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 025 */
/* 026 */ }
/* 027 */
/* 028 */ private void agg_doAggregateWithoutKey_0() throws java.io.IOException {
/* 029 */ // initialize aggregation buffer
/* 030 */ agg_bufIsNull_0 = false;
/* 031 */ agg_bufValue_0 = 0L;
/* 032 */
/* 033 */ while (inputadapter_input_0.hasNext() && !stopEarly()) {
/* 034 */ InternalRow inputadapter_row_0 = (InternalRow) inputadapter_input_0.next();
/* 035 */ long inputadapter_value_0 = inputadapter_row_0.getLong(0);
/* 036 */
/* 037 */ agg_doConsume_0(inputadapter_row_0, inputadapter_value_0);
/* 038 */ if (shouldStop()) return;
/* 039 */ }
/* 040 */
/* 041 */ }
/* 042 */
/* 043 */ private void agg_doConsume_0(InternalRow inputadapter_row_0, long agg_expr_0_0) throws java.io.IOException {
/* 044 */ // do aggregate
/* 045 */ // common sub-expressions
/* 046 */
/* 047 */ // evaluate aggregate function
/* 048 */ long agg_value_3 = -1L;
/* 049 */ agg_value_3 = agg_bufValue_0 + agg_expr_0_0;
/* 050 */ // update aggregation buffer
/* 051 */ agg_bufIsNull_0 = false;
/* 052 */ agg_bufValue_0 = agg_value_3;
/* 053 */
/* 054 */ }
/* 055 */
/* 056 */ protected void processNext() throws java.io.IOException {
/* 057 */ while (!agg_initAgg_0) {
/* 058 */ agg_initAgg_0 = true;
/* 059 */ long agg_beforeAgg_0 = System.nanoTime();
/* 060 */ agg_doAggregateWithoutKey_0();
/* 061 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[1] /* aggTime */).add((System.nanoTime() - agg_beforeAgg_0) / 1000000);
/* 062 */
/* 063 */ // output the result
/* 064 */
/* 065 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1);
/* 066 */ agg_mutableStateArray_0[0].reset();
/* 067 */
/* 068 */ agg_mutableStateArray_0[0].zeroOutNullBytes();
/* 069 */
/* 070 */ agg_mutableStateArray_0[0].write(0, agg_bufValue_0);
/* 071 */ append((agg_mutableStateArray_0[0].getRow()));
/* 072 */ }
/* 073 */ }
/* 074 */
/* 075 */ }
== Subtree 2 / 2 ==
*(1) HashAggregate(keys=[], functions=[partial_count(1)], output=[count#329L])
+- *(1) Project
+- *(1) Filter (isnotnull(ss_item_sk#13L) && (ss_item_sk#13L = 1000))
+- *(1) FileScan parquet [ss_item_sk#13L] Batched: true, Format: Parquet, Location: InMemoryFileIndex[file:/home/admin/zhoukeyong/workspace/tpc/tpcds/data/parquet/10/store_sales/par..., PartitionFilters: [], PushedFilters: [IsNotNull(ss_item_sk), EqualTo(ss_item_sk,1000)], ReadSchema: struct<ss_item_sk:bigint>
Generated code:
/* 001 */ public Object generate(Object[] references) {
/* 002 */ return new GeneratedIteratorForCodegenStage1(references);
/* 003 */ }
/* 004 */
/* 005 */ // codegenStageId=1
/* 006 */ final class GeneratedIteratorForCodegenStage1 extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 007 */ private Object[] references;
/* 008 */ private scala.collection.Iterator[] inputs;
/* 009 */ private boolean agg_initAgg_0;
/* 010 */ private boolean agg_bufIsNull_0;
/* 011 */ private long agg_bufValue_0;
/* 012 */ private long scan_scanTime_0;
/* 013 */ private boolean outputMetaColumns;
/* 014 */ private int scan_batchIdx_0;
/* 015 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] scan_mutableStateArray_3 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[3];
/* 016 */ private org.apache.spark.sql.vectorized.ColumnarBatch[] scan_mutableStateArray_1 = new org.apache.spark.sql.vectorized.ColumnarBatch[1];
/* 017 */ private scala.collection.Iterator[] scan_mutableStateArray_0 = new scala.collection.Iterator[1];
/* 018 */ private org.apache.spark.sql.execution.vectorized.OffHeapColumnVector[] scan_mutableStateArray_2 = new org.apache.spark.sql.execution.vectorized.OffHeapColumnVector[1];
/* 019 */
/* 020 */ public GeneratedIteratorForCodegenStage1(Object[] references) {
/* 021 */ this.references = references;
/* 022 */ }
/* 023 */
/* 024 */ public void init(int index, scala.collection.Iterator[] inputs) {
/* 025 */ partitionIndex = index;
/* 026 */ this.inputs = inputs;
/* 027 */
/* 028 */ scan_mutableStateArray_0[0] = inputs[0];
/* 029 */ outputMetaColumns = false;
/* 030 */ scan_mutableStateArray_3[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 031 */ scan_mutableStateArray_3[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 032 */ scan_mutableStateArray_3[2] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 033 */
/* 034 */ }
/* 035 */
/* 036 */ private void agg_doAggregateWithoutKey_0() throws java.io.IOException {
/* 037 */ // initialize aggregation buffer
/* 038 */ agg_bufIsNull_0 = false;
/* 039 */ agg_bufValue_0 = 0L;
/* 040 */
/* 041 */ if (scan_mutableStateArray_1[0] == null) {
/* 042 */ scan_nextBatch_0();
/* 043 */ }
/* 044 */ while (scan_mutableStateArray_1[0] != null) {
/* 045 */ int scan_numRows_0 = scan_mutableStateArray_1[0].numRows();
/* 046 */ int scan_localEnd_0 = scan_numRows_0 - scan_batchIdx_0;
/* 047 */ for (int scan_localIdx_0 = 0; scan_localIdx_0 < scan_localEnd_0; scan_localIdx_0++) {
/* 048 */ int scan_rowIdx_0 = scan_batchIdx_0 + scan_localIdx_0;
/* 049 */ if (!scan_mutableStateArray_1[0].validAt(scan_rowIdx_0)) { continue; }
/* 050 */ do {
/* 051 */ boolean scan_isNull_0 = scan_mutableStateArray_2[0].isNullAt(scan_rowIdx_0);
/* 052 */ long scan_value_0 = scan_isNull_0 ? -1L : (scan_mutableStateArray_2[0].getLong(scan_rowIdx_0));
/* 053 */
/* 054 */ if (!(!scan_isNull_0)) continue;
/* 055 */
/* 056 */ boolean filter_value_2 = false;
/* 057 */ filter_value_2 = scan_value_0 == 1000L;
/* 058 */ if (!filter_value_2) continue;
/* 059 */
/* 060 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[2] /* numOutputRows */).add(1);
/* 061 */
/* 062 */ agg_doConsume_0();
/* 063 */
/* 064 */ } while(false);
/* 065 */ // shouldStop check is eliminated
/* 066 */ }
/* 067 */ scan_batchIdx_0 = scan_numRows_0;
/* 068 */ scan_mutableStateArray_1[0] = null;
/* 069 */ scan_nextBatch_0();
/* 070 */ }
/* 071 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[1] /* scanTime */).add(scan_scanTime_0 / (1000 * 1000));
/* 072 */ scan_scanTime_0 = 0;
/* 073 */
/* 074 */ }
/* 075 */
/* 076 */ private void scan_nextBatch_0() throws java.io.IOException {
/* 077 */ long getBatchStart = System.nanoTime();
/* 078 */ if (scan_mutableStateArray_0[0].hasNext()) {
/* 079 */ scan_mutableStateArray_1[0] = (org.apache.spark.sql.vectorized.ColumnarBatch)scan_mutableStateArray_0[0].next();
/* 080 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(scan_mutableStateArray_1[0].numRows());
/* 081 */ scan_batchIdx_0 = 0;
/* 082 */ scan_mutableStateArray_2[0] = (org.apache.spark.sql.execution.vectorized.OffHeapColumnVector) (outputMetaColumns ?
/* 083 */ scan_mutableStateArray_1[0].column(0, true) : scan_mutableStateArray_1[0].column(0));
/* 084 */
/* 085 */ }
/* 086 */ scan_scanTime_0 += System.nanoTime() - getBatchStart;
/* 087 */ }
/* 088 */
/* 089 */ private void agg_doConsume_0() throws java.io.IOException {
/* 090 */ // do aggregate
/* 091 */ // common sub-expressions
/* 092 */
/* 093 */ // evaluate aggregate function
/* 094 */ long agg_value_1 = -1L;
/* 095 */ agg_value_1 = agg_bufValue_0 + 1L;
/* 096 */ // update aggregation buffer
/* 097 */ agg_bufIsNull_0 = false;
/* 098 */ agg_bufValue_0 = agg_value_1;
/* 099 */
/* 100 */ }
/* 101 */
/* 102 */ protected void processNext() throws java.io.IOException {
/* 103 */ while (!agg_initAgg_0) {
/* 104 */ agg_initAgg_0 = true;
/* 105 */ long agg_beforeAgg_0 = System.nanoTime();
/* 106 */ agg_doAggregateWithoutKey_0();
/* 107 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[4] /* aggTime */).add((System.nanoTime() - agg_beforeAgg_0) / 1000000);
/* 108 */
/* 109 */ // output the result
/* 110 */
/* 111 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[3] /* numOutputRows */).add(1);
/* 112 */ scan_mutableStateArray_3[2].reset();
/* 113 */
/* 114 */ scan_mutableStateArray_3[2].zeroOutNullBytes();
/* 115 */
/* 116 */ scan_mutableStateArray_3[2].write(0, agg_bufValue_0);
/* 117 */ append((scan_mutableStateArray_3[2].getRow()));
/* 118 */ }
/* 119 */ }
/* 120 */
/* 121 */ }
阿裡巴巴開源大資料技術團隊成立Apache Spark中國技術社群,定期推送精彩案例,技術專家直播,問答區數個Spark技術同學每日線上答疑,隻為營造純粹的Spark氛圍,歡迎釘釘掃碼加入!