Repository: spark
Updated Branches:
  refs/heads/master 344e3aab8 -> 9b33dfc40


[SPARK-22951][SQL] fix aggregation after dropDuplicates on empty data frames

## What changes were proposed in this pull request?

(courtesy of liancheng)

Spark SQL supports both global aggregation and grouping aggregation. Global 
aggregation always return a single row with the initial aggregation state as 
the output, even there are zero input rows. Spark implements this by simply 
checking the number of grouping keys and treats an aggregation as a global 
aggregation if it has zero grouping keys.

However, this simple principle drops the ball in the following case:

```scala
spark.emptyDataFrame.dropDuplicates().agg(count($"*") as "c").show()
// +---+
// | c |
// +---+
// | 1 |
// +---+
```

The reason is that:

1. `df.dropDuplicates()` is roughly translated into something equivalent to:

```scala
val allColumns = df.columns.map { col }
df.groupBy(allColumns: _*).agg(allColumns.head, allColumns.tail: _*)
```

This translation is implemented in the rule `ReplaceDeduplicateWithAggregate`.

2. `spark.emptyDataFrame` contains zero columns and zero rows.

Therefore, rule `ReplaceDeduplicateWithAggregate` makes a confusing 
transformation roughly equivalent to the following one:

```scala
spark.emptyDataFrame.dropDuplicates()
=> spark.emptyDataFrame.groupBy().agg(Map.empty[String, String])
```

The above transformation is confusing because the resulting aggregate operator 
contains no grouping keys (because `emptyDataFrame` contains no columns), and 
gets recognized as a global aggregation. As a result, Spark SQL allocates a 
single row filled by the initial aggregation state and uses it as the output, 
and returns a wrong result.

To fix this issue, this PR tweaks `ReplaceDeduplicateWithAggregate` by 
appending a literal `1` to the grouping key list of the resulting `Aggregate` 
operator when the input plan contains zero output columns. In this way, 
`spark.emptyDataFrame.dropDuplicates()` is now translated into a grouping 
aggregation, roughly depicted as:

```scala
spark.emptyDataFrame.dropDuplicates()
=> spark.emptyDataFrame.groupBy(lit(1)).agg(Map.empty[String, String])
```

Which is now properly treated as a grouping aggregation and returns the correct 
answer.

## How was this patch tested?

New unit tests added

Author: Feng Liu <feng...@databricks.com>

Closes #20174 from liufengdb/fix-duplicate.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9b33dfc4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9b33dfc4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9b33dfc4

Branch: refs/heads/master
Commit: 9b33dfc408de986f4203bb0ac0c3f5c56effd69d
Parents: 344e3aa
Author: Feng Liu <feng...@databricks.com>
Authored: Wed Jan 10 14:25:04 2018 -0800
Committer: Cheng Lian <lian.cs....@gmail.com>
Committed: Wed Jan 10 14:25:04 2018 -0800

----------------------------------------------------------------------
 .../sql/catalyst/optimizer/Optimizer.scala      |  8 ++++++-
 .../optimizer/ReplaceOperatorSuite.scala        | 10 +++++++-
 .../spark/sql/DataFrameAggregateSuite.scala     | 24 ++++++++++++++++++--
 3 files changed, 38 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9b33dfc4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index df0af82..c794ba8 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -1222,7 +1222,13 @@ object ReplaceDeduplicateWithAggregate extends 
Rule[LogicalPlan] {
           Alias(new First(attr).toAggregateExpression(), 
attr.name)(attr.exprId)
         }
       }
-      Aggregate(keys, aggCols, child)
+      // SPARK-22951: Physical aggregate operators distinguishes global 
aggregation and grouping
+      // aggregations by checking the number of grouping keys. The key 
difference here is that a
+      // global aggregation always returns at least one row even if there are 
no input rows. Here
+      // we append a literal when the grouping key list is empty so that the 
result aggregate
+      // operator is properly treated as a grouping aggregation.
+      val nonemptyKeys = if (keys.isEmpty) Literal(1) :: Nil else keys
+      Aggregate(nonemptyKeys, aggCols, child)
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/9b33dfc4/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
index 0fa1aae..e9701ff 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions.{Alias, Not}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, Not}
 import org.apache.spark.sql.catalyst.expressions.aggregate.First
 import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi, PlanTest}
 import org.apache.spark.sql.catalyst.plans.logical._
@@ -198,6 +198,14 @@ class ReplaceOperatorSuite extends PlanTest {
     comparePlans(optimized, correctAnswer)
   }
 
+  test("add one grouping key if necessary when replace Deduplicate with 
Aggregate") {
+    val input = LocalRelation()
+    val query = Deduplicate(Seq.empty, input) // dropDuplicates()
+    val optimized = Optimize.execute(query.analyze)
+    val correctAnswer = Aggregate(Seq(Literal(1)), input.output, input)
+    comparePlans(optimized, correctAnswer)
+  }
+
   test("don't replace streaming Deduplicate") {
     val input = LocalRelation(Seq('a.int, 'b.int), isStreaming = true)
     val attrA = input.output(0)

http://git-wip-us.apache.org/repos/asf/spark/blob/9b33dfc4/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 06848e4..e7776e3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql
 
 import scala.util.Random
 
+import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
+import org.apache.spark.sql.catalyst.expressions.aggregate.Count
 import org.apache.spark.sql.execution.WholeStageCodegenExec
 import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, 
ObjectHashAggregateExec, SortAggregateExec}
 import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
@@ -27,7 +29,7 @@ import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.test.SQLTestData.DecimalData
-import org.apache.spark.sql.types.{Decimal, DecimalType}
+import org.apache.spark.sql.types.DecimalType
 
 case class Fact(date: Int, hour: Int, minute: Int, room_name: String, temp: 
Double)
 
@@ -456,7 +458,6 @@ class DataFrameAggregateSuite extends QueryTest with 
SharedSQLContext {
 
   test("null moments") {
     val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
-
     checkAnswer(
       emptyTableData.agg(variance('a), var_samp('a), var_pop('a), 
skewness('a), kurtosis('a)),
       Row(null, null, null, null, null))
@@ -666,4 +667,23 @@ class DataFrameAggregateSuite extends QueryTest with 
SharedSQLContext {
       assert(exchangePlans.length == 1)
     }
   }
+
+  Seq(true, false).foreach { codegen =>
+    test("SPARK-22951: dropDuplicates on empty dataFrames should produce 
correct aggregate " +
+      s"results when codegen is enabled: $codegen") {
+      withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, codegen.toString)) {
+        // explicit global aggregations
+        val emptyAgg = Map.empty[String, String]
+        checkAnswer(spark.emptyDataFrame.agg(emptyAgg), Seq(Row()))
+        checkAnswer(spark.emptyDataFrame.groupBy().agg(emptyAgg), Seq(Row()))
+        checkAnswer(spark.emptyDataFrame.groupBy().agg(count("*")), 
Seq(Row(0)))
+        checkAnswer(spark.emptyDataFrame.dropDuplicates().agg(emptyAgg), 
Seq(Row()))
+        
checkAnswer(spark.emptyDataFrame.dropDuplicates().groupBy().agg(emptyAgg), 
Seq(Row()))
+        
checkAnswer(spark.emptyDataFrame.dropDuplicates().groupBy().agg(count("*")), 
Seq(Row(0)))
+
+        // global aggregation is converted to grouping aggregation:
+        assert(spark.emptyDataFrame.dropDuplicates().count() == 0)
+      }
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to