This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 8a0927c07a14 [SPARK-48307][SQL] InlineCTE should keep not-inlined 
relations in the original WithCTE node
8a0927c07a14 is described below

commit 8a0927c07a1483bcd9125bdc2062a63759b0a337
Author: Wenchen Fan <wenc...@databricks.com>
AuthorDate: Tue Jun 4 15:04:22 2024 -0700

    [SPARK-48307][SQL] InlineCTE should keep not-inlined relations in the 
original WithCTE node
    
    ### What changes were proposed in this pull request?
    
    I noticed an outdated comment in the rule `InlineCTE`
    ```
          // CTEs in SQL Commands have been inlined by `CTESubstitution` 
already, so it is safe to add
          // WithCTE as top node here.
    ```
    
    This is not true anymore after https://github.com/apache/spark/pull/42036 . 
It's not a big deal as we replace not-inlined CTE relations with `Repartition` 
during optimization, so it doesn't matter where we put the `WithCTE` node with 
not-inlined CTE relations, as it will disappear eventually. But it's still 
better to keep it at its original place, as third-party rules may be sensitive 
about the plan shape.
    
    ### Why are the changes needed?
    
    to keep the plan shape as much as can after inlining CTE relations.
    
    ### Does this PR introduce _any_ user-facing change?
    
    no
    
    ### How was this patch tested?
    
    new test
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    no
    
    Closes #46617 from cloud-fan/cte.
    
    Lead-authored-by: Wenchen Fan <wenc...@databricks.com>
    Co-authored-by: Wenchen Fan <cloud0...@gmail.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../sql/catalyst/analysis/CheckAnalysis.scala      |  45 +------
 .../spark/sql/catalyst/optimizer/InlineCTE.scala   | 133 +++++++++++++--------
 .../sql/catalyst/optimizer/InlineCTESuite.scala    |  42 +++++++
 3 files changed, 132 insertions(+), 88 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 1c2baa78be1b..8c380a7228c6 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -143,50 +143,17 @@ trait CheckAnalysis extends PredicateHelper with 
LookupCatalog with QueryErrorsB
       errorClass, missingCol, orderedCandidates, a.origin)
   }
 
-  private def checkUnreferencedCTERelations(
-      cteMap: mutable.Map[Long, (CTERelationDef, Int, mutable.Map[Long, Int])],
-      visited: mutable.Map[Long, Boolean],
-      danglingCTERelations: mutable.ArrayBuffer[CTERelationDef],
-      cteId: Long): Unit = {
-    if (visited(cteId)) {
-      return
-    }
-    val (cteDef, _, refMap) = cteMap(cteId)
-    refMap.foreach { case (id, _) =>
-      checkUnreferencedCTERelations(cteMap, visited, danglingCTERelations, id)
-    }
-    danglingCTERelations.append(cteDef)
-    visited(cteId) = true
-  }
-
   def checkAnalysis(plan: LogicalPlan): Unit = {
-    val inlineCTE = InlineCTE(alwaysInline = true)
-    val cteMap = mutable.HashMap.empty[Long, (CTERelationDef, Int, 
mutable.Map[Long, Int])]
-    inlineCTE.buildCTEMap(plan, cteMap)
-    val danglingCTERelations = mutable.ArrayBuffer.empty[CTERelationDef]
-    val visited: mutable.Map[Long, Boolean] = 
mutable.Map.empty.withDefaultValue(false)
-    // If a CTE relation is never used, it will disappear after inline. Here 
we explicitly collect
-    // these dangling CTE relations, and put them back in the main query, to 
make sure the entire
-    // query plan is valid.
-    cteMap.foreach { case (cteId, (_, refCount, _)) =>
-      // If a CTE relation ref count is 0, the other CTE relations that 
reference it should also be
-      // collected. This code will also guarantee the leaf relations that do 
not reference
-      // any others are collected first.
-      if (refCount == 0) {
-        checkUnreferencedCTERelations(cteMap, visited, danglingCTERelations, 
cteId)
-      }
-    }
-    // Inline all CTEs in the plan to help check query plan structures in 
subqueries.
-    var inlinedPlan: LogicalPlan = plan
-    try {
-      inlinedPlan = inlineCTE(plan)
+    // We should inline all CTE relations to restore the original plan shape, 
as the analysis check
+    // may need to match certain plan shapes. For dangling CTE relations, they 
will still be kept
+    // in the original `WithCTE` node, as we need to perform analysis check 
for them as well.
+    val inlineCTE = InlineCTE(alwaysInline = true, keepDanglingRelations = 
true)
+    val inlinedPlan: LogicalPlan = try {
+      inlineCTE(plan)
     } catch {
       case e: AnalysisException =>
         throw new ExtendedAnalysisException(e, plan)
     }
-    if (danglingCTERelations.nonEmpty) {
-      inlinedPlan = WithCTE(inlinedPlan, danglingCTERelations.toSeq)
-    }
     try {
       checkAnalysis0(inlinedPlan)
     } catch {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala
index 8d7ff4cbf163..50828b945bb4 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala
@@ -37,23 +37,19 @@ import 
org.apache.spark.sql.catalyst.trees.TreePattern.{CTE, PLAN_EXPRESSION}
  * query level.
  *
  * @param alwaysInline if true, inline all CTEs in the query plan.
+ * @param keepDanglingRelations if true, dangling CTE relations will be kept 
in the original
+ *                              `WithCTE` node.
  */
-case class InlineCTE(alwaysInline: Boolean = false) extends Rule[LogicalPlan] {
+case class InlineCTE(
+    alwaysInline: Boolean = false,
+    keepDanglingRelations: Boolean = false) extends Rule[LogicalPlan] {
 
   override def apply(plan: LogicalPlan): LogicalPlan = {
     if (!plan.isInstanceOf[Subquery] && plan.containsPattern(CTE)) {
-      val cteMap = mutable.SortedMap.empty[Long, (CTERelationDef, Int, 
mutable.Map[Long, Int])]
+      val cteMap = mutable.SortedMap.empty[Long, CTEReferenceInfo]
       buildCTEMap(plan, cteMap)
       cleanCTEMap(cteMap)
-      val notInlined = mutable.ArrayBuffer.empty[CTERelationDef]
-      val inlined = inlineCTE(plan, cteMap, notInlined)
-      // CTEs in SQL Commands have been inlined by `CTESubstitution` already, 
so it is safe to add
-      // WithCTE as top node here.
-      if (notInlined.isEmpty) {
-        inlined
-      } else {
-        WithCTE(inlined, notInlined.toSeq)
-      }
+      inlineCTE(plan, cteMap)
     } else {
       plan
     }
@@ -74,22 +70,23 @@ case class InlineCTE(alwaysInline: Boolean = false) extends 
Rule[LogicalPlan] {
    *
    * @param plan The plan to collect the CTEs from
    * @param cteMap A mutable map that accumulates the CTEs and their reference 
information by CTE
-   *               ids. The value of the map is tuple whose elements are:
-   *               - The CTE definition
-   *               - The number of incoming references to the CTE. This 
includes references from
-   *                 other CTEs and regular places.
-   *               - A mutable inner map that tracks outgoing references 
(counts) to other CTEs.
+   *               ids.
    * @param outerCTEId While collecting the map we use this optional CTE id to 
identify the
    *                   current outer CTE.
    */
-  def buildCTEMap(
+  private def buildCTEMap(
       plan: LogicalPlan,
-      cteMap: mutable.Map[Long, (CTERelationDef, Int, mutable.Map[Long, Int])],
+      cteMap: mutable.Map[Long, CTEReferenceInfo],
       outerCTEId: Option[Long] = None): Unit = {
     plan match {
       case WithCTE(child, cteDefs) =>
         cteDefs.foreach { cteDef =>
-          cteMap(cteDef.id) = (cteDef, 0, 
mutable.Map.empty.withDefaultValue(0))
+          cteMap(cteDef.id) = CTEReferenceInfo(
+            cteDef = cteDef,
+            refCount = 0,
+            outgoingRefs = mutable.Map.empty.withDefaultValue(0),
+            shouldInline = true
+          )
         }
         cteDefs.foreach { cteDef =>
           buildCTEMap(cteDef, cteMap, Some(cteDef.id))
@@ -97,11 +94,9 @@ case class InlineCTE(alwaysInline: Boolean = false) extends 
Rule[LogicalPlan] {
         buildCTEMap(child, cteMap, outerCTEId)
 
       case ref: CTERelationRef =>
-        val (cteDef, refCount, refMap) = cteMap(ref.cteId)
-        cteMap(ref.cteId) = (cteDef, refCount + 1, refMap)
+        cteMap(ref.cteId) = cteMap(ref.cteId).withRefCountIncreased(1)
         outerCTEId.foreach { cteId =>
-          val (_, _, outerRefMap) = cteMap(cteId)
-          outerRefMap(ref.cteId) += 1
+          cteMap(cteId).increaseOutgoingRefCount(ref.cteId, 1)
         }
 
       case _ =>
@@ -129,15 +124,12 @@ case class InlineCTE(alwaysInline: Boolean = false) 
extends Rule[LogicalPlan] {
    * @param cteMap A mutable map that accumulates the CTEs and their reference 
information by CTE
    *               ids. Needs to be sorted to speed up cleaning.
    */
-  private def cleanCTEMap(
-      cteMap: mutable.SortedMap[Long, (CTERelationDef, Int, mutable.Map[Long, 
Int])]
-    ) = {
+  private def cleanCTEMap(cteMap: mutable.SortedMap[Long, CTEReferenceInfo]): 
Unit = {
     cteMap.keys.toSeq.reverse.foreach { currentCTEId =>
-      val (_, currentRefCount, refMap) = cteMap(currentCTEId)
-      if (currentRefCount == 0) {
-        refMap.foreach { case (referencedCTEId, uselessRefCount) =>
-          val (cteDef, refCount, refMap) = cteMap(referencedCTEId)
-          cteMap(referencedCTEId) = (cteDef, refCount - uselessRefCount, 
refMap)
+      val refInfo = cteMap(currentCTEId)
+      if (refInfo.refCount == 0) {
+        refInfo.outgoingRefs.foreach { case (referencedCTEId, uselessRefCount) 
=>
+          cteMap(referencedCTEId) = 
cteMap(referencedCTEId).withRefCountDecreased(uselessRefCount)
         }
       }
     }
@@ -145,30 +137,45 @@ case class InlineCTE(alwaysInline: Boolean = false) 
extends Rule[LogicalPlan] {
 
   private def inlineCTE(
       plan: LogicalPlan,
-      cteMap: mutable.Map[Long, (CTERelationDef, Int, mutable.Map[Long, Int])],
-      notInlined: mutable.ArrayBuffer[CTERelationDef]): LogicalPlan = {
+      cteMap: mutable.Map[Long, CTEReferenceInfo]): LogicalPlan = {
     plan match {
       case WithCTE(child, cteDefs) =>
-        cteDefs.foreach { cteDef =>
-          val (cte, refCount, refMap) = cteMap(cteDef.id)
-          if (refCount > 0) {
-            val inlined = cte.copy(child = inlineCTE(cte.child, cteMap, 
notInlined))
-            cteMap(cteDef.id) = (inlined, refCount, refMap)
-            if (!shouldInline(inlined, refCount)) {
-              notInlined.append(inlined)
-            }
+        val remainingDefs = cteDefs.filter { cteDef =>
+          val refInfo = cteMap(cteDef.id)
+          if (refInfo.refCount > 0) {
+            val newDef = refInfo.cteDef.copy(child = 
inlineCTE(refInfo.cteDef.child, cteMap))
+            val inlineDecision = shouldInline(newDef, refInfo.refCount)
+            cteMap(cteDef.id) = cteMap(cteDef.id).copy(
+              cteDef = newDef, shouldInline = inlineDecision
+            )
+            // Retain the not-inlined CTE relations in place.
+            !inlineDecision
+          } else {
+            keepDanglingRelations
           }
         }
-        inlineCTE(child, cteMap, notInlined)
+        val inlined = inlineCTE(child, cteMap)
+        if (remainingDefs.isEmpty) {
+          inlined
+        } else {
+          WithCTE(inlined, remainingDefs)
+        }
 
       case ref: CTERelationRef =>
-        val (cteDef, refCount, _) = cteMap(ref.cteId)
-        if (shouldInline(cteDef, refCount)) {
-          if (ref.outputSet == cteDef.outputSet) {
-            cteDef.child
+        val refInfo = cteMap(ref.cteId)
+        if (refInfo.shouldInline) {
+          if (ref.outputSet == refInfo.cteDef.outputSet) {
+            refInfo.cteDef.child
           } else {
             val ctePlan = DeduplicateRelations(
-              Join(cteDef.child, cteDef.child, Inner, None, JoinHint(None, 
None))).children(1)
+              Join(
+                refInfo.cteDef.child,
+                refInfo.cteDef.child,
+                Inner,
+                None,
+                JoinHint(None, None)
+              )
+            ).children(1)
             val projectList = ref.output.zip(ctePlan.output).map { case 
(tgtAttr, srcAttr) =>
               if (srcAttr.semanticEquals(tgtAttr)) {
                 tgtAttr
@@ -184,13 +191,41 @@ case class InlineCTE(alwaysInline: Boolean = false) 
extends Rule[LogicalPlan] {
 
       case _ if plan.containsPattern(CTE) =>
         plan
-          .withNewChildren(plan.children.map(child => inlineCTE(child, cteMap, 
notInlined)))
+          .withNewChildren(plan.children.map(child => inlineCTE(child, 
cteMap)))
           
.transformExpressionsWithPruning(_.containsAllPatterns(PLAN_EXPRESSION, CTE)) {
             case e: SubqueryExpression =>
-              e.withNewPlan(inlineCTE(e.plan, cteMap, notInlined))
+              e.withNewPlan(inlineCTE(e.plan, cteMap))
           }
 
       case _ => plan
     }
   }
 }
+
+/**
+ * The bookkeeping information for tracking CTE relation references.
+ *
+ * @param cteDef The CTE relation definition
+ * @param refCount The number of incoming references to this CTE relation. 
This includes references
+ *                 from other CTE relations and regular places.
+ * @param outgoingRefs A mutable map that tracks outgoing reference counts to 
other CTE relations.
+ * @param shouldInline If true, this CTE relation should be inlined in the 
places that reference it.
+ */
+case class CTEReferenceInfo(
+    cteDef: CTERelationDef,
+    refCount: Int,
+    outgoingRefs: mutable.Map[Long, Int],
+    shouldInline: Boolean) {
+
+  def withRefCountIncreased(count: Int): CTEReferenceInfo = {
+    copy(refCount = refCount + count)
+  }
+
+  def withRefCountDecreased(count: Int): CTEReferenceInfo = {
+    copy(refCount = refCount - count)
+  }
+
+  def increaseOutgoingRefCount(cteDefId: Long, count: Int): Unit = {
+    outgoingRefs(cteDefId) += count
+  }
+}
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTESuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTESuite.scala
new file mode 100644
index 000000000000..9d775a5335c6
--- /dev/null
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTESuite.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.analysis.TestRelation
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical.{AppendData, 
CTERelationDef, CTERelationRef, LogicalPlan, OneRowRelation, WithCTE}
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+
+class InlineCTESuite extends PlanTest {
+
+  object Optimize extends RuleExecutor[LogicalPlan] {
+    val batches = Batch("inline CTE", FixedPoint(100), InlineCTE()) :: Nil
+  }
+
+  test("SPARK-48307: not-inlined CTE relation in command") {
+    val cteDef = CTERelationDef(OneRowRelation().select(rand(0).as("a")))
+    val cteRef = CTERelationRef(cteDef.id, cteDef.resolved, cteDef.output, 
cteDef.isStreaming)
+    val plan = AppendData.byName(
+      TestRelation(Seq($"a".double)),
+      WithCTE(cteRef.except(cteRef, isAll = true), Seq(cteDef))
+    ).analyze
+    comparePlans(Optimize.execute(plan), plan)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to