Github user gatorsmile commented on a diff in the pull request: https://github.com/apache/spark/pull/19082#discussion_r156003342 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala --- @@ -256,6 +258,85 @@ case class HashAggregateExec( """.stripMargin } + // Extracts all the input variable references for a given `aggExpr`. This result will be used + // to split aggregation into small functions. + private def getInputVariableReferences( + ctx: CodegenContext, + aggExpr: Expression, + subExprs: Map[Expression, SubExprEliminationState]): Set[(String, String)] = { + // `argSet` collects all the pairs of variable names and their types, the first in the pair is + // a type name and the second is a variable name. + val argSet = mutable.Set[(String, String)]() + val stack = mutable.Stack[Expression](aggExpr) + while (stack.nonEmpty) { + stack.pop() match { + case e if subExprs.contains(e) => + val exprCode = subExprs(e) + if (CodegenContext.isJavaIdentifier(exprCode.value)) { + argSet += ((ctx.javaType(e.dataType), exprCode.value)) + } + if (CodegenContext.isJavaIdentifier(exprCode.isNull)) { + argSet += (("boolean", exprCode.isNull)) + } + // Since the children possibly has common expressions, we push them here + stack.pushAll(e.children) + case ref: BoundReference + if ctx.currentVars != null && ctx.currentVars(ref.ordinal) != null => + val value = ctx.currentVars(ref.ordinal).value + val isNull = ctx.currentVars(ref.ordinal).isNull + if (CodegenContext.isJavaIdentifier(value)) { + argSet += ((ctx.javaType(ref.dataType), value)) + } + if (CodegenContext.isJavaIdentifier(isNull)) { + argSet += (("boolean", isNull)) + } + case _: BoundReference => + argSet += (("InternalRow", ctx.INPUT_ROW)) + case e => + stack.pushAll(e.children) + } + } + + argSet.toSet + } + + // Splits aggregate code into small functions because JVMs does not compile too long functions + private def splitAggregateExpressions( + ctx: CodegenContext, + aggExprs: Seq[Expression], + evalAndUpdateCodes: Seq[String], + subExprs: Map[Expression, SubExprEliminationState], + otherArgs: Seq[(String, String)] = Seq.empty): Seq[String] = { + aggExprs.zipWithIndex.map { case (aggExpr, i) => + // The maximum length of parameters in non-static Java methods is 254, but a parameter of + // type long or double contributes two units to the length. So, this method gives up + // splitting the code if the parameter length goes over 127. + val args = (getInputVariableReferences(ctx, aggExpr, subExprs) ++ otherArgs).toSeq + + // This is for testing/benchmarking only + val maxParamNumInJavaMethod = + sqlContext.getConf("spark.sql.codegen.aggregate.maxParamNumInJavaMethod", null) match { --- End diff -- Let us introduce an internal SQLConf. If the number is high enough, we can disable this feature.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org