This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c4b0c260bb13 [SPARK-47839][SQL] Fix aggregate bug in 
RewriteWithExpression
c4b0c260bb13 is described below

commit c4b0c260bb139f61901d5bd5f1d94dddaefc9207
Author: Kelvin Jiang <kelvin.ji...@databricks.com>
AuthorDate: Thu Apr 18 09:56:10 2024 +0800

    [SPARK-47839][SQL] Fix aggregate bug in RewriteWithExpression
    
    ### What changes were proposed in this pull request?
    
    - Fixes a bug where `RewriteWithExpression` can rewrite an `Aggregate` into 
an invalid one. The fix is done by separating out the "result expressions" from 
the "aggregate expressions" in the `Aggregate` node, and rewriting them 
separately.
    - Some QOL improvements around `With`:
      - Fix aliases created by `With` expression to use the 
`CommonExpressionId` to avoid duplicate aliases (added a conf to fall back to 
old behaviour, which is useful to keep the IDs consistent for golden files 
tests)
      - Implemented `QueryPlan.transformUpWithSubqueriesAndPruning` that the 
new logic depends on
    
    ### Why are the changes needed?
    
    See [JIRA ticket](https://issues.apache.org/jira/browse/SPARK-47839) for 
more details on the bug that this fixes.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Added new unit tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #46034 from kelvinjian-db/SPARK-47839-with-aggregate.
    
    Authored-by: Kelvin Jiang <kelvin.ji...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../explain-results/function_count_if.explain      |   7 +-
 .../sql/connect/ProtoToParsedPlanTestSuite.scala   |   1 +
 .../spark/sql/catalyst/expressions/With.scala      |   6 +-
 .../catalyst/optimizer/RewriteWithExpression.scala |  70 +++++--
 .../spark/sql/catalyst/plans/QueryPlan.scala       |  24 +++
 .../org/apache/spark/sql/internal/SQLConf.scala    |  11 +
 .../optimizer/RewriteWithExpressionSuite.scala     | 231 ++++++++++++++++-----
 7 files changed, 281 insertions(+), 69 deletions(-)

diff --git 
a/connector/connect/common/src/test/resources/query-tests/explain-results/function_count_if.explain
 
b/connector/connect/common/src/test/resources/query-tests/explain-results/function_count_if.explain
index f2ada15eccb7..a9fd2eeb669a 100644
--- 
a/connector/connect/common/src/test/resources/query-tests/explain-results/function_count_if.explain
+++ 
b/connector/connect/common/src/test/resources/query-tests/explain-results/function_count_if.explain
@@ -1,3 +1,4 @@
-Aggregate [count(if ((_common_expr_0#0 = false)) null else _common_expr_0#0) 
AS count_if((a > 0))#0L]
-+- Project [id#0L, a#0, b#0, d#0, e#0, f#0, g#0, (a#0 > 0) AS _common_expr_0#0]
-   +- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
+Project [_aggregateexpression#0L AS count_if((a > 0))#0L]
++- Aggregate [count(if ((_common_expr_0#0 = false)) null else 
_common_expr_0#0) AS _aggregateexpression#0L]
+   +- Project [id#0L, a#0, b#0, d#0, e#0, f#0, g#0, (a#0 > 0) AS 
_common_expr_0#0]
+      +- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala
index cc9decb4c98b..d404779d7a92 100644
--- 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala
@@ -126,6 +126,7 @@ class ProtoToParsedPlanTestSuite
         Connect.CONNECT_EXTENSIONS_EXPRESSION_CLASSES.key,
         "org.apache.spark.sql.connect.plugin.ExampleExpressionPlugin")
       .set(org.apache.spark.sql.internal.SQLConf.ANSI_ENABLED.key, 
false.toString)
+      
.set(org.apache.spark.sql.internal.SQLConf.USE_COMMON_EXPR_ID_FOR_ALIAS.key, 
false.toString)
   }
 
   protected val suiteBaseResourcePath = 
commonResourcePath.resolve("query-tests")
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala
index 2745b663639f..14deedd9c70f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
-import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMON_EXPR_REF, 
TreePattern, WITH_EXPRESSION}
+import org.apache.spark.sql.catalyst.trees.TreePattern.{AGGREGATE_EXPRESSION, 
COMMON_EXPR_REF, TreePattern, WITH_EXPRESSION}
 import org.apache.spark.sql.types.DataType
 
 /**
@@ -27,6 +27,10 @@ import org.apache.spark.sql.types.DataType
  */
 case class With(child: Expression, defs: Seq[CommonExpressionDef])
   extends Expression with Unevaluable {
+  // We do not allow With to be created with an AggregateExpression in the 
child, as this would
+  // create a dangling CommonExpressionRef after rewriting it in 
RewriteWithExpression.
+  assert(!child.containsPattern(AGGREGATE_EXPRESSION))
+
   override val nodePatterns: Seq[TreePattern] = Seq(WITH_EXPRESSION)
   override def dataType: DataType = child.dataType
   override def nullable: Boolean = child.nullable
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala
index 934eadbcee55..393a66f7c1e4 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala
@@ -21,36 +21,65 @@ import scala.collection.mutable
 
 import org.apache.spark.SparkException
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, PlanHelper, 
Project}
+import org.apache.spark.sql.catalyst.planning.PhysicalAggregation
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, 
PlanHelper, Project}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMON_EXPR_REF, 
WITH_EXPRESSION}
+import org.apache.spark.sql.internal.SQLConf
 
 /**
  * Rewrites the `With` expressions by adding a `Project` to pre-evaluate the 
common expressions, or
  * just inline them if they are cheap.
  *
+ * Since this rule can introduce new `Project` operators, it is advised to run 
[[CollapseProject]]
+ * after this rule.
+ *
  * Note: For now we only use `With` in a few `RuntimeReplaceable` expressions. 
If we expand its
  *       usage, we should support aggregate/window functions as well.
  */
 object RewriteWithExpression extends Rule[LogicalPlan] {
   override def apply(plan: LogicalPlan): LogicalPlan = {
-    
plan.transformDownWithSubqueriesAndPruning(_.containsPattern(WITH_EXPRESSION)) {
+    
plan.transformUpWithSubqueriesAndPruning(_.containsPattern(WITH_EXPRESSION)) {
+      // For aggregates, separate the computation of the aggregations 
themselves from the final
+      // result by moving the final result computation into a projection above 
it. This prevents
+      // this rule from producing an invalid Aggregate operator.
+      case p @ PhysicalAggregation(
+          groupingExpressions, aggregateExpressions, resultExpressions, child)
+          if p.expressions.exists(_.containsPattern(WITH_EXPRESSION)) =>
+        // PhysicalAggregation returns aggregateExpressions as attribute 
references, which we change
+        // to aliases so that they can be referred to by resultExpressions.
+        val aggExprs = aggregateExpressions.map(
+          ae => Alias(ae, "_aggregateexpression")(ae.resultId))
+        val aggExprIds = aggExprs.map(_.exprId).toSet
+        val resExprs = resultExpressions.map(_.transform {
+          case a: AttributeReference if aggExprIds.contains(a.exprId) =>
+            a.withName("_aggregateexpression")
+        }.asInstanceOf[NamedExpression])
+        // Rewrite the projection and the aggregate separately and then piece 
them together.
+        val agg = Aggregate(groupingExpressions, groupingExpressions ++ 
aggExprs, child)
+        val rewrittenAgg = applyInternal(agg)
+        val proj = Project(resExprs, rewrittenAgg)
+        applyInternal(proj)
       case p if p.expressions.exists(_.containsPattern(WITH_EXPRESSION)) =>
-        val inputPlans = p.children.toArray
-        var newPlan: LogicalPlan = p.mapExpressions { expr =>
-          rewriteWithExprAndInputPlans(expr, inputPlans)
-        }
-        newPlan = newPlan.withNewChildren(inputPlans.toIndexedSeq)
-        // Since we add extra Projects with extra columns to pre-evaluate the 
common expressions,
-        // the current operator may have extra columns if it inherits the 
output columns from its
-        // child, and we need to project away the extra columns to keep the 
plan schema unchanged.
-        assert(p.output.length <= newPlan.output.length)
-        if (p.output.length < newPlan.output.length) {
-          assert(p.outputSet.subsetOf(newPlan.outputSet))
-          Project(p.output, newPlan)
-        } else {
-          newPlan
-        }
+        applyInternal(p)
+    }
+  }
+
+  private def applyInternal(p: LogicalPlan): LogicalPlan = {
+    val inputPlans = p.children.toArray
+    var newPlan: LogicalPlan = p.mapExpressions { expr =>
+      rewriteWithExprAndInputPlans(expr, inputPlans)
+    }
+    newPlan = newPlan.withNewChildren(inputPlans.toIndexedSeq)
+    // Since we add extra Projects with extra columns to pre-evaluate the 
common expressions,
+    // the current operator may have extra columns if it inherits the output 
columns from its
+    // child, and we need to project away the extra columns to keep the plan 
schema unchanged.
+    assert(p.output.length <= newPlan.output.length)
+    if (p.output.length < newPlan.output.length) {
+      assert(p.outputSet.subsetOf(newPlan.outputSet))
+      Project(p.output, newPlan)
+    } else {
+      newPlan
     }
   }
 
@@ -93,7 +122,12 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
               //       if it's ref count is 1.
               refToExpr(id) = child
             } else {
-              val alias = Alias(child, s"_common_expr_$index")()
+              val aliasName = if 
(SQLConf.get.getConf(SQLConf.USE_COMMON_EXPR_ID_FOR_ALIAS)) {
+                s"_common_expr_${id.id}"
+              } else {
+                s"_common_expr_$index"
+              }
+              val alias = Alias(child, aliasName)()
               val fakeProj = Project(Seq(alias), 
inputPlans(childProjectionIndex))
               if 
(PlanHelper.specialExpressionsInUnsupportedOperator(fakeProj).nonEmpty) {
                 // We have to inline the common expression if it cannot be put 
in a Project.
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index 0f049103542e..505330d871cb 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -517,6 +517,30 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]
     transformDownWithSubqueriesAndPruning(AlwaysProcess.fn, UnknownRuleId)(f)
   }
 
+  /**
+   * Same as `transformUpWithSubqueries` except allows for pruning 
opportunities.
+   */
+  def transformUpWithSubqueriesAndPruning(
+    cond: TreePatternBits => Boolean,
+    ruleId: RuleId = UnknownRuleId)
+    (f: PartialFunction[PlanType, PlanType]): PlanType = {
+    val g: PartialFunction[PlanType, PlanType] = new PartialFunction[PlanType, 
PlanType] {
+      override def isDefinedAt(x: PlanType): Boolean = true
+
+      override def apply(plan: PlanType): PlanType = {
+        val transformed = plan.transformExpressionsUpWithPruning(t =>
+          t.containsPattern(PLAN_EXPRESSION) && cond(t)) {
+          case planExpression: PlanExpression[PlanType@unchecked] =>
+            val newPlan = 
planExpression.plan.transformUpWithSubqueriesAndPruning(cond, ruleId)(f)
+            planExpression.withNewPlan(newPlan)
+        }
+        f.applyOrElse[PlanType, PlanType](transformed, identity)
+      }
+    }
+
+    transformUpWithPruning(cond, ruleId)(g)
+  }
+
   /**
    * This method is the top-down (pre-order) counterpart of 
transformUpWithSubqueries.
    * Returns a copy of this node where the given partial function has been 
recursively applied
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 0691cd730939..1c7ae3d0bfa8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -3443,6 +3443,17 @@ object SQLConf {
       .booleanConf
       .createWithDefault(false)
 
+  val USE_COMMON_EXPR_ID_FOR_ALIAS =
+    buildConf("spark.sql.useCommonExprIdForAlias")
+      .internal()
+      .doc("When true, use the common expression ID for the alias when 
rewriting With " +
+        "expressions. Otherwise, use the index of the common expression 
definition. When true " +
+        "this avoids duplicate alias names, but is helpful to set to false for 
testing to ensure" +
+        "that alias names are consistent.")
+      .version("4.0.0")
+      .booleanConf
+      .createWithDefault(true)
+
   val USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES =
     buildConf("spark.sql.defaultColumn.useNullsForMissingDefaultValues")
       .internal()
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala
index a386e9bf4efe..d482b18d9331 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer
 import org.apache.spark.SparkException
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference, 
Coalesce, CommonExpressionDef, CommonExpressionRef, With}
+import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.PlanTest
 import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
 import org.apache.spark.sql.catalyst.rules.RuleExecutor
@@ -29,7 +29,9 @@ import org.apache.spark.sql.types.IntegerType
 class RewriteWithExpressionSuite extends PlanTest {
 
   object Optimizer extends RuleExecutor[LogicalPlan] {
-    val batches = Batch("Rewrite With expression", Once, 
RewriteWithExpression) :: Nil
+    val batches = Batch("Rewrite With expression", Once,
+      PullOutGroupingExpressions,
+      RewriteWithExpression) :: Nil
   }
 
   private val testRelation = LocalRelation($"a".int, $"b".int)
@@ -37,18 +39,21 @@ class RewriteWithExpressionSuite extends PlanTest {
 
   test("simple common expression") {
     val a = testRelation.output.head
-    val commonExprDef = CommonExpressionDef(a)
-    val ref = new CommonExpressionRef(commonExprDef)
-    val plan = testRelation.select(With(ref + ref, 
Seq(commonExprDef)).as("col"))
+    val expr = With(a) { case Seq(ref) =>
+      ref + ref
+    }
+    val plan = testRelation.select(expr.as("col"))
     comparePlans(Optimizer.execute(plan), testRelation.select((a + 
a).as("col")))
   }
 
   test("non-cheap common expression") {
     val a = testRelation.output.head
-    val commonExprDef = CommonExpressionDef(a + a)
-    val ref = new CommonExpressionRef(commonExprDef)
-    val plan = testRelation.select(With(ref * ref, 
Seq(commonExprDef)).as("col"))
-    val commonExprName = "_common_expr_0"
+    val expr = With(a + a) { case Seq(ref) =>
+      ref * ref
+    }
+    val plan = testRelation.select(expr.as("col"))
+    val commonExprId = expr.defs.head.id.id
+    val commonExprName = s"_common_expr_$commonExprId"
     comparePlans(
       Optimizer.execute(plan),
       testRelation
@@ -60,16 +65,18 @@ class RewriteWithExpressionSuite extends PlanTest {
 
   test("nested WITH expression in the definition expression") {
     val a = testRelation.output.head
-    val commonExprDef = CommonExpressionDef(a + a)
-    val ref = new CommonExpressionRef(commonExprDef)
-    val innerExpr = With(ref + ref, Seq(commonExprDef))
-    val innerCommonExprName = "_common_expr_0"
+    val innerExpr = With(a + a) { case Seq(ref) =>
+      ref + ref
+    }
+    val innerCommonExprId = innerExpr.defs.head.id.id
+    val innerCommonExprName = s"_common_expr_$innerCommonExprId"
 
     val b = testRelation.output.last
-    val outerCommonExprDef = CommonExpressionDef(innerExpr + b)
-    val outerRef = new CommonExpressionRef(outerCommonExprDef)
-    val outerExpr = With(outerRef * outerRef, Seq(outerCommonExprDef))
-    val outerCommonExprName = "_common_expr_0"
+    val outerExpr = With(innerExpr + b) { case Seq(ref) =>
+      ref * ref
+    }
+    val outerCommonExprId = outerExpr.defs.head.id.id
+    val outerCommonExprName = s"_common_expr_$outerCommonExprId"
 
     val plan = testRelation.select(outerExpr.as("col"))
     val rewrittenOuterExpr = ($"$innerCommonExprName" + 
$"$innerCommonExprName" + b)
@@ -88,16 +95,18 @@ class RewriteWithExpressionSuite extends PlanTest {
 
   test("nested WITH expression in the main expression") {
     val a = testRelation.output.head
-    val commonExprDef = CommonExpressionDef(a + a)
-    val ref = new CommonExpressionRef(commonExprDef)
-    val innerExpr = With(ref + ref, Seq(commonExprDef))
-    val innerCommonExprName = "_common_expr_0"
+    val innerExpr = With(a + a) { case Seq(ref) =>
+      ref + ref
+    }
+    val innerCommonExprId = innerExpr.defs.head.id.id
+    val innerCommonExprName = s"_common_expr_$innerCommonExprId"
 
     val b = testRelation.output.last
-    val outerCommonExprDef = CommonExpressionDef(b + b)
-    val outerRef = new CommonExpressionRef(outerCommonExprDef)
-    val outerExpr = With(outerRef * outerRef + innerExpr, 
Seq(outerCommonExprDef))
-    val outerCommonExprName = "_common_expr_0"
+    val outerExpr = With(b + b) { case Seq(ref) =>
+      ref * ref + innerExpr
+    }
+    val outerCommonExprId = outerExpr.defs.head.id.id
+    val outerCommonExprName = s"_common_expr_$outerCommonExprId"
 
     val plan = testRelation.select(outerExpr.as("col"))
     val rewrittenInnerExpr = (a + a).as(innerCommonExprName)
@@ -116,12 +125,12 @@ class RewriteWithExpressionSuite extends PlanTest {
 
   test("correlated nested WITH expression is not supported") {
     val b = testRelation.output.last
-    val outerCommonExprDef = CommonExpressionDef(b + b)
+    val outerCommonExprDef = CommonExpressionDef(b + b, CommonExpressionId(0))
     val outerRef = new CommonExpressionRef(outerCommonExprDef)
 
     val a = testRelation.output.head
     // The inner expression definition references the outer expression
-    val commonExprDef1 = CommonExpressionDef(a + a + outerRef)
+    val commonExprDef1 = CommonExpressionDef(a + a + outerRef, 
CommonExpressionId(1))
     val ref1 = new CommonExpressionRef(commonExprDef1)
     val innerExpr1 = With(ref1 + ref1, Seq(commonExprDef1))
 
@@ -139,10 +148,12 @@ class RewriteWithExpressionSuite extends PlanTest {
 
   test("WITH expression in filter") {
     val a = testRelation.output.head
-    val commonExprDef = CommonExpressionDef(a + a)
-    val ref = new CommonExpressionRef(commonExprDef)
-    val plan = testRelation.where(With(ref < 10 && ref > 0, 
Seq(commonExprDef)))
-    val commonExprName = "_common_expr_0"
+    val condition = With(a + a) { case Seq(ref) =>
+      ref < 10 && ref > 0
+    }
+    val plan = testRelation.where(condition)
+    val commonExprId = condition.defs.head.id.id
+    val commonExprName = s"_common_expr_$commonExprId"
     comparePlans(
       Optimizer.execute(plan),
       testRelation
@@ -155,11 +166,12 @@ class RewriteWithExpressionSuite extends PlanTest {
 
   test("WITH expression in join condition: only reference left child") {
     val a = testRelation.output.head
-    val commonExprDef = CommonExpressionDef(a + a)
-    val ref = new CommonExpressionRef(commonExprDef)
-    val condition = With(ref < 10 && ref > 0, Seq(commonExprDef))
+    val condition = With(a + a) { case Seq(ref) =>
+      ref < 10 && ref > 0
+    }
     val plan = testRelation.join(testRelation2, condition = Some(condition))
-    val commonExprName = "_common_expr_0"
+    val commonExprId = condition.defs.head.id.id
+    val commonExprName = s"_common_expr_$commonExprId"
     comparePlans(
       Optimizer.execute(plan),
       testRelation
@@ -172,11 +184,12 @@ class RewriteWithExpressionSuite extends PlanTest {
 
   test("WITH expression in join condition: only reference right child") {
     val x = testRelation2.output.head
-    val commonExprDef = CommonExpressionDef(x + x)
-    val ref = new CommonExpressionRef(commonExprDef)
-    val condition = With(ref < 10 && ref > 0, Seq(commonExprDef))
+    val condition = With(x + x) { case Seq(ref) =>
+      ref < 10 && ref > 0
+    }
     val plan = testRelation.join(testRelation2, condition = Some(condition))
-    val commonExprName = "_common_expr_0"
+    val commonExprId = condition.defs.head.id.id
+    val commonExprName = s"_common_expr_$commonExprId"
     comparePlans(
       Optimizer.execute(plan),
       testRelation
@@ -192,9 +205,9 @@ class RewriteWithExpressionSuite extends PlanTest {
   test("WITH expression in join condition: reference both children") {
     val a = testRelation.output.head
     val x = testRelation2.output.head
-    val commonExprDef = CommonExpressionDef(a + x)
-    val ref = new CommonExpressionRef(commonExprDef)
-    val condition = With(ref < 10 && ref > 0, Seq(commonExprDef))
+    val condition = With(a + x) { case Seq(ref) =>
+      ref < 10 && ref > 0
+    }
     val plan = testRelation.join(testRelation2, condition = Some(condition))
     comparePlans(
       Optimizer.execute(plan),
@@ -209,17 +222,20 @@ class RewriteWithExpressionSuite extends PlanTest {
 
   test("WITH expression inside conditional expression") {
     val a = testRelation.output.head
-    val commonExprDef = CommonExpressionDef(a + a)
-    val ref = new CommonExpressionRef(commonExprDef)
-    val expr = Coalesce(Seq(a, With(ref * ref, Seq(commonExprDef))))
+    val expr = Coalesce(Seq(a, With(a + a) { case Seq(ref) =>
+      ref * ref
+    }))
     val inlinedExpr = Coalesce(Seq(a, (a + a) * (a + a)))
     val plan = testRelation.select(expr.as("col"))
     // With in the conditional branches is always inlined.
     comparePlans(Optimizer.execute(plan), 
testRelation.select(inlinedExpr.as("col")))
 
-    val expr2 = Coalesce(Seq(With(ref * ref, Seq(commonExprDef)), a))
+    val expr2 = Coalesce(Seq(With(a + a) { case Seq(ref) =>
+      ref * ref
+    }, a))
     val plan2 = testRelation.select(expr2.as("col"))
-    val commonExprName = "_common_expr_0"
+    val commonExprId = expr2.children.head.asInstanceOf[With].defs.head.id.id
+    val commonExprName = s"_common_expr_$commonExprId"
     // With in the always-evaluated branches can still be optimized.
     comparePlans(
       Optimizer.execute(plan2),
@@ -229,4 +245,125 @@ class RewriteWithExpressionSuite extends PlanTest {
         .analyze
     )
   }
+
+  test("WITH expression in grouping exprs") {
+    val a = testRelation.output.head
+    val expr1 = With(a + 1) { case Seq(ref) =>
+      ref * ref
+    }
+    val expr2 = With(a + 1) { case Seq(ref) =>
+      ref * ref
+    }
+    val expr3 = With(a + 1) { case Seq(ref) =>
+      ref * ref
+    }
+    val plan = testRelation.groupBy(expr1)(
+      (expr2 + 2).as("col1"),
+      count(expr3 - 3).as("col2")
+    )
+    val commonExpr1Id = expr1.defs.head.id.id
+    val commonExpr1Name = s"_common_expr_$commonExpr1Id"
+    // Note that the common expression in expr2 gets de-duplicated by 
PullOutGroupingExpressions.
+    val commonExpr3Id = expr3.defs.head.id.id
+    val commonExpr3Name = s"_common_expr_$commonExpr3Id"
+    val groupingExprName = "_groupingexpression"
+    val aggExprName = "_aggregateexpression"
+    comparePlans(
+      Optimizer.execute(plan),
+      testRelation
+        .select(testRelation.output :+ (a + 1).as(commonExpr1Name): _*)
+        .select(testRelation.output :+
+          ($"$commonExpr1Name" * $"$commonExpr1Name").as(groupingExprName): _*)
+        .select(testRelation.output ++ Seq($"$groupingExprName", (a + 
1).as(commonExpr3Name)): _*)
+        .groupBy($"$groupingExprName")(
+          $"$groupingExprName",
+          count($"$commonExpr3Name" * $"$commonExpr3Name" - 3).as(aggExprName)
+        )
+        .select(($"$groupingExprName" + 2).as("col1"), 
$"`$aggExprName`".as("col2"))
+        .analyze
+    )
+    // Running CollapseProject after the rule cleans up the unnecessary 
projections.
+    comparePlans(
+      CollapseProject(Optimizer.execute(plan)),
+      testRelation
+        .select(testRelation.output :+ (a + 1).as(commonExpr1Name): _*)
+        .select(testRelation.output ++ Seq(
+          ($"$commonExpr1Name" * $"$commonExpr1Name").as(groupingExprName),
+          (a + 1).as(commonExpr3Name)): _*)
+        .groupBy($"$groupingExprName")(
+          ($"$groupingExprName" + 2).as("col1"),
+          count($"$commonExpr3Name" * $"$commonExpr3Name" - 3).as("col2")
+        )
+        .analyze
+    )
+  }
+
+  test("WITH expression in aggregate exprs") {
+    val Seq(a, b) = testRelation.output
+    val expr1 = With(a + 1) { case Seq(ref) =>
+      ref * ref
+    }
+    val expr2 = With(b + 2) { case Seq(ref) =>
+      ref * ref
+    }
+    val plan = testRelation.groupBy(a)(
+      (a + 3).as("col1"),
+      expr1.as("col2"),
+      max(expr2).as("col3")
+    )
+    val commonExpr1Id = expr1.defs.head.id.id
+    val commonExpr1Name = s"_common_expr_$commonExpr1Id"
+    val commonExpr2Id = expr2.defs.head.id.id
+    val commonExpr2Name = s"_common_expr_$commonExpr2Id"
+    val aggExprName = "_aggregateexpression"
+    comparePlans(
+      Optimizer.execute(plan),
+      testRelation
+        .select(testRelation.output :+ (b + 2).as(commonExpr2Name): _*)
+        .groupBy(a)(a, max($"$commonExpr2Name" * 
$"$commonExpr2Name").as(aggExprName))
+        .select(a, $"`$aggExprName`", (a + 1).as(commonExpr1Name))
+        .select(
+          (a + 3).as("col1"),
+          ($"$commonExpr1Name" * $"$commonExpr1Name").as("col2"),
+          $"`$aggExprName`".as("col3")
+        )
+        .analyze
+    )
+  }
+
+  test("WITH common expression is aggregate function") {
+    val a = testRelation.output.head
+    val expr = With(count(a - 1)) { case Seq(ref) =>
+      ref * ref
+    }
+    val plan = testRelation.groupBy(a)(
+      (a - 1).as("col1"),
+      expr.as("col2")
+    )
+    val aggExprName = "_aggregateexpression"
+    comparePlans(
+      Optimizer.execute(plan),
+      testRelation
+        .groupBy(a)(a, count(a - 1).as(aggExprName))
+        .select(
+          (a - 1).as("col1"),
+          ($"$aggExprName" * $"$aggExprName").as("col2")
+        )
+        .analyze
+    )
+  }
+
+  test("aggregate functions in child of WITH expression is not supported") {
+    val a = testRelation.output.head
+    intercept[java.lang.AssertionError] {
+      val expr = With(a - 1) { case Seq(ref) =>
+        sum(ref * ref)
+      }
+      val plan = testRelation.groupBy(a)(
+        (a - 1).as("col1"),
+        expr.as("col2")
+      )
+      Optimizer.execute(plan)
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to