This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new c6df89079886 [SPARK-49000][SQL] Fix "select count(distinct 1) from t" where t is empty table by expanding RewriteDistinctAggregates c6df89079886 is described below commit c6df890798862d0863afbfff8fca0ee4df70354f Author: Uros Bojanic <157381213+uros...@users.noreply.github.com> AuthorDate: Wed Jul 31 22:37:42 2024 +0800 [SPARK-49000][SQL] Fix "select count(distinct 1) from t" where t is empty table by expanding RewriteDistinctAggregates Fix `RewriteDistinctAggregates` rule to deal properly with aggregation on DISTINCT literals. Physical plan for `select count(distinct 1) from t`: ``` -- count(distinct 1) == Physical Plan == AdaptiveSparkPlan isFinalPlan=false +- HashAggregate(keys=[], functions=[count(distinct 1)], output=[count(DISTINCT 1)#2L]) +- HashAggregate(keys=[], functions=[partial_count(distinct 1)], output=[count#6L]) +- HashAggregate(keys=[], functions=[], output=[]) +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [plan_id=20] +- HashAggregate(keys=[], functions=[], output=[]) +- FileScan parquet spark_catalog.default.t[] Batched: false, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/Users/nikola.mandic/oss-spark/spark-warehouse/org.apache.spark.s..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct<> ``` Problem is happening when `HashAggregate(keys=[], functions=[], output=[])` node yields one row to `partial_count` node, which then captures one row. This four-node structure is constructed by `AggUtils.planAggregateWithOneDistinct`. To fix the problem, we're adding `Expand` node which will force non-empty grouping expressions in `HashAggregateExec` nodes. This will in turn enable streaming zero rows to parent `partial_count` node, yielding correct final result. Aggregation with DISTINCT literal gives wrong results. For example, when running on empty table `t`: `select count(distinct 1) from t` returns 1, while the correct result should be 0. For reference: `select count(1) from t` returns 0, which is the correct and expected result. Yes, this fixes a critical bug in Spark. New e2e SQL tests for aggregates with DISTINCT literals. No. Closes #47525 from nikolamand-db/SPARK-49000-spark-expand-approach. Lead-authored-by: Uros Bojanic <157381213+uros...@users.noreply.github.com> Co-authored-by: Nikola Mandic <nikola.man...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> (cherry picked from commit dfa21332f20fff4aa6052ffa556d206497c066cf) Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../optimizer/RewriteDistinctAggregates.scala | 13 ++- .../apache/spark/sql/DataFrameAggregateSuite.scala | 114 +++++++++++++++++++++ 2 files changed, 125 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index da3cf782f668..e91493188873 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -197,6 +197,15 @@ import org.apache.spark.util.collection.Utils * techniques. */ object RewriteDistinctAggregates extends Rule[LogicalPlan] { + private def mustRewrite( + aggregateExpressions: Seq[AggregateExpression], + groupingExpressions: Seq[Expression]): Boolean = { + // If there are any AggregateExpressions with filter, we need to rewrite the query. + // Also, if there are no grouping expressions and all aggregate expressions are foldable, + // we need to rewrite the query, e.g. SELECT COUNT(DISTINCT 1). + aggregateExpressions.exists(_.filter.isDefined) || (groupingExpressions.isEmpty && + aggregateExpressions.exists(_.aggregateFunction.children.forall(_.foldable))) + } private def mayNeedtoRewrite(a: Aggregate): Boolean = { val aggExpressions = collectAggregateExprs(a) @@ -205,7 +214,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // clause for this rule because aggregation strategy can handle a single distinct aggregate // group without filter clause. // This check can produce false-positives, e.g., SUM(DISTINCT a) & COUNT(DISTINCT a). - distinctAggs.size > 1 || distinctAggs.exists(_.filter.isDefined) + distinctAggs.size > 1 || mustRewrite(distinctAggs, a.groupingExpressions) } def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning( @@ -236,7 +245,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } // Aggregation strategy can handle queries with a single distinct group without filter clause. - if (distinctAggGroups.size > 1 || distinctAggs.exists(_.filter.isDefined)) { + if (distinctAggGroups.size > 1 || mustRewrite(distinctAggs, a.groupingExpressions)) { // Create the attributes for the grouping id and the group by clause. val gid = AttributeReference("gid", IntegerType, nullable = false)() val groupByMap = a.groupingExpressions.collect { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 1ba3f6c84d0a..d8e3a046655f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -24,6 +24,7 @@ import scala.util.Random import org.scalatest.matchers.must.Matchers.the import org.apache.spark.{SparkException, SparkThrowable} +import org.apache.spark.sql.catalyst.plans.logical.Expand import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} @@ -2150,6 +2151,119 @@ class DataFrameAggregateSuite extends QueryTest checkAnswer(df, Row(1, 2, 2) :: Row(3, 1, 1) :: Nil) } } + + test("aggregating with various distinct expressions") { + abstract class AggregateTestCaseBase( + val query: String, + val resultSeq: Seq[Seq[Row]], + val hasExpandNodeInPlan: Boolean) + case class AggregateTestCase( + override val query: String, + override val resultSeq: Seq[Seq[Row]], + override val hasExpandNodeInPlan: Boolean) + extends AggregateTestCaseBase(query, resultSeq, hasExpandNodeInPlan) + case class AggregateTestCaseDefault( + override val query: String) + extends AggregateTestCaseBase( + query, + Seq(Seq(Row(0)), Seq(Row(1)), Seq(Row(1))), + hasExpandNodeInPlan = true) + + val t = "t" + val testCases: Seq[AggregateTestCaseBase] = Seq( + AggregateTestCaseDefault( + s"""SELECT COUNT(DISTINCT "col") FROM $t""" + ), + AggregateTestCaseDefault( + s"SELECT COUNT(DISTINCT 1) FROM $t" + ), + AggregateTestCaseDefault( + s"SELECT COUNT(DISTINCT 1 + 2) FROM $t" + ), + AggregateTestCaseDefault( + s"SELECT COUNT(DISTINCT 1, 2, 1 + 2) FROM $t" + ), + AggregateTestCase( + s"SELECT COUNT(1), COUNT(DISTINCT 1) FROM $t", + Seq(Seq(Row(0, 0)), Seq(Row(1, 1)), Seq(Row(2, 1))), + hasExpandNodeInPlan = true + ), + AggregateTestCaseDefault( + s"""SELECT COUNT(DISTINCT 1, "col") FROM $t""" + ), + AggregateTestCaseDefault( + s"""SELECT COUNT(DISTINCT collation("abc")) FROM $t""" + ), + AggregateTestCaseDefault( + s"""SELECT COUNT(DISTINCT current_date()) FROM $t""" + ), + AggregateTestCaseDefault( + s"""SELECT COUNT(DISTINCT array(1, 2)[1]) FROM $t""" + ), + AggregateTestCaseDefault( + s"""SELECT COUNT(DISTINCT map(1, 2)[1]) FROM $t""" + ), + AggregateTestCaseDefault( + s"""SELECT COUNT(DISTINCT struct(1, 2).col1) FROM $t""" + ), + AggregateTestCase( + s"SELECT COUNT(DISTINCT 1) FROM $t GROUP BY col", + Seq(Seq(), Seq(Row(1)), Seq(Row(1), Row(1))), + hasExpandNodeInPlan = false + ), + AggregateTestCaseDefault( + s"SELECT COUNT(DISTINCT 1) FROM $t WHERE 1 = 1" + ), + AggregateTestCase( + s"SELECT COUNT(DISTINCT 1) FROM $t WHERE 1 = 0", + Seq(Seq(Row(0)), Seq(Row(0)), Seq(Row(0))), + hasExpandNodeInPlan = false + ), + AggregateTestCase( + s"SELECT SUM(DISTINCT 1) FROM (SELECT COUNT(DISTINCT 1) FROM $t)", + Seq(Seq(Row(1)), Seq(Row(1)), Seq(Row(1))), + hasExpandNodeInPlan = false + ), + AggregateTestCase( + s"SELECT SUM(DISTINCT 1) FROM (SELECT COUNT(1) FROM $t)", + Seq(Seq(Row(1)), Seq(Row(1)), Seq(Row(1))), + hasExpandNodeInPlan = false + ), + AggregateTestCase( + s"SELECT SUM(1) FROM (SELECT COUNT(DISTINCT 1) FROM $t)", + Seq(Seq(Row(1)), Seq(Row(1)), Seq(Row(1))), + hasExpandNodeInPlan = false + ), + AggregateTestCaseDefault( + s"SELECT SUM(x) FROM (SELECT COUNT(DISTINCT 1) AS x FROM $t)"), + AggregateTestCase( + s"""SELECT COUNT(DISTINCT 1), COUNT(DISTINCT "col") FROM $t""", + Seq(Seq(Row(0, 0)), Seq(Row(1, 1)), Seq(Row(1, 1))), + hasExpandNodeInPlan = true + ), + AggregateTestCase( + s"""SELECT COUNT(DISTINCT 1), COUNT(DISTINCT col) FROM $t""", + Seq(Seq(Row(0, 0)), Seq(Row(1, 1)), Seq(Row(1, 2))), + hasExpandNodeInPlan = true + ) + ) + withTable(t) { + sql(s"create table $t(col int) using parquet") + Seq(0, 1, 2).foreach(columnValue => { + if (columnValue != 0) { + sql(s"insert into $t(col) values($columnValue)") + } + testCases.foreach(testCase => { + val query = sql(testCase.query) + checkAnswer(query, testCase.resultSeq(columnValue)) + val hasExpandNodeInPlan = query.queryExecution.optimizedPlan.collectFirst { + case _: Expand => true + }.nonEmpty + assert(hasExpandNodeInPlan == testCase.hasExpandNodeInPlan) + }) + }) + } + } } case class B(c: Option[Double]) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org