This is an automated email from the ASF dual-hosted git repository.

ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 2cf937f9bac [SPARK-45266][PYTHON] Refactor ResolveFunctions analyzer 
rule to delay making lateral join when table arguments are used
2cf937f9bac is described below

commit 2cf937f9bac2131f3657660a8d65d07ab4ece490
Author: Takuya UESHIN <ues...@databricks.com>
AuthorDate: Thu Sep 28 10:37:18 2023 -0700

    [SPARK-45266][PYTHON] Refactor ResolveFunctions analyzer rule to delay 
making lateral join when table arguments are used
    
    ### What changes were proposed in this pull request?
    
    Refactors `ResolveFunctions` analyzer rule to delay making lateral join 
when table arguments are used.
    
    - Delay making lateral join when table arguments are used to after all the 
children are resolved
    - Resolve `UnresolvedPolymorphicPythonUDTF` in one place
    - Introduce a new error class 
`UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.UNSUPPORTED_TABLE_ARGUMENT` if table 
arguments are used inproperly.
    
    ### Why are the changes needed?
    
    The analyzer rule `ResolveFunctions` became complicated.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #43042 from ueshin/issues/SPARK-45266/analyzer.
    
    Authored-by: Takuya UESHIN <ues...@databricks.com>
    Signed-off-by: Takuya UESHIN <ues...@databricks.com>
---
 .../src/main/resources/error/error-classes.json    |   5 +
 ...ted-subquery-expression-category-error-class.md |   4 +
 .../spark/sql/catalyst/analysis/Analyzer.scala     | 155 +++++++++------------
 .../sql/catalyst/analysis/CheckAnalysis.scala      |   5 +
 .../spark/sql/catalyst/expressions/PythonUDF.scala |   6 +-
 .../named-function-arguments.sql.out               |  16 +--
 .../results/named-function-arguments.sql.out       |  16 +--
 .../sql/execution/python/PythonUDTFSuite.scala     |  20 ++-
 8 files changed, 103 insertions(+), 124 deletions(-)

diff --git a/common/utils/src/main/resources/error/error-classes.json 
b/common/utils/src/main/resources/error/error-classes.json
index 0882e387176..58fcedae332 100644
--- a/common/utils/src/main/resources/error/error-classes.json
+++ b/common/utils/src/main/resources/error/error-classes.json
@@ -3484,6 +3484,11 @@
         "message" : [
           "IN/EXISTS predicate subqueries can only be used in filters, joins, 
aggregations, window functions, projections, and UPDATE/MERGE/DELETE 
commands<treeNode>."
         ]
+      },
+      "UNSUPPORTED_TABLE_ARGUMENT" : {
+        "message" : [
+          "Table arguments are used in a function where they are not 
supported<treeNode>."
+        ]
       }
     },
     "sqlState" : "0A000"
diff --git 
a/docs/sql-error-conditions-unsupported-subquery-expression-category-error-class.md
 
b/docs/sql-error-conditions-unsupported-subquery-expression-category-error-class.md
index f61ea721aa0..45ad386c666 100644
--- 
a/docs/sql-error-conditions-unsupported-subquery-expression-category-error-class.md
+++ 
b/docs/sql-error-conditions-unsupported-subquery-expression-category-error-class.md
@@ -73,4 +73,8 @@ Correlated scalar subqueries can only be used in filters, 
aggregations, projecti
 
 IN/EXISTS predicate subqueries can only be used in filters, joins, 
aggregations, window functions, projections, and UPDATE/MERGE/DELETE 
commands`<treeNode>`.
 
+## UNSUPPORTED_TABLE_ARGUMENT
+
+Table arguments are used in a function where they are not 
supported`<treeNode>`.
+
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 67a958d73f7..cc0bfd3fc31 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -2082,7 +2082,7 @@ class Analyzer(override val catalogManager: 
CatalogManager) extends RuleExecutor
       case u: UnresolvedTableValuedFunction if 
u.functionArgs.forall(_.resolved) =>
         withPosition(u) {
           try {
-            val resolvedTvf = resolveBuiltinOrTempTableFunction(u.name, 
u.functionArgs).getOrElse {
+            val resolvedFunc = resolveBuiltinOrTempTableFunction(u.name, 
u.functionArgs).getOrElse {
               val CatalogAndIdentifier(catalog, ident) = 
expandIdentifier(u.name)
               if (CatalogV2Util.isSessionCatalog(catalog)) {
                 v1SessionCatalog.resolvePersistentTableFunction(
@@ -2092,93 +2092,19 @@ class Analyzer(override val catalogManager: 
CatalogManager) extends RuleExecutor
                   catalog, "table-valued functions")
               }
             }
-            // Resolve Python UDTF calls if needed.
-            val resolvedFunc = resolvedTvf match {
-              case g @ Generate(u: UnresolvedPolymorphicPythonUDTF, _, _, _, 
_, _) =>
-                val analyzeResult: PythonUDTFAnalyzeResult =
-                  u.resolveElementMetadata(u.func, u.children)
-                g.copy(generator =
-                  PythonUDTF(u.name, u.func, analyzeResult.schema, u.children,
-                    u.evalType, u.udfDeterministic, u.resultId, 
u.pythonUDTFPartitionColumnIndexes,
-                    analyzeResult = Some(analyzeResult)))
-              case other =>
-                other
-            }
-            val tableArgs = mutable.ArrayBuffer.empty[LogicalPlan]
-            val functionTableSubqueryArgs =
-              
mutable.ArrayBuffer.empty[FunctionTableSubqueryArgumentExpression]
-            val tvf = resolvedFunc.transformAllExpressionsWithPruning(
-              _.containsPattern(FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION), 
ruleId)  {
+            resolvedFunc.transformAllExpressionsWithPruning(
+              _.containsPattern(FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION))  
{
               case t: FunctionTableSubqueryArgumentExpression =>
-                val alias = 
SubqueryAlias.generateSubqueryName(s"_${tableArgs.size}")
-                val (
-                  pythonUDTFName: String,
-                  pythonUDTFAnalyzeResult: Option[PythonUDTFAnalyzeResult]) =
-                  resolvedFunc match {
-                    case Generate(p: PythonUDTF, _, _, _, _, _) =>
-                      (p.name,
-                        p.analyzeResult)
-                    case _ =>
-                      assert(!t.hasRepartitioning,
-                        "Cannot evaluate the table-valued function call 
because it included the " +
-                          "PARTITION BY clause, but only Python table 
functions support this " +
-                          "clause")
-                      ("", None)
-                  }
-                // Check if this is a call to a Python user-defined table 
function whose polymorphic
-                // 'analyze' method returned metadata indicated requested 
partitioning and/or
-                // ordering properties of the input relation. In that event, 
make sure that the UDTF
-                // call did not include any explicit PARTITION BY and/or ORDER 
BY clauses for the
-                // corresponding TABLE argument, and then update the TABLE 
argument representation
-                // to apply the requested partitioning and/or ordering.
-                pythonUDTFAnalyzeResult.map { analyzeResult =>
-                  val newTableArgument: 
FunctionTableSubqueryArgumentExpression =
-                    analyzeResult.applyToTableArgument(pythonUDTFName, t)
-                  tableArgs.append(SubqueryAlias(alias, 
newTableArgument.evaluable))
-                  functionTableSubqueryArgs.append(newTableArgument)
-                }.getOrElse {
-                  tableArgs.append(SubqueryAlias(alias, t.evaluable))
-                  functionTableSubqueryArgs.append(t)
+                resolvedFunc match {
+                  case Generate(_: PythonUDTF, _, _, _, _, _) =>
+                  case Generate(_: UnresolvedPolymorphicPythonUDTF, _, _, _, 
_, _) =>
+                  case _ =>
+                    assert(!t.hasRepartitioning,
+                      "Cannot evaluate the table-valued function call because 
it included the " +
+                        "PARTITION BY clause, but only Python table functions 
support this " +
+                        "clause")
                 }
-                UnresolvedAttribute(Seq(alias, "c"))
-            }
-            if (tableArgs.nonEmpty) {
-              if (!conf.tvfAllowMultipleTableArguments && tableArgs.size > 1) {
-                throw 
QueryCompilationErrors.tableValuedFunctionTooManyTableArgumentsError(
-                  tableArgs.size)
-              }
-              val alias = 
SubqueryAlias.generateSubqueryName(s"_${tableArgs.size}")
-              // Propagate the column indexes for TABLE arguments to the 
PythonUDTF instance.
-              def assignUDTFPartitionColumnIndexes(
-                  fn: PythonUDTFPartitionColumnIndexes => LogicalPlan): 
Option[LogicalPlan] = {
-                val indexes: Seq[Int] = functionTableSubqueryArgs.headOption
-                  .map(_.partitioningExpressionIndexes).getOrElse(Seq.empty)
-                if (indexes.nonEmpty) {
-                  Some(fn(PythonUDTFPartitionColumnIndexes(indexes)))
-                } else {
-                  None
-                }
-              }
-              val tvfWithTableColumnIndexes: LogicalPlan = tvf match {
-                case g@Generate(p: PythonUDTF, _, _, _, _, _) =>
-                  assignUDTFPartitionColumnIndexes(
-                    i => g.copy(generator = 
p.copy(pythonUDTFPartitionColumnIndexes = Some(i))))
-                    .getOrElse(g)
-                case g@Generate(p: UnresolvedPolymorphicPythonUDTF, _, _, _, 
_, _) =>
-                  assignUDTFPartitionColumnIndexes(
-                    i => g.copy(generator = 
p.copy(pythonUDTFPartitionColumnIndexes = Some(i))))
-                    .getOrElse(g)
-                case _ =>
-                  tvf
-              }
-              Project(
-                Seq(UnresolvedStar(Some(Seq(alias)))),
-                LateralJoin(
-                  tableArgs.reduceLeft(Join(_, _, Inner, None, JoinHint.NONE)),
-                  LateralSubquery(SubqueryAlias(alias, 
tvfWithTableColumnIndexes)), Inner, None)
-              )
-            } else {
-              tvf
+                t
             }
           } catch {
             case _: NoSuchFunctionException =>
@@ -2206,6 +2132,46 @@ class Analyzer(override val catalogManager: 
CatalogManager) extends RuleExecutor
         }
         Project(aliases, u.child)
 
+      case p: LogicalPlan
+          if p.resolved && 
p.containsPattern(FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION) =>
+        withPosition(p) {
+          val tableArgs =
+            
mutable.ArrayBuffer.empty[(FunctionTableSubqueryArgumentExpression, 
LogicalPlan)]
+
+          val tvf = p.transformExpressionsWithPruning(
+            _.containsPattern(FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION)) {
+            case t: FunctionTableSubqueryArgumentExpression =>
+              val alias = 
SubqueryAlias.generateSubqueryName(s"_${tableArgs.size}")
+              tableArgs.append((t, SubqueryAlias(alias, t.evaluable)))
+              UnresolvedAttribute(Seq(alias, "c"))
+          }
+
+          assert(tableArgs.nonEmpty)
+          if (!conf.tvfAllowMultipleTableArguments && tableArgs.size > 1) {
+            throw 
QueryCompilationErrors.tableValuedFunctionTooManyTableArgumentsError(
+              tableArgs.size)
+          }
+          val alias = SubqueryAlias.generateSubqueryName(s"_${tableArgs.size}")
+
+          // Propagate the column indexes for TABLE arguments to the 
PythonUDTF instance.
+          val tvfWithTableColumnIndexes = tvf match {
+            case g @ Generate(pyudtf: PythonUDTF, _, _, _, _, _)
+                if tableArgs.head._1.partitioningExpressionIndexes.nonEmpty =>
+              val partitionColumnIndexes =
+                
PythonUDTFPartitionColumnIndexes(tableArgs.head._1.partitioningExpressionIndexes)
+              g.copy(generator = pyudtf.copy(
+                pythonUDTFPartitionColumnIndexes = 
Some(partitionColumnIndexes)))
+            case _ => tvf
+          }
+
+          Project(
+            Seq(UnresolvedStar(Some(Seq(alias)))),
+            LateralJoin(
+              tableArgs.map(_._2).reduceLeft(Join(_, _, Inner, None, 
JoinHint.NONE)),
+              LateralSubquery(SubqueryAlias(alias, 
tvfWithTableColumnIndexes)), Inner, None)
+          )
+        }
+
       case q: LogicalPlan =>
         q.transformExpressionsUpWithPruning(
           _.containsAnyPattern(UNRESOLVED_FUNCTION, GENERATOR),
@@ -2251,9 +2217,20 @@ class Analyzer(override val catalogManager: 
CatalogManager) extends RuleExecutor
           }
 
           case u: UnresolvedPolymorphicPythonUDTF => withPosition(u) {
-            val elementSchema = u.resolveElementMetadata(u.func, 
u.children).schema
-            PythonUDTF(u.name, u.func, elementSchema, u.children,
-              u.evalType, u.udfDeterministic, u.resultId, 
u.pythonUDTFPartitionColumnIndexes)
+            // Check if this is a call to a Python user-defined table function 
whose polymorphic
+            // 'analyze' method returned metadata indicated requested 
partitioning and/or
+            // ordering properties of the input relation. In that event, make 
sure that the UDTF
+            // call did not include any explicit PARTITION BY and/or ORDER BY 
clauses for the
+            // corresponding TABLE argument, and then update the TABLE 
argument representation
+            // to apply the requested partitioning and/or ordering.
+            val analyzeResult = u.resolveElementMetadata(u.func, u.children)
+            val newChildren = u.children.map {
+              case t: FunctionTableSubqueryArgumentExpression =>
+                analyzeResult.applyToTableArgument(u.name, t)
+              case c => c
+            }
+            PythonUDTF(u.name, u.func, analyzeResult.schema, newChildren,
+              u.evalType, u.udfDeterministic, u.resultId)
           }
         }
     }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 83b682bc917..de453f6bc49 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -1075,6 +1075,11 @@ trait CheckAnalysis extends PredicateHelper with 
LookupCatalog with QueryErrorsB
         // allowed by spark.
         checkCorrelationsInSubquery(expr.plan, isLateral = true)
 
+      case _: FunctionTableSubqueryArgumentExpression =>
+        expr.failAnalysis(
+          errorClass = 
"UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.UNSUPPORTED_TABLE_ARGUMENT",
+          messageParameters = Map("treeNode" -> planToString(plan)))
+
       case inSubqueryOrExistsSubquery =>
         plan match {
           case _: Filter | _: SupportsSubquery | _: Join |
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
index a615348bc6e..bc74572444c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
@@ -180,8 +180,7 @@ case class PythonUDTF(
     evalType: Int,
     udfDeterministic: Boolean,
     resultId: ExprId = NamedExpression.newExprId,
-    pythonUDTFPartitionColumnIndexes: Option[PythonUDTFPartitionColumnIndexes] 
= None,
-    analyzeResult: Option[PythonUDTFAnalyzeResult] = None)
+    pythonUDTFPartitionColumnIndexes: Option[PythonUDTFPartitionColumnIndexes] 
= None)
   extends UnevaluableGenerator with PythonFuncExpression {
 
   override lazy val canonicalized: Expression = {
@@ -210,8 +209,7 @@ case class UnresolvedPolymorphicPythonUDTF(
     evalType: Int,
     udfDeterministic: Boolean,
     resolveElementMetadata: (PythonFunction, Seq[Expression]) => 
PythonUDTFAnalyzeResult,
-    resultId: ExprId = NamedExpression.newExprId,
-    pythonUDTFPartitionColumnIndexes: Option[PythonUDTFPartitionColumnIndexes] 
= None)
+    resultId: ExprId = NamedExpression.newExprId)
   extends UnevaluableGenerator with PythonFuncExpression {
 
   override lazy val resolved = false
diff --git 
a/sql/core/src/test/resources/sql-tests/analyzer-results/named-function-arguments.sql.out
 
b/sql/core/src/test/resources/sql-tests/analyzer-results/named-function-arguments.sql.out
index 11e2651c6f2..4ba47e9e1b4 100644
--- 
a/sql/core/src/test/resources/sql-tests/analyzer-results/named-function-arguments.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/analyzer-results/named-function-arguments.sql.out
@@ -202,21 +202,17 @@ SELECT * FROM explode(collection => TABLE(v))
 -- !query analysis
 org.apache.spark.sql.catalyst.ExtendedAnalysisException
 {
-  "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
-  "sqlState" : "42K09",
+  "errorClass" : 
"UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.UNSUPPORTED_TABLE_ARGUMENT",
+  "sqlState" : "0A000",
   "messageParameters" : {
-    "inputSql" : "\"outer(__auto_generated_subquery_name_0.c)\"",
-    "inputType" : "\"STRUCT<id: BIGINT>\"",
-    "paramIndex" : "1",
-    "requiredType" : "(\"ARRAY\" or \"MAP\")",
-    "sqlExpr" : "\"explode(outer(__auto_generated_subquery_name_0.c))\""
+    "treeNode" : "'Generate explode(table-argument#x []), false\n:  +- 
SubqueryAlias v\n:     +- View (`v`, [id#xL])\n:        +- Project [cast(id#xL 
as bigint) AS id#xL]\n:           +- Project [id#xL]\n:              +- Range 
(0, 8, step=1, splits=None)\n+- OneRowRelation\n"
   },
   "queryContext" : [ {
     "objectType" : "",
     "objectName" : "",
-    "startIndex" : 15,
-    "stopIndex" : 45,
-    "fragment" : "explode(collection => TABLE(v))"
+    "startIndex" : 37,
+    "stopIndex" : 44,
+    "fragment" : "TABLE(v)"
   } ]
 }
 
diff --git 
a/sql/core/src/test/resources/sql-tests/results/named-function-arguments.sql.out
 
b/sql/core/src/test/resources/sql-tests/results/named-function-arguments.sql.out
index 60301862a35..03963ac3ef9 100644
--- 
a/sql/core/src/test/resources/sql-tests/results/named-function-arguments.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/results/named-function-arguments.sql.out
@@ -185,21 +185,17 @@ struct<>
 -- !query output
 org.apache.spark.sql.catalyst.ExtendedAnalysisException
 {
-  "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
-  "sqlState" : "42K09",
+  "errorClass" : 
"UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.UNSUPPORTED_TABLE_ARGUMENT",
+  "sqlState" : "0A000",
   "messageParameters" : {
-    "inputSql" : "\"outer(__auto_generated_subquery_name_0.c)\"",
-    "inputType" : "\"STRUCT<id: BIGINT>\"",
-    "paramIndex" : "1",
-    "requiredType" : "(\"ARRAY\" or \"MAP\")",
-    "sqlExpr" : "\"explode(outer(__auto_generated_subquery_name_0.c))\""
+    "treeNode" : "'Generate explode(table-argument#x []), false\n:  +- 
SubqueryAlias v\n:     +- View (`v`, [id#xL])\n:        +- Project [cast(id#xL 
as bigint) AS id#xL]\n:           +- Project [id#xL]\n:              +- Range 
(0, 8, step=1, splits=None)\n+- OneRowRelation\n"
   },
   "queryContext" : [ {
     "objectType" : "",
     "objectName" : "",
-    "startIndex" : 15,
-    "stopIndex" : 45,
-    "fragment" : "explode(collection => TABLE(v))"
+    "startIndex" : 37,
+    "stopIndex" : 44,
+    "fragment" : "TABLE(v)"
   } ]
 }
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala
index cf687f90287..cdc3ef9e417 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala
@@ -128,6 +128,8 @@ class PythonUDTFSuite extends QueryTest with 
SharedSparkSession {
     def failure(plan: LogicalPlan): Unit = {
       fail(s"Unexpected plan: $plan")
     }
+
+    spark.udtf.registerPython("testUDTF", pythonUDTF)
     sql(
       """
         |SELECT * FROM testUDTF(
@@ -187,19 +189,15 @@ class PythonUDTFSuite extends QueryTest with 
SharedSparkSession {
     withTable("t") {
       sql("create table t(col array<int>) using parquet")
       val query = "select * from explode(table(t))"
-      checkError(
+      checkErrorMatchPVals(
         exception = intercept[AnalysisException](sql(query)),
-        errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
-        parameters = Map(
-          "sqlExpr" -> 
"\"explode(outer(__auto_generated_subquery_name_0.c))\"",
-          "paramIndex" -> "1",
-          "inputSql" -> "\"outer(__auto_generated_subquery_name_0.c)\"",
-          "inputType" -> "\"STRUCT<col: ARRAY<INT>>\"",
-          "requiredType" -> "(\"ARRAY\" or \"MAP\")"),
+        errorClass = 
"UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.UNSUPPORTED_TABLE_ARGUMENT",
+        sqlState = None,
+        parameters = Map("treeNode" -> "(?s).*"),
         context = ExpectedContext(
-          fragment = "explode(table(t))",
-          start = 14,
-          stop = 30))
+          fragment = "table(t)",
+          start = 22,
+          stop = 29))
     }
 
     spark.udtf.registerPython("UDTFCountSumLast", pythonUDTFCountSumLast)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to