whole stage codegen 是spark 2.0 以后引入的新特性,所以在最后单独把这一块拿出来讲一下。
相关背景可以看spark官方的jira:https://issues.apache.org/jira/browse/SPARK-12795a
whole stage codegen对性能有很大的提升。
如下图所示,将一棵树翻译成了一段代码执行,性能肯定会大幅提升。
codegen的更多原理以及测试结果:
https://databricks.com/blog/2015/04/13/deep-dive-into-spark-sqls-catalyst-optimizer.html
Whole stage codegen是默认开启的:
`val WHOLESTAGE_CODEGEN_ENABLED = buildConf("spark.sql.codegen.wholeStage")`
.internal()
.doc("When true, the whole stage (of multiple operators) will be compiled into single java" +
" method.")
.booleanConf
`.createWithDefault(true)`
其入口逻辑在preparations里:
protected def preparations: Seq[Rule[SparkPlan]] = Seq(
python.ExtractPythonUDFs,
PlanSubqueries(sparkSession),
EnsureRequirements(sparkSession.sessionState.conf),
CollapseCodegenStages(sparkSession.sessionState.conf),
ReuseExchange(sparkSession.sessionState.conf),
ReuseSubquery(sparkSession.sessionState.conf))
其中的CollapseCodegenStages是codegen优化的入口。
他的apply方法,如果开启了whole stage codegen,则执行相关的逻辑:
def apply(plan: SparkPlan): SparkPlan = {
if (conf.wholeStageEnabled) {
WholeStageCodegenId.resetPerQuery()
insertWholeStageCodegen(plan)
} else {
plan
}
}
WholeStageCodegenId就是一个递增的计数器,用来计数,resetPerQuery重置为1:
object WholeStageCodegenId {
private val codegenStageCounter = ThreadLocal.withInitial(new Supplier[Integer] {
override def get() = 1 // TODO: change to Scala lambda syntax when upgraded to Scala 2.12+
})
def resetPerQuery(): Unit = codegenStageCounter.set(1)
def getNextStageId(): Int = {
val counter = codegenStageCounter
val id = counter.get()
counter.set(id + 1)
id
}
}
还记得前面的physical plan 每一个stage前面带的数字1,2,... 5么,这个就是WholeStageCodegenId,用来将codegen生成的class和operator关联;前面的*号代表这个stage进行了codegen。可以看到Exchange是没有codegen的,因为它没有计算,只是一个shuffle过程。
*(5) Project [B#6]
+- *(5) SortMergeJoin [B#6], [B#14], Inner
:- *(2) Sort [B#6 ASC NULLS FIRST], false, 0
: +- Exchange(coordinator id: 1121577170) hashpartitioning(B#6, 200), coordinator[target post-shuffle partition size: 67108864]
: +- *(1) Project [B#6]
: +- *(1) Filter isnotnull(B#6)
: +- *(1) FileScan json [B#6] Batched: false, Format: JSON, Location: InMemoryFileIndex[file:examples/src/main/resources/test.json], PartitionFilters: [], PushedFilters: [IsNotNull(B)], ReadSchema: struct<B:string>
+- *(4) Sort [B#14 ASC NULLS FIRST], false, 0
+- Exchange(coordinator id: 1121577170) hashpartitioning(B#14, 200), coordinator[target post-shuffle partition size: 67108864]
+- *(3) Project [B#14]
+- *(3) Filter isnotnull(B#14)
+- *(3) FileScan json [B#14] Batched: false, Format: JSON, Location: InMemoryFileIndex[file:examples/src/main/resources/test2.json], PartitionFilters: [], PushedFilters: [IsNotNull(B)], ReadSchema: struct<B:string>
接下来在支持codegen的plan上面添加WholeStageCodegenExec:
/**
* Inserts a WholeStageCodegen on top of those that support codegen.
*/
private def insertWholeStageCodegen(plan: SparkPlan): SparkPlan = plan match {
// For operators that will output domain object, do not insert WholeStageCodegen for it as
// domain object can not be written into unsafe row.
case plan if plan.output.length == 1 && plan.output.head.dataType.isInstanceOf[ObjectType] =>
plan.withNewChildren(plan.children.map(insertWholeStageCodegen))
case plan: CodegenSupport if supportCodegen(plan) =>
WholeStageCodegenExec(insertInputAdapter(plan))(WholeStageCodegenId.getNextStageId())
case other =>
other.withNewChildren(other.children.map(insertWholeStageCodegen))
}
如果有子plan不支持codegen,则在该plan上面添加InputAdapter
/**
* Inserts an InputAdapter on top of those that do not support codegen.
*/
private def insertInputAdapter(plan: SparkPlan): SparkPlan = plan match {
case p if !supportCodegen(p) =>
// collapse them recursively
InputAdapter(insertWholeStageCodegen(p))
case j: SortMergeJoinExec =>
// The children of SortMergeJoin should do codegen separately.
j.withNewChildren(j.children.map(child => InputAdapter(insertWholeStageCodegen(child))))
case p =>
p.withNewChildren(p.children.map(insertInputAdapter))
}
一个physical plan如果想要支持 codegen,要实现CodegenSupport接口,并且重写doProduce和doConsume函数。
下面以SortExec为例,他是支持codegen的:
// 只看 doProduce和doConsume方法
case class SortExec(
sortOrder: Seq[SortOrder],
global: Boolean,
child: SparkPlan,
testSpillFrequency: Int = 0)
extends UnaryExecNode with CodegenSupport {
// 返回的string即为要执行的code
override protected def doProduce(ctx: CodegenContext): String = {
val needToSort = ctx.addMutableState(ctx.JAVA_BOOLEAN, "needToSort", v => s"$v = true;")
// Initialize the class member variables. This includes the instance of the Sorter and
// the iterator to return sorted rows.
val thisPlan = ctx.addReferenceObj("plan", this)
// Inline mutable state since not many Sort operations in a task
sorterVariable = ctx.addMutableState(classOf[UnsafeExternalRowSorter].getName, "sorter",
v => s"$v = $thisPlan.createSorter();", forceInline = true)
val metrics = ctx.addMutableState(classOf[TaskMetrics].getName, "metrics",
v => s"$v = org.apache.spark.TaskContext.get().taskMetrics();", forceInline = true)
val sortedIterator = ctx.addMutableState("scala.collection.Iterator<UnsafeRow>", "sortedIter",
forceInline = true)
val addToSorter = ctx.freshName("addToSorter")
val addToSorterFuncName = ctx.addNewFunction(addToSorter,
s"""
| private void $addToSorter() throws java.io.IOException {
| ${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
| }
""".stripMargin.trim)
val outputRow = ctx.freshName("outputRow")
val peakMemory = metricTerm(ctx, "peakMemory")
val spillSize = metricTerm(ctx, "spillSize")
val spillSizeBefore = ctx.freshName("spillSizeBefore")
val sortTime = metricTerm(ctx, "sortTime")
s"""
| if ($needToSort) {
| long $spillSizeBefore = $metrics.memoryBytesSpilled();
| $addToSorterFuncName();
| $sortedIterator = $sorterVariable.sort();
| $sortTime.add($sorterVariable.getSortTimeNanos() / 1000000);
| $peakMemory.add($sorterVariable.getPeakMemoryUsage());
| $spillSize.add($metrics.memoryBytesSpilled() - $spillSizeBefore);
| $metrics.incPeakExecutionMemory($sorterVariable.getPeakMemoryUsage());
| $needToSort = false;
| }
|
| while ($sortedIterator.hasNext()) {
| UnsafeRow $outputRow = (UnsafeRow)$sortedIterator.next();
| ${consume(ctx, null, outputRow)}
| if (shouldStop()) return;
| }
""".stripMargin.trim
}
// 返回的string即为要执行的code
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
s"""
|${row.code}
|$sorterVariable.insertRow((UnsafeRow)${row.value});
""".stripMargin
}
doConsume返回的代码,doConsume负责消费input的数据,这里的逻辑是将数据插入
sort_sorter_0
sort_sorter_0.insertRow((UnsafeRow)inputadapter_row_0);
doProduce返回的代码,doProduce负责输出数据,这里的逻辑也很简单,忽略指标统计相关的代码,就是sort_sorter_0.sort(),进行排序而已.
if (sort_needToSort_0) {
long sort_spillSizeBefore_0 = sort_metrics_0.memoryBytesSpilled();
sort_addToSorter_0();
sort_sortedIter_0 = sort_sorter_0.sort();
((org.apache.spark.sql.execution.metric.SQLMetric) references[3] /* sortTime */).add(sort_sorter_0.getSortTimeNanos() / 1000000);
((org.apache.spark.sql.execution.metric.SQLMetric) references[1] /* peakMemory */).add(sort_sorter_0.getPeakMemoryUsage());
((org.apache.spark.sql.execution.metric.SQLMetric) references[2] /* spillSize */).add(sort_metrics_0.memoryBytesSpilled() - sort_spillSizeBefore_0);
sort_metrics_0.incPeakExecutionMemory(sort_sorter_0.getPeakMemoryUsage());
sort_needToSort_0 = false;
}
while (sort_sortedIter_0.hasNext()) {
UnsafeRow sort_outputRow_0 = (UnsafeRow)sort_sortedIter_0.next();
append(sort_outputRow_0);
if (shouldStop()) return;
}
使用该函数可以输出所有的生成的code:
spark.sql("SELECT A.B FROM A JOIN C ON A.B = C.B").queryExecution.debug.codegen()
SortExec是stage4,看一下生成的完整的代码:
public Object generate(Object[] references) {
return new GeneratedIteratorForCodegenStage4(references);
}
// 逻辑很简单,就是sort,输出
final class GeneratedIteratorForCodegenStage4 extends org.apache.spark.sql.execution.BufferedRowIterator {
private Object[] references;
private scala.collection.Iterator[] inputs;
private boolean sort_needToSort_0;
private org.apache.spark.sql.execution.UnsafeExternalRowSorter sort_sorter_0;
private org.apache.spark.executor.TaskMetrics sort_metrics_0;
private scala.collection.Iterator<UnsafeRow> sort_sortedIter_0;
private scala.collection.Iterator inputadapter_input_0;
private scala.collection.Iterator inputadapter_input_1;
public GeneratedIteratorForCodegenStage4(Object[] references) {
this.references = references;
}
public void init(int index, scala.collection.Iterator[] inputs) {
partitionIndex = index;
this.inputs = inputs;
sort_needToSort_0 = true;
sort_sorter_0 = ((org.apache.spark.sql.execution.SortExec) references[0] /* plan */).createSorter();
sort_metrics_0 = org.apache.spark.TaskContext.get().taskMetrics();
inputadapter_input_0 = inputs[0];
inputadapter_input_1 = inputs[0];
}
private void sort_addToSorter_0() throws java.io.IOException {
// 从input中输入, 此处是child的produce方法产生
while (inputadapter_input_1.hasNext() && !stopEarly()) {
InternalRow inputadapter_row_1 = (InternalRow) inputadapter_input_1.next();
sort_sorter_0.insertRow((UnsafeRow)inputadapter_row_1);
if (shouldStop()) return;
}
}
protected void processNext() throws java.io.IOException {
if (sort_needToSort_0) {
long sort_spillSizeBefore_0 = sort_metrics_0.memoryBytesSpilled();
sort_addToSorter_0();
sort_sortedIter_0 = sort_sorter_0.sort();
((org.apache.spark.sql.execution.metric.SQLMetric) references[3] /* sortTime */).add(sort_sorter_0.getSortTimeNanos() /
((org.apache.spark.sql.execution.metric.SQLMetric) references[1] /* peakMemory */).add(
sort_sorter_0.getPeakMemoryUsage());
((org.apache.spark.sql.execution.metric.SQLMetric) references[2] /* spillSize */).add(
sort_metrics_0.memoryBytesSpilled() - sort_spillSizeBefore_0);
sort_metrics_0.incPeakExecutionMemory(sort_sorter_0.getPeakMemoryUsage());
sort_needToSort_0 = false;
}
while (sort_sortedIter_0.hasNext()) {
UnsafeRow sort_outputRow_0 = (UnsafeRow)sort_sortedIter_0.next();
// 往output输出,此处是parent的consume方法产生
append(sort_outputRow_0);
if (shouldStop()) return;
}
}
}
所以codegen的思路是child plan produce 数据给 parent plan, 而 parent plan consume child plan的数据,通过这种方式,将plan串联起来,生成最终执行的code。所以各个plan需要重写 doProduce 和 doConsume方法。
最后看一下WholeStageCodegenExec的 doExecute方法:
override def doExecute(): RDD[InternalRow] = {
// 得到生成的code
val (ctx, cleanedSource) = doCodeGen()
// try to compile and fallback if it failed
val (_, maxCodeSize) = try {
// 编译code
CodeGenerator.compile(cleanedSource)
} catch {
case _: Exception if !Utils.isTesting && sqlContext.conf.codegenFallback =>
// We should already saw the error message
logWarning(s"Whole-stage codegen disabled for plan (id=$codegenStageId):\n $treeString")
return child.execute()
}
// Check if compiled code has a too large function
if (maxCodeSize > sqlContext.conf.hugeMethodLimit) {
logInfo(s"Found too long generated codes and JIT optimization might not work: " +
s"the bytecode size ($maxCodeSize) is above the limit " +
s"${sqlContext.conf.hugeMethodLimit}, and the whole-stage codegen was disabled " +
s"for this plan (id=$codegenStageId). To avoid this, you can raise the limit " +
s"`${SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key}`:\n$treeString")
child match {
// The fallback solution of batch file source scan still uses WholeStageCodegenExec
case f: FileSourceScanExec if f.supportsBatch => // do nothing
case _ => return child.execute()
}
}
val references = ctx.references.toArray
val durationMs = longMetric("pipelineTime")
val rdds = child.asInstanceOf[CodegenSupport].inputRDDs()
assert(rdds.size <= 2, "Up to two input RDDs can be supported")
if (rdds.length == 1) {
rdds.head.mapPartitionsWithIndex { (index, iter) =>
// 编译code,生成类。使用了Janino。
val (clazz, _) = CodeGenerator.compile(cleanedSource)
// 反射生成code的实例
val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
// 执行逻辑
buffer.init(index, Array(iter))
new Iterator[InternalRow] {
override def hasNext: Boolean = {
val v = buffer.hasNext
if (!v) durationMs += buffer.durationMs()
v
}
override def next: InternalRow = buffer.next()
}
}
} else {
// Right now, we support up to two input RDDs.
rdds.head.zipPartitions(rdds(1)) { (leftIter, rightIter) =>
Iterator((leftIter, rightIter))
// a small hack to obtain the correct partition index
}.mapPartitionsWithIndex { (index, zippedIter) =>
val (leftIter, rightIter) = zippedIter.next()
val (clazz, _) = CodeGenerator.compile(cleanedSource)
val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
buffer.init(index, Array(leftIter, rightIter))
new Iterator[InternalRow] {
override def hasNext: Boolean = {
val v = buffer.hasNext
if (!v) durationMs += buffer.durationMs()
v
}
override def next: InternalRow = buffer.next()
}
}
}
}
逻辑不复杂,就是通过Janino编译代码生成类,通过反射实例化并执行。
whole stage codegen 的介绍就到这里了。其实原理并没有想象的复杂,大家仔细阅读源码肯定能搞明白。