Repository: spark
Updated Branches:
  refs/heads/master 295df746e -> a8af4da12


[SPARK-22682][SQL] HashExpression does not need to create global variables

## What changes were proposed in this pull request?

It turns out that `HashExpression` can pass around some values via parameter 
when splitting codes into methods, to save some global variable slots.

This can also prevent a weird case that global variable appears in parameter 
list, which is discovered by https://github.com/apache/spark/pull/19865

## How was this patch tested?

existing tests

Author: Wenchen Fan <wenc...@databricks.com>

Closes #19878 from cloud-fan/minor.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a8af4da1
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a8af4da1
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a8af4da1

Branch: refs/heads/master
Commit: a8af4da12ce43cd5528a53b5f7f454e9dbe71d6e
Parents: 295df74
Author: Wenchen Fan <wenc...@databricks.com>
Authored: Tue Dec 5 12:43:05 2017 +0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Tue Dec 5 12:43:05 2017 +0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/expressions/hash.scala   | 118 +++++++++++++------
 .../expressions/HashExpressionsSuite.scala      |  34 ++++--
 2 files changed, 106 insertions(+), 46 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a8af4da1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
index c3289b8..d0ed2ab 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
@@ -270,17 +270,36 @@ abstract class HashExpression[E] extends Expression {
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     ev.isNull = "false"
-    val childrenHash = ctx.splitExpressions(children.map { child =>
+
+    val childrenHash = children.map { child =>
       val childGen = child.genCode(ctx)
       childGen.code + ctx.nullSafeExec(child.nullable, childGen.isNull) {
         computeHash(childGen.value, child.dataType, ev.value, ctx)
       }
-    })
+    }
+
+    val hashResultType = ctx.javaType(dataType)
+    val codes = if (ctx.INPUT_ROW == null || ctx.currentVars != null) {
+      childrenHash.mkString("\n")
+    } else {
+      ctx.splitExpressions(
+        expressions = childrenHash,
+        funcName = "computeHash",
+        arguments = Seq("InternalRow" -> ctx.INPUT_ROW, hashResultType -> 
ev.value),
+        returnType = hashResultType,
+        makeSplitFunction = body =>
+          s"""
+             |$body
+             |return ${ev.value};
+           """.stripMargin,
+        foldFunctions = _.map(funcCall => s"${ev.value} = 
$funcCall;").mkString("\n"))
+    }
 
-    ctx.addMutableState(ctx.javaType(dataType), ev.value)
-    ev.copy(code = s"""
-      ${ev.value} = $seed;
-      $childrenHash""")
+    ev.copy(code =
+      s"""
+         |$hashResultType ${ev.value} = $seed;
+         |$codes
+       """.stripMargin)
   }
 
   protected def nullSafeElementHash(
@@ -389,13 +408,21 @@ abstract class HashExpression[E] extends Expression {
       input: String,
       result: String,
       fields: Array[StructField]): String = {
-    val hashes = fields.zipWithIndex.map { case (field, index) =>
+    val fieldsHash = fields.zipWithIndex.map { case (field, index) =>
       nullSafeElementHash(input, index.toString, field.nullable, 
field.dataType, result, ctx)
     }
+    val hashResultType = ctx.javaType(dataType)
     ctx.splitExpressions(
-      expressions = hashes,
-      funcName = "getHash",
-      arguments = Seq("InternalRow" -> input))
+      expressions = fieldsHash,
+      funcName = "computeHashForStruct",
+      arguments = Seq("InternalRow" -> input, hashResultType -> result),
+      returnType = hashResultType,
+      makeSplitFunction = body =>
+        s"""
+           |$body
+           |return $result;
+         """.stripMargin,
+      foldFunctions = _.map(funcCall => s"$result = 
$funcCall;").mkString("\n"))
   }
 
   @tailrec
@@ -610,25 +637,44 @@ case class HiveHash(children: Seq[Expression]) extends 
HashExpression[Int] {
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     ev.isNull = "false"
+
     val childHash = ctx.freshName("childHash")
-    val childrenHash = ctx.splitExpressions(children.map { child =>
+    val childrenHash = children.map { child =>
       val childGen = child.genCode(ctx)
       val codeToComputeHash = ctx.nullSafeExec(child.nullable, 
childGen.isNull) {
         computeHash(childGen.value, child.dataType, childHash, ctx)
       }
       s"""
          |${childGen.code}
+         |$childHash = 0;
          |$codeToComputeHash
          |${ev.value} = (31 * ${ev.value}) + $childHash;
-         |$childHash = 0;
        """.stripMargin
-    })
+    }
 
-    ctx.addMutableState(ctx.javaType(dataType), ev.value)
-    ctx.addMutableState(ctx.JAVA_INT, childHash, s"$childHash = 0;")
-    ev.copy(code = s"""
-      ${ev.value} = $seed;
-      $childrenHash""")
+    val codes = if (ctx.INPUT_ROW == null || ctx.currentVars != null) {
+      childrenHash.mkString("\n")
+    } else {
+      ctx.splitExpressions(
+        expressions = childrenHash,
+        funcName = "computeHash",
+        arguments = Seq("InternalRow" -> ctx.INPUT_ROW, ctx.JAVA_INT -> 
ev.value),
+        returnType = ctx.JAVA_INT,
+        makeSplitFunction = body =>
+          s"""
+             |${ctx.JAVA_INT} $childHash = 0;
+             |$body
+             |return ${ev.value};
+           """.stripMargin,
+        foldFunctions = _.map(funcCall => s"${ev.value} = 
$funcCall;").mkString("\n"))
+    }
+
+    ev.copy(code =
+      s"""
+         |${ctx.JAVA_INT} ${ev.value} = $seed;
+         |${ctx.JAVA_INT} $childHash = 0;
+         |$codes
+       """.stripMargin)
   }
 
   override def eval(input: InternalRow = null): Int = {
@@ -730,23 +776,29 @@ case class HiveHash(children: Seq[Expression]) extends 
HashExpression[Int] {
       input: String,
       result: String,
       fields: Array[StructField]): String = {
-    val localResult = ctx.freshName("localResult")
     val childResult = ctx.freshName("childResult")
-    fields.zipWithIndex.map { case (field, index) =>
+    val fieldsHash = fields.zipWithIndex.map { case (field, index) =>
+      val computeFieldHash = nullSafeElementHash(
+        input, index.toString, field.nullable, field.dataType, childResult, 
ctx)
       s"""
-         $childResult = 0;
-         ${nullSafeElementHash(input, index.toString, field.nullable, 
field.dataType,
-           childResult, ctx)}
-         $localResult = (31 * $localResult) + $childResult;
-       """
-    }.mkString(
-      s"""
-         int $localResult = 0;
-         int $childResult = 0;
-       """,
-      "",
-      s"$result = (31 * $result) + $localResult;"
-    )
+         |$childResult = 0;
+         |$computeFieldHash
+         |$result = (31 * $result) + $childResult;
+       """.stripMargin
+    }
+
+    s"${ctx.JAVA_INT} $childResult = 0;\n" + ctx.splitExpressions(
+      expressions = fieldsHash,
+      funcName = "computeHashForStruct",
+      arguments = Seq("InternalRow" -> input, ctx.JAVA_INT -> result),
+      returnType = ctx.JAVA_INT,
+      makeSplitFunction = body =>
+        s"""
+           |${ctx.JAVA_INT} $childResult = 0;
+           |$body
+           |return $result;
+           """.stripMargin,
+      foldFunctions = _.map(funcCall => s"$result = 
$funcCall;").mkString("\n"))
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/a8af4da1/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
index 112a4a0..4281c89 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
@@ -27,6 +27,7 @@ import org.scalatest.exceptions.TestFailedException
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.{RandomDataGenerator, Row}
+import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, RowEncoder}
 import 
org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
 import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, 
GenericArrayData}
@@ -620,23 +621,30 @@ class HashExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
   }
 
   test("SPARK-18207: Compute hash for a lot of expressions") {
+    def checkResult(schema: StructType, input: InternalRow): Unit = {
+      val exprs = schema.fields.zipWithIndex.map { case (f, i) =>
+        BoundReference(i, f.dataType, true)
+      }
+      val murmur3HashExpr = Murmur3Hash(exprs, 42)
+      val murmur3HashPlan = 
GenerateMutableProjection.generate(Seq(murmur3HashExpr))
+      val murmursHashEval = Murmur3Hash(exprs, 42).eval(input)
+      assert(murmur3HashPlan(input).getInt(0) == murmursHashEval)
+
+      val hiveHashExpr = HiveHash(exprs)
+      val hiveHashPlan = GenerateMutableProjection.generate(Seq(hiveHashExpr))
+      val hiveHashEval = HiveHash(exprs).eval(input)
+      assert(hiveHashPlan(input).getInt(0) == hiveHashEval)
+    }
+
     val N = 1000
     val wideRow = new GenericInternalRow(
       Seq.tabulate(N)(i => UTF8String.fromString(i.toString)).toArray[Any])
-    val schema = StructType((1 to N).map(i => StructField("", StringType)))
-
-    val exprs = schema.fields.zipWithIndex.map { case (f, i) =>
-      BoundReference(i, f.dataType, true)
-    }
-    val murmur3HashExpr = Murmur3Hash(exprs, 42)
-    val murmur3HashPlan = 
GenerateMutableProjection.generate(Seq(murmur3HashExpr))
-    val murmursHashEval = Murmur3Hash(exprs, 42).eval(wideRow)
-    assert(murmur3HashPlan(wideRow).getInt(0) == murmursHashEval)
+    val schema = StructType((1 to N).map(i => StructField(i.toString, 
StringType)))
+    checkResult(schema, wideRow)
 
-    val hiveHashExpr = HiveHash(exprs)
-    val hiveHashPlan = GenerateMutableProjection.generate(Seq(hiveHashExpr))
-    val hiveHashEval = HiveHash(exprs).eval(wideRow)
-    assert(hiveHashPlan(wideRow).getInt(0) == hiveHashEval)
+    val nestedRow = InternalRow(wideRow)
+    val nestedSchema = new StructType().add("nested", schema)
+    checkResult(nestedSchema, nestedRow)
   }
 
   test("SPARK-22284: Compute hash for nested structs") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to