This is an automated email from the ASF dual-hosted git repository.

ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 36983443112 [SPARK-45810][PYTHON] Create Python UDTF API to stop 
consuming rows from the input table
36983443112 is described below

commit 36983443112799dc2ee4462828e7c0552a63a229
Author: Daniel Tenedorio <daniel.tenedo...@databricks.com>
AuthorDate: Wed Nov 15 13:47:04 2023 -0800

    [SPARK-45810][PYTHON] Create Python UDTF API to stop consuming rows from 
the input table
    
    ### What changes were proposed in this pull request?
    
    This PR creates a Python UDTF API to stop consuming rows from the input 
table.
    
    If the UDTF raises a `SkipRestOfInputTableException` exception in the 
`eval` method, then the UDTF stops consuming rows from the input table for that 
input partition, and finally calls the `terminate` method (if any) to represent 
a successful UDTF call.
    
    For example:
    
    ```
    udtf(returnType="total: int")
    class TestUDTF:
        def __init__(self):
            self._total = 0
    
        def eval(self, _: Row):
            self._total += 1
            if self._total >= 3:
                raise SkipRestOfInputTableException("Stop at self._total >= 3")
    
        def terminate(self):
            yield self._total,
    ```
    
    ### Why are the changes needed?
    
    This is useful when the UDTF logic knows that we don't have to scan the 
input table anymore, and skip the rest of the I/O for that case.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, see above.
    
    ### How was this patch tested?
    
    This PR adds test coverage.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #43682 from dtenedor/udtf-api-stop-consuming-input-rows.
    
    Authored-by: Daniel Tenedorio <daniel.tenedo...@databricks.com>
    Signed-off-by: Takuya UESHIN <ues...@databricks.com>
---
 python/docs/source/user_guide/sql/python_udtf.rst | 38 ++++++++++++-----
 python/pyspark/sql/functions.py                   |  1 +
 python/pyspark/sql/tests/test_udtf.py             | 51 +++++++++++++++++++++++
 python/pyspark/sql/udtf.py                        | 19 ++++++++-
 python/pyspark/worker.py                          | 30 ++++++++++---
 5 files changed, 123 insertions(+), 16 deletions(-)

diff --git a/python/docs/source/user_guide/sql/python_udtf.rst 
b/python/docs/source/user_guide/sql/python_udtf.rst
index 0e0c6e28578..3e3c7634438 100644
--- a/python/docs/source/user_guide/sql/python_udtf.rst
+++ b/python/docs/source/user_guide/sql/python_udtf.rst
@@ -65,8 +65,8 @@ To implement a Python UDTF, you first need to define a class 
implementing the me
 
         def analyze(self, *args: Any) -> AnalyzeResult:
             """
-            Computes the output schema of a particular call to this function 
in response to the
-            arguments provided.
+            Static method to compute the output schema of a particular call to 
this function in
+            response to the arguments provided.
 
             This method is optional and only needed if the registration of the 
UDTF did not provide
             a static output schema to be use for all calls to the function. In 
this context,
@@ -101,12 +101,20 @@ To implement a Python UDTF, you first need to define a 
class implementing the me
                 partitionBy: Sequence[PartitioningColumn] = 
field(default_factory=tuple)
                 orderBy: Sequence[OrderingColumn] = 
field(default_factory=tuple)
 
+            Notes
+            -----
+            - It is possible for the `analyze` method to accept the exact 
arguments expected,
+              mapping 1:1 with the arguments provided to the UDTF call.
+            - The `analyze` method can instead choose to accept positional 
arguments if desired
+              (using `*args`) or keyword arguments (using `**kwargs`).
+
             Examples
             --------
-            analyze implementation that returns one output column for each 
word in the input string
-            argument.
+            This is an `analyze` implementation that returns one output column 
for each word in the
+            input string argument.
 
-            >>> def analyze(self, text: str) -> AnalyzeResult:
+            >>> @staticmethod
+            ... def analyze(text: str) -> AnalyzeResult:
             ...     schema = StructType()
             ...     for index, word in enumerate(text.split(" ")):
             ...         schema = schema.add(f"word_{index}")
@@ -114,7 +122,8 @@ To implement a Python UDTF, you first need to define a 
class implementing the me
 
             Same as above, but using *args to accept the arguments.
 
-            >>> def analyze(self, *args) -> AnalyzeResult:
+            >>> @staticmethod
+            ... def analyze(*args) -> AnalyzeResult:
             ...     assert len(args) == 1, "This function accepts one argument 
only"
             ...     assert args[0].dataType == StringType(), "Only string 
arguments are supported"
             ...     text = args[0]
@@ -125,7 +134,8 @@ To implement a Python UDTF, you first need to define a 
class implementing the me
 
             Same as above, but using **kwargs to accept the arguments.
 
-            >>> def analyze(self, **kwargs) -> AnalyzeResult:
+            >>> @staticmethod
+            ... def analyze(**kwargs) -> AnalyzeResult:
             ...     assert len(kwargs) == 1, "This function accepts one 
argument only"
             ...     assert "text" in kwargs, "An argument named 'text' is 
required"
             ...     assert kwargs["text"].dataType == StringType(), "Only 
strings are supported"
@@ -135,10 +145,11 @@ To implement a Python UDTF, you first need to define a 
class implementing the me
             ...         schema = schema.add(f"word_{index}")
             ...     return AnalyzeResult(schema=schema)
 
-            analyze implementation that returns a constant output schema, but 
add custom information
-            in the result metadata to be consumed by future __init__ method 
calls:
+            An `analyze` implementation that returns a constant output schema, 
but add custom
+            information in the result metadata to be consumed by future 
__init__ method calls:
 
-            >>> def analyze(self, text: str) -> AnalyzeResult:
+            >>> @staticmethod
+            ... def analyze(text: str) -> AnalyzeResult:
             ...     @dataclass
             ...     class AnalyzeResultWithOtherMetadata(AnalyzeResult):
             ...         num_words: int
@@ -190,6 +201,13 @@ To implement a Python UDTF, you first need to define a 
class implementing the me
             - It is also possible for UDTFs to accept the exact arguments 
expected, along with
               their types.
             - UDTFs can instead accept keyword arguments during the function 
call if needed.
+            - The `eval` method can raise a `SkipRestOfInputTableException` to 
indicate that the
+              UDTF wants to skip consuming all remaining rows from the current 
partition of the
+              input table. This will cause the UDTF to proceed directly to the 
`terminate` method.
+            - The `eval` method can raise any other exception to indicate that 
the UDTF should be
+              aborted entirely. This will cause the UDTF to skip the 
`terminate` method and proceed
+              directly to the `cleanup` method, and then the exception will be 
propagated to the
+              query processor causing the invoking query to fail.
 
             Examples
             --------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index ae0f1e70be6..e3b8e4965e4 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -51,6 +51,7 @@ from pyspark.sql.types import ArrayType, DataType, 
StringType, StructType, _from
 from pyspark.sql.udf import UserDefinedFunction, _create_py_udf  # noqa: F401
 from pyspark.sql.udtf import AnalyzeArgument, AnalyzeResult  # noqa: F401
 from pyspark.sql.udtf import OrderingColumn, PartitioningColumn  # noqa: F401
+from pyspark.sql.udtf import SkipRestOfInputTableException  # noqa: F401
 from pyspark.sql.udtf import UserDefinedTableFunction, _create_py_udtf
 
 # Keep pandas_udf and PandasUDFType import for backwards compatible import; 
moved in SPARK-28264
diff --git a/python/pyspark/sql/tests/test_udtf.py 
b/python/pyspark/sql/tests/test_udtf.py
index 3beb916de66..2794b51eb70 100644
--- a/python/pyspark/sql/tests/test_udtf.py
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -44,6 +44,7 @@ from pyspark.sql.functions import (
     AnalyzeResult,
     OrderingColumn,
     PartitioningColumn,
+    SkipRestOfInputTableException,
 )
 from pyspark.sql.types import (
     ArrayType,
@@ -2467,6 +2468,56 @@ class BaseUDTFTestsMixin:
             [Row(count=20, buffer="abc")],
         )
 
+    def test_udtf_with_skip_rest_of_input_table_exception(self):
+        @udtf(returnType="current: int, total: int")
+        class TestUDTF:
+            def __init__(self):
+                self._current = 0
+                self._total = 0
+
+            def eval(self, input: Row):
+                self._current = input["id"]
+                self._total += 1
+                if self._total >= 4:
+                    raise SkipRestOfInputTableException("Stop at self._total 
>= 4")
+
+            def terminate(self):
+                yield self._current, self._total
+
+        self.spark.udtf.register("test_udtf", TestUDTF)
+
+        # Run a test case including WITH SINGLE PARTITION on the UDTF call. The
+        # SkipRestOfInputTableException stops scanning rows after the fourth 
input row is consumed.
+        assertDataFrameEqual(
+            self.spark.sql(
+                """
+                WITH t AS (
+                  SELECT id FROM range(1, 21)
+                )
+                SELECT current, total
+                FROM test_udtf(TABLE(t) WITH SINGLE PARTITION ORDER BY id)
+                """
+            ),
+            [Row(current=4, total=4)],
+        )
+
+        # Run a test case including WITH SINGLE PARTITION on the UDTF call. The
+        # SkipRestOfInputTableException stops scanning rows for each of the 
two partitions
+        # separately.
+        assertDataFrameEqual(
+            self.spark.sql(
+                """
+                WITH t AS (
+                  SELECT id FROM range(1, 21)
+                )
+                SELECT current, total
+                FROM test_udtf(TABLE(t) PARTITION BY floor(id / 10) ORDER BY 
id)
+                ORDER BY ALL
+                """
+            ),
+            [Row(current=4, total=4), Row(current=13, total=4), 
Row(current=20, total=1)],
+        )
+
 
 class UDTFTests(BaseUDTFTestsMixin, ReusedSQLTestCase):
     @classmethod
diff --git a/python/pyspark/sql/udtf.py b/python/pyspark/sql/udtf.py
index aac212ffde9..ab330141514 100644
--- a/python/pyspark/sql/udtf.py
+++ b/python/pyspark/sql/udtf.py
@@ -38,7 +38,14 @@ if TYPE_CHECKING:
     from pyspark.sql.dataframe import DataFrame
     from pyspark.sql.session import SparkSession
 
-__all__ = ["AnalyzeArgument", "AnalyzeResult", "UDTFRegistration"]
+__all__ = [
+    "AnalyzeArgument",
+    "AnalyzeResult",
+    "PartitioningColumn",
+    "OrderingColumn",
+    "SkipRestOfInputTableException",
+    "UDTFRegistration",
+]
 
 
 @dataclass(frozen=True)
@@ -118,6 +125,16 @@ class AnalyzeResult:
     orderBy: Sequence[OrderingColumn] = field(default_factory=tuple)
 
 
+class SkipRestOfInputTableException(Exception):
+    """
+    This represents an exception that the 'eval' method may raise to indicate 
that it is done
+    consuming rows from the current partition of the input table. Then the 
UDTF's 'terminate'
+    method runs (if any).
+    """
+
+    pass
+
+
 def _create_udtf(
     cls: Type,
     returnType: Optional[Union[StructType, str]],
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index f6208032d9a..195c989c410 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -43,6 +43,7 @@ from pyspark.serializers import (
     CPickleSerializer,
     BatchedSerializer,
 )
+from pyspark.sql.functions import SkipRestOfInputTableException
 from pyspark.sql.pandas.serializers import (
     ArrowStreamPandasUDFSerializer,
     ArrowStreamPandasUDTFSerializer,
@@ -763,6 +764,7 @@ def read_udtf(pickleSer, infile, eval_type):
             self._udtf = create_udtf()
             self._prev_arguments: list = list()
             self._partition_child_indexes: list = partition_child_indexes
+            self._eval_raised_skip_rest_of_input_table: bool = False
 
         def eval(self, *args, **kwargs) -> Iterator:
             changed_partitions = self._check_partition_boundaries(
@@ -775,16 +777,24 @@ def read_udtf(pickleSer, infile, eval_type):
                         for row in result:
                             yield row
                 self._udtf = self._create_udtf()
-            if self._udtf.eval is not None:
+                self._eval_raised_skip_rest_of_input_table = False
+            if self._udtf.eval is not None and not 
self._eval_raised_skip_rest_of_input_table:
                 # Filter the arguments to exclude projected PARTITION BY 
values added by Catalyst.
                 filtered_args = [self._remove_partition_by_exprs(arg) for arg 
in args]
                 filtered_kwargs = {
                     key: self._remove_partition_by_exprs(value) for (key, 
value) in kwargs.items()
                 }
-                result = self._udtf.eval(*filtered_args, **filtered_kwargs)
-                if result is not None:
-                    for row in result:
-                        yield row
+                try:
+                    result = self._udtf.eval(*filtered_args, **filtered_kwargs)
+                    if result is not None:
+                        for row in result:
+                            yield row
+                except SkipRestOfInputTableException:
+                    # If the 'eval' method raised this exception, then we 
should skip the rest of
+                    # the rows in the current partition. Set this field to 
True here and then for
+                    # each subsequent row in the partition, we will skip 
calling the 'eval' method
+                    # until we see a change in the partition boundaries.
+                    self._eval_raised_skip_rest_of_input_table = True
 
         def terminate(self) -> Iterator:
             if self._udtf.terminate is not None:
@@ -995,6 +1005,8 @@ def read_udtf(pickleSer, infile, eval_type):
             def func(*a: Any) -> Any:
                 try:
                     return f(*a)
+                except SkipRestOfInputTableException:
+                    raise
                 except Exception as e:
                     raise PySparkRuntimeError(
                         error_class="UDTF_EXEC_ERROR",
@@ -1057,6 +1069,9 @@ def read_udtf(pickleSer, infile, eval_type):
                     yield from eval(*[a[o] for o in args_kwargs_offsets])
                 if terminate is not None:
                     yield from terminate()
+            except SkipRestOfInputTableException:
+                if terminate is not None:
+                    yield from terminate()
             finally:
                 if cleanup is not None:
                     cleanup()
@@ -1098,6 +1113,8 @@ def read_udtf(pickleSer, infile, eval_type):
             def evaluate(*a) -> tuple:
                 try:
                     res = f(*a)
+                except SkipRestOfInputTableException:
+                    raise
                 except Exception as e:
                     raise PySparkRuntimeError(
                         error_class="UDTF_EXEC_ERROR",
@@ -1144,6 +1161,9 @@ def read_udtf(pickleSer, infile, eval_type):
                     yield eval(*[a[o] for o in args_kwargs_offsets])
                 if terminate is not None:
                     yield terminate()
+            except SkipRestOfInputTableException:
+                if terminate is not None:
+                    yield terminate()
             finally:
                 if cleanup is not None:
                     cleanup()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to