This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 686f428dc104 [SPARK-46541][SQL][CONNECT] Fix the ambiguous column reference in self join 686f428dc104 is described below commit 686f428dc10410e95d4421d4cbe0dd509335c9f2 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Wed Jan 10 10:38:00 2024 +0800 [SPARK-46541][SQL][CONNECT] Fix the ambiguous column reference in self join ### What changes were proposed in this pull request? fix the logic of ambiguous column detection in spark connect ### Why are the changes needed? ``` In [24]: df1 = spark.range(10).withColumn("a", sf.lit(0)) In [25]: df2 = df1.withColumnRenamed("a", "b") In [26]: df1.join(df2, df1["a"] == df2["b"]) Out[26]: 23/12/22 09:33:28 ERROR ErrorUtils: Spark Connect RPC error during: analyze. UserId: ruifeng.zheng. SessionId: eaa2161f-4b64-4dbf-9809-af6b696d3005. org.apache.spark.sql.AnalysisException: [AMBIGUOUS_COLUMN_REFERENCE] Column a is ambiguous. It's because you joined several DataFrame together, and some of these DataFrames are the same. This column points to one of the DataFrame but Spark is unable to figure out which one. Please alias the DataFrames with different names via DataFrame.alias before joining them, and specify the column using qualified name, e.g. df.alias("a").join(df.alias("b"), col("a.id") > col("b.id")). SQLSTATE: 42702 at org.apache.spark.sql.catalyst.analysis.ColumnResolutionHelper.findPlanById(ColumnResolutionHelper.scala:555) at ``` ### Does this PR introduce _any_ user-facing change? yes, fix a bug ### How was this patch tested? added ut ### Was this patch authored or co-authored using generative AI tooling? no Closes #44532 from zhengruifeng/sql_connect_find_plan_id. Lead-authored-by: Ruifeng Zheng <ruife...@apache.org> Co-authored-by: Ruifeng Zheng <ruife...@foxmail.com> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../src/main/resources/error/error-classes.json | 11 +- .../org/apache/spark/sql/ClientE2ETestSuite.scala | 2 +- docs/sql-error-conditions.md | 6 + .../sql/tests/connect/test_connect_basic.py | 13 +- python/pyspark/sql/tests/test_dataframe.py | 9 +- .../catalyst/analysis/ColumnResolutionHelper.scala | 139 ++++++++++++--------- .../spark/sql/errors/QueryCompilationErrors.scala | 18 ++- 7 files changed, 133 insertions(+), 65 deletions(-) diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index c7f8f59a7679..e770b9c7053e 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -324,6 +324,12 @@ ], "sqlState" : "0AKD0" }, + "CANNOT_RESOLVE_DATAFRAME_COLUMN" : { + "message" : [ + "Cannot resolve dataframe column <name>. It's probably because of illegal references like `df1.select(df2.col(\"a\"))`." + ], + "sqlState" : "42704" + }, "CANNOT_RESOLVE_STAR_EXPAND" : { "message" : [ "Cannot resolve <targetString>.* given input columns <columns>. Please check that the specified table or struct exists and is accessible in the input columns." @@ -6843,11 +6849,6 @@ "Cannot modify the value of a static config: <k>" ] }, - "_LEGACY_ERROR_TEMP_3051" : { - "message" : [ - "When resolving <u>, fail to find subplan with plan_id=<planId> in <q>" - ] - }, "_LEGACY_ERROR_TEMP_3052" : { "message" : [ "Unexpected resolved action: <other>" diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index 0740334724e8..288964a084ba 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -894,7 +894,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateM // df1("i") is not ambiguous, but it's not valid in the projected df. df1.select((df1("i") + 1).as("plus")).select(df1("i")).collect() } - assert(e1.getMessage.contains("MISSING_ATTRIBUTES.RESOLVED_ATTRIBUTE_MISSING_FROM_INPUT")) + assert(e1.getMessage.contains("UNRESOLVED_COLUMN.WITH_SUGGESTION")) checkSameResult( Seq(Row(1, "a")), diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index f58b7f607a0b..db8ecf5b2a30 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -282,6 +282,12 @@ Cannot recognize hive type string: `<fieldType>`, column: `<fieldName>`. The spe Renaming a `<type>` across schemas is not allowed. +### CANNOT_RESOLVE_DATAFRAME_COLUMN + +[SQLSTATE: 42704](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) + +Cannot resolve dataframe column `<name>`. It's probably because of illegal references like `df1.select(df2.col("a"))`. + ### CANNOT_RESOLVE_STAR_EXPAND [SQLSTATE: 42704](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 045ba8f0060d..a1cd00e79e1a 100755 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -543,10 +543,21 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): with self.assertRaises(AnalysisException): cdf2.withColumn("x", cdf1.a + 1).schema - with self.assertRaisesRegex(AnalysisException, "attribute.*missing"): + # Can find the target plan node, but fail to resolve with it + with self.assertRaisesRegex( + AnalysisException, + "UNRESOLVED_COLUMN.WITH_SUGGESTION", + ): cdf3 = cdf1.select(cdf1.a) cdf3.select(cdf1.b).schema + # Can not find the target plan node by plan id + with self.assertRaisesRegex( + AnalysisException, + "CANNOT_RESOLVE_DATAFRAME_COLUMN", + ): + cdf1.select(cdf2.a).schema + def test_collect(self): cdf = self.connect.read.table(self.tbl_name) sdf = self.spark.read.table(self.tbl_name) diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index f1d690751ead..c77e7fd89d01 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -26,7 +26,6 @@ from typing import cast import io from contextlib import redirect_stdout -from pyspark import StorageLevel from pyspark.sql import SparkSession, Row, functions from pyspark.sql.functions import col, lit, count, sum, mean, struct from pyspark.sql.types import ( @@ -70,6 +69,14 @@ class DataFrameTestsMixin: self.assertEqual(self.spark.range(-2).count(), 0) self.assertEqual(self.spark.range(3).count(), 3) + def test_self_join(self): + df1 = self.spark.range(10).withColumn("a", lit(0)) + df2 = df1.withColumnRenamed("a", "b") + df = df1.join(df2, df1["a"] == df2["b"]) + self.assertTrue(df.count() == 100) + df = df2.join(df1, df2["b"] == df1["a"]) + self.assertTrue(df.count() == 100) + def test_duplicated_column_names(self): df = self.spark.createDataFrame([(1, 2)], ["c", "c"]) row = df.select("*").first() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index a90c61565039..3261aa51b9be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -426,7 +426,7 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { throws: Boolean = false, includeLastResort: Boolean = false): Expression = { resolveExpression( - tryResolveColumnByPlanId(expr, plan), + tryResolveDataFrameColumns(expr, Seq(plan)), resolveColumnByName = nameParts => { plan.resolve(nameParts, conf.resolver) }, @@ -448,7 +448,7 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { q: LogicalPlan, includeLastResort: Boolean = false): Expression = { resolveExpression( - tryResolveColumnByPlanId(e, q), + tryResolveDataFrameColumns(e, q.children), resolveColumnByName = nameParts => { q.resolveChildren(nameParts, conf.resolver) }, @@ -485,80 +485,107 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { // 4. if more than one matching nodes are found, fail due to ambiguous column reference; // 5. resolve the expression with the matching node, if any error occurs here, return the // original expression as it is. - private def tryResolveColumnByPlanId( + private def tryResolveDataFrameColumns( e: Expression, - q: LogicalPlan, - idToPlan: mutable.HashMap[Long, LogicalPlan] = mutable.HashMap.empty): Expression = e match { + q: Seq[LogicalPlan]): Expression = e match { case u: UnresolvedAttribute => - resolveUnresolvedAttributeByPlanId( - u, q, idToPlan: mutable.HashMap[Long, LogicalPlan] - ).getOrElse(u) + resolveDataFrameColumn(u, q).getOrElse(u) case _ if e.containsPattern(UNRESOLVED_ATTRIBUTE) => - e.mapChildren(c => tryResolveColumnByPlanId(c, q, idToPlan)) + e.mapChildren(c => tryResolveDataFrameColumns(c, q)) case _ => e } - private def resolveUnresolvedAttributeByPlanId( + private def resolveDataFrameColumn( u: UnresolvedAttribute, - q: LogicalPlan, - idToPlan: mutable.HashMap[Long, LogicalPlan]): Option[NamedExpression] = { + q: Seq[LogicalPlan]): Option[NamedExpression] = { val planIdOpt = u.getTagValue(LogicalPlan.PLAN_ID_TAG) if (planIdOpt.isEmpty) return None val planId = planIdOpt.get logDebug(s"Extract plan_id $planId from $u") - val plan = idToPlan.getOrElseUpdate(planId, { - findPlanById(u, planId, q).getOrElse { - // For example: - // df1 = spark.createDataFrame([Row(a = 1, b = 2, c = 3)]]) - // df2 = spark.createDataFrame([Row(a = 1, b = 2)]]) - // df1.select(df2.a) <- illegal reference df2.a - throw new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_3051", - messageParameters = Map( - "u" -> u.toString, - "planId" -> planId.toString, - "q" -> q.toString)) - } - }) + val isMetadataAccess = u.getTagValue(LogicalPlan.IS_METADATA_COL).nonEmpty + val (resolved, matched) = resolveDataFrameColumnByPlanId(u, planId, isMetadataAccess, q) + if (!matched) { + // Can not find the target plan node with plan id, e.g. + // df1 = spark.createDataFrame([Row(a = 1, b = 2, c = 3)]]) + // df2 = spark.createDataFrame([Row(a = 1, b = 2)]]) + // df1.select(df2.a) <- illegal reference df2.a + throw QueryCompilationErrors.cannotResolveColumn(u) + } + resolved + } - val isMetadataAccess = u.getTagValue(LogicalPlan.IS_METADATA_COL).isDefined - try { - if (!isMetadataAccess) { - plan.resolve(u.nameParts, conf.resolver) - } else if (u.nameParts.size == 1) { - plan.getMetadataAttributeByNameOpt(u.nameParts.head) - } else { - None + private def resolveDataFrameColumnByPlanId( + u: UnresolvedAttribute, + id: Long, + isMetadataAccess: Boolean, + q: Seq[LogicalPlan]): (Option[NamedExpression], Boolean) = { + q.iterator.map(resolveDataFrameColumnRecursively(u, id, isMetadataAccess, _)) + .foldLeft((Option.empty[NamedExpression], false)) { + case ((r1, m1), (r2, m2)) => + if (r1.nonEmpty && r2.nonEmpty) { + throw QueryCompilationErrors.ambiguousColumnReferences(u) + } + (if (r1.nonEmpty) r1 else r2, m1 | m2) } - } catch { - case e: AnalysisException => - logDebug(s"Fail to resolve $u with $plan due to $e") - None - } } - private def findPlanById( + private def resolveDataFrameColumnRecursively( u: UnresolvedAttribute, id: Long, - plan: LogicalPlan): Option[LogicalPlan] = { - if (plan.getTagValue(LogicalPlan.PLAN_ID_TAG).contains(id)) { - Some(plan) - } else if (plan.children.length == 1) { - findPlanById(u, id, plan.children.head) - } else if (plan.children.length > 1) { - val matched = plan.children.flatMap(findPlanById(u, id, _)) - if (matched.length > 1) { - throw new AnalysisException( - errorClass = "AMBIGUOUS_COLUMN_REFERENCE", - messageParameters = Map("name" -> toSQLId(u.nameParts)), - origin = u.origin - ) - } else { - matched.headOption + isMetadataAccess: Boolean, + p: LogicalPlan): (Option[NamedExpression], Boolean) = { + val (resolved, matched) = if (p.getTagValue(LogicalPlan.PLAN_ID_TAG).contains(id)) { + val resolved = try { + if (!isMetadataAccess) { + p.resolve(u.nameParts, conf.resolver) + } else if (u.nameParts.size == 1) { + p.getMetadataAttributeByNameOpt(u.nameParts.head) + } else { + None + } + } catch { + case e: AnalysisException => + logDebug(s"Fail to resolve $u with $p due to $e") + None } + (resolved, true) } else { - None + resolveDataFrameColumnByPlanId(u, id, isMetadataAccess, p.children) + } + + // In self join case like: + // df1 = spark.range(10).withColumn("a", sf.lit(0)) + // df2 = df1.withColumnRenamed("a", "b") + // df1.join(df2, df1["a"] == df2["b"]) + // + // the logical plan would be like: + // + // 'Join Inner, '`==`('a, 'b) [plan_id=5] + // :- Project [id#22L, 0 AS a#25] [plan_id=1] + // : +- Range (0, 10, step=1, splits=Some(12)) + // +- Project [id#28L, a#31 AS b#36] [plan_id=2] + // +- Project [id#28L, 0 AS a#31] [plan_id=1] + // +- Range (0, 10, step=1, splits=Some(12)) + // + // When resolving the column reference df1.a, the target node with plan_id=1 + // can be found in both sides of the Join node. + // To correctly resolve df1.a, the analyzer discards the resolved attribute + // in the right side, by filtering out the result by the output attributes of + // Project plan_id=2. + // + // However, there are analyzer rules (e.g. ResolveReferencesInSort) + // supporting missing column resolution. Then a valid resolved attribute + // maybe filtered out here. In this case, resolveDataFrameColumnByPlanId + // returns None, the dataframe column will remain unresolved, and the analyzer + // will try to resolve it without plan id later. + val filtered = resolved.filter { r => + if (isMetadataAccess) { + r.references.subsetOf(AttributeSet(p.output ++ p.metadataOutput)) + } else { + r.references.subsetOf(p.outputSet) + } } + (filtered, matched) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 387064695770..91d18788fd4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -24,7 +24,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.{SPARK_DOC_ROOT, SparkException, SparkThrowable, SparkUnsupportedOperationException} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{ExtendedAnalysisException, FunctionIdentifier, InternalRow, QualifiedTableName, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, FunctionAlreadyExistsException, NamespaceAlreadyExistsException, NoSuchFunctionException, NoSuchNamespaceException, NoSuchPartitionException, NoSuchTableException, ResolvedTable, Star, TableAlreadyExistsException, UnresolvedRegex} +import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, FunctionAlreadyExistsException, NamespaceAlreadyExistsException, NoSuchFunctionException, NoSuchNamespaceException, NoSuchPartitionException, NoSuchTableException, ResolvedTable, Star, TableAlreadyExistsException, UnresolvedAttribute, UnresolvedRegex} import org.apache.spark.sql.catalyst.catalog.{CatalogTable, InvalidUDFClassException} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, CreateMap, CreateStruct, Expression, GroupingID, NamedExpression, SpecifiedWindowFrame, WindowFrame, WindowFunction, WindowSpecDefinition} @@ -3940,4 +3940,20 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat "dsSchema" -> toSQLType(dsSchema), "expectedSchema" -> toSQLType(expectedSchema))) } + + def cannotResolveColumn(u: UnresolvedAttribute): Throwable = { + new AnalysisException( + errorClass = "CANNOT_RESOLVE_DATAFRAME_COLUMN", + messageParameters = Map("name" -> toSQLId(u.nameParts)), + origin = u.origin + ) + } + + def ambiguousColumnReferences(u: UnresolvedAttribute): Throwable = { + new AnalysisException( + errorClass = "AMBIGUOUS_COLUMN_REFERENCE", + messageParameters = Map("name" -> toSQLId(u.nameParts)), + origin = u.origin + ) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org