This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new c1ba963e64a2 [SPARK-46378][SQL] Still remove Sort after converting Aggregate to Project c1ba963e64a2 is described below commit c1ba963e64a22dea28e17b1ed954e6d03d38da1e Author: Wenchen Fan <wenc...@databricks.com> AuthorDate: Tue Dec 12 10:04:40 2023 -0800 [SPARK-46378][SQL] Still remove Sort after converting Aggregate to Project ### What changes were proposed in this pull request? This is a follow-up of https://github.com/apache/spark/pull/33397 to avoid sub-optimal plans. After converting `Aggregate` to `Project`, there is information lost: `Aggregate` doesn't care about the data order of inputs, but `Project` cares. `EliminateSorts` can remove `Sort` below `Aggregate`, but it doesn't work anymore if we convert `Aggregate` to `Project`. This PR fixes this issue by tagging the `Project` to be order-irrelevant if it's converted from `Aggregate`. Then `EliminateSorts` optimizes the tagged `Project`. ### Why are the changes needed? avoid sub-optimal plans ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? new test ### Was this patch authored or co-authored using generative AI tooling? No Closes #44310 from cloud-fan/sort. Authored-by: Wenchen Fan <wenc...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../scala/org/apache/spark/sql/catalyst/dsl/package.scala | 2 ++ .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 6 +++++- .../sql/catalyst/plans/logical/basicLogicalOperators.scala | 3 +++ .../spark/sql/catalyst/optimizer/EliminateSortsSuite.scala | 12 ++++++++++++ 4 files changed, 22 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 30d4c2dbb409..eb3047700215 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -395,6 +395,8 @@ package object dsl { def limit(limitExpr: Expression): LogicalPlan = Limit(limitExpr, logicalPlan) + def localLimit(limitExpr: Expression): LogicalPlan = LocalLimit(limitExpr, logicalPlan) + def offset(offsetExpr: Expression): LogicalPlan = Offset(offsetExpr, logicalPlan) def join( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 960f5e532c08..a4b25cbd1d2e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -769,7 +769,9 @@ object LimitPushDown extends Rule[LogicalPlan] { LocalLimit(exp, project.copy(child = pushLocalLimitThroughJoin(exp, join))) // Push down limit 1 through Aggregate and turn Aggregate into Project if it is group only. case Limit(le @ IntegerLiteral(1), a: Aggregate) if a.groupOnly => - Limit(le, Project(a.aggregateExpressions, LocalLimit(le, a.child))) + val project = Project(a.aggregateExpressions, LocalLimit(le, a.child)) + project.setTagValue(Project.dataOrderIrrelevantTag, ()) + Limit(le, project) case Limit(le @ IntegerLiteral(1), p @ Project(_, a: Aggregate)) if a.groupOnly => Limit(le, p.copy(child = Project(a.aggregateExpressions, LocalLimit(le, a.child)))) // Merge offset value and limit value into LocalLimit and pushes down LocalLimit through Offset. @@ -1583,6 +1585,8 @@ object EliminateSorts extends Rule[LogicalPlan] { right = recursiveRemoveSort(originRight, true)) case g @ Aggregate(_, aggs, originChild) if isOrderIrrelevantAggs(aggs) => g.copy(child = recursiveRemoveSort(originChild, true)) + case p: Project if p.getTagValue(Project.dataOrderIrrelevantTag).isDefined => + p.copy(child = recursiveRemoveSort(p.child, true)) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 497f485b67fe..65f4151c0c96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -101,6 +101,9 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) object Project { val hiddenOutputTag: TreeNodeTag[Seq[Attribute]] = TreeNodeTag[Seq[Attribute]]("hidden_output") + // Project with this tag means it doesn't care about the data order of its input. We only set + // this tag when the Project was converted from grouping-only Aggregate. + val dataOrderIrrelevantTag: TreeNodeTag[Unit] = TreeNodeTag[Unit]("data_order_irrelevant") def matchSchema(plan: LogicalPlan, schema: StructType, conf: SQLConf): Project = { assert(plan.resolved) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala index 7cbc308182c6..c6312fa1b1aa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala @@ -478,4 +478,16 @@ class EliminateSortsSuite extends AnalysisTest { comparePlans(Optimize.execute(originalPlan.analyze), correctAnswer.analyze) } + + test("SPARK-46378: Still remove Sort after converting Aggregate to Project") { + val originalPlan = testRelation.orderBy($"a".asc) + .groupBy($"a")($"a") + .limit(1) + + val correctAnswer = testRelation.localLimit(1) + .select($"a") + .limit(1) + + comparePlans(Optimize.execute(originalPlan.analyze), correctAnswer.analyze) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org