Repository: spark Updated Branches: refs/heads/master 3887b7eef -> 295df746e
[SPARK-22677][SQL] cleanup whole stage codegen for hash aggregate ## What changes were proposed in this pull request? The `HashAggregateExec` whole stage codegen path is a little messy and hard to understand, this code cleans it up a little bit, especially for the fast hash map part. ## How was this patch tested? existing tests Author: Wenchen Fan <wenc...@databricks.com> Closes #19869 from cloud-fan/hash-agg. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/295df746 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/295df746 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/295df746 Branch: refs/heads/master Commit: 295df746ecb1def5530a044d6670b28821da89f0 Parents: 3887b7e Author: Wenchen Fan <wenc...@databricks.com> Authored: Tue Dec 5 12:38:26 2017 +0800 Committer: Wenchen Fan <wenc...@databricks.com> Committed: Tue Dec 5 12:38:26 2017 +0800 ---------------------------------------------------------------------- .../execution/aggregate/HashAggregateExec.scala | 402 +++++++++---------- 1 file changed, 195 insertions(+), 207 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/295df746/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 9139788..26d8cd7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} -import org.apache.spark.sql.execution.vectorized.MutableColumnarRow +import org.apache.spark.sql.execution.vectorized.{ColumnarRow, MutableColumnarRow} import org.apache.spark.sql.types.{DecimalType, StringType, StructType} import org.apache.spark.unsafe.KVIterator import org.apache.spark.util.Utils @@ -444,6 +444,7 @@ case class HashAggregateExec( val funcName = ctx.freshName("doAggregateWithKeysOutput") val keyTerm = ctx.freshName("keyTerm") val bufferTerm = ctx.freshName("bufferTerm") + val numOutput = metricTerm(ctx, "numOutputRows") val body = if (modes.contains(Final) || modes.contains(Complete)) { @@ -520,6 +521,7 @@ case class HashAggregateExec( s""" private void $funcName(UnsafeRow $keyTerm, UnsafeRow $bufferTerm) throws java.io.IOException { + $numOutput.add(1); $body } """) @@ -549,7 +551,7 @@ case class HashAggregateExec( isSupported && isNotByteArrayDecimalType } - private def enableTwoLevelHashMap(ctx: CodegenContext) = { + private def enableTwoLevelHashMap(ctx: CodegenContext): Unit = { if (!checkIfFastHashMapSupported(ctx)) { if (modes.forall(mode => mode == Partial || mode == PartialMerge) && !Utils.isTesting) { logInfo("spark.sql.codegen.aggregate.map.twolevel.enabled is set to true, but" @@ -560,9 +562,8 @@ case class HashAggregateExec( // This is for testing/benchmarking only. // We enforce to first level to be a vectorized hashmap, instead of the default row-based one. - sqlContext.getConf("spark.sql.codegen.aggregate.map.vectorized.enable", null) match { - case "true" => isVectorizedHashMapEnabled = true - case null | "" | "false" => None } + isVectorizedHashMapEnabled = sqlContext.getConf( + "spark.sql.codegen.aggregate.map.vectorized.enable", "false") == "true" } } @@ -573,94 +574,84 @@ case class HashAggregateExec( enableTwoLevelHashMap(ctx) } else { sqlContext.getConf("spark.sql.codegen.aggregate.map.vectorized.enable", null) match { - case "true" => logWarning("Two level hashmap is disabled but vectorized hashmap is " + - "enabled.") - case null | "" | "false" => None + case "true" => + logWarning("Two level hashmap is disabled but vectorized hashmap is enabled.") + case _ => } } - fastHashMapTerm = ctx.freshName("fastHashMap") - val fastHashMapClassName = ctx.freshName("FastHashMap") - val fastHashMapGenerator = - if (isVectorizedHashMapEnabled) { - new VectorizedHashMapGenerator(ctx, aggregateExpressions, - fastHashMapClassName, groupingKeySchema, bufferSchema) - } else { - new RowBasedHashMapGenerator(ctx, aggregateExpressions, - fastHashMapClassName, groupingKeySchema, bufferSchema) - } val thisPlan = ctx.addReferenceObj("plan", this) - // Create a name for iterator from vectorized HashMap + // Create a name for the iterator from the fast hash map. val iterTermForFastHashMap = ctx.freshName("fastHashMapIter") if (isFastHashMapEnabled) { + // Generates the fast hash map class and creates the fash hash map term. + fastHashMapTerm = ctx.freshName("fastHashMap") + val fastHashMapClassName = ctx.freshName("FastHashMap") if (isVectorizedHashMapEnabled) { + val generatedMap = new VectorizedHashMapGenerator(ctx, aggregateExpressions, + fastHashMapClassName, groupingKeySchema, bufferSchema).generate() + ctx.addInnerClass(generatedMap) + ctx.addMutableState(fastHashMapClassName, fastHashMapTerm, s"$fastHashMapTerm = new $fastHashMapClassName();") ctx.addMutableState( - "java.util.Iterator<org.apache.spark.sql.execution.vectorized.ColumnarRow>", + s"java.util.Iterator<${classOf[ColumnarRow].getName}>", iterTermForFastHashMap) } else { + val generatedMap = new RowBasedHashMapGenerator(ctx, aggregateExpressions, + fastHashMapClassName, groupingKeySchema, bufferSchema).generate() + ctx.addInnerClass(generatedMap) + ctx.addMutableState(fastHashMapClassName, fastHashMapTerm, s"$fastHashMapTerm = new $fastHashMapClassName(" + s"$thisPlan.getTaskMemoryManager(), $thisPlan.getEmptyAggregationBuffer());") ctx.addMutableState( - "org.apache.spark.unsafe.KVIterator", + "org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow>", iterTermForFastHashMap) } } + // Create a name for the iterator from the regular hash map. + val iterTerm = ctx.freshName("mapIter") + ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm) // create hashMap hashMapTerm = ctx.freshName("hashMap") val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName - ctx.addMutableState(hashMapClassName, hashMapTerm) + ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();") sorterTerm = ctx.freshName("sorter") ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, sorterTerm) - // Create a name for iterator from HashMap - val iterTerm = ctx.freshName("mapIter") - ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm) - - def generateGenerateCode(): String = { - if (isFastHashMapEnabled) { - if (isVectorizedHashMapEnabled) { - s""" - | ${fastHashMapGenerator.asInstanceOf[VectorizedHashMapGenerator].generate()} - """.stripMargin - } else { - s""" - | ${fastHashMapGenerator.asInstanceOf[RowBasedHashMapGenerator].generate()} - """.stripMargin - } - } else "" - } - ctx.addInnerClass(generateGenerateCode()) - val doAgg = ctx.freshName("doAggregateWithKeys") val peakMemory = metricTerm(ctx, "peakMemory") val spillSize = metricTerm(ctx, "spillSize") val avgHashProbe = metricTerm(ctx, "avgHashProbe") - val doAggFuncName = ctx.addNewFunction(doAgg, - s""" - private void $doAgg() throws java.io.IOException { - $hashMapTerm = $thisPlan.createHashMap(); - ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} - ${if (isFastHashMapEnabled) { - s"$iterTermForFastHashMap = $fastHashMapTerm.rowIterator();"} else ""} + val finishRegularHashMap = s"$iterTerm = $thisPlan.finishAggregate(" + + s"$hashMapTerm, $sorterTerm, $peakMemory, $spillSize, $avgHashProbe);" + val finishHashMap = if (isFastHashMapEnabled) { + s""" + |$iterTermForFastHashMap = $fastHashMapTerm.rowIterator(); + |$finishRegularHashMap + """.stripMargin + } else { + finishRegularHashMap + } - $iterTerm = $thisPlan.finishAggregate($hashMapTerm, $sorterTerm, $peakMemory, $spillSize, - $avgHashProbe); - } - """) + val doAggFuncName = ctx.addNewFunction(doAgg, + s""" + |private void $doAgg() throws java.io.IOException { + | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} + | $finishHashMap + |} + """.stripMargin) // generate code for output val keyTerm = ctx.freshName("aggKey") val bufferTerm = ctx.freshName("aggBuffer") val outputFunc = generateResultFunction(ctx) - val numOutput = metricTerm(ctx, "numOutputRows") - def outputFromGeneratedMap: String = { + def outputFromFastHashMap: String = { if (isFastHashMapEnabled) { if (isVectorizedHashMapEnabled) { outputFromVectorizedMap @@ -672,48 +663,56 @@ case class HashAggregateExec( def outputFromRowBasedMap: String = { s""" - while ($iterTermForFastHashMap.next()) { - $numOutput.add(1); - UnsafeRow $keyTerm = (UnsafeRow) $iterTermForFastHashMap.getKey(); - UnsafeRow $bufferTerm = (UnsafeRow) $iterTermForFastHashMap.getValue(); - $outputFunc($keyTerm, $bufferTerm); - - if (shouldStop()) return; - } - $fastHashMapTerm.close(); - """ + |while ($iterTermForFastHashMap.next()) { + | UnsafeRow $keyTerm = (UnsafeRow) $iterTermForFastHashMap.getKey(); + | UnsafeRow $bufferTerm = (UnsafeRow) $iterTermForFastHashMap.getValue(); + | $outputFunc($keyTerm, $bufferTerm); + | + | if (shouldStop()) return; + |} + |$fastHashMapTerm.close(); + """.stripMargin } // Iterate over the aggregate rows and convert them from ColumnarRow to UnsafeRow def outputFromVectorizedMap: String = { - val row = ctx.freshName("fastHashMapRow") - ctx.currentVars = null - ctx.INPUT_ROW = row - val generateKeyRow = GenerateUnsafeProjection.createCode(ctx, - groupingKeySchema.toAttributes.zipWithIndex + val row = ctx.freshName("fastHashMapRow") + ctx.currentVars = null + ctx.INPUT_ROW = row + val generateKeyRow = GenerateUnsafeProjection.createCode(ctx, + groupingKeySchema.toAttributes.zipWithIndex .map { case (attr, i) => BoundReference(i, attr.dataType, attr.nullable) } - ) - val generateBufferRow = GenerateUnsafeProjection.createCode(ctx, - bufferSchema.toAttributes.zipWithIndex - .map { case (attr, i) => - BoundReference(groupingKeySchema.length + i, attr.dataType, attr.nullable) }) - s""" - | while ($iterTermForFastHashMap.hasNext()) { - | $numOutput.add(1); - | org.apache.spark.sql.execution.vectorized.ColumnarRow $row = - | (org.apache.spark.sql.execution.vectorized.ColumnarRow) - | $iterTermForFastHashMap.next(); - | ${generateKeyRow.code} - | ${generateBufferRow.code} - | $outputFunc(${generateKeyRow.value}, ${generateBufferRow.value}); - | - | if (shouldStop()) return; - | } - | - | $fastHashMapTerm.close(); - """.stripMargin + ) + val generateBufferRow = GenerateUnsafeProjection.createCode(ctx, + bufferSchema.toAttributes.zipWithIndex.map { case (attr, i) => + BoundReference(groupingKeySchema.length + i, attr.dataType, attr.nullable) + }) + val columnarRowCls = classOf[ColumnarRow].getName + s""" + |while ($iterTermForFastHashMap.hasNext()) { + | $columnarRowCls $row = ($columnarRowCls) $iterTermForFastHashMap.next(); + | ${generateKeyRow.code} + | ${generateBufferRow.code} + | $outputFunc(${generateKeyRow.value}, ${generateBufferRow.value}); + | + | if (shouldStop()) return; + |} + | + |$fastHashMapTerm.close(); + """.stripMargin } + def outputFromRegularHashMap: String = { + s""" + |while ($iterTerm.next()) { + | UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey(); + | UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue(); + | $outputFunc($keyTerm, $bufferTerm); + | + | if (shouldStop()) return; + |} + """.stripMargin + } val aggTime = metricTerm(ctx, "aggTime") val beforeAgg = ctx.freshName("beforeAgg") @@ -726,16 +725,8 @@ case class HashAggregateExec( } // output the result - ${outputFromGeneratedMap} - - while ($iterTerm.next()) { - $numOutput.add(1); - UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey(); - UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue(); - $outputFunc($keyTerm, $bufferTerm); - - if (shouldStop()) return; - } + $outputFromFastHashMap + $outputFromRegularHashMap $iterTerm.close(); if ($sorterTerm == null) { @@ -745,13 +736,11 @@ case class HashAggregateExec( } private def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { - // create grouping key - ctx.currentVars = input val unsafeRowKeyCode = GenerateUnsafeProjection.createCode( ctx, groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output))) val fastRowKeys = ctx.generateExpressions( - groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output))) + groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output))) val unsafeRowKeys = unsafeRowKeyCode.value val unsafeRowBuffer = ctx.freshName("unsafeRowAggBuffer") val fastRowBuffer = ctx.freshName("fastAggBuffer") @@ -768,12 +757,8 @@ case class HashAggregateExec( // generate hash code for key val hashExpr = Murmur3Hash(groupingExpressions, 42) - ctx.currentVars = input val hashEval = BindReferences.bindReference(hashExpr, child.output).genCode(ctx) - val inputAttr = aggregateBufferAttributes ++ child.output - ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ input - val (checkFallbackForGeneratedHashMap, checkFallbackForBytesToBytesMap, resetCounter, incCounter) = if (testFallbackStartsAt.isDefined) { val countTerm = ctx.freshName("fallbackCounter") @@ -784,86 +769,65 @@ case class HashAggregateExec( ("true", "true", "", "") } - // We first generate code to probe and update the fast hash map. If the probe is - // successful the corresponding fast row buffer will hold the mutable row - val findOrInsertFastHashMap: Option[String] = { + val findOrInsertRegularHashMap: String = + s""" + |// generate grouping key + |${unsafeRowKeyCode.code.trim} + |${hashEval.code.trim} + |if ($checkFallbackForBytesToBytesMap) { + | // try to get the buffer from hash map + | $unsafeRowBuffer = + | $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, ${hashEval.value}); + |} + |// Can't allocate buffer from the hash map. Spill the map and fallback to sort-based + |// aggregation after processing all input rows. + |if ($unsafeRowBuffer == null) { + | if ($sorterTerm == null) { + | $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter(); + | } else { + | $sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter()); + | } + | $resetCounter + | // the hash map had be spilled, it should have enough memory now, + | // try to allocate buffer again. + | $unsafeRowBuffer = $hashMapTerm.getAggregationBufferFromUnsafeRow( + | $unsafeRowKeys, ${hashEval.value}); + | if ($unsafeRowBuffer == null) { + | // failed to allocate the first page + | throw new OutOfMemoryError("No enough memory for aggregation"); + | } + |} + """.stripMargin + + val findOrInsertHashMap: String = { if (isFastHashMapEnabled) { - Option( - s""" - | - |if ($checkFallbackForGeneratedHashMap) { - | ${fastRowKeys.map(_.code).mkString("\n")} - | if (${fastRowKeys.map("!" + _.isNull).mkString(" && ")}) { - | $fastRowBuffer = $fastHashMapTerm.findOrInsert( - | ${fastRowKeys.map(_.value).mkString(", ")}); - | } - |} - """.stripMargin) + // If fast hash map is on, we first generate code to probe and update the fast hash map. + // If the probe is successful the corresponding fast row buffer will hold the mutable row. + s""" + |if ($checkFallbackForGeneratedHashMap) { + | ${fastRowKeys.map(_.code).mkString("\n")} + | if (${fastRowKeys.map("!" + _.isNull).mkString(" && ")}) { + | $fastRowBuffer = $fastHashMapTerm.findOrInsert( + | ${fastRowKeys.map(_.value).mkString(", ")}); + | } + |} + |// Cannot find the key in fast hash map, try regular hash map. + |if ($fastRowBuffer == null) { + | $findOrInsertRegularHashMap + |} + """.stripMargin } else { - None + findOrInsertRegularHashMap } } + val inputAttr = aggregateBufferAttributes ++ child.output + // Here we set `currentVars(0)` to `currentVars(numBufferSlots)` to null, so that when + // generating code for buffer columns, we use `INPUT_ROW`(will be the buffer row), while + // generating input columns, we use `currentVars`. + ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ input - def updateRowInFastHashMap(isVectorized: Boolean): Option[String] = { - ctx.INPUT_ROW = fastRowBuffer - val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr)) - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) - val effectiveCodes = subExprs.codes.mkString("\n") - val fastRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) { - boundUpdateExpr.map(_.genCode(ctx)) - } - val updateFastRow = fastRowEvals.zipWithIndex.map { case (ev, i) => - val dt = updateExpr(i).dataType - ctx.updateColumn(fastRowBuffer, dt, i, ev, updateExpr(i).nullable, isVectorized) - } - Option( - s""" - |// common sub-expressions - |$effectiveCodes - |// evaluate aggregate function - |${evaluateVariables(fastRowEvals)} - |// update fast row - |${updateFastRow.mkString("\n").trim} - | - """.stripMargin) - } - - // Next, we generate code to probe and update the unsafe row hash map. - val findOrInsertInUnsafeRowMap: String = { - s""" - | if ($fastRowBuffer == null) { - | // generate grouping key - | ${unsafeRowKeyCode.code.trim} - | ${hashEval.code.trim} - | if ($checkFallbackForBytesToBytesMap) { - | // try to get the buffer from hash map - | $unsafeRowBuffer = - | $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, ${hashEval.value}); - | } - | // Can't allocate buffer from the hash map. Spill the map and fallback to sort-based - | // aggregation after processing all input rows. - | if ($unsafeRowBuffer == null) { - | if ($sorterTerm == null) { - | $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter(); - | } else { - | $sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter()); - | } - | $resetCounter - | // the hash map had be spilled, it should have enough memory now, - | // try to allocate buffer again. - | $unsafeRowBuffer = - | $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, ${hashEval.value}); - | if ($unsafeRowBuffer == null) { - | // failed to allocate the first page - | throw new OutOfMemoryError("No enough memory for aggregation"); - | } - | } - | } - """.stripMargin - } - - val updateRowInUnsafeRowMap: String = { + val updateRowInRegularHashMap: String = { ctx.INPUT_ROW = unsafeRowBuffer val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr)) val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) @@ -882,45 +846,69 @@ case class HashAggregateExec( |${evaluateVariables(unsafeRowBufferEvals)} |// update unsafe row buffer |${updateUnsafeRowBuffer.mkString("\n").trim} - """.stripMargin + """.stripMargin + } + + val updateRowInHashMap: String = { + if (isFastHashMapEnabled) { + ctx.INPUT_ROW = fastRowBuffer + val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr)) + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) + val effectiveCodes = subExprs.codes.mkString("\n") + val fastRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) { + boundUpdateExpr.map(_.genCode(ctx)) + } + val updateFastRow = fastRowEvals.zipWithIndex.map { case (ev, i) => + val dt = updateExpr(i).dataType + ctx.updateColumn( + fastRowBuffer, dt, i, ev, updateExpr(i).nullable, isVectorizedHashMapEnabled) + } + + // If fast hash map is on, we first generate code to update row in fast hash map, if the + // previous loop up hit fast hash map. Otherwise, update row in regular hash map. + s""" + |if ($fastRowBuffer != null) { + | // common sub-expressions + | $effectiveCodes + | // evaluate aggregate function + | ${evaluateVariables(fastRowEvals)} + | // update fast row + | ${updateFastRow.mkString("\n").trim} + |} else { + | $updateRowInRegularHashMap + |} + """.stripMargin + } else { + updateRowInRegularHashMap + } } + val declareRowBuffer: String = if (isFastHashMapEnabled) { + val fastRowType = if (isVectorizedHashMapEnabled) { + classOf[MutableColumnarRow].getName + } else { + "UnsafeRow" + } + s""" + |UnsafeRow $unsafeRowBuffer = null; + |$fastRowType $fastRowBuffer = null; + """.stripMargin + } else { + s"UnsafeRow $unsafeRowBuffer = null;" + } // We try to do hash map based in-memory aggregation first. If there is not enough memory (the // hash map will return null for new key), we spill the hash map to disk to free memory, then // continue to do in-memory aggregation and spilling until all the rows had been processed. // Finally, sort the spilled aggregate buffers by key, and merge them together for same key. s""" - UnsafeRow $unsafeRowBuffer = null; - ${ - if (isVectorizedHashMapEnabled) { - s""" - | ${classOf[MutableColumnarRow].getName} $fastRowBuffer = null; - """.stripMargin - } else { - s""" - | UnsafeRow $fastRowBuffer = null; - """.stripMargin - } - } + $declareRowBuffer - ${findOrInsertFastHashMap.getOrElse("")} - - $findOrInsertInUnsafeRowMap + $findOrInsertHashMap $incCounter - if ($fastRowBuffer != null) { - // update fast row - ${ - if (isFastHashMapEnabled) { - updateRowInFastHashMap(isVectorizedHashMapEnabled).getOrElse("") - } else "" - } - } else { - // update unsafe row - $updateRowInUnsafeRowMap - } + $updateRowInHashMap """ } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org