Repository: spark
Updated Branches:
  refs/heads/master 3887b7eef -> 295df746e


[SPARK-22677][SQL] cleanup whole stage codegen for hash aggregate

## What changes were proposed in this pull request?

The `HashAggregateExec` whole stage codegen path is a little messy and hard to 
understand, this code cleans it up a little bit, especially for the fast hash 
map part.

## How was this patch tested?

existing tests

Author: Wenchen Fan <wenc...@databricks.com>

Closes #19869 from cloud-fan/hash-agg.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/295df746
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/295df746
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/295df746

Branch: refs/heads/master
Commit: 295df746ecb1def5530a044d6670b28821da89f0
Parents: 3887b7e
Author: Wenchen Fan <wenc...@databricks.com>
Authored: Tue Dec 5 12:38:26 2017 +0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Tue Dec 5 12:38:26 2017 +0800

----------------------------------------------------------------------
 .../execution/aggregate/HashAggregateExec.scala | 402 +++++++++----------
 1 file changed, 195 insertions(+), 207 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/295df746/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
index 9139788..26d8cd7 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
-import org.apache.spark.sql.execution.vectorized.MutableColumnarRow
+import org.apache.spark.sql.execution.vectorized.{ColumnarRow, 
MutableColumnarRow}
 import org.apache.spark.sql.types.{DecimalType, StringType, StructType}
 import org.apache.spark.unsafe.KVIterator
 import org.apache.spark.util.Utils
@@ -444,6 +444,7 @@ case class HashAggregateExec(
     val funcName = ctx.freshName("doAggregateWithKeysOutput")
     val keyTerm = ctx.freshName("keyTerm")
     val bufferTerm = ctx.freshName("bufferTerm")
+    val numOutput = metricTerm(ctx, "numOutputRows")
 
     val body =
     if (modes.contains(Final) || modes.contains(Complete)) {
@@ -520,6 +521,7 @@ case class HashAggregateExec(
       s"""
         private void $funcName(UnsafeRow $keyTerm, UnsafeRow $bufferTerm)
             throws java.io.IOException {
+          $numOutput.add(1);
           $body
         }
        """)
@@ -549,7 +551,7 @@ case class HashAggregateExec(
     isSupported  && isNotByteArrayDecimalType
   }
 
-  private def enableTwoLevelHashMap(ctx: CodegenContext) = {
+  private def enableTwoLevelHashMap(ctx: CodegenContext): Unit = {
     if (!checkIfFastHashMapSupported(ctx)) {
       if (modes.forall(mode => mode == Partial || mode == PartialMerge) && 
!Utils.isTesting) {
         logInfo("spark.sql.codegen.aggregate.map.twolevel.enabled is set to 
true, but"
@@ -560,9 +562,8 @@ case class HashAggregateExec(
 
       // This is for testing/benchmarking only.
       // We enforce to first level to be a vectorized hashmap, instead of the 
default row-based one.
-      sqlContext.getConf("spark.sql.codegen.aggregate.map.vectorized.enable", 
null) match {
-        case "true" => isVectorizedHashMapEnabled = true
-        case null | "" | "false" => None      }
+      isVectorizedHashMapEnabled = sqlContext.getConf(
+        "spark.sql.codegen.aggregate.map.vectorized.enable", "false") == "true"
     }
   }
 
@@ -573,94 +574,84 @@ case class HashAggregateExec(
       enableTwoLevelHashMap(ctx)
     } else {
       sqlContext.getConf("spark.sql.codegen.aggregate.map.vectorized.enable", 
null) match {
-        case "true" => logWarning("Two level hashmap is disabled but 
vectorized hashmap is " +
-          "enabled.")
-        case null | "" | "false" => None
+        case "true" =>
+          logWarning("Two level hashmap is disabled but vectorized hashmap is 
enabled.")
+        case _ =>
       }
     }
-    fastHashMapTerm = ctx.freshName("fastHashMap")
-    val fastHashMapClassName = ctx.freshName("FastHashMap")
-    val fastHashMapGenerator =
-      if (isVectorizedHashMapEnabled) {
-        new VectorizedHashMapGenerator(ctx, aggregateExpressions,
-          fastHashMapClassName, groupingKeySchema, bufferSchema)
-      } else {
-        new RowBasedHashMapGenerator(ctx, aggregateExpressions,
-          fastHashMapClassName, groupingKeySchema, bufferSchema)
-      }
 
     val thisPlan = ctx.addReferenceObj("plan", this)
 
-    // Create a name for iterator from vectorized HashMap
+    // Create a name for the iterator from the fast hash map.
     val iterTermForFastHashMap = ctx.freshName("fastHashMapIter")
     if (isFastHashMapEnabled) {
+      // Generates the fast hash map class and creates the fash hash map term.
+      fastHashMapTerm = ctx.freshName("fastHashMap")
+      val fastHashMapClassName = ctx.freshName("FastHashMap")
       if (isVectorizedHashMapEnabled) {
+        val generatedMap = new VectorizedHashMapGenerator(ctx, 
aggregateExpressions,
+          fastHashMapClassName, groupingKeySchema, bufferSchema).generate()
+        ctx.addInnerClass(generatedMap)
+
         ctx.addMutableState(fastHashMapClassName, fastHashMapTerm,
           s"$fastHashMapTerm = new $fastHashMapClassName();")
         ctx.addMutableState(
-          
"java.util.Iterator<org.apache.spark.sql.execution.vectorized.ColumnarRow>",
+          s"java.util.Iterator<${classOf[ColumnarRow].getName}>",
           iterTermForFastHashMap)
       } else {
+        val generatedMap = new RowBasedHashMapGenerator(ctx, 
aggregateExpressions,
+          fastHashMapClassName, groupingKeySchema, bufferSchema).generate()
+        ctx.addInnerClass(generatedMap)
+
         ctx.addMutableState(fastHashMapClassName, fastHashMapTerm,
           s"$fastHashMapTerm = new $fastHashMapClassName(" +
             s"$thisPlan.getTaskMemoryManager(), 
$thisPlan.getEmptyAggregationBuffer());")
         ctx.addMutableState(
-          "org.apache.spark.unsafe.KVIterator",
+          "org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow>",
           iterTermForFastHashMap)
       }
     }
 
+    // Create a name for the iterator from the regular hash map.
+    val iterTerm = ctx.freshName("mapIter")
+    ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, 
iterTerm)
     // create hashMap
     hashMapTerm = ctx.freshName("hashMap")
     val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName
-    ctx.addMutableState(hashMapClassName, hashMapTerm)
+    ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = 
$thisPlan.createHashMap();")
     sorterTerm = ctx.freshName("sorter")
     ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, sorterTerm)
 
-    // Create a name for iterator from HashMap
-    val iterTerm = ctx.freshName("mapIter")
-    ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, 
iterTerm)
-
-    def generateGenerateCode(): String = {
-      if (isFastHashMapEnabled) {
-        if (isVectorizedHashMapEnabled) {
-          s"""
-               | 
${fastHashMapGenerator.asInstanceOf[VectorizedHashMapGenerator].generate()}
-          """.stripMargin
-        } else {
-          s"""
-               | 
${fastHashMapGenerator.asInstanceOf[RowBasedHashMapGenerator].generate()}
-          """.stripMargin
-        }
-      } else ""
-    }
-    ctx.addInnerClass(generateGenerateCode())
-
     val doAgg = ctx.freshName("doAggregateWithKeys")
     val peakMemory = metricTerm(ctx, "peakMemory")
     val spillSize = metricTerm(ctx, "spillSize")
     val avgHashProbe = metricTerm(ctx, "avgHashProbe")
-    val doAggFuncName = ctx.addNewFunction(doAgg,
-      s"""
-        private void $doAgg() throws java.io.IOException {
-          $hashMapTerm = $thisPlan.createHashMap();
-          ${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
 
-          ${if (isFastHashMapEnabled) {
-              s"$iterTermForFastHashMap = $fastHashMapTerm.rowIterator();"} 
else ""}
+    val finishRegularHashMap = s"$iterTerm = $thisPlan.finishAggregate(" +
+      s"$hashMapTerm, $sorterTerm, $peakMemory, $spillSize, $avgHashProbe);"
+    val finishHashMap = if (isFastHashMapEnabled) {
+      s"""
+         |$iterTermForFastHashMap = $fastHashMapTerm.rowIterator();
+         |$finishRegularHashMap
+       """.stripMargin
+    } else {
+      finishRegularHashMap
+    }
 
-          $iterTerm = $thisPlan.finishAggregate($hashMapTerm, $sorterTerm, 
$peakMemory, $spillSize,
-            $avgHashProbe);
-        }
-       """)
+    val doAggFuncName = ctx.addNewFunction(doAgg,
+      s"""
+         |private void $doAgg() throws java.io.IOException {
+         |  ${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
+         |  $finishHashMap
+         |}
+       """.stripMargin)
 
     // generate code for output
     val keyTerm = ctx.freshName("aggKey")
     val bufferTerm = ctx.freshName("aggBuffer")
     val outputFunc = generateResultFunction(ctx)
-    val numOutput = metricTerm(ctx, "numOutputRows")
 
-    def outputFromGeneratedMap: String = {
+    def outputFromFastHashMap: String = {
       if (isFastHashMapEnabled) {
         if (isVectorizedHashMapEnabled) {
           outputFromVectorizedMap
@@ -672,48 +663,56 @@ case class HashAggregateExec(
 
     def outputFromRowBasedMap: String = {
       s"""
-       while ($iterTermForFastHashMap.next()) {
-         $numOutput.add(1);
-         UnsafeRow $keyTerm = (UnsafeRow) $iterTermForFastHashMap.getKey();
-         UnsafeRow $bufferTerm = (UnsafeRow) 
$iterTermForFastHashMap.getValue();
-         $outputFunc($keyTerm, $bufferTerm);
-
-         if (shouldStop()) return;
-       }
-       $fastHashMapTerm.close();
-     """
+         |while ($iterTermForFastHashMap.next()) {
+         |  UnsafeRow $keyTerm = (UnsafeRow) $iterTermForFastHashMap.getKey();
+         |  UnsafeRow $bufferTerm = (UnsafeRow) 
$iterTermForFastHashMap.getValue();
+         |  $outputFunc($keyTerm, $bufferTerm);
+         |
+         |  if (shouldStop()) return;
+         |}
+         |$fastHashMapTerm.close();
+       """.stripMargin
     }
 
     // Iterate over the aggregate rows and convert them from ColumnarRow to 
UnsafeRow
     def outputFromVectorizedMap: String = {
-        val row = ctx.freshName("fastHashMapRow")
-        ctx.currentVars = null
-        ctx.INPUT_ROW = row
-        val generateKeyRow = GenerateUnsafeProjection.createCode(ctx,
-          groupingKeySchema.toAttributes.zipWithIndex
+      val row = ctx.freshName("fastHashMapRow")
+      ctx.currentVars = null
+      ctx.INPUT_ROW = row
+      val generateKeyRow = GenerateUnsafeProjection.createCode(ctx,
+        groupingKeySchema.toAttributes.zipWithIndex
           .map { case (attr, i) => BoundReference(i, attr.dataType, 
attr.nullable) }
-        )
-        val generateBufferRow = GenerateUnsafeProjection.createCode(ctx,
-          bufferSchema.toAttributes.zipWithIndex
-          .map { case (attr, i) =>
-            BoundReference(groupingKeySchema.length + i, attr.dataType, 
attr.nullable) })
-        s"""
-           | while ($iterTermForFastHashMap.hasNext()) {
-           |   $numOutput.add(1);
-           |   org.apache.spark.sql.execution.vectorized.ColumnarRow $row =
-           |     (org.apache.spark.sql.execution.vectorized.ColumnarRow)
-           |     $iterTermForFastHashMap.next();
-           |   ${generateKeyRow.code}
-           |   ${generateBufferRow.code}
-           |   $outputFunc(${generateKeyRow.value}, 
${generateBufferRow.value});
-           |
-           |   if (shouldStop()) return;
-           | }
-           |
-           | $fastHashMapTerm.close();
-         """.stripMargin
+      )
+      val generateBufferRow = GenerateUnsafeProjection.createCode(ctx,
+        bufferSchema.toAttributes.zipWithIndex.map { case (attr, i) =>
+          BoundReference(groupingKeySchema.length + i, attr.dataType, 
attr.nullable)
+        })
+      val columnarRowCls = classOf[ColumnarRow].getName
+      s"""
+         |while ($iterTermForFastHashMap.hasNext()) {
+         |  $columnarRowCls $row = ($columnarRowCls) 
$iterTermForFastHashMap.next();
+         |  ${generateKeyRow.code}
+         |  ${generateBufferRow.code}
+         |  $outputFunc(${generateKeyRow.value}, ${generateBufferRow.value});
+         |
+         |  if (shouldStop()) return;
+         |}
+         |
+         |$fastHashMapTerm.close();
+       """.stripMargin
     }
 
+    def outputFromRegularHashMap: String = {
+      s"""
+         |while ($iterTerm.next()) {
+         |  UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey();
+         |  UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue();
+         |  $outputFunc($keyTerm, $bufferTerm);
+         |
+         |  if (shouldStop()) return;
+         |}
+       """.stripMargin
+    }
 
     val aggTime = metricTerm(ctx, "aggTime")
     val beforeAgg = ctx.freshName("beforeAgg")
@@ -726,16 +725,8 @@ case class HashAggregateExec(
      }
 
      // output the result
-     ${outputFromGeneratedMap}
-
-     while ($iterTerm.next()) {
-       $numOutput.add(1);
-       UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey();
-       UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue();
-       $outputFunc($keyTerm, $bufferTerm);
-
-       if (shouldStop()) return;
-     }
+     $outputFromFastHashMap
+     $outputFromRegularHashMap
 
      $iterTerm.close();
      if ($sorterTerm == null) {
@@ -745,13 +736,11 @@ case class HashAggregateExec(
   }
 
   private def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]): 
String = {
-
     // create grouping key
-    ctx.currentVars = input
     val unsafeRowKeyCode = GenerateUnsafeProjection.createCode(
       ctx, groupingExpressions.map(e => 
BindReferences.bindReference[Expression](e, child.output)))
     val fastRowKeys = ctx.generateExpressions(
-          groupingExpressions.map(e => 
BindReferences.bindReference[Expression](e, child.output)))
+      groupingExpressions.map(e => BindReferences.bindReference[Expression](e, 
child.output)))
     val unsafeRowKeys = unsafeRowKeyCode.value
     val unsafeRowBuffer = ctx.freshName("unsafeRowAggBuffer")
     val fastRowBuffer = ctx.freshName("fastAggBuffer")
@@ -768,12 +757,8 @@ case class HashAggregateExec(
 
     // generate hash code for key
     val hashExpr = Murmur3Hash(groupingExpressions, 42)
-    ctx.currentVars = input
     val hashEval = BindReferences.bindReference(hashExpr, 
child.output).genCode(ctx)
 
-    val inputAttr = aggregateBufferAttributes ++ child.output
-    ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ 
input
-
     val (checkFallbackForGeneratedHashMap, checkFallbackForBytesToBytesMap, 
resetCounter,
     incCounter) = if (testFallbackStartsAt.isDefined) {
       val countTerm = ctx.freshName("fallbackCounter")
@@ -784,86 +769,65 @@ case class HashAggregateExec(
       ("true", "true", "", "")
     }
 
-    // We first generate code to probe and update the fast hash map. If the 
probe is
-    // successful the corresponding fast row buffer will hold the mutable row
-    val findOrInsertFastHashMap: Option[String] = {
+    val findOrInsertRegularHashMap: String =
+      s"""
+         |// generate grouping key
+         |${unsafeRowKeyCode.code.trim}
+         |${hashEval.code.trim}
+         |if ($checkFallbackForBytesToBytesMap) {
+         |  // try to get the buffer from hash map
+         |  $unsafeRowBuffer =
+         |    $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, 
${hashEval.value});
+         |}
+         |// Can't allocate buffer from the hash map. Spill the map and 
fallback to sort-based
+         |// aggregation after processing all input rows.
+         |if ($unsafeRowBuffer == null) {
+         |  if ($sorterTerm == null) {
+         |    $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter();
+         |  } else {
+         |    
$sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter());
+         |  }
+         |  $resetCounter
+         |  // the hash map had be spilled, it should have enough memory now,
+         |  // try to allocate buffer again.
+         |  $unsafeRowBuffer = $hashMapTerm.getAggregationBufferFromUnsafeRow(
+         |    $unsafeRowKeys, ${hashEval.value});
+         |  if ($unsafeRowBuffer == null) {
+         |    // failed to allocate the first page
+         |    throw new OutOfMemoryError("No enough memory for aggregation");
+         |  }
+         |}
+       """.stripMargin
+
+    val findOrInsertHashMap: String = {
       if (isFastHashMapEnabled) {
-        Option(
-          s"""
-             |
-             |if ($checkFallbackForGeneratedHashMap) {
-             |  ${fastRowKeys.map(_.code).mkString("\n")}
-             |  if (${fastRowKeys.map("!" + _.isNull).mkString(" && ")}) {
-             |    $fastRowBuffer = $fastHashMapTerm.findOrInsert(
-             |        ${fastRowKeys.map(_.value).mkString(", ")});
-             |  }
-             |}
-         """.stripMargin)
+        // If fast hash map is on, we first generate code to probe and update 
the fast hash map.
+        // If the probe is successful the corresponding fast row buffer will 
hold the mutable row.
+        s"""
+           |if ($checkFallbackForGeneratedHashMap) {
+           |  ${fastRowKeys.map(_.code).mkString("\n")}
+           |  if (${fastRowKeys.map("!" + _.isNull).mkString(" && ")}) {
+           |    $fastRowBuffer = $fastHashMapTerm.findOrInsert(
+           |      ${fastRowKeys.map(_.value).mkString(", ")});
+           |  }
+           |}
+           |// Cannot find the key in fast hash map, try regular hash map.
+           |if ($fastRowBuffer == null) {
+           |  $findOrInsertRegularHashMap
+           |}
+         """.stripMargin
       } else {
-        None
+        findOrInsertRegularHashMap
       }
     }
 
+    val inputAttr = aggregateBufferAttributes ++ child.output
+    // Here we set `currentVars(0)` to `currentVars(numBufferSlots)` to null, 
so that when
+    // generating code for buffer columns, we use `INPUT_ROW`(will be the 
buffer row), while
+    // generating input columns, we use `currentVars`.
+    ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ 
input
 
-    def updateRowInFastHashMap(isVectorized: Boolean): Option[String] = {
-      ctx.INPUT_ROW = fastRowBuffer
-      val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, 
inputAttr))
-      val subExprs = 
ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
-      val effectiveCodes = subExprs.codes.mkString("\n")
-      val fastRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) {
-        boundUpdateExpr.map(_.genCode(ctx))
-      }
-      val updateFastRow = fastRowEvals.zipWithIndex.map { case (ev, i) =>
-        val dt = updateExpr(i).dataType
-        ctx.updateColumn(fastRowBuffer, dt, i, ev, updateExpr(i).nullable, 
isVectorized)
-      }
-      Option(
-        s"""
-           |// common sub-expressions
-           |$effectiveCodes
-           |// evaluate aggregate function
-           |${evaluateVariables(fastRowEvals)}
-           |// update fast row
-           |${updateFastRow.mkString("\n").trim}
-           |
-         """.stripMargin)
-    }
-
-    // Next, we generate code to probe and update the unsafe row hash map.
-    val findOrInsertInUnsafeRowMap: String = {
-      s"""
-         | if ($fastRowBuffer == null) {
-         |   // generate grouping key
-         |   ${unsafeRowKeyCode.code.trim}
-         |   ${hashEval.code.trim}
-         |   if ($checkFallbackForBytesToBytesMap) {
-         |     // try to get the buffer from hash map
-         |     $unsafeRowBuffer =
-         |       
$hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, 
${hashEval.value});
-         |   }
-         |   // Can't allocate buffer from the hash map. Spill the map and 
fallback to sort-based
-         |   // aggregation after processing all input rows.
-         |   if ($unsafeRowBuffer == null) {
-         |     if ($sorterTerm == null) {
-         |       $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter();
-         |     } else {
-         |       
$sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter());
-         |     }
-         |     $resetCounter
-         |     // the hash map had be spilled, it should have enough memory 
now,
-         |     // try to allocate buffer again.
-         |     $unsafeRowBuffer =
-         |       
$hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, 
${hashEval.value});
-         |     if ($unsafeRowBuffer == null) {
-         |       // failed to allocate the first page
-         |       throw new OutOfMemoryError("No enough memory for 
aggregation");
-         |     }
-         |   }
-         | }
-       """.stripMargin
-    }
-
-    val updateRowInUnsafeRowMap: String = {
+    val updateRowInRegularHashMap: String = {
       ctx.INPUT_ROW = unsafeRowBuffer
       val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, 
inputAttr))
       val subExprs = 
ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
@@ -882,45 +846,69 @@ case class HashAggregateExec(
          |${evaluateVariables(unsafeRowBufferEvals)}
          |// update unsafe row buffer
          |${updateUnsafeRowBuffer.mkString("\n").trim}
-           """.stripMargin
+       """.stripMargin
+    }
+
+    val updateRowInHashMap: String = {
+      if (isFastHashMapEnabled) {
+        ctx.INPUT_ROW = fastRowBuffer
+        val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, 
inputAttr))
+        val subExprs = 
ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
+        val effectiveCodes = subExprs.codes.mkString("\n")
+        val fastRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) {
+          boundUpdateExpr.map(_.genCode(ctx))
+        }
+        val updateFastRow = fastRowEvals.zipWithIndex.map { case (ev, i) =>
+          val dt = updateExpr(i).dataType
+          ctx.updateColumn(
+            fastRowBuffer, dt, i, ev, updateExpr(i).nullable, 
isVectorizedHashMapEnabled)
+        }
+
+        // If fast hash map is on, we first generate code to update row in 
fast hash map, if the
+        // previous loop up hit fast hash map. Otherwise, update row in 
regular hash map.
+        s"""
+           |if ($fastRowBuffer != null) {
+           |  // common sub-expressions
+           |  $effectiveCodes
+           |  // evaluate aggregate function
+           |  ${evaluateVariables(fastRowEvals)}
+           |  // update fast row
+           |  ${updateFastRow.mkString("\n").trim}
+           |} else {
+           |  $updateRowInRegularHashMap
+           |}
+       """.stripMargin
+      } else {
+        updateRowInRegularHashMap
+      }
     }
 
+    val declareRowBuffer: String = if (isFastHashMapEnabled) {
+      val fastRowType = if (isVectorizedHashMapEnabled) {
+        classOf[MutableColumnarRow].getName
+      } else {
+        "UnsafeRow"
+      }
+      s"""
+         |UnsafeRow $unsafeRowBuffer = null;
+         |$fastRowType $fastRowBuffer = null;
+       """.stripMargin
+    } else {
+      s"UnsafeRow $unsafeRowBuffer = null;"
+    }
 
     // We try to do hash map based in-memory aggregation first. If there is 
not enough memory (the
     // hash map will return null for new key), we spill the hash map to disk 
to free memory, then
     // continue to do in-memory aggregation and spilling until all the rows 
had been processed.
     // Finally, sort the spilled aggregate buffers by key, and merge them 
together for same key.
     s"""
-     UnsafeRow $unsafeRowBuffer = null;
-     ${
-        if (isVectorizedHashMapEnabled) {
-          s"""
-             | ${classOf[MutableColumnarRow].getName} $fastRowBuffer = null;
-           """.stripMargin
-        } else {
-          s"""
-             | UnsafeRow $fastRowBuffer = null;
-           """.stripMargin
-        }
-      }
+     $declareRowBuffer
 
-     ${findOrInsertFastHashMap.getOrElse("")}
-
-     $findOrInsertInUnsafeRowMap
+     $findOrInsertHashMap
 
      $incCounter
 
-     if ($fastRowBuffer != null) {
-       // update fast row
-       ${
-          if (isFastHashMapEnabled) {
-            updateRowInFastHashMap(isVectorizedHashMapEnabled).getOrElse("")
-          } else ""
-        }
-     } else {
-       // update unsafe row
-       $updateRowInUnsafeRowMap
-     }
+     $updateRowInHashMap
      """
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to