This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 8a0927c07a14 [SPARK-48307][SQL] InlineCTE should keep not-inlined relations in the original WithCTE node 8a0927c07a14 is described below commit 8a0927c07a1483bcd9125bdc2062a63759b0a337 Author: Wenchen Fan <wenc...@databricks.com> AuthorDate: Tue Jun 4 15:04:22 2024 -0700 [SPARK-48307][SQL] InlineCTE should keep not-inlined relations in the original WithCTE node ### What changes were proposed in this pull request? I noticed an outdated comment in the rule `InlineCTE` ``` // CTEs in SQL Commands have been inlined by `CTESubstitution` already, so it is safe to add // WithCTE as top node here. ``` This is not true anymore after https://github.com/apache/spark/pull/42036 . It's not a big deal as we replace not-inlined CTE relations with `Repartition` during optimization, so it doesn't matter where we put the `WithCTE` node with not-inlined CTE relations, as it will disappear eventually. But it's still better to keep it at its original place, as third-party rules may be sensitive about the plan shape. ### Why are the changes needed? to keep the plan shape as much as can after inlining CTE relations. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? new test ### Was this patch authored or co-authored using generative AI tooling? no Closes #46617 from cloud-fan/cte. Lead-authored-by: Wenchen Fan <wenc...@databricks.com> Co-authored-by: Wenchen Fan <cloud0...@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../sql/catalyst/analysis/CheckAnalysis.scala | 45 +------ .../spark/sql/catalyst/optimizer/InlineCTE.scala | 133 +++++++++++++-------- .../sql/catalyst/optimizer/InlineCTESuite.scala | 42 +++++++ 3 files changed, 132 insertions(+), 88 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 1c2baa78be1b..8c380a7228c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -143,50 +143,17 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB errorClass, missingCol, orderedCandidates, a.origin) } - private def checkUnreferencedCTERelations( - cteMap: mutable.Map[Long, (CTERelationDef, Int, mutable.Map[Long, Int])], - visited: mutable.Map[Long, Boolean], - danglingCTERelations: mutable.ArrayBuffer[CTERelationDef], - cteId: Long): Unit = { - if (visited(cteId)) { - return - } - val (cteDef, _, refMap) = cteMap(cteId) - refMap.foreach { case (id, _) => - checkUnreferencedCTERelations(cteMap, visited, danglingCTERelations, id) - } - danglingCTERelations.append(cteDef) - visited(cteId) = true - } - def checkAnalysis(plan: LogicalPlan): Unit = { - val inlineCTE = InlineCTE(alwaysInline = true) - val cteMap = mutable.HashMap.empty[Long, (CTERelationDef, Int, mutable.Map[Long, Int])] - inlineCTE.buildCTEMap(plan, cteMap) - val danglingCTERelations = mutable.ArrayBuffer.empty[CTERelationDef] - val visited: mutable.Map[Long, Boolean] = mutable.Map.empty.withDefaultValue(false) - // If a CTE relation is never used, it will disappear after inline. Here we explicitly collect - // these dangling CTE relations, and put them back in the main query, to make sure the entire - // query plan is valid. - cteMap.foreach { case (cteId, (_, refCount, _)) => - // If a CTE relation ref count is 0, the other CTE relations that reference it should also be - // collected. This code will also guarantee the leaf relations that do not reference - // any others are collected first. - if (refCount == 0) { - checkUnreferencedCTERelations(cteMap, visited, danglingCTERelations, cteId) - } - } - // Inline all CTEs in the plan to help check query plan structures in subqueries. - var inlinedPlan: LogicalPlan = plan - try { - inlinedPlan = inlineCTE(plan) + // We should inline all CTE relations to restore the original plan shape, as the analysis check + // may need to match certain plan shapes. For dangling CTE relations, they will still be kept + // in the original `WithCTE` node, as we need to perform analysis check for them as well. + val inlineCTE = InlineCTE(alwaysInline = true, keepDanglingRelations = true) + val inlinedPlan: LogicalPlan = try { + inlineCTE(plan) } catch { case e: AnalysisException => throw new ExtendedAnalysisException(e, plan) } - if (danglingCTERelations.nonEmpty) { - inlinedPlan = WithCTE(inlinedPlan, danglingCTERelations.toSeq) - } try { checkAnalysis0(inlinedPlan) } catch { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala index 8d7ff4cbf163..50828b945bb4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala @@ -37,23 +37,19 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{CTE, PLAN_EXPRESSION} * query level. * * @param alwaysInline if true, inline all CTEs in the query plan. + * @param keepDanglingRelations if true, dangling CTE relations will be kept in the original + * `WithCTE` node. */ -case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] { +case class InlineCTE( + alwaysInline: Boolean = false, + keepDanglingRelations: Boolean = false) extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = { if (!plan.isInstanceOf[Subquery] && plan.containsPattern(CTE)) { - val cteMap = mutable.SortedMap.empty[Long, (CTERelationDef, Int, mutable.Map[Long, Int])] + val cteMap = mutable.SortedMap.empty[Long, CTEReferenceInfo] buildCTEMap(plan, cteMap) cleanCTEMap(cteMap) - val notInlined = mutable.ArrayBuffer.empty[CTERelationDef] - val inlined = inlineCTE(plan, cteMap, notInlined) - // CTEs in SQL Commands have been inlined by `CTESubstitution` already, so it is safe to add - // WithCTE as top node here. - if (notInlined.isEmpty) { - inlined - } else { - WithCTE(inlined, notInlined.toSeq) - } + inlineCTE(plan, cteMap) } else { plan } @@ -74,22 +70,23 @@ case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] { * * @param plan The plan to collect the CTEs from * @param cteMap A mutable map that accumulates the CTEs and their reference information by CTE - * ids. The value of the map is tuple whose elements are: - * - The CTE definition - * - The number of incoming references to the CTE. This includes references from - * other CTEs and regular places. - * - A mutable inner map that tracks outgoing references (counts) to other CTEs. + * ids. * @param outerCTEId While collecting the map we use this optional CTE id to identify the * current outer CTE. */ - def buildCTEMap( + private def buildCTEMap( plan: LogicalPlan, - cteMap: mutable.Map[Long, (CTERelationDef, Int, mutable.Map[Long, Int])], + cteMap: mutable.Map[Long, CTEReferenceInfo], outerCTEId: Option[Long] = None): Unit = { plan match { case WithCTE(child, cteDefs) => cteDefs.foreach { cteDef => - cteMap(cteDef.id) = (cteDef, 0, mutable.Map.empty.withDefaultValue(0)) + cteMap(cteDef.id) = CTEReferenceInfo( + cteDef = cteDef, + refCount = 0, + outgoingRefs = mutable.Map.empty.withDefaultValue(0), + shouldInline = true + ) } cteDefs.foreach { cteDef => buildCTEMap(cteDef, cteMap, Some(cteDef.id)) @@ -97,11 +94,9 @@ case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] { buildCTEMap(child, cteMap, outerCTEId) case ref: CTERelationRef => - val (cteDef, refCount, refMap) = cteMap(ref.cteId) - cteMap(ref.cteId) = (cteDef, refCount + 1, refMap) + cteMap(ref.cteId) = cteMap(ref.cteId).withRefCountIncreased(1) outerCTEId.foreach { cteId => - val (_, _, outerRefMap) = cteMap(cteId) - outerRefMap(ref.cteId) += 1 + cteMap(cteId).increaseOutgoingRefCount(ref.cteId, 1) } case _ => @@ -129,15 +124,12 @@ case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] { * @param cteMap A mutable map that accumulates the CTEs and their reference information by CTE * ids. Needs to be sorted to speed up cleaning. */ - private def cleanCTEMap( - cteMap: mutable.SortedMap[Long, (CTERelationDef, Int, mutable.Map[Long, Int])] - ) = { + private def cleanCTEMap(cteMap: mutable.SortedMap[Long, CTEReferenceInfo]): Unit = { cteMap.keys.toSeq.reverse.foreach { currentCTEId => - val (_, currentRefCount, refMap) = cteMap(currentCTEId) - if (currentRefCount == 0) { - refMap.foreach { case (referencedCTEId, uselessRefCount) => - val (cteDef, refCount, refMap) = cteMap(referencedCTEId) - cteMap(referencedCTEId) = (cteDef, refCount - uselessRefCount, refMap) + val refInfo = cteMap(currentCTEId) + if (refInfo.refCount == 0) { + refInfo.outgoingRefs.foreach { case (referencedCTEId, uselessRefCount) => + cteMap(referencedCTEId) = cteMap(referencedCTEId).withRefCountDecreased(uselessRefCount) } } } @@ -145,30 +137,45 @@ case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] { private def inlineCTE( plan: LogicalPlan, - cteMap: mutable.Map[Long, (CTERelationDef, Int, mutable.Map[Long, Int])], - notInlined: mutable.ArrayBuffer[CTERelationDef]): LogicalPlan = { + cteMap: mutable.Map[Long, CTEReferenceInfo]): LogicalPlan = { plan match { case WithCTE(child, cteDefs) => - cteDefs.foreach { cteDef => - val (cte, refCount, refMap) = cteMap(cteDef.id) - if (refCount > 0) { - val inlined = cte.copy(child = inlineCTE(cte.child, cteMap, notInlined)) - cteMap(cteDef.id) = (inlined, refCount, refMap) - if (!shouldInline(inlined, refCount)) { - notInlined.append(inlined) - } + val remainingDefs = cteDefs.filter { cteDef => + val refInfo = cteMap(cteDef.id) + if (refInfo.refCount > 0) { + val newDef = refInfo.cteDef.copy(child = inlineCTE(refInfo.cteDef.child, cteMap)) + val inlineDecision = shouldInline(newDef, refInfo.refCount) + cteMap(cteDef.id) = cteMap(cteDef.id).copy( + cteDef = newDef, shouldInline = inlineDecision + ) + // Retain the not-inlined CTE relations in place. + !inlineDecision + } else { + keepDanglingRelations } } - inlineCTE(child, cteMap, notInlined) + val inlined = inlineCTE(child, cteMap) + if (remainingDefs.isEmpty) { + inlined + } else { + WithCTE(inlined, remainingDefs) + } case ref: CTERelationRef => - val (cteDef, refCount, _) = cteMap(ref.cteId) - if (shouldInline(cteDef, refCount)) { - if (ref.outputSet == cteDef.outputSet) { - cteDef.child + val refInfo = cteMap(ref.cteId) + if (refInfo.shouldInline) { + if (ref.outputSet == refInfo.cteDef.outputSet) { + refInfo.cteDef.child } else { val ctePlan = DeduplicateRelations( - Join(cteDef.child, cteDef.child, Inner, None, JoinHint(None, None))).children(1) + Join( + refInfo.cteDef.child, + refInfo.cteDef.child, + Inner, + None, + JoinHint(None, None) + ) + ).children(1) val projectList = ref.output.zip(ctePlan.output).map { case (tgtAttr, srcAttr) => if (srcAttr.semanticEquals(tgtAttr)) { tgtAttr @@ -184,13 +191,41 @@ case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] { case _ if plan.containsPattern(CTE) => plan - .withNewChildren(plan.children.map(child => inlineCTE(child, cteMap, notInlined))) + .withNewChildren(plan.children.map(child => inlineCTE(child, cteMap))) .transformExpressionsWithPruning(_.containsAllPatterns(PLAN_EXPRESSION, CTE)) { case e: SubqueryExpression => - e.withNewPlan(inlineCTE(e.plan, cteMap, notInlined)) + e.withNewPlan(inlineCTE(e.plan, cteMap)) } case _ => plan } } } + +/** + * The bookkeeping information for tracking CTE relation references. + * + * @param cteDef The CTE relation definition + * @param refCount The number of incoming references to this CTE relation. This includes references + * from other CTE relations and regular places. + * @param outgoingRefs A mutable map that tracks outgoing reference counts to other CTE relations. + * @param shouldInline If true, this CTE relation should be inlined in the places that reference it. + */ +case class CTEReferenceInfo( + cteDef: CTERelationDef, + refCount: Int, + outgoingRefs: mutable.Map[Long, Int], + shouldInline: Boolean) { + + def withRefCountIncreased(count: Int): CTEReferenceInfo = { + copy(refCount = refCount + count) + } + + def withRefCountDecreased(count: Int): CTEReferenceInfo = { + copy(refCount = refCount - count) + } + + def increaseOutgoingRefCount(cteDefId: Long, count: Int): Unit = { + outgoingRefs(cteDefId) += count + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTESuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTESuite.scala new file mode 100644 index 000000000000..9d775a5335c6 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTESuite.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.TestRelation +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CTERelationDef, CTERelationRef, LogicalPlan, OneRowRelation, WithCTE} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +class InlineCTESuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Batch("inline CTE", FixedPoint(100), InlineCTE()) :: Nil + } + + test("SPARK-48307: not-inlined CTE relation in command") { + val cteDef = CTERelationDef(OneRowRelation().select(rand(0).as("a"))) + val cteRef = CTERelationRef(cteDef.id, cteDef.resolved, cteDef.output, cteDef.isStreaming) + val plan = AppendData.byName( + TestRelation(Seq($"a".double)), + WithCTE(cteRef.except(cteRef, isAll = true), Seq(cteDef)) + ).analyze + comparePlans(Optimize.execute(plan), plan) + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org