Repository: spark
Updated Branches:
  refs/heads/branch-1.3 a98603f8c -> 3d2eaf0a7


[SPARK-10169] [SQL] [BRANCH-1.3] Partial aggregation's plan is wrong when a 
grouping expression is used as an argument of the aggregate fucntion

https://issues.apache.org/jira/browse/SPARK-10169

Author: Wenchen Fan <cloud0...@outlook.com>
Author: Yin Huai <yh...@databricks.com>

Closes #8380 from yhuai/aggTransformDown-branch1.3.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3d2eaf0a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3d2eaf0a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3d2eaf0a

Branch: refs/heads/branch-1.3
Commit: 3d2eaf0a7701bfd9a41ba4c1b29e5be77156a9bf
Parents: a98603f
Author: Wenchen Fan <cloud0...@outlook.com>
Authored: Mon Aug 24 13:00:49 2015 -0700
Committer: Michael Armbrust <mich...@databricks.com>
Committed: Mon Aug 24 13:00:49 2015 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/planning/patterns.scala  | 14 +++++++++++--
 .../org/apache/spark/sql/SQLQuerySuite.scala    | 22 ++++++++++++++++++++
 2 files changed, 34 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3d2eaf0a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 9c8c643..d0ebe24 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -151,7 +151,10 @@ object PartialAggregation {
 
         // Replace aggregations with a new expression that computes the result 
from the already
         // computed partial evaluations and grouping values.
-        val rewrittenAggregateExpressions = 
aggregateExpressions.map(_.transformUp {
+        // transformDown is needed at here because we want to match aggregate 
function first.
+        // Otherwise, if a grouping expression is used as an argument of an 
aggregate function,
+        // we will match grouping expression first and have a wrong plan.
+        val rewrittenAggregateExpressions = 
aggregateExpressions.map(_.transformDown {
           case e: Expression if partialEvaluations.contains(new 
TreeNodeRef(e)) =>
             partialEvaluations(new TreeNodeRef(e)).finalEvaluation
 
@@ -159,8 +162,15 @@ object PartialAggregation {
             // Should trim aliases around `GetField`s. These aliases are 
introduced while
             // resolving struct field accesses, because `GetField` is not a 
`NamedExpression`.
             // (Should we just turn `GetField` into a `NamedExpression`?)
+            def trimAliases(e: Expression): Expression =
+              e.transform { case Alias(g: GetField, _) => g }
+            val trimmed = e match {
+              // Don't trim the top level Alias.
+              case Alias(child, name) => Alias(trimAliases(child), name)()
+              case _ => trimAliases(e)
+            }
             namedGroupingExpressions
-              .get(e.transform { case Alias(g: GetField, _) => g })
+              .get(trimmed)
               .map(_.toAttribute)
               .getOrElse(e)
         }).asInstanceOf[Seq[NamedExpression]]

http://git-wip-us.apache.org/repos/asf/spark/blob/3d2eaf0a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 87e7cf8..b52b606 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1099,4 +1099,26 @@ class SQLQuerySuite extends QueryTest with 
BeforeAndAfterAll {
     checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY c0.a"), Row(1))
     checkAnswer(sql("SELECT b[0].a FROM t ORDER BY c0.a"), Row(1))
   }
+
+  test("SPARK-10169: grouping expressions used as arguments of aggregate 
functions.") {
+    sqlCtx.sparkContext
+      .parallelize((1 to 1000), 50)
+      .map(i => Tuple1(i))
+      .toDF("i")
+      .registerTempTable("t")
+
+    val query = sqlCtx.sql(
+      """
+        |select i % 10, sum(if(i % 10 = 5, 1, 0)), count(i)
+        |from t
+        |where i % 10 = 5
+        |group by i % 10
+      """.stripMargin)
+
+    checkAnswer(
+      query,
+      Row(5, 100, 100))
+
+    dropTempTable("t")
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to