Repository: spark
Updated Branches:
  refs/heads/master 6f41c593b -> 813c0f945


[SPARK-22704][SQL] Least and Greatest use less global variables

## What changes were proposed in this pull request?

This PR accomplishes the following two items.

1. Reduce # of global variables from two to one
2. Make lifetime of global variable local within an operation

Item 1. reduces # of constant pool entries in a Java class. Item 2. ensures 
that an variable is not passed to arguments in a method split by 
`CodegenContext.splitExpressions()`, which is addressed by #19865.

## How was this patch tested?

Added new test into `ArithmeticExpressionSuite`

Author: Kazuaki Ishizaki <ishiz...@jp.ibm.com>

Closes #19899 from kiszk/SPARK-22704.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/813c0f94
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/813c0f94
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/813c0f94

Branch: refs/heads/master
Commit: 813c0f945d7f03800975eaed26b86a1f30e513c9
Parents: 6f41c59
Author: Kazuaki Ishizaki <ishiz...@jp.ibm.com>
Authored: Thu Dec 7 00:45:51 2017 +0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Thu Dec 7 00:45:51 2017 +0800

----------------------------------------------------------------------
 .../sql/catalyst/expressions/arithmetic.scala   | 94 +++++++++++++-------
 .../expressions/ArithmeticExpressionSuite.scala | 11 +++
 2 files changed, 73 insertions(+), 32 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/813c0f94/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 739bd13..1893eec 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -602,23 +602,38 @@ case class Least(children: Seq[Expression]) extends 
Expression {
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val evalChildren = children.map(_.genCode(ctx))
-    ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
-    ctx.addMutableState(ctx.javaType(dataType), ev.value)
-    def updateEval(eval: ExprCode): String = {
+    val tmpIsNull = ctx.freshName("leastTmpIsNull")
+    ctx.addMutableState(ctx.JAVA_BOOLEAN, tmpIsNull)
+    val evals = evalChildren.map(eval =>
       s"""
-        ${eval.code}
-        if (!${eval.isNull} && (${ev.isNull} ||
-          ${ctx.genGreater(dataType, ev.value, eval.value)})) {
-          ${ev.isNull} = false;
-          ${ev.value} = ${eval.value};
-        }
-      """
-    }
-    val codes = 
ctx.splitExpressionsWithCurrentInputs(evalChildren.map(updateEval))
-    ev.copy(code = s"""
-      ${ev.isNull} = true;
-      ${ev.value} = ${ctx.defaultValue(dataType)};
-      $codes""")
+         |${eval.code}
+         |if (!${eval.isNull} && ($tmpIsNull ||
+         |  ${ctx.genGreater(dataType, ev.value, eval.value)})) {
+         |  $tmpIsNull = false;
+         |  ${ev.value} = ${eval.value};
+         |}
+      """.stripMargin
+    )
+
+    val resultType = ctx.javaType(dataType)
+    val codes = ctx.splitExpressionsWithCurrentInputs(
+      expressions = evals,
+      funcName = "least",
+      extraArguments = Seq(resultType -> ev.value),
+      returnType = resultType,
+      makeSplitFunction = body =>
+        s"""
+          |$body
+          |return ${ev.value};
+        """.stripMargin,
+      foldFunctions = _.map(funcCall => s"${ev.value} = 
$funcCall;").mkString("\n"))
+    ev.copy(code =
+      s"""
+         |$tmpIsNull = true;
+         |${ctx.javaType(dataType)} ${ev.value} = 
${ctx.defaultValue(dataType)};
+         |$codes
+         |final boolean ${ev.isNull} = $tmpIsNull;
+      """.stripMargin)
   }
 }
 
@@ -668,22 +683,37 @@ case class Greatest(children: Seq[Expression]) extends 
Expression {
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val evalChildren = children.map(_.genCode(ctx))
-    ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
-    ctx.addMutableState(ctx.javaType(dataType), ev.value)
-    def updateEval(eval: ExprCode): String = {
+    val tmpIsNull = ctx.freshName("greatestTmpIsNull")
+    ctx.addMutableState(ctx.JAVA_BOOLEAN, tmpIsNull)
+    val evals = evalChildren.map(eval =>
       s"""
-        ${eval.code}
-        if (!${eval.isNull} && (${ev.isNull} ||
-          ${ctx.genGreater(dataType, eval.value, ev.value)})) {
-          ${ev.isNull} = false;
-          ${ev.value} = ${eval.value};
-        }
-      """
-    }
-    val codes = 
ctx.splitExpressionsWithCurrentInputs(evalChildren.map(updateEval))
-    ev.copy(code = s"""
-      ${ev.isNull} = true;
-      ${ev.value} = ${ctx.defaultValue(dataType)};
-      $codes""")
+         |${eval.code}
+         |if (!${eval.isNull} && ($tmpIsNull ||
+         |  ${ctx.genGreater(dataType, eval.value, ev.value)})) {
+         |  $tmpIsNull = false;
+         |  ${ev.value} = ${eval.value};
+         |}
+      """.stripMargin
+    )
+
+    val resultType = ctx.javaType(dataType)
+    val codes = ctx.splitExpressionsWithCurrentInputs(
+      expressions = evals,
+      funcName = "greatest",
+      extraArguments = Seq(resultType -> ev.value),
+      returnType = resultType,
+      makeSplitFunction = body =>
+        s"""
+           |$body
+           |return ${ev.value};
+        """.stripMargin,
+      foldFunctions = _.map(funcCall => s"${ev.value} = 
$funcCall;").mkString("\n"))
+    ev.copy(code =
+      s"""
+         |$tmpIsNull = true;
+         |${ctx.javaType(dataType)} ${ev.value} = 
${ctx.defaultValue(dataType)};
+         |$codes
+         |final boolean ${ev.isNull} = $tmpIsNull;
+      """.stripMargin)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/813c0f94/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
index fb759eb..be638d8 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
@@ -23,6 +23,7 @@ import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
 import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
 import org.apache.spark.sql.types._
 
 class ArithmeticExpressionSuite extends SparkFunSuite with 
ExpressionEvalHelper {
@@ -343,4 +344,14 @@ class ArithmeticExpressionSuite extends SparkFunSuite with 
ExpressionEvalHelper
     checkEvaluation(Least(inputsExpr), "s" * 1, EmptyRow)
     checkEvaluation(Greatest(inputsExpr), "s" * N, EmptyRow)
   }
+
+  test("SPARK-22704: Least and greatest use less global variables") {
+    val ctx1 = new CodegenContext()
+    Least(Seq(Literal(1), Literal(1))).genCode(ctx1)
+    assert(ctx1.mutableStates.size == 1)
+
+    val ctx2 = new CodegenContext()
+    Greatest(Seq(Literal(1), Literal(1))).genCode(ctx2)
+    assert(ctx2.mutableStates.size == 1)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to