This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 380c0f2033f [SPARK-44640][PYTHON] Improve error messages for Python 
UDTF returning non Iterable
380c0f2033f is described below

commit 380c0f2033fb83b5e4f13693d2576d72c5cc01f2
Author: allisonwang-db <allison.w...@databricks.com>
AuthorDate: Fri Aug 4 10:22:46 2023 +0900

    [SPARK-44640][PYTHON] Improve error messages for Python UDTF returning non 
Iterable
    
    ### What changes were proposed in this pull request?
    
    This PR improves error messages when the result of a Python UDTF is not an 
Iterable. It also improves the error messages when a UDTF encounters an 
exception when executing `eval`.
    
    ### Why are the changes needed?
    
    To make Python UDTFs more user-friendly.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. For example this UDTF:
    ```
    udtf(returnType="x: int")
    class TestUDTF:
        def eval(self, a):
            return a
    ```
    Before this PR, it fails with this error for regular UDTFs:
    ```
        return tuple(map(verify_and_convert_result, res))
    TypeError: 'int' object is not iterable
    ```
    And this error for arrow-optimized UDTFs:
    ```
        raise ValueError("DataFrame constructor not properly called!")
    ValueError: DataFrame constructor not properly called!
    ```
    
    After this PR, the error message will be:
    `pyspark.errors.exceptions.base.PySparkRuntimeError: 
[UDTF_RETURN_NOT_ITERABLE] The return value of the UDTF is invalid. It should 
be an iterable (e.g., generator or list), but got 'int'. Please make sure that 
the UDTF returns one of these types.`
    
    ### How was this patch tested?
    
    New UTs.
    
    Closes #42302 from allisonwang-db/spark-44640-udtf-non-iterable.
    
    Authored-by: allisonwang-db <allison.w...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/errors/error_classes.py |  5 ++++
 python/pyspark/sql/tests/test_udtf.py  | 42 +++++++++++++++++++++++++--
 python/pyspark/sql/udtf.py             | 40 ++++++++++++++++++++-----
 python/pyspark/worker.py               | 53 ++++++++++++++++++----------------
 4 files changed, 105 insertions(+), 35 deletions(-)

diff --git a/python/pyspark/errors/error_classes.py 
b/python/pyspark/errors/error_classes.py
index d6f093246da..84448f1507d 100644
--- a/python/pyspark/errors/error_classes.py
+++ b/python/pyspark/errors/error_classes.py
@@ -738,6 +738,11 @@ ERROR_CLASSES_JSON = """
       "User defined table function encountered an error in the '<method_name>' 
method: <error>"
     ]
   },
+  "UDTF_RETURN_NOT_ITERABLE" : {
+    "message" : [
+      "The return value of the UDTF is invalid. It should be an iterable 
(e.g., generator or list), but got '<type>'. Please make sure that the UDTF 
returns one of these types."
+    ]
+  },
   "UDTF_RETURN_SCHEMA_MISMATCH" : {
     "message" : [
       "The number of columns in the result does not match the specified 
schema. Expected column count: <expected>, Actual column count: <actual>. 
Please make sure the values returned by the function have the same number of 
columns as specified in the output schema."
diff --git a/python/pyspark/sql/tests/test_udtf.py 
b/python/pyspark/sql/tests/test_udtf.py
index 65184549573..26da83980e1 100644
--- a/python/pyspark/sql/tests/test_udtf.py
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -180,6 +180,15 @@ class BaseUDTFTestsMixin:
         with self.assertRaisesRegex(PythonException, "Unexpected tuple 1 with 
StructType"):
             func(lit(1)).collect()
 
+    def test_udtf_with_invalid_return_value(self):
+        @udtf(returnType="x: int")
+        class TestUDTF:
+            def eval(self, a):
+                return a
+
+        with self.assertRaisesRegex(PythonException, 
"UDTF_RETURN_NOT_ITERABLE"):
+            TestUDTF(lit(1)).collect()
+
     def test_udtf_eval_with_no_return(self):
         @udtf(returnType="a: int")
         class TestUDTF:
@@ -375,6 +384,35 @@ class BaseUDTFTestsMixin:
             ],
         )
 
+    def test_init_with_exception(self):
+        @udtf(returnType="x: int")
+        class TestUDTF:
+            def __init__(self):
+                raise Exception("error")
+
+            def eval(self):
+                yield 1,
+
+        with self.assertRaisesRegex(
+            PythonException,
+            r"\[UDTF_EXEC_ERROR\] User defined table function encountered an 
error "
+            r"in the '__init__' method: error",
+        ):
+            TestUDTF().show()
+
+    def test_eval_with_exception(self):
+        @udtf(returnType="x: int")
+        class TestUDTF:
+            def eval(self):
+                raise Exception("error")
+
+        with self.assertRaisesRegex(
+            PythonException,
+            r"\[UDTF_EXEC_ERROR\] User defined table function encountered an 
error "
+            r"in the 'eval' method: error",
+        ):
+            TestUDTF().show()
+
     def test_terminate_with_exceptions(self):
         @udtf(returnType="a: int, b: int")
         class TestUDTF:
@@ -386,8 +424,8 @@ class BaseUDTFTestsMixin:
 
         with self.assertRaisesRegex(
             PythonException,
-            "User defined table function encountered an error in the 
'terminate' "
-            "method: terminate error",
+            r"\[UDTF_EXEC_ERROR\] User defined table function encountered an 
error "
+            r"in the 'terminate' method: terminate error",
         ):
             TestUDTF(lit(1)).collect()
 
diff --git a/python/pyspark/sql/udtf.py b/python/pyspark/sql/udtf.py
index d14a263f839..74a9084c6cd 100644
--- a/python/pyspark/sql/udtf.py
+++ b/python/pyspark/sql/udtf.py
@@ -18,14 +18,15 @@
 User-defined table function related classes and functions
 """
 from dataclasses import dataclass
+from functools import wraps
 import inspect
 import sys
 import warnings
-from typing import Any, Iterator, Type, TYPE_CHECKING, Optional, Union
+from typing import Any, Iterable, Iterator, Type, TYPE_CHECKING, Optional, 
Union, Callable
 
 from py4j.java_gateway import JavaObject
 
-from pyspark.errors import PySparkAttributeError, PySparkTypeError
+from pyspark.errors import PySparkAttributeError, PySparkRuntimeError, 
PySparkTypeError
 from pyspark.rdd import PythonEvalType
 from pyspark.sql.column import _to_java_column, _to_seq
 from pyspark.sql.pandas.utils import require_minimum_pandas_version, 
require_minimum_pyarrow_version
@@ -143,6 +144,20 @@ def _vectorize_udtf(cls: Type) -> Type:
     """Vectorize a Python UDTF handler class."""
     import pandas as pd
 
+    # Wrap the exception thrown from the UDTF in a PySparkRuntimeError.
+    def wrap_func(f: Callable[..., Any]) -> Callable[..., Any]:
+        @wraps(f)
+        def evaluate(*a: Any) -> Any:
+            try:
+                return f(*a)
+            except Exception as e:
+                raise PySparkRuntimeError(
+                    error_class="UDTF_EXEC_ERROR",
+                    message_parameters={"method_name": f.__name__, "error": 
str(e)},
+                )
+
+        return evaluate
+
     class VectorizedUDTF:
         def __init__(self) -> None:
             self.func = cls()
@@ -157,17 +172,26 @@ def _vectorize_udtf(cls: Type) -> Type:
 
         def eval(self, *args: pd.Series) -> Iterator[pd.DataFrame]:
             if len(args) == 0:
-                yield pd.DataFrame(self.func.eval())
+                yield pd.DataFrame(wrap_func(self.func.eval)())
             else:
                 # Create tuples from the input pandas Series, each tuple
                 # represents a row across all Series.
                 row_tuples = zip(*args)
                 for row in row_tuples:
-                    yield pd.DataFrame(self.func.eval(*row))
-
-        def terminate(self) -> Iterator[pd.DataFrame]:
-            if hasattr(self.func, "terminate"):
-                yield pd.DataFrame(self.func.terminate())
+                    res = wrap_func(self.func.eval)(*row)
+                    if res is not None and not isinstance(res, Iterable):
+                        raise PySparkRuntimeError(
+                            error_class="UDTF_RETURN_NOT_ITERABLE",
+                            message_parameters={
+                                "type": type(res).__name__,
+                            },
+                        )
+                    yield pd.DataFrame(res)
+
+        if hasattr(cls, "terminate"):
+
+            def terminate(self) -> Iterator[pd.DataFrame]:
+                yield pd.DataFrame(wrap_func(self.func.terminate)())
 
     vectorized_udtf = VectorizedUDTF
     vectorized_udtf.__name__ = cls.__name__
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 20e856c9add..3acfa58b6fb 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -23,7 +23,7 @@ import sys
 import time
 from inspect import currentframe, getframeinfo, getfullargspec
 import json
-from typing import Iterator
+from typing import Iterable, Iterator
 
 # 'resource' is a Unix specific module.
 has_resource_module = True
@@ -591,6 +591,7 @@ def read_udtf(pickleSer, infile, eval_type):
 
         def wrap_arrow_udtf(f, return_type):
             arrow_return_type = to_arrow_type(return_type)
+            return_type_size = len(return_type)
 
             def verify_result(result):
                 import pandas as pd
@@ -599,7 +600,7 @@ def read_udtf(pickleSer, infile, eval_type):
                     raise PySparkTypeError(
                         error_class="INVALID_ARROW_UDTF_RETURN_TYPE",
                         message_parameters={
-                            "type_name": type(result).__name_,
+                            "type_name": type(result).__name__,
                             "value": str(result),
                         },
                     )
@@ -609,11 +610,11 @@ def read_udtf(pickleSer, infile, eval_type):
                 # result dataframe may contain an empty row. For example, when 
a UDTF is
                 # defined as follows: def eval(self): yield tuple().
                 if len(result) > 0 or len(result.columns) > 0:
-                    if len(result.columns) != len(return_type):
+                    if len(result.columns) != return_type_size:
                         raise PySparkRuntimeError(
                             error_class="UDTF_RETURN_SCHEMA_MISMATCH",
                             message_parameters={
-                                "expected": str(len(return_type)),
+                                "expected": str(return_type_size),
                                 "actual": str(len(result.columns)),
                             },
                         )
@@ -641,13 +642,7 @@ def read_udtf(pickleSer, infile, eval_type):
                     yield from eval(*[a[o] for o in arg_offsets])
             finally:
                 if terminate is not None:
-                    try:
-                        yield from terminate()
-                    except BaseException as e:
-                        raise PySparkRuntimeError(
-                            error_class="UDTF_EXEC_ERROR",
-                            message_parameters={"method_name": "terminate", 
"error": str(e)},
-                        )
+                    yield from terminate()
 
         return mapper, None, ser, ser
 
@@ -656,15 +651,16 @@ def read_udtf(pickleSer, infile, eval_type):
         def wrap_udtf(f, return_type):
             assert return_type.needConversion()
             toInternal = return_type.toInternal
+            return_type_size = len(return_type)
 
             def verify_and_convert_result(result):
                 # TODO(SPARK-44005): support returning non-tuple values
                 if result is not None and hasattr(result, "__len__"):
-                    if len(result) != len(return_type):
+                    if len(result) != return_type_size:
                         raise PySparkRuntimeError(
                             error_class="UDTF_RETURN_SCHEMA_MISMATCH",
                             message_parameters={
-                                "expected": str(len(return_type)),
+                                "expected": str(return_type_size),
                                 "actual": str(len(result)),
                             },
                         )
@@ -672,16 +668,29 @@ def read_udtf(pickleSer, infile, eval_type):
 
             # Evaluate the function and return a tuple back to the executor.
             def evaluate(*a) -> tuple:
-                res = f(*a)
+                try:
+                    res = f(*a)
+                except Exception as e:
+                    raise PySparkRuntimeError(
+                        error_class="UDTF_EXEC_ERROR",
+                        message_parameters={"method_name": f.__name__, 
"error": str(e)},
+                    )
+
                 if res is None:
                     # If the function returns None or does not have an 
explicit return statement,
                     # an empty tuple is returned to the executor.
                     # This is because directly constructing tuple(None) 
results in an exception.
                     return tuple()
-                else:
-                    # If the function returns a result, we map it to the 
internal representation and
-                    # returns the results as a tuple.
-                    return tuple(map(verify_and_convert_result, res))
+
+                if not isinstance(res, Iterable):
+                    raise PySparkRuntimeError(
+                        error_class="UDTF_RETURN_NOT_ITERABLE",
+                        message_parameters={"type": type(res).__name__},
+                    )
+
+                # If the function returns a result, we map it to the internal 
representation and
+                # returns the results as a tuple.
+                return tuple(map(verify_and_convert_result, res))
 
             return evaluate
 
@@ -699,13 +708,7 @@ def read_udtf(pickleSer, infile, eval_type):
                     yield eval(*[a[o] for o in arg_offsets])
             finally:
                 if terminate is not None:
-                    try:
-                        yield terminate()
-                    except BaseException as e:
-                        raise PySparkRuntimeError(
-                            error_class="UDTF_EXEC_ERROR",
-                            message_parameters={"method_name": "terminate", 
"error": str(e)},
-                        )
+                    yield terminate()
 
         return mapper, None, ser, ser
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to