This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new c6df89079886 [SPARK-49000][SQL] Fix "select count(distinct 1) from t" 
where t is empty table by expanding RewriteDistinctAggregates
c6df89079886 is described below

commit c6df890798862d0863afbfff8fca0ee4df70354f
Author: Uros Bojanic <157381213+uros...@users.noreply.github.com>
AuthorDate: Wed Jul 31 22:37:42 2024 +0800

    [SPARK-49000][SQL] Fix "select count(distinct 1) from t" where t is empty 
table by expanding RewriteDistinctAggregates
    
    Fix `RewriteDistinctAggregates` rule to deal properly with aggregation on 
DISTINCT literals. Physical plan for `select count(distinct 1) from t`:
    ```
    -- count(distinct 1)
    == Physical Plan ==
    AdaptiveSparkPlan isFinalPlan=false
    +- HashAggregate(keys=[], functions=[count(distinct 1)], 
output=[count(DISTINCT 1)#2L])
       +- HashAggregate(keys=[], functions=[partial_count(distinct 1)], 
output=[count#6L])
          +- HashAggregate(keys=[], functions=[], output=[])
             +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [plan_id=20]
                +- HashAggregate(keys=[], functions=[], output=[])
                   +- FileScan parquet spark_catalog.default.t[] Batched: 
false, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex(1 
paths)[file:/Users/nikola.mandic/oss-spark/spark-warehouse/org.apache.spark.s...,
 PartitionFilters: [], PushedFilters: [], ReadSchema: struct<>
    ```
    Problem is happening when `HashAggregate(keys=[], functions=[], output=[])` 
node yields one row to `partial_count` node, which then captures one row. This 
four-node structure is constructed by `AggUtils.planAggregateWithOneDistinct`.
    
    To fix the problem, we're adding `Expand` node which will force non-empty 
grouping expressions in `HashAggregateExec` nodes. This will in turn enable 
streaming zero rows to parent `partial_count` node, yielding correct final 
result.
    
    Aggregation with DISTINCT literal gives wrong results. For example, when 
running on empty table `t`:
    `select count(distinct 1) from t` returns 1, while the correct result 
should be 0.
    For reference:
    `select count(1) from t` returns 0, which is the correct and expected 
result.
    
    Yes, this fixes a critical bug in Spark.
    
    New e2e SQL tests for aggregates with DISTINCT literals.
    
    No.
    
    Closes #47525 from nikolamand-db/SPARK-49000-spark-expand-approach.
    
    Lead-authored-by: Uros Bojanic <157381213+uros...@users.noreply.github.com>
    Co-authored-by: Nikola Mandic <nikola.man...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
    (cherry picked from commit dfa21332f20fff4aa6052ffa556d206497c066cf)
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../optimizer/RewriteDistinctAggregates.scala      |  13 ++-
 .../apache/spark/sql/DataFrameAggregateSuite.scala | 114 +++++++++++++++++++++
 2 files changed, 125 insertions(+), 2 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
index da3cf782f668..e91493188873 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
@@ -197,6 +197,15 @@ import org.apache.spark.util.collection.Utils
  * techniques.
  */
 object RewriteDistinctAggregates extends Rule[LogicalPlan] {
+  private def mustRewrite(
+      aggregateExpressions: Seq[AggregateExpression],
+      groupingExpressions: Seq[Expression]): Boolean = {
+    // If there are any AggregateExpressions with filter, we need to rewrite 
the query.
+    // Also, if there are no grouping expressions and all aggregate 
expressions are foldable,
+    // we need to rewrite the query, e.g. SELECT COUNT(DISTINCT 1).
+    aggregateExpressions.exists(_.filter.isDefined) || 
(groupingExpressions.isEmpty &&
+      
aggregateExpressions.exists(_.aggregateFunction.children.forall(_.foldable)))
+  }
 
   private def mayNeedtoRewrite(a: Aggregate): Boolean = {
     val aggExpressions = collectAggregateExprs(a)
@@ -205,7 +214,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
     // clause for this rule because aggregation strategy can handle a single 
distinct aggregate
     // group without filter clause.
     // This check can produce false-positives, e.g., SUM(DISTINCT a) & 
COUNT(DISTINCT a).
-    distinctAggs.size > 1 || distinctAggs.exists(_.filter.isDefined)
+    distinctAggs.size > 1 || mustRewrite(distinctAggs, a.groupingExpressions)
   }
 
   def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
@@ -236,7 +245,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
     }
 
     // Aggregation strategy can handle queries with a single distinct group 
without filter clause.
-    if (distinctAggGroups.size > 1 || distinctAggs.exists(_.filter.isDefined)) 
{
+    if (distinctAggGroups.size > 1 || mustRewrite(distinctAggs, 
a.groupingExpressions)) {
       // Create the attributes for the grouping id and the group by clause.
       val gid = AttributeReference("gid", IntegerType, nullable = false)()
       val groupByMap = a.groupingExpressions.collect {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 1ba3f6c84d0a..d8e3a046655f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -24,6 +24,7 @@ import scala.util.Random
 import org.scalatest.matchers.must.Matchers.the
 
 import org.apache.spark.{SparkException, SparkThrowable}
+import org.apache.spark.sql.catalyst.plans.logical.Expand
 import org.apache.spark.sql.execution.WholeStageCodegenExec
 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
 import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, 
ObjectHashAggregateExec, SortAggregateExec}
@@ -2150,6 +2151,119 @@ class DataFrameAggregateSuite extends QueryTest
       checkAnswer(df, Row(1, 2, 2) :: Row(3, 1, 1) :: Nil)
     }
   }
+
+  test("aggregating with various distinct expressions") {
+    abstract class AggregateTestCaseBase(
+        val query: String,
+        val resultSeq: Seq[Seq[Row]],
+        val hasExpandNodeInPlan: Boolean)
+    case class AggregateTestCase(
+        override val query: String,
+        override val resultSeq: Seq[Seq[Row]],
+        override val hasExpandNodeInPlan: Boolean)
+      extends AggregateTestCaseBase(query, resultSeq, hasExpandNodeInPlan)
+    case class AggregateTestCaseDefault(
+        override val query: String)
+      extends AggregateTestCaseBase(
+        query,
+        Seq(Seq(Row(0)), Seq(Row(1)), Seq(Row(1))),
+        hasExpandNodeInPlan = true)
+
+    val t = "t"
+    val testCases: Seq[AggregateTestCaseBase] = Seq(
+      AggregateTestCaseDefault(
+        s"""SELECT COUNT(DISTINCT "col") FROM $t"""
+      ),
+      AggregateTestCaseDefault(
+        s"SELECT COUNT(DISTINCT 1) FROM $t"
+      ),
+      AggregateTestCaseDefault(
+        s"SELECT COUNT(DISTINCT 1 + 2) FROM $t"
+      ),
+      AggregateTestCaseDefault(
+        s"SELECT COUNT(DISTINCT 1, 2, 1 + 2) FROM $t"
+      ),
+      AggregateTestCase(
+        s"SELECT COUNT(1), COUNT(DISTINCT 1) FROM $t",
+        Seq(Seq(Row(0, 0)), Seq(Row(1, 1)), Seq(Row(2, 1))),
+        hasExpandNodeInPlan = true
+      ),
+      AggregateTestCaseDefault(
+        s"""SELECT COUNT(DISTINCT 1, "col") FROM $t"""
+      ),
+      AggregateTestCaseDefault(
+        s"""SELECT COUNT(DISTINCT collation("abc")) FROM $t"""
+      ),
+      AggregateTestCaseDefault(
+        s"""SELECT COUNT(DISTINCT current_date()) FROM $t"""
+      ),
+      AggregateTestCaseDefault(
+        s"""SELECT COUNT(DISTINCT array(1, 2)[1]) FROM $t"""
+      ),
+      AggregateTestCaseDefault(
+        s"""SELECT COUNT(DISTINCT map(1, 2)[1]) FROM $t"""
+      ),
+      AggregateTestCaseDefault(
+        s"""SELECT COUNT(DISTINCT struct(1, 2).col1) FROM $t"""
+      ),
+      AggregateTestCase(
+        s"SELECT COUNT(DISTINCT 1) FROM $t GROUP BY col",
+        Seq(Seq(), Seq(Row(1)), Seq(Row(1), Row(1))),
+        hasExpandNodeInPlan = false
+      ),
+      AggregateTestCaseDefault(
+        s"SELECT COUNT(DISTINCT 1) FROM $t WHERE 1 = 1"
+      ),
+      AggregateTestCase(
+        s"SELECT COUNT(DISTINCT 1) FROM $t WHERE 1 = 0",
+        Seq(Seq(Row(0)), Seq(Row(0)), Seq(Row(0))),
+        hasExpandNodeInPlan = false
+      ),
+      AggregateTestCase(
+        s"SELECT SUM(DISTINCT 1) FROM (SELECT COUNT(DISTINCT 1) FROM $t)",
+        Seq(Seq(Row(1)), Seq(Row(1)), Seq(Row(1))),
+        hasExpandNodeInPlan = false
+      ),
+      AggregateTestCase(
+        s"SELECT SUM(DISTINCT 1) FROM (SELECT COUNT(1) FROM $t)",
+        Seq(Seq(Row(1)), Seq(Row(1)), Seq(Row(1))),
+        hasExpandNodeInPlan = false
+      ),
+      AggregateTestCase(
+        s"SELECT SUM(1) FROM (SELECT COUNT(DISTINCT 1) FROM $t)",
+        Seq(Seq(Row(1)), Seq(Row(1)), Seq(Row(1))),
+        hasExpandNodeInPlan = false
+      ),
+      AggregateTestCaseDefault(
+        s"SELECT SUM(x) FROM (SELECT COUNT(DISTINCT 1) AS x FROM $t)"),
+      AggregateTestCase(
+        s"""SELECT COUNT(DISTINCT 1), COUNT(DISTINCT "col") FROM $t""",
+        Seq(Seq(Row(0, 0)), Seq(Row(1, 1)), Seq(Row(1, 1))),
+        hasExpandNodeInPlan = true
+      ),
+      AggregateTestCase(
+        s"""SELECT COUNT(DISTINCT 1), COUNT(DISTINCT col) FROM $t""",
+        Seq(Seq(Row(0, 0)), Seq(Row(1, 1)), Seq(Row(1, 2))),
+        hasExpandNodeInPlan = true
+      )
+    )
+    withTable(t) {
+      sql(s"create table $t(col int) using parquet")
+      Seq(0, 1, 2).foreach(columnValue => {
+        if (columnValue != 0) {
+          sql(s"insert into $t(col) values($columnValue)")
+        }
+        testCases.foreach(testCase => {
+          val query = sql(testCase.query)
+          checkAnswer(query, testCase.resultSeq(columnValue))
+          val hasExpandNodeInPlan = 
query.queryExecution.optimizedPlan.collectFirst {
+            case _: Expand => true
+          }.nonEmpty
+          assert(hasExpandNodeInPlan == testCase.hasExpandNodeInPlan)
+        })
+      })
+    }
+  }
 }
 
 case class B(c: Option[Double])


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to