This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 686f428dc104 [SPARK-46541][SQL][CONNECT] Fix the ambiguous column 
reference in self join
686f428dc104 is described below

commit 686f428dc10410e95d4421d4cbe0dd509335c9f2
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Wed Jan 10 10:38:00 2024 +0800

    [SPARK-46541][SQL][CONNECT] Fix the ambiguous column reference in self join
    
    ### What changes were proposed in this pull request?
    fix the logic of ambiguous column detection in spark connect
    
    ### Why are the changes needed?
    ```
    In [24]: df1 = spark.range(10).withColumn("a", sf.lit(0))
    
    In [25]: df2 = df1.withColumnRenamed("a", "b")
    
    In [26]: df1.join(df2, df1["a"] == df2["b"])
    Out[26]: 23/12/22 09:33:28 ERROR ErrorUtils: Spark Connect RPC error 
during: analyze. UserId: ruifeng.zheng. SessionId: 
eaa2161f-4b64-4dbf-9809-af6b696d3005.
    org.apache.spark.sql.AnalysisException: [AMBIGUOUS_COLUMN_REFERENCE] Column 
a is ambiguous. It's because you joined several DataFrame together, and some of 
these DataFrames are the same.
    This column points to one of the DataFrame but Spark is unable to figure 
out which one.
    Please alias the DataFrames with different names via DataFrame.alias before 
joining them,
    and specify the column using qualified name, e.g. 
df.alias("a").join(df.alias("b"), col("a.id") > col("b.id")). SQLSTATE: 42702
            at 
org.apache.spark.sql.catalyst.analysis.ColumnResolutionHelper.findPlanById(ColumnResolutionHelper.scala:555)
            at
    
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    yes, fix a bug
    
    ### How was this patch tested?
    added ut
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #44532 from zhengruifeng/sql_connect_find_plan_id.
    
    Lead-authored-by: Ruifeng Zheng <ruife...@apache.org>
    Co-authored-by: Ruifeng Zheng <ruife...@foxmail.com>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 .../src/main/resources/error/error-classes.json    |  11 +-
 .../org/apache/spark/sql/ClientE2ETestSuite.scala  |   2 +-
 docs/sql-error-conditions.md                       |   6 +
 .../sql/tests/connect/test_connect_basic.py        |  13 +-
 python/pyspark/sql/tests/test_dataframe.py         |   9 +-
 .../catalyst/analysis/ColumnResolutionHelper.scala | 139 ++++++++++++---------
 .../spark/sql/errors/QueryCompilationErrors.scala  |  18 ++-
 7 files changed, 133 insertions(+), 65 deletions(-)

diff --git a/common/utils/src/main/resources/error/error-classes.json 
b/common/utils/src/main/resources/error/error-classes.json
index c7f8f59a7679..e770b9c7053e 100644
--- a/common/utils/src/main/resources/error/error-classes.json
+++ b/common/utils/src/main/resources/error/error-classes.json
@@ -324,6 +324,12 @@
     ],
     "sqlState" : "0AKD0"
   },
+  "CANNOT_RESOLVE_DATAFRAME_COLUMN" : {
+    "message" : [
+      "Cannot resolve dataframe column <name>. It's probably because of 
illegal references like `df1.select(df2.col(\"a\"))`."
+    ],
+    "sqlState" : "42704"
+  },
   "CANNOT_RESOLVE_STAR_EXPAND" : {
     "message" : [
       "Cannot resolve <targetString>.* given input columns <columns>. Please 
check that the specified table or struct exists and is accessible in the input 
columns."
@@ -6843,11 +6849,6 @@
       "Cannot modify the value of a static config: <k>"
     ]
   },
-  "_LEGACY_ERROR_TEMP_3051" : {
-    "message" : [
-      "When resolving <u>, fail to find subplan with plan_id=<planId> in <q>"
-    ]
-  },
   "_LEGACY_ERROR_TEMP_3052" : {
     "message" : [
       "Unexpected resolved action: <other>"
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
index 0740334724e8..288964a084ba 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
@@ -894,7 +894,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with 
SQLHelper with PrivateM
       // df1("i") is not ambiguous, but it's not valid in the projected df.
       df1.select((df1("i") + 1).as("plus")).select(df1("i")).collect()
     }
-    
assert(e1.getMessage.contains("MISSING_ATTRIBUTES.RESOLVED_ATTRIBUTE_MISSING_FROM_INPUT"))
+    assert(e1.getMessage.contains("UNRESOLVED_COLUMN.WITH_SUGGESTION"))
 
     checkSameResult(
       Seq(Row(1, "a")),
diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md
index f58b7f607a0b..db8ecf5b2a30 100644
--- a/docs/sql-error-conditions.md
+++ b/docs/sql-error-conditions.md
@@ -282,6 +282,12 @@ Cannot recognize hive type string: `<fieldType>`, column: 
`<fieldName>`. The spe
 
 Renaming a `<type>` across schemas is not allowed.
 
+### CANNOT_RESOLVE_DATAFRAME_COLUMN
+
+[SQLSTATE: 
42704](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)
+
+Cannot resolve dataframe column `<name>`. It's probably because of illegal 
references like `df1.select(df2.col("a"))`.
+
 ### CANNOT_RESOLVE_STAR_EXPAND
 
 [SQLSTATE: 
42704](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 045ba8f0060d..a1cd00e79e1a 100755
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -543,10 +543,21 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
         with self.assertRaises(AnalysisException):
             cdf2.withColumn("x", cdf1.a + 1).schema
 
-        with self.assertRaisesRegex(AnalysisException, "attribute.*missing"):
+        # Can find the target plan node, but fail to resolve with it
+        with self.assertRaisesRegex(
+            AnalysisException,
+            "UNRESOLVED_COLUMN.WITH_SUGGESTION",
+        ):
             cdf3 = cdf1.select(cdf1.a)
             cdf3.select(cdf1.b).schema
 
+        # Can not find the target plan node by plan id
+        with self.assertRaisesRegex(
+            AnalysisException,
+            "CANNOT_RESOLVE_DATAFRAME_COLUMN",
+        ):
+            cdf1.select(cdf2.a).schema
+
     def test_collect(self):
         cdf = self.connect.read.table(self.tbl_name)
         sdf = self.spark.read.table(self.tbl_name)
diff --git a/python/pyspark/sql/tests/test_dataframe.py 
b/python/pyspark/sql/tests/test_dataframe.py
index f1d690751ead..c77e7fd89d01 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -26,7 +26,6 @@ from typing import cast
 import io
 from contextlib import redirect_stdout
 
-from pyspark import StorageLevel
 from pyspark.sql import SparkSession, Row, functions
 from pyspark.sql.functions import col, lit, count, sum, mean, struct
 from pyspark.sql.types import (
@@ -70,6 +69,14 @@ class DataFrameTestsMixin:
         self.assertEqual(self.spark.range(-2).count(), 0)
         self.assertEqual(self.spark.range(3).count(), 3)
 
+    def test_self_join(self):
+        df1 = self.spark.range(10).withColumn("a", lit(0))
+        df2 = df1.withColumnRenamed("a", "b")
+        df = df1.join(df2, df1["a"] == df2["b"])
+        self.assertTrue(df.count() == 100)
+        df = df2.join(df1, df2["b"] == df1["a"])
+        self.assertTrue(df.count() == 100)
+
     def test_duplicated_column_names(self):
         df = self.spark.createDataFrame([(1, 2)], ["c", "c"])
         row = df.select("*").first()
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
index a90c61565039..3261aa51b9be 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
@@ -426,7 +426,7 @@ trait ColumnResolutionHelper extends Logging with 
DataTypeErrorsBase {
       throws: Boolean = false,
       includeLastResort: Boolean = false): Expression = {
     resolveExpression(
-      tryResolveColumnByPlanId(expr, plan),
+      tryResolveDataFrameColumns(expr, Seq(plan)),
       resolveColumnByName = nameParts => {
         plan.resolve(nameParts, conf.resolver)
       },
@@ -448,7 +448,7 @@ trait ColumnResolutionHelper extends Logging with 
DataTypeErrorsBase {
       q: LogicalPlan,
       includeLastResort: Boolean = false): Expression = {
     resolveExpression(
-      tryResolveColumnByPlanId(e, q),
+      tryResolveDataFrameColumns(e, q.children),
       resolveColumnByName = nameParts => {
         q.resolveChildren(nameParts, conf.resolver)
       },
@@ -485,80 +485,107 @@ trait ColumnResolutionHelper extends Logging with 
DataTypeErrorsBase {
   //    4. if more than one matching nodes are found, fail due to ambiguous 
column reference;
   //    5. resolve the expression with the matching node, if any error occurs 
here, return the
   //       original expression as it is.
-  private def tryResolveColumnByPlanId(
+  private def tryResolveDataFrameColumns(
       e: Expression,
-      q: LogicalPlan,
-      idToPlan: mutable.HashMap[Long, LogicalPlan] = mutable.HashMap.empty): 
Expression = e match {
+      q: Seq[LogicalPlan]): Expression = e match {
     case u: UnresolvedAttribute =>
-      resolveUnresolvedAttributeByPlanId(
-        u, q, idToPlan: mutable.HashMap[Long, LogicalPlan]
-      ).getOrElse(u)
+      resolveDataFrameColumn(u, q).getOrElse(u)
     case _ if e.containsPattern(UNRESOLVED_ATTRIBUTE) =>
-      e.mapChildren(c => tryResolveColumnByPlanId(c, q, idToPlan))
+      e.mapChildren(c => tryResolveDataFrameColumns(c, q))
     case _ => e
   }
 
-  private def resolveUnresolvedAttributeByPlanId(
+  private def resolveDataFrameColumn(
       u: UnresolvedAttribute,
-      q: LogicalPlan,
-      idToPlan: mutable.HashMap[Long, LogicalPlan]): Option[NamedExpression] = 
{
+      q: Seq[LogicalPlan]): Option[NamedExpression] = {
     val planIdOpt = u.getTagValue(LogicalPlan.PLAN_ID_TAG)
     if (planIdOpt.isEmpty) return None
     val planId = planIdOpt.get
     logDebug(s"Extract plan_id $planId from $u")
 
-    val plan = idToPlan.getOrElseUpdate(planId, {
-      findPlanById(u, planId, q).getOrElse {
-        // For example:
-        //  df1 = spark.createDataFrame([Row(a = 1, b = 2, c = 3)]])
-        //  df2 = spark.createDataFrame([Row(a = 1, b = 2)]])
-        //  df1.select(df2.a)   <-   illegal reference df2.a
-        throw new AnalysisException(
-          errorClass = "_LEGACY_ERROR_TEMP_3051",
-          messageParameters = Map(
-            "u" -> u.toString,
-            "planId" -> planId.toString,
-            "q" -> q.toString))
-      }
-    })
+    val isMetadataAccess = u.getTagValue(LogicalPlan.IS_METADATA_COL).nonEmpty
+    val (resolved, matched) = resolveDataFrameColumnByPlanId(u, planId, 
isMetadataAccess, q)
+    if (!matched) {
+      // Can not find the target plan node with plan id, e.g.
+      //  df1 = spark.createDataFrame([Row(a = 1, b = 2, c = 3)]])
+      //  df2 = spark.createDataFrame([Row(a = 1, b = 2)]])
+      //  df1.select(df2.a)   <-   illegal reference df2.a
+      throw QueryCompilationErrors.cannotResolveColumn(u)
+    }
+    resolved
+  }
 
-    val isMetadataAccess = u.getTagValue(LogicalPlan.IS_METADATA_COL).isDefined
-    try {
-      if (!isMetadataAccess) {
-        plan.resolve(u.nameParts, conf.resolver)
-      } else if (u.nameParts.size == 1) {
-        plan.getMetadataAttributeByNameOpt(u.nameParts.head)
-      } else {
-        None
+  private def resolveDataFrameColumnByPlanId(
+      u: UnresolvedAttribute,
+      id: Long,
+      isMetadataAccess: Boolean,
+      q: Seq[LogicalPlan]): (Option[NamedExpression], Boolean) = {
+    q.iterator.map(resolveDataFrameColumnRecursively(u, id, isMetadataAccess, 
_))
+      .foldLeft((Option.empty[NamedExpression], false)) {
+        case ((r1, m1), (r2, m2)) =>
+          if (r1.nonEmpty && r2.nonEmpty) {
+            throw QueryCompilationErrors.ambiguousColumnReferences(u)
+          }
+          (if (r1.nonEmpty) r1 else r2, m1 | m2)
       }
-    } catch {
-      case e: AnalysisException =>
-        logDebug(s"Fail to resolve $u with $plan due to $e")
-        None
-    }
   }
 
-  private def findPlanById(
+  private def resolveDataFrameColumnRecursively(
       u: UnresolvedAttribute,
       id: Long,
-      plan: LogicalPlan): Option[LogicalPlan] = {
-    if (plan.getTagValue(LogicalPlan.PLAN_ID_TAG).contains(id)) {
-      Some(plan)
-    } else if (plan.children.length == 1) {
-      findPlanById(u, id, plan.children.head)
-    } else if (plan.children.length > 1) {
-      val matched = plan.children.flatMap(findPlanById(u, id, _))
-      if (matched.length > 1) {
-        throw new AnalysisException(
-          errorClass = "AMBIGUOUS_COLUMN_REFERENCE",
-          messageParameters = Map("name" -> toSQLId(u.nameParts)),
-          origin = u.origin
-        )
-      } else {
-        matched.headOption
+      isMetadataAccess: Boolean,
+      p: LogicalPlan): (Option[NamedExpression], Boolean) = {
+    val (resolved, matched) = if 
(p.getTagValue(LogicalPlan.PLAN_ID_TAG).contains(id)) {
+      val resolved = try {
+        if (!isMetadataAccess) {
+          p.resolve(u.nameParts, conf.resolver)
+        } else if (u.nameParts.size == 1) {
+          p.getMetadataAttributeByNameOpt(u.nameParts.head)
+        } else {
+          None
+        }
+      } catch {
+        case e: AnalysisException =>
+          logDebug(s"Fail to resolve $u with $p due to $e")
+          None
       }
+      (resolved, true)
     } else {
-      None
+      resolveDataFrameColumnByPlanId(u, id, isMetadataAccess, p.children)
+    }
+
+    // In self join case like:
+    //   df1 = spark.range(10).withColumn("a", sf.lit(0))
+    //   df2 = df1.withColumnRenamed("a", "b")
+    //   df1.join(df2, df1["a"] == df2["b"])
+    //
+    // the logical plan would be like:
+    //
+    // 'Join Inner, '`==`('a, 'b)                             [plan_id=5]
+    //    :- Project [id#22L, 0 AS a#25]                      [plan_id=1]
+    //      :  +- Range (0, 10, step=1, splits=Some(12))
+    //    +- Project [id#28L, a#31 AS b#36]                   [plan_id=2]
+    //         +- Project [id#28L, 0 AS a#31]                 [plan_id=1]
+    //            +- Range (0, 10, step=1, splits=Some(12))
+    //
+    // When resolving the column reference df1.a, the target node with 
plan_id=1
+    // can be found in both sides of the Join node.
+    // To correctly resolve df1.a, the analyzer discards the resolved attribute
+    // in the right side, by filtering out the result by the output attributes 
of
+    // Project plan_id=2.
+    //
+    // However, there are analyzer rules (e.g. ResolveReferencesInSort)
+    // supporting missing column resolution. Then a valid resolved attribute
+    // maybe filtered out here. In this case, resolveDataFrameColumnByPlanId
+    // returns None, the dataframe column will remain unresolved, and the 
analyzer
+    // will try to resolve it without plan id later.
+    val filtered = resolved.filter { r =>
+      if (isMetadataAccess) {
+        r.references.subsetOf(AttributeSet(p.output ++ p.metadataOutput))
+      } else {
+        r.references.subsetOf(p.outputSet)
+      }
     }
+    (filtered, matched)
   }
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
index 387064695770..91d18788fd4c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
@@ -24,7 +24,7 @@ import org.apache.hadoop.fs.Path
 import org.apache.spark.{SPARK_DOC_ROOT, SparkException, SparkThrowable, 
SparkUnsupportedOperationException}
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.{ExtendedAnalysisException, 
FunctionIdentifier, InternalRow, QualifiedTableName, TableIdentifier}
-import 
org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, 
FunctionAlreadyExistsException, NamespaceAlreadyExistsException, 
NoSuchFunctionException, NoSuchNamespaceException, NoSuchPartitionException, 
NoSuchTableException, ResolvedTable, Star, TableAlreadyExistsException, 
UnresolvedRegex}
+import 
org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, 
FunctionAlreadyExistsException, NamespaceAlreadyExistsException, 
NoSuchFunctionException, NoSuchNamespaceException, NoSuchPartitionException, 
NoSuchTableException, ResolvedTable, Star, TableAlreadyExistsException, 
UnresolvedAttribute, UnresolvedRegex}
 import org.apache.spark.sql.catalyst.catalog.{CatalogTable, 
InvalidUDFClassException}
 import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
 import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, 
AttributeReference, AttributeSet, CreateMap, CreateStruct, Expression, 
GroupingID, NamedExpression, SpecifiedWindowFrame, WindowFrame, WindowFunction, 
WindowSpecDefinition}
@@ -3940,4 +3940,20 @@ private[sql] object QueryCompilationErrors extends 
QueryErrorsBase with Compilat
         "dsSchema" -> toSQLType(dsSchema),
         "expectedSchema" -> toSQLType(expectedSchema)))
   }
+
+  def cannotResolveColumn(u: UnresolvedAttribute): Throwable = {
+    new AnalysisException(
+      errorClass = "CANNOT_RESOLVE_DATAFRAME_COLUMN",
+      messageParameters = Map("name" -> toSQLId(u.nameParts)),
+      origin = u.origin
+    )
+  }
+
+  def ambiguousColumnReferences(u: UnresolvedAttribute): Throwable = {
+    new AnalysisException(
+      errorClass = "AMBIGUOUS_COLUMN_REFERENCE",
+      messageParameters = Map("name" -> toSQLId(u.nameParts)),
+      origin = u.origin
+    )
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to