This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 42f2132d1fc9 [SPARK-48206][SQL][TESTS] Add tests for window rewrites 
with RewriteWithExpression
42f2132d1fc9 is described below

commit 42f2132d1fc99bf2ec5bd398d21dcbdbd5cbde47
Author: Kelvin Jiang <kelvin.ji...@databricks.com>
AuthorDate: Mon May 13 22:28:27 2024 +0800

    [SPARK-48206][SQL][TESTS] Add tests for window rewrites with 
RewriteWithExpression
    
    ### What changes were proposed in this pull request?
    
    This PR adds more testing for `RewriteWithExpression` around `Window` 
operators.
    
    ### Why are the changes needed?
    
    Adds more testing for `RewriteWithExpression`, which can be fragile around 
`WindowExpressions`.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Unit tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #46492 from kelvinjian-db/SPARK-48206-window.
    
    Authored-by: Kelvin Jiang <kelvin.ji...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../optimizer/RewriteWithExpressionSuite.scala     | 223 +++++++++++++--------
 1 file changed, 135 insertions(+), 88 deletions(-)

diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala
index 8f023fa4156b..aa8ffb2b0454 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala
@@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.PlanTest
 import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
 import org.apache.spark.sql.catalyst.rules.RuleExecutor
-import org.apache.spark.sql.types.IntegerType
 
 class RewriteWithExpressionSuite extends PlanTest {
 
@@ -37,6 +36,20 @@ class RewriteWithExpressionSuite extends PlanTest {
   private val testRelation = LocalRelation($"a".int, $"b".int)
   private val testRelation2 = LocalRelation($"x".int, $"y".int)
 
+  private def normalizeCommonExpressionIds(plan: LogicalPlan): LogicalPlan = {
+    plan.transformAllExpressions {
+      case a: Alias if a.name.startsWith("_common_expr") =>
+        a.withName("_common_expr_0")
+      case a: AttributeReference if a.name.startsWith("_common_expr") =>
+        a.withName("_common_expr_0")
+    }
+  }
+
+  override def comparePlans(
+    plan1: LogicalPlan, plan2: LogicalPlan, checkAnalysis: Boolean = true): 
Unit = {
+    super.comparePlans(normalizeCommonExpressionIds(plan1), 
normalizeCommonExpressionIds(plan2))
+  }
+
   test("simple common expression") {
     val a = testRelation.output.head
     val expr = With(a) { case Seq(ref) =>
@@ -52,65 +65,48 @@ class RewriteWithExpressionSuite extends PlanTest {
       ref * ref
     }
     val plan = testRelation.select(expr.as("col"))
-    val commonExprId = expr.defs.head.id.id
-    val commonExprName = s"_common_expr_$commonExprId"
     comparePlans(
       Optimizer.execute(plan),
       testRelation
-        .select((testRelation.output :+ (a + a).as(commonExprName)): _*)
-        .select(($"$commonExprName" * $"$commonExprName").as("col"))
+        .select((testRelation.output :+ (a + a).as("_common_expr_0")): _*)
+        .select(($"_common_expr_0" * $"_common_expr_0").as("col"))
         .analyze
     )
   }
 
   test("nested WITH expression in the definition expression") {
-    val a = testRelation.output.head
+    val Seq(a, b) = testRelation.output
     val innerExpr = With(a + a) { case Seq(ref) =>
       ref + ref
     }
-    val innerCommonExprId = innerExpr.defs.head.id.id
-    val innerCommonExprName = s"_common_expr_$innerCommonExprId"
-
-    val b = testRelation.output.last
     val outerExpr = With(innerExpr + b) { case Seq(ref) =>
       ref * ref
     }
-    val outerCommonExprId = outerExpr.defs.head.id.id
-    val outerCommonExprName = s"_common_expr_$outerCommonExprId"
 
     val plan = testRelation.select(outerExpr.as("col"))
-    val rewrittenOuterExpr = ($"$innerCommonExprName" + 
$"$innerCommonExprName" + b)
-      .as(outerCommonExprName)
-    val outerExprAttr = AttributeReference(outerCommonExprName, IntegerType)(
-      exprId = rewrittenOuterExpr.exprId)
     comparePlans(
       Optimizer.execute(plan),
       testRelation
-        .select((testRelation.output :+ (a + a).as(innerCommonExprName)): _*)
-        .select((testRelation.output :+ $"$innerCommonExprName" :+ 
rewrittenOuterExpr): _*)
-        .select((outerExprAttr * outerExprAttr).as("col"))
+        .select((testRelation.output :+ (a + a).as("_common_expr_0")): _*)
+        .select((testRelation.output ++ Seq($"_common_expr_0",
+          ($"_common_expr_0" + $"_common_expr_0" + b).as("_common_expr_1"))): 
_*)
+        .select(($"_common_expr_1" * $"_common_expr_1").as("col"))
         .analyze
     )
   }
 
   test("nested WITH expression in the main expression") {
-    val a = testRelation.output.head
+    val Seq(a, b) = testRelation.output
     val innerExpr = With(a + a) { case Seq(ref) =>
       ref + ref
     }
-    val innerCommonExprId = innerExpr.defs.head.id.id
-    val innerCommonExprName = s"_common_expr_$innerCommonExprId"
-
-    val b = testRelation.output.last
     val outerExpr = With(b + b) { case Seq(ref) =>
       ref * ref + innerExpr
     }
-    val outerCommonExprId = outerExpr.defs.head.id.id
-    val outerCommonExprName = s"_common_expr_$outerCommonExprId"
 
     val plan = testRelation.select(outerExpr.as("col"))
-    val rewrittenInnerExpr = (a + a).as(innerCommonExprName)
-    val rewrittenOuterExpr = (b + b).as(outerCommonExprName)
+    val rewrittenInnerExpr = (a + a).as("_common_expr_0")
+    val rewrittenOuterExpr = (b + b).as("_common_expr_1")
     val finalExpr = rewrittenOuterExpr.toAttribute * 
rewrittenOuterExpr.toAttribute +
       (rewrittenInnerExpr.toAttribute + rewrittenInnerExpr.toAttribute)
     comparePlans(
@@ -124,11 +120,10 @@ class RewriteWithExpressionSuite extends PlanTest {
   }
 
   test("correlated nested WITH expression is not supported") {
-    val b = testRelation.output.last
+    val Seq(a, b) = testRelation.output
     val outerCommonExprDef = CommonExpressionDef(b + b, CommonExpressionId(0))
     val outerRef = new CommonExpressionRef(outerCommonExprDef)
 
-    val a = testRelation.output.head
     // The inner expression definition references the outer expression
     val commonExprDef1 = CommonExpressionDef(a + a + outerRef, 
CommonExpressionId(1))
     val ref1 = new CommonExpressionRef(commonExprDef1)
@@ -152,13 +147,11 @@ class RewriteWithExpressionSuite extends PlanTest {
       ref < 10 && ref > 0
     }
     val plan = testRelation.where(condition)
-    val commonExprId = condition.defs.head.id.id
-    val commonExprName = s"_common_expr_$commonExprId"
     comparePlans(
       Optimizer.execute(plan),
       testRelation
-        .select((testRelation.output :+ (a + a).as(commonExprName)): _*)
-        .where($"$commonExprName" < 10 && $"$commonExprName" > 0)
+        .select((testRelation.output :+ (a + a).as("_common_expr_0")): _*)
+        .where($"_common_expr_0" < 10 && $"_common_expr_0" > 0)
         .select(testRelation.output: _*)
         .analyze
     )
@@ -170,13 +163,11 @@ class RewriteWithExpressionSuite extends PlanTest {
       ref < 10 && ref > 0
     }
     val plan = testRelation.join(testRelation2, condition = Some(condition))
-    val commonExprId = condition.defs.head.id.id
-    val commonExprName = s"_common_expr_$commonExprId"
     comparePlans(
       Optimizer.execute(plan),
       testRelation
-        .select((testRelation.output :+ (a + a).as(commonExprName)): _*)
-        .join(testRelation2, condition = Some($"$commonExprName" < 10 && 
$"$commonExprName" > 0))
+        .select((testRelation.output :+ (a + a).as("_common_expr_0")): _*)
+        .join(testRelation2, condition = Some($"_common_expr_0" < 10 && 
$"_common_expr_0" > 0))
         .select((testRelation.output ++ testRelation2.output): _*)
         .analyze
     )
@@ -188,14 +179,12 @@ class RewriteWithExpressionSuite extends PlanTest {
       ref < 10 && ref > 0
     }
     val plan = testRelation.join(testRelation2, condition = Some(condition))
-    val commonExprId = condition.defs.head.id.id
-    val commonExprName = s"_common_expr_$commonExprId"
     comparePlans(
       Optimizer.execute(plan),
       testRelation
         .join(
-          testRelation2.select((testRelation2.output :+ (x + 
x).as(commonExprName)): _*),
-          condition = Some($"$commonExprName" < 10 && $"$commonExprName" > 0)
+          testRelation2.select((testRelation2.output :+ (x + 
x).as("_common_expr_0")): _*),
+          condition = Some($"_common_expr_0" < 10 && $"_common_expr_0" > 0)
         )
         .select((testRelation.output ++ testRelation2.output): _*)
         .analyze
@@ -234,14 +223,12 @@ class RewriteWithExpressionSuite extends PlanTest {
       ref * ref
     }, a))
     val plan2 = testRelation.select(expr2.as("col"))
-    val commonExprId = expr2.children.head.asInstanceOf[With].defs.head.id.id
-    val commonExprName = s"_common_expr_$commonExprId"
     // With in the always-evaluated branches can still be optimized.
     comparePlans(
       Optimizer.execute(plan2),
       testRelation
-        .select((testRelation.output :+ (a + a).as(commonExprName)): _*)
-        .select(Coalesce(Seq(($"$commonExprName" * $"$commonExprName"), 
a)).as("col"))
+        .select((testRelation.output :+ (a + a).as("_common_expr_0")): _*)
+        .select(Coalesce(Seq(($"_common_expr_0" * $"_common_expr_0"), 
a)).as("col"))
         .analyze
     )
   }
@@ -261,38 +248,32 @@ class RewriteWithExpressionSuite extends PlanTest {
       (expr2 + 2).as("col1"),
       count(expr3 - 3).as("col2")
     )
-    val commonExpr1Id = expr1.defs.head.id.id
-    val commonExpr1Name = s"_common_expr_$commonExpr1Id"
-    // Note that the common expression in expr2 gets de-duplicated by 
PullOutGroupingExpressions.
-    val commonExpr3Id = expr3.defs.head.id.id
-    val commonExpr3Name = s"_common_expr_$commonExpr3Id"
-    val groupingExprName = "_groupingexpression"
-    val aggExprName = "_aggregateexpression"
     comparePlans(
       Optimizer.execute(plan),
       testRelation
-        .select(testRelation.output :+ (a + 1).as(commonExpr1Name): _*)
+        .select(testRelation.output :+ (a + 1).as("_common_expr_0"): _*)
         .select(testRelation.output :+
-          ($"$commonExpr1Name" * $"$commonExpr1Name").as(groupingExprName): _*)
-        .select(testRelation.output ++ Seq($"$groupingExprName", (a + 
1).as(commonExpr3Name)): _*)
-        .groupBy($"$groupingExprName")(
-          $"$groupingExprName",
-          count($"$commonExpr3Name" * $"$commonExpr3Name" - 3).as(aggExprName)
+          ($"_common_expr_0" * $"_common_expr_0").as("_groupingexpression"): 
_*)
+        .select(testRelation.output ++ Seq($"_groupingexpression",
+          (a + 1).as("_common_expr_1")): _*)
+        .groupBy($"_groupingexpression")(
+          $"_groupingexpression",
+          count($"_common_expr_1" * $"_common_expr_1" - 
3).as("_aggregateexpression")
         )
-        .select(($"$groupingExprName" + 2).as("col1"), 
$"`$aggExprName`".as("col2"))
+        .select(($"_groupingexpression" + 2).as("col1"), 
$"_aggregateexpression".as("col2"))
         .analyze
     )
     // Running CollapseProject after the rule cleans up the unnecessary 
projections.
     comparePlans(
       CollapseProject(Optimizer.execute(plan)),
       testRelation
-        .select(testRelation.output :+ (a + 1).as(commonExpr1Name): _*)
+        .select(testRelation.output :+ (a + 1).as("_common_expr_0"): _*)
         .select(testRelation.output ++ Seq(
-          ($"$commonExpr1Name" * $"$commonExpr1Name").as(groupingExprName),
-          (a + 1).as(commonExpr3Name)): _*)
-        .groupBy($"$groupingExprName")(
-          ($"$groupingExprName" + 2).as("col1"),
-          count($"$commonExpr3Name" * $"$commonExpr3Name" - 3).as("col2")
+          ($"_common_expr_0" * $"_common_expr_0").as("_groupingexpression"),
+          (a + 1).as("_common_expr_1")): _*)
+        .groupBy($"_groupingexpression")(
+          ($"_groupingexpression" + 2).as("col1"),
+          count($"_common_expr_1" * $"_common_expr_1" - 3).as("col2")
         )
         .analyze
     )
@@ -311,21 +292,16 @@ class RewriteWithExpressionSuite extends PlanTest {
       expr1.as("col2"),
       max(expr2).as("col3")
     )
-    val commonExpr1Id = expr1.defs.head.id.id
-    val commonExpr1Name = s"_common_expr_$commonExpr1Id"
-    val commonExpr2Id = expr2.defs.head.id.id
-    val commonExpr2Name = s"_common_expr_$commonExpr2Id"
-    val aggExprName = "_aggregateexpression"
     comparePlans(
       Optimizer.execute(plan),
       testRelation
-        .select(testRelation.output :+ (b + 2).as(commonExpr2Name): _*)
-        .groupBy(a)(a, max($"$commonExpr2Name" * 
$"$commonExpr2Name").as(aggExprName))
-        .select(a, $"`$aggExprName`", (a + 1).as(commonExpr1Name))
+        .select(testRelation.output :+ (b + 2).as("_common_expr_0"): _*)
+        .groupBy(a)(a, max($"_common_expr_0" * 
$"_common_expr_0").as("_aggregateexpression"))
+        .select(a, $"_aggregateexpression", (a + 1).as("_common_expr_1"))
         .select(
           (a + 3).as("col1"),
-          ($"$commonExpr1Name" * $"$commonExpr1Name").as("col2"),
-          $"`$aggExprName`".as("col3")
+          ($"_common_expr_1" * $"_common_expr_1").as("col2"),
+          $"_aggregateexpression".as("col3")
         )
         .analyze
     )
@@ -340,14 +316,13 @@ class RewriteWithExpressionSuite extends PlanTest {
       (a - 1).as("col1"),
       expr.as("col2")
     )
-    val aggExprName = "_aggregateexpression"
     comparePlans(
       Optimizer.execute(plan),
       testRelation
-        .groupBy(a)(a, count(a - 1).as(aggExprName))
+        .groupBy(a)(a, count(a - 1).as("_aggregateexpression"))
         .select(
           (a - 1).as("col1"),
-          ($"$aggExprName" * $"$aggExprName").as("col2")
+          ($"_aggregateexpression" * $"_aggregateexpression").as("col2")
         )
         .analyze
     )
@@ -376,19 +351,91 @@ class RewriteWithExpressionSuite extends PlanTest {
       ref * max(expr) + ref
     }
     val plan = testRelation.groupBy(a)(nestedExpr.as("col")).analyze
-    val commonExpr1Id = expr.defs.head.id.id
-    val commonExpr1Name = s"_common_expr_$commonExpr1Id"
-    val commonExpr2Id = nestedExpr.defs.head.id.id
-    val commonExpr2Name = s"_common_expr_$commonExpr2Id"
-    val aggExprName = "_aggregateexpression"
     comparePlans(
       Optimizer.execute(plan),
       testRelation
-        .select(testRelation.output :+ (a + 1).as(commonExpr1Name): _*)
-        .groupBy(a)(a, max($"$commonExpr1Name" * 
$"$commonExpr1Name").as(aggExprName))
-        .select($"a", $"$aggExprName", (a - 1).as(commonExpr2Name))
-        .select(($"$commonExpr2Name" * $"$aggExprName" + 
$"$commonExpr2Name").as("col"))
+        .select(testRelation.output :+ (a + 1).as("_common_expr_0"): _*)
+        .groupBy(a)(a, max($"_common_expr_0" * 
$"_common_expr_0").as("_aggregateexpression"))
+        .select($"a", $"_aggregateexpression", (a - 1).as("_common_expr_1"))
+        .select(($"_common_expr_1" * $"_aggregateexpression" + 
$"_common_expr_1").as("col"))
+        .analyze
+    )
+  }
+
+  test("WITH expression in window exprs") {
+    val Seq(a, b) = testRelation.output
+    val expr1 = With(a + 1) { case Seq(ref) =>
+      ref * ref
+    }
+    val expr2 = With(b + 2) { case Seq(ref) =>
+      ref * ref
+    }
+    val frame = SpecifiedWindowFrame(RowFrame, UnboundedPreceding, 
UnboundedFollowing)
+    val plan = testRelation
+      .window(
+        Seq(windowExpr(count(a), windowSpec(Seq(expr2), Nil, 
frame)).as("col2")),
+        Seq(expr2),
+        Nil
+      )
+      .window(
+        Seq(windowExpr(sum(expr1), windowSpec(Seq(a), Nil, frame)).as("col3")),
+        Seq(a),
+        Nil
+      )
+      .select((a - 1).as("col1"), $"col2", $"col3")
+      .analyze
+    comparePlans(
+      Optimizer.execute(plan),
+      testRelation
+        .select(a, b, (b + 2).as("_common_expr_0"))
+        .select(a, b, $"_common_expr_0", (b + 2).as("_common_expr_1"))
+        .window(
+          Seq(windowExpr(count(a), windowSpec(Seq($"_common_expr_0" * 
$"_common_expr_0"), Nil,
+            frame)).as("col2")),
+          Seq($"_common_expr_1" * $"_common_expr_1"),
+          Nil
+        )
+        .select(a, b, $"col2")
+        .select(a, b, $"col2", (a + 1).as("_common_expr_2"))
+        .window(
+          Seq(windowExpr(sum($"_common_expr_2" * $"_common_expr_2"),
+            windowSpec(Seq(a), Nil, frame)).as("col3")),
+          Seq(a),
+          Nil
+        )
+        .select(a, b, $"col2", $"col3")
+        .select((a - 1).as("col1"), $"col2", $"col3")
+        .analyze
+    )
+  }
+
+  test("WITH common expression is window function") {
+    val a = testRelation.output.head
+    val frame = SpecifiedWindowFrame(RowFrame, UnboundedPreceding, 
UnboundedFollowing)
+    val winExpr = windowExpr(sum(a), windowSpec(Seq(a), Nil, frame))
+    val expr = With(winExpr) {
+      case Seq(ref) => ref * ref
+    }
+    val plan = testRelation.select(expr.as("col")).analyze
+    comparePlans(
+      Optimizer.execute(plan),
+      testRelation
+        .select(a)
+        .window(Seq(winExpr.as("_we0")), Seq(a), Nil)
+        .select(a, $"_we0", ($"_we0" * $"_we0").as("col"))
+        .select($"col")
         .analyze
     )
   }
+
+  test("window functions in child of WITH expression with ref is not 
supported") {
+    val a = testRelation.output.head
+    intercept[java.lang.AssertionError] {
+      val expr = With(a - 1) { case Seq(ref) =>
+        ref + windowExpr(sum(ref), windowSpec(Seq(a), Nil, UnspecifiedFrame))
+      }
+      val plan = testRelation.window(Seq(expr.as("col")), Seq(a), Nil)
+      Optimizer.execute(plan)
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to