This is an automated email from the ASF dual-hosted git repository. ron pushed a commit to branch release-1.17 in repository https://gitbox.apache.org/repos/asf/flink.git
commit b05eb8bb1a0223f2adbc6296f96d95db89716ffb Author: Ron <ldliu...@163.com> AuthorDate: Tue Jun 6 10:11:31 2023 +0800 [FLINK-32220][table-runtime] Improving the adaptive local hash agg code to avoid get value from RowData repeatedly This closes #22684 (cherry picked from commit 3046875c7aa0501f9a67f280034a74ea107315e3) --- .../planner/codegen/ProjectionCodeGenerator.scala | 45 +++++++++++----------- .../codegen/agg/batch/HashAggCodeGenerator.scala | 30 +++++++-------- 2 files changed, 38 insertions(+), 37 deletions(-) diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ProjectionCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ProjectionCodeGenerator.scala index e2c3759ddaa..78095999e2c 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ProjectionCodeGenerator.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ProjectionCodeGenerator.scala @@ -28,7 +28,6 @@ import org.apache.flink.table.planner.functions.aggfunctions._ import org.apache.flink.table.planner.plan.utils.AggregateInfo import org.apache.flink.table.runtime.generated.{GeneratedProjection, Projection} import org.apache.flink.table.types.logical.{BigIntType, LogicalType, RowType} -import org.apache.flink.table.types.logical.utils.LogicalTypeChecks.getFieldTypes import scala.collection.mutable.ArrayBuffer @@ -161,7 +160,7 @@ object ProjectionCodeGenerator { sumAggFunction.getResultType.getLogicalType, aggInfo.agg.getArgList.get(0)) case _: MaxAggFunction | _: MinAggFunction => - fieldExprs += GenerateUtils.generateFieldAccess( + fieldExprs += reuseFieldExprForAggFunc( ctx, inputType, inputTerm, @@ -231,25 +230,23 @@ object ProjectionCodeGenerator { inputTerm: String, targetType: LogicalType, index: Int): GeneratedExpression = { - val fieldType = getFieldTypes(inputType).get(index) - val resultTypeTerm = primitiveTypeTermForType(fieldType) - val defaultValue = primitiveDefaultValue(fieldType) - val readCode = rowFieldReadAccess(index.toString, inputTerm, fieldType) - val Seq(fieldTerm, nullTerm) = - ctx.addReusableLocalVariables((resultTypeTerm, "field"), ("boolean", "isNull")) - - val inputCode = - s""" - |$nullTerm = $inputTerm.isNullAt($index); - |$fieldTerm = $defaultValue; - |if (!$nullTerm) { - | $fieldTerm = $readCode; - |} - """.stripMargin.trim - - val expression = GeneratedExpression(fieldTerm, nullTerm, inputCode, fieldType) + val fieldExpr = reuseFieldExprForAggFunc(ctx, inputType, inputTerm, index) // Convert the projected value type to sum agg func target type. - ScalarOperatorGens.generateCast(ctx, expression, targetType, true) + ScalarOperatorGens.generateCast(ctx, fieldExpr, targetType, nullOnFailure = true) + } + + /** Get reuse field expr if it has been evaluated before for adaptive local hash aggregation. */ + def reuseFieldExprForAggFunc( + ctx: CodeGeneratorContext, + inputType: LogicalType, + inputTerm: String, + index: Int): GeneratedExpression = { + // Reuse the field access code if it has been evaluated before + ctx.getReusableInputUnboxingExprs(inputTerm, index) match { + case None => GenerateUtils.generateFieldAccess(ctx, inputType, inputTerm, index) + case Some(expr) => + GeneratedExpression(expr.resultTerm, expr.nullTerm, NO_CODE, expr.resultType) + } } /** @@ -260,13 +257,17 @@ object ProjectionCodeGenerator { ctx: CodeGeneratorContext, inputTerm: String, index: Int): GeneratedExpression = { + val fieldNullCode = ctx.getReusableInputUnboxingExprs(inputTerm, index) match { + case None => s"$inputTerm.isNullAt($index)" + case Some(expr) => expr.nullTerm + } + val Seq(fieldTerm, nullTerm) = ctx.addReusableLocalVariables(("long", "field"), ("boolean", "isNull")) - val inputCode = s""" |$fieldTerm = 0L; - |if (!$inputTerm.isNullAt($index)) { + |if (!$fieldNullCode) { | $fieldTerm = 1L; |} """.stripMargin.trim diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenerator.scala index e9ccb27aef0..80881bb25ad 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenerator.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenerator.scala @@ -142,21 +142,6 @@ object HashAggCodeGenerator { outRecordWriterTerm = currentKeyWriterTerm) .code - val valueProjectionCode = - if (!isFinal && supportAdaptiveLocalHashAgg) { - ProjectionCodeGenerator.genAdaptiveLocalHashAggValueProjectionCode( - ctx, - inputType, - classOf[BinaryRowData], - inputTerm = inputTerm, - aggInfos, - outRecordTerm = currentValueTerm, - outRecordWriterTerm = currentValueWriterTerm - ) - } else { - "" - } - // gen code to create groupKey, aggBuffer Type array // it will be used in BytesHashMap and BufferedKVExternalSorter if enable fallback val groupKeyTypesTerm = CodeGenUtils.newName("groupKeyTypes") @@ -264,6 +249,21 @@ object HashAggCodeGenerator { } val localAggSuppressedTerm = CodeGenUtils.newName("localAggSuppressed") ctx.addReusableMember(s"private transient boolean $localAggSuppressedTerm = false;") + val valueProjectionCode = + if (!isFinal && supportAdaptiveLocalHashAgg) { + ProjectionCodeGenerator.genAdaptiveLocalHashAggValueProjectionCode( + ctx, + inputType, + classOf[BinaryRowData], + inputTerm = inputTerm, + aggInfos, + outRecordTerm = currentValueTerm, + outRecordWriterTerm = currentValueWriterTerm + ) + } else { + "" + } + val ( distinctCountIncCode, totalCountIncCode,