Repository: spark
Updated Branches:
  refs/heads/master 4c2efde93 -> 8a0ed5a5e


[SPARK-22668][SQL] Ensure no global variables in arguments of method split by 
CodegenContext.splitExpressions()

## What changes were proposed in this pull request?

Passing global variables to the split method is dangerous, as any mutating to 
it is ignored and may lead to unexpected behavior.

To prevent this, one approach is to make sure no expression would output global 
variables: Localizing lifetime of mutable states in expressions.

Another approach is, when calling `ctx.splitExpression`, make sure we don't use 
children's output as parameter names.

Approach 1 is actually hard to do, as we need to check all expressions and 
operators that support whole-stage codegen. Approach 2 is easier as the callers 
of `ctx.splitExpressions` are not too many.

Besides, approach 2 is more flexible, as children's output may be other stuff 
that can't be parameter name: literal, inlined statement(a + 1), etc.

close https://github.com/apache/spark/pull/19865
close https://github.com/apache/spark/pull/19938

## How was this patch tested?

existing tests

Author: Wenchen Fan <wenc...@databricks.com>

Closes #20021 from cloud-fan/codegen.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8a0ed5a5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8a0ed5a5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8a0ed5a5

Branch: refs/heads/master
Commit: 8a0ed5a5ee64a6e854c516f80df5a9729435479b
Parents: 4c2efde
Author: Wenchen Fan <wenc...@databricks.com>
Authored: Fri Dec 22 00:21:27 2017 +0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Fri Dec 22 00:21:27 2017 +0800

----------------------------------------------------------------------
 .../sql/catalyst/expressions/arithmetic.scala   | 18 +++++------
 .../expressions/codegen/CodeGenerator.scala     | 32 +++++++++++++++++---
 .../expressions/conditionalExpressions.scala    |  8 ++---
 .../catalyst/expressions/nullExpressions.scala  |  9 +++---
 .../sql/catalyst/expressions/predicates.scala   |  2 +-
 5 files changed, 43 insertions(+), 26 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8a0ed5a5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index d3a8cb5..8bb1459 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -602,13 +602,13 @@ case class Least(children: Seq[Expression]) extends 
Expression {
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val evalChildren = children.map(_.genCode(ctx))
-    val tmpIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "leastTmpIsNull")
+    ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
     val evals = evalChildren.map(eval =>
       s"""
          |${eval.code}
-         |if (!${eval.isNull} && ($tmpIsNull ||
+         |if (!${eval.isNull} && (${ev.isNull} ||
          |  ${ctx.genGreater(dataType, ev.value, eval.value)})) {
-         |  $tmpIsNull = false;
+         |  ${ev.isNull} = false;
          |  ${ev.value} = ${eval.value};
          |}
       """.stripMargin
@@ -628,10 +628,9 @@ case class Least(children: Seq[Expression]) extends 
Expression {
       foldFunctions = _.map(funcCall => s"${ev.value} = 
$funcCall;").mkString("\n"))
     ev.copy(code =
       s"""
-         |$tmpIsNull = true;
+         |${ev.isNull} = true;
          |${ctx.javaType(dataType)} ${ev.value} = 
${ctx.defaultValue(dataType)};
          |$codes
-         |final boolean ${ev.isNull} = $tmpIsNull;
       """.stripMargin)
   }
 }
@@ -682,13 +681,13 @@ case class Greatest(children: Seq[Expression]) extends 
Expression {
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val evalChildren = children.map(_.genCode(ctx))
-    val tmpIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "greatestTmpIsNull")
+    ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
     val evals = evalChildren.map(eval =>
       s"""
          |${eval.code}
-         |if (!${eval.isNull} && ($tmpIsNull ||
+         |if (!${eval.isNull} && (${ev.isNull} ||
          |  ${ctx.genGreater(dataType, eval.value, ev.value)})) {
-         |  $tmpIsNull = false;
+         |  ${ev.isNull} = false;
          |  ${ev.value} = ${eval.value};
          |}
       """.stripMargin
@@ -708,10 +707,9 @@ case class Greatest(children: Seq[Expression]) extends 
Expression {
       foldFunctions = _.map(funcCall => s"${ev.value} = 
$funcCall;").mkString("\n"))
     ev.copy(code =
       s"""
-         |$tmpIsNull = true;
+         |${ev.isNull} = true;
          |${ctx.javaType(dataType)} ${ev.value} = 
${ctx.defaultValue(dataType)};
          |$codes
-         |final boolean ${ev.isNull} = $tmpIsNull;
       """.stripMargin)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/8a0ed5a5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 41a920b..9adf632 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -128,7 +128,7 @@ class CodegenContext {
    * `currentVars` to null, or set `currentVars(i)` to null for certain 
columns, before calling
    * `Expression.genCode`.
    */
-  final var INPUT_ROW = "i"
+  var INPUT_ROW = "i"
 
   /**
    * Holding a list of generated columns as input of current operator, will be 
used by
@@ -146,22 +146,30 @@ class CodegenContext {
    * as a member variable
    *
    * They will be kept as member variables in generated classes like 
`SpecificProjection`.
+   *
+   * Exposed for tests only.
    */
-  val inlinedMutableStates: mutable.ArrayBuffer[(String, String)] =
+  private[catalyst] val inlinedMutableStates: mutable.ArrayBuffer[(String, 
String)] =
     mutable.ArrayBuffer.empty[(String, String)]
 
   /**
    * The mapping between mutable state types and corrseponding compacted 
arrays.
    * The keys are java type string. The values are [[MutableStateArrays]] 
which encapsulates
    * the compacted arrays for the mutable states with the same java type.
+   *
+   * Exposed for tests only.
    */
-  val arrayCompactedMutableStates: mutable.Map[String, MutableStateArrays] =
+  private[catalyst] val arrayCompactedMutableStates: mutable.Map[String, 
MutableStateArrays] =
     mutable.Map.empty[String, MutableStateArrays]
 
   // An array holds the code that will initialize each state
-  val mutableStateInitCode: mutable.ArrayBuffer[String] =
+  // Exposed for tests only.
+  private[catalyst] val mutableStateInitCode: mutable.ArrayBuffer[String] =
     mutable.ArrayBuffer.empty[String]
 
+  // Tracks the names of all the mutable states.
+  private val mutableStateNames: mutable.HashSet[String] = 
mutable.HashSet.empty
+
   /**
    * This class holds a set of names of mutableStateArrays that is used for 
compacting mutable
    * states for a certain type, and holds the next available slot of the 
current compacted array.
@@ -172,7 +180,11 @@ class CodegenContext {
 
     private[this] var currentIndex = 0
 
-    private def createNewArray() = 
arrayNames.append(freshName("mutableStateArray"))
+    private def createNewArray() = {
+      val newArrayName = freshName("mutableStateArray")
+      mutableStateNames += newArrayName
+      arrayNames.append(newArrayName)
+    }
 
     def getCurrentIndex: Int = currentIndex
 
@@ -241,6 +253,7 @@ class CodegenContext {
       val initCode = initFunc(varName)
       inlinedMutableStates += ((javaType, varName))
       mutableStateInitCode += initCode
+      mutableStateNames += varName
       varName
     } else {
       val arrays = arrayCompactedMutableStates.getOrElseUpdate(javaType, new 
MutableStateArrays)
@@ -930,6 +943,15 @@ class CodegenContext {
       // inline execution if only one block
       blocks.head
     } else {
+      if (Utils.isTesting) {
+        // Passing global variables to the split method is dangerous, as any 
mutating to it is
+        // ignored and may lead to unexpected behavior.
+        arguments.foreach { case (_, name) =>
+          assert(!mutableStateNames.contains(name),
+            s"split function argument $name cannot be a global variable.")
+        }
+      }
+
       val func = freshName(funcName)
       val argString = arguments.map { case (t, name) => s"$t $name" 
}.mkString(", ")
       val functions = blocks.zipWithIndex.map { case (body, i) =>

http://git-wip-us.apache.org/repos/asf/spark/blob/8a0ed5a5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index 1a9b682..142dfb0 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -190,7 +190,7 @@ case class CaseWhen(
     // It is initialized to `NOT_MATCHED`, and if it's set to `HAS_NULL` or 
`HAS_NONNULL`,
     // We won't go on anymore on the computation.
     val resultState = ctx.freshName("caseWhenResultState")
-    val tmpResult = ctx.addMutableState(ctx.javaType(dataType), 
"caseWhenTmpResult")
+    ev.value = ctx.addMutableState(ctx.javaType(dataType), ev.value)
 
     // these blocks are meant to be inside a
     // do {
@@ -205,7 +205,7 @@ case class CaseWhen(
          |if (!${cond.isNull} && ${cond.value}) {
          |  ${res.code}
          |  $resultState = (byte)(${res.isNull} ? $HAS_NULL : $HAS_NONNULL);
-         |  $tmpResult = ${res.value};
+         |  ${ev.value} = ${res.value};
          |  continue;
          |}
        """.stripMargin
@@ -216,7 +216,7 @@ case class CaseWhen(
       s"""
          |${res.code}
          |$resultState = (byte)(${res.isNull} ? $HAS_NULL : $HAS_NONNULL);
-         |$tmpResult = ${res.value};
+         |${ev.value} = ${res.value};
        """.stripMargin
     }
 
@@ -264,13 +264,11 @@ case class CaseWhen(
     ev.copy(code =
       s"""
          |${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED;
-         |$tmpResult = ${ctx.defaultValue(dataType)};
          |do {
          |  $codes
          |} while (false);
          |// TRUE if any condition is met and the result is null, or no any 
condition is met.
          |final boolean ${ev.isNull} = ($resultState != $HAS_NONNULL);
-         |final ${ctx.javaType(dataType)} ${ev.value} = $tmpResult;
        """.stripMargin)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/8a0ed5a5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
index b4f895f..470d5da 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
@@ -72,7 +72,7 @@ case class Coalesce(children: Seq[Expression]) extends 
Expression {
   }
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
-    val tmpIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "coalesceTmpIsNull")
+    ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
 
     // all the evals are meant to be in a do { ... } while (false); loop
     val evals = children.map { e =>
@@ -80,7 +80,7 @@ case class Coalesce(children: Seq[Expression]) extends 
Expression {
       s"""
          |${eval.code}
          |if (!${eval.isNull}) {
-         |  $tmpIsNull = false;
+         |  ${ev.isNull} = false;
          |  ${ev.value} = ${eval.value};
          |  continue;
          |}
@@ -103,7 +103,7 @@ case class Coalesce(children: Seq[Expression]) extends 
Expression {
       foldFunctions = _.map { funcCall =>
         s"""
            |${ev.value} = $funcCall;
-           |if (!$tmpIsNull) {
+           |if (!${ev.isNull}) {
            |  continue;
            |}
          """.stripMargin
@@ -112,12 +112,11 @@ case class Coalesce(children: Seq[Expression]) extends 
Expression {
 
     ev.copy(code =
       s"""
-         |$tmpIsNull = true;
+         |${ev.isNull} = true;
          |$resultType ${ev.value} = ${ctx.defaultValue(dataType)};
          |do {
          |  $codes
          |} while (false);
-         |final boolean ${ev.isNull} = $tmpIsNull;
        """.stripMargin)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/8a0ed5a5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index ac9f56f..f4ee3d1 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -285,7 +285,7 @@ case class In(value: Expression, list: Seq[Expression]) 
extends Predicate {
          |${valueGen.code}
          |byte $tmpResult = $HAS_NULL;
          |if (!${valueGen.isNull}) {
-         |  $tmpResult = 0;
+         |  $tmpResult = $NOT_MATCHED;
          |  $javaDataType $valueArg = ${valueGen.value};
          |  do {
          |    $codes


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to