Repository: spark
Updated Branches:
  refs/heads/master 8fb445d9b -> 94922d79e


[SPARK-17289][SQL] Fix a bug to satisfy sort requirements in partial 
aggregations

## What changes were proposed in this pull request?
Partial aggregations are generated in `EnsureRequirements`, but the planner 
fails to
check if partial aggregation satisfies sort requirements.
For the following query:
```
val df2 = (0 to 1000).map(x => (x % 2, x.toString)).toDF("a", 
"b").createOrReplaceTempView("t2")
spark.sql("select max(b) from t2 group by a").explain(true)
```
Now, the SortAggregator won't insert Sort operator before partial aggregation, 
this will break sort-based partial aggregation.
```
== Physical Plan ==
SortAggregate(key=[a#5], functions=[max(b#6)], output=[max(b)#17])
+- *Sort [a#5 ASC], false, 0
   +- Exchange hashpartitioning(a#5, 200)
      +- SortAggregate(key=[a#5], functions=[partial_max(b#6)], output=[a#5, 
max#19])
         +- LocalTableScan [a#5, b#6]
```
Actually, a correct plan is:
```
== Physical Plan ==
SortAggregate(key=[a#5], functions=[max(b#6)], output=[max(b)#17])
+- *Sort [a#5 ASC], false, 0
   +- Exchange hashpartitioning(a#5, 200)
      +- SortAggregate(key=[a#5], functions=[partial_max(b#6)], output=[a#5, 
max#19])
         +- *Sort [a#5 ASC], false, 0
            +- LocalTableScan [a#5, b#6]
```

## How was this patch tested?
Added tests in `PlannerSuite`.

Author: Takeshi YAMAMURO <linguin....@gmail.com>

Closes #14865 from maropu/SPARK-17289.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/94922d79
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/94922d79
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/94922d79

Branch: refs/heads/master
Commit: 94922d79e9f90fac3777db0974ccf7566b8ac3b3
Parents: 8fb445d
Author: Takeshi YAMAMURO <linguin....@gmail.com>
Authored: Tue Aug 30 16:43:47 2016 +0800
Committer: Cheng Lian <l...@databricks.com>
Committed: Tue Aug 30 16:43:47 2016 +0800

----------------------------------------------------------------------
 .../execution/exchange/EnsureRequirements.scala |  3 ++-
 .../spark/sql/execution/PlannerSuite.scala      | 22 +++++++++++++++++++-
 2 files changed, 23 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/94922d79/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
index fee7010..66e99de 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
@@ -164,7 +164,8 @@ case class EnsureRequirements(conf: SQLConf) extends 
Rule[SparkPlan] {
         // If an aggregation needs a shuffle and support partial aggregations, 
a map-side partial
         // aggregation and a shuffle are added as children.
         val (mergeAgg, mapSideAgg) = 
AggUtils.createMapMergeAggregatePair(operator)
-        (mergeAgg, createShuffleExchange(requiredChildDistributions.head, 
mapSideAgg) :: Nil)
+        (mergeAgg, createShuffleExchange(
+          requiredChildDistributions.head, 
ensureDistributionAndOrdering(mapSideAgg)) :: Nil)
       case _ =>
         // Ensure that the operator's children satisfy their output 
distribution requirements:
         val childrenWithDist = 
operator.children.zip(requiredChildDistributions)

http://git-wip-us.apache.org/repos/asf/spark/blob/94922d79/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 07efc72..b0aa337 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -18,12 +18,13 @@
 package org.apache.spark.sql.execution
 
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{execution, DataFrame, Row}
+import org.apache.spark.sql.{execution, Row}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.Inner
 import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition}
 import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.execution.aggregate.SortAggregateExec
 import org.apache.spark.sql.execution.columnar.InMemoryRelation
 import org.apache.spark.sql.execution.exchange.{EnsureRequirements, 
ReusedExchangeExec, ReuseExchange, ShuffleExchange}
 import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, 
SortMergeJoinExec}
@@ -70,6 +71,25 @@ class PlannerSuite extends SharedSQLContext {
       s"The plan of query $query does not have partial aggregations.")
   }
 
+  test("SPARK-17289 sort-based partial aggregation needs a sort operator as a 
child") {
+    withTempView("testSortBasedPartialAggregation") {
+      val schema = StructType(
+        StructField(s"key", IntegerType, true) :: StructField(s"value", 
StringType, true) :: Nil)
+      val rowRDD = sparkContext.parallelize((0 until 1000).map(d => Row(d % 2, 
d.toString)))
+      spark.createDataFrame(rowRDD, schema)
+        .createOrReplaceTempView("testSortBasedPartialAggregation")
+
+      // This test assumes a query below uses sort-based aggregations
+      val planned = sql("SELECT MAX(value) FROM 
testSortBasedPartialAggregation GROUP BY key")
+        .queryExecution.executedPlan
+      // This line extracts both SortAggregate and Sort operators
+      val extractedOps = planned.collect { case n if n.nodeName contains 
"Sort" => n }
+      val aggOps = extractedOps.collect { case n if n.nodeName contains 
"SortAggregate" => n }
+      assert(extractedOps.size == 4 && aggOps.size == 2,
+        s"The plan $planned does not have correct sort-based partial aggregate 
pairs.")
+    }
+  }
+
   test("non-partial aggregation for aggregates") {
     withTempView("testNonPartialAggregation") {
       val schema = StructType(StructField(s"value", IntegerType, true) :: Nil)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to