This is an automated email from the ASF dual-hosted git repository. libenchao pushed a commit to branch release-1.11 in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/release-1.11 by this push: new 334b7b8 [FLINK-16589][table-planner-blink] Split code for AggsHandlerCodeGenerator 334b7b8 is described below commit 334b7b8b00885135b0682756f970b4e440f0f189 Author: Benchao Li <libenc...@gmail.com> AuthorDate: Fri Jun 19 16:15:40 2020 +0800 [FLINK-16589][table-planner-blink] Split code for AggsHandlerCodeGenerator This closes #12710 --- .../planner/codegen/CodeGeneratorContext.scala | 31 +++++----- .../table/planner/codegen/ExprCodeGenerator.scala | 18 ++++-- .../table/planner/codegen/GenerateUtils.scala | 29 ++++++---- .../planner/codegen/ProjectionCodeGenerator.scala | 3 +- .../codegen/agg/AggsHandlerCodeGenerator.scala | 66 ++++++++++++++++------ .../planner/codegen/agg/DistinctAggCodeGen.scala | 3 +- .../codegen/agg/batch/HashAggCodeGenHelper.scala | 3 +- .../runtime/stream/sql/AggregateITCase.scala | 26 +++++++++ 8 files changed, 129 insertions(+), 50 deletions(-) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/CodeGeneratorContext.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/CodeGeneratorContext.scala index 5ef90a2..b9bc6fc 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/CodeGeneratorContext.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/CodeGeneratorContext.scala @@ -18,7 +18,6 @@ package org.apache.flink.table.planner.codegen -import org.apache.flink.api.common.ExecutionConfig import org.apache.flink.api.common.functions.{Function, RuntimeContext} import org.apache.flink.api.common.typeutils.TypeSerializer import org.apache.flink.table.api.TableConfig @@ -109,9 +108,9 @@ class CodeGeneratorContext(val tableConfig: TableConfig) { private var currentMethodNameForLocalVariables = "DEFAULT" /** - * Flag that indicates whether the generated code is split into several methods. + * Flag map that indicates whether the generated code for method is split into several methods. */ - private var isCodeSplit = false + private val isCodeSplitMap = mutable.Map[String, Boolean]() // map of local variable statements. It will be placed in method if method code not excess // max code length, otherwise will be placed in member area of the class. The statements @@ -149,11 +148,12 @@ class CodeGeneratorContext(val tableConfig: TableConfig) { } /** - * Set the flag [[isCodeSplit]] to be true, which indicates the generated code is split into - * several methods. + * Set the flag [[isCodeSplitMap]] to be true for methodName, which indicates + * the generated code is split into several methods. + * @param methodName the method which will be split. */ - def setCodeSplit(): Unit = { - isCodeSplit = true + def setCodeSplit(methodName: String = currentMethodNameForLocalVariables): Unit = { + isCodeSplitMap(methodName) = true } /** @@ -210,10 +210,14 @@ class CodeGeneratorContext(val tableConfig: TableConfig) { */ def reuseMemberCode(): String = { val result = reusableMemberStatements.mkString("\n") - if (isCodeSplit) { + if (isCodeSplitMap.nonEmpty) { val localVariableAsMember = reusableLocalVariableStatements.map( - statements => statements._2.map("private " + _).mkString("\n") - ).mkString("\n") + statements => if (isCodeSplitMap.getOrElse(statements._1, false)) { + statements._2.map("private " + _).mkString("\n") + } else { + "" + } + ).filter(_.length > 0).mkString("\n") result + "\n" + localVariableAsMember } else { result @@ -224,8 +228,8 @@ class CodeGeneratorContext(val tableConfig: TableConfig) { * @return code block of statements that will be placed in the member area of the class * if generated code is split or in local variables of method */ - def reuseLocalVariableCode(methodName: String = null): String = { - if (isCodeSplit) { + def reuseLocalVariableCode(methodName: String = currentMethodNameForLocalVariables): String = { + if (isCodeSplitMap.getOrElse(methodName, false)) { GeneratedExpression.NO_CODE } else if (methodName == null) { reusableLocalVariableStatements(currentMethodNameForLocalVariables).mkString("\n") @@ -375,8 +379,7 @@ class CodeGeneratorContext(val tableConfig: TableConfig) { clazz: Class[_], outRecordTerm: String, outRecordWriterTerm: Option[String] = None): Unit = { - val statement = generateRecordStatement(t, clazz, outRecordTerm, outRecordWriterTerm) - reusableMemberStatements.add(statement) + generateRecordStatement(t, clazz, outRecordTerm, outRecordWriterTerm, this) } /** diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala index fbea1be..3928294 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala @@ -227,10 +227,12 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean) outRowWriter: Option[String] = Some(DEFAULT_OUT_RECORD_WRITER_TERM), reusedOutRow: Boolean = true, outRowAlreadyExists: Boolean = false, - allowSplit: Boolean = false): GeneratedExpression = { + allowSplit: Boolean = false, + methodName: String = null): GeneratedExpression = { val fieldExprIdxToOutputRowPosMap = fieldExprs.indices.map(i => i -> i).toMap generateResultExpression(fieldExprs, fieldExprIdxToOutputRowPosMap, returnType, - returnTypeClazz, outRow, outRowWriter, reusedOutRow, outRowAlreadyExists, allowSplit) + returnTypeClazz, outRow, outRowWriter, reusedOutRow, outRowAlreadyExists, + allowSplit, methodName) } /** @@ -257,7 +259,8 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean) outRowWriter: Option[String], reusedOutRow: Boolean, outRowAlreadyExists: Boolean, - allowSplit: Boolean) + allowSplit: Boolean, + methodName: String) : GeneratedExpression = { // initial type check if (returnType.getFieldCount != fieldExprs.length) { @@ -298,7 +301,11 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean) val maxCodeLength = ctx.tableConfig.getMaxGeneratedCodeLength val setFieldsCode = if (allowSplit && totalLen > maxCodeLength) { // do the split. - ctx.setCodeSplit() + if (methodName != null) { + ctx.setCodeSplit(methodName) + } else { + ctx.setCodeSplit() + } setFieldsCodes.map(project => { val methodName = newName("split") val method = @@ -315,9 +322,8 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean) } val outRowInitCode = if (!outRowAlreadyExists) { - val initCode = generateRecordStatement(returnType, returnTypeClazz, outRow, outRowWriter) + val initCode = generateRecordStatement(returnType, returnTypeClazz, outRow, outRowWriter, ctx) if (reusedOutRow) { - ctx.addReusableMember(initCode) NO_CODE } else { initCode diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/GenerateUtils.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/GenerateUtils.scala index 9d0fe44..d707139 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/GenerateUtils.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/GenerateUtils.scala @@ -208,21 +208,23 @@ object GenerateUtils { // --------------------------- General Generate Utils ---------------------------------- /** - * Generates a record declaration statement. The record can be any type of RowData or - * other types. + * Generates a record declaration statement, and add it to reusable member. The record + * can be any type of RowData or other types. * * @param t the record type * @param clazz the specified class of the type (only used when RowType) * @param recordTerm the record term to be declared * @param recordWriterTerm the record writer term (only used when BinaryRowData type) - * @return the record declaration statement + * @param ctx the code generator context + * @return the record initialization statement */ @tailrec def generateRecordStatement( t: LogicalType, clazz: Class[_], recordTerm: String, - recordWriterTerm: Option[String] = None) + recordWriterTerm: Option[String] = None, + ctx: CodeGeneratorContext) : String = t.getTypeRoot match { // ordered by type root definition case ROW | STRUCTURED_TYPE if clazz == classOf[BinaryRowData] => @@ -231,26 +233,33 @@ object GenerateUtils { ) val binaryRowWriter = className[BinaryRowWriter] val typeTerm = clazz.getCanonicalName + ctx.addReusableMember(s"$typeTerm $recordTerm = new $typeTerm(${getFieldCount(t)});") + ctx.addReusableMember( + s"$binaryRowWriter $writerTerm = new $binaryRowWriter($recordTerm);") s""" - |final $typeTerm $recordTerm = new $typeTerm(${getFieldCount(t)}); - |final $binaryRowWriter $writerTerm = new $binaryRowWriter($recordTerm); + |$recordTerm = new $typeTerm(${getFieldCount(t)}); + |$writerTerm = new $binaryRowWriter($recordTerm); |""".stripMargin.trim case ROW | STRUCTURED_TYPE if clazz == classOf[GenericRowData] || clazz == classOf[BoxedWrapperRowData] => val typeTerm = clazz.getCanonicalName - s"final $typeTerm $recordTerm = new $typeTerm(${getFieldCount(t)});" + ctx.addReusableMember(s"$typeTerm $recordTerm = new $typeTerm(${getFieldCount(t)});") + s"$recordTerm = new $typeTerm(${getFieldCount(t)});" case ROW | STRUCTURED_TYPE if clazz == classOf[JoinedRowData] => val typeTerm = clazz.getCanonicalName - s"final $typeTerm $recordTerm = new $typeTerm();" + ctx.addReusableMember(s"$typeTerm $recordTerm = new $typeTerm();") + s"$recordTerm = new $typeTerm();" case DISTINCT_TYPE => generateRecordStatement( t.asInstanceOf[DistinctType].getSourceType, clazz, recordTerm, - recordWriterTerm) + recordWriterTerm, + ctx) case _ => val typeTerm = boxedTypeTermForType(t) - s"final $typeTerm $recordTerm = new $typeTerm();" + ctx.addReusableMember(s"$typeTerm $recordTerm = new $typeTerm();") + s"$recordTerm = new $typeTerm();" } def generateNullLiteral( diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/ProjectionCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/ProjectionCodeGenerator.scala index 7581885..eceb0c9 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/ProjectionCodeGenerator.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/ProjectionCodeGenerator.scala @@ -124,9 +124,8 @@ object ProjectionCodeGenerator { val outRowInitCode = { val initCode = generateRecordStatement( - outType, outClass, outRecordTerm, Some(outRecordWriterTerm)) + outType, outClass, outRecordTerm, Some(outRecordWriterTerm), ctx) if (reusedOutRecord) { - ctx.addReusableMember(initCode) NO_CODE } else { initCode diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/AggsHandlerCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/AggsHandlerCodeGenerator.scala index c89a8d6..d9f496a 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/AggsHandlerCodeGenerator.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/AggsHandlerCodeGenerator.scala @@ -585,6 +585,7 @@ class AggsHandlerCodeGenerator( public final class $functionName implements $NAMESPACE_AGGS_HANDLER_FUNCTION<$namespaceClassName> { + private $namespaceClassName $NAMESPACE_TERM; ${ctx.reuseMemberCode()} public $functionName(Object[] references) throws Exception { @@ -608,14 +609,14 @@ class AggsHandlerCodeGenerator( @Override public void merge(Object ns, $ROW_DATA $MERGED_ACC_TERM) throws Exception { - $namespaceClassName $NAMESPACE_TERM = ($namespaceClassName) ns; + $NAMESPACE_TERM = ($namespaceClassName) ns; $mergeCode } @Override public void setAccumulators(Object ns, $ROW_DATA $ACC_TERM) throws Exception { - $namespaceClassName $NAMESPACE_TERM = ($namespaceClassName) ns; + $NAMESPACE_TERM = ($namespaceClassName) ns; $setAccumulatorsCode } @@ -631,13 +632,13 @@ class AggsHandlerCodeGenerator( @Override public $ROW_DATA getValue(Object ns) throws Exception { - $namespaceClassName $NAMESPACE_TERM = ($namespaceClassName) ns; + $NAMESPACE_TERM = ($namespaceClassName) ns; $getValueCode } @Override public void cleanup(Object ns) throws Exception { - $namespaceClassName $NAMESPACE_TERM = ($namespaceClassName) ns; + $NAMESPACE_TERM = ($namespaceClassName) ns; ${ctx.reuseCleanupCode()} } @@ -684,6 +685,7 @@ class AggsHandlerCodeGenerator( public final class $functionName implements ${className[NamespaceTableAggsHandleFunction[_]]}<$namespaceClassName> { + private $namespaceClassName $NAMESPACE_TERM; ${ctx.reuseMemberCode()} private $CONVERT_COLLECTOR_TYPE_TERM $MEMBER_COLLECTOR_TERM; @@ -709,14 +711,14 @@ class AggsHandlerCodeGenerator( @Override public void merge(Object ns, $ROW_DATA $MERGED_ACC_TERM) throws Exception { - $namespaceClassName $NAMESPACE_TERM = ($namespaceClassName) ns; + $NAMESPACE_TERM = ($namespaceClassName) ns; $mergeCode } @Override public void setAccumulators(Object ns, $ROW_DATA $ACC_TERM) throws Exception { - $namespaceClassName $NAMESPACE_TERM = ($namespaceClassName) ns; + $NAMESPACE_TERM = ($namespaceClassName) ns; $setAccumulatorsCode } @@ -735,13 +737,13 @@ class AggsHandlerCodeGenerator( $COLLECTOR<$ROW_DATA> $COLLECTOR_TERM) throws Exception { $MEMBER_COLLECTOR_TERM.$COLLECTOR_TERM = $COLLECTOR_TERM; - $namespaceClassName $NAMESPACE_TERM = ($namespaceClassName) ns; + $NAMESPACE_TERM = ($namespaceClassName) ns; $emitValueCode } @Override public void cleanup(Object ns) throws Exception { - $namespaceClassName $NAMESPACE_TERM = ($namespaceClassName) ns; + $NAMESPACE_TERM = ($namespaceClassName) ns; ${ctx.reuseCleanupCode()} } @@ -806,7 +808,9 @@ class AggsHandlerCodeGenerator( accTypeInfo, classOf[GenericRowData], outRow = accTerm, - reusedOutRow = false) + reusedOutRow = false, + allowSplit = true, + methodName = methodName) s""" |${ctx.reuseLocalVariableCode(methodName)} @@ -829,7 +833,9 @@ class AggsHandlerCodeGenerator( accTypeInfo, classOf[GenericRowData], outRow = accTerm, - reusedOutRow = false) + reusedOutRow = false, + allowSplit = true, + methodName = methodName) s""" |${ctx.reuseLocalVariableCode(methodName)} @@ -845,7 +851,8 @@ class AggsHandlerCodeGenerator( // bind input1 as accumulators val exprGenerator = new ExprCodeGenerator(ctx, INPUT_NOT_NULL) .bindInput(accTypeInfo, inputTerm = ACC_TERM) - val body = aggBufferCodeGens.map(_.setAccumulator(exprGenerator)).mkString("\n") + val body = splitExpressionsIfNecessary( + aggBufferCodeGens.map(_.setAccumulator(exprGenerator)), methodName) s""" |${ctx.reuseLocalVariableCode(methodName)} @@ -859,7 +866,8 @@ class AggsHandlerCodeGenerator( ctx.startNewLocalVariableStatement(methodName) val exprGenerator = new ExprCodeGenerator(ctx, INPUT_NOT_NULL) - val body = aggBufferCodeGens.map(_.resetAccumulator(exprGenerator)).mkString("\n") + val body = splitExpressionsIfNecessary(aggBufferCodeGens.map(_.resetAccumulator(exprGenerator)), + methodName) s""" |${ctx.reuseLocalVariableCode(methodName)} @@ -878,7 +886,8 @@ class AggsHandlerCodeGenerator( // bind input1 as inputRow val exprGenerator = new ExprCodeGenerator(ctx, INPUT_NOT_NULL) .bindInput(inputType, inputTerm = ACCUMULATE_INPUT_TERM) - val body = aggActionCodeGens.map(_.accumulate(exprGenerator)).mkString("\n") + val body = splitExpressionsIfNecessary( + aggActionCodeGens.map(_.accumulate(exprGenerator)), methodName) s""" |${ctx.reuseLocalVariableCode(methodName)} |${ctx.reuseInputUnboxingCode(ACCUMULATE_INPUT_TERM)} @@ -901,7 +910,8 @@ class AggsHandlerCodeGenerator( // bind input1 as inputRow val exprGenerator = new ExprCodeGenerator(ctx, INPUT_NOT_NULL) .bindInput(inputType, inputTerm = RETRACT_INPUT_TERM) - val body = aggActionCodeGens.map(_.retract(exprGenerator)).mkString("\n") + val body = splitExpressionsIfNecessary( + aggActionCodeGens.map(_.retract(exprGenerator)), methodName) s""" |${ctx.reuseLocalVariableCode(methodName)} |${ctx.reuseInputUnboxingCode(RETRACT_INPUT_TERM)} @@ -935,7 +945,8 @@ class AggsHandlerCodeGenerator( // bind input1 as otherAcc val exprGenerator = new ExprCodeGenerator(ctx, INPUT_NOT_NULL) .bindInput(mergedAccType, inputTerm = MERGED_ACC_TERM) - val body = aggActionCodeGens.map(_.merge(exprGenerator)).mkString("\n") + val body = splitExpressionsIfNecessary( + aggActionCodeGens.map(_.merge(exprGenerator)), methodName) s""" |${ctx.reuseLocalVariableCode(methodName)} |${ctx.reuseInputUnboxingCode(MERGED_ACC_TERM)} @@ -947,6 +958,27 @@ class AggsHandlerCodeGenerator( } } + private def splitExpressionsIfNecessary(exprs: Array[String], methodName: String): String = { + val totalLen = exprs.map(_.length).sum + val maxCodeLength = ctx.tableConfig.getMaxGeneratedCodeLength + if (totalLen > maxCodeLength) { + ctx.setCodeSplit(methodName) + exprs.map(expr => { + val splitMethodName = newName("split_" + methodName) + val method = + s""" + |private void $splitMethodName() throws Exception { + | $expr + |} + |""".stripMargin + ctx.addReusableMember(method) + s"$splitMethodName();" + }).mkString("\n") + } else { + exprs.mkString("\n") + } + } + private def getWindowExpressions( windowProperties: Seq[PlannerWindowProperty]): Seq[GeneratedExpression] = { windowProperties.map { @@ -1006,7 +1038,9 @@ class AggsHandlerCodeGenerator( valueType, classOf[GenericRowData], outRow = aggValueTerm, - reusedOutRow = false) + reusedOutRow = false, + allowSplit = true, + methodName = methodName) s""" |${ctx.reuseLocalVariableCode(methodName)} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/DistinctAggCodeGen.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/DistinctAggCodeGen.scala index ea319b3..99d41e2 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/DistinctAggCodeGen.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/DistinctAggCodeGen.scala @@ -481,10 +481,11 @@ class DistinctAggCodeGen( if (useBackupDataView) { // this is called in the merge method val otherMapViewTerm = newName("otherMapView") + ctx.addReusableMember(s"private $MAP_VIEW $otherMapViewTerm;") val code = s""" |${expr.code} - |$MAP_VIEW $otherMapViewTerm = null; + |$otherMapViewTerm = null; |if (!${expr.nullTerm}) { | $otherMapViewTerm = ${genToExternal(ctx, externalAccType, expr.resultTerm)}; |} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenHelper.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenHelper.scala index abc5791..5ba6227 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenHelper.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenHelper.scala @@ -407,7 +407,8 @@ object HashAggCodeGenHelper { outRowWriter = None, reusedOutRow = true, outRowAlreadyExists = true, - allowSplit = false + allowSplit = false, + methodName = null ) } diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala index 2a76d6e..ccf76db 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala @@ -1334,4 +1334,30 @@ class AggregateITCase( val expected = Seq("3,29.39,tom...@gmail.com") assertEquals(expected.sorted, sink.getRetractResults.sorted) } + + @Test + def testAggregationCodeSplit(): Unit = { + + val t = env.fromCollection(TestData.smallTupleData3) + .toTable(tEnv, 'a, 'b, 'c) + tEnv.createTemporaryView("MyTable", t) + + val columnNumber = 500 + + val selectList = Stream.range(3, columnNumber) + .map(i => s"SUM(CASE WHEN a IS NOT NULL AND a > $i THEN 0 WHEN a < 0 THEN 0 ELSE $i END)") + .mkString(",") + val sqlQuery = s"select $selectList from MyTable group by b, c" + + val result = tEnv.sqlQuery(sqlQuery).toRetractStream[Row] + val sink = new TestingRetractSink + result.addSink(sink) + env.execute() + + val expected = Stream.range(3, columnNumber).map(_.toString).mkString(",") + assertEquals(sink.getRawResults.size, 3) + sink.getRetractResults.foreach(result => + assertEquals(expected, result) + ) + } }