Repository: spark Updated Branches: refs/heads/branch-1.4 c73498773 -> 2671551a9
[SPARK-10169] [SQL] [BRANCH-1.4] Partial aggregation's plan is wrong when a grouping expression is used as an argument of the aggregate fucntion https://issues.apache.org/jira/browse/SPARK-10169 Author: Yin Huai <yh...@databricks.com> Author: Wenchen Fan <cloud0...@outlook.com> Closes #8379 from yhuai/aggTransformDown-branch1.4. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2671551a Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2671551a Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2671551a Branch: refs/heads/branch-1.4 Commit: 2671551a94f203bcfb3d0ab11e551c2f9865f4ea Parents: c734987 Author: Yin Huai <yh...@databricks.com> Authored: Mon Aug 24 13:02:06 2015 -0700 Committer: Michael Armbrust <mich...@databricks.com> Committed: Mon Aug 24 13:02:06 2015 -0700 ---------------------------------------------------------------------- .../spark/sql/catalyst/planning/patterns.scala | 13 ++++++++++-- .../org/apache/spark/sql/SQLQuerySuite.scala | 22 ++++++++++++++++++++ 2 files changed, 33 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/2671551a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 1dd75a8..c1b88d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -151,7 +151,10 @@ object PartialAggregation { // Replace aggregations with a new expression that computes the result from the already // computed partial evaluations and grouping values. - val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp { + // transformDown is needed at here because we want to match aggregate function first. + // Otherwise, if a grouping expression is used as an argument of an aggregate function, + // we will match grouping expression first and have a wrong plan. + val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformDown { case e: Expression if partialEvaluations.contains(new TreeNodeRef(e)) => partialEvaluations(new TreeNodeRef(e)).finalEvaluation @@ -159,7 +162,13 @@ object PartialAggregation { // Should trim aliases around `GetField`s. These aliases are introduced while // resolving struct field accesses, because `GetField` is not a `NamedExpression`. // (Should we just turn `GetField` into a `NamedExpression`?) - val trimmed = e.transform { case Alias(g: ExtractValue, _) => g } + def trimAliases(e: Expression): Expression = + e.transform { case Alias(g: ExtractValue, _) => g } + val trimmed = e match { + // Don't trim the top level Alias. + case Alias(child, name) => Alias(trimAliases(child), name)() + case _ => trimAliases(e) + } namedGroupingExpressions .find { case (k, v) => k semanticEquals trimmed } .map(_._2.toAttribute) http://git-wip-us.apache.org/repos/asf/spark/blob/2671551a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 8a0679e..1067b94 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1335,4 +1335,26 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { checkAnswer(sql("SELECT a.`c.b`, `b.$q`[0].`a@!.q`, `q.w`.`w.i&`[0] FROM t"), Row(1, 1, 1)) } + + test("SPARK-10169: grouping expressions used as arguments of aggregate functions.") { + sqlCtx.sparkContext + .parallelize((1 to 1000), 50) + .map(i => Tuple1(i)) + .toDF("i") + .registerTempTable("t") + + val query = sqlCtx.sql( + """ + |select i % 10, sum(if(i % 10 = 5, 1, 0)), count(i) + |from t + |where i % 10 = 5 + |group by i % 10 + """.stripMargin) + + checkAnswer( + query, + Row(5, 100, 100)) + + dropTempTable("t") + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org