This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new fe3a20a5e23 [SPARK-44876][PYTHON][FOLLOWUP][3.5] Fix Arrow-optimized 
Python UDF to delay wrapping the function with fail_on_stopiteration
fe3a20a5e23 is described below

commit fe3a20a5e231fd1516666b141e72ea8a1090647a
Author: Takuya UESHIN <ues...@databricks.com>
AuthorDate: Mon Sep 4 15:25:33 2023 +0800

    [SPARK-44876][PYTHON][FOLLOWUP][3.5] Fix Arrow-optimized Python UDF to 
delay wrapping the function with fail_on_stopiteration
    
    ### What changes were proposed in this pull request?
    
    This is a backport of https://github.com/apache/spark/pull/42784.
    
    Fixes Arrow-optimized Python UDF to delay wrapping the function with 
`fail_on_stopiteration`.
    
    Also removed unnecessary verification `verify_result_type`.
    
    ### Why are the changes needed?
    
    For Arrow-optimized Python UDF, `fail_on_stopiteration` can be applied to 
only the wrapped function to avoid unnecessary overhead.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Added the related test.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #42785 from ueshin/issues/SPARK-44876/3.5/fail_on_stopiteration.
    
    Authored-by: Takuya UESHIN <ues...@databricks.com>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 python/pyspark/sql/tests/test_udf.py | 21 ++++++++++++++++++---
 python/pyspark/worker.py             | 22 +++++++---------------
 2 files changed, 25 insertions(+), 18 deletions(-)

diff --git a/python/pyspark/sql/tests/test_udf.py 
b/python/pyspark/sql/tests/test_udf.py
index 239ff27813b..2f8c1cd2136 100644
--- a/python/pyspark/sql/tests/test_udf.py
+++ b/python/pyspark/sql/tests/test_udf.py
@@ -24,7 +24,7 @@ import datetime
 
 from pyspark import SparkContext, SQLContext
 from pyspark.sql import SparkSession, Column, Row
-from pyspark.sql.functions import udf, assert_true, lit, rand
+from pyspark.sql.functions import col, udf, assert_true, lit, rand
 from pyspark.sql.udf import UserDefinedFunction
 from pyspark.sql.types import (
     StringType,
@@ -38,9 +38,9 @@ from pyspark.sql.types import (
     TimestampNTZType,
     DayTimeIntervalType,
 )
-from pyspark.errors import AnalysisException, PySparkTypeError
+from pyspark.errors import AnalysisException, PythonException, PySparkTypeError
 from pyspark.testing.sqlutils import ReusedSQLTestCase, test_compiled, 
test_not_compiled_message
-from pyspark.testing.utils import QuietTest
+from pyspark.testing.utils import QuietTest, assertDataFrameEqual
 
 
 class BaseUDFTestsMixin(object):
@@ -898,6 +898,21 @@ class BaseUDFTestsMixin(object):
         self.assertEquals(row[1], {"a": "b"})
         self.assertEquals(row[2], Row(col1=1, col2=2))
 
+    def test_raise_stop_iteration(self):
+        @udf("int")
+        def test_udf(a):
+            if a < 5:
+                return a
+            else:
+                raise StopIteration()
+
+        assertDataFrameEqual(
+            self.spark.range(5).select(test_udf(col("id"))), [Row(i) for i in 
range(5)]
+        )
+
+        with self.assertRaisesRegex(PythonException, "StopIteration"):
+            self.spark.range(10).select(test_udf(col("id"))).show()
+
 
 class UDFTests(BaseUDFTestsMixin, ReusedSQLTestCase):
     @classmethod
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index edbfad4a5dc..d2ea18c45c9 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -154,20 +154,9 @@ def wrap_arrow_batch_udf(f, return_type):
     elif type(return_type) == BinaryType:
         result_func = lambda r: bytes(r) if r is not None else r  # noqa: E731
 
+    @fail_on_stopiteration
     def evaluate(*args: pd.Series) -> pd.Series:
-        return pd.Series(result_func(f(*a)) for a in zip(*args))
-
-    def verify_result_type(result):
-        if not hasattr(result, "__len__"):
-            pd_type = "pandas.DataFrame" if type(return_type) == StructType 
else "pandas.Series"
-            raise PySparkTypeError(
-                error_class="UDF_RETURN_TYPE",
-                message_parameters={
-                    "expected": pd_type,
-                    "actual": type(result).__name__,
-                },
-            )
-        return result
+        return pd.Series([result_func(f(*a)) for a in zip(*args)])
 
     def verify_result_length(result, length):
         if len(result) != length:
@@ -181,7 +170,7 @@ def wrap_arrow_batch_udf(f, return_type):
         return result
 
     return lambda *a: (
-        verify_result_length(verify_result_type(evaluate(*a)), len(a[0])),
+        verify_result_length(evaluate(*a), len(a[0])),
         arrow_return_type,
     )
 
@@ -543,7 +532,10 @@ def read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index):
         else:
             chained_func = chain(chained_func, f)
 
-    if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF:
+    if eval_type in (
+        PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
+        PythonEvalType.SQL_ARROW_BATCHED_UDF,
+    ):
         func = chained_func
     else:
         # make sure StopIteration's raised in the user code are not ignored


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to