Repository: spark
Updated Branches:
  refs/heads/master 1e07fff24 -> b70e483cb


[SPARK-22617][SQL] make splitExpressions extract current input of the context

## What changes were proposed in this pull request?

Mostly when we call `CodegenContext.splitExpressions`, we want to split the 
code into methods and pass the current inputs of the codegen context to these 
methods so that the code in these methods can still be evaluated.

This PR makes the expectation clear, while still keep the advanced version of 
`splitExpressions` to customize the inputs to pass to generated methods.

## How was this patch tested?

existing test

Author: Wenchen Fan <wenc...@databricks.com>

Closes #19827 from cloud-fan/codegen.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b70e483c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b70e483c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b70e483c

Branch: refs/heads/master
Commit: b70e483cb32d07eaab80739cd0cfcd8fe922547c
Parents: 1e07fff
Author: Wenchen Fan <wenc...@databricks.com>
Authored: Tue Nov 28 22:57:30 2017 +0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Tue Nov 28 22:57:30 2017 +0800

----------------------------------------------------------------------
 .../sql/catalyst/expressions/arithmetic.scala   |  4 +-
 .../expressions/codegen/CodeGenerator.scala     | 13 ++--
 .../codegen/GenerateMutableProjection.scala     |  4 +-
 .../codegen/GenerateSafeProjection.scala        | 37 ++++++-----
 .../codegen/GenerateUnsafeProjection.scala      | 67 ++++++++++----------
 .../expressions/complexTypeCreator.scala        | 31 ++++-----
 .../sql/catalyst/expressions/generators.scala   |  2 +-
 .../spark/sql/catalyst/expressions/hash.scala   | 26 +++++---
 .../catalyst/expressions/nullExpressions.scala  |  2 +-
 .../catalyst/expressions/objects/objects.scala  |  6 +-
 .../expressions/stringExpressions.scala         |  2 +-
 11 files changed, 108 insertions(+), 86 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b70e483c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index e5a1096..d98f7b3 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -614,7 +614,7 @@ case class Least(children: Seq[Expression]) extends 
Expression {
         }
       """
     }
-    val codes = ctx.splitExpressions(ctx.INPUT_ROW, 
evalChildren.map(updateEval))
+    val codes = ctx.splitExpressions(evalChildren.map(updateEval))
     ev.copy(code = s"""
       ${ev.isNull} = true;
       ${ev.value} = ${ctx.defaultValue(dataType)};
@@ -680,7 +680,7 @@ case class Greatest(children: Seq[Expression]) extends 
Expression {
         }
       """
     }
-    val codes = ctx.splitExpressions(ctx.INPUT_ROW, 
evalChildren.map(updateEval))
+    val codes = ctx.splitExpressions(evalChildren.map(updateEval))
     ev.copy(code = s"""
       ${ev.isNull} = true;
       ${ev.value} = ${ctx.defaultValue(dataType)};

http://git-wip-us.apache.org/repos/asf/spark/blob/b70e483c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 0498e61..668c816 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -781,15 +781,18 @@ class CodegenContext {
    * beyond 1000kb, we declare a private, inner sub-class, and the function is 
inlined to it
    * instead, because classes have a constant pool limit of 65,536 named 
values.
    *
-   * @param row the variable name of row that is used by expressions
+   * Note that we will extract the current inputs of this context and pass 
them to the generated
+   * functions. The input is `INPUT_ROW` for normal codegen path, and 
`currentVars` for whole
+   * stage codegen path. Whole stage codegen path is not supported yet.
+   *
    * @param expressions the codes to evaluate expressions.
    */
-  def splitExpressions(row: String, expressions: Seq[String]): String = {
-    if (row == null || currentVars != null) {
-      // Cannot split these expressions because they are not created from a 
row object.
+  def splitExpressions(expressions: Seq[String]): String = {
+    // TODO: support whole stage codegen
+    if (INPUT_ROW == null || currentVars != null) {
       return expressions.mkString("\n")
     }
-    splitExpressions(expressions, funcName = "apply", arguments = 
("InternalRow", row) :: Nil)
+    splitExpressions(expressions, funcName = "apply", arguments = 
("InternalRow", INPUT_ROW) :: Nil)
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/b70e483c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index 802e8bd..5fdbda5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -91,8 +91,8 @@ object GenerateMutableProjection extends 
CodeGenerator[Seq[Expression], MutableP
         ctx.updateColumn("mutableRow", e.dataType, i, ev, e.nullable)
     }
 
-    val allProjections = ctx.splitExpressions(ctx.INPUT_ROW, projectionCodes)
-    val allUpdates = ctx.splitExpressions(ctx.INPUT_ROW, updates)
+    val allProjections = ctx.splitExpressions(projectionCodes)
+    val allUpdates = ctx.splitExpressions(updates)
 
     val codeBody = s"""
       public java.lang.Object generate(Object[] references) {

http://git-wip-us.apache.org/repos/asf/spark/blob/b70e483c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
index 1e4ac3f..5d35cce 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
@@ -45,7 +45,8 @@ object GenerateSafeProjection extends 
CodeGenerator[Seq[Expression], Projection]
       ctx: CodegenContext,
       input: String,
       schema: StructType): ExprCode = {
-    val tmp = ctx.freshName("tmp")
+    // Puts `input` in a local variable to avoid to re-evaluate it if it's a 
statement.
+    val tmpInput = ctx.freshName("tmpInput")
     val output = ctx.freshName("safeRow")
     val values = ctx.freshName("values")
     // These expressions could be split into multiple functions
@@ -54,17 +55,21 @@ object GenerateSafeProjection extends 
CodeGenerator[Seq[Expression], Projection]
     val rowClass = classOf[GenericInternalRow].getName
 
     val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) 
=>
-      val converter = convertToSafe(ctx, ctx.getValue(tmp, dt, i.toString), dt)
+      val converter = convertToSafe(ctx, ctx.getValue(tmpInput, dt, 
i.toString), dt)
       s"""
-        if (!$tmp.isNullAt($i)) {
+        if (!$tmpInput.isNullAt($i)) {
           ${converter.code}
           $values[$i] = ${converter.value};
         }
       """
     }
-    val allFields = ctx.splitExpressions(tmp, fieldWriters)
+    val allFields = ctx.splitExpressions(
+      expressions = fieldWriters,
+      funcName = "writeFields",
+      arguments = Seq("InternalRow" -> tmpInput)
+    )
     val code = s"""
-      final InternalRow $tmp = $input;
+      final InternalRow $tmpInput = $input;
       $values = new Object[${schema.length}];
       $allFields
       final InternalRow $output = new $rowClass($values);
@@ -78,20 +83,22 @@ object GenerateSafeProjection extends 
CodeGenerator[Seq[Expression], Projection]
       ctx: CodegenContext,
       input: String,
       elementType: DataType): ExprCode = {
-    val tmp = ctx.freshName("tmp")
+    // Puts `input` in a local variable to avoid to re-evaluate it if it's a 
statement.
+    val tmpInput = ctx.freshName("tmpInput")
     val output = ctx.freshName("safeArray")
     val values = ctx.freshName("values")
     val numElements = ctx.freshName("numElements")
     val index = ctx.freshName("index")
     val arrayClass = classOf[GenericArrayData].getName
 
-    val elementConverter = convertToSafe(ctx, ctx.getValue(tmp, elementType, 
index), elementType)
+    val elementConverter = convertToSafe(
+      ctx, ctx.getValue(tmpInput, elementType, index), elementType)
     val code = s"""
-      final ArrayData $tmp = $input;
-      final int $numElements = $tmp.numElements();
+      final ArrayData $tmpInput = $input;
+      final int $numElements = $tmpInput.numElements();
       final Object[] $values = new Object[$numElements];
       for (int $index = 0; $index < $numElements; $index++) {
-        if (!$tmp.isNullAt($index)) {
+        if (!$tmpInput.isNullAt($index)) {
           ${elementConverter.code}
           $values[$index] = ${elementConverter.value};
         }
@@ -107,14 +114,14 @@ object GenerateSafeProjection extends 
CodeGenerator[Seq[Expression], Projection]
       input: String,
       keyType: DataType,
       valueType: DataType): ExprCode = {
-    val tmp = ctx.freshName("tmp")
+    val tmpInput = ctx.freshName("tmpInput")
     val output = ctx.freshName("safeMap")
     val mapClass = classOf[ArrayBasedMapData].getName
 
-    val keyConverter = createCodeForArray(ctx, s"$tmp.keyArray()", keyType)
-    val valueConverter = createCodeForArray(ctx, s"$tmp.valueArray()", 
valueType)
+    val keyConverter = createCodeForArray(ctx, s"$tmpInput.keyArray()", 
keyType)
+    val valueConverter = createCodeForArray(ctx, s"$tmpInput.valueArray()", 
valueType)
     val code = s"""
-      final MapData $tmp = $input;
+      final MapData $tmpInput = $input;
       ${keyConverter.code}
       ${valueConverter.code}
       final MapData $output = new $mapClass(${keyConverter.value}, 
${valueConverter.value});
@@ -152,7 +159,7 @@ object GenerateSafeProjection extends 
CodeGenerator[Seq[Expression], Projection]
             }
           """
     }
-    val allExpressions = ctx.splitExpressions(ctx.INPUT_ROW, expressionCodes)
+    val allExpressions = ctx.splitExpressions(expressionCodes)
 
     val codeBody = s"""
       public java.lang.Object generate(Object[] references) {

http://git-wip-us.apache.org/repos/asf/spark/blob/b70e483c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 4bd50ae..b022457 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -36,7 +36,7 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
     case NullType => true
     case t: AtomicType => true
     case _: CalendarIntervalType => true
-    case t: StructType => t.toSeq.forall(field => canSupport(field.dataType))
+    case t: StructType => t.forall(field => canSupport(field.dataType))
     case t: ArrayType if canSupport(t.elementType) => true
     case MapType(kt, vt, _) if canSupport(kt) && canSupport(vt) => true
     case udt: UserDefinedType[_] => canSupport(udt.sqlType)
@@ -49,25 +49,18 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
       input: String,
       fieldTypes: Seq[DataType],
       bufferHolder: String): String = {
+    // Puts `input` in a local variable to avoid to re-evaluate it if it's a 
statement.
+    val tmpInput = ctx.freshName("tmpInput")
     val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) =>
-      val javaType = ctx.javaType(dt)
-      val isNullVar = ctx.freshName("isNull")
-      val valueVar = ctx.freshName("value")
-      val defaultValue = ctx.defaultValue(dt)
-      val readValue = ctx.getValue(input, dt, i.toString)
-      val code =
-        s"""
-          boolean $isNullVar = $input.isNullAt($i);
-          $javaType $valueVar = $isNullVar ? $defaultValue : $readValue;
-        """
-      ExprCode(code, isNullVar, valueVar)
+      ExprCode("", s"$tmpInput.isNullAt($i)", ctx.getValue(tmpInput, dt, 
i.toString))
     }
 
     s"""
-      if ($input instanceof UnsafeRow) {
-        ${writeUnsafeData(ctx, s"((UnsafeRow) $input)", bufferHolder)}
+      final InternalRow $tmpInput = $input;
+      if ($tmpInput instanceof UnsafeRow) {
+        ${writeUnsafeData(ctx, s"((UnsafeRow) $tmpInput)", bufferHolder)}
       } else {
-        ${writeExpressionsToBuffer(ctx, input, fieldEvals, fieldTypes, 
bufferHolder)}
+        ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes, 
bufferHolder)}
       }
     """
   }
@@ -167,9 +160,20 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
         }
     }
 
+    val writeFieldsCode = if (isTopLevel && (row == null || ctx.currentVars != 
null)) {
+      // TODO: support whole stage codegen
+      writeFields.mkString("\n")
+    } else {
+      assert(row != null, "the input row name cannot be null when generating 
code to write it.")
+      ctx.splitExpressions(
+        expressions = writeFields,
+        funcName = "writeFields",
+        arguments = Seq("InternalRow" -> row))
+    }
+
     s"""
       $resetWriter
-      ${ctx.splitExpressions(row, writeFields)}
+      $writeFieldsCode
     """.trim
   }
 
@@ -179,13 +183,14 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
       input: String,
       elementType: DataType,
       bufferHolder: String): String = {
+    // Puts `input` in a local variable to avoid to re-evaluate it if it's a 
statement.
+    val tmpInput = ctx.freshName("tmpInput")
     val arrayWriterClass = classOf[UnsafeArrayWriter].getName
     val arrayWriter = ctx.freshName("arrayWriter")
     ctx.addMutableState(arrayWriterClass, arrayWriter,
       s"$arrayWriter = new $arrayWriterClass();")
     val numElements = ctx.freshName("numElements")
     val index = ctx.freshName("index")
-    val element = ctx.freshName("element")
 
     val et = elementType match {
       case udt: UserDefinedType[_] => udt.sqlType
@@ -201,6 +206,7 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
     }
 
     val tmpCursor = ctx.freshName("tmpCursor")
+    val element = ctx.getValue(tmpInput, et, index)
     val writeElement = et match {
       case t: StructType =>
         s"""
@@ -233,17 +239,17 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
 
     val primitiveTypeName = if (ctx.isPrimitiveType(jt)) 
ctx.primitiveTypeName(et) else ""
     s"""
-      if ($input instanceof UnsafeArrayData) {
-        ${writeUnsafeData(ctx, s"((UnsafeArrayData) $input)", bufferHolder)}
+      final ArrayData $tmpInput = $input;
+      if ($tmpInput instanceof UnsafeArrayData) {
+        ${writeUnsafeData(ctx, s"((UnsafeArrayData) $tmpInput)", bufferHolder)}
       } else {
-        final int $numElements = $input.numElements();
+        final int $numElements = $tmpInput.numElements();
         $arrayWriter.initialize($bufferHolder, $numElements, 
$elementOrOffsetSize);
 
         for (int $index = 0; $index < $numElements; $index++) {
-          if ($input.isNullAt($index)) {
+          if ($tmpInput.isNullAt($index)) {
             $arrayWriter.setNull$primitiveTypeName($index);
           } else {
-            final $jt $element = ${ctx.getValue(input, et, index)};
             $writeElement
           }
         }
@@ -258,19 +264,16 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
       keyType: DataType,
       valueType: DataType,
       bufferHolder: String): String = {
-    val keys = ctx.freshName("keys")
-    val values = ctx.freshName("values")
+    // Puts `input` in a local variable to avoid to re-evaluate it if it's a 
statement.
+    val tmpInput = ctx.freshName("tmpInput")
     val tmpCursor = ctx.freshName("tmpCursor")
 
-
     // Writes out unsafe map according to the format described in 
`UnsafeMapData`.
     s"""
-      if ($input instanceof UnsafeMapData) {
-        ${writeUnsafeData(ctx, s"((UnsafeMapData) $input)", bufferHolder)}
+      final MapData $tmpInput = $input;
+      if ($tmpInput instanceof UnsafeMapData) {
+        ${writeUnsafeData(ctx, s"((UnsafeMapData) $tmpInput)", bufferHolder)}
       } else {
-        final ArrayData $keys = $input.keyArray();
-        final ArrayData $values = $input.valueArray();
-
         // preserve 8 bytes to write the key array numBytes later.
         $bufferHolder.grow(8);
         $bufferHolder.cursor += 8;
@@ -278,11 +281,11 @@ object GenerateUnsafeProjection extends 
CodeGenerator[Seq[Expression], UnsafePro
         // Remember the current cursor so that we can write numBytes of key 
array later.
         final int $tmpCursor = $bufferHolder.cursor;
 
-        ${writeArrayToBuffer(ctx, keys, keyType, bufferHolder)}
+        ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, 
bufferHolder)}
         // Write the numBytes of key array into the first 8 bytes.
         Platform.putLong($bufferHolder.buffer, $tmpCursor - 8, 
$bufferHolder.cursor - $tmpCursor);
 
-        ${writeArrayToBuffer(ctx, values, valueType, bufferHolder)}
+        ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, 
bufferHolder)}
       }
     """
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/b70e483c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 2a00d57..57a7f2e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -63,7 +63,7 @@ case class CreateArray(children: Seq[Expression]) extends 
Expression {
     val (preprocess, assigns, postprocess, arrayData) =
       GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false)
     ev.copy(
-      code = preprocess + ctx.splitExpressions(ctx.INPUT_ROW, assigns) + 
postprocess,
+      code = preprocess + ctx.splitExpressions(assigns) + postprocess,
       value = arrayData,
       isNull = "false")
   }
@@ -216,10 +216,10 @@ case class CreateMap(children: Seq[Expression]) extends 
Expression {
       s"""
        final boolean ${ev.isNull} = false;
        $preprocessKeyData
-       ${ctx.splitExpressions(ctx.INPUT_ROW, assignKeys)}
+       ${ctx.splitExpressions(assignKeys)}
        $postprocessKeyData
        $preprocessValueData
-       ${ctx.splitExpressions(ctx.INPUT_ROW, assignValues)}
+       ${ctx.splitExpressions(assignValues)}
        $postprocessValueData
        final MapData ${ev.value} = new $mapClass($keyArrayData, 
$valueArrayData);
       """
@@ -351,24 +351,25 @@ case class CreateNamedStruct(children: Seq[Expression]) 
extends CreateNamedStruc
     val rowClass = classOf[GenericInternalRow].getName
     val values = ctx.freshName("values")
     ctx.addMutableState("Object[]", values, s"$values = null;")
-
-    ev.copy(code = s"""
-      $values = new Object[${valExprs.size}];""" +
-      ctx.splitExpressions(
-        ctx.INPUT_ROW,
-        valExprs.zipWithIndex.map { case (e, i) =>
-          val eval = e.genCode(ctx)
-          eval.code + s"""
+    val valuesCode = ctx.splitExpressions(
+      valExprs.zipWithIndex.map { case (e, i) =>
+        val eval = e.genCode(ctx)
+        s"""
+          ${eval.code}
           if (${eval.isNull}) {
             $values[$i] = null;
           } else {
             $values[$i] = ${eval.value};
           }"""
-        }) +
+      })
+
+    ev.copy(code =
       s"""
-        final InternalRow ${ev.value} = new $rowClass($values);
-        $values = null;
-      """, isNull = "false")
+         |$values = new Object[${valExprs.size}];
+         |$valuesCode
+         |final InternalRow ${ev.value} = new $rowClass($values);
+         |$values = null;
+       """.stripMargin, isNull = "false")
   }
 
   override def prettyName: String = "named_struct"

http://git-wip-us.apache.org/repos/asf/spark/blob/b70e483c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index 8618f49..f1aa130 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -203,7 +203,7 @@ case class Stack(children: Seq[Expression]) extends 
Generator {
     ctx.addMutableState("InternalRow[]", rowData, s"$rowData = new 
InternalRow[$numRows];")
     val values = children.tail
     val dataTypes = values.take(numFields).map(_.dataType)
-    val code = ctx.splitExpressions(ctx.INPUT_ROW, Seq.tabulate(numRows) { row 
=>
+    val code = ctx.splitExpressions(Seq.tabulate(numRows) { row =>
       val fields = Seq.tabulate(numFields) { col =>
         val index = row * numFields + col
         if (index < values.length) values(index) else Literal(null, 
dataTypes(col))

http://git-wip-us.apache.org/repos/asf/spark/blob/b70e483c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
index 9e0786e..c3289b8 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
@@ -270,7 +270,7 @@ abstract class HashExpression[E] extends Expression {
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     ev.isNull = "false"
-    val childrenHash = ctx.splitExpressions(ctx.INPUT_ROW, children.map { 
child =>
+    val childrenHash = ctx.splitExpressions(children.map { child =>
       val childGen = child.genCode(ctx)
       childGen.code + ctx.nullSafeExec(child.nullable, childGen.isNull) {
         computeHash(childGen.value, child.dataType, ev.value, ctx)
@@ -330,9 +330,9 @@ abstract class HashExpression[E] extends Expression {
     } else {
       val bytes = ctx.freshName("bytes")
       s"""
-            final byte[] $bytes = 
$input.toJavaBigDecimal().unscaledValue().toByteArray();
-            ${genHashBytes(bytes, result)}
-          """
+         |final byte[] $bytes = 
$input.toJavaBigDecimal().unscaledValue().toByteArray();
+         |${genHashBytes(bytes, result)}
+       """.stripMargin
     }
   }
 
@@ -392,7 +392,10 @@ abstract class HashExpression[E] extends Expression {
     val hashes = fields.zipWithIndex.map { case (field, index) =>
       nullSafeElementHash(input, index.toString, field.nullable, 
field.dataType, result, ctx)
     }
-    ctx.splitExpressions(input, hashes)
+    ctx.splitExpressions(
+      expressions = hashes,
+      funcName = "getHash",
+      arguments = Seq("InternalRow" -> input))
   }
 
   @tailrec
@@ -608,12 +611,17 @@ case class HiveHash(children: Seq[Expression]) extends 
HashExpression[Int] {
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     ev.isNull = "false"
     val childHash = ctx.freshName("childHash")
-    val childrenHash = ctx.splitExpressions(ctx.INPUT_ROW, children.map { 
child =>
+    val childrenHash = ctx.splitExpressions(children.map { child =>
       val childGen = child.genCode(ctx)
-      childGen.code + ctx.nullSafeExec(child.nullable, childGen.isNull) {
+      val codeToComputeHash = ctx.nullSafeExec(child.nullable, 
childGen.isNull) {
         computeHash(childGen.value, child.dataType, childHash, ctx)
-      } + s"${ev.value} = (31 * ${ev.value}) + $childHash;" +
-        s"\n$childHash = 0;"
+      }
+      s"""
+         |${childGen.code}
+         |$codeToComputeHash
+         |${ev.value} = (31 * ${ev.value}) + $childHash;
+         |$childHash = 0;
+       """.stripMargin
     })
 
     ctx.addMutableState(ctx.javaType(dataType), ev.value)

http://git-wip-us.apache.org/repos/asf/spark/blob/b70e483c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
index 5eaf3f2..173e171 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
@@ -91,7 +91,7 @@ case class Coalesce(children: Seq[Expression]) extends 
Expression {
     ev.copy(code = s"""
       ${ev.isNull} = true;
       ${ev.value} = ${ctx.defaultValue(dataType)};
-      ${ctx.splitExpressions(ctx.INPUT_ROW, evals)}""")
+      ${ctx.splitExpressions(evals)}""")
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b70e483c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 006d37f..e2bc79d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -101,7 +101,7 @@ trait InvokeLike extends Expression with NonSQLExpression {
         """
       }
     }
-    val argCode = ctx.splitExpressions(ctx.INPUT_ROW, argCodes)
+    val argCode = ctx.splitExpressions(argCodes)
 
     (argCode, argValues.mkString(", "), resultIsNull)
   }
@@ -1119,7 +1119,7 @@ case class CreateExternalRow(children: Seq[Expression], 
schema: StructType)
          """
     }
 
-    val childrenCode = ctx.splitExpressions(ctx.INPUT_ROW, childrenCodes)
+    val childrenCode = ctx.splitExpressions(childrenCodes)
     val schemaField = ctx.addReferenceObj("schema", schema)
 
     val code = s"""
@@ -1254,7 +1254,7 @@ case class InitializeJavaBean(beanInstance: Expression, 
setters: Map[String, Exp
            ${javaBeanInstance}.$setterMethod(${fieldGen.value});
          """
     }
-    val initializeCode = ctx.splitExpressions(ctx.INPUT_ROW, initialize.toSeq)
+    val initializeCode = ctx.splitExpressions(initialize.toSeq)
 
     val code = s"""
       ${instanceGen.code}

http://git-wip-us.apache.org/repos/asf/spark/blob/b70e483c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index d629eb7..ee5cf92 100755
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -208,7 +208,7 @@ case class ConcatWs(children: Seq[Expression])
         }
       }.unzip
 
-      val codes = ctx.splitExpressions(ctx.INPUT_ROW, evals.map(_.code))
+      val codes = ctx.splitExpressions(evals.map(_.code))
       val varargCounts = ctx.splitExpressions(
         expressions = varargCount,
         funcName = "varargCountsConcatWs",


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to