cloud-fan commented on code in PR #55985:
URL: https://github.com/apache/spark/pull/55985#discussion_r3288769328


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala:
##########
@@ -62,6 +63,17 @@ object NormalizeCTEIds extends Rule[LogicalPlan] {
         unionLoop.copy(id = defIdToNewId(unionLoop.id))
       case unionLoopRef: UnionLoopRef if 
defIdToNewId.contains(unionLoopRef.loopId) =>
         unionLoopRef.copy(loopId = defIdToNewId(unionLoopRef.loopId))
+      case other => other
     }
+
+    normalizedPlan
+      .withNewChildren(normalizedPlan.children.map {
+        case withCTE: WithCTE => withCTE

Review Comment:
   The nested-`WithCTE` skip here is asymmetric — it only fires for `WithCTE`s 
reached as plan children. The `transformExpressionsDown` block on lines 74–76 
recurses into `subqueryExpression.plan` unconditionally, so if that plan IS a 
`WithCTE` at the top, its `cteDefs` and inner plan are walked and Refs are 
rewritten — the same double-rewrite pattern this PR fixes, just reached through 
a `SubqueryExpression` instead of a child. None of the existing tests trigger 
this (the test scenario reaches the nested `WithCTE` through 
`SubqueryAlias`/`View`, which are children), so it's a latent hole rather than 
a regression.
   
   Suggest hoisting the skip to a single guard at the top of `canonicalizeCTE` 
so it covers both paths uniformly:
   
   ```scala
   private def canonicalizeCTE(
       plan: LogicalPlan,
       defIdToNewId: mutable.Map[Long, Long]): LogicalPlan = plan match {
     // Nested WithCTEs are normalized separately by applyInternal — never
     // descend into one here. Safe at the normalization phase because a nested
     // WithCTE in the analyzed plan comes from plan composition (temp views,
     // DataFrame ops, etc.) and is self-contained: its bodies don't reference
     // outer-scope cteDef ids.
     case _: WithCTE => plan
     case ref: CTERelationRef if defIdToNewId.contains(ref.cteId) =>
       ref.copy(cteId = defIdToNewId(ref.cteId))
     case unionLoop: UnionLoop if defIdToNewId.contains(unionLoop.id) =>
       unionLoop.copy(id = defIdToNewId(unionLoop.id))
     case unionLoopRef: UnionLoopRef if 
defIdToNewId.contains(unionLoopRef.loopId) =>
       unionLoopRef.copy(loopId = defIdToNewId(unionLoopRef.loopId))
     case other =>
       other
         .withNewChildren(other.children.map(canonicalizeCTE(_, defIdToNewId)))
         .transformExpressionsDown {
           case sub: SubqueryExpression =>
             sub.withNewPlan(canonicalizeCTE(sub.plan, defIdToNewId))
         }
   }
   ```
   
   Two payoffs:
   - The skip is symmetric — any `WithCTE` reached through `canonicalizeCTE` 
returns immediately, regardless of whether it came via children or a subquery.
   - The reader sees the invariant in one line. The existing comment at lines 
58–59 ("For nested WithCTE, if defIndex didn't contain the cteId...") described 
the OLD code's safeguard, where `canonicalizeCTE` would walk into nested 
`WithCTE`s and rely on the `contains` guard to skip refs that didn't belong to 
the current scope. With the new code's child-skip, nested `WithCTE` refs are no 
longer visited that way, so that comment is somewhat misleading. The 
early-return version subsumes it.



##########
sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala:
##########
@@ -261,6 +261,45 @@ abstract class CTEInlineSuiteBase
     }
   }
 
+  test("SPARK-56921: plan normalization handles nested CTEs under union") {
+    withTempView("input", "common") {
+      Seq((1, 1, 10), (1, 2, 20), (2, 1, 30))
+        .toDF("a", "b", "value")
+        .createOrReplaceTempView("input")
+
+      sql(
+        s"""with cte_common as (
+           |  select a, b, sum(value) as value
+           |  from input
+           |  group by a, b
+           |)
+           |select * from cte_common
+         """.stripMargin).createOrReplaceTempView("common")
+
+      val left = sql(
+        s"""with cte_a as (
+           |  select a, sum(value) as value
+           |  from common
+           |  group by a
+           |)
+           |select a as id, value from cte_a
+         """.stripMargin)
+
+      val right = sql(
+        s"""with cte_b as (
+           |  select b, sum(value) as value
+           |  from common
+           |  group by b
+           |)
+           |select b as id, value from cte_b
+         """.stripMargin)
+
+      val df = left.union(right)
+      df.queryExecution.normalized

Review Comment:
   Minor: this call is functionally redundant — `checkAnswer` triggers 
`executedPlan` → `optimizedPlan` → `withCachedData` → `normalized` (see 
`QueryExecution.scala:321-353`), so the test still catches the same failure 
without this line. Fine to keep for documentation, but the PR description's 
framing that it "forces" something extra is a bit misleading. Either drop the 
line or add a brief comment that it's there to make the normalization step 
explicit.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to