This is an automated email from the ASF dual-hosted git repository.

ron pushed a commit to branch release-1.17
in repository https://gitbox.apache.org/repos/asf/flink.git

commit b05eb8bb1a0223f2adbc6296f96d95db89716ffb
Author: Ron <ldliu...@163.com>
AuthorDate: Tue Jun 6 10:11:31 2023 +0800

    [FLINK-32220][table-runtime] Improving the adaptive local hash agg code to 
avoid get value from RowData repeatedly
    
    This closes #22684
    
    (cherry picked from commit 3046875c7aa0501f9a67f280034a74ea107315e3)
---
 .../planner/codegen/ProjectionCodeGenerator.scala  | 45 +++++++++++-----------
 .../codegen/agg/batch/HashAggCodeGenerator.scala   | 30 +++++++--------
 2 files changed, 38 insertions(+), 37 deletions(-)

diff --git 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ProjectionCodeGenerator.scala
 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ProjectionCodeGenerator.scala
index e2c3759ddaa..78095999e2c 100644
--- 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ProjectionCodeGenerator.scala
+++ 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ProjectionCodeGenerator.scala
@@ -28,7 +28,6 @@ import org.apache.flink.table.planner.functions.aggfunctions._
 import org.apache.flink.table.planner.plan.utils.AggregateInfo
 import org.apache.flink.table.runtime.generated.{GeneratedProjection, 
Projection}
 import org.apache.flink.table.types.logical.{BigIntType, LogicalType, RowType}
-import 
org.apache.flink.table.types.logical.utils.LogicalTypeChecks.getFieldTypes
 
 import scala.collection.mutable.ArrayBuffer
 
@@ -161,7 +160,7 @@ object ProjectionCodeGenerator {
               sumAggFunction.getResultType.getLogicalType,
               aggInfo.agg.getArgList.get(0))
           case _: MaxAggFunction | _: MinAggFunction =>
-            fieldExprs += GenerateUtils.generateFieldAccess(
+            fieldExprs += reuseFieldExprForAggFunc(
               ctx,
               inputType,
               inputTerm,
@@ -231,25 +230,23 @@ object ProjectionCodeGenerator {
       inputTerm: String,
       targetType: LogicalType,
       index: Int): GeneratedExpression = {
-    val fieldType = getFieldTypes(inputType).get(index)
-    val resultTypeTerm = primitiveTypeTermForType(fieldType)
-    val defaultValue = primitiveDefaultValue(fieldType)
-    val readCode = rowFieldReadAccess(index.toString, inputTerm, fieldType)
-    val Seq(fieldTerm, nullTerm) =
-      ctx.addReusableLocalVariables((resultTypeTerm, "field"), ("boolean", 
"isNull"))
-
-    val inputCode =
-      s"""
-         |$nullTerm = $inputTerm.isNullAt($index);
-         |$fieldTerm = $defaultValue;
-         |if (!$nullTerm) {
-         |  $fieldTerm = $readCode;
-         |}
-           """.stripMargin.trim
-
-    val expression = GeneratedExpression(fieldTerm, nullTerm, inputCode, 
fieldType)
+    val fieldExpr = reuseFieldExprForAggFunc(ctx, inputType, inputTerm, index)
     // Convert the projected value type to sum agg func target type.
-    ScalarOperatorGens.generateCast(ctx, expression, targetType, true)
+    ScalarOperatorGens.generateCast(ctx, fieldExpr, targetType, nullOnFailure 
= true)
+  }
+
+  /** Get reuse field expr if it has been evaluated before for adaptive local 
hash aggregation. */
+  def reuseFieldExprForAggFunc(
+      ctx: CodeGeneratorContext,
+      inputType: LogicalType,
+      inputTerm: String,
+      index: Int): GeneratedExpression = {
+    // Reuse the field access code if it has been evaluated before
+    ctx.getReusableInputUnboxingExprs(inputTerm, index) match {
+      case None => GenerateUtils.generateFieldAccess(ctx, inputType, 
inputTerm, index)
+      case Some(expr) =>
+        GeneratedExpression(expr.resultTerm, expr.nullTerm, NO_CODE, 
expr.resultType)
+    }
   }
 
   /**
@@ -260,13 +257,17 @@ object ProjectionCodeGenerator {
       ctx: CodeGeneratorContext,
       inputTerm: String,
       index: Int): GeneratedExpression = {
+    val fieldNullCode = ctx.getReusableInputUnboxingExprs(inputTerm, index) 
match {
+      case None => s"$inputTerm.isNullAt($index)"
+      case Some(expr) => expr.nullTerm
+    }
+
     val Seq(fieldTerm, nullTerm) =
       ctx.addReusableLocalVariables(("long", "field"), ("boolean", "isNull"))
-
     val inputCode =
       s"""
          |$fieldTerm = 0L;
-         |if (!$inputTerm.isNullAt($index)) {
+         |if (!$fieldNullCode) {
          |  $fieldTerm = 1L;
          |}
            """.stripMargin.trim
diff --git 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenerator.scala
 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenerator.scala
index e9ccb27aef0..80881bb25ad 100644
--- 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenerator.scala
+++ 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenerator.scala
@@ -142,21 +142,6 @@ object HashAggCodeGenerator {
         outRecordWriterTerm = currentKeyWriterTerm)
       .code
 
-    val valueProjectionCode =
-      if (!isFinal && supportAdaptiveLocalHashAgg) {
-        ProjectionCodeGenerator.genAdaptiveLocalHashAggValueProjectionCode(
-          ctx,
-          inputType,
-          classOf[BinaryRowData],
-          inputTerm = inputTerm,
-          aggInfos,
-          outRecordTerm = currentValueTerm,
-          outRecordWriterTerm = currentValueWriterTerm
-        )
-      } else {
-        ""
-      }
-
     // gen code to create groupKey, aggBuffer Type array
     // it will be used in BytesHashMap and BufferedKVExternalSorter if enable 
fallback
     val groupKeyTypesTerm = CodeGenUtils.newName("groupKeyTypes")
@@ -264,6 +249,21 @@ object HashAggCodeGenerator {
     }
     val localAggSuppressedTerm = CodeGenUtils.newName("localAggSuppressed")
     ctx.addReusableMember(s"private transient boolean $localAggSuppressedTerm 
= false;")
+    val valueProjectionCode =
+      if (!isFinal && supportAdaptiveLocalHashAgg) {
+        ProjectionCodeGenerator.genAdaptiveLocalHashAggValueProjectionCode(
+          ctx,
+          inputType,
+          classOf[BinaryRowData],
+          inputTerm = inputTerm,
+          aggInfos,
+          outRecordTerm = currentValueTerm,
+          outRecordWriterTerm = currentValueWriterTerm
+        )
+      } else {
+        ""
+      }
+
     val (
       distinctCountIncCode,
       totalCountIncCode,

Reply via email to