Repository: spark
Updated Branches:
  refs/heads/master da7bfac48 -> 25bba58d1


[SPARK-13404] [SQL] Create variables for input row when it's actually used

## What changes were proposed in this pull request?

This PR change the way how we generate the code for the output variables 
passing from a plan to it's parent.

Right now, they are generated before call consume() of it's parent. It's not 
efficient, if the parent is a Filter or Join, which could filter out most the 
rows, the time to access some of the columns that are not used by the Filter or 
Join are wasted.

This PR try to improve this by defering the access of columns until they are 
actually used by a plan. After this PR, a plan does not need to generate code 
to evaluate the variables for output, just passing the ExprCode to its parent 
by `consume()`. In `parent.consumeChild()`, it will check the output from child 
and `usedInputs`, generate the code for those columns that is part of 
`usedInputs` before calling `doConsume()`.

This PR also change the `if` from
```
if (cond) {
  xxx
}
```
to
```
if (!cond) continue;
xxx
```
The new one could help to reduce the nested indents for multiple levels of 
Filter and BroadcastHashJoin.

It also added some comments for operators.

## How was the this patch tested?

Unit tests. Manually ran TPCDS Q55, this PR improve the performance about 30% 
(scale=10, from 2.56s to 1.96s)

Author: Davies Liu <dav...@databricks.com>

Closes #11274 from davies/gen_defer.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/25bba58d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/25bba58d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/25bba58d

Branch: refs/heads/master
Commit: 25bba58d160d0d24e40db1ca595200a52db922ed
Parents: da7bfac
Author: Davies Liu <dav...@databricks.com>
Authored: Mon Mar 7 20:09:08 2016 -0800
Committer: Davies Liu <davies....@gmail.com>
Committed: Mon Mar 7 20:09:08 2016 -0800

----------------------------------------------------------------------
 .../catalyst/expressions/BoundAttribute.scala   |  7 +-
 .../expressions/codegen/CodeGenerator.scala     |  2 +
 .../spark/sql/execution/ExistingRDD.scala       | 41 +++++----
 .../org/apache/spark/sql/execution/Expand.scala |  4 +-
 .../spark/sql/execution/WholeStageCodegen.scala | 96 ++++++++++++++------
 .../execution/aggregate/TungstenAggregate.scala | 50 +++++-----
 .../spark/sql/execution/basicOperators.scala    | 32 ++++---
 .../sql/execution/joins/BroadcastHashJoin.scala | 84 ++++++++---------
 .../sql/execution/joins/SortMergeJoin.scala     | 63 +++++++------
 9 files changed, 224 insertions(+), 155 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/25bba58d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index 4727ff1..72fe065 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -62,9 +62,10 @@ case class BoundReference(ordinal: Int, dataType: DataType, 
nullable: Boolean)
     val javaType = ctx.javaType(dataType)
     val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
     if (ctx.currentVars != null && ctx.currentVars(ordinal) != null) {
-      ev.isNull = ctx.currentVars(ordinal).isNull
-      ev.value = ctx.currentVars(ordinal).value
-      ""
+      val oev = ctx.currentVars(ordinal)
+      ev.isNull = oev.isNull
+      ev.value = oev.value
+      oev.code
     } else if (nullable) {
       s"""
         boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal);

http://git-wip-us.apache.org/repos/asf/spark/blob/25bba58d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 63e1956..c4265a7 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -37,6 +37,8 @@ import org.apache.spark.util.Utils
  * Java source for evaluating an [[Expression]] given a [[InternalRow]] of 
input.
  *
  * @param code The sequence of statements required to evaluate the expression.
+ *             It should be empty string, if `isNull` and `value` are already 
existed, or no code
+ *             needed to evaluate them (literals).
  * @param isNull A term that holds a boolean value representing whether the 
expression evaluated
  *                 to null.
  * @param value A term for a (possibly primitive) value of the result of the 
evaluation. Not

http://git-wip-us.apache.org/repos/asf/spark/blob/25bba58d/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index 4ad0750..3662ed7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -151,9 +151,6 @@ private[sql] case class PhysicalRDD(
     val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, 
x._1.dataType, true))
     val row = ctx.freshName("row")
     val numOutputRows = metricTerm(ctx, "numOutputRows")
-    ctx.INPUT_ROW = row
-    ctx.currentVars = null
-    val columns = exprs.map(_.gen(ctx))
 
     // The input RDD can either return (all) ColumnarBatches or InternalRows. 
We determine this
     // by looking at the first value of the RDD and then calling the function 
which will process
@@ -161,7 +158,9 @@ private[sql] case class PhysicalRDD(
     // TODO: The abstractions between this class and SqlNewHadoopRDD makes it 
difficult to know
     // here which path to use. Fix this.
 
-
+    ctx.INPUT_ROW = row
+    ctx.currentVars = null
+    val columns1 = exprs.map(_.gen(ctx))
     val scanBatches = ctx.freshName("processBatches")
     ctx.addNewFunction(scanBatches,
       s"""
@@ -170,12 +169,11 @@ private[sql] case class PhysicalRDD(
       |     int numRows = $batch.numRows();
       |     if ($idx == 0) $numOutputRows.add(numRows);
       |
-      |     while ($idx < numRows) {
+      |     while (!shouldStop() && $idx < numRows) {
       |       InternalRow $row = $batch.getRow($idx++);
-      |       ${columns.map(_.code).mkString("\n").trim}
-      |       ${consume(ctx, columns).trim}
-      |       if (shouldStop()) return;
+      |       ${consume(ctx, columns1).trim}
       |     }
+      |     if (shouldStop()) return;
       |
       |     if (!$input.hasNext()) {
       |       $batch = null;
@@ -186,30 +184,37 @@ private[sql] case class PhysicalRDD(
       |   }
       | }""".stripMargin)
 
+    ctx.INPUT_ROW = row
+    ctx.currentVars = null
+    val columns2 = exprs.map(_.gen(ctx))
+    val inputRow = if (isUnsafeRow) row else null
     val scanRows = ctx.freshName("processRows")
     ctx.addNewFunction(scanRows,
       s"""
        | private void $scanRows(InternalRow $row) throws java.io.IOException {
-       |   while (true) {
+       |   boolean firstRow = true;
+       |   while (!shouldStop() && (firstRow || $input.hasNext())) {
+       |     if (firstRow) {
+       |       firstRow = false;
+       |     } else {
+       |       $row = (InternalRow) $input.next();
+       |     }
        |     $numOutputRows.add(1);
-       |     ${columns.map(_.code).mkString("\n").trim}
-       |     ${consume(ctx, columns).trim}
-       |     if (shouldStop()) return;
-       |     if (!$input.hasNext()) break;
-       |     $row = (InternalRow)$input.next();
+       |     ${consume(ctx, columns2, inputRow).trim}
        |   }
        | }""".stripMargin)
 
+    val value = ctx.freshName("value")
     s"""
        | if ($batch != null) {
        |   $scanBatches();
        | } else if ($input.hasNext()) {
-       |   Object value = $input.next();
-       |   if (value instanceof $columnarBatchClz) {
-       |     $batch = ($columnarBatchClz)value;
+       |   Object $value = $input.next();
+       |   if ($value instanceof $columnarBatchClz) {
+       |     $batch = ($columnarBatchClz)$value;
        |     $scanBatches();
        |   } else {
-       |     $scanRows((InternalRow)value);
+       |     $scanRows((InternalRow) $value);
        |   }
        | }
      """.stripMargin

http://git-wip-us.apache.org/repos/asf/spark/blob/25bba58d/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
index 12998a3..524285b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
@@ -185,8 +185,10 @@ case class Expand(
 
     val numOutput = metricTerm(ctx, "numOutputRows")
     val i = ctx.freshName("i")
+    // these column have to declared before the loop.
+    val evaluate = evaluateVariables(outputColumns)
     s"""
-       |${outputColumns.map(_.code).mkString("\n").trim}
+       |$evaluate
        |for (int $i = 0; $i < ${projections.length}; $i ++) {
        |  switch ($i) {
        |    ${cases.mkString("\n").trim}

http://git-wip-us.apache.org/repos/asf/spark/blob/25bba58d/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index 6d231bf..45578d5 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -81,11 +81,14 @@ trait CodegenSupport extends SparkPlan {
     this.parent = parent
     ctx.freshNamePrefix = variablePrefix
     waitForSubqueries()
-    doProduce(ctx)
+    s"""
+       |/*** PRODUCE: ${toCommentSafeString(this.simpleString)} */
+       |${doProduce(ctx)}
+     """.stripMargin
   }
 
   /**
-    * Generate the Java source code to process, should be overrided by 
subclass to support codegen.
+    * Generate the Java source code to process, should be overridden by 
subclass to support codegen.
     *
     * doProduce() usually generate the framework, for example, aggregation 
could generate this:
     *
@@ -94,11 +97,11 @@ trait CodegenSupport extends SparkPlan {
     *     # call child.produce()
     *     initialized = true;
     *   }
-    *   while (hashmap.hasNext()) {
+    *   while (!shouldStop() && hashmap.hasNext()) {
     *     row = hashmap.next();
     *     # build the aggregation results
-    *     # create varialbles for results
-    *     # call consume(), wich will call parent.doConsume()
+    *     # create variables for results
+    *     # call consume(), which will call parent.doConsume()
     *   }
     */
   protected def doProduce(ctx: CodegenContext): String
@@ -114,27 +117,71 @@ trait CodegenSupport extends SparkPlan {
   }
 
   /**
-    * Consume the columns generated from it's child, call doConsume() or emit 
the rows.
+    * Returns source code to evaluate all the variables, and clear the code of 
them, to prevent
+    * them to be evaluated twice.
+    */
+  protected def evaluateVariables(variables: Seq[ExprCode]): String = {
+    val evaluate = variables.filter(_.code != 
"").map(_.code.trim).mkString("\n")
+    variables.foreach(_.code = "")
+    evaluate
+  }
+
+  /**
+    * Returns source code to evaluate the variables for required attributes, 
and clear the code
+    * of evaluated variables, to prevent them to be evaluated twice..
     */
+  protected def evaluateRequiredVariables(
+      attributes: Seq[Attribute],
+      variables: Seq[ExprCode],
+      required: AttributeSet): String = {
+    var evaluateVars = ""
+    variables.zipWithIndex.foreach { case (ev, i) =>
+      if (ev.code != "" && required.contains(attributes(i))) {
+        evaluateVars += ev.code.trim + "\n"
+        ev.code = ""
+      }
+    }
+    evaluateVars
+  }
+
+  /**
+   * The subset of inputSet those should be evaluated before this plan.
+   *
+   * We will use this to insert some code to access those columns that are 
actually used by current
+   * plan before calling doConsume().
+   */
+  def usedInputs: AttributeSet = references
+
+  /**
+   * Consume the columns generated from its child, call doConsume() or emit 
the rows.
+   *
+   * An operator could generate variables for the output, or a row, either one 
could be null.
+   *
+   * If the row is not null, we create variables to access the columns that 
are actually used by
+   * current plan before calling doConsume().
+   */
   def consumeChild(
       ctx: CodegenContext,
       child: SparkPlan,
       input: Seq[ExprCode],
       row: String = null): String = {
     ctx.freshNamePrefix = variablePrefix
-    if (row != null) {
-      ctx.currentVars = null
-      ctx.INPUT_ROW = row
-      val evals = child.output.zipWithIndex.map { case (attr, i) =>
-        BoundReference(i, attr.dataType, attr.nullable).gen(ctx)
+    val inputVars =
+      if (row != null) {
+        ctx.currentVars = null
+        ctx.INPUT_ROW = row
+        child.output.zipWithIndex.map { case (attr, i) =>
+          BoundReference(i, attr.dataType, attr.nullable).gen(ctx)
+        }
+      } else {
+        input
       }
-      s"""
-         | ${evals.map(_.code).mkString("\n")}
-         | ${doConsume(ctx, evals)}
-       """.stripMargin
-    } else {
-      doConsume(ctx, input)
-    }
+    s"""
+       |
+       |/*** CONSUME: ${toCommentSafeString(this.simpleString)} */
+       |${evaluateRequiredVariables(child.output, inputVars, usedInputs)}
+       |${doConsume(ctx, inputVars)}
+     """.stripMargin
   }
 
   /**
@@ -145,9 +192,8 @@ trait CodegenSupport extends SparkPlan {
     * For example, Filter will generate the code like this:
     *
     *   # code to evaluate the predicate expression, result is isNull1 and 
value2
-    *   if (isNull1 || value2) {
-    *     # call consume(), which will call parent.doConsume()
-    *   }
+    *   if (isNull1 || !value2) continue;
+    *   # call consume(), which will call parent.doConsume()
     */
   protected def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = 
{
     throw new UnsupportedOperationException
@@ -190,13 +236,9 @@ case class InputAdapter(child: SparkPlan) extends 
UnaryNode with CodegenSupport
     ctx.currentVars = null
     val columns = exprs.map(_.gen(ctx))
     s"""
-       | while ($input.hasNext()) {
+       | while (!shouldStop() && $input.hasNext()) {
        |   InternalRow $row = (InternalRow) $input.next();
-       |   ${columns.map(_.code).mkString("\n").trim}
        |   ${consume(ctx, columns).trim}
-       |   if (shouldStop()) {
-       |     return;
-       |   }
        | }
      """.stripMargin
   }
@@ -332,10 +374,12 @@ case class WholeStageCodegen(child: SparkPlan) extends 
UnaryNode with CodegenSup
         val colExprs = output.zipWithIndex.map { case (attr, i) =>
           BoundReference(i, attr.dataType, attr.nullable)
         }
+        val evaluateInputs = evaluateVariables(input)
         // generate the code to create a UnsafeRow
         ctx.currentVars = input
         val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
         s"""
+           |$evaluateInputs
            |${code.code.trim}
            |append(${code.value}.copy());
          """.stripMargin.trim

http://git-wip-us.apache.org/repos/asf/spark/blob/25bba58d/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index a467229..f07add8 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -116,6 +116,8 @@ case class TungstenAggregate(
   // all the mode of aggregate expressions
   private val modes = aggregateExpressions.map(_.mode).distinct
 
+  override def usedInputs: AttributeSet = inputSet
+
   override def supportCodegen: Boolean = {
     // ImperativeAggregate is not supported right now
     
!aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate])
@@ -164,23 +166,24 @@ case class TungstenAggregate(
        """.stripMargin
       ExprCode(ev.code + initVars, isNull, value)
     }
+    val initBufVar = evaluateVariables(bufVars)
 
     // generate variables for output
-    val bufferAttrs = functions.flatMap(_.aggBufferAttributes)
     val (resultVars, genResult) = if (modes.contains(Final) || 
modes.contains(Complete)) {
       // evaluate aggregate results
       ctx.currentVars = bufVars
       val aggResults = functions.map(_.evaluateExpression).map { e =>
-        BindReferences.bindReference(e, bufferAttrs).gen(ctx)
+        BindReferences.bindReference(e, aggregateBufferAttributes).gen(ctx)
       }
+      val evaluateAggResults = evaluateVariables(aggResults)
       // evaluate result expressions
       ctx.currentVars = aggResults
       val resultVars = resultExpressions.map { e =>
         BindReferences.bindReference(e, aggregateAttributes).gen(ctx)
       }
       (resultVars, s"""
-        | ${aggResults.map(_.code).mkString("\n")}
-        | ${resultVars.map(_.code).mkString("\n")}
+        |$evaluateAggResults
+        |${evaluateVariables(resultVars)}
        """.stripMargin)
     } else if (modes.contains(Partial) || modes.contains(PartialMerge)) {
       // output the aggregate buffer directly
@@ -188,7 +191,7 @@ case class TungstenAggregate(
     } else {
       // no aggregate function, the result should be literals
       val resultVars = resultExpressions.map(_.gen(ctx))
-      (resultVars, resultVars.map(_.code).mkString("\n"))
+      (resultVars, evaluateVariables(resultVars))
     }
 
     val doAgg = ctx.freshName("doAggregateWithoutKey")
@@ -196,7 +199,7 @@ case class TungstenAggregate(
       s"""
          | private void $doAgg() throws java.io.IOException {
          |   // initialize aggregation buffer
-         |   ${bufVars.map(_.code).mkString("\n")}
+         |   $initBufVar
          |
          |   ${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
          | }
@@ -204,7 +207,7 @@ case class TungstenAggregate(
 
     val numOutput = metricTerm(ctx, "numOutputRows")
     s"""
-       | if (!$initAgg) {
+       | while (!$initAgg) {
        |   $initAgg = true;
        |   $doAgg();
        |
@@ -241,7 +244,7 @@ case class TungstenAggregate(
     }
     s"""
        | // do aggregate
-       | ${aggVals.map(_.code).mkString("\n").trim}
+       | ${evaluateVariables(aggVals)}
        | // update aggregation buffer
        | ${updates.mkString("\n").trim}
      """.stripMargin
@@ -252,8 +255,7 @@ case class TungstenAggregate(
   private val declFunctions = aggregateExpressions.map(_.aggregateFunction)
     .filter(_.isInstanceOf[DeclarativeAggregate])
     .map(_.asInstanceOf[DeclarativeAggregate])
-  private val bufferAttributes = declFunctions.flatMap(_.aggBufferAttributes)
-  private val bufferSchema = StructType.fromAttributes(bufferAttributes)
+  private val bufferSchema = 
StructType.fromAttributes(aggregateBufferAttributes)
 
   // The name for HashMap
   private var hashMapTerm: String = _
@@ -318,7 +320,7 @@ case class TungstenAggregate(
       val mergeExpr = declFunctions.flatMap(_.mergeExpressions)
       val mergeProjection = newMutableProjection(
         mergeExpr,
-        bufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes),
+        aggregateBufferAttributes ++ 
declFunctions.flatMap(_.inputAggBufferAttributes),
         subexpressionEliminationEnabled)()
       val joinedRow = new JoinedRow()
 
@@ -380,15 +382,18 @@ case class TungstenAggregate(
       val keyVars = groupingExpressions.zipWithIndex.map { case (e, i) =>
         BoundReference(i, e.dataType, e.nullable).gen(ctx)
       }
+      val evaluateKeyVars = evaluateVariables(keyVars)
       ctx.INPUT_ROW = bufferTerm
-      val bufferVars = bufferAttributes.zipWithIndex.map { case (e, i) =>
+      val bufferVars = aggregateBufferAttributes.zipWithIndex.map { case (e, 
i) =>
         BoundReference(i, e.dataType, e.nullable).gen(ctx)
       }
+      val evaluateBufferVars = evaluateVariables(bufferVars)
       // evaluate the aggregation result
       ctx.currentVars = bufferVars
       val aggResults = declFunctions.map(_.evaluateExpression).map { e =>
-        BindReferences.bindReference(e, bufferAttributes).gen(ctx)
+        BindReferences.bindReference(e, aggregateBufferAttributes).gen(ctx)
       }
+      val evaluateAggResults = evaluateVariables(aggResults)
       // generate the final result
       ctx.currentVars = keyVars ++ aggResults
       val inputAttrs = groupingAttributes ++ aggregateAttributes
@@ -396,11 +401,9 @@ case class TungstenAggregate(
         BindReferences.bindReference(e, inputAttrs).gen(ctx)
       }
       s"""
-       ${keyVars.map(_.code).mkString("\n")}
-       ${bufferVars.map(_.code).mkString("\n")}
-       ${aggResults.map(_.code).mkString("\n")}
-       ${resultVars.map(_.code).mkString("\n")}
-
+       $evaluateKeyVars
+       $evaluateBufferVars
+       $evaluateAggResults
        ${consume(ctx, resultVars)}
        """
 
@@ -422,10 +425,7 @@ case class TungstenAggregate(
       val eval = resultExpressions.map{ e =>
         BindReferences.bindReference(e, groupingAttributes).gen(ctx)
       }
-      s"""
-       ${eval.map(_.code).mkString("\n")}
-       ${consume(ctx, eval)}
-       """
+      consume(ctx, eval)
     }
   }
 
@@ -508,8 +508,8 @@ case class TungstenAggregate(
     ctx.currentVars = input
     val hashEval = BindReferences.bindReference(hashExpr, 
child.output).gen(ctx)
 
-    val inputAttr = bufferAttributes ++ child.output
-    ctx.currentVars = new Array[ExprCode](bufferAttributes.length) ++ input
+    val inputAttr = aggregateBufferAttributes ++ child.output
+    ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ 
input
     ctx.INPUT_ROW = buffer
     // TODO: support subexpression elimination
     val evals = updateExpr.map(BindReferences.bindReference(_, 
inputAttr).gen(ctx))
@@ -557,7 +557,7 @@ case class TungstenAggregate(
      $incCounter
 
      // evaluate aggregate function
-     ${evals.map(_.code).mkString("\n").trim}
+     ${evaluateVariables(evals)}
      // update aggregate buffer
      ${updates.mkString("\n").trim}
      """

http://git-wip-us.apache.org/repos/asf/spark/blob/25bba58d/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index b2f443c..4a9e736 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -39,15 +39,26 @@ case class Project(projectList: Seq[NamedExpression], 
child: SparkPlan)
     child.asInstanceOf[CodegenSupport].produce(ctx, this)
   }
 
+  override def usedInputs: AttributeSet = {
+    // only the attributes those are used at least twice should be evaluated 
before this plan,
+    // otherwise we could defer the evaluation until output attribute is 
actually used.
+    val usedExprIds = projectList.flatMap(_.collect {
+      case a: Attribute => a.exprId
+    })
+    val usedMoreThanOnce = usedExprIds.groupBy(id => id).filter(_._2.size > 
1).keySet
+    references.filter(a => usedMoreThanOnce.contains(a.exprId))
+  }
+
   override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
     val exprs = projectList.map(x =>
       ExpressionCanonicalizer.execute(BindReferences.bindReference(x, 
child.output)))
     ctx.currentVars = input
-    val output = exprs.map(_.gen(ctx))
+    val resultVars = exprs.map(_.gen(ctx))
+    // Evaluation of non-deterministic expressions can't be deferred.
+    val nonDeterministicAttrs = 
projectList.filterNot(_.deterministic).map(_.toAttribute)
     s"""
-       | ${output.map(_.code).mkString("\n")}
-       |
-       | ${consume(ctx, output)}
+       |${evaluateRequiredVariables(output, resultVars, 
AttributeSet(nonDeterministicAttrs))}
+       |${consume(ctx, resultVars)}
      """.stripMargin
   }
 
@@ -89,11 +100,10 @@ case class Filter(condition: Expression, child: SparkPlan) 
extends UnaryNode wit
       s""
     }
     s"""
-       | ${eval.code}
-       | if ($nullCheck ${eval.value}) {
-       |   $numOutput.add(1);
-       |   ${consume(ctx, ctx.currentVars)}
-       | }
+       |${eval.code}
+       |if (!($nullCheck ${eval.value})) continue;
+       |$numOutput.add(1);
+       |${consume(ctx, ctx.currentVars)}
      """.stripMargin
   }
 
@@ -228,15 +238,13 @@ case class Range(
       |   }
       | }
       |
-      | while (!$overflow && $checkEnd) {
+      | while (!$overflow && $checkEnd && !shouldStop()) {
       |  long $value = $number;
       |  $number += ${step}L;
       |  if ($number < $value ^ ${step}L < 0) {
       |    $overflow = true;
       |  }
       |  ${consume(ctx, Seq(ev))}
-      |
-      |  if (shouldStop()) return;
       | }
      """.stripMargin
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/25bba58d/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index 6699dba..c52662a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -190,40 +190,38 @@ case class BroadcastHashJoin(
     val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
     val matched = ctx.freshName("matched")
     val buildVars = genBuildSideVars(ctx, matched)
-    val resultVars = buildSide match {
-      case BuildLeft => buildVars ++ input
-      case BuildRight => input ++ buildVars
-    }
     val numOutput = metricTerm(ctx, "numOutputRows")
 
-    val outputCode = if (condition.isDefined) {
+    val checkCondition = if (condition.isDefined) {
+      val expr = condition.get
+      // evaluate the variables from build side that used by condition
+      val eval = evaluateRequiredVariables(buildPlan.output, buildVars, 
expr.references)
       // filter the output via condition
-      ctx.currentVars = resultVars
-      val ev = BindReferences.bindReference(condition.get, 
this.output).gen(ctx)
+      ctx.currentVars = input ++ buildVars
+      val ev = BindReferences.bindReference(expr, streamedPlan.output ++ 
buildPlan.output).gen(ctx)
       s"""
+         |$eval
          |${ev.code}
-         |if (!${ev.isNull} && ${ev.value}) {
-         |  $numOutput.add(1);
-         |  ${consume(ctx, resultVars)}
-         |}
+         |if (${ev.isNull} || !${ev.value}) continue;
        """.stripMargin
     } else {
-      s"""
-         |$numOutput.add(1);
-         |${consume(ctx, resultVars)}
-       """.stripMargin
+      ""
     }
 
+    val resultVars = buildSide match {
+      case BuildLeft => buildVars ++ input
+      case BuildRight => input ++ buildVars
+    }
     if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) {
       s"""
          |// generate join key for stream side
          |${keyEv.code}
          |// find matches from HashedRelation
          |UnsafeRow $matched = $anyNull ? null: 
(UnsafeRow)$relationTerm.getValue(${keyEv.value});
-         |if ($matched != null) {
-         |  ${buildVars.map(_.code).mkString("\n")}
-         |  $outputCode
-         |}
+         |if ($matched == null) continue;
+         |$checkCondition
+         |$numOutput.add(1);
+         |${consume(ctx, resultVars)}
        """.stripMargin
 
     } else {
@@ -236,13 +234,13 @@ case class BroadcastHashJoin(
          |${keyEv.code}
          |// find matches from HashRelation
          |$bufferType $matches = $anyNull ? null : 
($bufferType)$relationTerm.get(${keyEv.value});
-         |if ($matches != null) {
-         |  int $size = $matches.size();
-         |  for (int $i = 0; $i < $size; $i++) {
-         |    UnsafeRow $matched = (UnsafeRow) $matches.apply($i);
-         |    ${buildVars.map(_.code).mkString("\n")}
-         |    $outputCode
-         |  }
+         |if ($matches == null) continue;
+         |int $size = $matches.size();
+         |for (int $i = 0; $i < $size; $i++) {
+         |  UnsafeRow $matched = (UnsafeRow) $matches.apply($i);
+         |  $checkCondition
+         |  $numOutput.add(1);
+         |  ${consume(ctx, resultVars)}
          |}
        """.stripMargin
     }
@@ -257,21 +255,21 @@ case class BroadcastHashJoin(
     val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
     val matched = ctx.freshName("matched")
     val buildVars = genBuildSideVars(ctx, matched)
-    val resultVars = buildSide match {
-      case BuildLeft => buildVars ++ input
-      case BuildRight => input ++ buildVars
-    }
     val numOutput = metricTerm(ctx, "numOutputRows")
 
     // filter the output via condition
     val conditionPassed = ctx.freshName("conditionPassed")
     val checkCondition = if (condition.isDefined) {
-      ctx.currentVars = resultVars
-      val ev = BindReferences.bindReference(condition.get, 
this.output).gen(ctx)
+      val expr = condition.get
+      // evaluate the variables from build side that used by condition
+      val eval = evaluateRequiredVariables(buildPlan.output, buildVars, 
expr.references)
+      ctx.currentVars = input ++ buildVars
+      val ev = BindReferences.bindReference(expr, streamedPlan.output ++ 
buildPlan.output).gen(ctx)
       s"""
          |boolean $conditionPassed = true;
+         |${eval.trim}
+         |${ev.code}
          |if ($matched != null) {
-         |  ${ev.code}
          |  $conditionPassed = !${ev.isNull} && ${ev.value};
          |}
        """.stripMargin
@@ -279,17 +277,21 @@ case class BroadcastHashJoin(
       s"final boolean $conditionPassed = true;"
     }
 
+    val resultVars = buildSide match {
+      case BuildLeft => buildVars ++ input
+      case BuildRight => input ++ buildVars
+    }
     if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) {
       s"""
          |// generate join key for stream side
          |${keyEv.code}
          |// find matches from HashedRelation
          |UnsafeRow $matched = $anyNull ? null: 
(UnsafeRow)$relationTerm.getValue(${keyEv.value});
-         |${buildVars.map(_.code).mkString("\n")}
          |${checkCondition.trim}
          |if (!$conditionPassed) {
-         |  // reset to null
-         |  ${buildVars.map(v => s"${v.isNull} = true;").mkString("\n")}
+         |  $matched = null;
+         |  // reset the variables those are already evaluated.
+         |  ${buildVars.filter(_.code == "").map(v => s"${v.isNull} = 
true;").mkString("\n")}
          |}
          |$numOutput.add(1);
          |${consume(ctx, resultVars)}
@@ -311,13 +313,11 @@ case class BroadcastHashJoin(
          |// the last iteration of this loop is to emit an empty row if there 
is no matched rows.
          |for (int $i = 0; $i <= $size; $i++) {
          |  UnsafeRow $matched = $i < $size ? (UnsafeRow) $matches.apply($i) : 
null;
-         |  ${buildVars.map(_.code).mkString("\n")}
          |  ${checkCondition.trim}
-         |  if ($conditionPassed && ($i < $size || !$found)) {
-         |    $found = true;
-         |    $numOutput.add(1);
-         |    ${consume(ctx, resultVars)}
-         |  }
+         |  if (!$conditionPassed || ($i == $size && $found)) continue;
+         |  $found = true;
+         |  $numOutput.add(1);
+         |  ${consume(ctx, resultVars)}
          |}
        """.stripMargin
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/25bba58d/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
index 7ec4027..cffd6f6 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
@@ -306,11 +306,11 @@ case class SortMergeJoin(
       val (used, notUsed) = attributes.zip(variables).partition{ case (a, ev) 
=>
         condRefs.contains(a)
       }
-      val beforeCond = used.map(_._2.code).mkString("\n")
-      val afterCond = notUsed.map(_._2.code).mkString("\n")
+      val beforeCond = evaluateVariables(used.map(_._2))
+      val afterCond = evaluateVariables(notUsed.map(_._2))
       (beforeCond, afterCond)
     } else {
-      (variables.map(_.code).mkString("\n"), "")
+      (evaluateVariables(variables), "")
     }
   }
 
@@ -326,41 +326,48 @@ case class SortMergeJoin(
     val leftVars = createLeftVars(ctx, leftRow)
     val rightRow = ctx.freshName("rightRow")
     val rightVars = createRightVar(ctx, rightRow)
-    val resultVars = leftVars ++ rightVars
-
-    // Check condition
-    ctx.currentVars = resultVars
-    val cond = if (condition.isDefined) {
-      BindReferences.bindReference(condition.get, output).gen(ctx)
-    } else {
-      ExprCode("", "false", "true")
-    }
-    // Split the code of creating variables based on whether it's used by 
condition or not.
-    val loaded = ctx.freshName("loaded")
-    val (leftBefore, leftAfter) = splitVarsByCondition(left.output, leftVars)
-    val (rightBefore, rightAfter) = splitVarsByCondition(right.output, 
rightVars)
-
 
     val size = ctx.freshName("size")
     val i = ctx.freshName("i")
     val numOutput = metricTerm(ctx, "numOutputRows")
+    val (beforeLoop, condCheck) = if (condition.isDefined) {
+      // Split the code of creating variables based on whether it's used by 
condition or not.
+      val loaded = ctx.freshName("loaded")
+      val (leftBefore, leftAfter) = splitVarsByCondition(left.output, leftVars)
+      val (rightBefore, rightAfter) = splitVarsByCondition(right.output, 
rightVars)
+      // Generate code for condition
+      ctx.currentVars = leftVars ++ rightVars
+      val cond = BindReferences.bindReference(condition.get, output).gen(ctx)
+      // evaluate the columns those used by condition before loop
+      val before = s"""
+           |boolean $loaded = false;
+           |$leftBefore
+         """.stripMargin
+
+      val checking = s"""
+         |$rightBefore
+         |${cond.code}
+         |if (${cond.isNull} || !${cond.value}) continue;
+         |if (!$loaded) {
+         |  $loaded = true;
+         |  $leftAfter
+         |}
+         |$rightAfter
+     """.stripMargin
+      (before, checking)
+    } else {
+      (evaluateVariables(leftVars), "")
+    }
+
     s"""
        |while (findNextInnerJoinRows($leftInput, $rightInput)) {
        |  int $size = $matches.size();
-       |  boolean $loaded = false;
-       |  $leftBefore
+       |  ${beforeLoop.trim}
        |  for (int $i = 0; $i < $size; $i ++) {
        |    InternalRow $rightRow = (InternalRow) $matches.get($i);
-       |    $rightBefore
-       |    ${cond.code}
-       |    if (${cond.isNull} || !${cond.value}) continue;
-       |    if (!$loaded) {
-       |      $loaded = true;
-       |      $leftAfter
-       |    }
-       |    $rightAfter
+       |    ${condCheck.trim}
        |    $numOutput.add(1);
-       |    ${consume(ctx, resultVars)}
+       |    ${consume(ctx, leftVars ++ rightVars)}
        |  }
        |  if (shouldStop()) return;
        |}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to