This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new c4b0c260bb13 [SPARK-47839][SQL] Fix aggregate bug in RewriteWithExpression c4b0c260bb13 is described below commit c4b0c260bb139f61901d5bd5f1d94dddaefc9207 Author: Kelvin Jiang <kelvin.ji...@databricks.com> AuthorDate: Thu Apr 18 09:56:10 2024 +0800 [SPARK-47839][SQL] Fix aggregate bug in RewriteWithExpression ### What changes were proposed in this pull request? - Fixes a bug where `RewriteWithExpression` can rewrite an `Aggregate` into an invalid one. The fix is done by separating out the "result expressions" from the "aggregate expressions" in the `Aggregate` node, and rewriting them separately. - Some QOL improvements around `With`: - Fix aliases created by `With` expression to use the `CommonExpressionId` to avoid duplicate aliases (added a conf to fall back to old behaviour, which is useful to keep the IDs consistent for golden files tests) - Implemented `QueryPlan.transformUpWithSubqueriesAndPruning` that the new logic depends on ### Why are the changes needed? See [JIRA ticket](https://issues.apache.org/jira/browse/SPARK-47839) for more details on the bug that this fixes. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added new unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46034 from kelvinjian-db/SPARK-47839-with-aggregate. Authored-by: Kelvin Jiang <kelvin.ji...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../explain-results/function_count_if.explain | 7 +- .../sql/connect/ProtoToParsedPlanTestSuite.scala | 1 + .../spark/sql/catalyst/expressions/With.scala | 6 +- .../catalyst/optimizer/RewriteWithExpression.scala | 70 +++++-- .../spark/sql/catalyst/plans/QueryPlan.scala | 24 +++ .../org/apache/spark/sql/internal/SQLConf.scala | 11 + .../optimizer/RewriteWithExpressionSuite.scala | 231 ++++++++++++++++----- 7 files changed, 281 insertions(+), 69 deletions(-) diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_count_if.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_count_if.explain index f2ada15eccb7..a9fd2eeb669a 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_count_if.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_count_if.explain @@ -1,3 +1,4 @@ -Aggregate [count(if ((_common_expr_0#0 = false)) null else _common_expr_0#0) AS count_if((a > 0))#0L] -+- Project [id#0L, a#0, b#0, d#0, e#0, f#0, g#0, (a#0 > 0) AS _common_expr_0#0] - +- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] +Project [_aggregateexpression#0L AS count_if((a > 0))#0L] ++- Aggregate [count(if ((_common_expr_0#0 = false)) null else _common_expr_0#0) AS _aggregateexpression#0L] + +- Project [id#0L, a#0, b#0, d#0, e#0, f#0, g#0, (a#0 > 0) AS _common_expr_0#0] + +- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala index cc9decb4c98b..d404779d7a92 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala @@ -126,6 +126,7 @@ class ProtoToParsedPlanTestSuite Connect.CONNECT_EXTENSIONS_EXPRESSION_CLASSES.key, "org.apache.spark.sql.connect.plugin.ExampleExpressionPlugin") .set(org.apache.spark.sql.internal.SQLConf.ANSI_ENABLED.key, false.toString) + .set(org.apache.spark.sql.internal.SQLConf.USE_COMMON_EXPR_ID_FOR_ALIAS.key, false.toString) } protected val suiteBaseResourcePath = commonResourcePath.resolve("query-tests") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala index 2745b663639f..14deedd9c70f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMON_EXPR_REF, TreePattern, WITH_EXPRESSION} +import org.apache.spark.sql.catalyst.trees.TreePattern.{AGGREGATE_EXPRESSION, COMMON_EXPR_REF, TreePattern, WITH_EXPRESSION} import org.apache.spark.sql.types.DataType /** @@ -27,6 +27,10 @@ import org.apache.spark.sql.types.DataType */ case class With(child: Expression, defs: Seq[CommonExpressionDef]) extends Expression with Unevaluable { + // We do not allow With to be created with an AggregateExpression in the child, as this would + // create a dangling CommonExpressionRef after rewriting it in RewriteWithExpression. + assert(!child.containsPattern(AGGREGATE_EXPRESSION)) + override val nodePatterns: Seq[TreePattern] = Seq(WITH_EXPRESSION) override def dataType: DataType = child.dataType override def nullable: Boolean = child.nullable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala index 934eadbcee55..393a66f7c1e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala @@ -21,36 +21,65 @@ import scala.collection.mutable import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, PlanHelper, Project} +import org.apache.spark.sql.catalyst.planning.PhysicalAggregation +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, PlanHelper, Project} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMON_EXPR_REF, WITH_EXPRESSION} +import org.apache.spark.sql.internal.SQLConf /** * Rewrites the `With` expressions by adding a `Project` to pre-evaluate the common expressions, or * just inline them if they are cheap. * + * Since this rule can introduce new `Project` operators, it is advised to run [[CollapseProject]] + * after this rule. + * * Note: For now we only use `With` in a few `RuntimeReplaceable` expressions. If we expand its * usage, we should support aggregate/window functions as well. */ object RewriteWithExpression extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = { - plan.transformDownWithSubqueriesAndPruning(_.containsPattern(WITH_EXPRESSION)) { + plan.transformUpWithSubqueriesAndPruning(_.containsPattern(WITH_EXPRESSION)) { + // For aggregates, separate the computation of the aggregations themselves from the final + // result by moving the final result computation into a projection above it. This prevents + // this rule from producing an invalid Aggregate operator. + case p @ PhysicalAggregation( + groupingExpressions, aggregateExpressions, resultExpressions, child) + if p.expressions.exists(_.containsPattern(WITH_EXPRESSION)) => + // PhysicalAggregation returns aggregateExpressions as attribute references, which we change + // to aliases so that they can be referred to by resultExpressions. + val aggExprs = aggregateExpressions.map( + ae => Alias(ae, "_aggregateexpression")(ae.resultId)) + val aggExprIds = aggExprs.map(_.exprId).toSet + val resExprs = resultExpressions.map(_.transform { + case a: AttributeReference if aggExprIds.contains(a.exprId) => + a.withName("_aggregateexpression") + }.asInstanceOf[NamedExpression]) + // Rewrite the projection and the aggregate separately and then piece them together. + val agg = Aggregate(groupingExpressions, groupingExpressions ++ aggExprs, child) + val rewrittenAgg = applyInternal(agg) + val proj = Project(resExprs, rewrittenAgg) + applyInternal(proj) case p if p.expressions.exists(_.containsPattern(WITH_EXPRESSION)) => - val inputPlans = p.children.toArray - var newPlan: LogicalPlan = p.mapExpressions { expr => - rewriteWithExprAndInputPlans(expr, inputPlans) - } - newPlan = newPlan.withNewChildren(inputPlans.toIndexedSeq) - // Since we add extra Projects with extra columns to pre-evaluate the common expressions, - // the current operator may have extra columns if it inherits the output columns from its - // child, and we need to project away the extra columns to keep the plan schema unchanged. - assert(p.output.length <= newPlan.output.length) - if (p.output.length < newPlan.output.length) { - assert(p.outputSet.subsetOf(newPlan.outputSet)) - Project(p.output, newPlan) - } else { - newPlan - } + applyInternal(p) + } + } + + private def applyInternal(p: LogicalPlan): LogicalPlan = { + val inputPlans = p.children.toArray + var newPlan: LogicalPlan = p.mapExpressions { expr => + rewriteWithExprAndInputPlans(expr, inputPlans) + } + newPlan = newPlan.withNewChildren(inputPlans.toIndexedSeq) + // Since we add extra Projects with extra columns to pre-evaluate the common expressions, + // the current operator may have extra columns if it inherits the output columns from its + // child, and we need to project away the extra columns to keep the plan schema unchanged. + assert(p.output.length <= newPlan.output.length) + if (p.output.length < newPlan.output.length) { + assert(p.outputSet.subsetOf(newPlan.outputSet)) + Project(p.output, newPlan) + } else { + newPlan } } @@ -93,7 +122,12 @@ object RewriteWithExpression extends Rule[LogicalPlan] { // if it's ref count is 1. refToExpr(id) = child } else { - val alias = Alias(child, s"_common_expr_$index")() + val aliasName = if (SQLConf.get.getConf(SQLConf.USE_COMMON_EXPR_ID_FOR_ALIAS)) { + s"_common_expr_${id.id}" + } else { + s"_common_expr_$index" + } + val alias = Alias(child, aliasName)() val fakeProj = Project(Seq(alias), inputPlans(childProjectionIndex)) if (PlanHelper.specialExpressionsInUnsupportedOperator(fakeProj).nonEmpty) { // We have to inline the common expression if it cannot be put in a Project. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 0f049103542e..505330d871cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -517,6 +517,30 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] transformDownWithSubqueriesAndPruning(AlwaysProcess.fn, UnknownRuleId)(f) } + /** + * Same as `transformUpWithSubqueries` except allows for pruning opportunities. + */ + def transformUpWithSubqueriesAndPruning( + cond: TreePatternBits => Boolean, + ruleId: RuleId = UnknownRuleId) + (f: PartialFunction[PlanType, PlanType]): PlanType = { + val g: PartialFunction[PlanType, PlanType] = new PartialFunction[PlanType, PlanType] { + override def isDefinedAt(x: PlanType): Boolean = true + + override def apply(plan: PlanType): PlanType = { + val transformed = plan.transformExpressionsUpWithPruning(t => + t.containsPattern(PLAN_EXPRESSION) && cond(t)) { + case planExpression: PlanExpression[PlanType@unchecked] => + val newPlan = planExpression.plan.transformUpWithSubqueriesAndPruning(cond, ruleId)(f) + planExpression.withNewPlan(newPlan) + } + f.applyOrElse[PlanType, PlanType](transformed, identity) + } + } + + transformUpWithPruning(cond, ruleId)(g) + } + /** * This method is the top-down (pre-order) counterpart of transformUpWithSubqueries. * Returns a copy of this node where the given partial function has been recursively applied diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 0691cd730939..1c7ae3d0bfa8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -3443,6 +3443,17 @@ object SQLConf { .booleanConf .createWithDefault(false) + val USE_COMMON_EXPR_ID_FOR_ALIAS = + buildConf("spark.sql.useCommonExprIdForAlias") + .internal() + .doc("When true, use the common expression ID for the alias when rewriting With " + + "expressions. Otherwise, use the index of the common expression definition. When true " + + "this avoids duplicate alias names, but is helpful to set to false for testing to ensure" + + "that alias names are consistent.") + .version("4.0.0") + .booleanConf + .createWithDefault(true) + val USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES = buildConf("spark.sql.defaultColumn.useNullsForMissingDefaultValues") .internal() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala index a386e9bf4efe..d482b18d9331 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Coalesce, CommonExpressionDef, CommonExpressionRef, With} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -29,7 +29,9 @@ import org.apache.spark.sql.types.IntegerType class RewriteWithExpressionSuite extends PlanTest { object Optimizer extends RuleExecutor[LogicalPlan] { - val batches = Batch("Rewrite With expression", Once, RewriteWithExpression) :: Nil + val batches = Batch("Rewrite With expression", Once, + PullOutGroupingExpressions, + RewriteWithExpression) :: Nil } private val testRelation = LocalRelation($"a".int, $"b".int) @@ -37,18 +39,21 @@ class RewriteWithExpressionSuite extends PlanTest { test("simple common expression") { val a = testRelation.output.head - val commonExprDef = CommonExpressionDef(a) - val ref = new CommonExpressionRef(commonExprDef) - val plan = testRelation.select(With(ref + ref, Seq(commonExprDef)).as("col")) + val expr = With(a) { case Seq(ref) => + ref + ref + } + val plan = testRelation.select(expr.as("col")) comparePlans(Optimizer.execute(plan), testRelation.select((a + a).as("col"))) } test("non-cheap common expression") { val a = testRelation.output.head - val commonExprDef = CommonExpressionDef(a + a) - val ref = new CommonExpressionRef(commonExprDef) - val plan = testRelation.select(With(ref * ref, Seq(commonExprDef)).as("col")) - val commonExprName = "_common_expr_0" + val expr = With(a + a) { case Seq(ref) => + ref * ref + } + val plan = testRelation.select(expr.as("col")) + val commonExprId = expr.defs.head.id.id + val commonExprName = s"_common_expr_$commonExprId" comparePlans( Optimizer.execute(plan), testRelation @@ -60,16 +65,18 @@ class RewriteWithExpressionSuite extends PlanTest { test("nested WITH expression in the definition expression") { val a = testRelation.output.head - val commonExprDef = CommonExpressionDef(a + a) - val ref = new CommonExpressionRef(commonExprDef) - val innerExpr = With(ref + ref, Seq(commonExprDef)) - val innerCommonExprName = "_common_expr_0" + val innerExpr = With(a + a) { case Seq(ref) => + ref + ref + } + val innerCommonExprId = innerExpr.defs.head.id.id + val innerCommonExprName = s"_common_expr_$innerCommonExprId" val b = testRelation.output.last - val outerCommonExprDef = CommonExpressionDef(innerExpr + b) - val outerRef = new CommonExpressionRef(outerCommonExprDef) - val outerExpr = With(outerRef * outerRef, Seq(outerCommonExprDef)) - val outerCommonExprName = "_common_expr_0" + val outerExpr = With(innerExpr + b) { case Seq(ref) => + ref * ref + } + val outerCommonExprId = outerExpr.defs.head.id.id + val outerCommonExprName = s"_common_expr_$outerCommonExprId" val plan = testRelation.select(outerExpr.as("col")) val rewrittenOuterExpr = ($"$innerCommonExprName" + $"$innerCommonExprName" + b) @@ -88,16 +95,18 @@ class RewriteWithExpressionSuite extends PlanTest { test("nested WITH expression in the main expression") { val a = testRelation.output.head - val commonExprDef = CommonExpressionDef(a + a) - val ref = new CommonExpressionRef(commonExprDef) - val innerExpr = With(ref + ref, Seq(commonExprDef)) - val innerCommonExprName = "_common_expr_0" + val innerExpr = With(a + a) { case Seq(ref) => + ref + ref + } + val innerCommonExprId = innerExpr.defs.head.id.id + val innerCommonExprName = s"_common_expr_$innerCommonExprId" val b = testRelation.output.last - val outerCommonExprDef = CommonExpressionDef(b + b) - val outerRef = new CommonExpressionRef(outerCommonExprDef) - val outerExpr = With(outerRef * outerRef + innerExpr, Seq(outerCommonExprDef)) - val outerCommonExprName = "_common_expr_0" + val outerExpr = With(b + b) { case Seq(ref) => + ref * ref + innerExpr + } + val outerCommonExprId = outerExpr.defs.head.id.id + val outerCommonExprName = s"_common_expr_$outerCommonExprId" val plan = testRelation.select(outerExpr.as("col")) val rewrittenInnerExpr = (a + a).as(innerCommonExprName) @@ -116,12 +125,12 @@ class RewriteWithExpressionSuite extends PlanTest { test("correlated nested WITH expression is not supported") { val b = testRelation.output.last - val outerCommonExprDef = CommonExpressionDef(b + b) + val outerCommonExprDef = CommonExpressionDef(b + b, CommonExpressionId(0)) val outerRef = new CommonExpressionRef(outerCommonExprDef) val a = testRelation.output.head // The inner expression definition references the outer expression - val commonExprDef1 = CommonExpressionDef(a + a + outerRef) + val commonExprDef1 = CommonExpressionDef(a + a + outerRef, CommonExpressionId(1)) val ref1 = new CommonExpressionRef(commonExprDef1) val innerExpr1 = With(ref1 + ref1, Seq(commonExprDef1)) @@ -139,10 +148,12 @@ class RewriteWithExpressionSuite extends PlanTest { test("WITH expression in filter") { val a = testRelation.output.head - val commonExprDef = CommonExpressionDef(a + a) - val ref = new CommonExpressionRef(commonExprDef) - val plan = testRelation.where(With(ref < 10 && ref > 0, Seq(commonExprDef))) - val commonExprName = "_common_expr_0" + val condition = With(a + a) { case Seq(ref) => + ref < 10 && ref > 0 + } + val plan = testRelation.where(condition) + val commonExprId = condition.defs.head.id.id + val commonExprName = s"_common_expr_$commonExprId" comparePlans( Optimizer.execute(plan), testRelation @@ -155,11 +166,12 @@ class RewriteWithExpressionSuite extends PlanTest { test("WITH expression in join condition: only reference left child") { val a = testRelation.output.head - val commonExprDef = CommonExpressionDef(a + a) - val ref = new CommonExpressionRef(commonExprDef) - val condition = With(ref < 10 && ref > 0, Seq(commonExprDef)) + val condition = With(a + a) { case Seq(ref) => + ref < 10 && ref > 0 + } val plan = testRelation.join(testRelation2, condition = Some(condition)) - val commonExprName = "_common_expr_0" + val commonExprId = condition.defs.head.id.id + val commonExprName = s"_common_expr_$commonExprId" comparePlans( Optimizer.execute(plan), testRelation @@ -172,11 +184,12 @@ class RewriteWithExpressionSuite extends PlanTest { test("WITH expression in join condition: only reference right child") { val x = testRelation2.output.head - val commonExprDef = CommonExpressionDef(x + x) - val ref = new CommonExpressionRef(commonExprDef) - val condition = With(ref < 10 && ref > 0, Seq(commonExprDef)) + val condition = With(x + x) { case Seq(ref) => + ref < 10 && ref > 0 + } val plan = testRelation.join(testRelation2, condition = Some(condition)) - val commonExprName = "_common_expr_0" + val commonExprId = condition.defs.head.id.id + val commonExprName = s"_common_expr_$commonExprId" comparePlans( Optimizer.execute(plan), testRelation @@ -192,9 +205,9 @@ class RewriteWithExpressionSuite extends PlanTest { test("WITH expression in join condition: reference both children") { val a = testRelation.output.head val x = testRelation2.output.head - val commonExprDef = CommonExpressionDef(a + x) - val ref = new CommonExpressionRef(commonExprDef) - val condition = With(ref < 10 && ref > 0, Seq(commonExprDef)) + val condition = With(a + x) { case Seq(ref) => + ref < 10 && ref > 0 + } val plan = testRelation.join(testRelation2, condition = Some(condition)) comparePlans( Optimizer.execute(plan), @@ -209,17 +222,20 @@ class RewriteWithExpressionSuite extends PlanTest { test("WITH expression inside conditional expression") { val a = testRelation.output.head - val commonExprDef = CommonExpressionDef(a + a) - val ref = new CommonExpressionRef(commonExprDef) - val expr = Coalesce(Seq(a, With(ref * ref, Seq(commonExprDef)))) + val expr = Coalesce(Seq(a, With(a + a) { case Seq(ref) => + ref * ref + })) val inlinedExpr = Coalesce(Seq(a, (a + a) * (a + a))) val plan = testRelation.select(expr.as("col")) // With in the conditional branches is always inlined. comparePlans(Optimizer.execute(plan), testRelation.select(inlinedExpr.as("col"))) - val expr2 = Coalesce(Seq(With(ref * ref, Seq(commonExprDef)), a)) + val expr2 = Coalesce(Seq(With(a + a) { case Seq(ref) => + ref * ref + }, a)) val plan2 = testRelation.select(expr2.as("col")) - val commonExprName = "_common_expr_0" + val commonExprId = expr2.children.head.asInstanceOf[With].defs.head.id.id + val commonExprName = s"_common_expr_$commonExprId" // With in the always-evaluated branches can still be optimized. comparePlans( Optimizer.execute(plan2), @@ -229,4 +245,125 @@ class RewriteWithExpressionSuite extends PlanTest { .analyze ) } + + test("WITH expression in grouping exprs") { + val a = testRelation.output.head + val expr1 = With(a + 1) { case Seq(ref) => + ref * ref + } + val expr2 = With(a + 1) { case Seq(ref) => + ref * ref + } + val expr3 = With(a + 1) { case Seq(ref) => + ref * ref + } + val plan = testRelation.groupBy(expr1)( + (expr2 + 2).as("col1"), + count(expr3 - 3).as("col2") + ) + val commonExpr1Id = expr1.defs.head.id.id + val commonExpr1Name = s"_common_expr_$commonExpr1Id" + // Note that the common expression in expr2 gets de-duplicated by PullOutGroupingExpressions. + val commonExpr3Id = expr3.defs.head.id.id + val commonExpr3Name = s"_common_expr_$commonExpr3Id" + val groupingExprName = "_groupingexpression" + val aggExprName = "_aggregateexpression" + comparePlans( + Optimizer.execute(plan), + testRelation + .select(testRelation.output :+ (a + 1).as(commonExpr1Name): _*) + .select(testRelation.output :+ + ($"$commonExpr1Name" * $"$commonExpr1Name").as(groupingExprName): _*) + .select(testRelation.output ++ Seq($"$groupingExprName", (a + 1).as(commonExpr3Name)): _*) + .groupBy($"$groupingExprName")( + $"$groupingExprName", + count($"$commonExpr3Name" * $"$commonExpr3Name" - 3).as(aggExprName) + ) + .select(($"$groupingExprName" + 2).as("col1"), $"`$aggExprName`".as("col2")) + .analyze + ) + // Running CollapseProject after the rule cleans up the unnecessary projections. + comparePlans( + CollapseProject(Optimizer.execute(plan)), + testRelation + .select(testRelation.output :+ (a + 1).as(commonExpr1Name): _*) + .select(testRelation.output ++ Seq( + ($"$commonExpr1Name" * $"$commonExpr1Name").as(groupingExprName), + (a + 1).as(commonExpr3Name)): _*) + .groupBy($"$groupingExprName")( + ($"$groupingExprName" + 2).as("col1"), + count($"$commonExpr3Name" * $"$commonExpr3Name" - 3).as("col2") + ) + .analyze + ) + } + + test("WITH expression in aggregate exprs") { + val Seq(a, b) = testRelation.output + val expr1 = With(a + 1) { case Seq(ref) => + ref * ref + } + val expr2 = With(b + 2) { case Seq(ref) => + ref * ref + } + val plan = testRelation.groupBy(a)( + (a + 3).as("col1"), + expr1.as("col2"), + max(expr2).as("col3") + ) + val commonExpr1Id = expr1.defs.head.id.id + val commonExpr1Name = s"_common_expr_$commonExpr1Id" + val commonExpr2Id = expr2.defs.head.id.id + val commonExpr2Name = s"_common_expr_$commonExpr2Id" + val aggExprName = "_aggregateexpression" + comparePlans( + Optimizer.execute(plan), + testRelation + .select(testRelation.output :+ (b + 2).as(commonExpr2Name): _*) + .groupBy(a)(a, max($"$commonExpr2Name" * $"$commonExpr2Name").as(aggExprName)) + .select(a, $"`$aggExprName`", (a + 1).as(commonExpr1Name)) + .select( + (a + 3).as("col1"), + ($"$commonExpr1Name" * $"$commonExpr1Name").as("col2"), + $"`$aggExprName`".as("col3") + ) + .analyze + ) + } + + test("WITH common expression is aggregate function") { + val a = testRelation.output.head + val expr = With(count(a - 1)) { case Seq(ref) => + ref * ref + } + val plan = testRelation.groupBy(a)( + (a - 1).as("col1"), + expr.as("col2") + ) + val aggExprName = "_aggregateexpression" + comparePlans( + Optimizer.execute(plan), + testRelation + .groupBy(a)(a, count(a - 1).as(aggExprName)) + .select( + (a - 1).as("col1"), + ($"$aggExprName" * $"$aggExprName").as("col2") + ) + .analyze + ) + } + + test("aggregate functions in child of WITH expression is not supported") { + val a = testRelation.output.head + intercept[java.lang.AssertionError] { + val expr = With(a - 1) { case Seq(ref) => + sum(ref * ref) + } + val plan = testRelation.groupBy(a)( + (a - 1).as("col1"), + expr.as("col2") + ) + Optimizer.execute(plan) + } + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org