Repository: spark Updated Branches: refs/heads/master 295df746e -> a8af4da12
[SPARK-22682][SQL] HashExpression does not need to create global variables ## What changes were proposed in this pull request? It turns out that `HashExpression` can pass around some values via parameter when splitting codes into methods, to save some global variable slots. This can also prevent a weird case that global variable appears in parameter list, which is discovered by https://github.com/apache/spark/pull/19865 ## How was this patch tested? existing tests Author: Wenchen Fan <wenc...@databricks.com> Closes #19878 from cloud-fan/minor. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a8af4da1 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a8af4da1 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a8af4da1 Branch: refs/heads/master Commit: a8af4da12ce43cd5528a53b5f7f454e9dbe71d6e Parents: 295df74 Author: Wenchen Fan <wenc...@databricks.com> Authored: Tue Dec 5 12:43:05 2017 +0800 Committer: Wenchen Fan <wenc...@databricks.com> Committed: Tue Dec 5 12:43:05 2017 +0800 ---------------------------------------------------------------------- .../spark/sql/catalyst/expressions/hash.scala | 118 +++++++++++++------ .../expressions/HashExpressionsSuite.scala | 34 ++++-- 2 files changed, 106 insertions(+), 46 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/a8af4da1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index c3289b8..d0ed2ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -270,17 +270,36 @@ abstract class HashExpression[E] extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { ev.isNull = "false" - val childrenHash = ctx.splitExpressions(children.map { child => + + val childrenHash = children.map { child => val childGen = child.genCode(ctx) childGen.code + ctx.nullSafeExec(child.nullable, childGen.isNull) { computeHash(childGen.value, child.dataType, ev.value, ctx) } - }) + } + + val hashResultType = ctx.javaType(dataType) + val codes = if (ctx.INPUT_ROW == null || ctx.currentVars != null) { + childrenHash.mkString("\n") + } else { + ctx.splitExpressions( + expressions = childrenHash, + funcName = "computeHash", + arguments = Seq("InternalRow" -> ctx.INPUT_ROW, hashResultType -> ev.value), + returnType = hashResultType, + makeSplitFunction = body => + s""" + |$body + |return ${ev.value}; + """.stripMargin, + foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) + } - ctx.addMutableState(ctx.javaType(dataType), ev.value) - ev.copy(code = s""" - ${ev.value} = $seed; - $childrenHash""") + ev.copy(code = + s""" + |$hashResultType ${ev.value} = $seed; + |$codes + """.stripMargin) } protected def nullSafeElementHash( @@ -389,13 +408,21 @@ abstract class HashExpression[E] extends Expression { input: String, result: String, fields: Array[StructField]): String = { - val hashes = fields.zipWithIndex.map { case (field, index) => + val fieldsHash = fields.zipWithIndex.map { case (field, index) => nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx) } + val hashResultType = ctx.javaType(dataType) ctx.splitExpressions( - expressions = hashes, - funcName = "getHash", - arguments = Seq("InternalRow" -> input)) + expressions = fieldsHash, + funcName = "computeHashForStruct", + arguments = Seq("InternalRow" -> input, hashResultType -> result), + returnType = hashResultType, + makeSplitFunction = body => + s""" + |$body + |return $result; + """.stripMargin, + foldFunctions = _.map(funcCall => s"$result = $funcCall;").mkString("\n")) } @tailrec @@ -610,25 +637,44 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { ev.isNull = "false" + val childHash = ctx.freshName("childHash") - val childrenHash = ctx.splitExpressions(children.map { child => + val childrenHash = children.map { child => val childGen = child.genCode(ctx) val codeToComputeHash = ctx.nullSafeExec(child.nullable, childGen.isNull) { computeHash(childGen.value, child.dataType, childHash, ctx) } s""" |${childGen.code} + |$childHash = 0; |$codeToComputeHash |${ev.value} = (31 * ${ev.value}) + $childHash; - |$childHash = 0; """.stripMargin - }) + } - ctx.addMutableState(ctx.javaType(dataType), ev.value) - ctx.addMutableState(ctx.JAVA_INT, childHash, s"$childHash = 0;") - ev.copy(code = s""" - ${ev.value} = $seed; - $childrenHash""") + val codes = if (ctx.INPUT_ROW == null || ctx.currentVars != null) { + childrenHash.mkString("\n") + } else { + ctx.splitExpressions( + expressions = childrenHash, + funcName = "computeHash", + arguments = Seq("InternalRow" -> ctx.INPUT_ROW, ctx.JAVA_INT -> ev.value), + returnType = ctx.JAVA_INT, + makeSplitFunction = body => + s""" + |${ctx.JAVA_INT} $childHash = 0; + |$body + |return ${ev.value}; + """.stripMargin, + foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) + } + + ev.copy(code = + s""" + |${ctx.JAVA_INT} ${ev.value} = $seed; + |${ctx.JAVA_INT} $childHash = 0; + |$codes + """.stripMargin) } override def eval(input: InternalRow = null): Int = { @@ -730,23 +776,29 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { input: String, result: String, fields: Array[StructField]): String = { - val localResult = ctx.freshName("localResult") val childResult = ctx.freshName("childResult") - fields.zipWithIndex.map { case (field, index) => + val fieldsHash = fields.zipWithIndex.map { case (field, index) => + val computeFieldHash = nullSafeElementHash( + input, index.toString, field.nullable, field.dataType, childResult, ctx) s""" - $childResult = 0; - ${nullSafeElementHash(input, index.toString, field.nullable, field.dataType, - childResult, ctx)} - $localResult = (31 * $localResult) + $childResult; - """ - }.mkString( - s""" - int $localResult = 0; - int $childResult = 0; - """, - "", - s"$result = (31 * $result) + $localResult;" - ) + |$childResult = 0; + |$computeFieldHash + |$result = (31 * $result) + $childResult; + """.stripMargin + } + + s"${ctx.JAVA_INT} $childResult = 0;\n" + ctx.splitExpressions( + expressions = fieldsHash, + funcName = "computeHashForStruct", + arguments = Seq("InternalRow" -> input, ctx.JAVA_INT -> result), + returnType = ctx.JAVA_INT, + makeSplitFunction = body => + s""" + |${ctx.JAVA_INT} $childResult = 0; + |$body + |return $result; + """.stripMargin, + foldFunctions = _.map(funcCall => s"$result = $funcCall;").mkString("\n")) } } http://git-wip-us.apache.org/repos/asf/spark/blob/a8af4da1/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala index 112a4a0..4281c89 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala @@ -27,6 +27,7 @@ import org.scalatest.exceptions.TestFailedException import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{RandomDataGenerator, Row} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, RowEncoder} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} @@ -620,23 +621,30 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("SPARK-18207: Compute hash for a lot of expressions") { + def checkResult(schema: StructType, input: InternalRow): Unit = { + val exprs = schema.fields.zipWithIndex.map { case (f, i) => + BoundReference(i, f.dataType, true) + } + val murmur3HashExpr = Murmur3Hash(exprs, 42) + val murmur3HashPlan = GenerateMutableProjection.generate(Seq(murmur3HashExpr)) + val murmursHashEval = Murmur3Hash(exprs, 42).eval(input) + assert(murmur3HashPlan(input).getInt(0) == murmursHashEval) + + val hiveHashExpr = HiveHash(exprs) + val hiveHashPlan = GenerateMutableProjection.generate(Seq(hiveHashExpr)) + val hiveHashEval = HiveHash(exprs).eval(input) + assert(hiveHashPlan(input).getInt(0) == hiveHashEval) + } + val N = 1000 val wideRow = new GenericInternalRow( Seq.tabulate(N)(i => UTF8String.fromString(i.toString)).toArray[Any]) - val schema = StructType((1 to N).map(i => StructField("", StringType))) - - val exprs = schema.fields.zipWithIndex.map { case (f, i) => - BoundReference(i, f.dataType, true) - } - val murmur3HashExpr = Murmur3Hash(exprs, 42) - val murmur3HashPlan = GenerateMutableProjection.generate(Seq(murmur3HashExpr)) - val murmursHashEval = Murmur3Hash(exprs, 42).eval(wideRow) - assert(murmur3HashPlan(wideRow).getInt(0) == murmursHashEval) + val schema = StructType((1 to N).map(i => StructField(i.toString, StringType))) + checkResult(schema, wideRow) - val hiveHashExpr = HiveHash(exprs) - val hiveHashPlan = GenerateMutableProjection.generate(Seq(hiveHashExpr)) - val hiveHashEval = HiveHash(exprs).eval(wideRow) - assert(hiveHashPlan(wideRow).getInt(0) == hiveHashEval) + val nestedRow = InternalRow(wideRow) + val nestedSchema = new StructType().add("nested", schema) + checkResult(nestedSchema, nestedRow) } test("SPARK-22284: Compute hash for nested structs") { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org