Repository: spark
Updated Branches:
  refs/heads/master 4637fc08a -> b9dfdcc63


Revert "[SPARK-13031] [SQL] cleanup codegen and improve test coverage"

This reverts commit cc18a7199240bf3b03410c1ba6704fe7ce6ae38e.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b9dfdcc6
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b9dfdcc6
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b9dfdcc6

Branch: refs/heads/master
Commit: b9dfdcc63bb12bc24de96060e756889c2ceda519
Parents: 4637fc0
Author: Davies Liu <davies....@gmail.com>
Authored: Thu Jan 28 17:01:12 2016 -0800
Committer: Davies Liu <davies....@gmail.com>
Committed: Thu Jan 28 17:01:12 2016 -0800

----------------------------------------------------------------------
 .../expressions/codegen/CodeGenerator.scala     |  13 +-
 .../codegen/GenerateMutableProjection.scala     |   2 +-
 .../spark/sql/execution/WholeStageCodegen.scala | 188 ++++++-------------
 .../execution/aggregate/TungstenAggregate.scala |  88 +++------
 .../spark/sql/execution/basicOperators.scala    |  96 +++++-----
 .../org/apache/spark/sql/SQLQuerySuite.scala    | 103 +++++-----
 .../sql/execution/metric/SQLMetricsSuite.scala  |  34 ++--
 .../apache/spark/sql/test/SQLTestUtils.scala    |   2 +-
 .../spark/sql/util/DataFrameCallbackSuite.scala |  10 +-
 9 files changed, 202 insertions(+), 334 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b9dfdcc6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index e6704cf..2747c31 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -145,22 +145,13 @@ class CodegenContext {
   private val curId = new java.util.concurrent.atomic.AtomicInteger()
 
   /**
-    * A prefix used to generate fresh name.
-    */
-  var freshNamePrefix = ""
-
-  /**
    * Returns a term name that is unique within this instance of a 
`CodeGenerator`.
    *
    * (Since we aren't in a macro context we do not seem to have access to the 
built in `freshName`
    * function.)
    */
-  def freshName(name: String): String = {
-    if (freshNamePrefix == "") {
-      s"$name${curId.getAndIncrement}"
-    } else {
-      s"${freshNamePrefix}_$name${curId.getAndIncrement}"
-    }
+  def freshName(prefix: String): String = {
+    s"$prefix${curId.getAndIncrement}"
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/b9dfdcc6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index ec31db1..d9fe761 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -93,7 +93,7 @@ object GenerateMutableProjection extends 
CodeGenerator[Seq[Expression], () => Mu
             // Can't call setNullAt on DecimalType, because we need to keep 
the offset
             s"""
               if (this.isNull_$i) {
-                ${ctx.setColumn("mutableRow", e.dataType, i, "null")};
+                ${ctx.setColumn("mutableRow", e.dataType, i, null)};
               } else {
                 ${ctx.setColumn("mutableRow", e.dataType, i, 
s"this.value_$i")};
               }

http://git-wip-us.apache.org/repos/asf/spark/blob/b9dfdcc6/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index ef81ba6..57f4945 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -22,11 +22,9 @@ import scala.collection.mutable.ArrayBuffer
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.SQLContext
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, 
Expression, LeafExpression}
 import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.catalyst.plans.physical.Partitioning
 import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.util.Utils
 
 /**
   * An interface for those physical operators that support codegen.
@@ -44,16 +42,10 @@ trait CodegenSupport extends SparkPlan {
   private var parent: CodegenSupport = null
 
   /**
-    * Returns the RDD of InternalRow which generates the input rows.
+    * Returns an input RDD of InternalRow and Java source code to process them.
     */
-  def upstream(): RDD[InternalRow]
-
-  /**
-    * Returns Java source code to process the rows from upstream.
-    */
-  def produce(ctx: CodegenContext, parent: CodegenSupport): String = {
+  def produce(ctx: CodegenContext, parent: CodegenSupport): (RDD[InternalRow], 
String) = {
     this.parent = parent
-    ctx.freshNamePrefix = nodeName
     doProduce(ctx)
   }
 
@@ -74,41 +66,16 @@ trait CodegenSupport extends SparkPlan {
     *     # call consume(), wich will call parent.doConsume()
     *   }
     */
-  protected def doProduce(ctx: CodegenContext): String
+  protected def doProduce(ctx: CodegenContext): (RDD[InternalRow], String)
 
   /**
-    * Consume the columns generated from current SparkPlan, call it's parent.
+    * Consume the columns generated from current SparkPlan, call it's parent 
or create an iterator.
     */
-  def consume(ctx: CodegenContext, input: Seq[ExprCode], row: String = null): 
String = {
-    if (input != null) {
-      assert(input.length == output.length)
-    }
-    parent.consumeChild(ctx, this, input, row)
+  protected def consume(ctx: CodegenContext, columns: Seq[ExprCode]): String = 
{
+    assert(columns.length == output.length)
+    parent.doConsume(ctx, this, columns)
   }
 
-  /**
-    * Consume the columns generated from it's child, call doConsume() or emit 
the rows.
-    */
-  def consumeChild(
-      ctx: CodegenContext,
-      child: SparkPlan,
-      input: Seq[ExprCode],
-      row: String = null): String = {
-    ctx.freshNamePrefix = nodeName
-    if (row != null) {
-      ctx.currentVars = null
-      ctx.INPUT_ROW = row
-      val evals = child.output.zipWithIndex.map { case (attr, i) =>
-        BoundReference(i, attr.dataType, attr.nullable).gen(ctx)
-      }
-      s"""
-         | ${evals.map(_.code).mkString("\n")}
-         | ${doConsume(ctx, evals)}
-       """.stripMargin
-    } else {
-      doConsume(ctx, input)
-    }
-  }
 
   /**
     * Generate the Java source code to process the rows from child SparkPlan.
@@ -122,9 +89,7 @@ trait CodegenSupport extends SparkPlan {
     *     # call consume(), which will call parent.doConsume()
     *   }
     */
-  protected def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = 
{
-    throw new UnsupportedOperationException
-  }
+  def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): 
String
 }
 
 
@@ -137,36 +102,31 @@ trait CodegenSupport extends SparkPlan {
 case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport 
{
 
   override def output: Seq[Attribute] = child.output
-  override def outputPartitioning: Partitioning = child.outputPartitioning
-  override def outputOrdering: Seq[SortOrder] = child.outputOrdering
-
-  override def doPrepare(): Unit = {
-    child.prepare()
-  }
 
-  override def doExecute(): RDD[InternalRow] = {
-    child.execute()
-  }
+  override def supportCodegen: Boolean = true
 
-  override def supportCodegen: Boolean = false
-
-  override def upstream(): RDD[InternalRow] = {
-    child.execute()
-  }
-
-  override def doProduce(ctx: CodegenContext): String = {
+  override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
     val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, 
x._1.dataType, true))
     val row = ctx.freshName("row")
     ctx.INPUT_ROW = row
     ctx.currentVars = null
     val columns = exprs.map(_.gen(ctx))
-    s"""
-       | while (input.hasNext()) {
+    val code = s"""
+       |  while (input.hasNext()) {
        |   InternalRow $row = (InternalRow) input.next();
        |   ${columns.map(_.code).mkString("\n")}
        |   ${consume(ctx, columns)}
        | }
      """.stripMargin
+    (child.execute(), code)
+  }
+
+  def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): 
String = {
+    throw new UnsupportedOperationException
+  }
+
+  override def doExecute(): RDD[InternalRow] = {
+    throw new UnsupportedOperationException
   }
 
   override def simpleString: String = "INPUT"
@@ -183,20 +143,16 @@ case class InputAdapter(child: SparkPlan) extends 
LeafNode with CodegenSupport {
   *
   * -> execute()
   *     |
-  *  doExecute() --------->   upstream() -------> upstream() ------> execute()
-  *     |
-  *      ----------------->   produce()
+  *  doExecute() -------->   produce()
   *                             |
   *                          doProduce()  -------> produce()
   *                                                   |
-  *                                                doProduce()
+  *                                                doProduce() ---> execute()
   *                                                   |
   *                                                consume()
-  *                        consumeChild() <-----------|
+  *                          doConsume()  ------------|
   *                             |
-  *                          doConsume()
-  *                             |
-  *  consumeChild()  <-----  consume()
+  *  doConsume()  <-----    consume()
   *
   * SparkPlan A should override doProduce() and doConsume().
   *
@@ -206,48 +162,37 @@ case class InputAdapter(child: SparkPlan) extends 
LeafNode with CodegenSupport {
 case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
   extends SparkPlan with CodegenSupport {
 
-  override def supportCodegen: Boolean = false
-
   override def output: Seq[Attribute] = plan.output
-  override def outputPartitioning: Partitioning = plan.outputPartitioning
-  override def outputOrdering: Seq[SortOrder] = plan.outputOrdering
-
-  override def doPrepare(): Unit = {
-    plan.prepare()
-  }
 
   override def doExecute(): RDD[InternalRow] = {
     val ctx = new CodegenContext
-    val code = plan.produce(ctx, this)
+    val (rdd, code) = plan.produce(ctx, this)
     val references = ctx.references.toArray
     val source = s"""
       public Object generate(Object[] references) {
-        return new GeneratedIterator(references);
+       return new GeneratedIterator(references);
       }
 
       class GeneratedIterator extends 
org.apache.spark.sql.execution.BufferedRowIterator {
 
-        private Object[] references;
-        ${ctx.declareMutableStates()}
-        ${ctx.declareAddedFunctions()}
+       private Object[] references;
+       ${ctx.declareMutableStates()}
 
-        public GeneratedIterator(Object[] references) {
+       public GeneratedIterator(Object[] references) {
          this.references = references;
          ${ctx.initMutableStates()}
-        }
+       }
 
-        protected void processNext() throws java.io.IOException {
+       protected void processNext() {
          $code
-        }
+       }
       }
-      """
-
+     """
     // try to compile, helpful for debug
     // println(s"${CodeFormatter.format(source)}")
     CodeGenerator.compile(source)
 
-    plan.upstream().mapPartitions { iter =>
-
+    rdd.mapPartitions { iter =>
       val clazz = CodeGenerator.compile(source)
       val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
       buffer.setInput(iter)
@@ -258,47 +203,29 @@ case class WholeStageCodegen(plan: CodegenSupport, 
children: Seq[SparkPlan])
     }
   }
 
-  override def upstream(): RDD[InternalRow] = {
+  override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
     throw new UnsupportedOperationException
   }
 
-  override def doProduce(ctx: CodegenContext): String = {
-    throw new UnsupportedOperationException
-  }
-
-  override def consumeChild(
-      ctx: CodegenContext,
-      child: SparkPlan,
-      input: Seq[ExprCode],
-      row: String = null): String = {
-
-    if (row != null) {
-      // There is an UnsafeRow already
+  override def doConsume(ctx: CodegenContext, child: SparkPlan, input: 
Seq[ExprCode]): String = {
+    if (input.nonEmpty) {
+      val colExprs = output.zipWithIndex.map { case (attr, i) =>
+        BoundReference(i, attr.dataType, attr.nullable)
+      }
+      // generate the code to create a UnsafeRow
+      ctx.currentVars = input
+      val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
       s"""
-         | currentRow = $row;
+         | ${code.code.trim}
+         | currentRow = ${code.value};
          | return;
-       """.stripMargin
+     """.stripMargin
     } else {
-      assert(input != null)
-      if (input.nonEmpty) {
-        val colExprs = output.zipWithIndex.map { case (attr, i) =>
-          BoundReference(i, attr.dataType, attr.nullable)
-        }
-        // generate the code to create a UnsafeRow
-        ctx.currentVars = input
-        val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
-        s"""
-           | ${code.code.trim}
-           | currentRow = ${code.value};
-           | return;
-         """.stripMargin
-      } else {
-        // There is no columns
-        s"""
-           | currentRow = unsafeRow;
-           | return;
-         """.stripMargin
-      }
+      // There is no columns
+      s"""
+         | currentRow = unsafeRow;
+         | return;
+       """.stripMargin
     }
   }
 
@@ -319,7 +246,7 @@ case class WholeStageCodegen(plan: CodegenSupport, 
children: Seq[SparkPlan])
     builder.append(simpleString)
     builder.append("\n")
 
-    plan.generateTreeString(depth + 2, lastChildren :+ false :+ true, builder)
+    plan.generateTreeString(depth + 1, lastChildren :+children.isEmpty :+ 
true, builder)
     if (children.nonEmpty) {
       children.init.foreach(_.generateTreeString(depth + 1, lastChildren :+ 
false, builder))
       children.last.generateTreeString(depth + 1, lastChildren :+ true, 
builder)
@@ -359,14 +286,13 @@ private[sql] case class CollapseCodegenStages(sqlContext: 
SQLContext) extends Ru
         case plan: CodegenSupport if supportCodegen(plan) &&
           // Whole stage codegen is only useful when there are at least two 
levels of operators that
           // support it (save at least one projection/iterator).
-          (Utils.isTesting || plan.children.exists(supportCodegen)) =>
+          plan.children.exists(supportCodegen) =>
 
           var inputs = ArrayBuffer[SparkPlan]()
           val combined = plan.transform {
             case p if !supportCodegen(p) =>
-              val input = apply(p)  // collapse them recursively
-              inputs += input
-              InputAdapter(input)
+              inputs += p
+              InputAdapter(p)
           }.asInstanceOf[CodegenSupport]
           WholeStageCodegen(combined, inputs)
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/b9dfdcc6/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index cbd2634..23e54f3 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -117,7 +117,9 @@ case class TungstenAggregate(
   override def supportCodegen: Boolean = {
     groupingExpressions.isEmpty &&
       // ImperativeAggregate is not supported right now
-      
!aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate])
+      
!aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate])
 &&
+      // final aggregation only have one row, do not need to codegen
+      !aggregateExpressions.exists(e => e.mode == Final || e.mode == Complete)
   }
 
   // The variables used as aggregation buffer
@@ -125,11 +127,7 @@ case class TungstenAggregate(
 
   private val modes = aggregateExpressions.map(_.mode).distinct
 
-  override def upstream(): RDD[InternalRow] = {
-    child.asInstanceOf[CodegenSupport].upstream()
-  }
-
-  protected override def doProduce(ctx: CodegenContext): String = {
+  protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], 
String) = {
     val initAgg = ctx.freshName("initAgg")
     ctx.addMutableState("boolean", initAgg, s"$initAgg = false;")
 
@@ -139,80 +137,50 @@ case class TungstenAggregate(
     bufVars = initExpr.map { e =>
       val isNull = ctx.freshName("bufIsNull")
       val value = ctx.freshName("bufValue")
-      ctx.addMutableState("boolean", isNull, "")
-      ctx.addMutableState(ctx.javaType(e.dataType), value, "")
       // The initial expression should not access any column
       val ev = e.gen(ctx)
       val initVars = s"""
-         | $isNull = ${ev.isNull};
-         | $value = ${ev.value};
+         | boolean $isNull = ${ev.isNull};
+         | ${ctx.javaType(e.dataType)} $value = ${ev.value};
        """.stripMargin
       ExprCode(ev.code + initVars, isNull, value)
     }
 
-    // generate variables for output
-    val (resultVars, genResult) = if (modes.contains(Final) | 
modes.contains(Complete)) {
-      // evaluate aggregate results
-      ctx.currentVars = bufVars
-      val bufferAttrs = functions.flatMap(_.aggBufferAttributes)
-      val aggResults = functions.map(_.evaluateExpression).map { e =>
-        BindReferences.bindReference(e, bufferAttrs).gen(ctx)
-      }
-      // evaluate result expressions
-      ctx.currentVars = aggResults
-      val resultVars = resultExpressions.map { e =>
-        BindReferences.bindReference(e, aggregateAttributes).gen(ctx)
-      }
-      (resultVars, s"""
-        | ${aggResults.map(_.code).mkString("\n")}
-        | ${resultVars.map(_.code).mkString("\n")}
-       """.stripMargin)
-    } else {
-      // output the aggregate buffer directly
-      (bufVars, "")
-    }
-
-    val doAgg = ctx.freshName("doAgg")
-    ctx.addNewFunction(doAgg,
+    val (rdd, childSource) = child.asInstanceOf[CodegenSupport].produce(ctx, 
this)
+    val source =
       s"""
-         | private void $doAgg() {
+         | if (!$initAgg) {
+         |   $initAgg = true;
+         |
          |   // initialize aggregation buffer
          |   ${bufVars.map(_.code).mkString("\n")}
          |
-         |   ${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
+         |   $childSource
+         |
+         |   // output the result
+         |   ${consume(ctx, bufVars)}
          | }
-       """.stripMargin)
+       """.stripMargin
 
-    s"""
-       | if (!$initAgg) {
-       |   $initAgg = true;
-       |   $doAgg();
-       |
-       |   // output the result
-       |   $genResult
-       |
-       |   ${consume(ctx, resultVars)}
-       | }
-     """.stripMargin
+    (rdd, source)
   }
 
-  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+  override def doConsume(ctx: CodegenContext, child: SparkPlan, input: 
Seq[ExprCode]): String = {
     // only have DeclarativeAggregate
     val functions = 
aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
-    val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output
-    val updateExpr = aggregateExpressions.flatMap { e =>
-      e.mode match {
-        case Partial | Complete =>
-          
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions
-        case PartialMerge | Final =>
-          
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions
-      }
+    // the mode could be only Partial or PartialMerge
+    val updateExpr = if (modes.contains(Partial)) {
+      functions.flatMap(_.updateExpressions)
+    } else {
+      functions.flatMap(_.mergeExpressions)
     }
 
+    val inputAttr = functions.flatMap(_.aggBufferAttributes) ++ child.output
+    val boundExpr = updateExpr.map(e => BindReferences.bindReference(e, 
inputAttr))
     ctx.currentVars = bufVars ++ input
     // TODO: support subexpression elimination
-    val updates = updateExpr.zipWithIndex.map { case (e, i) =>
-      val ev = BindReferences.bindReference[Expression](e, inputAttrs).gen(ctx)
+    val codes = boundExpr.zipWithIndex.map { case (e, i) =>
+      val ev = e.gen(ctx)
       s"""
          | ${ev.code}
          | ${bufVars(i).isNull} = ${ev.isNull};
@@ -222,7 +190,7 @@ case class TungstenAggregate(
 
     s"""
        | // do aggregate and update aggregation buffer
-       | ${updates.mkString("")}
+       | ${codes.mkString("")}
      """.stripMargin
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b9dfdcc6/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index e7a73d5..6deb72a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -37,15 +37,11 @@ case class Project(projectList: Seq[NamedExpression], 
child: SparkPlan)
 
   override def output: Seq[Attribute] = projectList.map(_.toAttribute)
 
-  override def upstream(): RDD[InternalRow] = {
-    child.asInstanceOf[CodegenSupport].upstream()
-  }
-
-  protected override def doProduce(ctx: CodegenContext): String = {
+  protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], 
String) = {
     child.asInstanceOf[CodegenSupport].produce(ctx, this)
   }
 
-  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+  override def doConsume(ctx: CodegenContext, child: SparkPlan, input: 
Seq[ExprCode]): String = {
     val exprs = projectList.map(x =>
       ExpressionCanonicalizer.execute(BindReferences.bindReference(x, 
child.output)))
     ctx.currentVars = input
@@ -80,15 +76,11 @@ case class Filter(condition: Expression, child: SparkPlan) 
extends UnaryNode wit
     "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of 
input rows"),
     "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of 
output rows"))
 
-  override def upstream(): RDD[InternalRow] = {
-    child.asInstanceOf[CodegenSupport].upstream()
-  }
-
-  protected override def doProduce(ctx: CodegenContext): String = {
+  protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], 
String) = {
     child.asInstanceOf[CodegenSupport].produce(ctx, this)
   }
 
-  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+  override def doConsume(ctx: CodegenContext, child: SparkPlan, input: 
Seq[ExprCode]): String = {
     val expr = ExpressionCanonicalizer.execute(
       BindReferences.bindReference(condition, child.output))
     ctx.currentVars = input
@@ -161,21 +153,17 @@ case class Range(
     output: Seq[Attribute])
   extends LeafNode with CodegenSupport {
 
-  override def upstream(): RDD[InternalRow] = {
-    sqlContext.sparkContext.parallelize(0 until numSlices, numSlices).map(i => 
InternalRow(i))
-  }
-
-  protected override def doProduce(ctx: CodegenContext): String = {
-    val initTerm = ctx.freshName("initRange")
+  protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], 
String) = {
+    val initTerm = ctx.freshName("range_initRange")
     ctx.addMutableState("boolean", initTerm, s"$initTerm = false;")
-    val partitionEnd = ctx.freshName("partitionEnd")
+    val partitionEnd = ctx.freshName("range_partitionEnd")
     ctx.addMutableState("long", partitionEnd, s"$partitionEnd = 0L;")
-    val number = ctx.freshName("number")
+    val number = ctx.freshName("range_number")
     ctx.addMutableState("long", number, s"$number = 0L;")
-    val overflow = ctx.freshName("overflow")
+    val overflow = ctx.freshName("range_overflow")
     ctx.addMutableState("boolean", overflow, s"$overflow = false;")
 
-    val value = ctx.freshName("value")
+    val value = ctx.freshName("range_value")
     val ev = ExprCode("", "false", value)
     val BigInt = classOf[java.math.BigInteger].getName
     val checkEnd = if (step > 0) {
@@ -184,42 +172,38 @@ case class Range(
       s"$number > $partitionEnd"
     }
 
-    ctx.addNewFunction("initRange",
-      s"""
-        | private void initRange(int idx) {
-        |   $BigInt index = $BigInt.valueOf(idx);
-        |   $BigInt numSlice = $BigInt.valueOf(${numSlices}L);
-        |   $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L);
-        |   $BigInt step = $BigInt.valueOf(${step}L);
-        |   $BigInt start = $BigInt.valueOf(${start}L);
-        |
-        |   $BigInt st = 
index.multiply(numElement).divide(numSlice).multiply(step).add(start);
-        |   if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) {
-        |     $number = Long.MAX_VALUE;
-        |   } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) {
-        |     $number = Long.MIN_VALUE;
-        |   } else {
-        |     $number = st.longValue();
-        |   }
-        |
-        |   $BigInt end = 
index.add($BigInt.ONE).multiply(numElement).divide(numSlice)
-        |     .multiply(step).add(start);
-        |   if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) {
-        |     $partitionEnd = Long.MAX_VALUE;
-        |   } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) {
-        |     $partitionEnd = Long.MIN_VALUE;
-        |   } else {
-        |     $partitionEnd = end.longValue();
-        |   }
-        | }
-       """.stripMargin)
+    val rdd = sqlContext.sparkContext.parallelize(0 until numSlices, numSlices)
+      .map(i => InternalRow(i))
 
-    s"""
+    val code = s"""
       | // initialize Range
       | if (!$initTerm) {
       |   $initTerm = true;
       |   if (input.hasNext()) {
-      |     initRange(((InternalRow) input.next()).getInt(0));
+      |     $BigInt index = $BigInt.valueOf(((InternalRow) 
input.next()).getInt(0));
+      |     $BigInt numSlice = $BigInt.valueOf(${numSlices}L);
+      |     $BigInt numElement = $BigInt.valueOf(${numElements.toLong}L);
+      |     $BigInt step = $BigInt.valueOf(${step}L);
+      |     $BigInt start = $BigInt.valueOf(${start}L);
+      |
+      |     $BigInt st = 
index.multiply(numElement).divide(numSlice).multiply(step).add(start);
+      |     if (st.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) {
+      |       $number = Long.MAX_VALUE;
+      |     } else if (st.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) {
+      |       $number = Long.MIN_VALUE;
+      |     } else {
+      |       $number = st.longValue();
+      |     }
+      |
+      |     $BigInt end = 
index.add($BigInt.ONE).multiply(numElement).divide(numSlice)
+      |       .multiply(step).add(start);
+      |     if (end.compareTo($BigInt.valueOf(Long.MAX_VALUE)) > 0) {
+      |       $partitionEnd = Long.MAX_VALUE;
+      |     } else if (end.compareTo($BigInt.valueOf(Long.MIN_VALUE)) < 0) {
+      |       $partitionEnd = Long.MIN_VALUE;
+      |     } else {
+      |       $partitionEnd = end.longValue();
+      |     }
       |   } else {
       |     return;
       |   }
@@ -234,6 +218,12 @@ case class Range(
       |  ${consume(ctx, Seq(ev))}
       | }
      """.stripMargin
+
+    (rdd, code)
+  }
+
+  def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): 
String = {
+    throw new UnsupportedOperationException
   }
 
   protected override def doExecute(): RDD[InternalRow] = {

http://git-wip-us.apache.org/repos/asf/spark/blob/b9dfdcc6/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 51a50c1..989cb29 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1939,61 +1939,58 @@ class SQLQuerySuite extends QueryTest with 
SharedSQLContext {
   }
 
   test("Common subexpression elimination") {
-    // TODO: support subexpression elimination in whole stage codegen
-    withSQLConf("spark.sql.codegen.wholeStage" -> "false") {
-      // select from a table to prevent constant folding.
-      val df = sql("SELECT a, b from testData2 limit 1")
-      checkAnswer(df, Row(1, 1))
-
-      checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2))
-      checkAnswer(df.selectExpr("a + 1", "a + 1 + 1"), Row(2, 3))
-
-      // This does not work because the expressions get grouped like (a + a) + 
1
-      checkAnswer(df.selectExpr("a + 1", "a + a + 1"), Row(2, 3))
-      checkAnswer(df.selectExpr("a + 1", "a + (a + 1)"), Row(2, 3))
-
-      // Identity udf that tracks the number of times it is called.
-      val countAcc = sparkContext.accumulator(0, "CallCount")
-      sqlContext.udf.register("testUdf", (x: Int) => {
-        countAcc.++=(1)
-        x
-      })
-
-      // Evaluates df, verifying it is equal to the expectedResult and the 
accumulator's value
-      // is correct.
-      def verifyCallCount(df: DataFrame, expectedResult: Row, expectedCount: 
Int): Unit = {
-        countAcc.setValue(0)
-        checkAnswer(df, expectedResult)
-        assert(countAcc.value == expectedCount)
-      }
+    // select from a table to prevent constant folding.
+    val df = sql("SELECT a, b from testData2 limit 1")
+    checkAnswer(df, Row(1, 1))
+
+    checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2))
+    checkAnswer(df.selectExpr("a + 1", "a + 1 + 1"), Row(2, 3))
+
+    // This does not work because the expressions get grouped like (a + a) + 1
+    checkAnswer(df.selectExpr("a + 1", "a + a + 1"), Row(2, 3))
+    checkAnswer(df.selectExpr("a + 1", "a + (a + 1)"), Row(2, 3))
+
+    // Identity udf that tracks the number of times it is called.
+    val countAcc = sparkContext.accumulator(0, "CallCount")
+    sqlContext.udf.register("testUdf", (x: Int) => {
+      countAcc.++=(1)
+      x
+    })
 
-      verifyCallCount(df.selectExpr("testUdf(a)"), Row(1), 1)
-      verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1)
-      verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a + 1)"), 
Row(2, 2), 1)
-      verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a)"), Row(2, 
1), 2)
-      verifyCallCount(
-        df.selectExpr("testUdf(a + 1) + testUdf(a + 1)", "testUdf(a + 1)"), 
Row(4, 2), 1)
-
-      verifyCallCount(
-        df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), 
Row(4, 2), 2)
-
-      val testUdf = functions.udf((x: Int) => {
-        countAcc.++=(1)
-        x
-      })
-      verifyCallCount(
-        df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), 
Row(3.0), 1)
-
-      // Would be nice if semantic equals for `+` understood commutative
-      verifyCallCount(
-        df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), 
Row(4, 2), 2)
-
-      // Try disabling it via configuration.
-      sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false")
-      verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2)
-      sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true")
-      verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1)
+    // Evaluates df, verifying it is equal to the expectedResult and the 
accumulator's value
+    // is correct.
+    def verifyCallCount(df: DataFrame, expectedResult: Row, expectedCount: 
Int): Unit = {
+      countAcc.setValue(0)
+      checkAnswer(df, expectedResult)
+      assert(countAcc.value == expectedCount)
     }
+
+    verifyCallCount(df.selectExpr("testUdf(a)"), Row(1), 1)
+    verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1)
+    verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a + 1)"), Row(2, 
2), 1)
+    verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a)"), Row(2, 1), 
2)
+    verifyCallCount(
+      df.selectExpr("testUdf(a + 1) + testUdf(a + 1)", "testUdf(a + 1)"), 
Row(4, 2), 1)
+
+    verifyCallCount(
+      df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), 
Row(4, 2), 2)
+
+    val testUdf = functions.udf((x: Int) => {
+      countAcc.++=(1)
+      x
+    })
+    verifyCallCount(
+      df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), 
Row(3.0), 1)
+
+    // Would be nice if semantic equals for `+` understood commutative
+    verifyCallCount(
+      df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), 
Row(4, 2), 2)
+
+    // Try disabling it via configuration.
+    sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false")
+    verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2)
+    sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true")
+    verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1)
   }
 
   test("SPARK-10707: nullability should be correctly propagated through set 
operations (1)") {

http://git-wip-us.apache.org/repos/asf/spark/blob/b9dfdcc6/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index 82f6811..cbae19e 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -335,24 +335,22 @@ class SQLMetricsSuite extends SparkFunSuite with 
SharedSQLContext {
 
   test("save metrics") {
     withTempPath { file =>
-      withSQLConf("spark.sql.codegen.wholeStage" -> "false") {
-        val previousExecutionIds = sqlContext.listener.executionIdToData.keySet
-        // Assume the execution plan is
-        // PhysicalRDD(nodeId = 0)
-        person.select('name).write.format("json").save(file.getAbsolutePath)
-        sparkContext.listenerBus.waitUntilEmpty(10000)
-        val executionIds = 
sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds)
-        assert(executionIds.size === 1)
-        val executionId = executionIds.head
-        val jobs = sqlContext.listener.getExecution(executionId).get.jobs
-        // Use "<=" because there is a race condition that we may miss some 
jobs
-        // TODO Change "<=" to "=" once we fix the race condition that missing 
the JobStarted event.
-        assert(jobs.size <= 1)
-        val metricValues = sqlContext.listener.getExecutionMetrics(executionId)
-        // Because "save" will create a new DataFrame internally, we cannot 
get the real metric id.
-        // However, we still can check the value.
-        assert(metricValues.values.toSeq === Seq("2"))
-      }
+      val previousExecutionIds = sqlContext.listener.executionIdToData.keySet
+      // Assume the execution plan is
+      // PhysicalRDD(nodeId = 0)
+      person.select('name).write.format("json").save(file.getAbsolutePath)
+      sparkContext.listenerBus.waitUntilEmpty(10000)
+      val executionIds = 
sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds)
+      assert(executionIds.size === 1)
+      val executionId = executionIds.head
+      val jobs = sqlContext.listener.getExecution(executionId).get.jobs
+      // Use "<=" because there is a race condition that we may miss some jobs
+      // TODO Change "<=" to "=" once we fix the race condition that missing 
the JobStarted event.
+      assert(jobs.size <= 1)
+      val metricValues = sqlContext.listener.getExecutionMetrics(executionId)
+      // Because "save" will create a new DataFrame internally, we cannot get 
the real metric id.
+      // However, we still can check the value.
+      assert(metricValues.values.toSeq === Seq("2"))
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b9dfdcc6/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
index 7d6bff8..d481437 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -199,7 +199,7 @@ private[sql] trait SQLTestUtils
     val schema = df.schema
     val childRDD = df
       .queryExecution
-      .sparkPlan.asInstanceOf[org.apache.spark.sql.execution.Filter]
+      .executedPlan.asInstanceOf[org.apache.spark.sql.execution.Filter]
       .child
       .execute()
       .map(row => Row.fromSeq(row.copy().toSeq(schema)))

http://git-wip-us.apache.org/repos/asf/spark/blob/b9dfdcc6/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
index a3e5243..9a24a24 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala
@@ -97,12 +97,10 @@ class DataFrameCallbackSuite extends QueryTest with 
SharedSQLContext {
     }
     sqlContext.listenerManager.register(listener)
 
-    withSQLConf("spark.sql.codegen.wholeStage" -> "false") {
-      val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count()
-      df.collect()
-      df.collect()
-      Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect()
-    }
+    val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count()
+    df.collect()
+    df.collect()
+    Seq(1 -> "a", 2 -> "a").toDF("i", "j").groupBy("i").count().collect()
 
     assert(metrics.length == 3)
     assert(metrics(0) == 1)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to