This is an automated email from the ASF dual-hosted git repository. yumwang pushed a commit to branch branch-3.4 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push: new 5a1c6e6ffe2 [SPARK-44846][SQL] Convert the lower redundant Aggregate to Project in RemoveRedundantAggregates 5a1c6e6ffe2 is described below commit 5a1c6e6ffe244461d23de98ddb317904db19fc4b Author: zml1206 <zhuml1...@gmail.com> AuthorDate: Mon Sep 4 20:23:39 2023 +0800 [SPARK-44846][SQL] Convert the lower redundant Aggregate to Project in RemoveRedundantAggregates ### What changes were proposed in this pull request? This PR provides a safe way to remove a redundant `Aggregate` in rule `RemoveRedundantAggregates`. Just convert the lower redundant `Aggregate` to `Project`. ### Why are the changes needed? The aggregate contains complex grouping expressions after `RemoveRedundantAggregates`, if `aggregateExpressions` has (if / case) branches, it is possible that `groupingExpressions` is no longer a subexpression of `aggregateExpressions` after execute `PushFoldableIntoBranches` rule, Then cause `boundReference` error. For example ``` SELECT c * 2 AS d FROM ( SELECT if(b > 1, 1, b) AS c FROM ( SELECT if(a < 0, 0, a) AS b FROM VALUES (-1), (1), (2) AS t1(a) ) t2 GROUP BY b ) t3 GROUP BY c ``` Before pr ``` == Optimized Logical Plan == Aggregate [if ((b#0 > 1)) 1 else b#0], [if ((b#0 > 1)) 2 else (b#0 * 2) AS d#2] +- Project [if ((a#3 < 0)) 0 else a#3 AS b#0] +- LocalRelation [a#3] ``` ``` == Error == Couldn't find b#0 in [if ((b#0 > 1)) 1 else b#0#7] java.lang.IllegalStateException: Couldn't find b#0 in [if ((b#0 > 1)) 1 else b#0#7] at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:80) at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:73) at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$1(TreeNode.scala:461) at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(origin.scala:76) at org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:461) at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$3(TreeNode.scala:466) at org.apache.spark.sql.catalyst.trees.BinaryLike.mapChildren(TreeNode.scala:1241) at org.apache.spark.sql.catalyst.trees.BinaryLike.mapChildren$(TreeNode.scala:1240) at org.apache.spark.sql.catalyst.expressions.BinaryExpression.mapChildren(Expression.scala:653) ...... ``` After pr ``` == Optimized Logical Plan == Aggregate [c#1], [(c#1 * 2) AS d#2] +- Project [if ((b#0 > 1)) 1 else b#0 AS c#1] +- Project [if ((a#3 < 0)) 0 else a#3 AS b#0] +- LocalRelation [a#3] ``` ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? UT Closes #42633 from zml1206/SPARK-44846-2. Authored-by: zml1206 <zhuml1...@gmail.com> Signed-off-by: Yuming Wang <yumw...@ebay.com> (cherry picked from commit 32a87f03da7eef41161a5a7a3aba4a48e0421912) Signed-off-by: Yuming Wang <yumw...@ebay.com> --- .../optimizer/RemoveRedundantAggregates.scala | 19 ++----------------- .../optimizer/RemoveRedundantAggregatesSuite.scala | 21 ++++++++++++--------- .../test/resources/sql-tests/inputs/group-by.sql | 13 +++++++++++++ .../resources/sql-tests/results/group-by.sql.out | 18 ++++++++++++++++++ 4 files changed, 45 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregates.scala index 2104bce3711..0c3d5bcf01a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregates.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.analysis.PullOutNondeterministic import org.apache.spark.sql.catalyst.expressions.{AliasHelper, AttributeSet, ExpressionSet} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} @@ -32,22 +31,8 @@ object RemoveRedundantAggregates extends Rule[LogicalPlan] with AliasHelper { def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning( _.containsPattern(AGGREGATE), ruleId) { case upper @ Aggregate(_, _, lower: Aggregate) if isLowerRedundant(upper, lower) => - val aliasMap = getAliasMap(lower) - - val newAggregate = upper.copy( - child = lower.child, - groupingExpressions = upper.groupingExpressions.map(replaceAlias(_, aliasMap)), - aggregateExpressions = upper.aggregateExpressions.map( - replaceAliasButKeepName(_, aliasMap)) - ) - - // We might have introduces non-deterministic grouping expression - if (newAggregate.groupingExpressions.exists(!_.deterministic)) { - PullOutNondeterministic.applyLocally.applyOrElse(newAggregate, identity[LogicalPlan]) - } else { - newAggregate - } - + val projectList = lower.aggregateExpressions.filter(upper.references.contains(_)) + upper.copy(child = Project(projectList, lower.child)) case agg @ Aggregate(groupingExps, _, child) if agg.groupOnly && child.distinctKeys.exists(_.subsetOf(ExpressionSet(groupingExps))) => Project(agg.aggregateExpressions, child) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregatesSuite.scala index 3fb67320f1f..5fdbf828f50 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregatesSuite.scala @@ -31,7 +31,8 @@ class RemoveRedundantAggregatesSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("RemoveRedundantAggregates", FixedPoint(10), - RemoveRedundantAggregates) :: Nil + RemoveRedundantAggregates, + RemoveNoopOperators) :: Nil } private val relation = LocalRelation($"a".int, $"b".int) @@ -53,6 +54,7 @@ class RemoveRedundantAggregatesSuite extends PlanTest { .groupBy($"a")($"a") .analyze val expected = relation + .select($"a") .groupBy($"a")($"a") .analyze val optimized = Optimize.execute(query) @@ -68,6 +70,7 @@ class RemoveRedundantAggregatesSuite extends PlanTest { .groupBy($"a")($"a") .analyze val expected = relation + .select($"a") .groupBy($"a")($"a") .analyze val optimized = Optimize.execute(query) @@ -81,6 +84,7 @@ class RemoveRedundantAggregatesSuite extends PlanTest { .groupBy($"a")($"a") .analyze val expected = relation + .select($"a") .groupBy($"a")($"a") .analyze val optimized = Optimize.execute(query) @@ -94,7 +98,8 @@ class RemoveRedundantAggregatesSuite extends PlanTest { .groupBy($"c")($"c") .analyze val expected = relation - .groupBy($"a" + $"b")(($"a" + $"b") as "c") + .select(($"a" + $"b") as "c") + .groupBy($"c")($"c") .analyze val optimized = Optimize.execute(query) comparePlans(optimized, expected) @@ -107,6 +112,7 @@ class RemoveRedundantAggregatesSuite extends PlanTest { .groupBy($"a")($"a", rand(0) as "c") .analyze val expected = relation + .select($"a") .groupBy($"a")($"a", rand(0) as "c") .analyze val optimized = Optimize.execute(query) @@ -119,7 +125,9 @@ class RemoveRedundantAggregatesSuite extends PlanTest { .groupBy($"a", $"c")($"a", $"c") .analyze val expected = relation - .groupBy($"a", $"c")($"a", rand(0) as "c") + .select($"a", $"b", rand(0) as "_nondeterministic") + .select($"a", $"_nondeterministic" as "c") + .groupBy($"a", $"c")($"a", $"c") .analyze val optimized = Optimize.execute(query) comparePlans(optimized, expected) @@ -152,7 +160,7 @@ class RemoveRedundantAggregatesSuite extends PlanTest { test("Remove redundant aggregate - upper has contains foldable expressions") { val originalQuery = x.groupBy($"a", $"b")($"a", $"b").groupBy($"a")($"a", TrueLiteral).analyze - val correctAnswer = x.groupBy($"a")($"a", TrueLiteral).analyze + val correctAnswer = x.select($"a").groupBy($"a")($"a", TrueLiteral).analyze val optimized = Optimize.execute(originalQuery) comparePlans(optimized, correctAnswer) } @@ -175,7 +183,6 @@ class RemoveRedundantAggregatesSuite extends PlanTest { .analyze val expected = relation .groupBy($"a")($"a", ($"a" + rand(0)) as "c") - .select($"a", $"c") .analyze val optimized = Optimize.execute(query) comparePlans(optimized, expected) @@ -188,7 +195,6 @@ class RemoveRedundantAggregatesSuite extends PlanTest { .groupBy("x.a".attr, "x.b".attr)("x.a".attr, "x.b".attr) val correctAnswer = x.groupBy($"a", $"b")($"a", $"b") .join(y, joinType, Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr)) - .select("x.a".attr, "x.b".attr) val optimized = Optimize.execute(originalQuery.analyze) comparePlans(optimized, correctAnswer.analyze) @@ -202,7 +208,6 @@ class RemoveRedundantAggregatesSuite extends PlanTest { .groupBy("x.a".attr, "d".attr)("x.a".attr, "d".attr) val correctAnswer = x.groupBy($"a", $"b")($"a", $"b".as("d")) .join(y, joinType, Some("x.a".attr === "y.a".attr && "d".attr === "y.b".attr)) - .select("x.a".attr, "d".attr) val optimized = Optimize.execute(originalQuery.analyze) comparePlans(optimized, correctAnswer.analyze) @@ -232,7 +237,6 @@ class RemoveRedundantAggregatesSuite extends PlanTest { .groupBy("x.a".attr, "x.b".attr)("x.a".attr) val correctAnswer = x.groupBy($"a", $"b")($"a", $"b") .join(y, joinType, Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr)) - .select("x.a".attr, "x.b".attr) .join(y, joinType, Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr)) .select("x.a".attr) @@ -248,7 +252,6 @@ class RemoveRedundantAggregatesSuite extends PlanTest { .analyze val correctAnswer = relation .groupBy($"a")($"a", count($"b").as("cnt")) - .select($"a", $"cnt") .analyze val optimized = Optimize.execute(originalQuery) comparePlans(optimized, correctAnswer) diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index c812403ba2c..c35cdb0de27 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -251,3 +251,16 @@ GROUP BY a; SELECT mode(a), mode(b) FROM testData; SELECT a, mode(b) FROM testData GROUP BY a ORDER BY a; + + +-- SPARK-44846: PushFoldableIntoBranches in complex grouping expressions cause bindReference error +SELECT c * 2 AS d +FROM ( + SELECT if(b > 1, 1, b) AS c + FROM ( + SELECT if(a < 0, 0, a) AS b + FROM VALUES (-1), (1), (2) AS t1(a) + ) t2 + GROUP BY b + ) t3 +GROUP BY c; diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index 6e7592d6978..3ebab783b75 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1103,3 +1103,21 @@ NULL 1 1 1 2 1 3 1 + + +-- !query +SELECT c * 2 AS d +FROM ( + SELECT if(b > 1, 1, b) AS c + FROM ( + SELECT if(a < 0, 0, a) AS b + FROM VALUES (-1), (1), (2) AS t1(a) + ) t2 + GROUP BY b + ) t3 +GROUP BY c +-- !query schema +struct<d:int> +-- !query output +0 +2 --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org