maropu commented on a change in pull request #20965: [SPARK-21870][SQL] Split 
aggregation code into small functions
URL: https://github.com/apache/spark/pull/20965#discussion_r320557819
 
 

 ##########
 File path: 
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
 ##########
 @@ -824,59 +944,158 @@ case class HashAggregateExec(
     // generating input columns, we use `currentVars`.
     ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ 
input
 
+    val aggNames = aggregateExpressions.map(_.aggregateFunction.prettyName)
+    // Computes start offsets for each aggregation function code
+    // in the underlying buffer row.
+    val bufferStartOffsets = {
+      val offsets = mutable.ArrayBuffer[Int]()
+      var curOffset = 0
+      updateExprs.foreach { exprsForOneFunc =>
+        offsets += curOffset
+        curOffset += exprsForOneFunc.length
+      }
+      offsets.toArray
+    }
+
     val updateRowInRegularHashMap: String = {
       ctx.INPUT_ROW = unsafeRowBuffer
-      val boundUpdateExpr = bindReferences(updateExpr, inputAttr)
-      val subExprs = 
ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
+      val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc =>
+        bindReferences(updateExprsForOneFunc, inputAttr)
+      }
+      val subExprs = 
ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten)
       val effectiveCodes = subExprs.codes.mkString("\n")
-      val unsafeRowBufferEvals = 
ctx.withSubExprEliminationExprs(subExprs.states) {
-        boundUpdateExpr.map(_.genCode(ctx))
+      val unsafeRowBufferEvals = boundUpdateExprs.map { 
boundUpdateExprsForOneFunc =>
+        ctx.withSubExprEliminationExprs(subExprs.states) {
+          boundUpdateExprsForOneFunc.map(_.genCode(ctx))
+        }
       }
-      val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map { case 
(ev, i) =>
-        val dt = updateExpr(i).dataType
-        CodeGenerator.updateColumn(unsafeRowBuffer, dt, i, ev, 
updateExpr(i).nullable)
+
+      val aggCodeBlocks = updateExprs.indices.map { i =>
+        val rowBufferEvalsForOneFunc = unsafeRowBufferEvals(i)
+        val boundUpdateExprsForOneFunc = boundUpdateExprs(i)
+        val bufferOffset = bufferStartOffsets(i)
+
+        // All the update code for aggregation buffers should be placed in the 
end
+        // of each aggregation function code.
+        val updateRowBuffers = rowBufferEvalsForOneFunc.zipWithIndex.map { 
case (ev, j) =>
+          val updateExpr = boundUpdateExprsForOneFunc(j)
+          val dt = updateExpr.dataType
+          val nullable = updateExpr.nullable
+          CodeGenerator.updateColumn(unsafeRowBuffer, dt, bufferOffset + j, 
ev, nullable)
+        }
+        code"""
+           |// evaluate aggregate function for ${aggNames(i)}
+           |${evaluateVariables(rowBufferEvalsForOneFunc)}
+           |// update unsafe row buffer
+           |${updateRowBuffers.mkString("\n").trim}
+         """.stripMargin
+      }
+
+      lazy val nonSplitAggCode = {
+        s"""
+           |// common sub-expressions
+           |$effectiveCodes
+           |// evaluate aggregate functions and update aggregation buffers
+           |${aggCodeBlocks.fold(EmptyBlock)(_ + _)}
+         """.stripMargin
+      }
+
+      if (conf.codegenSplitAggregateFunc &&
+          aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) {
+        val maybeSplitCode = splitAggregateExpressions(
+          ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states)
+
+        maybeSplitCode.map { updateAggCode =>
+          s"""
+             |// do aggregate
+             |// common sub-expressions
+             |$effectiveCodes
+             |// evaluate aggregate functions and update aggregation buffers
+             |$updateAggCode
+           """.stripMargin
+        }.getOrElse {
+          nonSplitAggCode
+        }
+      } else {
+        nonSplitAggCode
       }
-      s"""
-         |// common sub-expressions
-         |$effectiveCodes
-         |// evaluate aggregate function
-         |${evaluateVariables(unsafeRowBufferEvals)}
-         |// update unsafe row buffer
-         |${updateUnsafeRowBuffer.mkString("\n").trim}
-       """.stripMargin
     }
 
     val updateRowInHashMap: String = {
       if (isFastHashMapEnabled) {
         if (isVectorizedHashMapEnabled) {
           ctx.INPUT_ROW = fastRowBuffer
-          val boundUpdateExpr = bindReferences(updateExpr, inputAttr)
-          val subExprs = 
ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
+          val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc =>
+            bindReferences(updateExprsForOneFunc, inputAttr)
+          }
+          val subExprs = 
ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten)
           val effectiveCodes = subExprs.codes.mkString("\n")
-          val fastRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) {
-            boundUpdateExpr.map(_.genCode(ctx))
+          val fastRowEvals = boundUpdateExprs.map { boundUpdateExprsForOneFunc 
=>
+            ctx.withSubExprEliminationExprs(subExprs.states) {
+              boundUpdateExprsForOneFunc.map(_.genCode(ctx))
+            }
           }
-          val updateFastRow = fastRowEvals.zipWithIndex.map { case (ev, i) =>
-            val dt = updateExpr(i).dataType
-            CodeGenerator.updateColumn(
-              fastRowBuffer, dt, i, ev, updateExpr(i).nullable, isVectorized = 
true)
+
+          val aggCodeBlocks = fastRowEvals.zipWithIndex.map { case 
(fastRowEvalsForOneFunc, i) =>
+            val boundUpdateExprsForOneFunc = boundUpdateExprs(i)
+            val bufferOffset = bufferStartOffsets(i)
+            // All the update code for aggregation buffers should be placed in 
the end
+            // of each aggregation function code.
+             val updateRowBuffer = fastRowEvalsForOneFunc.zipWithIndex.map { 
case (ev, j) =>
+               val updateExpr = boundUpdateExprsForOneFunc(j)
+               val dt = updateExpr.dataType
+               val nullable = updateExpr.nullable
+               CodeGenerator.updateColumn(fastRowBuffer, dt, bufferOffset + j, 
ev, nullable,
+                 isVectorized = true)
+             }
+             code"""
+                |// evaluate aggregate function for ${aggNames(i)}
+                |${evaluateVariables(fastRowEvalsForOneFunc)}
+                |// update fast row
+                |${updateRowBuffer.mkString("\n").trim}
+              """.stripMargin
           }
 
-          // If vectorized fast hash map is on, we first generate code to 
update row
-          // in vectorized fast hash map, if the previous loop up hit 
vectorized fast hash map.
-          // Otherwise, update row in regular hash map.
-          s"""
-             |if ($fastRowBuffer != null) {
-             |  // common sub-expressions
-             |  $effectiveCodes
-             |  // evaluate aggregate function
-             |  ${evaluateVariables(fastRowEvals)}
-             |  // update fast row
-             |  ${updateFastRow.mkString("\n").trim}
-             |} else {
-             |  $updateRowInRegularHashMap
-             |}
-          """.stripMargin
+          lazy val nonSplitAggCode = {
+            // If vectorized fast hash map is on, we first generate code to 
update row
+            // in vectorized fast hash map, if the previous loop up hit 
vectorized fast hash map.
+            // Otherwise, update row in regular hash map.
+            s"""
+               |if ($fastRowBuffer != null) {
+               |  // common sub-expressions
+               |  $effectiveCodes
+               |  // evaluate aggregate functions and update aggregation 
buffers
+               |  ${aggCodeBlocks.fold(EmptyBlock)(_ + _)}
 
 Review comment:
   nice suggestion and I'll follow it.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to