This is an automated email from the ASF dual-hosted git repository. ueshin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 36983443112 [SPARK-45810][PYTHON] Create Python UDTF API to stop consuming rows from the input table 36983443112 is described below commit 36983443112799dc2ee4462828e7c0552a63a229 Author: Daniel Tenedorio <daniel.tenedo...@databricks.com> AuthorDate: Wed Nov 15 13:47:04 2023 -0800 [SPARK-45810][PYTHON] Create Python UDTF API to stop consuming rows from the input table ### What changes were proposed in this pull request? This PR creates a Python UDTF API to stop consuming rows from the input table. If the UDTF raises a `SkipRestOfInputTableException` exception in the `eval` method, then the UDTF stops consuming rows from the input table for that input partition, and finally calls the `terminate` method (if any) to represent a successful UDTF call. For example: ``` udtf(returnType="total: int") class TestUDTF: def __init__(self): self._total = 0 def eval(self, _: Row): self._total += 1 if self._total >= 3: raise SkipRestOfInputTableException("Stop at self._total >= 3") def terminate(self): yield self._total, ``` ### Why are the changes needed? This is useful when the UDTF logic knows that we don't have to scan the input table anymore, and skip the rest of the I/O for that case. ### Does this PR introduce _any_ user-facing change? Yes, see above. ### How was this patch tested? This PR adds test coverage. ### Was this patch authored or co-authored using generative AI tooling? No Closes #43682 from dtenedor/udtf-api-stop-consuming-input-rows. Authored-by: Daniel Tenedorio <daniel.tenedo...@databricks.com> Signed-off-by: Takuya UESHIN <ues...@databricks.com> --- python/docs/source/user_guide/sql/python_udtf.rst | 38 ++++++++++++----- python/pyspark/sql/functions.py | 1 + python/pyspark/sql/tests/test_udtf.py | 51 +++++++++++++++++++++++ python/pyspark/sql/udtf.py | 19 ++++++++- python/pyspark/worker.py | 30 ++++++++++--- 5 files changed, 123 insertions(+), 16 deletions(-) diff --git a/python/docs/source/user_guide/sql/python_udtf.rst b/python/docs/source/user_guide/sql/python_udtf.rst index 0e0c6e28578..3e3c7634438 100644 --- a/python/docs/source/user_guide/sql/python_udtf.rst +++ b/python/docs/source/user_guide/sql/python_udtf.rst @@ -65,8 +65,8 @@ To implement a Python UDTF, you first need to define a class implementing the me def analyze(self, *args: Any) -> AnalyzeResult: """ - Computes the output schema of a particular call to this function in response to the - arguments provided. + Static method to compute the output schema of a particular call to this function in + response to the arguments provided. This method is optional and only needed if the registration of the UDTF did not provide a static output schema to be use for all calls to the function. In this context, @@ -101,12 +101,20 @@ To implement a Python UDTF, you first need to define a class implementing the me partitionBy: Sequence[PartitioningColumn] = field(default_factory=tuple) orderBy: Sequence[OrderingColumn] = field(default_factory=tuple) + Notes + ----- + - It is possible for the `analyze` method to accept the exact arguments expected, + mapping 1:1 with the arguments provided to the UDTF call. + - The `analyze` method can instead choose to accept positional arguments if desired + (using `*args`) or keyword arguments (using `**kwargs`). + Examples -------- - analyze implementation that returns one output column for each word in the input string - argument. + This is an `analyze` implementation that returns one output column for each word in the + input string argument. - >>> def analyze(self, text: str) -> AnalyzeResult: + >>> @staticmethod + ... def analyze(text: str) -> AnalyzeResult: ... schema = StructType() ... for index, word in enumerate(text.split(" ")): ... schema = schema.add(f"word_{index}") @@ -114,7 +122,8 @@ To implement a Python UDTF, you first need to define a class implementing the me Same as above, but using *args to accept the arguments. - >>> def analyze(self, *args) -> AnalyzeResult: + >>> @staticmethod + ... def analyze(*args) -> AnalyzeResult: ... assert len(args) == 1, "This function accepts one argument only" ... assert args[0].dataType == StringType(), "Only string arguments are supported" ... text = args[0] @@ -125,7 +134,8 @@ To implement a Python UDTF, you first need to define a class implementing the me Same as above, but using **kwargs to accept the arguments. - >>> def analyze(self, **kwargs) -> AnalyzeResult: + >>> @staticmethod + ... def analyze(**kwargs) -> AnalyzeResult: ... assert len(kwargs) == 1, "This function accepts one argument only" ... assert "text" in kwargs, "An argument named 'text' is required" ... assert kwargs["text"].dataType == StringType(), "Only strings are supported" @@ -135,10 +145,11 @@ To implement a Python UDTF, you first need to define a class implementing the me ... schema = schema.add(f"word_{index}") ... return AnalyzeResult(schema=schema) - analyze implementation that returns a constant output schema, but add custom information - in the result metadata to be consumed by future __init__ method calls: + An `analyze` implementation that returns a constant output schema, but add custom + information in the result metadata to be consumed by future __init__ method calls: - >>> def analyze(self, text: str) -> AnalyzeResult: + >>> @staticmethod + ... def analyze(text: str) -> AnalyzeResult: ... @dataclass ... class AnalyzeResultWithOtherMetadata(AnalyzeResult): ... num_words: int @@ -190,6 +201,13 @@ To implement a Python UDTF, you first need to define a class implementing the me - It is also possible for UDTFs to accept the exact arguments expected, along with their types. - UDTFs can instead accept keyword arguments during the function call if needed. + - The `eval` method can raise a `SkipRestOfInputTableException` to indicate that the + UDTF wants to skip consuming all remaining rows from the current partition of the + input table. This will cause the UDTF to proceed directly to the `terminate` method. + - The `eval` method can raise any other exception to indicate that the UDTF should be + aborted entirely. This will cause the UDTF to skip the `terminate` method and proceed + directly to the `cleanup` method, and then the exception will be propagated to the + query processor causing the invoking query to fail. Examples -------- diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index ae0f1e70be6..e3b8e4965e4 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -51,6 +51,7 @@ from pyspark.sql.types import ArrayType, DataType, StringType, StructType, _from from pyspark.sql.udf import UserDefinedFunction, _create_py_udf # noqa: F401 from pyspark.sql.udtf import AnalyzeArgument, AnalyzeResult # noqa: F401 from pyspark.sql.udtf import OrderingColumn, PartitioningColumn # noqa: F401 +from pyspark.sql.udtf import SkipRestOfInputTableException # noqa: F401 from pyspark.sql.udtf import UserDefinedTableFunction, _create_py_udtf # Keep pandas_udf and PandasUDFType import for backwards compatible import; moved in SPARK-28264 diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index 3beb916de66..2794b51eb70 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -44,6 +44,7 @@ from pyspark.sql.functions import ( AnalyzeResult, OrderingColumn, PartitioningColumn, + SkipRestOfInputTableException, ) from pyspark.sql.types import ( ArrayType, @@ -2467,6 +2468,56 @@ class BaseUDTFTestsMixin: [Row(count=20, buffer="abc")], ) + def test_udtf_with_skip_rest_of_input_table_exception(self): + @udtf(returnType="current: int, total: int") + class TestUDTF: + def __init__(self): + self._current = 0 + self._total = 0 + + def eval(self, input: Row): + self._current = input["id"] + self._total += 1 + if self._total >= 4: + raise SkipRestOfInputTableException("Stop at self._total >= 4") + + def terminate(self): + yield self._current, self._total + + self.spark.udtf.register("test_udtf", TestUDTF) + + # Run a test case including WITH SINGLE PARTITION on the UDTF call. The + # SkipRestOfInputTableException stops scanning rows after the fourth input row is consumed. + assertDataFrameEqual( + self.spark.sql( + """ + WITH t AS ( + SELECT id FROM range(1, 21) + ) + SELECT current, total + FROM test_udtf(TABLE(t) WITH SINGLE PARTITION ORDER BY id) + """ + ), + [Row(current=4, total=4)], + ) + + # Run a test case including WITH SINGLE PARTITION on the UDTF call. The + # SkipRestOfInputTableException stops scanning rows for each of the two partitions + # separately. + assertDataFrameEqual( + self.spark.sql( + """ + WITH t AS ( + SELECT id FROM range(1, 21) + ) + SELECT current, total + FROM test_udtf(TABLE(t) PARTITION BY floor(id / 10) ORDER BY id) + ORDER BY ALL + """ + ), + [Row(current=4, total=4), Row(current=13, total=4), Row(current=20, total=1)], + ) + class UDTFTests(BaseUDTFTestsMixin, ReusedSQLTestCase): @classmethod diff --git a/python/pyspark/sql/udtf.py b/python/pyspark/sql/udtf.py index aac212ffde9..ab330141514 100644 --- a/python/pyspark/sql/udtf.py +++ b/python/pyspark/sql/udtf.py @@ -38,7 +38,14 @@ if TYPE_CHECKING: from pyspark.sql.dataframe import DataFrame from pyspark.sql.session import SparkSession -__all__ = ["AnalyzeArgument", "AnalyzeResult", "UDTFRegistration"] +__all__ = [ + "AnalyzeArgument", + "AnalyzeResult", + "PartitioningColumn", + "OrderingColumn", + "SkipRestOfInputTableException", + "UDTFRegistration", +] @dataclass(frozen=True) @@ -118,6 +125,16 @@ class AnalyzeResult: orderBy: Sequence[OrderingColumn] = field(default_factory=tuple) +class SkipRestOfInputTableException(Exception): + """ + This represents an exception that the 'eval' method may raise to indicate that it is done + consuming rows from the current partition of the input table. Then the UDTF's 'terminate' + method runs (if any). + """ + + pass + + def _create_udtf( cls: Type, returnType: Optional[Union[StructType, str]], diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index f6208032d9a..195c989c410 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -43,6 +43,7 @@ from pyspark.serializers import ( CPickleSerializer, BatchedSerializer, ) +from pyspark.sql.functions import SkipRestOfInputTableException from pyspark.sql.pandas.serializers import ( ArrowStreamPandasUDFSerializer, ArrowStreamPandasUDTFSerializer, @@ -763,6 +764,7 @@ def read_udtf(pickleSer, infile, eval_type): self._udtf = create_udtf() self._prev_arguments: list = list() self._partition_child_indexes: list = partition_child_indexes + self._eval_raised_skip_rest_of_input_table: bool = False def eval(self, *args, **kwargs) -> Iterator: changed_partitions = self._check_partition_boundaries( @@ -775,16 +777,24 @@ def read_udtf(pickleSer, infile, eval_type): for row in result: yield row self._udtf = self._create_udtf() - if self._udtf.eval is not None: + self._eval_raised_skip_rest_of_input_table = False + if self._udtf.eval is not None and not self._eval_raised_skip_rest_of_input_table: # Filter the arguments to exclude projected PARTITION BY values added by Catalyst. filtered_args = [self._remove_partition_by_exprs(arg) for arg in args] filtered_kwargs = { key: self._remove_partition_by_exprs(value) for (key, value) in kwargs.items() } - result = self._udtf.eval(*filtered_args, **filtered_kwargs) - if result is not None: - for row in result: - yield row + try: + result = self._udtf.eval(*filtered_args, **filtered_kwargs) + if result is not None: + for row in result: + yield row + except SkipRestOfInputTableException: + # If the 'eval' method raised this exception, then we should skip the rest of + # the rows in the current partition. Set this field to True here and then for + # each subsequent row in the partition, we will skip calling the 'eval' method + # until we see a change in the partition boundaries. + self._eval_raised_skip_rest_of_input_table = True def terminate(self) -> Iterator: if self._udtf.terminate is not None: @@ -995,6 +1005,8 @@ def read_udtf(pickleSer, infile, eval_type): def func(*a: Any) -> Any: try: return f(*a) + except SkipRestOfInputTableException: + raise except Exception as e: raise PySparkRuntimeError( error_class="UDTF_EXEC_ERROR", @@ -1057,6 +1069,9 @@ def read_udtf(pickleSer, infile, eval_type): yield from eval(*[a[o] for o in args_kwargs_offsets]) if terminate is not None: yield from terminate() + except SkipRestOfInputTableException: + if terminate is not None: + yield from terminate() finally: if cleanup is not None: cleanup() @@ -1098,6 +1113,8 @@ def read_udtf(pickleSer, infile, eval_type): def evaluate(*a) -> tuple: try: res = f(*a) + except SkipRestOfInputTableException: + raise except Exception as e: raise PySparkRuntimeError( error_class="UDTF_EXEC_ERROR", @@ -1144,6 +1161,9 @@ def read_udtf(pickleSer, infile, eval_type): yield eval(*[a[o] for o in args_kwargs_offsets]) if terminate is not None: yield terminate() + except SkipRestOfInputTableException: + if terminate is not None: + yield terminate() finally: if cleanup is not None: cleanup() --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org