Repository: spark Updated Branches: refs/heads/master 83a6ace0d -> b41ec9977
[SPARK-18528][SQL] Fix a bug to initialise an iterator of aggregation buffer ## What changes were proposed in this pull request? This pr is to fix an `NullPointerException` issue caused by a following `limit + aggregate` query; ``` scala> val df = Seq(("a", 1), ("b", 2), ("c", 1), ("d", 5)).toDF("id", "value") scala> df.limit(2).groupBy("id").count().show WARN TaskSetManager: Lost task 0.0 in stage 9.0 (TID 8204, lvsp20hdn012.stubprod.com): java.lang.NullPointerException at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.agg_doAggregateWithKeys$(Unknown Source) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source) ``` The root culprit is that [`$doAgg()`](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala#L596) skips an initialization of [the buffer iterator](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala#L603); `BaseLimitExec` sets `stopEarly=true` and `$doAgg()` exits in the middle without the initialization. ## How was this patch tested? Added a test to check if no exception happens for limit + aggregates in `DataFrameAggregateSuite.scala`. Author: Takeshi YAMAMURO <linguin....@gmail.com> Closes #15980 from maropu/SPARK-18528. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b41ec997 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b41ec997 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b41ec997 Branch: refs/heads/master Commit: b41ec997786e2be42a8a2a182212a610d08b221b Parents: 83a6ace Author: Takeshi YAMAMURO <linguin....@gmail.com> Authored: Thu Dec 22 01:53:33 2016 +0100 Committer: Herman van Hovell <hvanhov...@databricks.com> Committed: Thu Dec 22 01:53:33 2016 +0100 ---------------------------------------------------------------------- .../apache/spark/sql/execution/BufferedRowIterator.java | 10 ++++++++++ .../spark/sql/execution/WholeStageCodegenExec.scala | 2 +- .../main/scala/org/apache/spark/sql/execution/limit.scala | 6 +++--- .../org/apache/spark/sql/DataFrameAggregateSuite.scala | 8 ++++++++ 4 files changed, 22 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/b41ec997/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java ---------------------------------------------------------------------- diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java index 086547c..730a4ae 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java @@ -70,6 +70,16 @@ public abstract class BufferedRowIterator { } /** + * Returns whether this iterator should stop fetching next row from [[CodegenSupport#inputRDDs]]. + * + * If it returns true, the caller should exit the loop that [[InputAdapter]] generates. + * This interface is mainly used to limit the number of input rows. + */ + protected boolean stopEarly() { + return false; + } + + /** * Returns whether `processNext()` should stop processing next row from `input` or not. * * If it returns true, the caller should exit the loop (return from processNext()). http://git-wip-us.apache.org/repos/asf/spark/blob/b41ec997/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 516b9d5..2ead8f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -241,7 +241,7 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") val row = ctx.freshName("row") s""" - | while ($input.hasNext()) { + | while ($input.hasNext() && !stopEarly()) { | InternalRow $row = (InternalRow) $input.next(); | ${consume(ctx, null, row).trim} | if (shouldStop()) return; http://git-wip-us.apache.org/repos/asf/spark/blob/b41ec997/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 9918ac3..757fe21 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -70,10 +70,10 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport { val stopEarly = ctx.freshName("stopEarly") ctx.addMutableState("boolean", stopEarly, s"$stopEarly = false;") - ctx.addNewFunction("shouldStop", s""" + ctx.addNewFunction("stopEarly", s""" @Override - protected boolean shouldStop() { - return !currentRows.isEmpty() || $stopEarly; + protected boolean stopEarly() { + return $stopEarly; } """) val countTerm = ctx.freshName("count") http://git-wip-us.apache.org/repos/asf/spark/blob/b41ec997/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 7aa4f00..6451759 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -513,4 +513,12 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { df.groupBy($"x").agg(countDistinct($"y"), sort_array(collect_list($"z"))), Seq(Row(1, 2, Seq("a", "b")), Row(3, 2, Seq("c", "c", "d")))) } + + test("SPARK-18004 limit + aggregates") { + val df = Seq(("a", 1), ("b", 2), ("c", 1), ("d", 5)).toDF("id", "value") + val limit2Df = df.limit(2) + checkAnswer( + limit2Df.groupBy("id").count().select($"id"), + limit2Df.select($"id")) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org