This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 79d1cded8555 [SPARK-46337][SQL][CONNECT][PYTHON] Make 
`CTESubstitution` retain the `PLAN_ID_TAG`
79d1cded8555 is described below

commit 79d1cded8555c5a0cc97b76747753785477eab8f
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Sat Dec 9 16:12:14 2023 +0900

    [SPARK-46337][SQL][CONNECT][PYTHON] Make `CTESubstitution` retain the 
`PLAN_ID_TAG`
    
    ### What changes were proposed in this pull request?
    Make `CTESubstitution` retain the `PLAN_ID_TAG`
    
    ### Why are the changes needed?
    before this PR:
    ```
    df1 = spark.range(10)
    df2 = spark.sql("with dt as (select 1 as ida) select ida as id from dt")
    df1.join(df2, df1.id == df2.id)
    
    AnalysisException: When resolving 'id, fail to find subplan with plan_id=2 
in 'Join Inner, '`==`('id, 'id)
    :- Range (0, 10, step=1, splits=Some(12))
    +- WithCTE
       :- CTERelationDef 4, false
       :  +- SubqueryAlias dt
       :     +- Project [1 AS ida#22]
       :        +- OneRowRelation
       +- Project [ida#22 AS id#21]
          +- SubqueryAlias dt
             +- CTERelationRef 4, true, [ida#22], false
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    yes
    
    ### How was this patch tested?
    added ut
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #44268 from zhengruifeng/connect_plan_id_cte.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/sql/tests/connect/test_connect_basic.py     | 14 ++++++++++++++
 .../spark/sql/catalyst/analysis/CTESubstitution.scala      |  9 ++++++---
 .../sql/catalyst/analysis/ColumnResolutionHelper.scala     |  2 +-
 3 files changed, 21 insertions(+), 4 deletions(-)

diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 2431b948f9da..32cd4ed62495 100755
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -515,6 +515,20 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
         self.assertEqual(cdf7.schema, sdf7.schema)
         self.assertEqual(cdf7.collect(), sdf7.collect())
 
+    def test_join_with_cte(self):
+        cte_query = "with dt as (select 1 as ida) select ida as id from dt"
+
+        sdf1 = self.spark.range(10)
+        sdf2 = self.spark.sql(cte_query)
+        sdf3 = sdf1.join(sdf2, sdf1.id == sdf2.id)
+
+        cdf1 = self.connect.range(10)
+        cdf2 = self.connect.sql(cte_query)
+        cdf3 = cdf1.join(cdf2, cdf1.id == cdf2.id)
+
+        self.assertEqual(sdf3.schema, cdf3.schema)
+        self.assertEqual(sdf3.collect(), cdf3.collect())
+
     def test_invalid_column(self):
         # SPARK-41812: fail df1.select(df2.col)
         data1 = [Row(a=1, b=2, c=3)]
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala
index 2982d8477fcc..173c9d44a2b3 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala
@@ -149,10 +149,12 @@ object CTESubstitution extends Rule[LogicalPlan] {
       plan: LogicalPlan,
       cteDefs: ArrayBuffer[CTERelationDef]): LogicalPlan = {
     plan.resolveOperatorsUp {
-      case UnresolvedWith(child, relations) =>
+      case cte @ UnresolvedWith(child, relations) =>
         val resolvedCTERelations =
           resolveCTERelations(relations, isLegacy = true, forceInline = false, 
Seq.empty, cteDefs)
-        substituteCTE(child, alwaysInline = true, resolvedCTERelations)
+        val substituted = substituteCTE(child, alwaysInline = true, 
resolvedCTERelations)
+        substituted.copyTagsFrom(cte)
+        substituted
     }
   }
 
@@ -202,7 +204,7 @@ object CTESubstitution extends Rule[LogicalPlan] {
     var firstSubstituted: Option[LogicalPlan] = None
     val newPlan = plan.resolveOperatorsDownWithPruning(
         _.containsAnyPattern(UNRESOLVED_WITH, PLAN_EXPRESSION)) {
-      case UnresolvedWith(child: LogicalPlan, relations) =>
+      case cte @ UnresolvedWith(child: LogicalPlan, relations) =>
         val resolvedCTERelations =
           resolveCTERelations(relations, isLegacy = false, forceInline, 
outerCTEDefs, cteDefs) ++
             outerCTEDefs
@@ -213,6 +215,7 @@ object CTESubstitution extends Rule[LogicalPlan] {
         if (firstSubstituted.isEmpty) {
           firstSubstituted = Some(substituted)
         }
+        substituted.copyTagsFrom(cte)
         substituted
 
       case other =>
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
index edfc60fc6eaa..70b44fbfa79f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
@@ -514,7 +514,7 @@ trait ColumnResolutionHelper extends Logging with 
DataTypeErrorsBase {
         //  df2 = spark.createDataFrame([Row(a = 1, b = 2)]])
         //  df1.select(df2.a)   <-   illegal reference df2.a
         throw new AnalysisException(s"When resolving $u, " +
-          s"fail to find subplan with plan_id=$planId in $q")
+          s"fail to find subplan with plan_id=$planId in\n$q")
       }
     })
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to