This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 8facc992fe46 [SPARK-45352][SQL] Eliminate foldable window partitions
8facc992fe46 is described below

commit 8facc992fe465cdf7db64de604c1d688c95f767a
Author: zml1206 <zhuml1...@gmail.com>
AuthorDate: Mon Jan 8 17:18:15 2024 +0800

    [SPARK-45352][SQL] Eliminate foldable window partitions
    
    ### What changes were proposed in this pull request?
    This PR add a new optimizer rule `EliminateWindowPartitions`, it remove 
window partition if partition expressions are foldable.
    sql1:
    `select row_number() over(order by a) b from t `
    sql2:
    `select row_number() over(partition by 1 order by a) b from t `
    After this PR, the `optimizedPlan` for sql1 and sql2 is the same.
    
    ### Why are the changes needed?
    Foldable partition is redundant, remove it not only can simplify plan, but 
some rules can also take effect when the partitions are all foldable, such as 
`LimitPushDownThroughWindow`.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    UT
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #43144 from zml1206/SPARK-45352.
    
    Authored-by: zml1206 <zhuml1...@gmail.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../optimizer/EliminateWindowPartitions.scala      | 41 +++++++++++
 .../spark/sql/catalyst/optimizer/Optimizer.scala   |  1 +
 .../sql/catalyst/rules/RuleIdCollection.scala      |  1 +
 .../optimizer/EliminateWindowPartitionsSuite.scala | 80 ++++++++++++++++++++++
 .../spark/sql/DataFrameWindowFramesSuite.scala     | 22 ++++++
 5 files changed, 145 insertions(+)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateWindowPartitions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateWindowPartitions.scala
new file mode 100644
index 000000000000..e3d1b0544358
--- /dev/null
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/EliminateWindowPartitions.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.expressions.{NamedExpression, 
WindowExpression, WindowSpecDefinition}
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Window}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.TreePattern.{WINDOW, 
WINDOW_EXPRESSION}
+
+/**
+ * Remove window partition if partition expressions are foldable.
+ */
+object EliminateWindowPartitions extends Rule[LogicalPlan] {
+  override def apply(plan: LogicalPlan): LogicalPlan = 
plan.transformWithPruning(
+    _.containsPattern(WINDOW), ruleId) {
+    case w @ Window(windowExprs, partitionSpec, _, _) if 
partitionSpec.exists(_.foldable) =>
+      val newWindowExprs = windowExprs.map(_.transformWithPruning(
+        _.containsPattern(WINDOW_EXPRESSION)) {
+        case windowExpr @ WindowExpression(_, wsd @ WindowSpecDefinition(ps, 
_, _))
+          if ps.exists(_.foldable) =>
+          val newWsd = wsd.copy(partitionSpec = ps.filter(!_.foldable))
+          windowExpr.copy(windowSpec = newWsd)
+      }.asInstanceOf[NamedExpression])
+      w.copy(windowExpressions = newWindowExprs, partitionSpec = 
partitionSpec.filter(!_.foldable))
+  }
+}
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 1a831b958ef2..61791b35df85 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -96,6 +96,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
         CombineUnions,
         // Constant folding and strength reduction
         OptimizeRepartition,
+        EliminateWindowPartitions,
         TransposeWindow,
         NullPropagation,
         // NullPropagation may introduce Exists subqueries, so 
RewriteNonCorrelatedExists must run
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
index 8eeea74b5376..08f728da2e9d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
@@ -124,6 +124,7 @@ object RuleIdCollection {
       "org.apache.spark.sql.catalyst.optimizer.EliminateMapObjects" ::
       "org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin" ::
       "org.apache.spark.sql.catalyst.optimizer.EliminateSerialization" ::
+      "org.apache.spark.sql.catalyst.optimizer.EliminateWindowPartitions" ::
       "org.apache.spark.sql.catalyst.optimizer.InferWindowGroupLimit" ::
       "org.apache.spark.sql.catalyst.optimizer.LikeSimplification" ::
       "org.apache.spark.sql.catalyst.optimizer.LimitPushDown" ::
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateWindowPartitionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateWindowPartitionsSuite.scala
new file mode 100644
index 000000000000..b5ad729365aa
--- /dev/null
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateWindowPartitionsSuite.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules._
+
+class EliminateWindowPartitionsSuite extends PlanTest {
+
+  private object Optimize extends RuleExecutor[LogicalPlan] {
+    val batches =
+      Batch("Eliminate window partitions", FixedPoint(20),
+        EliminateWindowPartitions) :: Nil
+  }
+
+  val testRelation = LocalRelation($"a".int, $"b".int)
+  private val a = testRelation.output(0)
+  private val b = testRelation.output(1)
+  private val windowFrame = SpecifiedWindowFrame(RowFrame, UnboundedPreceding, 
CurrentRow)
+
+  test("Remove foldable window partitions") {
+    val originalQuery =
+      testRelation
+        .select(a, b,
+          windowExpr(RowNumber(),
+            windowSpec(Literal(1) :: Nil, b.desc :: Nil, 
windowFrame)).as("rn"))
+
+    val correctAnswer =
+      testRelation
+        .select(a, b,
+          windowExpr(RowNumber(),
+            windowSpec(Nil, b.desc :: Nil, windowFrame)).as("rn"))
+    comparePlans(Optimize.execute(originalQuery.analyze), 
correctAnswer.analyze)
+  }
+
+  test("Remove part of window partitions which is foldable") {
+    val originalQuery =
+      testRelation
+        .select(a, b,
+          windowExpr(RowNumber(),
+            windowSpec(a :: Literal(1) :: Nil, b.desc :: Nil, 
windowFrame)).as("rn"))
+
+    val correctAnswer =
+      testRelation
+        .select(a, b,
+          windowExpr(RowNumber(),
+            windowSpec(a :: Nil, b.desc :: Nil, windowFrame)).as("rn"))
+    comparePlans(Optimize.execute(originalQuery.analyze), 
correctAnswer.analyze)
+  }
+
+  test("Can't remove non-foldable window partitions") {
+    val originalQuery =
+      testRelation
+        .select(a, b,
+          windowExpr(RowNumber(),
+            windowSpec(a :: Nil, b.desc :: Nil, windowFrame)).as("rn"))
+
+    val correctAnswer = originalQuery
+    comparePlans(Optimize.execute(originalQuery.analyze), 
correctAnswer.analyze)
+  }
+}
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala
index 0e3932cf1e1d..fe1393af8174 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala
@@ -18,9 +18,11 @@
 package org.apache.spark.sql
 
 import org.apache.spark.sql.catalyst.expressions.{Ascending, Literal, 
NonFoldableLiteral, RangeFrame, SortOrder, SpecifiedWindowFrame, UnaryMinus, 
UnspecifiedFrame}
+import org.apache.spark.sql.catalyst.optimizer.EliminateWindowPartitions
 import org.apache.spark.sql.catalyst.plans.logical.{Window => WindowNode}
 import org.apache.spark.sql.expressions.{Window, WindowSpec}
 import org.apache.spark.sql.functions._
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.sql.types.CalendarIntervalType
 
@@ -548,4 +550,24 @@ class DataFrameWindowFramesSuite extends QueryTest with 
SharedSparkSession {
       df,
       Row(1) :: Row(1) :: Nil)
   }
+
+  test("SPARK-45352: Eliminate foldable window partitions") {
+    val df = Seq((1, 1), (1, 2), (1, 3), (2, 1), (2, 2)).toDF("a", "b")
+
+    Seq(true, false).foreach { eliminateWindowPartitionsEnabled =>
+      val excludedRules =
+        if (eliminateWindowPartitionsEnabled) "" else 
EliminateWindowPartitions.ruleName
+      withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> excludedRules) {
+        val window1 = Window.partitionBy(lit(1)).orderBy($"b")
+        checkAnswer(
+          df.select($"a", $"b", row_number().over(window1)),
+          Seq(Row(1, 1, 1), Row(1, 2, 3), Row(1, 3, 5), Row(2, 1, 2), Row(2, 
2, 4)))
+
+        val window2 = Window.partitionBy($"a", lit(1)).orderBy($"b")
+        checkAnswer(
+          df.select($"a", $"b", row_number().over(window2)),
+          Seq(Row(1, 1, 1), Row(1, 2, 2), Row(1, 3, 3), Row(2, 1, 1), Row(2, 
2, 2)))
+      }
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to