This is an automated email from the ASF dual-hosted git repository. ueshin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 2cf937f9bac [SPARK-45266][PYTHON] Refactor ResolveFunctions analyzer rule to delay making lateral join when table arguments are used 2cf937f9bac is described below commit 2cf937f9bac2131f3657660a8d65d07ab4ece490 Author: Takuya UESHIN <ues...@databricks.com> AuthorDate: Thu Sep 28 10:37:18 2023 -0700 [SPARK-45266][PYTHON] Refactor ResolveFunctions analyzer rule to delay making lateral join when table arguments are used ### What changes were proposed in this pull request? Refactors `ResolveFunctions` analyzer rule to delay making lateral join when table arguments are used. - Delay making lateral join when table arguments are used to after all the children are resolved - Resolve `UnresolvedPolymorphicPythonUDTF` in one place - Introduce a new error class `UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.UNSUPPORTED_TABLE_ARGUMENT` if table arguments are used inproperly. ### Why are the changes needed? The analyzer rule `ResolveFunctions` became complicated. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #43042 from ueshin/issues/SPARK-45266/analyzer. Authored-by: Takuya UESHIN <ues...@databricks.com> Signed-off-by: Takuya UESHIN <ues...@databricks.com> --- .../src/main/resources/error/error-classes.json | 5 + ...ted-subquery-expression-category-error-class.md | 4 + .../spark/sql/catalyst/analysis/Analyzer.scala | 155 +++++++++------------ .../sql/catalyst/analysis/CheckAnalysis.scala | 5 + .../spark/sql/catalyst/expressions/PythonUDF.scala | 6 +- .../named-function-arguments.sql.out | 16 +-- .../results/named-function-arguments.sql.out | 16 +-- .../sql/execution/python/PythonUDTFSuite.scala | 20 ++- 8 files changed, 103 insertions(+), 124 deletions(-) diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index 0882e387176..58fcedae332 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -3484,6 +3484,11 @@ "message" : [ "IN/EXISTS predicate subqueries can only be used in filters, joins, aggregations, window functions, projections, and UPDATE/MERGE/DELETE commands<treeNode>." ] + }, + "UNSUPPORTED_TABLE_ARGUMENT" : { + "message" : [ + "Table arguments are used in a function where they are not supported<treeNode>." + ] } }, "sqlState" : "0A000" diff --git a/docs/sql-error-conditions-unsupported-subquery-expression-category-error-class.md b/docs/sql-error-conditions-unsupported-subquery-expression-category-error-class.md index f61ea721aa0..45ad386c666 100644 --- a/docs/sql-error-conditions-unsupported-subquery-expression-category-error-class.md +++ b/docs/sql-error-conditions-unsupported-subquery-expression-category-error-class.md @@ -73,4 +73,8 @@ Correlated scalar subqueries can only be used in filters, aggregations, projecti IN/EXISTS predicate subqueries can only be used in filters, joins, aggregations, window functions, projections, and UPDATE/MERGE/DELETE commands`<treeNode>`. +## UNSUPPORTED_TABLE_ARGUMENT + +Table arguments are used in a function where they are not supported`<treeNode>`. + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 67a958d73f7..cc0bfd3fc31 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2082,7 +2082,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) => withPosition(u) { try { - val resolvedTvf = resolveBuiltinOrTempTableFunction(u.name, u.functionArgs).getOrElse { + val resolvedFunc = resolveBuiltinOrTempTableFunction(u.name, u.functionArgs).getOrElse { val CatalogAndIdentifier(catalog, ident) = expandIdentifier(u.name) if (CatalogV2Util.isSessionCatalog(catalog)) { v1SessionCatalog.resolvePersistentTableFunction( @@ -2092,93 +2092,19 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor catalog, "table-valued functions") } } - // Resolve Python UDTF calls if needed. - val resolvedFunc = resolvedTvf match { - case g @ Generate(u: UnresolvedPolymorphicPythonUDTF, _, _, _, _, _) => - val analyzeResult: PythonUDTFAnalyzeResult = - u.resolveElementMetadata(u.func, u.children) - g.copy(generator = - PythonUDTF(u.name, u.func, analyzeResult.schema, u.children, - u.evalType, u.udfDeterministic, u.resultId, u.pythonUDTFPartitionColumnIndexes, - analyzeResult = Some(analyzeResult))) - case other => - other - } - val tableArgs = mutable.ArrayBuffer.empty[LogicalPlan] - val functionTableSubqueryArgs = - mutable.ArrayBuffer.empty[FunctionTableSubqueryArgumentExpression] - val tvf = resolvedFunc.transformAllExpressionsWithPruning( - _.containsPattern(FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION), ruleId) { + resolvedFunc.transformAllExpressionsWithPruning( + _.containsPattern(FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION)) { case t: FunctionTableSubqueryArgumentExpression => - val alias = SubqueryAlias.generateSubqueryName(s"_${tableArgs.size}") - val ( - pythonUDTFName: String, - pythonUDTFAnalyzeResult: Option[PythonUDTFAnalyzeResult]) = - resolvedFunc match { - case Generate(p: PythonUDTF, _, _, _, _, _) => - (p.name, - p.analyzeResult) - case _ => - assert(!t.hasRepartitioning, - "Cannot evaluate the table-valued function call because it included the " + - "PARTITION BY clause, but only Python table functions support this " + - "clause") - ("", None) - } - // Check if this is a call to a Python user-defined table function whose polymorphic - // 'analyze' method returned metadata indicated requested partitioning and/or - // ordering properties of the input relation. In that event, make sure that the UDTF - // call did not include any explicit PARTITION BY and/or ORDER BY clauses for the - // corresponding TABLE argument, and then update the TABLE argument representation - // to apply the requested partitioning and/or ordering. - pythonUDTFAnalyzeResult.map { analyzeResult => - val newTableArgument: FunctionTableSubqueryArgumentExpression = - analyzeResult.applyToTableArgument(pythonUDTFName, t) - tableArgs.append(SubqueryAlias(alias, newTableArgument.evaluable)) - functionTableSubqueryArgs.append(newTableArgument) - }.getOrElse { - tableArgs.append(SubqueryAlias(alias, t.evaluable)) - functionTableSubqueryArgs.append(t) + resolvedFunc match { + case Generate(_: PythonUDTF, _, _, _, _, _) => + case Generate(_: UnresolvedPolymorphicPythonUDTF, _, _, _, _, _) => + case _ => + assert(!t.hasRepartitioning, + "Cannot evaluate the table-valued function call because it included the " + + "PARTITION BY clause, but only Python table functions support this " + + "clause") } - UnresolvedAttribute(Seq(alias, "c")) - } - if (tableArgs.nonEmpty) { - if (!conf.tvfAllowMultipleTableArguments && tableArgs.size > 1) { - throw QueryCompilationErrors.tableValuedFunctionTooManyTableArgumentsError( - tableArgs.size) - } - val alias = SubqueryAlias.generateSubqueryName(s"_${tableArgs.size}") - // Propagate the column indexes for TABLE arguments to the PythonUDTF instance. - def assignUDTFPartitionColumnIndexes( - fn: PythonUDTFPartitionColumnIndexes => LogicalPlan): Option[LogicalPlan] = { - val indexes: Seq[Int] = functionTableSubqueryArgs.headOption - .map(_.partitioningExpressionIndexes).getOrElse(Seq.empty) - if (indexes.nonEmpty) { - Some(fn(PythonUDTFPartitionColumnIndexes(indexes))) - } else { - None - } - } - val tvfWithTableColumnIndexes: LogicalPlan = tvf match { - case g@Generate(p: PythonUDTF, _, _, _, _, _) => - assignUDTFPartitionColumnIndexes( - i => g.copy(generator = p.copy(pythonUDTFPartitionColumnIndexes = Some(i)))) - .getOrElse(g) - case g@Generate(p: UnresolvedPolymorphicPythonUDTF, _, _, _, _, _) => - assignUDTFPartitionColumnIndexes( - i => g.copy(generator = p.copy(pythonUDTFPartitionColumnIndexes = Some(i)))) - .getOrElse(g) - case _ => - tvf - } - Project( - Seq(UnresolvedStar(Some(Seq(alias)))), - LateralJoin( - tableArgs.reduceLeft(Join(_, _, Inner, None, JoinHint.NONE)), - LateralSubquery(SubqueryAlias(alias, tvfWithTableColumnIndexes)), Inner, None) - ) - } else { - tvf + t } } catch { case _: NoSuchFunctionException => @@ -2206,6 +2132,46 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } Project(aliases, u.child) + case p: LogicalPlan + if p.resolved && p.containsPattern(FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION) => + withPosition(p) { + val tableArgs = + mutable.ArrayBuffer.empty[(FunctionTableSubqueryArgumentExpression, LogicalPlan)] + + val tvf = p.transformExpressionsWithPruning( + _.containsPattern(FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION)) { + case t: FunctionTableSubqueryArgumentExpression => + val alias = SubqueryAlias.generateSubqueryName(s"_${tableArgs.size}") + tableArgs.append((t, SubqueryAlias(alias, t.evaluable))) + UnresolvedAttribute(Seq(alias, "c")) + } + + assert(tableArgs.nonEmpty) + if (!conf.tvfAllowMultipleTableArguments && tableArgs.size > 1) { + throw QueryCompilationErrors.tableValuedFunctionTooManyTableArgumentsError( + tableArgs.size) + } + val alias = SubqueryAlias.generateSubqueryName(s"_${tableArgs.size}") + + // Propagate the column indexes for TABLE arguments to the PythonUDTF instance. + val tvfWithTableColumnIndexes = tvf match { + case g @ Generate(pyudtf: PythonUDTF, _, _, _, _, _) + if tableArgs.head._1.partitioningExpressionIndexes.nonEmpty => + val partitionColumnIndexes = + PythonUDTFPartitionColumnIndexes(tableArgs.head._1.partitioningExpressionIndexes) + g.copy(generator = pyudtf.copy( + pythonUDTFPartitionColumnIndexes = Some(partitionColumnIndexes))) + case _ => tvf + } + + Project( + Seq(UnresolvedStar(Some(Seq(alias)))), + LateralJoin( + tableArgs.map(_._2).reduceLeft(Join(_, _, Inner, None, JoinHint.NONE)), + LateralSubquery(SubqueryAlias(alias, tvfWithTableColumnIndexes)), Inner, None) + ) + } + case q: LogicalPlan => q.transformExpressionsUpWithPruning( _.containsAnyPattern(UNRESOLVED_FUNCTION, GENERATOR), @@ -2251,9 +2217,20 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } case u: UnresolvedPolymorphicPythonUDTF => withPosition(u) { - val elementSchema = u.resolveElementMetadata(u.func, u.children).schema - PythonUDTF(u.name, u.func, elementSchema, u.children, - u.evalType, u.udfDeterministic, u.resultId, u.pythonUDTFPartitionColumnIndexes) + // Check if this is a call to a Python user-defined table function whose polymorphic + // 'analyze' method returned metadata indicated requested partitioning and/or + // ordering properties of the input relation. In that event, make sure that the UDTF + // call did not include any explicit PARTITION BY and/or ORDER BY clauses for the + // corresponding TABLE argument, and then update the TABLE argument representation + // to apply the requested partitioning and/or ordering. + val analyzeResult = u.resolveElementMetadata(u.func, u.children) + val newChildren = u.children.map { + case t: FunctionTableSubqueryArgumentExpression => + analyzeResult.applyToTableArgument(u.name, t) + case c => c + } + PythonUDTF(u.name, u.func, analyzeResult.schema, newChildren, + u.evalType, u.udfDeterministic, u.resultId) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 83b682bc917..de453f6bc49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -1075,6 +1075,11 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB // allowed by spark. checkCorrelationsInSubquery(expr.plan, isLateral = true) + case _: FunctionTableSubqueryArgumentExpression => + expr.failAnalysis( + errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.UNSUPPORTED_TABLE_ARGUMENT", + messageParameters = Map("treeNode" -> planToString(plan))) + case inSubqueryOrExistsSubquery => plan match { case _: Filter | _: SupportsSubquery | _: Join | diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala index a615348bc6e..bc74572444c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala @@ -180,8 +180,7 @@ case class PythonUDTF( evalType: Int, udfDeterministic: Boolean, resultId: ExprId = NamedExpression.newExprId, - pythonUDTFPartitionColumnIndexes: Option[PythonUDTFPartitionColumnIndexes] = None, - analyzeResult: Option[PythonUDTFAnalyzeResult] = None) + pythonUDTFPartitionColumnIndexes: Option[PythonUDTFPartitionColumnIndexes] = None) extends UnevaluableGenerator with PythonFuncExpression { override lazy val canonicalized: Expression = { @@ -210,8 +209,7 @@ case class UnresolvedPolymorphicPythonUDTF( evalType: Int, udfDeterministic: Boolean, resolveElementMetadata: (PythonFunction, Seq[Expression]) => PythonUDTFAnalyzeResult, - resultId: ExprId = NamedExpression.newExprId, - pythonUDTFPartitionColumnIndexes: Option[PythonUDTFPartitionColumnIndexes] = None) + resultId: ExprId = NamedExpression.newExprId) extends UnevaluableGenerator with PythonFuncExpression { override lazy val resolved = false diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/named-function-arguments.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/named-function-arguments.sql.out index 11e2651c6f2..4ba47e9e1b4 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/named-function-arguments.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/named-function-arguments.sql.out @@ -202,21 +202,17 @@ SELECT * FROM explode(collection => TABLE(v)) -- !query analysis org.apache.spark.sql.catalyst.ExtendedAnalysisException { - "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", - "sqlState" : "42K09", + "errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.UNSUPPORTED_TABLE_ARGUMENT", + "sqlState" : "0A000", "messageParameters" : { - "inputSql" : "\"outer(__auto_generated_subquery_name_0.c)\"", - "inputType" : "\"STRUCT<id: BIGINT>\"", - "paramIndex" : "1", - "requiredType" : "(\"ARRAY\" or \"MAP\")", - "sqlExpr" : "\"explode(outer(__auto_generated_subquery_name_0.c))\"" + "treeNode" : "'Generate explode(table-argument#x []), false\n: +- SubqueryAlias v\n: +- View (`v`, [id#xL])\n: +- Project [cast(id#xL as bigint) AS id#xL]\n: +- Project [id#xL]\n: +- Range (0, 8, step=1, splits=None)\n+- OneRowRelation\n" }, "queryContext" : [ { "objectType" : "", "objectName" : "", - "startIndex" : 15, - "stopIndex" : 45, - "fragment" : "explode(collection => TABLE(v))" + "startIndex" : 37, + "stopIndex" : 44, + "fragment" : "TABLE(v)" } ] } diff --git a/sql/core/src/test/resources/sql-tests/results/named-function-arguments.sql.out b/sql/core/src/test/resources/sql-tests/results/named-function-arguments.sql.out index 60301862a35..03963ac3ef9 100644 --- a/sql/core/src/test/resources/sql-tests/results/named-function-arguments.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/named-function-arguments.sql.out @@ -185,21 +185,17 @@ struct<> -- !query output org.apache.spark.sql.catalyst.ExtendedAnalysisException { - "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", - "sqlState" : "42K09", + "errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.UNSUPPORTED_TABLE_ARGUMENT", + "sqlState" : "0A000", "messageParameters" : { - "inputSql" : "\"outer(__auto_generated_subquery_name_0.c)\"", - "inputType" : "\"STRUCT<id: BIGINT>\"", - "paramIndex" : "1", - "requiredType" : "(\"ARRAY\" or \"MAP\")", - "sqlExpr" : "\"explode(outer(__auto_generated_subquery_name_0.c))\"" + "treeNode" : "'Generate explode(table-argument#x []), false\n: +- SubqueryAlias v\n: +- View (`v`, [id#xL])\n: +- Project [cast(id#xL as bigint) AS id#xL]\n: +- Project [id#xL]\n: +- Range (0, 8, step=1, splits=None)\n+- OneRowRelation\n" }, "queryContext" : [ { "objectType" : "", "objectName" : "", - "startIndex" : 15, - "stopIndex" : 45, - "fragment" : "explode(collection => TABLE(v))" + "startIndex" : 37, + "stopIndex" : 44, + "fragment" : "TABLE(v)" } ] } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala index cf687f90287..cdc3ef9e417 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala @@ -128,6 +128,8 @@ class PythonUDTFSuite extends QueryTest with SharedSparkSession { def failure(plan: LogicalPlan): Unit = { fail(s"Unexpected plan: $plan") } + + spark.udtf.registerPython("testUDTF", pythonUDTF) sql( """ |SELECT * FROM testUDTF( @@ -187,19 +189,15 @@ class PythonUDTFSuite extends QueryTest with SharedSparkSession { withTable("t") { sql("create table t(col array<int>) using parquet") val query = "select * from explode(table(t))" - checkError( + checkErrorMatchPVals( exception = intercept[AnalysisException](sql(query)), - errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", - parameters = Map( - "sqlExpr" -> "\"explode(outer(__auto_generated_subquery_name_0.c))\"", - "paramIndex" -> "1", - "inputSql" -> "\"outer(__auto_generated_subquery_name_0.c)\"", - "inputType" -> "\"STRUCT<col: ARRAY<INT>>\"", - "requiredType" -> "(\"ARRAY\" or \"MAP\")"), + errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.UNSUPPORTED_TABLE_ARGUMENT", + sqlState = None, + parameters = Map("treeNode" -> "(?s).*"), context = ExpectedContext( - fragment = "explode(table(t))", - start = 14, - stop = 30)) + fragment = "table(t)", + start = 22, + stop = 29)) } spark.udtf.registerPython("UDTFCountSumLast", pythonUDTFCountSumLast) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org