Repository: spark
Updated Branches:
  refs/heads/master 03fdc92e4 -> ced6ccf0d


[SPARK-22701][SQL] add ctx.splitExpressionsWithCurrentInputs

## What changes were proposed in this pull request?

This pattern appears many times in the codebase:
```
if (ctx.INPUT_ROW == null || ctx.currentVars != null) {
  exprs.mkString("\n")
} else {
  ctx.splitExpressions(...)
}
```

This PR adds a `ctx.splitExpressionsWithCurrentInputs` for this pattern

## How was this patch tested?

existing tests

Author: Wenchen Fan <wenc...@databricks.com>

Closes #19895 from cloud-fan/splitExpression.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ced6ccf0
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ced6ccf0
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ced6ccf0

Branch: refs/heads/master
Commit: ced6ccf0d6f362e299f270ed2a474f2e14f845da
Parents: 03fdc92
Author: Wenchen Fan <wenc...@databricks.com>
Authored: Tue Dec 5 10:15:15 2017 -0800
Committer: gatorsmile <gatorsm...@gmail.com>
Committed: Tue Dec 5 10:15:15 2017 -0800

----------------------------------------------------------------------
 .../sql/catalyst/expressions/arithmetic.scala   |  4 +-
 .../expressions/codegen/CodeGenerator.scala     | 44 ++++-----
 .../codegen/GenerateMutableProjection.scala     |  4 +-
 .../codegen/GenerateSafeProjection.scala        |  2 +-
 .../expressions/complexTypeCreator.scala        |  6 +-
 .../expressions/conditionalExpressions.scala    | 84 ++++++++---------
 .../sql/catalyst/expressions/generators.scala   |  2 +-
 .../spark/sql/catalyst/expressions/hash.scala   | 55 +++++-------
 .../catalyst/expressions/nullExpressions.scala  | 94 +++++++++-----------
 .../catalyst/expressions/objects/objects.scala  |  6 +-
 .../sql/catalyst/expressions/predicates.scala   | 47 +++++-----
 .../expressions/stringExpressions.scala         | 37 ++++----
 12 files changed, 179 insertions(+), 206 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ced6ccf0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index d98f7b3..739bd13 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -614,7 +614,7 @@ case class Least(children: Seq[Expression]) extends 
Expression {
         }
       """
     }
-    val codes = ctx.splitExpressions(evalChildren.map(updateEval))
+    val codes = 
ctx.splitExpressionsWithCurrentInputs(evalChildren.map(updateEval))
     ev.copy(code = s"""
       ${ev.isNull} = true;
       ${ev.value} = ${ctx.defaultValue(dataType)};
@@ -680,7 +680,7 @@ case class Greatest(children: Seq[Expression]) extends 
Expression {
         }
       """
     }
-    val codes = ctx.splitExpressions(evalChildren.map(updateEval))
+    val codes = 
ctx.splitExpressionsWithCurrentInputs(evalChildren.map(updateEval))
     ev.copy(code = s"""
       ${ev.isNull} = true;
       ${ev.value} = ${ctx.defaultValue(dataType)};

http://git-wip-us.apache.org/repos/asf/spark/blob/ced6ccf0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 1645db1..670c82e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -781,29 +781,26 @@ class CodegenContext {
    * beyond 1000kb, we declare a private, inner sub-class, and the function is 
inlined to it
    * instead, because classes have a constant pool limit of 65,536 named 
values.
    *
-   * Note that we will extract the current inputs of this context and pass 
them to the generated
-   * functions. The input is `INPUT_ROW` for normal codegen path, and 
`currentVars` for whole
-   * stage codegen path. Whole stage codegen path is not supported yet.
-   *
-   * @param expressions the codes to evaluate expressions.
-   */
-  def splitExpressions(expressions: Seq[String]): String = {
-    splitExpressions(expressions, funcName = "apply", extraArguments = Nil)
-  }
-
-  /**
-   * Similar to [[splitExpressions(expressions: Seq[String])]], but has 
customized function name
-   * and extra arguments.
+   * Note that different from `splitExpressions`, we will extract the current 
inputs of this
+   * context and pass them to the generated functions. The input is 
`INPUT_ROW` for normal codegen
+   * path, and `currentVars` for whole stage codegen path. Whole stage codegen 
path is not
+   * supported yet.
    *
    * @param expressions the codes to evaluate expressions.
    * @param funcName the split function name base.
-   * @param extraArguments the list of (type, name) of the arguments of the 
split function
-   *                       except for ctx.INPUT_ROW
-  */
-  def splitExpressions(
+   * @param extraArguments the list of (type, name) of the arguments of the 
split function,
+   *                       except for the current inputs like `ctx.INPUT_ROW`.
+   * @param returnType the return type of the split function.
+   * @param makeSplitFunction makes split function body, e.g. add preparation 
or cleanup.
+   * @param foldFunctions folds the split function calls.
+   */
+  def splitExpressionsWithCurrentInputs(
       expressions: Seq[String],
-      funcName: String,
-      extraArguments: Seq[(String, String)]): String = {
+      funcName: String = "apply",
+      extraArguments: Seq[(String, String)] = Nil,
+      returnType: String = "void",
+      makeSplitFunction: String => String = identity,
+      foldFunctions: Seq[String] => String = _.mkString("", ";\n", ";")): 
String = {
     // TODO: support whole stage codegen
     if (INPUT_ROW == null || currentVars != null) {
       expressions.mkString("\n")
@@ -811,13 +808,18 @@ class CodegenContext {
       splitExpressions(
         expressions,
         funcName,
-        arguments = ("InternalRow", INPUT_ROW) +: extraArguments)
+        ("InternalRow", INPUT_ROW) +: extraArguments,
+        returnType,
+        makeSplitFunction,
+        foldFunctions)
     }
   }
 
   /**
    * Splits the generated code of expressions into multiple functions, because 
function has
-   * 64kb code size limit in JVM
+   * 64kb code size limit in JVM. If the class to which the function would be 
inlined would grow
+   * beyond 1000kb, we declare a private, inner sub-class, and the function is 
inlined to it
+   * instead, because classes have a constant pool limit of 65,536 named 
values.
    *
    * @param expressions the codes to evaluate expressions.
    * @param funcName the split function name base.

http://git-wip-us.apache.org/repos/asf/spark/blob/ced6ccf0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index 5fdbda5..bd8312e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -91,8 +91,8 @@ object GenerateMutableProjection extends 
CodeGenerator[Seq[Expression], MutableP
         ctx.updateColumn("mutableRow", e.dataType, i, ev, e.nullable)
     }
 
-    val allProjections = ctx.splitExpressions(projectionCodes)
-    val allUpdates = ctx.splitExpressions(updates)
+    val allProjections = ctx.splitExpressionsWithCurrentInputs(projectionCodes)
+    val allUpdates = ctx.splitExpressionsWithCurrentInputs(updates)
 
     val codeBody = s"""
       public java.lang.Object generate(Object[] references) {

http://git-wip-us.apache.org/repos/asf/spark/blob/ced6ccf0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
index 5d35cce..44e7148 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
@@ -159,7 +159,7 @@ object GenerateSafeProjection extends 
CodeGenerator[Seq[Expression], Projection]
             }
           """
     }
-    val allExpressions = ctx.splitExpressions(expressionCodes)
+    val allExpressions = ctx.splitExpressionsWithCurrentInputs(expressionCodes)
 
     val codeBody = s"""
       public java.lang.Object generate(Object[] references) {

http://git-wip-us.apache.org/repos/asf/spark/blob/ced6ccf0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index fc68bf4..087b210 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -108,7 +108,7 @@ private [sql] object GenArrayData {
          }
        """
       }
-      val assignmentString = ctx.splitExpressions(
+      val assignmentString = ctx.splitExpressionsWithCurrentInputs(
         expressions = assignments,
         funcName = "apply",
         extraArguments = ("Object[]", arrayDataName) :: Nil)
@@ -139,7 +139,7 @@ private [sql] object GenArrayData {
          }
        """
       }
-      val assignmentString = ctx.splitExpressions(
+      val assignmentString = ctx.splitExpressionsWithCurrentInputs(
         expressions = assignments,
         funcName = "apply",
         extraArguments = ("UnsafeArrayData", arrayDataName) :: Nil)
@@ -357,7 +357,7 @@ case class CreateNamedStruct(children: Seq[Expression]) 
extends CreateNamedStruc
     val rowClass = classOf[GenericInternalRow].getName
     val values = ctx.freshName("values")
     ctx.addMutableState("Object[]", values, s"$values = null;")
-    val valuesCode = ctx.splitExpressions(
+    val valuesCode = ctx.splitExpressionsWithCurrentInputs(
       valExprs.zipWithIndex.map { case (e, i) =>
         val eval = e.genCode(ctx)
         s"""

http://git-wip-us.apache.org/repos/asf/spark/blob/ced6ccf0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index 43e6431..ae5f714 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -219,57 +219,51 @@ case class CaseWhen(
 
     val allConditions = cases ++ elseCode
 
-    val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) {
-        allConditions.mkString("\n")
-      } else {
-        // This generates code like:
-        //   conditionMet = caseWhen_1(i);
-        //   if(conditionMet) {
-        //     continue;
-        //   }
-        //   conditionMet = caseWhen_2(i);
-        //   if(conditionMet) {
-        //     continue;
-        //   }
-        //   ...
-        // and the declared methods are:
-        //   private boolean caseWhen_1234() {
-        //     boolean conditionMet = false;
-        //     do {
-        //       // here the evaluation of the conditions
-        //     } while (false);
-        //     return conditionMet;
-        //   }
-        ctx.splitExpressions(allConditions, "caseWhen",
-          ("InternalRow", ctx.INPUT_ROW) :: Nil,
-          returnType = ctx.JAVA_BOOLEAN,
-          makeSplitFunction = {
-            func =>
-              s"""
-                ${ctx.JAVA_BOOLEAN} $conditionMet = false;
-                do {
-                  $func
-                } while (false);
-                return $conditionMet;
-              """
-          },
-          foldFunctions = { funcCalls =>
-            funcCalls.map { funcCall =>
-              s"""
-                $conditionMet = $funcCall;
-                if ($conditionMet) {
-                  continue;
-                }"""
-            }.mkString
-          })
-      }
+    // This generates code like:
+    //   conditionMet = caseWhen_1(i);
+    //   if(conditionMet) {
+    //     continue;
+    //   }
+    //   conditionMet = caseWhen_2(i);
+    //   if(conditionMet) {
+    //     continue;
+    //   }
+    //   ...
+    // and the declared methods are:
+    //   private boolean caseWhen_1234() {
+    //     boolean conditionMet = false;
+    //     do {
+    //       // here the evaluation of the conditions
+    //     } while (false);
+    //     return conditionMet;
+    //   }
+    val codes = ctx.splitExpressionsWithCurrentInputs(
+      expressions = allConditions,
+      funcName = "caseWhen",
+      returnType = ctx.JAVA_BOOLEAN,
+      makeSplitFunction = func =>
+        s"""
+           |${ctx.JAVA_BOOLEAN} $conditionMet = false;
+           |do {
+           |  $func
+           |} while (false);
+           |return $conditionMet;
+         """.stripMargin,
+      foldFunctions = _.map { funcCall =>
+        s"""
+           |$conditionMet = $funcCall;
+           |if ($conditionMet) {
+           |  continue;
+           |}
+         """.stripMargin
+      }.mkString)
 
     ev.copy(code = s"""
       ${ev.isNull} = true;
       ${ev.value} = ${ctx.defaultValue(dataType)};
       ${ctx.JAVA_BOOLEAN} $conditionMet = false;
       do {
-        $code
+        $codes
       } while (false);""")
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/ced6ccf0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index f1aa130..cd38783 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -203,7 +203,7 @@ case class Stack(children: Seq[Expression]) extends 
Generator {
     ctx.addMutableState("InternalRow[]", rowData, s"$rowData = new 
InternalRow[$numRows];")
     val values = children.tail
     val dataTypes = values.take(numFields).map(_.dataType)
-    val code = ctx.splitExpressions(Seq.tabulate(numRows) { row =>
+    val code = ctx.splitExpressionsWithCurrentInputs(Seq.tabulate(numRows) { 
row =>
       val fields = Seq.tabulate(numFields) { col =>
         val index = row * numFields + col
         if (index < values.length) values(index) else Literal(null, 
dataTypes(col))

http://git-wip-us.apache.org/repos/asf/spark/blob/ced6ccf0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
index d0ed2ab..055ebf6 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
@@ -279,21 +279,17 @@ abstract class HashExpression[E] extends Expression {
     }
 
     val hashResultType = ctx.javaType(dataType)
-    val codes = if (ctx.INPUT_ROW == null || ctx.currentVars != null) {
-      childrenHash.mkString("\n")
-    } else {
-      ctx.splitExpressions(
-        expressions = childrenHash,
-        funcName = "computeHash",
-        arguments = Seq("InternalRow" -> ctx.INPUT_ROW, hashResultType -> 
ev.value),
-        returnType = hashResultType,
-        makeSplitFunction = body =>
-          s"""
-             |$body
-             |return ${ev.value};
-           """.stripMargin,
-        foldFunctions = _.map(funcCall => s"${ev.value} = 
$funcCall;").mkString("\n"))
-    }
+    val codes = ctx.splitExpressionsWithCurrentInputs(
+      expressions = childrenHash,
+      funcName = "computeHash",
+      extraArguments = Seq(hashResultType -> ev.value),
+      returnType = hashResultType,
+      makeSplitFunction = body =>
+        s"""
+           |$body
+           |return ${ev.value};
+         """.stripMargin,
+      foldFunctions = _.map(funcCall => s"${ev.value} = 
$funcCall;").mkString("\n"))
 
     ev.copy(code =
       s"""
@@ -652,22 +648,19 @@ case class HiveHash(children: Seq[Expression]) extends 
HashExpression[Int] {
        """.stripMargin
     }
 
-    val codes = if (ctx.INPUT_ROW == null || ctx.currentVars != null) {
-      childrenHash.mkString("\n")
-    } else {
-      ctx.splitExpressions(
-        expressions = childrenHash,
-        funcName = "computeHash",
-        arguments = Seq("InternalRow" -> ctx.INPUT_ROW, ctx.JAVA_INT -> 
ev.value),
-        returnType = ctx.JAVA_INT,
-        makeSplitFunction = body =>
-          s"""
-             |${ctx.JAVA_INT} $childHash = 0;
-             |$body
-             |return ${ev.value};
-           """.stripMargin,
-        foldFunctions = _.map(funcCall => s"${ev.value} = 
$funcCall;").mkString("\n"))
-    }
+    val codes = ctx.splitExpressionsWithCurrentInputs(
+      expressions = childrenHash,
+      funcName = "computeHash",
+      extraArguments = Seq(ctx.JAVA_INT -> ev.value),
+      returnType = ctx.JAVA_INT,
+      makeSplitFunction = body =>
+        s"""
+           |${ctx.JAVA_INT} $childHash = 0;
+           |$body
+           |return ${ev.value};
+         """.stripMargin,
+      foldFunctions = _.map(funcCall => s"${ev.value} = 
$funcCall;").mkString("\n"))
+
 
     ev.copy(code =
       s"""

http://git-wip-us.apache.org/repos/asf/spark/blob/ced6ccf0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
index 3b52a0e..26c9a41 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
@@ -87,37 +87,32 @@ case class Coalesce(children: Seq[Expression]) extends 
Expression {
          |}
        """.stripMargin
     }
-    val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) {
-        evals.mkString("\n")
-      } else {
-        ctx.splitExpressions(evals, "coalesce",
-          ("InternalRow", ctx.INPUT_ROW) :: Nil,
-          makeSplitFunction = {
-            func =>
-              s"""
-                |do {
-                |  $func
-                |} while (false);
-              """.stripMargin
-          },
-          foldFunctions = { funcCalls =>
-            funcCalls.map { funcCall =>
-              s"""
-                 |$funcCall;
-                 |if (!${ev.isNull}) {
-                 |  continue;
-                 |}
-               """.stripMargin
-            }.mkString
-          })
-      }
+
+    val codes = ctx.splitExpressionsWithCurrentInputs(
+      expressions = evals,
+      funcName = "coalesce",
+      makeSplitFunction = func =>
+        s"""
+           |do {
+           |  $func
+           |} while (false);
+         """.stripMargin,
+      foldFunctions = _.map { funcCall =>
+        s"""
+           |$funcCall;
+           |if (!${ev.isNull}) {
+           |  continue;
+           |}
+         """.stripMargin
+      }.mkString)
+
 
     ev.copy(code =
       s"""
          |${ev.isNull} = true;
          |${ev.value} = ${ctx.defaultValue(dataType)};
          |do {
-         |  $code
+         |  $codes
          |} while (false);
        """.stripMargin)
   }
@@ -415,39 +410,32 @@ case class AtLeastNNonNulls(n: Int, children: 
Seq[Expression]) extends Predicate
       }
     }
 
-    val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) {
-        evals.mkString("\n")
-      } else {
-        ctx.splitExpressions(
-          expressions = evals,
-          funcName = "atLeastNNonNulls",
-          arguments = ("InternalRow", ctx.INPUT_ROW) :: (ctx.JAVA_INT, 
nonnull) :: Nil,
-          returnType = ctx.JAVA_INT,
-          makeSplitFunction = { body =>
-            s"""
-               |do {
-               |  $body
-               |} while (false);
-               |return $nonnull;
-             """.stripMargin
-          },
-          foldFunctions = { funcCalls =>
-            funcCalls.map(funcCall =>
-              s"""
-                 |$nonnull = $funcCall;
-                 |if ($nonnull >= $n) {
-                 |  continue;
-                 |}
-               """.stripMargin).mkString("\n")
-          }
-        )
-      }
+    val codes = ctx.splitExpressionsWithCurrentInputs(
+      expressions = evals,
+      funcName = "atLeastNNonNulls",
+      extraArguments = (ctx.JAVA_INT, nonnull) :: Nil,
+      returnType = ctx.JAVA_INT,
+      makeSplitFunction = body =>
+        s"""
+           |do {
+           |  $body
+           |} while (false);
+           |return $nonnull;
+         """.stripMargin,
+      foldFunctions = _.map { funcCall =>
+        s"""
+           |$nonnull = $funcCall;
+           |if ($nonnull >= $n) {
+           |  continue;
+           |}
+         """.stripMargin
+      }.mkString)
 
     ev.copy(code =
       s"""
          |${ctx.JAVA_INT} $nonnull = 0;
          |do {
-         |  $code
+         |  $codes
          |} while (false);
          |${ctx.JAVA_BOOLEAN} ${ev.value} = $nonnull >= $n;
        """.stripMargin, isNull = "false")

http://git-wip-us.apache.org/repos/asf/spark/blob/ced6ccf0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index e2bc79d..730b2ff 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -101,7 +101,7 @@ trait InvokeLike extends Expression with NonSQLExpression {
         """
       }
     }
-    val argCode = ctx.splitExpressions(argCodes)
+    val argCode = ctx.splitExpressionsWithCurrentInputs(argCodes)
 
     (argCode, argValues.mkString(", "), resultIsNull)
   }
@@ -1119,7 +1119,7 @@ case class CreateExternalRow(children: Seq[Expression], 
schema: StructType)
          """
     }
 
-    val childrenCode = ctx.splitExpressions(childrenCodes)
+    val childrenCode = ctx.splitExpressionsWithCurrentInputs(childrenCodes)
     val schemaField = ctx.addReferenceObj("schema", schema)
 
     val code = s"""
@@ -1254,7 +1254,7 @@ case class InitializeJavaBean(beanInstance: Expression, 
setters: Map[String, Exp
            ${javaBeanInstance}.$setterMethod(${fieldGen.value});
          """
     }
-    val initializeCode = ctx.splitExpressions(initialize.toSeq)
+    val initializeCode = 
ctx.splitExpressionsWithCurrentInputs(initialize.toSeq)
 
     val code = s"""
       ${instanceGen.code}

http://git-wip-us.apache.org/repos/asf/spark/blob/ced6ccf0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 75cc9b3..04e6694 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -253,31 +253,26 @@ case class In(value: Expression, list: Seq[Expression]) 
extends Predicate {
          |  continue;
          |}
        """.stripMargin)
-    val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) {
-        listCode.mkString("\n")
-      } else {
-        ctx.splitExpressions(
-          expressions = listCode,
-          funcName = "valueIn",
-          arguments = ("InternalRow", ctx.INPUT_ROW) :: (javaDataType, 
valueArg) :: Nil,
-          makeSplitFunction = { body =>
-            s"""
-               |do {
-               |  $body
-               |} while (false);
-             """.stripMargin
-          },
-          foldFunctions = { funcCalls =>
-            funcCalls.map(funcCall =>
-              s"""
-                 |$funcCall;
-                 |if (${ev.value}) {
-                 |  continue;
-                 |}
-               """.stripMargin).mkString("\n")
-          }
-        )
-      }
+
+    val codes = ctx.splitExpressionsWithCurrentInputs(
+      expressions = listCode,
+      funcName = "valueIn",
+      extraArguments = (javaDataType, valueArg) :: Nil,
+      makeSplitFunction = body =>
+        s"""
+           |do {
+           |  $body
+           |} while (false);
+         """.stripMargin,
+      foldFunctions = _.map { funcCall =>
+        s"""
+           |$funcCall;
+           |if (${ev.value}) {
+           |  continue;
+           |}
+         """.stripMargin
+      }.mkString("\n"))
+
     ev.copy(code =
       s"""
          |${valueGen.code}
@@ -286,7 +281,7 @@ case class In(value: Expression, list: Seq[Expression]) 
extends Predicate {
          |if (!${ev.isNull}) {
          |  $javaDataType $valueArg = ${valueGen.value};
          |  do {
-         |    $code
+         |    $codes
          |  } while (false);
          |}
        """.stripMargin)

http://git-wip-us.apache.org/repos/asf/spark/blob/ced6ccf0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 34917ac..47f0b57 100755
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -73,7 +73,7 @@ case class Concat(children: Seq[Expression]) extends 
Expression with ImplicitCas
         }
       """
     }
-    val codes = ctx.splitExpressions(
+    val codes = ctx.splitExpressionsWithCurrentInputs(
       expressions = inputs,
       funcName = "valueConcat",
       extraArguments = ("UTF8String[]", args) :: Nil)
@@ -152,7 +152,7 @@ case class ConcatWs(children: Seq[Expression])
           ""
         }
       }
-      val codes = ctx.splitExpressions(
+      val codes = ctx.splitExpressionsWithCurrentInputs(
           expressions = inputs,
           funcName = "valueConcatWs",
           extraArguments = ("UTF8String[]", args) :: Nil)
@@ -200,31 +200,32 @@ case class ConcatWs(children: Seq[Expression])
         }
       }.unzip
 
-      val codes = ctx.splitExpressions(evals.map(_.code))
-      val varargCounts = ctx.splitExpressions(
+      val codes = ctx.splitExpressionsWithCurrentInputs(evals.map(_.code))
+
+      val varargCounts = ctx.splitExpressionsWithCurrentInputs(
         expressions = varargCount,
         funcName = "varargCountsConcatWs",
-        arguments = ("InternalRow", ctx.INPUT_ROW) :: Nil,
         returnType = "int",
         makeSplitFunction = body =>
           s"""
-           int $varargNum = 0;
-           $body
-           return $varargNum;
-           """,
-        foldFunctions = _.mkString(s"$varargNum += ", s";\n$varargNum += ", 
";"))
-      val varargBuilds = ctx.splitExpressions(
+             |int $varargNum = 0;
+             |$body
+             |return $varargNum;
+           """.stripMargin,
+        foldFunctions = _.map(funcCall => s"$varargNum += 
$funcCall;").mkString("\n"))
+
+      val varargBuilds = ctx.splitExpressionsWithCurrentInputs(
         expressions = varargBuild,
         funcName = "varargBuildsConcatWs",
-        arguments =
-          ("InternalRow", ctx.INPUT_ROW) :: ("UTF8String []", array) :: 
("int", idxInVararg) :: Nil,
+        extraArguments = ("UTF8String []", array) :: ("int", idxInVararg) :: 
Nil,
         returnType = "int",
         makeSplitFunction = body =>
           s"""
-           $body
-           return $idxInVararg;
-           """,
-        foldFunctions = _.mkString(s"$idxInVararg = ", s";\n$idxInVararg = ", 
";"))
+             |$body
+             |return $idxInVararg;
+           """.stripMargin,
+        foldFunctions = _.map(funcCall => s"$idxInVararg = 
$funcCall;").mkString("\n"))
+
       ev.copy(
         s"""
         $codes
@@ -1380,7 +1381,7 @@ case class FormatString(children: Expression*) extends 
Expression with ImplicitC
          $argList[$index] = $value;
        """
     }
-    val argListCodes = ctx.splitExpressions(
+    val argListCodes = ctx.splitExpressionsWithCurrentInputs(
       expressions = argListCode,
       funcName = "valueFormatString",
       extraArguments = ("Object[]", argList) :: Nil)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to