This is an automated email from the ASF dual-hosted git repository.

ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 69cf80d25f0e [SPARK-45402][SQL][PYTHON] Add UDTF API for 'eval' and 
'terminate' methods to consume previous 'analyze' result
69cf80d25f0e is described below

commit 69cf80d25f0e4ed46ec38a63e063471988c31732
Author: Daniel Tenedorio <daniel.tenedo...@databricks.com>
AuthorDate: Wed Oct 11 18:52:06 2023 -0700

    [SPARK-45402][SQL][PYTHON] Add UDTF API for 'eval' and 'terminate' methods 
to consume previous 'analyze' result
    
    ### What changes were proposed in this pull request?
    
    This PR adds a Python UDTF API for the `eval` and `terminate` methods to 
consume the previous `analyze` result.
    
    This also works for subclasses of the `AnalyzeResult` class, allowing the 
UDTF to return custom state from `analyze` to be consumed later.
    
    For example, we can now define a UDTF that perform complex initialization 
in the `analyze` method and then returns the result of that in the `terminate` 
method:
    
    ```
    def MyUDTF(self):
        dataclass
        class AnalyzeResultWithBuffer(AnalyzeResult):
            buffer: str
    
        udtf
        class TestUDTF:
            def __init__(self, analyze_result):
                self._total = 0
                self._buffer = do_complex_initialization(analyze_result.buffer)
    
            staticmethod
            def analyze(argument, _):
                return AnalyzeResultWithBuffer(
                    schema=StructType()
                        .add("total", IntegerType())
                        .add("buffer", StringType()),
                    with_single_partition=True,
                    buffer=argument.value,
                )
    
            def eval(self, argument, row: Row):
                self._total += 1
    
            def terminate(self):
                yield self._total, self._buffer
    
    self.spark.udtf.register("my_ddtf", MyUDTF)
    ```
    
    Then the results might look like:
    
    ```
    sql(
        """
        WITH t AS (
          SELECT id FROM range(1, 21)
        )
        SELECT total, buffer
        FROM test_udtf("abc", TABLE(t))
        """
    ).collect()
    
    > 20, "complex_initialization_result"
    ```
    
    ### Why are the changes needed?
    
    In this way, the UDTF can perform potentially expensive initialization 
logic in the `analyze` method just once and result the result of such 
initialization rather than repeating the initialization in `eval`.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, see above.
    
    ### How was this patch tested?
    
    This PR adds new unit test coverage.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #43204 from dtenedor/prepare-string.
    
    Authored-by: Daniel Tenedorio <daniel.tenedo...@databricks.com>
    Signed-off-by: Takuya UESHIN <ues...@databricks.com>
---
 python/docs/source/user_guide/sql/python_udtf.rst  | 124 ++++++++++++++++++++-
 python/pyspark/sql/tests/test_udtf.py              |  53 +++++++++
 python/pyspark/sql/udtf.py                         |   5 +-
 python/pyspark/sql/worker/analyze_udtf.py          |   2 +
 python/pyspark/worker.py                           |  34 +++++-
 .../spark/sql/catalyst/analysis/Analyzer.scala     |   5 +-
 .../spark/sql/catalyst/expressions/PythonUDF.scala |  20 +++-
 .../execution/python/BatchEvalPythonUDTFExec.scala |   8 ++
 .../python/UserDefinedPythonFunction.scala         |   7 +-
 .../sql-tests/analyzer-results/udtf/udtf.sql.out   |  26 +++--
 .../test/resources/sql-tests/inputs/udtf/udtf.sql  |   9 +-
 .../resources/sql-tests/results/udtf/udtf.sql.out  |  28 +++--
 .../apache/spark/sql/IntegratedUDFTestUtils.scala  |  64 ++++++++++-
 .../sql/execution/python/PythonUDTFSuite.scala     |  42 +++++--
 14 files changed, 374 insertions(+), 53 deletions(-)

diff --git a/python/docs/source/user_guide/sql/python_udtf.rst 
b/python/docs/source/user_guide/sql/python_udtf.rst
index 74d8eb889861..fb42644dc702 100644
--- a/python/docs/source/user_guide/sql/python_udtf.rst
+++ b/python/docs/source/user_guide/sql/python_udtf.rst
@@ -50,10 +50,108 @@ To implement a Python UDTF, you first need to define a 
class implementing the me
 
             Notes
             -----
-            - This method does not accept any extra arguments. Only the default
-              constructor is supported.
             - You cannot create or reference the Spark session within the 
UDTF. Any
               attempt to do so will result in a serialization error.
+            - If the below `analyze` method is implemented, it is also 
possible to define this
+              method as: `__init__(self, analyze_result: AnalyzeResult)`. In 
this case, the result
+              of the `analyze` method is passed into all future instantiations 
of this UDTF class.
+              In this way, the UDTF may inspect the schema and metadata of the 
output table as
+              needed during execution of other methods in this class. Note 
that it is possible to
+              create a subclass of the `AnalyzeResult` class if desired for 
purposes of passing
+              custom information generated just once during UDTF analysis to 
other method calls;
+              this can be especially useful if this initialization is 
expensive.
+            """
+            ...
+
+        def analyze(self, *args: Any) -> AnalyzeResult:
+            """
+            Computes the output schema of a particular call to this function 
in response to the
+            arguments provided.
+
+            This method is optional and only needed if the registration of the 
UDTF did not provide
+            a static output schema to be use for all calls to the function. In 
this context,
+            `output schema` refers to the ordered list of the names and types 
of the columns in the
+            function's result table.
+
+            This method accepts zero or more parameters mapping 1:1 with the 
arguments provided to
+            the particular UDTF call under consideration. Each parameter is an 
instance of the
+            `AnalyzeArgument` class, which contains fields including the 
provided argument's data
+            type and value (in the case of literal scalar arguments only). For 
table arguments, the
+            `is_table` field is set to true and the `data_type` field is a 
StructType representing
+            the table's column types:
+
+                data_type: DataType
+                value: Optional[Any]
+                is_table: bool
+
+            This method returns an instance of the `AnalyzeResult` class which 
includes the result
+            table's schema as a StructType. If the UDTF accepts an input table 
argument, then the
+            `AnalyzeResult` can also include a requested way to partition the 
rows of the input
+            table across several UDTF calls. If `with_single_partition` is set 
to True, the query
+            planner will arrange a repartitioning operation from the previous 
execution stage such
+            that all rows of the input table are consumed by the `eval` method 
from exactly one
+            instance of the UDTF class. On the other hand, if the 
`partition_by` list is non-empty,
+            the query planner will arrange a repartitioning such that all rows 
with each unique
+            combination of values of the partitioning columns are consumed by 
a separate unique
+            instance of the UDTF class. If `order_by` is non-empty, this 
specifies the requested
+            ordering of rows within each partition.
+
+                schema: StructType
+                with_single_partition: bool = False
+                partition_by: Sequence[PartitioningColumn] = 
field(default_factory=tuple)
+                order_by: Sequence[OrderingColumn] = 
field(default_factory=tuple)
+
+            Examples
+            --------
+            analyze implementation that returns one output column for each 
word in the input string
+            argument.
+
+            >>> def analyze(self, text: str) -> AnalyzeResult:
+            ...     schema = StructType()
+            ...     for index, word in enumerate(text.split(" ")):
+            ...         schema = schema.add(f"word_{index}")
+            ...     return AnalyzeResult(schema=schema)
+
+            Same as above, but using *args to accept the arguments.
+
+            >>> def analyze(self, *args) -> AnalyzeResult:
+            ...     assert len(args) == 1, "This function accepts one argument 
only"
+            ...     assert args[0].data_type == StringType(), "Only string 
arguments are supported"
+            ...     text = args[0]
+            ...     schema = StructType()
+            ...     for index, word in enumerate(text.split(" ")):
+            ...         schema = schema.add(f"word_{index}")
+            ...     return AnalyzeResult(schema=schema)
+
+            Same as above, but using **kwargs to accept the arguments.
+
+            >>> def analyze(self, **kwargs) -> AnalyzeResult:
+            ...     assert len(kwargs) == 1, "This function accepts one 
argument only"
+            ...     assert "text" in kwargs, "An argument named 'text' is 
required"
+            ...     assert kwargs["text"].data_type == StringType(), "Only 
strings are supported"
+            ...     text = args["text"]
+            ...     schema = StructType()
+            ...     for index, word in enumerate(text.split(" ")):
+            ...         schema = schema.add(f"word_{index}")
+            ...     return AnalyzeResult(schema=schema)
+
+            analyze implementation that returns a constant output schema, but 
add custom information
+            in the result metadata to be consumed by future __init__ method 
calls:
+
+            >>> def analyze(self, text: str) -> AnalyzeResult:
+            ...     @dataclass
+            ...     class AnalyzeResultWithOtherMetadata(AnalyzeResult):
+            ...         num_words: int
+            ...         num_articles: int
+            ...     words = text.split(" ")
+            ...     return AnalyzeResultWithOtherMetadata(
+            ...         schema=StructType()
+            ...             .add("word", StringType())
+            ...             .add('total", IntegerType()),
+            ...         num_words=len(words),
+            ...         num_articles=len((
+            ...             word for word in words
+            ...             if word == 'a' or word == 'an' or word == 'the')))
             """
             ...
 
@@ -89,7 +187,9 @@ To implement a Python UDTF, you first need to define a class 
implementing the me
             -----
             - The result of the function must be a tuple representing a single 
row
               in the UDTF result table.
-            - UDTFs currently do not accept keyword arguments during the 
function call.
+            - It is also possible for UDTFs to accept the exact arguments 
expected, along with
+              their types.
+            - UDTFs can instead accept keyword arguments during the function 
call if needed.
 
             Examples
             --------
@@ -103,6 +203,24 @@ To implement a Python UDTF, you first need to define a 
class implementing the me
             >>> def eval(self, x: int, y: int):
             ...     yield (x + y, x - y)
             ...     yield (y + x, y - x)
+
+            Same as above, but using *args to accept the arguments:
+
+            >>> def eval(self, *args):
+            ...     assert len(args) == 2, "This function accepts two integer 
arguments only"
+            ...     x = args[0]
+            ...     y = args[1]
+            ...     yield (x + y, x - y)
+            ...     yield (y + x, y - x)
+
+            Same as above, but using **kwargs to accept the arguments:
+
+            >>> def eval(self, **kwargs):
+            ...     assert len(kwargs) == 2, "This function accepts two 
integer arguments only"
+            ...     x = kwargs["x"]
+            ...     y = kwargs["y"]
+            ...     yield (x + y, x - y)
+            ...     yield (y + x, y - x)
             """
             ...
 
diff --git a/python/pyspark/sql/tests/test_udtf.py 
b/python/pyspark/sql/tests/test_udtf.py
index 9c821f4bde9c..98676bd7be49 100644
--- a/python/pyspark/sql/tests/test_udtf.py
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -18,6 +18,7 @@ import os
 import shutil
 import tempfile
 import unittest
+from dataclasses import dataclass
 from typing import Iterator
 
 from py4j.protocol import Py4JJavaError
@@ -2365,6 +2366,58 @@ class BaseUDTFTestsMixin:
             + [Row(partition_col=42, count=3, total=3, last=None)],
         )
 
+    def test_udtf_with_prepare_string_from_analyze(self):
+        @dataclass
+        class AnalyzeResultWithBuffer(AnalyzeResult):
+            buffer: str = ""
+
+        @udtf
+        class TestUDTF:
+            def __init__(self, analyze_result=None):
+                self._total = 0
+                if analyze_result is not None:
+                    self._buffer = analyze_result.buffer
+                else:
+                    self._buffer = ""
+
+            @staticmethod
+            def analyze(argument, _):
+                if (
+                    argument.value is None
+                    or argument.is_table
+                    or not isinstance(argument.value, str)
+                    or len(argument.value) == 0
+                ):
+                    raise Exception("The first argument must be non-empty 
string")
+                assert argument.data_type == StringType()
+                assert not argument.is_table
+                return AnalyzeResultWithBuffer(
+                    schema=StructType().add("total", 
IntegerType()).add("buffer", StringType()),
+                    with_single_partition=True,
+                    buffer=argument.value,
+                )
+
+            def eval(self, argument, row: Row):
+                self._total += 1
+
+            def terminate(self):
+                yield self._total, self._buffer
+
+        self.spark.udtf.register("test_udtf", TestUDTF)
+
+        assertDataFrameEqual(
+            self.spark.sql(
+                """
+                WITH t AS (
+                  SELECT id FROM range(1, 21)
+                )
+                SELECT total, buffer
+                FROM test_udtf("abc", TABLE(t))
+                """
+            ).collect(),
+            [Row(count=20, buffer="abc")],
+        )
+
 
 class UDTFTests(BaseUDTFTestsMixin, ReusedSQLTestCase):
     @classmethod
diff --git a/python/pyspark/sql/udtf.py b/python/pyspark/sql/udtf.py
index ba4bac2ffdfa..26ce68111db8 100644
--- a/python/pyspark/sql/udtf.py
+++ b/python/pyspark/sql/udtf.py
@@ -85,7 +85,10 @@ class OrderingColumn:
     overrideNullsFirst: Optional[bool] = None
 
 
-@dataclass(frozen=True)
+# Note: this class is a "dataclass" for purposes of convenience, but it is not 
marked "frozen"
+# because the intention is that users may create subclasses of it for purposes 
of returning custom
+# information from the "analyze" method.
+@dataclass
 class AnalyzeResult:
     """
     The return of Python UDTF's analyze static method.
diff --git a/python/pyspark/sql/worker/analyze_udtf.py 
b/python/pyspark/sql/worker/analyze_udtf.py
index 6fb3ca995e5d..a6aa381eb14a 100644
--- a/python/pyspark/sql/worker/analyze_udtf.py
+++ b/python/pyspark/sql/worker/analyze_udtf.py
@@ -126,6 +126,8 @@ def main(infile: IO, outfile: IO) -> None:
 
         # Return the analyzed schema.
         write_with_length(result.schema.json().encode("utf-8"), outfile)
+        # Return the pickled 'AnalyzeResult' class instance.
+        pickleSer._write_with_length(result, outfile)
         # Return whether the "with single partition" property is requested.
         write_int(1 if result.with_single_partition else 0, outfile)
         # Return the list of partitioning columns, if any.
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 3d08f6c4baea..df7dd1bc2f73 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -20,6 +20,7 @@ Worker that receives input from Piped RDD.
 """
 import os
 import sys
+import dataclasses
 import time
 from inspect import getfullargspec
 import json
@@ -666,7 +667,7 @@ def read_udtf(pickleSer, infile, eval_type):
         # Each row is a group so do not batch but send one by one.
         ser = BatchedSerializer(CPickleSerializer(), 1)
 
-    # See `PythonUDTFRunner.PythonUDFWriterThread.writeCommand'
+    # See 'PythonUDTFRunner.PythonUDFWriterThread.writeCommand'
     num_arg = read_int(infile)
     args_offsets = []
     kwargs_offsets = {}
@@ -679,6 +680,14 @@ def read_udtf(pickleSer, infile, eval_type):
             args_offsets.append(offset)
     num_partition_child_indexes = read_int(infile)
     partition_child_indexes = [read_int(infile) for i in 
range(num_partition_child_indexes)]
+    has_pickled_analyze_result = read_bool(infile)
+    if has_pickled_analyze_result:
+        pickled_analyze_result = pickleSer._read_with_length(infile)
+    else:
+        pickled_analyze_result = None
+    # Initially we assume that the UDTF __init__ method accepts the pickled 
AnalyzeResult,
+    # although we may set this to false later if we find otherwise.
+    udtf_init_method_accepts_analyze_result = True
     handler = read_command(pickleSer, infile)
     if not isinstance(handler, type):
         raise PySparkRuntimeError(
@@ -692,6 +701,29 @@ def read_udtf(pickleSer, infile, eval_type):
             f"The return type of a UDTF must be a struct type, but got 
{type(return_type)}."
         )
 
+    # Update the handler that creates a new UDTF instance to first try calling 
the UDTF constructor
+    # with one argument containing the previous AnalyzeResult. If that fails, 
then try a constructor
+    # with no arguments. In this way each UDTF class instance can decide if it 
wants to inspect the
+    # AnalyzeResult.
+    if has_pickled_analyze_result:
+        prev_handler = handler
+
+        def construct_udtf():
+            nonlocal udtf_init_method_accepts_analyze_result
+            if not udtf_init_method_accepts_analyze_result:
+                return prev_handler()
+            else:
+                try:
+                    # Here we pass the AnalyzeResult to the UDTF's __init__ 
method.
+                    return 
prev_handler(dataclasses.replace(pickled_analyze_result))
+                except TypeError:
+                    # This means that the UDTF handler does not accept an 
AnalyzeResult object in
+                    # its __init__ method.
+                    udtf_init_method_accepts_analyze_result = False
+                    return prev_handler()
+
+        handler = construct_udtf
+
     class UDTFWithPartitions:
         """
         This implements the logic of a UDTF that accepts an input TABLE 
argument with one or more
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index cc0bfd3fc31b..18a0aec8fc61 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -2229,8 +2229,9 @@ class Analyzer(override val catalogManager: 
CatalogManager) extends RuleExecutor
                 analyzeResult.applyToTableArgument(u.name, t)
               case c => c
             }
-            PythonUDTF(u.name, u.func, analyzeResult.schema, newChildren,
-              u.evalType, u.udfDeterministic, u.resultId)
+            PythonUDTF(
+              u.name, u.func, analyzeResult.schema, 
Some(analyzeResult.pickledAnalyzeResult),
+              newChildren, u.evalType, u.udfDeterministic, u.resultId)
           }
         }
     }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
index 539505543a40..f886b50e8a23 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
@@ -159,6 +159,10 @@ abstract class UnevaluableGenerator extends Generator {
  * @param name name of the Python UDTF being called
  * @param func string contents of the Python code in the UDTF, along with 
other environment state
  * @param elementSchema result schema of the function call
+ * @param pickledAnalyzeResult if the UDTF defined an 'analyze' method, this 
contains the pickled
+ *                             'AnalyzeResult' instance from that method, 
which contains all
+ *                             metadata returned including the result schema 
of the function call as
+ *                             well as optional other information
  * @param children input arguments to the UDTF call; for scalar arguments 
these are the expressions
  *                 themeselves, and for TABLE arguments, these are instances of
  *                 [[FunctionTableSubqueryArgumentExpression]]
@@ -167,15 +171,15 @@ abstract class UnevaluableGenerator extends Generator {
  * @param udfDeterministic true if this function is deterministic wherein it 
returns the same result
  *                         rows for every call with the same input arguments
  * @param resultId unique expression ID for this function invocation
- * @param pythonUDTFPartitionColumnIndexes holds the indexes of the TABLE 
argument to the Python
- *                                         UDTF call, if applicable
- * @param analyzeResult holds the result of the polymorphic Python UDTF 
'analze' method, if the UDTF
- *                      defined one
+ * @param pythonUDTFPartitionColumnIndexes holds the zero-based indexes of the 
projected results of
+ *                                         all PARTITION BY expressions within 
the TABLE argument of
+ *                                         the Python UDTF call, if applicable
  */
 case class PythonUDTF(
     name: String,
     func: PythonFunction,
     elementSchema: StructType,
+    pickledAnalyzeResult: Option[Array[Byte]],
     children: Seq[Expression],
     evalType: Int,
     udfDeterministic: Boolean,
@@ -224,6 +228,7 @@ case class UnresolvedPolymorphicPythonUDTF(
 /**
  * Represents the result of invoking the polymorphic 'analyze' method on a 
Python user-defined table
  * function. This returns the table function's output schema in addition to 
other optional metadata.
+ *
  * @param schema result schema of this particular function call in response to 
the particular
  *               arguments provided, including the types of any provided 
scalar arguments (and
  *               their values, in the case of literals) as well as the names 
and types of columns of
@@ -241,12 +246,17 @@ case class UnresolvedPolymorphicPythonUDTF(
  * @param orderByExpressions if non-empty, this contains the list of ordering 
items that the
  *                           'analyze' method explicitly indicated that the 
UDTF call should consume
  *                           the input table rows by
+ * @param pickledAnalyzeResult this is the pickled 'AnalyzeResult' instance 
from the UDTF, which
+ *                             contains all metadata returned by the Python 
UDTF 'analyze' method
+ *                             including the result schema of the function 
call as well as optional
+ *                             other information
  */
 case class PythonUDTFAnalyzeResult(
     schema: StructType,
     withSinglePartition: Boolean,
     partitionByExpressions: Seq[Expression],
-    orderByExpressions: Seq[SortOrder]) {
+    orderByExpressions: Seq[SortOrder],
+    pickledAnalyzeResult: Array[Byte]) {
   /**
    * Applies the requested properties from this analysis result to the target 
TABLE argument
    * expression of a UDTF call, throwing an error if any properties of the 
UDTF call are
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
index a70d16dc7e89..40993f96e7a0 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala
@@ -112,6 +112,7 @@ object PythonUDTFRunner {
       dataOut: DataOutputStream,
       udtf: PythonUDTF,
       argMetas: Array[ArgumentMetadata]): Unit = {
+    // Write the argument types of the UDTF.
     dataOut.writeInt(argMetas.length)
     argMetas.foreach {
       case ArgumentMetadata(offset, name) =>
@@ -124,6 +125,8 @@ object PythonUDTFRunner {
             dataOut.writeBoolean(false)
         }
     }
+    // Write the zero-based indexes of the projected results of all PARTITION 
BY expressions within
+    // the TABLE argument of the Python UDTF call, if applicable.
     udtf.pythonUDTFPartitionColumnIndexes match {
       case Some(partitionColumnIndexes) =>
         dataOut.writeInt(partitionColumnIndexes.partitionChildIndexes.length)
@@ -132,7 +135,12 @@ object PythonUDTFRunner {
       case None =>
         dataOut.writeInt(0)
     }
+    // Write the pickled AnalyzeResult buffer from the UDTF "analyze" method, 
if any.
+    dataOut.writeBoolean(udtf.pickledAnalyzeResult.nonEmpty)
+    udtf.pickledAnalyzeResult.foreach(PythonWorkerUtils.writeBytes(_, dataOut))
+    // Write the contents of the Python script itself.
     PythonWorkerUtils.writePythonFunction(udtf.func, dataOut)
+    // Write the result schema of the UDTF call.
     PythonWorkerUtils.writeUTF(udtf.elementSchema.json, dataOut)
   }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
index b03942cdf43c..d8d3cc9b7fc4 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
@@ -129,6 +129,7 @@ case class UserDefinedPythonTableFunction(
           name = name,
           func = func,
           elementSchema = rt,
+          pickledAnalyzeResult = None,
           children = exprs,
           evalType = pythonEvalType,
           udfDeterministic = udfDeterministic)
@@ -283,6 +284,9 @@ object UserDefinedPythonTableFunction {
       val schema = DataType.fromJson(
         PythonWorkerUtils.readUTF(length, dataIn)).asInstanceOf[StructType]
 
+      // Receive the pickled AnalyzeResult buffer, if any.
+      val pickledAnalyzeResult: Array[Byte] = 
PythonWorkerUtils.readBytes(dataIn)
+
       // Receive whether the "with single partition" property is requested.
       val withSinglePartition = dataIn.readInt() == 1
       // Receive the list of requested partitioning columns, if any.
@@ -324,7 +328,8 @@ object UserDefinedPythonTableFunction {
         schema = schema,
         withSinglePartition = withSinglePartition,
         partitionByExpressions = partitionByColumns.toSeq,
-        orderByExpressions = orderBy.toSeq)
+        orderByExpressions = orderBy.toSeq,
+        pickledAnalyzeResult = pickledAnalyzeResult)
     } catch {
       case eof: EOFException =>
         throw new SparkException("Python worker exited unexpectedly 
(crashed)", eof)
diff --git 
a/sql/core/src/test/resources/sql-tests/analyzer-results/udtf/udtf.sql.out 
b/sql/core/src/test/resources/sql-tests/analyzer-results/udtf/udtf.sql.out
index f7b2bada26ec..1b923442207e 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/udtf/udtf.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/udtf/udtf.sql.out
@@ -123,13 +123,19 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
 
 
 -- !query
-SELECT * FROM UDTFWithSinglePartition(TABLE(t2))
+SELECT * FROM UDTFWithSinglePartition(0, TABLE(t2))
 -- !query analysis
 [Analyzer test output redacted due to nondeterminism]
 
 
 -- !query
-SELECT * FROM UDTFWithSinglePartition(TABLE(t2) WITH SINGLE PARTITION)
+SELECT * FROM UDTFWithSinglePartition(1, TABLE(t2))
+-- !query analysis
+[Analyzer test output redacted due to nondeterminism]
+
+
+-- !query
+SELECT * FROM UDTFWithSinglePartition(0, TABLE(t2) WITH SINGLE PARTITION)
 -- !query analysis
 org.apache.spark.sql.AnalysisException
 {
@@ -144,14 +150,14 @@ org.apache.spark.sql.AnalysisException
     "objectType" : "",
     "objectName" : "",
     "startIndex" : 15,
-    "stopIndex" : 70,
-    "fragment" : "UDTFWithSinglePartition(TABLE(t2) WITH SINGLE PARTITION)"
+    "stopIndex" : 73,
+    "fragment" : "UDTFWithSinglePartition(0, TABLE(t2) WITH SINGLE PARTITION)"
   } ]
 }
 
 
 -- !query
-SELECT * FROM UDTFWithSinglePartition(TABLE(t2) PARTITION BY partition_col)
+SELECT * FROM UDTFWithSinglePartition(0, TABLE(t2) PARTITION BY partition_col)
 -- !query analysis
 org.apache.spark.sql.AnalysisException
 {
@@ -166,8 +172,8 @@ org.apache.spark.sql.AnalysisException
     "objectType" : "",
     "objectName" : "",
     "startIndex" : 15,
-    "stopIndex" : 75,
-    "fragment" : "UDTFWithSinglePartition(TABLE(t2) PARTITION BY 
partition_col)"
+    "stopIndex" : 78,
+    "fragment" : "UDTFWithSinglePartition(0, TABLE(t2) PARTITION BY 
partition_col)"
   } ]
 }
 
@@ -176,7 +182,7 @@ org.apache.spark.sql.AnalysisException
 SELECT * FROM
     VALUES (0), (1) AS t(col)
     JOIN LATERAL
-    UDTFWithSinglePartition(TABLE(t2) PARTITION BY partition_col)
+    UDTFWithSinglePartition(0, TABLE(t2) PARTITION BY partition_col)
 -- !query analysis
 org.apache.spark.sql.AnalysisException
 {
@@ -191,8 +197,8 @@ org.apache.spark.sql.AnalysisException
     "objectType" : "",
     "objectName" : "",
     "startIndex" : 66,
-    "stopIndex" : 126,
-    "fragment" : "UDTFWithSinglePartition(TABLE(t2) PARTITION BY 
partition_col)"
+    "stopIndex" : 129,
+    "fragment" : "UDTFWithSinglePartition(0, TABLE(t2) PARTITION BY 
partition_col)"
   } ]
 }
 
diff --git a/sql/core/src/test/resources/sql-tests/inputs/udtf/udtf.sql 
b/sql/core/src/test/resources/sql-tests/inputs/udtf/udtf.sql
index 6d49177c4f6a..6d34b91e2f16 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/udtf/udtf.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/udtf/udtf.sql
@@ -47,13 +47,14 @@ SELECT * FROM
 --           order_by=[
 --               OrderingColumn("input"),
 --               OrderingColumn("partition_col")])
-SELECT * FROM UDTFWithSinglePartition(TABLE(t2));
-SELECT * FROM UDTFWithSinglePartition(TABLE(t2) WITH SINGLE PARTITION);
-SELECT * FROM UDTFWithSinglePartition(TABLE(t2) PARTITION BY partition_col);
+SELECT * FROM UDTFWithSinglePartition(0, TABLE(t2));
+SELECT * FROM UDTFWithSinglePartition(1, TABLE(t2));
+SELECT * FROM UDTFWithSinglePartition(0, TABLE(t2) WITH SINGLE PARTITION);
+SELECT * FROM UDTFWithSinglePartition(0, TABLE(t2) PARTITION BY partition_col);
 SELECT * FROM
     VALUES (0), (1) AS t(col)
     JOIN LATERAL
-    UDTFWithSinglePartition(TABLE(t2) PARTITION BY partition_col);
+    UDTFWithSinglePartition(0, TABLE(t2) PARTITION BY partition_col);
 -- As a reminder, the UDTFPartitionByOrderBy function returns this analyze 
result:
 --     AnalyzeResult(
 --         schema=StructType()
diff --git a/sql/core/src/test/resources/sql-tests/results/udtf/udtf.sql.out 
b/sql/core/src/test/resources/sql-tests/results/udtf/udtf.sql.out
index a93aac945015..11295c43d8cb 100644
--- a/sql/core/src/test/resources/sql-tests/results/udtf/udtf.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/udtf/udtf.sql.out
@@ -161,7 +161,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
 
 
 -- !query
-SELECT * FROM UDTFWithSinglePartition(TABLE(t2))
+SELECT * FROM UDTFWithSinglePartition(0, TABLE(t2))
 -- !query schema
 struct<count:int,total:int,last:int>
 -- !query output
@@ -169,7 +169,15 @@ struct<count:int,total:int,last:int>
 
 
 -- !query
-SELECT * FROM UDTFWithSinglePartition(TABLE(t2) WITH SINGLE PARTITION)
+SELECT * FROM UDTFWithSinglePartition(1, TABLE(t2))
+-- !query schema
+struct<count:int,total:int,last:int>
+-- !query output
+3      6       3
+
+
+-- !query
+SELECT * FROM UDTFWithSinglePartition(0, TABLE(t2) WITH SINGLE PARTITION)
 -- !query schema
 struct<>
 -- !query output
@@ -186,14 +194,14 @@ org.apache.spark.sql.AnalysisException
     "objectType" : "",
     "objectName" : "",
     "startIndex" : 15,
-    "stopIndex" : 70,
-    "fragment" : "UDTFWithSinglePartition(TABLE(t2) WITH SINGLE PARTITION)"
+    "stopIndex" : 73,
+    "fragment" : "UDTFWithSinglePartition(0, TABLE(t2) WITH SINGLE PARTITION)"
   } ]
 }
 
 
 -- !query
-SELECT * FROM UDTFWithSinglePartition(TABLE(t2) PARTITION BY partition_col)
+SELECT * FROM UDTFWithSinglePartition(0, TABLE(t2) PARTITION BY partition_col)
 -- !query schema
 struct<>
 -- !query output
@@ -210,8 +218,8 @@ org.apache.spark.sql.AnalysisException
     "objectType" : "",
     "objectName" : "",
     "startIndex" : 15,
-    "stopIndex" : 75,
-    "fragment" : "UDTFWithSinglePartition(TABLE(t2) PARTITION BY 
partition_col)"
+    "stopIndex" : 78,
+    "fragment" : "UDTFWithSinglePartition(0, TABLE(t2) PARTITION BY 
partition_col)"
   } ]
 }
 
@@ -220,7 +228,7 @@ org.apache.spark.sql.AnalysisException
 SELECT * FROM
     VALUES (0), (1) AS t(col)
     JOIN LATERAL
-    UDTFWithSinglePartition(TABLE(t2) PARTITION BY partition_col)
+    UDTFWithSinglePartition(0, TABLE(t2) PARTITION BY partition_col)
 -- !query schema
 struct<>
 -- !query output
@@ -237,8 +245,8 @@ org.apache.spark.sql.AnalysisException
     "objectType" : "",
     "objectName" : "",
     "startIndex" : 66,
-    "stopIndex" : 126,
-    "fragment" : "UDTFWithSinglePartition(TABLE(t2) PARTITION BY 
partition_col)"
+    "stopIndex" : 129,
+    "fragment" : "UDTFWithSinglePartition(0, TABLE(t2) PARTITION BY 
partition_col)"
   } ]
 }
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala
index ef4606b70cae..3c30c414f81f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala
@@ -524,8 +524,15 @@ object IntegratedUDFTestUtils extends SQLHelper {
     val name: String = "UDTFWithSinglePartition"
     val pythonScript: String =
       s"""
+        |import json
+        |from dataclasses import dataclass
         |from pyspark.sql.functions import AnalyzeResult, OrderingColumn, 
PartitioningColumn
         |from pyspark.sql.types import IntegerType, Row, StructType
+        |
+        |@dataclass
+        |class AnalyzeResultWithBuffer(AnalyzeResult):
+        |    buffer: str = ""
+        |
         |class $name:
         |    def __init__(self):
         |        self._count = 0
@@ -533,8 +540,14 @@ object IntegratedUDFTestUtils extends SQLHelper {
         |        self._last = None
         |
         |    @staticmethod
-        |    def analyze(self):
-        |        return AnalyzeResult(
+        |    def analyze(initial_count, input_table):
+        |        buffer = ""
+        |        if initial_count.value is not None:
+        |            assert(not initial_count.is_table)
+        |            assert(initial_count.data_type == IntegerType())
+        |            count = initial_count.value
+        |            buffer = json.dumps({"initial_count": count})
+        |        return AnalyzeResultWithBuffer(
         |            schema=StructType()
         |                .add("count", IntegerType())
         |                .add("total", IntegerType())
@@ -542,9 +555,10 @@ object IntegratedUDFTestUtils extends SQLHelper {
         |            with_single_partition=True,
         |            order_by=[
         |                OrderingColumn("input"),
-        |                OrderingColumn("partition_col")])
+        |                OrderingColumn("partition_col")],
+        |            buffer=buffer)
         |
-        |    def eval(self, row: Row):
+        |    def eval(self, initial_count, row):
         |        self._count += 1
         |        self._last = row["input"]
         |        self._sum += row["input"]
@@ -693,6 +707,48 @@ object IntegratedUDFTestUtils extends SQLHelper {
         "without a corresponding partitioning table requirement"
   }
 
+  object TestPythonUDTFForwardStateFromAnalyze extends TestUDTF {
+    val name: String = "TestPythonUDTFForwardStateFromAnalyze"
+    val pythonScript: String =
+      s"""
+         |from dataclasses import dataclass
+         |from pyspark.sql.functions import AnalyzeResult
+         |from pyspark.sql.types import StringType, StructType
+         |
+         |@dataclass
+         |class AnalyzeResultWithBuffer(AnalyzeResult):
+         |    buffer: str = ""
+         |
+         |class $name:
+         |    def __init__(self, analyze_result):
+         |        self._analyze_result = analyze_result
+         |
+         |    @staticmethod
+         |    def analyze(argument):
+         |        assert(argument.data_type == StringType())
+         |        return AnalyzeResultWithBuffer(
+         |            schema=StructType()
+         |                .add("result", StringType()),
+         |            buffer=argument.value)
+         |
+         |    def eval(self, argument):
+         |        pass
+         |
+         |    def terminate(self):
+         |        yield self._analyze_result.buffer,
+         |""".stripMargin
+
+    val udtf: UserDefinedPythonTableFunction = 
createUserDefinedPythonTableFunction(
+      name = name,
+      pythonScript = pythonScript,
+      returnType = None)
+
+    def apply(session: SparkSession, exprs: Column*): DataFrame =
+      udtf.apply(session, exprs: _*)
+
+    val prettyName: String = "Python UDTF whose 'analyze' method sets state 
and reads it later"
+  }
+
   /**
    * A Scalar Pandas UDF that takes one column, casts into string, executes the
    * Python native function, and casts back to the type of input column.
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala
index cdc3ef9e4178..efab685236de 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala
@@ -48,15 +48,15 @@ class PythonUDTFSuite extends QueryTest with 
SharedSparkSession {
 
   private val pythonUDTFCountSumLast: UserDefinedPythonTableFunction =
     createUserDefinedPythonTableFunction(
-      "UDTFCountSumLast", TestPythonUDTFCountSumLast.pythonScript, None)
+      TestPythonUDTFCountSumLast.name, 
TestPythonUDTFCountSumLast.pythonScript, None)
 
   private val pythonUDTFWithSinglePartition: UserDefinedPythonTableFunction =
     createUserDefinedPythonTableFunction(
-      "UDTFWithSinglePartition", 
TestPythonUDTFWithSinglePartition.pythonScript, None)
+      TestPythonUDTFWithSinglePartition.name, 
TestPythonUDTFWithSinglePartition.pythonScript, None)
 
   private val pythonUDTFPartitionByOrderBy: UserDefinedPythonTableFunction =
     createUserDefinedPythonTableFunction(
-      "UDTFPartitionByOrderBy", TestPythonUDTFPartitionBy.pythonScript, None)
+      TestPythonUDTFPartitionBy.name, TestPythonUDTFPartitionBy.pythonScript, 
None)
 
   private val arrowPythonUDTF: UserDefinedPythonTableFunction =
     createUserDefinedPythonTableFunction(
@@ -65,6 +65,11 @@ class PythonUDTFSuite extends QueryTest with 
SharedSparkSession {
       Some(returnType),
       evalType = PythonEvalType.SQL_ARROW_TABLE_UDF)
 
+  private val pythonUDTFForwardStateFromAnalyze: 
UserDefinedPythonTableFunction =
+    createUserDefinedPythonTableFunction(
+      TestPythonUDTFForwardStateFromAnalyze.name,
+      TestPythonUDTFForwardStateFromAnalyze.pythonScript, None)
+
   test("Simple PythonUDTF") {
     assume(shouldTestPythonUDFs)
     val df = pythonUDTF(spark, lit(1), lit(2))
@@ -200,14 +205,14 @@ class PythonUDTFSuite extends QueryTest with 
SharedSparkSession {
           stop = 29))
     }
 
-    spark.udtf.registerPython("UDTFCountSumLast", pythonUDTFCountSumLast)
+    spark.udtf.registerPython(TestPythonUDTFCountSumLast.name, 
pythonUDTFCountSumLast)
     var plan = sql(
-      """
+      s"""
         |WITH t AS (
         |  VALUES (0, 1), (1, 2), (1, 3) t(partition_col, input)
         |)
         |SELECT count, total, last
-        |FROM UDTFCountSumLast(TABLE(t) WITH SINGLE PARTITION)
+        |FROM ${TestPythonUDTFCountSumLast.name}(TABLE(t) WITH SINGLE 
PARTITION)
         |ORDER BY 1, 2
         |""".stripMargin).queryExecution.analyzed
     plan.collectFirst { case r: Repartition => r } match {
@@ -216,16 +221,16 @@ class PythonUDTFSuite extends QueryTest with 
SharedSparkSession {
         failure(plan)
     }
 
-    spark.udtf.registerPython("UDTFWithSinglePartition", 
pythonUDTFWithSinglePartition)
+    spark.udtf.registerPython(TestPythonUDTFWithSinglePartition.name, 
pythonUDTFWithSinglePartition)
     plan = sql(
-      """
+      s"""
         |WITH t AS (
         |    SELECT id AS partition_col, 1 AS input FROM range(1, 21)
         |    UNION ALL
         |    SELECT id AS partition_col, 2 AS input FROM range(1, 21)
         |)
         |SELECT count, total, last
-        |FROM UDTFWithSinglePartition(TABLE(t))
+        |FROM ${TestPythonUDTFWithSinglePartition.name}(0, TABLE(t))
         |ORDER BY 1, 2
         |""".stripMargin).queryExecution.analyzed
     plan.collectFirst { case r: Repartition => r } match {
@@ -234,16 +239,16 @@ class PythonUDTFSuite extends QueryTest with 
SharedSparkSession {
         failure(plan)
     }
 
-    spark.udtf.registerPython("UDTFPartitionByOrderBy", 
pythonUDTFPartitionByOrderBy)
+    spark.udtf.registerPython(TestPythonUDTFPartitionBy.name, 
pythonUDTFPartitionByOrderBy)
     plan = sql(
-      """
+      s"""
         |WITH t AS (
         |    SELECT id AS partition_col, 1 AS input FROM range(1, 21)
         |    UNION ALL
         |    SELECT id AS partition_col, 2 AS input FROM range(1, 21)
         |)
         |SELECT partition_col, count, total, last
-        |FROM UDTFPartitionByOrderBy(TABLE(t))
+        |FROM ${TestPythonUDTFPartitionBy.name}(TABLE(t))
         |ORDER BY 1, 2
         |""".stripMargin).queryExecution.analyzed
     plan.collectFirst { case r: RepartitionByExpression => r } match {
@@ -345,4 +350,17 @@ class PythonUDTFSuite extends QueryTest with 
SharedSparkSession {
       Literal("abc"))) ==
       Seq(2, 3))
   }
+
+  test("SPARK-45402: Add UDTF API for 'analyze' to return a buffer to consume 
on class creation") {
+    spark.udtf.registerPython(
+      TestPythonUDTFForwardStateFromAnalyze.name,
+      pythonUDTFForwardStateFromAnalyze)
+    withTable("t") {
+      sql("create table t(col array<int>) using parquet")
+      val query = s"select * from 
${TestPythonUDTFForwardStateFromAnalyze.name}('abc')"
+      checkAnswer(
+        sql(query),
+        Row("abc"))
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to