This is an automated email from the ASF dual-hosted git repository.

ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3e22c8653d7 [SPARK-44640][PYTHON][FOLLOW-UP] Update UDTF error 
messages to include method name
3e22c8653d7 is described below

commit 3e22c8653d728a6b8523051faddcca437accfc22
Author: allisonwang-db <allison.w...@databricks.com>
AuthorDate: Sat Sep 2 16:07:09 2023 -0700

    [SPARK-44640][PYTHON][FOLLOW-UP] Update UDTF error messages to include 
method name
    
    ### What changes were proposed in this pull request?
    
    This PR is a follow-up for SPARK-44640 to make the error message of a few 
UDTF errors more informative by including the method name in the error message 
(`eval` or `terminate`).
    
    ### Why are the changes needed?
    
    To improve error messages.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Existing tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #42726 from allisonwang-db/SPARK-44640-follow-up.
    
    Authored-by: allisonwang-db <allison.w...@databricks.com>
    Signed-off-by: Takuya UESHIN <ues...@databricks.com>
---
 python/pyspark/errors/error_classes.py |  8 ++++----
 python/pyspark/sql/tests/test_udtf.py  | 21 +++++++++++++++++++
 python/pyspark/worker.py               | 37 +++++++++++++++++++++++++---------
 3 files changed, 52 insertions(+), 14 deletions(-)

diff --git a/python/pyspark/errors/error_classes.py 
b/python/pyspark/errors/error_classes.py
index ca448a169e8..74f52c416e9 100644
--- a/python/pyspark/errors/error_classes.py
+++ b/python/pyspark/errors/error_classes.py
@@ -244,7 +244,7 @@ ERROR_CLASSES_JSON = """
   },
   "INVALID_ARROW_UDTF_RETURN_TYPE" : {
     "message" : [
-      "The return type of the arrow-optimized Python UDTF should be of type 
'pandas.DataFrame', but the function returned a value of type <type_name> with 
value: <value>."
+      "The return type of the arrow-optimized Python UDTF should be of type 
'pandas.DataFrame', but the '<func>' method returned a value of type 
<type_name> with value: <value>."
     ]
   },
   "INVALID_BROADCAST_OPERATION": {
@@ -745,17 +745,17 @@ ERROR_CLASSES_JSON = """
   },
   "UDTF_INVALID_OUTPUT_ROW_TYPE" : {
     "message" : [
-        "The type of an individual output row in the UDTF is invalid. Each row 
should be a tuple, list, or dict, but got '<type>'. Please make sure that the 
output rows are of the correct type."
+        "The type of an individual output row in the '<func>' method of the 
UDTF is invalid. Each row should be a tuple, list, or dict, but got '<type>'. 
Please make sure that the output rows are of the correct type."
     ]
   },
   "UDTF_RETURN_NOT_ITERABLE" : {
     "message" : [
-      "The return value of the UDTF is invalid. It should be an iterable 
(e.g., generator or list), but got '<type>'. Please make sure that the UDTF 
returns one of these types."
+      "The return value of the '<func>' method of the UDTF is invalid. It 
should be an iterable (e.g., generator or list), but got '<type>'. Please make 
sure that the UDTF returns one of these types."
     ]
   },
   "UDTF_RETURN_SCHEMA_MISMATCH" : {
     "message" : [
-      "The number of columns in the result does not match the specified 
schema. Expected column count: <expected>, Actual column count: <actual>. 
Please make sure the values returned by the function have the same number of 
columns as specified in the output schema."
+      "The number of columns in the result does not match the specified 
schema. Expected column count: <expected>, Actual column count: <actual>. 
Please make sure the values returned by the '<func>' method have the same 
number of columns as specified in the output schema."
     ]
   },
   "UDTF_RETURN_TYPE_MISMATCH" : {
diff --git a/python/pyspark/sql/tests/test_udtf.py 
b/python/pyspark/sql/tests/test_udtf.py
index c5f8b7693c2..97d5190a506 100644
--- a/python/pyspark/sql/tests/test_udtf.py
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -190,6 +190,27 @@ class BaseUDTFTestsMixin:
         with self.assertRaisesRegex(PythonException, 
"UDTF_RETURN_NOT_ITERABLE"):
             TestUDTF(lit(1)).collect()
 
+    def test_udtf_with_zero_arg_and_invalid_return_value(self):
+        @udtf(returnType="x: int")
+        class TestUDTF:
+            def eval(self):
+                return 1
+
+        with self.assertRaisesRegex(PythonException, 
"UDTF_RETURN_NOT_ITERABLE"):
+            TestUDTF().collect()
+
+    def test_udtf_with_invalid_return_value_in_terminate(self):
+        @udtf(returnType="x: int")
+        class TestUDTF:
+            def eval(self, a):
+                ...
+
+            def terminate(self):
+                return 1
+
+        with self.assertRaisesRegex(PythonException, 
"UDTF_RETURN_NOT_ITERABLE"):
+            TestUDTF(lit(1)).collect()
+
     def test_udtf_eval_with_no_return(self):
         @udtf(returnType="a: int")
         class TestUDTF:
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index d95a5c4672f..fff99f1de3d 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -773,6 +773,7 @@ def read_udtf(pickleSer, infile, eval_type):
                         message_parameters={
                             "type_name": type(result).__name__,
                             "value": str(result),
+                            "func": f.__name__,
                         },
                     )
 
@@ -787,6 +788,7 @@ def read_udtf(pickleSer, infile, eval_type):
                             message_parameters={
                                 "expected": str(return_type_size),
                                 "actual": str(len(result.columns)),
+                                "func": f.__name__,
                             },
                         )
 
@@ -806,9 +808,23 @@ def read_udtf(pickleSer, infile, eval_type):
                         message_parameters={"method_name": f.__name__, 
"error": str(e)},
                     )
 
+            def check_return_value(res):
+                # Check whether the result of an arrow UDTF is iterable before
+                # using it to construct a pandas DataFrame.
+                if res is not None and not isinstance(res, Iterable):
+                    raise PySparkRuntimeError(
+                        error_class="UDTF_RETURN_NOT_ITERABLE",
+                        message_parameters={
+                            "type": type(res).__name__,
+                            "func": f.__name__,
+                        },
+                    )
+
             def evaluate(*args: pd.Series, **kwargs: pd.Series):
                 if len(args) == 0 and len(kwargs) == 0:
-                    yield verify_result(pd.DataFrame(func())), 
arrow_return_type
+                    res = func()
+                    check_return_value(res)
+                    yield verify_result(pd.DataFrame(res)), arrow_return_type
                 else:
                     # Create tuples from the input pandas Series, each tuple
                     # represents a row across all Series.
@@ -820,13 +836,7 @@ def read_udtf(pickleSer, infile, eval_type):
                             *row[:len_args],
                             **{key: row[len_args + i] for i, key in 
enumerate(keys)},
                         )
-                        if res is not None and not isinstance(res, Iterable):
-                            raise PySparkRuntimeError(
-                                error_class="UDTF_RETURN_NOT_ITERABLE",
-                                message_parameters={
-                                    "type": type(res).__name__,
-                                },
-                            )
+                        check_return_value(res)
                         yield verify_result(pd.DataFrame(res)), 
arrow_return_type
 
             return evaluate
@@ -868,13 +878,17 @@ def read_udtf(pickleSer, infile, eval_type):
                             message_parameters={
                                 "expected": str(return_type_size),
                                 "actual": str(len(result)),
+                                "func": f.__name__,
                             },
                         )
 
                     if not (isinstance(result, (list, dict, tuple)) or 
hasattr(result, "__dict__")):
                         raise PySparkRuntimeError(
                             error_class="UDTF_INVALID_OUTPUT_ROW_TYPE",
-                            message_parameters={"type": type(result).__name__},
+                            message_parameters={
+                                "type": type(result).__name__,
+                                "func": f.__name__,
+                            },
                         )
 
                 return toInternal(result)
@@ -898,7 +912,10 @@ def read_udtf(pickleSer, infile, eval_type):
                 if not isinstance(res, Iterable):
                     raise PySparkRuntimeError(
                         error_class="UDTF_RETURN_NOT_ITERABLE",
-                        message_parameters={"type": type(res).__name__},
+                        message_parameters={
+                            "type": type(res).__name__,
+                            "func": f.__name__,
+                        },
                     )
 
                 # If the function returns a result, we map it to the internal 
representation and


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to