Repository: spark Updated Branches: refs/heads/master 8fb445d9b -> 94922d79e
[SPARK-17289][SQL] Fix a bug to satisfy sort requirements in partial aggregations ## What changes were proposed in this pull request? Partial aggregations are generated in `EnsureRequirements`, but the planner fails to check if partial aggregation satisfies sort requirements. For the following query: ``` val df2 = (0 to 1000).map(x => (x % 2, x.toString)).toDF("a", "b").createOrReplaceTempView("t2") spark.sql("select max(b) from t2 group by a").explain(true) ``` Now, the SortAggregator won't insert Sort operator before partial aggregation, this will break sort-based partial aggregation. ``` == Physical Plan == SortAggregate(key=[a#5], functions=[max(b#6)], output=[max(b)#17]) +- *Sort [a#5 ASC], false, 0 +- Exchange hashpartitioning(a#5, 200) +- SortAggregate(key=[a#5], functions=[partial_max(b#6)], output=[a#5, max#19]) +- LocalTableScan [a#5, b#6] ``` Actually, a correct plan is: ``` == Physical Plan == SortAggregate(key=[a#5], functions=[max(b#6)], output=[max(b)#17]) +- *Sort [a#5 ASC], false, 0 +- Exchange hashpartitioning(a#5, 200) +- SortAggregate(key=[a#5], functions=[partial_max(b#6)], output=[a#5, max#19]) +- *Sort [a#5 ASC], false, 0 +- LocalTableScan [a#5, b#6] ``` ## How was this patch tested? Added tests in `PlannerSuite`. Author: Takeshi YAMAMURO <linguin....@gmail.com> Closes #14865 from maropu/SPARK-17289. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/94922d79 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/94922d79 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/94922d79 Branch: refs/heads/master Commit: 94922d79e9f90fac3777db0974ccf7566b8ac3b3 Parents: 8fb445d Author: Takeshi YAMAMURO <linguin....@gmail.com> Authored: Tue Aug 30 16:43:47 2016 +0800 Committer: Cheng Lian <l...@databricks.com> Committed: Tue Aug 30 16:43:47 2016 +0800 ---------------------------------------------------------------------- .../execution/exchange/EnsureRequirements.scala | 3 ++- .../spark/sql/execution/PlannerSuite.scala | 22 +++++++++++++++++++- 2 files changed, 23 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/94922d79/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index fee7010..66e99de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -164,7 +164,8 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { // If an aggregation needs a shuffle and support partial aggregations, a map-side partial // aggregation and a shuffle are added as children. val (mergeAgg, mapSideAgg) = AggUtils.createMapMergeAggregatePair(operator) - (mergeAgg, createShuffleExchange(requiredChildDistributions.head, mapSideAgg) :: Nil) + (mergeAgg, createShuffleExchange( + requiredChildDistributions.head, ensureDistributionAndOrdering(mapSideAgg)) :: Nil) case _ => // Ensure that the operator's children satisfy their output distribution requirements: val childrenWithDist = operator.children.zip(requiredChildDistributions) http://git-wip-us.apache.org/repos/asf/spark/blob/94922d79/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 07efc72..b0aa337 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -18,12 +18,13 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{execution, DataFrame, Row} +import org.apache.spark.sql.{execution, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.aggregate.SortAggregateExec import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchange} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} @@ -70,6 +71,25 @@ class PlannerSuite extends SharedSQLContext { s"The plan of query $query does not have partial aggregations.") } + test("SPARK-17289 sort-based partial aggregation needs a sort operator as a child") { + withTempView("testSortBasedPartialAggregation") { + val schema = StructType( + StructField(s"key", IntegerType, true) :: StructField(s"value", StringType, true) :: Nil) + val rowRDD = sparkContext.parallelize((0 until 1000).map(d => Row(d % 2, d.toString))) + spark.createDataFrame(rowRDD, schema) + .createOrReplaceTempView("testSortBasedPartialAggregation") + + // This test assumes a query below uses sort-based aggregations + val planned = sql("SELECT MAX(value) FROM testSortBasedPartialAggregation GROUP BY key") + .queryExecution.executedPlan + // This line extracts both SortAggregate and Sort operators + val extractedOps = planned.collect { case n if n.nodeName contains "Sort" => n } + val aggOps = extractedOps.collect { case n if n.nodeName contains "SortAggregate" => n } + assert(extractedOps.size == 4 && aggOps.size == 2, + s"The plan $planned does not have correct sort-based partial aggregate pairs.") + } + } + test("non-partial aggregation for aggregates") { withTempView("testNonPartialAggregation") { val schema = StructType(StructField(s"value", IntegerType, true) :: Nil) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org