[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user maropu commented on a diff in the pull request: https://github.com/apache/spark/pull/19082#discussion_r179391181 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala --- @@ -257,6 +259,78 @@ case class HashAggregateExec( """.stripMargin } + // Extracts all the input variable references for a given `aggExpr`. This result will be used + // to split aggregation into small functions. + private def getInputVariableReferences( + context: CodegenContext, + aggregateExpression: Expression, + subExprs: Map[Expression, SubExprEliminationState]): Set[(String, String)] = { +// `argSet` collects all the pairs of variable names and their types, the first in the pair is +// a type name and the second is a variable name. +val argSet = mutable.Set[(String, String)]() +val stack = mutable.Stack[Expression](aggregateExpression) +while (stack.nonEmpty) { + stack.pop() match { +case e if subExprs.contains(e) => + val exprCode = subExprs(e) + if (CodegenContext.isJavaIdentifier(exprCode.value)) { --- End diff -- hey, good news! Thanks for letting me know ;) --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user rednaxelafx commented on a diff in the pull request: https://github.com/apache/spark/pull/19082#discussion_r179367236 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala --- @@ -257,6 +259,78 @@ case class HashAggregateExec( """.stripMargin } + // Extracts all the input variable references for a given `aggExpr`. This result will be used + // to split aggregation into small functions. + private def getInputVariableReferences( + context: CodegenContext, + aggregateExpression: Expression, + subExprs: Map[Expression, SubExprEliminationState]): Set[(String, String)] = { +// `argSet` collects all the pairs of variable names and their types, the first in the pair is +// a type name and the second is a variable name. +val argSet = mutable.Set[(String, String)]() +val stack = mutable.Stack[Expression](aggregateExpression) +while (stack.nonEmpty) { + stack.pop() match { +case e if subExprs.contains(e) => + val exprCode = subExprs(e) + if (CodegenContext.isJavaIdentifier(exprCode.value)) { --- End diff -- Once we have @viirya 's https://github.com/apache/spark/pull/20043 merged we won't need the ugly `CodegenContext.isJavaIdentifier` hack any more >_<||| --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user maropu commented on a diff in the pull request: https://github.com/apache/spark/pull/19082#discussion_r178716133 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala --- @@ -825,52 +924,92 @@ case class HashAggregateExec( ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ input val updateRowInRegularHashMap: String = { - ctx.INPUT_ROW = unsafeRowBuffer + // We need to copy the aggregation row buffer to a local row first because each aggregate + // function directly updates the buffer when it finishes. --- End diff -- We need this copy because: https://github.com/apache/spark/pull/19082#discussion_r143326742 --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user cloud-fan commented on a diff in the pull request: https://github.com/apache/spark/pull/19082#discussion_r163305111 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala --- @@ -825,52 +924,92 @@ case class HashAggregateExec( ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ input val updateRowInRegularHashMap: String = { - ctx.INPUT_ROW = unsafeRowBuffer + // We need to copy the aggregation row buffer to a local row first because each aggregate + // function directly updates the buffer when it finishes. --- End diff -- why does this matter? We should avoid unnecessary data copy as possible as we can. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user maropu closed the pull request at: https://github.com/apache/spark/pull/19082 --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user maropu commented on a diff in the pull request: https://github.com/apache/spark/pull/19082#discussion_r156249393 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala --- @@ -269,28 +343,50 @@ case class HashAggregateExec( e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions } } -ctx.currentVars = bufVars ++ input + +// We need to copy the aggregation buffer to local variables first because each aggregate +// function directly updates the buffer when it finishes. --- End diff -- just FYI: we must need local copys from this discussions, too https://github.com/apache/spark/pull/19865 --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user maropu commented on a diff in the pull request: https://github.com/apache/spark/pull/19082#discussion_r156229346 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala --- @@ -863,25 +984,43 @@ case class HashAggregateExec( } val updateRowInUnsafeRowMap: String = { - ctx.INPUT_ROW = unsafeRowBuffer + // We need to copy the aggregation row buffer to a local row first because each aggregate + // function directly updates the buffer when it finishes. + val localRowBuffer = ctx.freshName("localUnsafeRowBuffer") + val initLocalRowBuffer = s"InternalRow $localRowBuffer = $unsafeRowBuffer.copy();" + + ctx.INPUT_ROW = localRowBuffer val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr)) val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) val effectiveCodes = subExprs.codes.mkString("\n") val unsafeRowBufferEvals = ctx.withSubExprEliminationExprs(subExprs.states) { boundUpdateExpr.map(_.genCode(ctx)) } - val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) => + + val evalAndUpdateCodes = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) => val dt = updateExpr(i).dataType -ctx.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) +val updateColumnCode = ctx.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) +s""" + | // evaluate aggregate function + | ${ev.code} + | // update unsafe row buffer + | $updateColumnCode + """.stripMargin } + + val updateAggValCode = splitAggregateExpressions( +ctx, boundUpdateExpr, evalAndUpdateCodes, subExprs.states, +Seq(("InternalRow", unsafeRowBuffer))) --- End diff -- ok --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user gatorsmile commented on a diff in the pull request: https://github.com/apache/spark/pull/19082#discussion_r156184646 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala --- @@ -863,25 +984,43 @@ case class HashAggregateExec( } val updateRowInUnsafeRowMap: String = { - ctx.INPUT_ROW = unsafeRowBuffer + // We need to copy the aggregation row buffer to a local row first because each aggregate + // function directly updates the buffer when it finishes. + val localRowBuffer = ctx.freshName("localUnsafeRowBuffer") + val initLocalRowBuffer = s"InternalRow $localRowBuffer = $unsafeRowBuffer.copy();" + + ctx.INPUT_ROW = localRowBuffer val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr)) val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) val effectiveCodes = subExprs.codes.mkString("\n") val unsafeRowBufferEvals = ctx.withSubExprEliminationExprs(subExprs.states) { boundUpdateExpr.map(_.genCode(ctx)) } - val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) => + + val evalAndUpdateCodes = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) => val dt = updateExpr(i).dataType -ctx.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) +val updateColumnCode = ctx.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) +s""" + | // evaluate aggregate function + | ${ev.code} + | // update unsafe row buffer + | $updateColumnCode + """.stripMargin } + + val updateAggValCode = splitAggregateExpressions( +ctx, boundUpdateExpr, evalAndUpdateCodes, subExprs.states, +Seq(("InternalRow", unsafeRowBuffer))) --- End diff -- ``` ctx, boundUpdateExpr, evalAndUpdateCodes, subExprs.states, Seq(("InternalRow", unsafeRowBuffer))) ``` --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user maropu commented on a diff in the pull request: https://github.com/apache/spark/pull/19082#discussion_r156092874 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala --- @@ -863,25 +984,43 @@ case class HashAggregateExec( } val updateRowInUnsafeRowMap: String = { - ctx.INPUT_ROW = unsafeRowBuffer + // We need to copy the aggregation row buffer to a local row first because each aggregate + // function directly updates the buffer when it finishes. + val localRowBuffer = ctx.freshName("localUnsafeRowBuffer") + val initLocalRowBuffer = s"InternalRow $localRowBuffer = $unsafeRowBuffer.copy();" + + ctx.INPUT_ROW = localRowBuffer val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr)) val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) val effectiveCodes = subExprs.codes.mkString("\n") val unsafeRowBufferEvals = ctx.withSubExprEliminationExprs(subExprs.states) { boundUpdateExpr.map(_.genCode(ctx)) } - val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) => + + val evalAndUpdateCodes = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) => val dt = updateExpr(i).dataType -ctx.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) +val updateColumnCode = ctx.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) +s""" + | // evaluate aggregate function + | ${ev.code} + | // update unsafe row buffer + | $updateColumnCode + """.stripMargin } + + val updateAggValCode = splitAggregateExpressions( +ctx, boundUpdateExpr, evalAndUpdateCodes, subExprs.states, +Seq(("InternalRow", unsafeRowBuffer))) --- End diff -- Need more indents here? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user maropu commented on a diff in the pull request: https://github.com/apache/spark/pull/19082#discussion_r156092752 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala --- @@ -256,6 +258,85 @@ case class HashAggregateExec( """.stripMargin } + // Extracts all the input variable references for a given `aggExpr`. This result will be used + // to split aggregation into small functions. + private def getInputVariableReferences( + ctx: CodegenContext, + aggExpr: Expression, + subExprs: Map[Expression, SubExprEliminationState]): Set[(String, String)] = { +// `argSet` collects all the pairs of variable names and their types, the first in the pair is +// a type name and the second is a variable name. +val argSet = mutable.Set[(String, String)]() +val stack = mutable.Stack[Expression](aggExpr) +while (stack.nonEmpty) { + stack.pop() match { +case e if subExprs.contains(e) => + val exprCode = subExprs(e) + if (CodegenContext.isJavaIdentifier(exprCode.value)) { +argSet += ((ctx.javaType(e.dataType), exprCode.value)) + } + if (CodegenContext.isJavaIdentifier(exprCode.isNull)) { +argSet += (("boolean", exprCode.isNull)) + } + // Since the children possibly has common expressions, we push them here + stack.pushAll(e.children) +case ref: BoundReference +if ctx.currentVars != null && ctx.currentVars(ref.ordinal) != null => + val value = ctx.currentVars(ref.ordinal).value + val isNull = ctx.currentVars(ref.ordinal).isNull + if (CodegenContext.isJavaIdentifier(value)) { +argSet += ((ctx.javaType(ref.dataType), value)) + } + if (CodegenContext.isJavaIdentifier(isNull)) { +argSet += (("boolean", isNull)) + } +case _: BoundReference => + argSet += (("InternalRow", ctx.INPUT_ROW)) +case e => + stack.pushAll(e.children) + } +} + +argSet.toSet + } + + // Splits aggregate code into small functions because JVMs does not compile too long functions + private def splitAggregateExpressions( + ctx: CodegenContext, + aggExprs: Seq[Expression], + evalAndUpdateCodes: Seq[String], + subExprs: Map[Expression, SubExprEliminationState], + otherArgs: Seq[(String, String)] = Seq.empty): Seq[String] = { +aggExprs.zipWithIndex.map { case (aggExpr, i) => + // The maximum length of parameters in non-static Java methods is 254, but a parameter of + // type long or double contributes two units to the length. So, this method gives up + // splitting the code if the parameter length goes over 127. + val args = (getInputVariableReferences(ctx, aggExpr, subExprs) ++ otherArgs).toSeq + + // This is for testing/benchmarking only + val maxParamNumInJavaMethod = + sqlContext.getConf("spark.sql.codegen.aggregate.maxParamNumInJavaMethod", null) match { --- End diff -- ok --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user gatorsmile commented on a diff in the pull request: https://github.com/apache/spark/pull/19082#discussion_r156002166 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala --- @@ -805,26 +908,44 @@ case class HashAggregateExec( def updateRowInFastHashMap(isVectorized: Boolean): Option[String] = { - ctx.INPUT_ROW = fastRowBuffer + // We need to copy the aggregation row buffer to a local row first because each aggregate + // function directly updates the buffer when it finishes. + val localRowBuffer = ctx.freshName("localFastRowBuffer") + val initLocalRowBuffer = s"InternalRow $localRowBuffer = $fastRowBuffer.copy();" + + ctx.INPUT_ROW = localRowBuffer val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr)) val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) val effectiveCodes = subExprs.codes.mkString("\n") val fastRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) { boundUpdateExpr.map(_.genCode(ctx)) } - val updateFastRow = fastRowEvals.zipWithIndex.map { case (ev, i) => + + val evalAndUpdateCodes = fastRowEvals.zipWithIndex.map { case (ev, i) => val dt = updateExpr(i).dataType -ctx.updateColumn(fastRowBuffer, dt, i, ev, updateExpr(i).nullable, isVectorized) +val updateColumnCode = ctx.updateColumn( + fastRowBuffer, dt, i, ev, updateExpr(i).nullable, isVectorized) +s""" + | // evaluate aggregate function + | ${ev.code} + | // update fast row + | $updateColumnCode + """.stripMargin } + + val updateAggValCode = splitAggregateExpressions( +ctx, boundUpdateExpr, evalAndUpdateCodes, subExprs.states, +Seq(("InternalRow", fastRowBuffer))) --- End diff -- indents --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user gatorsmile commented on a diff in the pull request: https://github.com/apache/spark/pull/19082#discussion_r15593 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala --- @@ -1070,6 +1071,24 @@ class CodegenContext { } } +object CodegenContext { + + private val javaKeywords = Set( +"abstract", "assert", "boolean", "break", "byte", "case", "catch", "char", "class", "const", +"continue", "default", "do", "double", "else", "extends", "false", "final", "finally", "float", +"for", "goto", "if", "implements", "import", "instanceof", "int", "interface", "long", "native", +"new", "null", "package", "private", "protected", "public", "return", "short", "static", +"strictfp", "super", "switch", "synchronized", "this", "throw", "throws", "transient", "true", +"try", "void", "volatile", "while" + ) + + def isJavaIdentifier(str: String): Boolean = str match { +case null | "" => false +case _ => !javaKeywords.contains(str) && isJavaIdentifierStart(str.charAt(0)) && + (1 until str.length).forall(i => isJavaIdentifierPart(str.charAt(i))) + } --- End diff -- ```Scala /** * Returns true if the given `str` is a valid java identifier. */ def isJavaIdentifier(str: String): Boolean = str match { case null | "" => false case _ => !javaKeywords.contains(str) && isJavaIdentifierStart(str.charAt(0)) && (1 until str.length).forall(i => isJavaIdentifierPart(str.charAt(i))) } ``` --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user gatorsmile commented on a diff in the pull request: https://github.com/apache/spark/pull/19082#discussion_r156003342 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala --- @@ -256,6 +258,85 @@ case class HashAggregateExec( """.stripMargin } + // Extracts all the input variable references for a given `aggExpr`. This result will be used + // to split aggregation into small functions. + private def getInputVariableReferences( + ctx: CodegenContext, + aggExpr: Expression, + subExprs: Map[Expression, SubExprEliminationState]): Set[(String, String)] = { +// `argSet` collects all the pairs of variable names and their types, the first in the pair is +// a type name and the second is a variable name. +val argSet = mutable.Set[(String, String)]() +val stack = mutable.Stack[Expression](aggExpr) +while (stack.nonEmpty) { + stack.pop() match { +case e if subExprs.contains(e) => + val exprCode = subExprs(e) + if (CodegenContext.isJavaIdentifier(exprCode.value)) { +argSet += ((ctx.javaType(e.dataType), exprCode.value)) + } + if (CodegenContext.isJavaIdentifier(exprCode.isNull)) { +argSet += (("boolean", exprCode.isNull)) + } + // Since the children possibly has common expressions, we push them here + stack.pushAll(e.children) +case ref: BoundReference +if ctx.currentVars != null && ctx.currentVars(ref.ordinal) != null => + val value = ctx.currentVars(ref.ordinal).value + val isNull = ctx.currentVars(ref.ordinal).isNull + if (CodegenContext.isJavaIdentifier(value)) { +argSet += ((ctx.javaType(ref.dataType), value)) + } + if (CodegenContext.isJavaIdentifier(isNull)) { +argSet += (("boolean", isNull)) + } +case _: BoundReference => + argSet += (("InternalRow", ctx.INPUT_ROW)) +case e => + stack.pushAll(e.children) + } +} + +argSet.toSet + } + + // Splits aggregate code into small functions because JVMs does not compile too long functions + private def splitAggregateExpressions( + ctx: CodegenContext, + aggExprs: Seq[Expression], + evalAndUpdateCodes: Seq[String], + subExprs: Map[Expression, SubExprEliminationState], + otherArgs: Seq[(String, String)] = Seq.empty): Seq[String] = { +aggExprs.zipWithIndex.map { case (aggExpr, i) => + // The maximum length of parameters in non-static Java methods is 254, but a parameter of + // type long or double contributes two units to the length. So, this method gives up + // splitting the code if the parameter length goes over 127. + val args = (getInputVariableReferences(ctx, aggExpr, subExprs) ++ otherArgs).toSeq + + // This is for testing/benchmarking only + val maxParamNumInJavaMethod = + sqlContext.getConf("spark.sql.codegen.aggregate.maxParamNumInJavaMethod", null) match { --- End diff -- Let us introduce an internal SQLConf. If the number is high enough, we can disable this feature. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user gatorsmile commented on a diff in the pull request: https://github.com/apache/spark/pull/19082#discussion_r156000426 --- Diff: sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala --- @@ -380,4 +380,19 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { s"Incorrect Evaluation: expressions: $exprAnd, actual: $actualAnd, expected: $expectedAnd") } } + + test("SPARK-21870 check if CodegenContext.isJavaIdentifier works correctly") { +assert(CodegenContext.isJavaIdentifier("agg_value") === true) +assert(CodegenContext.isJavaIdentifier("agg_value1") === true) +assert(CodegenContext.isJavaIdentifier("bhj_value4") === true) +assert(CodegenContext.isJavaIdentifier("smj_value6") === true) +assert(CodegenContext.isJavaIdentifier("rdd_value7") === true) +assert(CodegenContext.isJavaIdentifier("scan_isNull") === true) +assert(CodegenContext.isJavaIdentifier("test") === true) +assert(CodegenContext.isJavaIdentifier("true") === false) +assert(CodegenContext.isJavaIdentifier("false") === false) +assert(CodegenContext.isJavaIdentifier("390239") === false) +assert(CodegenContext.isJavaIdentifier(literal) === false) +assert(CodegenContext.isJavaIdentifier(double) === false) --- End diff -- ```Scala import CodegenContext.isJavaIdentifier // positive cases assert(isJavaIdentifier("agg_value")) assert(isJavaIdentifier("agg_value1")) assert(isJavaIdentifier("bhj_value4")) assert(isJavaIdentifier("smj_value6")) assert(isJavaIdentifier("rdd_value7")) assert(isJavaIdentifier("scan_isNull")) assert(isJavaIdentifier("test")) // negative cases assert(!isJavaIdentifier("true")) assert(!isJavaIdentifier("false")) assert(!isJavaIdentifier("390239")) assert(!isJavaIdentifier(literal)) assert(!isJavaIdentifier(double)) ``` --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user gatorsmile commented on a diff in the pull request: https://github.com/apache/spark/pull/19082#discussion_r156002134 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala --- @@ -863,25 +984,43 @@ case class HashAggregateExec( } val updateRowInUnsafeRowMap: String = { - ctx.INPUT_ROW = unsafeRowBuffer + // We need to copy the aggregation row buffer to a local row first because each aggregate + // function directly updates the buffer when it finishes. + val localRowBuffer = ctx.freshName("localUnsafeRowBuffer") + val initLocalRowBuffer = s"InternalRow $localRowBuffer = $unsafeRowBuffer.copy();" + + ctx.INPUT_ROW = localRowBuffer val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr)) val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) val effectiveCodes = subExprs.codes.mkString("\n") val unsafeRowBufferEvals = ctx.withSubExprEliminationExprs(subExprs.states) { boundUpdateExpr.map(_.genCode(ctx)) } - val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) => + + val evalAndUpdateCodes = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) => val dt = updateExpr(i).dataType -ctx.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) +val updateColumnCode = ctx.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) +s""" + | // evaluate aggregate function + | ${ev.code} + | // update unsafe row buffer + | $updateColumnCode + """.stripMargin } + + val updateAggValCode = splitAggregateExpressions( +ctx, boundUpdateExpr, evalAndUpdateCodes, subExprs.states, +Seq(("InternalRow", unsafeRowBuffer))) --- End diff -- indents. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user viirya commented on a diff in the pull request: https://github.com/apache/spark/pull/19082#discussion_r143903396 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala --- @@ -244,6 +246,89 @@ case class HashAggregateExec( protected override val shouldStopRequired = false + // Extracts all the input variable references for a given `aggExpr`. This result will be used + // to split aggregation into small functions. + private def getInputVariableReferences( + ctx: CodegenContext, + aggExpr: Expression, + subExprs: Map[Expression, SubExprEliminationState]): Set[(String, String)] = { +// `argSet` collects all the pairs of variable names and their types, the first in the pair is +// a type name and the second is a variable name. +val argSet = mutable.Set[(String, String)]() +val stack = mutable.Stack[Expression](aggExpr) +while (stack.nonEmpty) { + stack.pop() match { +case e if subExprs.contains(e) => + val exprCode = subExprs(e) + if (CodegenContext.isJavaIdentifier(exprCode.value)) { +argSet += ((ctx.javaType(e.dataType), exprCode.value)) + } + if (CodegenContext.isJavaIdentifier(exprCode.isNull)) { +argSet += (("boolean", exprCode.isNull)) + } + // Since the children possibly has common expressions, we push them here + stack.pushAll(e.children) +case ref: BoundReference +if ctx.currentVars != null && ctx.currentVars(ref.ordinal) != null => + val value = ctx.currentVars(ref.ordinal).value + val isNull = ctx.currentVars(ref.ordinal).isNull + if (CodegenContext.isJavaIdentifier(value)) { +argSet += ((ctx.javaType(ref.dataType), value)) + } + if (CodegenContext.isJavaIdentifier(isNull)) { +argSet += (("boolean", isNull)) + } +case _: BoundReference => + argSet += (("InternalRow", ctx.INPUT_ROW)) +case e => + stack.pushAll(e.children) + } +} + +argSet.toSet + } + + // Splits the aggregation into small functions because the HotSpot does not compile + // too long functions. + private def splitAggregateExpressions( + ctx: CodegenContext, + aggExprs: Seq[Expression], + evalAndUpdateCodes: Seq[String], + subExprs: Map[Expression, SubExprEliminationState], + otherArgs: Seq[(String, String)] = Seq.empty): Seq[String] = { +aggExprs.zipWithIndex.map { case (aggExpr, i) => + // The maximum number of parameters in non-static Java methods is 254, so this method gives + // up splitting the code if the number goes over the limit. + // You can find more information about the limit in the JVM specification: + // - The number of method parameters is limited to 255 by the definition of a method --- End diff -- @kiszk Thanks for pinging me. I've updated the similar check in #18931 too. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user viirya commented on a diff in the pull request: https://github.com/apache/spark/pull/19082#discussion_r143897243 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala --- @@ -244,6 +246,89 @@ case class HashAggregateExec( protected override val shouldStopRequired = false + // Extracts all the input variable references for a given `aggExpr`. This result will be used + // to split aggregation into small functions. + private def getInputVariableReferences( + ctx: CodegenContext, + aggExpr: Expression, + subExprs: Map[Expression, SubExprEliminationState]): Set[(String, String)] = { +// `argSet` collects all the pairs of variable names and their types, the first in the pair is +// a type name and the second is a variable name. +val argSet = mutable.Set[(String, String)]() +val stack = mutable.Stack[Expression](aggExpr) +while (stack.nonEmpty) { + stack.pop() match { +case e if subExprs.contains(e) => + val exprCode = subExprs(e) + if (CodegenContext.isJavaIdentifier(exprCode.value)) { +argSet += ((ctx.javaType(e.dataType), exprCode.value)) + } + if (CodegenContext.isJavaIdentifier(exprCode.isNull)) { +argSet += (("boolean", exprCode.isNull)) + } + // Since the children possibly has common expressions, we push them here + stack.pushAll(e.children) +case ref: BoundReference +if ctx.currentVars != null && ctx.currentVars(ref.ordinal) != null => + val value = ctx.currentVars(ref.ordinal).value + val isNull = ctx.currentVars(ref.ordinal).isNull + if (CodegenContext.isJavaIdentifier(value)) { +argSet += ((ctx.javaType(ref.dataType), value)) + } + if (CodegenContext.isJavaIdentifier(isNull)) { +argSet += (("boolean", isNull)) + } +case _: BoundReference => + argSet += (("InternalRow", ctx.INPUT_ROW)) +case e => + stack.pushAll(e.children) + } +} + +argSet.toSet + } + + // Splits the aggregation into small functions because the HotSpot does not compile + // too long functions. + private def splitAggregateExpressions( + ctx: CodegenContext, + aggExprs: Seq[Expression], + evalAndUpdateCodes: Seq[String], + subExprs: Map[Expression, SubExprEliminationState], + otherArgs: Seq[(String, String)] = Seq.empty): Seq[String] = { +aggExprs.zipWithIndex.map { case (aggExpr, i) => + // The maximum number of parameters in non-static Java methods is 254, so this method gives + // up splitting the code if the number goes over the limit. + // You can find more information about the limit in the JVM specification: + // - The number of method parameters is limited to 255 by the definition of a method --- End diff -- @a10y Thanks for the info. Very helpful. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user maropu commented on a diff in the pull request: https://github.com/apache/spark/pull/19082#discussion_r143897031 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala --- @@ -244,6 +246,89 @@ case class HashAggregateExec( protected override val shouldStopRequired = false + // Extracts all the input variable references for a given `aggExpr`. This result will be used + // to split aggregation into small functions. + private def getInputVariableReferences( + ctx: CodegenContext, + aggExpr: Expression, + subExprs: Map[Expression, SubExprEliminationState]): Set[(String, String)] = { +// `argSet` collects all the pairs of variable names and their types, the first in the pair is +// a type name and the second is a variable name. +val argSet = mutable.Set[(String, String)]() +val stack = mutable.Stack[Expression](aggExpr) +while (stack.nonEmpty) { + stack.pop() match { +case e if subExprs.contains(e) => + val exprCode = subExprs(e) + if (CodegenContext.isJavaIdentifier(exprCode.value)) { +argSet += ((ctx.javaType(e.dataType), exprCode.value)) + } + if (CodegenContext.isJavaIdentifier(exprCode.isNull)) { +argSet += (("boolean", exprCode.isNull)) + } + // Since the children possibly has common expressions, we push them here + stack.pushAll(e.children) +case ref: BoundReference +if ctx.currentVars != null && ctx.currentVars(ref.ordinal) != null => + val value = ctx.currentVars(ref.ordinal).value + val isNull = ctx.currentVars(ref.ordinal).isNull + if (CodegenContext.isJavaIdentifier(value)) { +argSet += ((ctx.javaType(ref.dataType), value)) + } + if (CodegenContext.isJavaIdentifier(isNull)) { +argSet += (("boolean", isNull)) + } +case _: BoundReference => + argSet += (("InternalRow", ctx.INPUT_ROW)) +case e => + stack.pushAll(e.children) + } +} + +argSet.toSet + } + + // Splits the aggregation into small functions because the HotSpot does not compile + // too long functions. + private def splitAggregateExpressions( + ctx: CodegenContext, + aggExprs: Seq[Expression], + evalAndUpdateCodes: Seq[String], + subExprs: Map[Expression, SubExprEliminationState], + otherArgs: Seq[(String, String)] = Seq.empty): Seq[String] = { +aggExprs.zipWithIndex.map { case (aggExpr, i) => + // The maximum number of parameters in non-static Java methods is 254, so this method gives + // up splitting the code if the number goes over the limit. + // You can find more information about the limit in the JVM specification: + // - The number of method parameters is limited to 255 by the definition of a method --- End diff -- I just fixed the code to give up splitting if the length goes over 127 because IIUC the current implemented aggregate functions in spark do not go over the limit. I feel it is some complicated to check types there... --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user kiszk commented on a diff in the pull request: https://github.com/apache/spark/pull/19082#discussion_r143895917 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala --- @@ -244,6 +246,89 @@ case class HashAggregateExec( protected override val shouldStopRequired = false + // Extracts all the input variable references for a given `aggExpr`. This result will be used + // to split aggregation into small functions. + private def getInputVariableReferences( + ctx: CodegenContext, + aggExpr: Expression, + subExprs: Map[Expression, SubExprEliminationState]): Set[(String, String)] = { +// `argSet` collects all the pairs of variable names and their types, the first in the pair is +// a type name and the second is a variable name. +val argSet = mutable.Set[(String, String)]() +val stack = mutable.Stack[Expression](aggExpr) +while (stack.nonEmpty) { + stack.pop() match { +case e if subExprs.contains(e) => + val exprCode = subExprs(e) + if (CodegenContext.isJavaIdentifier(exprCode.value)) { +argSet += ((ctx.javaType(e.dataType), exprCode.value)) + } + if (CodegenContext.isJavaIdentifier(exprCode.isNull)) { +argSet += (("boolean", exprCode.isNull)) + } + // Since the children possibly has common expressions, we push them here + stack.pushAll(e.children) +case ref: BoundReference +if ctx.currentVars != null && ctx.currentVars(ref.ordinal) != null => + val value = ctx.currentVars(ref.ordinal).value + val isNull = ctx.currentVars(ref.ordinal).isNull + if (CodegenContext.isJavaIdentifier(value)) { +argSet += ((ctx.javaType(ref.dataType), value)) + } + if (CodegenContext.isJavaIdentifier(isNull)) { +argSet += (("boolean", isNull)) + } +case _: BoundReference => + argSet += (("InternalRow", ctx.INPUT_ROW)) +case e => + stack.pushAll(e.children) + } +} + +argSet.toSet + } + + // Splits the aggregation into small functions because the HotSpot does not compile + // too long functions. + private def splitAggregateExpressions( + ctx: CodegenContext, + aggExprs: Seq[Expression], + evalAndUpdateCodes: Seq[String], + subExprs: Map[Expression, SubExprEliminationState], + otherArgs: Seq[(String, String)] = Seq.empty): Seq[String] = { +aggExprs.zipWithIndex.map { case (aggExpr, i) => + // The maximum number of parameters in non-static Java methods is 254, so this method gives + // up splitting the code if the number goes over the limit. + // You can find more information about the limit in the JVM specification: + // - The number of method parameters is limited to 255 by the definition of a method --- End diff -- @a10y Good catch. You are right. We have 254 slots. Each Long or double takes two slots. We need to check type of parameters, too. cc: @maropu, @viirya --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user maropu commented on a diff in the pull request: https://github.com/apache/spark/pull/19082#discussion_r143876358 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala --- @@ -244,6 +246,89 @@ case class HashAggregateExec( protected override val shouldStopRequired = false + // Extracts all the input variable references for a given `aggExpr`. This result will be used + // to split aggregation into small functions. + private def getInputVariableReferences( + ctx: CodegenContext, + aggExpr: Expression, + subExprs: Map[Expression, SubExprEliminationState]): Set[(String, String)] = { +// `argSet` collects all the pairs of variable names and their types, the first in the pair is +// a type name and the second is a variable name. +val argSet = mutable.Set[(String, String)]() +val stack = mutable.Stack[Expression](aggExpr) +while (stack.nonEmpty) { + stack.pop() match { +case e if subExprs.contains(e) => + val exprCode = subExprs(e) + if (CodegenContext.isJavaIdentifier(exprCode.value)) { +argSet += ((ctx.javaType(e.dataType), exprCode.value)) + } + if (CodegenContext.isJavaIdentifier(exprCode.isNull)) { +argSet += (("boolean", exprCode.isNull)) + } + // Since the children possibly has common expressions, we push them here + stack.pushAll(e.children) +case ref: BoundReference +if ctx.currentVars != null && ctx.currentVars(ref.ordinal) != null => + val value = ctx.currentVars(ref.ordinal).value + val isNull = ctx.currentVars(ref.ordinal).isNull + if (CodegenContext.isJavaIdentifier(value)) { +argSet += ((ctx.javaType(ref.dataType), value)) + } + if (CodegenContext.isJavaIdentifier(isNull)) { +argSet += (("boolean", isNull)) + } +case _: BoundReference => + argSet += (("InternalRow", ctx.INPUT_ROW)) +case e => + stack.pushAll(e.children) + } +} + +argSet.toSet + } + + // Splits the aggregation into small functions because the HotSpot does not compile + // too long functions. + private def splitAggregateExpressions( + ctx: CodegenContext, + aggExprs: Seq[Expression], + evalAndUpdateCodes: Seq[String], + subExprs: Map[Expression, SubExprEliminationState], + otherArgs: Seq[(String, String)] = Seq.empty): Seq[String] = { +aggExprs.zipWithIndex.map { case (aggExpr, i) => + // The maximum number of parameters in non-static Java methods is 254, so this method gives + // up splitting the code if the number goes over the limit. + // You can find more information about the limit in the JVM specification: + // - The number of method parameters is limited to 255 by the definition of a method --- End diff -- Aha, I'll recheck this. Thanks. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user a10y commented on a diff in the pull request: https://github.com/apache/spark/pull/19082#discussion_r143836700 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala --- @@ -244,6 +246,89 @@ case class HashAggregateExec( protected override val shouldStopRequired = false + // Extracts all the input variable references for a given `aggExpr`. This result will be used + // to split aggregation into small functions. + private def getInputVariableReferences( + ctx: CodegenContext, + aggExpr: Expression, + subExprs: Map[Expression, SubExprEliminationState]): Set[(String, String)] = { +// `argSet` collects all the pairs of variable names and their types, the first in the pair is +// a type name and the second is a variable name. +val argSet = mutable.Set[(String, String)]() +val stack = mutable.Stack[Expression](aggExpr) +while (stack.nonEmpty) { + stack.pop() match { +case e if subExprs.contains(e) => + val exprCode = subExprs(e) + if (CodegenContext.isJavaIdentifier(exprCode.value)) { +argSet += ((ctx.javaType(e.dataType), exprCode.value)) + } + if (CodegenContext.isJavaIdentifier(exprCode.isNull)) { +argSet += (("boolean", exprCode.isNull)) + } + // Since the children possibly has common expressions, we push them here + stack.pushAll(e.children) +case ref: BoundReference +if ctx.currentVars != null && ctx.currentVars(ref.ordinal) != null => + val value = ctx.currentVars(ref.ordinal).value + val isNull = ctx.currentVars(ref.ordinal).isNull + if (CodegenContext.isJavaIdentifier(value)) { +argSet += ((ctx.javaType(ref.dataType), value)) + } + if (CodegenContext.isJavaIdentifier(isNull)) { +argSet += (("boolean", isNull)) + } +case _: BoundReference => + argSet += (("InternalRow", ctx.INPUT_ROW)) +case e => + stack.pushAll(e.children) + } +} + +argSet.toSet + } + + // Splits the aggregation into small functions because the HotSpot does not compile + // too long functions. + private def splitAggregateExpressions( + ctx: CodegenContext, + aggExprs: Seq[Expression], + evalAndUpdateCodes: Seq[String], + subExprs: Map[Expression, SubExprEliminationState], + otherArgs: Seq[(String, String)] = Seq.empty): Seq[String] = { +aggExprs.zipWithIndex.map { case (aggExpr, i) => + // The maximum number of parameters in non-static Java methods is 254, so this method gives + // up splitting the code if the number goes over the limit. + // You can find more information about the limit in the JVM specification: + // - The number of method parameters is limited to 255 by the definition of a method --- End diff -- Unless you wanna check the types of all the parameters you might be better off halving this to 127 parameters for the worst case. Though I'm not sure how many codegens this affects... --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user a10y commented on a diff in the pull request: https://github.com/apache/spark/pull/19082#discussion_r143836110 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala --- @@ -244,6 +246,89 @@ case class HashAggregateExec( protected override val shouldStopRequired = false + // Extracts all the input variable references for a given `aggExpr`. This result will be used + // to split aggregation into small functions. + private def getInputVariableReferences( + ctx: CodegenContext, + aggExpr: Expression, + subExprs: Map[Expression, SubExprEliminationState]): Set[(String, String)] = { +// `argSet` collects all the pairs of variable names and their types, the first in the pair is +// a type name and the second is a variable name. +val argSet = mutable.Set[(String, String)]() +val stack = mutable.Stack[Expression](aggExpr) +while (stack.nonEmpty) { + stack.pop() match { +case e if subExprs.contains(e) => + val exprCode = subExprs(e) + if (CodegenContext.isJavaIdentifier(exprCode.value)) { +argSet += ((ctx.javaType(e.dataType), exprCode.value)) + } + if (CodegenContext.isJavaIdentifier(exprCode.isNull)) { +argSet += (("boolean", exprCode.isNull)) + } + // Since the children possibly has common expressions, we push them here + stack.pushAll(e.children) +case ref: BoundReference +if ctx.currentVars != null && ctx.currentVars(ref.ordinal) != null => + val value = ctx.currentVars(ref.ordinal).value + val isNull = ctx.currentVars(ref.ordinal).isNull + if (CodegenContext.isJavaIdentifier(value)) { +argSet += ((ctx.javaType(ref.dataType), value)) + } + if (CodegenContext.isJavaIdentifier(isNull)) { +argSet += (("boolean", isNull)) + } +case _: BoundReference => + argSet += (("InternalRow", ctx.INPUT_ROW)) +case e => + stack.pushAll(e.children) + } +} + +argSet.toSet + } + + // Splits the aggregation into small functions because the HotSpot does not compile + // too long functions. + private def splitAggregateExpressions( + ctx: CodegenContext, + aggExprs: Seq[Expression], + evalAndUpdateCodes: Seq[String], + subExprs: Map[Expression, SubExprEliminationState], + otherArgs: Seq[(String, String)] = Seq.empty): Seq[String] = { +aggExprs.zipWithIndex.map { case (aggExpr, i) => + // The maximum number of parameters in non-static Java methods is 254, so this method gives + // up splitting the code if the number goes over the limit. + // You can find more information about the limit in the JVM specification: + // - The number of method parameters is limited to 255 by the definition of a method --- End diff -- If you read the spec closely, it actually says that there are 255 slots, where 1 slot is taken by **this** and 2 slots each are taken up by **long** and **double** parameters. https://docs.oracle.com/javase/specs/jvms/se8/html/jvms-4.html#jvms-4.3.3 --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user maropu commented on a diff in the pull request: https://github.com/apache/spark/pull/19082#discussion_r143359416 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala --- @@ -797,26 +904,44 @@ case class HashAggregateExec( def updateRowInFastHashMap(isVectorized: Boolean): Option[String] = { - ctx.INPUT_ROW = fastRowBuffer + // We need to copy the aggregation row buffer to a local row first because each aggregate + // function directly updates the buffer when it finishes. + val localRowBuffer = ctx.freshName("localFastRowBuffer") + val initLocalRowBuffer = s"InternalRow $localRowBuffer = $fastRowBuffer.copy();" --- End diff -- I just passed the local variable as each function argument; ``` /* 329 */ // do aggregate /* 330 */ // copy aggregation row buffer to the local /* 331 */ InternalRow agg_localFastRowBuffer = agg_fastAggBuffer.copy(); /* 332 */ // common sub-expressions /* 333 */ boolean agg_isNull27 = false; /* 334 */ long agg_value30 = -1L; /* 335 */ if (!false) { /* 336 */ agg_value30 = (long) inputadapter_value; /* 337 */ } /* 338 */ // process aggregate functions to update aggregation buffer /* 339 */ agg_doAggregateVal_add2(inputadapter_value, agg_value30, agg_fastAggBuffer, agg_localFastRowBuffer, agg_isNull27); /* 340 */ agg_doAggregateVal_add3(inputadapter_value, agg_value30, agg_fastAggBuffer, agg_localFastRowBuffer, agg_isNull27); /* 341 */ agg_doAggregateVal_if1(inputadapter_value, agg_value30, agg_fastAggBuffer, agg_localFastRowBuffer, agg_isNull27); /* 342 */ ``` Since each split function directly updates an input row, we need to copy it to the local so that all the split functions can reference the old state. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user viirya commented on a diff in the pull request: https://github.com/apache/spark/pull/19082#discussion_r143326742 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala --- @@ -797,26 +904,44 @@ case class HashAggregateExec( def updateRowInFastHashMap(isVectorized: Boolean): Option[String] = { - ctx.INPUT_ROW = fastRowBuffer + // We need to copy the aggregation row buffer to a local row first because each aggregate + // function directly updates the buffer when it finishes. + val localRowBuffer = ctx.freshName("localFastRowBuffer") + val initLocalRowBuffer = s"InternalRow $localRowBuffer = $fastRowBuffer.copy();" --- End diff -- Why we need to copy the row buffer? You let `updateExpr` bound to the local copied row buffer, but the evaluation is happened in split functions. Isn't possible the `updateExpr` can't find the local variable of the copied row buffer in the functions? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user maropu commented on a diff in the pull request: https://github.com/apache/spark/pull/19082#discussion_r143218854 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala --- @@ -944,6 +945,24 @@ class CodegenContext { } } +object CodegenContext { + + private val javaKeywords = Set( --- End diff -- cc: @rednaxelafx --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user kiszk commented on a diff in the pull request: https://github.com/apache/spark/pull/19082#discussion_r143216395 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala --- @@ -944,6 +945,24 @@ class CodegenContext { } } +object CodegenContext { + + private val javaKeywords = Set( --- End diff -- Do we need add `enum`? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user kiszk commented on a diff in the pull request: https://github.com/apache/spark/pull/19082#discussion_r136900835 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala --- @@ -244,6 +246,89 @@ case class HashAggregateExec( protected override val shouldStopRequired = false + // Extracts all the input variable references for a given `aggExpr`. This result will be used + // to split aggregation into small functions. + private def getInputVariableReferences( + ctx: CodegenContext, + aggExpr: Expression, + subExprs: Map[Expression, SubExprEliminationState]): Set[(String, String)] = { +// `argSet` collects all the pairs of variable names and their types, the first in the pair is +// a type name and the second is a variable name. +val argSet = mutable.Set[(String, String)]() +val stack = mutable.Stack[Expression](aggExpr) +while (stack.nonEmpty) { + stack.pop() match { +case e if subExprs.contains(e) => + val exprCode = subExprs(e) + if (CodegenContext.isJavaIdentifier(exprCode.value)) { +argSet += ((ctx.javaType(e.dataType), exprCode.value)) + } + if (CodegenContext.isJavaIdentifier(exprCode.isNull)) { +argSet += (("boolean", exprCode.isNull)) + } + // Since the children possibly has common expressions, we push them here + stack.pushAll(e.children) +case ref: BoundReference +if ctx.currentVars != null && ctx.currentVars(ref.ordinal) != null => + val value = ctx.currentVars(ref.ordinal).value + val isNull = ctx.currentVars(ref.ordinal).isNull + if (CodegenContext.isJavaIdentifier(value)) { +argSet += ((ctx.javaType(ref.dataType), value)) + } + if (CodegenContext.isJavaIdentifier(isNull)) { +argSet += (("boolean", isNull)) + } +case _: BoundReference => + argSet += (("InternalRow", ctx.INPUT_ROW)) +case e => + stack.pushAll(e.children) + } +} + +argSet.toSet + } + + // Splits the aggregation into small functions because the HotSpot does not compile + // too long functions. + private def splitAggregateExpressions( + ctx: CodegenContext, + aggExprs: Seq[Expression], + evalAndUpdateCodes: Seq[String], + subExprs: Map[Expression, SubExprEliminationState], + otherArgs: Seq[(String, String)] = Seq.empty): Seq[String] = { +aggExprs.zipWithIndex.map { case (aggExpr, i) => + // The maximum number of parameters in Java methods is 255, so this method gives up splitting + // the code if the number goes over the limit. + // You can find more information about the limit in the JVM specification: + // - The number of method parameters is limited to 255 by the definition of a method + // descriptor, where the limit includes one unit for this in the case of instance + // or interface method invocations. + val args = (getInputVariableReferences(ctx, aggExpr, subExprs) ++ otherArgs).toSeq + + // This is for testing/benchmarking only + val maxParamNumInJavaMethod = + sqlContext.getConf("spark.sql.codegen.aggregate.maxParamNumInJavaMethod", null) match { +case null | "" => 255 --- End diff -- If line 314 uses `<=`, this should be 254. In the previous commit, `<` is used. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user maropu commented on a diff in the pull request: https://github.com/apache/spark/pull/19082#discussion_r136518024 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala --- @@ -244,6 +246,92 @@ case class HashAggregateExec( protected override val shouldStopRequired = false + // We assume a prefix has lower cases and a name has camel cases + private val variableName = "^[a-z]+_[a-zA-Z]+[0-9]*".r --- End diff -- Good suggestion and I'm also looking for other better one. I'll try to fix. --- If your project is set up for it, you can reply to this email and have your reply appear on GitHub as well. If your project does not have this feature enabled and wishes so, or if the feature is enabled but not working, please contact infrastructure at infrastruct...@apache.org or file a JIRA ticket with INFRA. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user maropu commented on a diff in the pull request: https://github.com/apache/spark/pull/19082#discussion_r136517567 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala --- @@ -244,6 +246,92 @@ case class HashAggregateExec( protected override val shouldStopRequired = false + // We assume a prefix has lower cases and a name has camel cases + private val variableName = "^[a-z]+_[a-zA-Z]+[0-9]*".r + + // Returns true if a given name id belongs to this `CodegenContext` + private def isVariable(nameId: String): Boolean = nameId match { +case variableName() => true +case _ => false + } + + // Extracts all the outer references for a given `aggExpr`. This result will be used to split + // aggregation into small functions. + private def getOuterReferences( --- End diff -- ok, I'll rename this. --- If your project is set up for it, you can reply to this email and have your reply appear on GitHub as well. If your project does not have this feature enabled and wishes so, or if the feature is enabled but not working, please contact infrastructure at infrastruct...@apache.org or file a JIRA ticket with INFRA. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user rednaxelafx commented on a diff in the pull request: https://github.com/apache/spark/pull/19082#discussion_r136506452 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala --- @@ -244,6 +246,92 @@ case class HashAggregateExec( protected override val shouldStopRequired = false + // We assume a prefix has lower cases and a name has camel cases + private val variableName = "^[a-z]+_[a-zA-Z]+[0-9]*".r + + // Returns true if a given name id belongs to this `CodegenContext` + private def isVariable(nameId: String): Boolean = nameId match { +case variableName() => true +case _ => false + } + + // Extracts all the outer references for a given `aggExpr`. This result will be used to split + // aggregation into small functions. + private def getOuterReferences( + ctx: CodegenContext, + aggExpr: Expression, + subExprs: Map[Expression, SubExprEliminationState]): Set[(String, String)] = { +val stack = mutable.Stack[Expression](aggExpr) +val argSet = mutable.Set[(String, String)]() +val addIfNotLiteral = (value: String, tpe: String) => { --- End diff -- Hmm. Just a cosmetic style comment: I would have declared `addIfNotLiteral ` with a `def` instead of making it a `scala.Function2[String, String, Unit]`. BTW, can we add a comment to `val argSet` for what those two fields of the `Tuple2[String, String]` means? And then also make this `addIfNotLiteral` function take the arguments in the same order as the tuple. --- If your project is set up for it, you can reply to this email and have your reply appear on GitHub as well. If your project does not have this feature enabled and wishes so, or if the feature is enabled but not working, please contact infrastructure at infrastruct...@apache.org or file a JIRA ticket with INFRA. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user rednaxelafx commented on a diff in the pull request: https://github.com/apache/spark/pull/19082#discussion_r136506046 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala --- @@ -244,6 +246,92 @@ case class HashAggregateExec( protected override val shouldStopRequired = false + // We assume a prefix has lower cases and a name has camel cases + private val variableName = "^[a-z]+_[a-zA-Z]+[0-9]*".r --- End diff -- I know the regular expression is tempting, but there's actually a better way to do this along your idea, under the current framework. I've got a piece of code sitting in my own workspace that checks for Java identifiers: ```scala object CodegenContext { private val javaKeywords = Set( "abstract", "assert", "boolean", "break", "byte", "case", "catch", "char", "class", "const", "continue", "default", "do", "double", "else", "extends", "false", "final", "finally", "float", "for", "goto", "if", "implements", "import", "instanceof", "int", "interface", "long", "native", "new", "null", "package", "private", "protected", "public", "return", "short", "static", "strictfp", "super", "switch", "synchronized", "this", "throw", "throws", "transient", "true", "try", "void", "volatile", "while" ) def isJavaIdentifier(str: String): Boolean = str match { case null | "" => false case _ => java.lang.Character.isJavaIdentifierStart(str.charAt(0)) && (1 until str.length).forall( i => java.lang.Character.isJavaIdentifierPart(str.charAt(i))) && !javaKeywords.contains(str) } } ``` Feel free to use it here if you'd like. This is the way `java.lang.Character.isJavaIdentifierStart()` and `java.lang.Character.isJavaIdentifierPart()` is supposed to be used anyway, nothing creative. If you want to use it in a `case` like the way you're using the regular expression, just wrap the util above into an `unapply()`. But I'd say simply making `def isVariable(nameId: String) = CodegenContext.isJavaIdentifier(nameId)` is clean enough. --- If your project is set up for it, you can reply to this email and have your reply appear on GitHub as well. If your project does not have this feature enabled and wishes so, or if the feature is enabled but not working, please contact infrastructure at infrastruct...@apache.org or file a JIRA ticket with INFRA. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user viirya commented on a diff in the pull request: https://github.com/apache/spark/pull/19082#discussion_r136500779 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala --- @@ -244,6 +246,92 @@ case class HashAggregateExec( protected override val shouldStopRequired = false + // We assume a prefix has lower cases and a name has camel cases + private val variableName = "^[a-z]+_[a-zA-Z]+[0-9]*".r + + // Returns true if a given name id belongs to this `CodegenContext` + private def isVariable(nameId: String): Boolean = nameId match { +case variableName() => true +case _ => false + } + + // Extracts all the outer references for a given `aggExpr`. This result will be used to split + // aggregation into small functions. + private def getOuterReferences( --- End diff -- `OuterReference` actually has special meaning in correlated subquery. This name can be confusing. --- If your project is set up for it, you can reply to this email and have your reply appear on GitHub as well. If your project does not have this feature enabled and wishes so, or if the feature is enabled but not working, please contact infrastructure at infrastruct...@apache.org or file a JIRA ticket with INFRA. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user maropu commented on a diff in the pull request: https://github.com/apache/spark/pull/19082#discussion_r136490017 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala --- @@ -244,6 +246,92 @@ case class HashAggregateExec( protected override val shouldStopRequired = false + // We assume a prefix has lower cases and a name has camel cases + private val variableName = "^[a-z]+_[a-zA-Z]+[0-9]*".r + + // Returns true if a given name id belongs to this `CodegenContext` + private def isVariable(nameId: String): Boolean = nameId match { +case variableName() => true +case _ => false + } + + // Extracts all the outer references for a given `aggExpr`. This result will be used to split + // aggregation into small functions. + private def getOuterReferences( + ctx: CodegenContext, + aggExpr: Expression, + subExprs: Map[Expression, SubExprEliminationState]): Set[(String, String)] = { +val stack = mutable.Stack[Expression](aggExpr) +val argSet = mutable.Set[(String, String)]() +val addIfNotLiteral = (value: String, tpe: String) => { + if (isVariable(value)) { +argSet += ((tpe, value)) + } +} +while (stack.nonEmpty) { + stack.pop() match { +case e if subExprs.contains(e) => + val exprCode = subExprs(e) + addIfNotLiteral(exprCode.value, ctx.javaType(e.dataType)) + addIfNotLiteral(exprCode.isNull, "boolean") + // Since the children possibly has common expressions, we push them here + stack.pushAll(e.children) +case ref: BoundReference +if ctx.currentVars != null && ctx.currentVars(ref.ordinal) != null => + val argVal = ctx.currentVars(ref.ordinal).value + addIfNotLiteral(argVal, ctx.javaType(ref.dataType)) + addIfNotLiteral(ctx.currentVars(ref.ordinal).isNull, "boolean") +case _: BoundReference => + argSet += (("InternalRow", ctx.INPUT_ROW)) +case e => + stack.pushAll(e.children) + } +} + +argSet.toSet + } + + // Splits the aggregation into small functions because the HotSpot does not compile + // too long functions. + private def splitAggregateExpressions( + ctx: CodegenContext, + aggExprs: Seq[Expression], + evalAndUpdateCodes: Seq[String], + subExprs: Map[Expression, SubExprEliminationState], + otherArgs: Seq[(String, String)] = Seq.empty): Seq[String] = { +aggExprs.zipWithIndex.map { case (aggExpr, i) => + // The maximum number of parameters in Java methods is 255, so this method gives up splitting + // the code if the number goes over the limit. + // You can find more information about the limit in the JVM specification: + // - The number of method parameters is limited to 255 by the definition of a method + // descriptor, where the limit includes one unit for this in the case of instance + // or interface method invocations. + val args = (getOuterReferences(ctx, aggExpr, subExprs) ++ otherArgs).toSeq + + // This is for testing/benchmarking only + val maxParamNumInJavaMethod = + sqlContext.getConf("spark.sql.codegen.aggregate.maxParamNumInJavaMethod", null) match { --- End diff -- This is a test-only option, so I think we need not check that. --- If your project is set up for it, you can reply to this email and have your reply appear on GitHub as well. If your project does not have this feature enabled and wishes so, or if the feature is enabled but not working, please contact infrastructure at infrastruct...@apache.org or file a JIRA ticket with INFRA. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user kiszk commented on a diff in the pull request: https://github.com/apache/spark/pull/19082#discussion_r136455832 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala --- @@ -244,6 +246,92 @@ case class HashAggregateExec( protected override val shouldStopRequired = false + // We assume a prefix has lower cases and a name has camel cases + private val variableName = "^[a-z]+_[a-zA-Z]+[0-9]*".r + + // Returns true if a given name id belongs to this `CodegenContext` + private def isVariable(nameId: String): Boolean = nameId match { +case variableName() => true +case _ => false + } + + // Extracts all the outer references for a given `aggExpr`. This result will be used to split + // aggregation into small functions. + private def getOuterReferences( + ctx: CodegenContext, + aggExpr: Expression, + subExprs: Map[Expression, SubExprEliminationState]): Set[(String, String)] = { +val stack = mutable.Stack[Expression](aggExpr) +val argSet = mutable.Set[(String, String)]() +val addIfNotLiteral = (value: String, tpe: String) => { + if (isVariable(value)) { +argSet += ((tpe, value)) + } +} +while (stack.nonEmpty) { + stack.pop() match { +case e if subExprs.contains(e) => + val exprCode = subExprs(e) + addIfNotLiteral(exprCode.value, ctx.javaType(e.dataType)) + addIfNotLiteral(exprCode.isNull, "boolean") + // Since the children possibly has common expressions, we push them here + stack.pushAll(e.children) +case ref: BoundReference +if ctx.currentVars != null && ctx.currentVars(ref.ordinal) != null => + val argVal = ctx.currentVars(ref.ordinal).value + addIfNotLiteral(argVal, ctx.javaType(ref.dataType)) + addIfNotLiteral(ctx.currentVars(ref.ordinal).isNull, "boolean") +case _: BoundReference => + argSet += (("InternalRow", ctx.INPUT_ROW)) +case e => + stack.pushAll(e.children) + } +} + +argSet.toSet + } + + // Splits the aggregation into small functions because the HotSpot does not compile + // too long functions. + private def splitAggregateExpressions( + ctx: CodegenContext, + aggExprs: Seq[Expression], + evalAndUpdateCodes: Seq[String], + subExprs: Map[Expression, SubExprEliminationState], + otherArgs: Seq[(String, String)] = Seq.empty): Seq[String] = { +aggExprs.zipWithIndex.map { case (aggExpr, i) => + // The maximum number of parameters in Java methods is 255, so this method gives up splitting + // the code if the number goes over the limit. + // You can find more information about the limit in the JVM specification: + // - The number of method parameters is limited to 255 by the definition of a method + // descriptor, where the limit includes one unit for this in the case of instance + // or interface method invocations. + val args = (getOuterReferences(ctx, aggExpr, subExprs) ++ otherArgs).toSeq + + // This is for testing/benchmarking only + val maxParamNumInJavaMethod = + sqlContext.getConf("spark.sql.codegen.aggregate.maxParamNumInJavaMethod", null) match { --- End diff -- Can we add a check code if a user specify a value that is more than 255? --- If your project is set up for it, you can reply to this email and have your reply appear on GitHub as well. If your project does not have this feature enabled and wishes so, or if the feature is enabled but not working, please contact infrastructure at infrastruct...@apache.org or file a JIRA ticket with INFRA. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user maropu commented on a diff in the pull request: https://github.com/apache/spark/pull/19082#discussion_r136207629 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala --- @@ -244,6 +246,92 @@ case class HashAggregateExec( protected override val shouldStopRequired = false + // We assume a prefix has lower cases and a name has camel cases + private val variableName = "^[a-z]+_[a-zA-Z]+[0-9]*".r + + // Returns true if a given name id belongs to this `CodegenContext` + private def isVariable(nameId: String): Boolean = nameId match { +case variableName() => true +case _ => false + } + + // Extracts all the outer references for a given `aggExpr`. This result will be used to split + // aggregation into small functions. + private def getOuterReferences( + ctx: CodegenContext, + aggExpr: Expression, + subExprs: Map[Expression, SubExprEliminationState]): Set[(String, String)] = { +val stack = mutable.Stack[Expression](aggExpr) +val argSet = mutable.Set[(String, String)]() +val addIfNotLiteral = (value: String, tpe: String) => { + if (isVariable(value)) { +argSet += ((tpe, value)) + } +} +while (stack.nonEmpty) { + stack.pop() match { +case e if subExprs.contains(e) => + val exprCode = subExprs(e) + addIfNotLiteral(exprCode.value, ctx.javaType(e.dataType)) + addIfNotLiteral(exprCode.isNull, "boolean") + // Since the children possibly has common expressions, we push them here + stack.pushAll(e.children) +case ref: BoundReference +if ctx.currentVars != null && ctx.currentVars(ref.ordinal) != null => + val argVal = ctx.currentVars(ref.ordinal).value + addIfNotLiteral(argVal, ctx.javaType(ref.dataType)) + addIfNotLiteral(ctx.currentVars(ref.ordinal).isNull, "boolean") +case _: BoundReference => + argSet += (("InternalRow", ctx.INPUT_ROW)) +case e => + stack.pushAll(e.children) + } +} + +argSet.toSet + } + + // Splits the aggregation into small functions because the HotSpot does not compile + // too long functions. + private def splitAggregateExpressions( + ctx: CodegenContext, + aggExprs: Seq[Expression], + evalAndUpdateCodes: Seq[String], + subExprs: Map[Expression, SubExprEliminationState], + otherArgs: Seq[(String, String)] = Seq.empty): Seq[String] = { +aggExprs.zipWithIndex.map { case (aggExpr, i) => + // The maximum number of parameters in Java methods is 255, so this method gives up splitting + // the code if the number goes over the limit. + // You can find more information about the limit in the JVM specification: + // - The number of method parameters is limited to 255 by the definition of a method + // descriptor, where the limit includes one unit for this in the case of instance + // or interface method invocations. + val args = (getOuterReferences(ctx, aggExpr, subExprs) ++ otherArgs).toSeq + + // This is for testing/benchmarking only + val maxParamNumInJavaMethod = + sqlContext.getConf("spark.sql.codegen.aggregate.maxParamNumInJavaMethod", null) match { +case null | "" => 256 --- End diff -- oh, good catch! I'll fix --- If your project is set up for it, you can reply to this email and have your reply appear on GitHub as well. If your project does not have this feature enabled and wishes so, or if the feature is enabled but not working, please contact infrastructure at infrastruct...@apache.org or file a JIRA ticket with INFRA. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
Github user kiszk commented on a diff in the pull request: https://github.com/apache/spark/pull/19082#discussion_r136090141 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala --- @@ -244,6 +246,92 @@ case class HashAggregateExec( protected override val shouldStopRequired = false + // We assume a prefix has lower cases and a name has camel cases + private val variableName = "^[a-z]+_[a-zA-Z]+[0-9]*".r + + // Returns true if a given name id belongs to this `CodegenContext` + private def isVariable(nameId: String): Boolean = nameId match { +case variableName() => true +case _ => false + } + + // Extracts all the outer references for a given `aggExpr`. This result will be used to split + // aggregation into small functions. + private def getOuterReferences( + ctx: CodegenContext, + aggExpr: Expression, + subExprs: Map[Expression, SubExprEliminationState]): Set[(String, String)] = { +val stack = mutable.Stack[Expression](aggExpr) +val argSet = mutable.Set[(String, String)]() +val addIfNotLiteral = (value: String, tpe: String) => { + if (isVariable(value)) { +argSet += ((tpe, value)) + } +} +while (stack.nonEmpty) { + stack.pop() match { +case e if subExprs.contains(e) => + val exprCode = subExprs(e) + addIfNotLiteral(exprCode.value, ctx.javaType(e.dataType)) + addIfNotLiteral(exprCode.isNull, "boolean") + // Since the children possibly has common expressions, we push them here + stack.pushAll(e.children) +case ref: BoundReference +if ctx.currentVars != null && ctx.currentVars(ref.ordinal) != null => + val argVal = ctx.currentVars(ref.ordinal).value + addIfNotLiteral(argVal, ctx.javaType(ref.dataType)) + addIfNotLiteral(ctx.currentVars(ref.ordinal).isNull, "boolean") +case _: BoundReference => + argSet += (("InternalRow", ctx.INPUT_ROW)) +case e => + stack.pushAll(e.children) + } +} + +argSet.toSet + } + + // Splits the aggregation into small functions because the HotSpot does not compile + // too long functions. + private def splitAggregateExpressions( + ctx: CodegenContext, + aggExprs: Seq[Expression], + evalAndUpdateCodes: Seq[String], + subExprs: Map[Expression, SubExprEliminationState], + otherArgs: Seq[(String, String)] = Seq.empty): Seq[String] = { +aggExprs.zipWithIndex.map { case (aggExpr, i) => + // The maximum number of parameters in Java methods is 255, so this method gives up splitting + // the code if the number goes over the limit. + // You can find more information about the limit in the JVM specification: + // - The number of method parameters is limited to 255 by the definition of a method + // descriptor, where the limit includes one unit for this in the case of instance + // or interface method invocations. + val args = (getOuterReferences(ctx, aggExpr, subExprs) ++ otherArgs).toSeq + + // This is for testing/benchmarking only + val maxParamNumInJavaMethod = + sqlContext.getConf("spark.sql.codegen.aggregate.maxParamNumInJavaMethod", null) match { +case null | "" => 256 --- End diff -- Since `$doAggVal` is [non-static method](https://stackoverflow.com/questions/30581531/maximum-number-of-parameters-in-java-method-declaration), this number should be `255`. --- If your project is set up for it, you can reply to this email and have your reply appear on GitHub as well. If your project does not have this feature enabled and wishes so, or if the feature is enabled but not working, please contact infrastructure at infrastruct...@apache.org or file a JIRA ticket with INFRA. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19082: [SPARK-21870][SQL] Split aggregation code into sm...
GitHub user maropu opened a pull request: https://github.com/apache/spark/pull/19082 [SPARK-21870][SQL] Split aggregation code into small functions for the HotSpot ## What changes were proposed in this pull request? This pr proposes to split aggregation code into pieces in `HashAggregateExec` for the JVM HotSpot. In #18810, we got performance regression if the HotSpot didn't compile too long functions (the limit is 8 in bytecode size). I checked and I found the codegen of `HashAggregateExec` frequently goes over the limit, for example: ``` spark.range(1000).selectExpr("id % 1024 AS a", "id AS b").write.saveAsTable("t") sql("SELECT a, KURTOSIS(b)FROM t GROUP BY a") ``` This query goes over the limit and the actual bytecode size is `12356`. This pr split the aggregation code into small separate functions and, in a simple example; ``` sql("SELECT SUM(a), AVG(a) FROM VALUES(1) t(a)").debugCodegen ``` - generated code with this pr: ``` /* 083 */ private void agg_doAggregateWithoutKey() throws java.io.IOException { /* 084 */ // initialize aggregation buffer /* 085 */ final long agg_value = -1L; /* 086 */ agg_bufIsNull = true; /* 087 */ agg_bufValue = agg_value; /* 088 */ boolean agg_isNull1 = false; /* 089 */ double agg_value1 = -1.0; /* 090 */ if (!false) { /* 091 */ agg_value1 = (double) 0; /* 092 */ } /* 093 */ agg_bufIsNull1 = agg_isNull1; /* 094 */ agg_bufValue1 = agg_value1; /* 095 */ agg_bufIsNull2 = false; /* 096 */ agg_bufValue2 = 0L; /* 097 */ /* 098 */ while (inputadapter_input.hasNext() && !stopEarly()) { /* 099 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 100 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 101 */ long inputadapter_value = inputadapter_isNull ? -1L : (inputadapter_row.getLong(0)); /* 102 */ boolean inputadapter_isNull1 = inputadapter_row.isNullAt(1); /* 103 */ double inputadapter_value1 = inputadapter_isNull1 ? -1.0 : (inputadapter_row.getDouble(1)); /* 104 */ boolean inputadapter_isNull2 = inputadapter_row.isNullAt(2); /* 105 */ long inputadapter_value2 = inputadapter_isNull2 ? -1L : (inputadapter_row.getLong(2)); /* 106 */ /* 107 */ // do aggregate /* 108 */ // copy aggregation buffer to the local /* 109 */ boolean agg_localBufIsNull = agg_bufIsNull; /* 110 */ long agg_localBufValue = agg_bufValue; /* 111 */ boolean agg_localBufIsNull1 = agg_bufIsNull1; /* 112 */ double agg_localBufValue1 = agg_bufValue1; /* 113 */ boolean agg_localBufIsNull2 = agg_bufIsNull2; /* 114 */ long agg_localBufValue2 = agg_bufValue2; /* 115 */ // common sub-expressions /* 116 */ /* 117 */ // process aggregate functions to update aggregation buffer /* 118 */ agg_doAggregateVal_coalesce(agg_localBufIsNull, agg_localBufValue, inputadapter_value, inputadapter_isNull); /* 119 */ agg_doAggregateVal_add(agg_localBufValue1, inputadapter_isNull1, inputadapter_value1, agg_localBufIsNull1); /* 120 */ agg_doAggregateVal_add1(inputadapter_isNull2, inputadapter_value2, agg_localBufIsNull2, agg_localBufValue2); /* 121 */ if (shouldStop()) return; /* 122 */ } ``` - generated code in the current master ``` /* 083 */ private void agg_doAggregateWithoutKey() throws java.io.IOException { /* 084 */ // initialize aggregation buffer /* 085 */ final long agg_value = -1L; /* 086 */ agg_bufIsNull = true; /* 087 */ agg_bufValue = agg_value; /* 088 */ boolean agg_isNull1 = false; /* 089 */ double agg_value1 = -1.0; /* 090 */ if (!false) { /* 091 */ agg_value1 = (double) 0; /* 092 */ } /* 093 */ agg_bufIsNull1 = agg_isNull1; /* 094 */ agg_bufValue1 = agg_value1; /* 095 */ agg_bufIsNull2 = false; /* 096 */ agg_bufValue2 = 0L; /* 097 */ /* 098 */ while (inputadapter_input.hasNext() && !stopEarly()) { /* 099 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 100 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 101 */ long inputadapter_value = inputadapter_isNull ? -1L : (inputadapter_row.getLong(0)); /* 102 */ boolean inputadapter_isNull1 = inputadapter_row.isNullAt(1); /* 103 */ double inputadapter_value1 = inputadapter_isNull1 ? -1.0 : (inputadapter_row.getDouble(1)); /* 104 */ boolean inputadapter_isNull2 = inputadapter_row.isNullAt(2); /* 105 */ long inputadapter_value2 = inputadapter_isNull2 ? -1L : (inputadapter_row.getLong(2)); /* 106 */ /* 107 */ // do aggreg