Repository: spark
Updated Branches:
  refs/heads/master a1877f45c -> 70221903f


[SPARK-22596][SQL] set ctx.currentVars in CodegenSupport.consume

## What changes were proposed in this pull request?

`ctx.currentVars` means the input variables for the current operator, which is 
already decided in `CodegenSupport`, we can set it there instead of `doConsume`.

also add more comments to help people understand the codegen framework.

After this PR, we now have a principle about setting `ctx.currentVars` and 
`ctx.INPUT_ROW`:
1. for non-whole-stage-codegen path, never set them. (permit some special cases 
like generating ordering)
2. for whole-stage-codegen `produce` path, mostly we don't need to set them, 
but blocking operators may need to set them for expressions that produce data 
from data source, sort buffer, aggregate buffer, etc.
3. for whole-stage-codegen `consume` path, mostly we don't need to set them 
because `currentVars` is automatically set to child input variables and 
`INPUT_ROW` is mostly not used. A few plans need to tweak them as they may have 
different inputs, or they use the input row.

## How was this patch tested?

existing tests.

Author: Wenchen Fan <wenc...@databricks.com>

Closes #19803 from cloud-fan/codegen.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/70221903
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/70221903
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/70221903

Branch: refs/heads/master
Commit: 70221903f54eaa0514d5d189dfb6f175a62228a8
Parents: a1877f4
Author: Wenchen Fan <wenc...@databricks.com>
Authored: Fri Nov 24 21:50:30 2017 -0800
Committer: gatorsmile <gatorsm...@gmail.com>
Committed: Fri Nov 24 21:50:30 2017 -0800

----------------------------------------------------------------------
 .../catalyst/expressions/BoundAttribute.scala   | 23 +++++++++--------
 .../expressions/codegen/CodeGenerator.scala     | 14 +++++++---
 .../sql/execution/DataSourceScanExec.scala      | 14 +++++-----
 .../apache/spark/sql/execution/ExpandExec.scala |  3 ---
 .../spark/sql/execution/GenerateExec.scala      |  2 --
 .../sql/execution/WholeStageCodegenExec.scala   | 27 +++++++++++++++-----
 .../sql/execution/basicPhysicalOperators.scala  |  6 +----
 .../apache/spark/sql/execution/objects.scala    | 20 +++++----------
 8 files changed, 59 insertions(+), 50 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/70221903/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index 7d16118..6a17a39 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -59,21 +59,24 @@ case class BoundReference(ordinal: Int, dataType: DataType, 
nullable: Boolean)
   }
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
-    val javaType = ctx.javaType(dataType)
-    val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
     if (ctx.currentVars != null && ctx.currentVars(ordinal) != null) {
       val oev = ctx.currentVars(ordinal)
       ev.isNull = oev.isNull
       ev.value = oev.value
-      val code = oev.code
-      oev.code = ""
-      ev.copy(code = code)
-    } else if (nullable) {
-      ev.copy(code = s"""
-        boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal);
-        $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : 
($value);""")
+      ev.copy(code = oev.code)
     } else {
-      ev.copy(code = s"""$javaType ${ev.value} = $value;""", isNull = "false")
+      assert(ctx.INPUT_ROW != null, "INPUT_ROW and currentVars cannot both be 
null.")
+      val javaType = ctx.javaType(dataType)
+      val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
+      if (nullable) {
+        ev.copy(code =
+          s"""
+             |boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal);
+             |$javaType ${ev.value} = ${ev.isNull} ? 
${ctx.defaultValue(dataType)} : ($value);
+           """.stripMargin)
+      } else {
+        ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = "false")
+      }
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/70221903/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 9df8a8d..0498e61 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -134,6 +134,17 @@ class CodegenContext {
   }
 
   /**
+   * Holding the variable name of the input row of the current operator, will 
be used by
+   * `BoundReference` to generate code.
+   *
+   * Note that if `currentVars` is not null, `BoundReference` prefers 
`currentVars` over `INPUT_ROW`
+   * to generate code. If you want to make sure the generated code use 
`INPUT_ROW`, you need to set
+   * `currentVars` to null, or set `currentVars(i)` to null for certain 
columns, before calling
+   * `Expression.genCode`.
+   */
+  final var INPUT_ROW = "i"
+
+  /**
    * Holding a list of generated columns as input of current operator, will be 
used by
    * BoundReference to generate code.
    */
@@ -386,9 +397,6 @@ class CodegenContext {
   final val JAVA_FLOAT = "float"
   final val JAVA_DOUBLE = "double"
 
-  /** The variable name of the input row in generated code. */
-  final var INPUT_ROW = "i"
-
   /**
    * The map from a variable name to it's next ID.
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/70221903/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
index a477c23..747749b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
@@ -123,7 +123,7 @@ case class RowDataSourceScanExec(
        |while ($input.hasNext()) {
        |  InternalRow $row = (InternalRow) $input.next();
        |  $numOutputRows.add(1);
-       |  ${consume(ctx, columnsRowInput, null).trim}
+       |  ${consume(ctx, columnsRowInput).trim}
        |  if (shouldStop()) return;
        |}
      """.stripMargin
@@ -355,19 +355,21 @@ case class FileSourceScanExec(
     // PhysicalRDD always just has one input
     val input = ctx.freshName("input")
     ctx.addMutableState("scala.collection.Iterator", input, s"$input = 
inputs[0];")
-    val exprRows = output.zipWithIndex.map{ case (a, i) =>
-      BoundReference(i, a.dataType, a.nullable)
-    }
     val row = ctx.freshName("row")
+
     ctx.INPUT_ROW = row
     ctx.currentVars = null
-    val columnsRowInput = exprRows.map(_.genCode(ctx))
+    // Always provide `outputVars`, so that the framework can help us build 
unsafe row if the input
+    // row is not unsafe row, i.e. `needsUnsafeRowConversion` is true.
+    val outputVars = output.zipWithIndex.map{ case (a, i) =>
+      BoundReference(i, a.dataType, a.nullable).genCode(ctx)
+    }
     val inputRow = if (needsUnsafeRowConversion) null else row
     s"""
        |while ($input.hasNext()) {
        |  InternalRow $row = (InternalRow) $input.next();
        |  $numOutputRows.add(1);
-       |  ${consume(ctx, columnsRowInput, inputRow).trim}
+       |  ${consume(ctx, outputVars, inputRow).trim}
        |  if (shouldStop()) return;
        |}
      """.stripMargin

http://git-wip-us.apache.org/repos/asf/spark/blob/70221903/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
index 33849f4..a7bd5eb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
@@ -133,9 +133,6 @@ case class ExpandExec(
      * size explosion.
      */
 
-    // Set input variables
-    ctx.currentVars = input
-
     // Tracks whether a column has the same output for all rows.
     // Size of sameOutput array should equal N.
     // If sameOutput(i) is true, then the i-th column has the same value for 
all output rows given

http://git-wip-us.apache.org/repos/asf/spark/blob/70221903/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
index c142d3b..e1562be 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
@@ -135,8 +135,6 @@ case class GenerateExec(
   override def needCopyResult: Boolean = true
 
   override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: 
ExprCode): String = {
-    ctx.currentVars = input
-
     // Add input rows to the values when we are joining
     val values = if (join) {
       input

http://git-wip-us.apache.org/repos/asf/spark/blob/70221903/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
index 16b5706..7166b77 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
@@ -108,20 +108,22 @@ trait CodegenSupport extends SparkPlan {
 
   /**
    * Consume the generated columns or row from current SparkPlan, call its 
parent's `doConsume()`.
+   *
+   * Note that `outputVars` and `row` can't both be null.
    */
   final def consume(ctx: CodegenContext, outputVars: Seq[ExprCode], row: 
String = null): String = {
     val inputVars =
-      if (row != null) {
+      if (outputVars != null) {
+        assert(outputVars.length == output.length)
+        // outputVars will be used to generate the code for UnsafeRow, so we 
should copy them
+        outputVars.map(_.copy())
+      } else {
+        assert(row != null, "outputVars and row cannot both be null.")
         ctx.currentVars = null
         ctx.INPUT_ROW = row
         output.zipWithIndex.map { case (attr, i) =>
           BoundReference(i, attr.dataType, attr.nullable).genCode(ctx)
         }
-      } else {
-        assert(outputVars != null)
-        assert(outputVars.length == output.length)
-        // outputVars will be used to generate the code for UnsafeRow, so we 
should copy them
-        outputVars.map(_.copy())
       }
 
     val rowVar = if (row != null) {
@@ -147,6 +149,11 @@ trait CodegenSupport extends SparkPlan {
       }
     }
 
+    // Set up the `currentVars` in the codegen context, as we generate the 
code of `inputVars`
+    // before calling `parent.doConsume`. We can't set up `INPUT_ROW`, because 
parent needs to
+    // generate code of `rowVar` manually.
+    ctx.currentVars = inputVars
+    ctx.INPUT_ROW = null
     ctx.freshNamePrefix = parent.variablePrefix
     val evaluated = evaluateRequiredVariables(output, inputVars, 
parent.usedInputs)
     s"""
@@ -193,7 +200,8 @@ trait CodegenSupport extends SparkPlan {
   def usedInputs: AttributeSet = references
 
   /**
-   * Generate the Java source code to process the rows from child SparkPlan.
+   * Generate the Java source code to process the rows from child SparkPlan. 
This should only be
+   * called from `consume`.
    *
    * This should be override by subclass to support codegen.
    *
@@ -207,6 +215,11 @@ trait CodegenSupport extends SparkPlan {
    *   }
    *
    * Note: A plan can either consume the rows as UnsafeRow (row), or a list of 
variables (input).
+   *       When consuming as a listing of variables, the code to produce the 
input is already
+   *       generated and `CodegenContext.currentVars` is already set. When 
consuming as UnsafeRow,
+   *       implementations need to put `row.code` in the generated code and set
+   *       `CodegenContext.INPUT_ROW` manually. Some plans may need more 
tweaks as they have
+   *       different inputs(join build side, aggregate buffer, etc.), or other 
special cases.
    */
   def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): 
String = {
     throw new UnsupportedOperationException

http://git-wip-us.apache.org/repos/asf/spark/blob/70221903/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index f205bdf..c9a1514 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -56,9 +56,7 @@ case class ProjectExec(projectList: Seq[NamedExpression], 
child: SparkPlan)
   }
 
   override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: 
ExprCode): String = {
-    val exprs = projectList.map(x =>
-      ExpressionCanonicalizer.execute(BindReferences.bindReference(x, 
child.output)))
-    ctx.currentVars = input
+    val exprs = projectList.map(x => 
BindReferences.bindReference[Expression](x, child.output))
     val resultVars = exprs.map(_.genCode(ctx))
     // Evaluation of non-deterministic expressions can't be deferred.
     val nonDeterministicAttrs = 
projectList.filterNot(_.deterministic).map(_.toAttribute)
@@ -152,8 +150,6 @@ case class FilterExec(condition: Expression, child: 
SparkPlan)
        """.stripMargin
     }
 
-    ctx.currentVars = input
-
     // To generate the predicates we will follow this algorithm.
     // For each predicate that is not IsNotNull, we will generate them one by 
one loading attributes
     // as necessary. For each of both attributes, if there is an IsNotNull 
predicate we will

http://git-wip-us.apache.org/repos/asf/spark/blob/70221903/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
index d861109..d1bd8a7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
@@ -81,11 +81,8 @@ case class DeserializeToObjectExec(
   }
 
   override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: 
ExprCode): String = {
-    val bound = ExpressionCanonicalizer.execute(
-      BindReferences.bindReference(deserializer, child.output))
-    ctx.currentVars = input
-    val resultVars = bound.genCode(ctx) :: Nil
-    consume(ctx, resultVars)
+    val resultObj = BindReferences.bindReference(deserializer, 
child.output).genCode(ctx)
+    consume(ctx, resultObj :: Nil)
   }
 
   override protected def doExecute(): RDD[InternalRow] = {
@@ -118,11 +115,9 @@ case class SerializeFromObjectExec(
   }
 
   override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: 
ExprCode): String = {
-    val bound = serializer.map { expr =>
-      ExpressionCanonicalizer.execute(BindReferences.bindReference(expr, 
child.output))
+    val resultVars = serializer.map { expr =>
+      BindReferences.bindReference[Expression](expr, child.output).genCode(ctx)
     }
-    ctx.currentVars = input
-    val resultVars = bound.map(_.genCode(ctx))
     consume(ctx, resultVars)
   }
 
@@ -224,12 +219,9 @@ case class MapElementsExec(
     val funcObj = Literal.create(func, ObjectType(funcClass))
     val callFunc = Invoke(funcObj, methodName, outputObjAttr.dataType, 
child.output)
 
-    val bound = ExpressionCanonicalizer.execute(
-      BindReferences.bindReference(callFunc, child.output))
-    ctx.currentVars = input
-    val resultVars = bound.genCode(ctx) :: Nil
+    val result = BindReferences.bindReference(callFunc, 
child.output).genCode(ctx)
 
-    consume(ctx, resultVars)
+    consume(ctx, result :: Nil)
   }
 
   override protected def doExecute(): RDD[InternalRow] = {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to