This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c1ba963e64a2 [SPARK-46378][SQL] Still remove Sort after converting 
Aggregate to Project
c1ba963e64a2 is described below

commit c1ba963e64a22dea28e17b1ed954e6d03d38da1e
Author: Wenchen Fan <wenc...@databricks.com>
AuthorDate: Tue Dec 12 10:04:40 2023 -0800

    [SPARK-46378][SQL] Still remove Sort after converting Aggregate to Project
    
    ### What changes were proposed in this pull request?
    
    This is a follow-up of https://github.com/apache/spark/pull/33397 to avoid 
sub-optimal plans. After converting `Aggregate` to `Project`, there is 
information lost: `Aggregate` doesn't care about the data order of inputs, but 
`Project` cares. `EliminateSorts` can remove `Sort` below `Aggregate`, but it 
doesn't work anymore if we convert `Aggregate` to `Project`.
    
    This PR fixes this issue by tagging the `Project` to be order-irrelevant if 
it's converted from `Aggregate`. Then `EliminateSorts` optimizes the tagged 
`Project`.
    
    ### Why are the changes needed?
    
    avoid sub-optimal plans
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    ### How was this patch tested?
    
    new test
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #44310 from cloud-fan/sort.
    
    Authored-by: Wenchen Fan <wenc...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../scala/org/apache/spark/sql/catalyst/dsl/package.scala    |  2 ++
 .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala  |  6 +++++-
 .../sql/catalyst/plans/logical/basicLogicalOperators.scala   |  3 +++
 .../spark/sql/catalyst/optimizer/EliminateSortsSuite.scala   | 12 ++++++++++++
 4 files changed, 22 insertions(+), 1 deletion(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 30d4c2dbb409..eb3047700215 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -395,6 +395,8 @@ package object dsl {
 
       def limit(limitExpr: Expression): LogicalPlan = Limit(limitExpr, 
logicalPlan)
 
+      def localLimit(limitExpr: Expression): LogicalPlan = 
LocalLimit(limitExpr, logicalPlan)
+
       def offset(offsetExpr: Expression): LogicalPlan = Offset(offsetExpr, 
logicalPlan)
 
       def join(
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 960f5e532c08..a4b25cbd1d2e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -769,7 +769,9 @@ object LimitPushDown extends Rule[LogicalPlan] {
       LocalLimit(exp, project.copy(child = pushLocalLimitThroughJoin(exp, 
join)))
     // Push down limit 1 through Aggregate and turn Aggregate into Project if 
it is group only.
     case Limit(le @ IntegerLiteral(1), a: Aggregate) if a.groupOnly =>
-      Limit(le, Project(a.aggregateExpressions, LocalLimit(le, a.child)))
+      val project = Project(a.aggregateExpressions, LocalLimit(le, a.child))
+      project.setTagValue(Project.dataOrderIrrelevantTag, ())
+      Limit(le, project)
     case Limit(le @ IntegerLiteral(1), p @ Project(_, a: Aggregate)) if 
a.groupOnly =>
       Limit(le, p.copy(child = Project(a.aggregateExpressions, LocalLimit(le, 
a.child))))
     // Merge offset value and limit value into LocalLimit and pushes down 
LocalLimit through Offset.
@@ -1583,6 +1585,8 @@ object EliminateSorts extends Rule[LogicalPlan] {
         right = recursiveRemoveSort(originRight, true))
     case g @ Aggregate(_, aggs, originChild) if isOrderIrrelevantAggs(aggs) =>
       g.copy(child = recursiveRemoveSort(originChild, true))
+    case p: Project if p.getTagValue(Project.dataOrderIrrelevantTag).isDefined 
=>
+      p.copy(child = recursiveRemoveSort(p.child, true))
   }
 
   /**
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 497f485b67fe..65f4151c0c96 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -101,6 +101,9 @@ case class Project(projectList: Seq[NamedExpression], 
child: LogicalPlan)
 
 object Project {
   val hiddenOutputTag: TreeNodeTag[Seq[Attribute]] = 
TreeNodeTag[Seq[Attribute]]("hidden_output")
+  // Project with this tag means it doesn't care about the data order of its 
input. We only set
+  // this tag when the Project was converted from grouping-only Aggregate.
+  val dataOrderIrrelevantTag: TreeNodeTag[Unit] = 
TreeNodeTag[Unit]("data_order_irrelevant")
 
   def matchSchema(plan: LogicalPlan, schema: StructType, conf: SQLConf): 
Project = {
     assert(plan.resolved)
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala
index 7cbc308182c6..c6312fa1b1aa 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala
@@ -478,4 +478,16 @@ class EliminateSortsSuite extends AnalysisTest {
 
     comparePlans(Optimize.execute(originalPlan.analyze), correctAnswer.analyze)
   }
+
+  test("SPARK-46378: Still remove Sort after converting Aggregate to Project") 
{
+    val originalPlan = testRelation.orderBy($"a".asc)
+      .groupBy($"a")($"a")
+      .limit(1)
+
+    val correctAnswer = testRelation.localLimit(1)
+      .select($"a")
+      .limit(1)
+
+    comparePlans(Optimize.execute(originalPlan.analyze), correctAnswer.analyze)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to