Repository: spark
Updated Branches:
  refs/heads/master 3f98375d8 -> 1283c3d11


[SPARK-20725][SQL] partial aggregate should behave correctly for sameResult

## What changes were proposed in this pull request?

For aggregate function with `PartialMerge` or `Final` mode, the input is 
aggregate buffers instead of the actual children expressions. So the actual 
children expressions won't affect the result, we should normalize the expr id 
for them.

## How was this patch tested?

a new regression test

Author: Wenchen Fan <wenc...@databricks.com>

Closes #17964 from cloud-fan/tmp.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1283c3d1
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1283c3d1
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1283c3d1

Branch: refs/heads/master
Commit: 1283c3d11af6d55eaf0e40d6df09dc6bcc198322
Parents: 3f98375
Author: Wenchen Fan <wenc...@databricks.com>
Authored: Sat May 13 12:09:06 2017 -0700
Committer: Herman van Hovell <hvanhov...@databricks.com>
Committed: Sat May 13 12:09:06 2017 -0700

----------------------------------------------------------------------
 .../catalyst/expressions/aggregate/interfaces.scala   | 14 ++++++++++++--
 .../apache/spark/sql/catalyst/plans/QueryPlan.scala   |  4 ++--
 .../apache/spark/sql/execution/SameResultSuite.scala  | 12 ++++++++++++
 3 files changed, 26 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1283c3d1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index 80c25d0..fffcc7c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -105,12 +105,22 @@ case class AggregateExpression(
   }
 
   // We compute the same thing regardless of our final result.
-  override lazy val canonicalized: Expression =
+  override lazy val canonicalized: Expression = {
+    val normalizedAggFunc = mode match {
+      // For PartialMerge or Final mode, the input to the `aggregateFunction` 
is aggregate buffers,
+      // and the actual children of `aggregateFunction` is not used, here we 
normalize the expr id.
+      case PartialMerge | Final => aggregateFunction.transform {
+        case a: AttributeReference => a.withExprId(ExprId(0))
+      }
+      case Partial | Complete => aggregateFunction
+    }
+
     AggregateExpression(
-      aggregateFunction.canonicalized.asInstanceOf[AggregateFunction],
+      normalizedAggFunc.canonicalized.asInstanceOf[AggregateFunction],
       mode,
       isDistinct,
       ExprId(0))
+  }
 
   override def children: Seq[Expression] = aggregateFunction :: Nil
   override def dataType: DataType = aggregateFunction.dataType

http://git-wip-us.apache.org/repos/asf/spark/blob/1283c3d1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index 51faa33..959fcf7 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -286,7 +286,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] 
extends TreeNode[PlanT
 
     def recursiveTransform(arg: Any): AnyRef = arg match {
       case e: Expression => transformExpression(e)
-      case Some(e: Expression) => Some(transformExpression(e))
+      case Some(value) => Some(recursiveTransform(value))
       case m: Map[_, _] => m
       case d: DataType => d // Avoid unpacking Structs
       case seq: Traversable[_] => seq.map(recursiveTransform)
@@ -320,7 +320,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] 
extends TreeNode[PlanT
 
     productIterator.flatMap {
       case e: Expression => e :: Nil
-      case Some(e: Expression) => e :: Nil
+      case s: Some[_] => seqToExpressions(s.toSeq)
       case seq: Traversable[_] => seqToExpressions(seq)
       case other => Nil
     }.toSeq

http://git-wip-us.apache.org/repos/asf/spark/blob/1283c3d1/sql/core/src/test/scala/org/apache/spark/sql/execution/SameResultSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/SameResultSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/SameResultSuite.scala
index 25e4ca0..aaf51b5 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/SameResultSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/SameResultSuite.scala
@@ -18,12 +18,14 @@
 package org.apache.spark.sql.execution
 
 import org.apache.spark.sql.{DataFrame, QueryTest}
+import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.SharedSQLContext
 
 /**
  * Tests for the sameResult function for [[SparkPlan]]s.
  */
 class SameResultSuite extends QueryTest with SharedSQLContext {
+  import testImplicits._
 
   test("FileSourceScanExec: different orders of data filters and partition 
filters") {
     withTempPath { path =>
@@ -46,4 +48,14 @@ class SameResultSuite extends QueryTest with 
SharedSQLContext {
     df.queryExecution.sparkPlan.find(_.isInstanceOf[FileSourceScanExec]).get
       .asInstanceOf[FileSourceScanExec]
   }
+
+  test("SPARK-20725: partial aggregate should behave correctly for 
sameResult") {
+    val df1 = spark.range(10).agg(sum($"id"))
+    val df2 = spark.range(10).agg(sum($"id"))
+    
assert(df1.queryExecution.executedPlan.sameResult(df2.queryExecution.executedPlan))
+
+    val df3 = spark.range(10).agg(sumDistinct($"id"))
+    val df4 = spark.range(10).agg(sumDistinct($"id"))
+    
assert(df3.queryExecution.executedPlan.sameResult(df4.queryExecution.executedPlan))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to