This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new f2a6c97d718 [SPARK-44876][PYTHON][FOLLOWUP] Fix Arrow-optimized Python UDF to delay wrapping the function with fail_on_stopiteration f2a6c97d718 is described below commit f2a6c97d718839896343feaa520396f328f2f866 Author: Takuya UESHIN <ues...@databricks.com> AuthorDate: Mon Sep 4 15:24:33 2023 +0800 [SPARK-44876][PYTHON][FOLLOWUP] Fix Arrow-optimized Python UDF to delay wrapping the function with fail_on_stopiteration ### What changes were proposed in this pull request? Fixes Arrow-optimized Python UDF to delay wrapping the function with `fail_on_stopiteration`. Also removed unnecessary verification `verify_result_type`. ### Why are the changes needed? For Arrow-optimized Python UDF, `fail_on_stopiteration` can be applied to only the wrapped function to avoid unnecessary overhead. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added the related test. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #42784 from ueshin/issues/SPARK-44876/fail_on_stopiteration. Authored-by: Takuya UESHIN <ues...@databricks.com> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- python/pyspark/sql/tests/test_udf.py | 15 +++++++++++++++ python/pyspark/worker.py | 22 ++++++---------------- 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py index 32ea05bd00a..1f895b1780b 100644 --- a/python/pyspark/sql/tests/test_udf.py +++ b/python/pyspark/sql/tests/test_udf.py @@ -1005,6 +1005,21 @@ class BaseUDFTestsMixin(object): with self.subTest(with_b=True, query_no=i): assertDataFrameEqual(df, [Row(0), Row(101)]) + def test_raise_stop_iteration(self): + @udf("int") + def test_udf(a): + if a < 5: + return a + else: + raise StopIteration() + + assertDataFrameEqual( + self.spark.range(5).select(test_udf(col("id"))), [Row(i) for i in range(5)] + ) + + with self.assertRaisesRegex(PythonException, "StopIteration"): + self.spark.range(10).select(test_udf(col("id"))).show() + class UDFTests(BaseUDFTestsMixin, ReusedSQLTestCase): @classmethod diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index fff99f1de3d..92bc622775b 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -139,6 +139,7 @@ def wrap_arrow_batch_udf(f, return_type): elif type(return_type) == BinaryType: result_func = lambda r: bytes(r) if r is not None else r # noqa: E731 + @fail_on_stopiteration def evaluate(*args: pd.Series, **kwargs: pd.Series) -> pd.Series: keys = list(kwargs.keys()) len_args = len(args) @@ -151,18 +152,6 @@ def wrap_arrow_batch_udf(f, return_type): ] ) - def verify_result_type(result): - if not hasattr(result, "__len__"): - pd_type = "pandas.DataFrame" if type(return_type) == StructType else "pandas.Series" - raise PySparkTypeError( - error_class="UDF_RETURN_TYPE", - message_parameters={ - "expected": pd_type, - "actual": type(result).__name__, - }, - ) - return result - def verify_result_length(result, length): if len(result) != length: raise PySparkRuntimeError( @@ -175,9 +164,7 @@ def wrap_arrow_batch_udf(f, return_type): return result return lambda *a, **kw: ( - verify_result_length( - verify_result_type(evaluate(*a, **kw)), len((list(a) + list(kw.values()))[0]) - ), + verify_result_length(evaluate(*a, **kw), len((list(a) + list(kw.values()))[0])), arrow_return_type, ) @@ -562,7 +549,10 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): else: chained_func = chain(chained_func, f) - if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF: + if eval_type in ( + PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, + PythonEvalType.SQL_ARROW_BATCHED_UDF, + ): func = chained_func else: # make sure StopIteration's raised in the user code are not ignored --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org