This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new a6cda2302c29 [SPARK-45760][SQL][FOLLOWUP] Inline With inside conditional branches a6cda2302c29 is described below commit a6cda2302c2962072af104c5d012329b06cbf166 Author: Wenchen Fan <wenc...@databricks.com> AuthorDate: Tue Nov 28 12:53:13 2023 +0100 [SPARK-45760][SQL][FOLLOWUP] Inline With inside conditional branches ### What changes were proposed in this pull request? This is a followup of https://github.com/apache/spark/pull/43623 to fix a regression. For `With` inside conditional branches, they may not be evaluated at all and we should not pull out the common expressions into a `Project`, but just inline. ### Why are the changes needed? avoid perf regression ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? new test ### Was this patch authored or co-authored using generative AI tooling? No Closes #43978 from cloud-fan/with. Authored-by: Wenchen Fan <wenc...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../sql/catalyst/expressions/Expression.scala | 5 + .../expressions/conditionalExpressions.scala | 19 +++- .../sql/catalyst/expressions/nullExpressions.scala | 8 ++ .../catalyst/optimizer/RewriteWithExpression.scala | 119 ++++++++++++++------- .../optimizer/RewriteWithExpressionSuite.scala | 79 +++++++++++++- 5 files changed, 185 insertions(+), 45 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 0dc70c6c3947..2cc813bd3055 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -513,6 +513,11 @@ trait ConditionalExpression extends Expression { */ def alwaysEvaluatedInputs: Seq[Expression] + /** + * Return a copy of itself with a new `alwaysEvaluatedInputs`. + */ + def withNewAlwaysEvaluatedInputs(alwaysEvaluatedInputs: Seq[Expression]): ConditionalExpression + /** * Return groups of branches. For each group, at least one branch will be hit at runtime, * so that we can eagerly evaluate the common expressions of a group. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 28a7db51621f..9ee2f2bb4141 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -56,6 +56,10 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi */ override def alwaysEvaluatedInputs: Seq[Expression] = predicate :: Nil + override def withNewAlwaysEvaluatedInputs(alwaysEvaluatedInputs: Seq[Expression]): If = { + copy(predicate = alwaysEvaluatedInputs.head) + } + override def branchGroups: Seq[Seq[Expression]] = Seq(Seq(trueValue, falseValue)) final override val nodePatterns : Seq[TreePattern] = Seq(IF) @@ -165,8 +169,15 @@ case class CaseWhen( final override val nodePatterns : Seq[TreePattern] = Seq(CASE_WHEN) - override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = - super.legacyWithNewChildren(newChildren) + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): CaseWhen = { + if (newChildren.length % 2 == 0) { + copy(branches = newChildren.grouped(2).map { case Seq(a, b) => (a, b) }.toSeq) + } else { + copy( + branches = newChildren.dropRight(1).grouped(2).map { case Seq(a, b) => (a, b) }.toSeq, + elseValue = newChildren.lastOption) + } + } // both then and else expressions should be considered. @transient @@ -213,6 +224,10 @@ case class CaseWhen( */ override def alwaysEvaluatedInputs: Seq[Expression] = children.head :: Nil + override def withNewAlwaysEvaluatedInputs(alwaysEvaluatedInputs: Seq[Expression]): CaseWhen = { + withNewChildrenInternal(alwaysEvaluatedInputs.toIndexedSeq ++ children.drop(1)) + } + override def branchGroups: Seq[Seq[Expression]] = { // We look at subexpressions in conditions and values of `CaseWhen` separately. It is // because a subexpression in conditions will be run no matter which condition is matched diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 0e9e375b8acf..4ccb369f5e2b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -70,6 +70,10 @@ case class Coalesce(children: Seq[Expression]) */ override def alwaysEvaluatedInputs: Seq[Expression] = children.head :: Nil + override def withNewAlwaysEvaluatedInputs(alwaysEvaluatedInputs: Seq[Expression]): Coalesce = { + withNewChildrenInternal(alwaysEvaluatedInputs.toIndexedSeq ++ children.drop(1)) + } + override def branchGroups: Seq[Seq[Expression]] = if (children.length > 1) { // If there is only one child, the first child is already covered by // `alwaysEvaluatedInputs` and we should exclude it here. @@ -290,6 +294,10 @@ case class NaNvl(left: Expression, right: Expression) */ override def alwaysEvaluatedInputs: Seq[Expression] = left :: Nil + override def withNewAlwaysEvaluatedInputs(alwaysEvaluatedInputs: Seq[Expression]): NaNvl = { + copy(left = alwaysEvaluatedInputs.head) + } + override def branchGroups: Seq[Seq[Expression]] = Seq(children) override def eval(input: InternalRow): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala index c5bd71b4a7d1..cf2c77069a19 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.{Alias, CommonExpressionDef, CommonExpressionRef, Expression, With} +import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMON_EXPR_REF, WITH_EXPRESSION} @@ -35,56 +36,92 @@ object RewriteWithExpression extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = { plan.transformWithPruning(_.containsPattern(WITH_EXPRESSION)) { case p if p.expressions.exists(_.containsPattern(WITH_EXPRESSION)) => - var newChildren = p.children - var newPlan: LogicalPlan = p.transformExpressionsUp { - case With(child, defs) => - val refToExpr = mutable.HashMap.empty[Long, Expression] - val childProjections = Array.fill(newChildren.size)(mutable.ArrayBuffer.empty[Alias]) + val inputPlans = p.children.toArray + var newPlan: LogicalPlan = p.mapExpressions { expr => + rewriteWithExprAndInputPlans(expr, inputPlans) + } + newPlan = newPlan.withNewChildren(inputPlans.toIndexedSeq) + if (p.output == newPlan.output) { + newPlan + } else { + Project(p.output, newPlan) + } + } + } + + private def rewriteWithExprAndInputPlans( + e: Expression, + inputPlans: Array[LogicalPlan]): Expression = { + if (!e.containsPattern(WITH_EXPRESSION)) return e + e match { + case w: With => + // Rewrite nested With expressions first + val child = rewriteWithExprAndInputPlans(w.child, inputPlans) + val defs = w.defs.map(rewriteWithExprAndInputPlans(_, inputPlans)) + val refToExpr = mutable.HashMap.empty[Long, Expression] + val childProjections = Array.fill(inputPlans.length)(mutable.ArrayBuffer.empty[Alias]) + + defs.zipWithIndex.foreach { case (CommonExpressionDef(child, id), index) => + if (child.containsPattern(COMMON_EXPR_REF)) { + throw SparkException.internalError( + "Common expression definition cannot reference other Common expression definitions") + } - defs.zipWithIndex.foreach { case (CommonExpressionDef(child, id), index) => - if (CollapseProject.isCheap(child)) { - refToExpr(id) = child - } else { - val childProjectionIndex = newChildren.indexWhere( - c => child.references.subsetOf(c.outputSet) - ) - if (childProjectionIndex == -1) { - // When we cannot rewrite the common expressions, force to inline them so that the - // query can still run. This can happen if the join condition contains `With` and - // the common expression references columns from both join sides. - // TODO: things can go wrong if the common expression is nondeterministic. We - // don't fix it for now to match the old buggy behavior when certain - // `RuntimeReplaceable` did not use the `With` expression. - // TODO: we should calculate the ref count and also inline the common expression - // if it's ref count is 1. - refToExpr(id) = child - } else { - val alias = Alias(child, s"_common_expr_$index")() - childProjections(childProjectionIndex) += alias - refToExpr(id) = alias.toAttribute - } - } + if (CollapseProject.isCheap(child)) { + refToExpr(id) = child + } else { + val childProjectionIndex = inputPlans.indexWhere( + c => child.references.subsetOf(c.outputSet) + ) + if (childProjectionIndex == -1) { + // When we cannot rewrite the common expressions, force to inline them so that the + // query can still run. This can happen if the join condition contains `With` and + // the common expression references columns from both join sides. + // TODO: things can go wrong if the common expression is nondeterministic. We + // don't fix it for now to match the old buggy behavior when certain + // `RuntimeReplaceable` did not use the `With` expression. + // TODO: we should calculate the ref count and also inline the common expression + // if it's ref count is 1. + refToExpr(id) = child + } else { + val alias = Alias(child, s"_common_expr_$index")() + childProjections(childProjectionIndex) += alias + refToExpr(id) = alias.toAttribute } + } + } + + for (i <- inputPlans.indices) { + val projectList = childProjections(i) + if (projectList.nonEmpty) { + inputPlans(i) = Project(inputPlans(i).output ++ projectList, inputPlans(i)) + } + } - newChildren = newChildren.zip(childProjections).map { case (child, projections) => - if (projections.nonEmpty) { - Project(child.output ++ projections, child) - } else { - child - } + child.transformWithPruning(_.containsPattern(COMMON_EXPR_REF)) { + case ref: CommonExpressionRef => + if (!refToExpr.contains(ref.id)) { + throw SparkException.internalError("Undefined common expression id " + ref.id) } + refToExpr(ref.id) + } + case c: ConditionalExpression => + val newAlwaysEvaluatedInputs = c.alwaysEvaluatedInputs.map( + rewriteWithExprAndInputPlans(_, inputPlans)) + val newExpr = c.withNewAlwaysEvaluatedInputs(newAlwaysEvaluatedInputs) + // Use transformUp to handle nested With. + newExpr.transformUpWithPruning(_.containsPattern(WITH_EXPRESSION)) { + case With(child, defs) => + // For With in the conditional branches, they may not be evaluated at all and we can't + // pull the common expressions into a project which will always be evaluated. Inline it. + val refToExpr = defs.map(d => d.id -> d.child).toMap child.transformWithPruning(_.containsPattern(COMMON_EXPR_REF)) { case ref: CommonExpressionRef => refToExpr(ref.id) } } - newPlan = newPlan.withNewChildren(newChildren) - if (p.output == newPlan.output) { - newPlan - } else { - Project(p.output, newPlan) - } + case other => other.mapChildren(rewriteWithExprAndInputPlans(_, inputPlans)) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala index c625379eb5ff..a386e9bf4efe 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, CommonExpressionDef, CommonExpressionRef, With} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Coalesce, CommonExpressionDef, CommonExpressionRef, With} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -57,7 +58,7 @@ class RewriteWithExpressionSuite extends PlanTest { ) } - test("nested WITH expression") { + test("nested WITH expression in the definition expression") { val a = testRelation.output.head val commonExprDef = CommonExpressionDef(a + a) val ref = new CommonExpressionRef(commonExprDef) @@ -85,6 +86,57 @@ class RewriteWithExpressionSuite extends PlanTest { ) } + test("nested WITH expression in the main expression") { + val a = testRelation.output.head + val commonExprDef = CommonExpressionDef(a + a) + val ref = new CommonExpressionRef(commonExprDef) + val innerExpr = With(ref + ref, Seq(commonExprDef)) + val innerCommonExprName = "_common_expr_0" + + val b = testRelation.output.last + val outerCommonExprDef = CommonExpressionDef(b + b) + val outerRef = new CommonExpressionRef(outerCommonExprDef) + val outerExpr = With(outerRef * outerRef + innerExpr, Seq(outerCommonExprDef)) + val outerCommonExprName = "_common_expr_0" + + val plan = testRelation.select(outerExpr.as("col")) + val rewrittenInnerExpr = (a + a).as(innerCommonExprName) + val rewrittenOuterExpr = (b + b).as(outerCommonExprName) + val finalExpr = rewrittenOuterExpr.toAttribute * rewrittenOuterExpr.toAttribute + + (rewrittenInnerExpr.toAttribute + rewrittenInnerExpr.toAttribute) + comparePlans( + Optimizer.execute(plan), + testRelation + .select((testRelation.output :+ rewrittenInnerExpr): _*) + .select((testRelation.output :+ rewrittenInnerExpr.toAttribute :+ rewrittenOuterExpr): _*) + .select(finalExpr.as("col")) + .analyze + ) + } + + test("correlated nested WITH expression is not supported") { + val b = testRelation.output.last + val outerCommonExprDef = CommonExpressionDef(b + b) + val outerRef = new CommonExpressionRef(outerCommonExprDef) + + val a = testRelation.output.head + // The inner expression definition references the outer expression + val commonExprDef1 = CommonExpressionDef(a + a + outerRef) + val ref1 = new CommonExpressionRef(commonExprDef1) + val innerExpr1 = With(ref1 + ref1, Seq(commonExprDef1)) + + val outerExpr1 = With(outerRef + innerExpr1, Seq(outerCommonExprDef)) + intercept[SparkException](Optimizer.execute(testRelation.select(outerExpr1.as("col")))) + + val commonExprDef2 = CommonExpressionDef(a + a) + val ref2 = new CommonExpressionRef(commonExprDef2) + // The inner main expression references the outer expression + val innerExpr2 = With(ref2 + outerRef, Seq(commonExprDef1)) + + val outerExpr2 = With(outerRef + innerExpr2, Seq(outerCommonExprDef)) + intercept[SparkException](Optimizer.execute(testRelation.select(outerExpr2.as("col")))) + } + test("WITH expression in filter") { val a = testRelation.output.head val commonExprDef = CommonExpressionDef(a + a) @@ -154,4 +206,27 @@ class RewriteWithExpressionSuite extends PlanTest { ) ) } + + test("WITH expression inside conditional expression") { + val a = testRelation.output.head + val commonExprDef = CommonExpressionDef(a + a) + val ref = new CommonExpressionRef(commonExprDef) + val expr = Coalesce(Seq(a, With(ref * ref, Seq(commonExprDef)))) + val inlinedExpr = Coalesce(Seq(a, (a + a) * (a + a))) + val plan = testRelation.select(expr.as("col")) + // With in the conditional branches is always inlined. + comparePlans(Optimizer.execute(plan), testRelation.select(inlinedExpr.as("col"))) + + val expr2 = Coalesce(Seq(With(ref * ref, Seq(commonExprDef)), a)) + val plan2 = testRelation.select(expr2.as("col")) + val commonExprName = "_common_expr_0" + // With in the always-evaluated branches can still be optimized. + comparePlans( + Optimizer.execute(plan2), + testRelation + .select((testRelation.output :+ (a + a).as(commonExprName)): _*) + .select(Coalesce(Seq(($"$commonExprName" * $"$commonExprName"), a)).as("col")) + .analyze + ) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org