Repository: spark
Updated Branches:
  refs/heads/branch-2.1 60e02a173 -> 021952d58


[SPARK-18528][SQL] Fix a bug to initialise an iterator of aggregation buffer

## What changes were proposed in this pull request?
This pr is to fix an `NullPointerException` issue caused by a following `limit 
+ aggregate` query;
```
scala> val df = Seq(("a", 1), ("b", 2), ("c", 1), ("d", 5)).toDF("id", "value")
scala> df.limit(2).groupBy("id").count().show
WARN TaskSetManager: Lost task 0.0 in stage 9.0 (TID 8204, 
lvsp20hdn012.stubprod.com): java.lang.NullPointerException
at 
org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.agg_doAggregateWithKeys$(Unknown
 Source)
at 
org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown
 Source)
```
The root culprit is that 
[`$doAgg()`](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala#L596)
 skips an initialization of [the buffer 
iterator](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala#L603);
 `BaseLimitExec` sets `stopEarly=true` and `$doAgg()` exits in the middle 
without the initialization.

## How was this patch tested?
Added a test to check if no exception happens for limit + aggregates in 
`DataFrameAggregateSuite.scala`.

Author: Takeshi YAMAMURO <linguin....@gmail.com>

Closes #15980 from maropu/SPARK-18528.

(cherry picked from commit b41ec997786e2be42a8a2a182212a610d08b221b)
Signed-off-by: Herman van Hovell <hvanhov...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/021952d5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/021952d5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/021952d5

Branch: refs/heads/branch-2.1
Commit: 021952d5808715d0b9d6c716f8b67cd550f7982e
Parents: 60e02a1
Author: Takeshi YAMAMURO <linguin....@gmail.com>
Authored: Thu Dec 22 01:53:33 2016 +0100
Committer: Herman van Hovell <hvanhov...@databricks.com>
Committed: Thu Dec 22 01:53:44 2016 +0100

----------------------------------------------------------------------
 .../apache/spark/sql/execution/BufferedRowIterator.java   | 10 ++++++++++
 .../spark/sql/execution/WholeStageCodegenExec.scala       |  2 +-
 .../main/scala/org/apache/spark/sql/execution/limit.scala |  6 +++---
 .../org/apache/spark/sql/DataFrameAggregateSuite.scala    |  8 ++++++++
 4 files changed, 22 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/021952d5/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java
 
b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java
index 086547c..730a4ae 100644
--- 
a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java
+++ 
b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java
@@ -70,6 +70,16 @@ public abstract class BufferedRowIterator {
   }
 
   /**
+   * Returns whether this iterator should stop fetching next row from 
[[CodegenSupport#inputRDDs]].
+   *
+   * If it returns true, the caller should exit the loop that [[InputAdapter]] 
generates.
+   * This interface is mainly used to limit the number of input rows.
+   */
+  protected boolean stopEarly() {
+    return false;
+  }
+
+  /**
    * Returns whether `processNext()` should stop processing next row from 
`input` or not.
    *
    * If it returns true, the caller should exit the loop (return from 
processNext()).

http://git-wip-us.apache.org/repos/asf/spark/blob/021952d5/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
index 516b9d5..2ead8f6 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
@@ -241,7 +241,7 @@ case class InputAdapter(child: SparkPlan) extends 
UnaryExecNode with CodegenSupp
     ctx.addMutableState("scala.collection.Iterator", input, s"$input = 
inputs[0];")
     val row = ctx.freshName("row")
     s"""
-       | while ($input.hasNext()) {
+       | while ($input.hasNext() && !stopEarly()) {
        |   InternalRow $row = (InternalRow) $input.next();
        |   ${consume(ctx, null, row).trim}
        |   if (shouldStop()) return;

http://git-wip-us.apache.org/repos/asf/spark/blob/021952d5/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
index 9918ac3..757fe21 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
@@ -70,10 +70,10 @@ trait BaseLimitExec extends UnaryExecNode with 
CodegenSupport {
     val stopEarly = ctx.freshName("stopEarly")
     ctx.addMutableState("boolean", stopEarly, s"$stopEarly = false;")
 
-    ctx.addNewFunction("shouldStop", s"""
+    ctx.addNewFunction("stopEarly", s"""
       @Override
-      protected boolean shouldStop() {
-        return !currentRows.isEmpty() || $stopEarly;
+      protected boolean stopEarly() {
+        return $stopEarly;
       }
     """)
     val countTerm = ctx.freshName("count")

http://git-wip-us.apache.org/repos/asf/spark/blob/021952d5/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 7aa4f00..6451759 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -513,4 +513,12 @@ class DataFrameAggregateSuite extends QueryTest with 
SharedSQLContext {
       df.groupBy($"x").agg(countDistinct($"y"), 
sort_array(collect_list($"z"))),
       Seq(Row(1, 2, Seq("a", "b")), Row(3, 2, Seq("c", "c", "d"))))
   }
+
+  test("SPARK-18004 limit + aggregates") {
+    val df = Seq(("a", 1), ("b", 2), ("c", 1), ("d", 5)).toDF("id", "value")
+    val limit2Df = df.limit(2)
+    checkAnswer(
+      limit2Df.groupBy("id").count().select($"id"),
+      limit2Df.select($"id"))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to