This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 380c0f2033f [SPARK-44640][PYTHON] Improve error messages for Python UDTF returning non Iterable 380c0f2033f is described below commit 380c0f2033fb83b5e4f13693d2576d72c5cc01f2 Author: allisonwang-db <allison.w...@databricks.com> AuthorDate: Fri Aug 4 10:22:46 2023 +0900 [SPARK-44640][PYTHON] Improve error messages for Python UDTF returning non Iterable ### What changes were proposed in this pull request? This PR improves error messages when the result of a Python UDTF is not an Iterable. It also improves the error messages when a UDTF encounters an exception when executing `eval`. ### Why are the changes needed? To make Python UDTFs more user-friendly. ### Does this PR introduce _any_ user-facing change? Yes. For example this UDTF: ``` udtf(returnType="x: int") class TestUDTF: def eval(self, a): return a ``` Before this PR, it fails with this error for regular UDTFs: ``` return tuple(map(verify_and_convert_result, res)) TypeError: 'int' object is not iterable ``` And this error for arrow-optimized UDTFs: ``` raise ValueError("DataFrame constructor not properly called!") ValueError: DataFrame constructor not properly called! ``` After this PR, the error message will be: `pyspark.errors.exceptions.base.PySparkRuntimeError: [UDTF_RETURN_NOT_ITERABLE] The return value of the UDTF is invalid. It should be an iterable (e.g., generator or list), but got 'int'. Please make sure that the UDTF returns one of these types.` ### How was this patch tested? New UTs. Closes #42302 from allisonwang-db/spark-44640-udtf-non-iterable. Authored-by: allisonwang-db <allison.w...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/errors/error_classes.py | 5 ++++ python/pyspark/sql/tests/test_udtf.py | 42 +++++++++++++++++++++++++-- python/pyspark/sql/udtf.py | 40 ++++++++++++++++++++----- python/pyspark/worker.py | 53 ++++++++++++++++++---------------- 4 files changed, 105 insertions(+), 35 deletions(-) diff --git a/python/pyspark/errors/error_classes.py b/python/pyspark/errors/error_classes.py index d6f093246da..84448f1507d 100644 --- a/python/pyspark/errors/error_classes.py +++ b/python/pyspark/errors/error_classes.py @@ -738,6 +738,11 @@ ERROR_CLASSES_JSON = """ "User defined table function encountered an error in the '<method_name>' method: <error>" ] }, + "UDTF_RETURN_NOT_ITERABLE" : { + "message" : [ + "The return value of the UDTF is invalid. It should be an iterable (e.g., generator or list), but got '<type>'. Please make sure that the UDTF returns one of these types." + ] + }, "UDTF_RETURN_SCHEMA_MISMATCH" : { "message" : [ "The number of columns in the result does not match the specified schema. Expected column count: <expected>, Actual column count: <actual>. Please make sure the values returned by the function have the same number of columns as specified in the output schema." diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index 65184549573..26da83980e1 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -180,6 +180,15 @@ class BaseUDTFTestsMixin: with self.assertRaisesRegex(PythonException, "Unexpected tuple 1 with StructType"): func(lit(1)).collect() + def test_udtf_with_invalid_return_value(self): + @udtf(returnType="x: int") + class TestUDTF: + def eval(self, a): + return a + + with self.assertRaisesRegex(PythonException, "UDTF_RETURN_NOT_ITERABLE"): + TestUDTF(lit(1)).collect() + def test_udtf_eval_with_no_return(self): @udtf(returnType="a: int") class TestUDTF: @@ -375,6 +384,35 @@ class BaseUDTFTestsMixin: ], ) + def test_init_with_exception(self): + @udtf(returnType="x: int") + class TestUDTF: + def __init__(self): + raise Exception("error") + + def eval(self): + yield 1, + + with self.assertRaisesRegex( + PythonException, + r"\[UDTF_EXEC_ERROR\] User defined table function encountered an error " + r"in the '__init__' method: error", + ): + TestUDTF().show() + + def test_eval_with_exception(self): + @udtf(returnType="x: int") + class TestUDTF: + def eval(self): + raise Exception("error") + + with self.assertRaisesRegex( + PythonException, + r"\[UDTF_EXEC_ERROR\] User defined table function encountered an error " + r"in the 'eval' method: error", + ): + TestUDTF().show() + def test_terminate_with_exceptions(self): @udtf(returnType="a: int, b: int") class TestUDTF: @@ -386,8 +424,8 @@ class BaseUDTFTestsMixin: with self.assertRaisesRegex( PythonException, - "User defined table function encountered an error in the 'terminate' " - "method: terminate error", + r"\[UDTF_EXEC_ERROR\] User defined table function encountered an error " + r"in the 'terminate' method: terminate error", ): TestUDTF(lit(1)).collect() diff --git a/python/pyspark/sql/udtf.py b/python/pyspark/sql/udtf.py index d14a263f839..74a9084c6cd 100644 --- a/python/pyspark/sql/udtf.py +++ b/python/pyspark/sql/udtf.py @@ -18,14 +18,15 @@ User-defined table function related classes and functions """ from dataclasses import dataclass +from functools import wraps import inspect import sys import warnings -from typing import Any, Iterator, Type, TYPE_CHECKING, Optional, Union +from typing import Any, Iterable, Iterator, Type, TYPE_CHECKING, Optional, Union, Callable from py4j.java_gateway import JavaObject -from pyspark.errors import PySparkAttributeError, PySparkTypeError +from pyspark.errors import PySparkAttributeError, PySparkRuntimeError, PySparkTypeError from pyspark.rdd import PythonEvalType from pyspark.sql.column import _to_java_column, _to_seq from pyspark.sql.pandas.utils import require_minimum_pandas_version, require_minimum_pyarrow_version @@ -143,6 +144,20 @@ def _vectorize_udtf(cls: Type) -> Type: """Vectorize a Python UDTF handler class.""" import pandas as pd + # Wrap the exception thrown from the UDTF in a PySparkRuntimeError. + def wrap_func(f: Callable[..., Any]) -> Callable[..., Any]: + @wraps(f) + def evaluate(*a: Any) -> Any: + try: + return f(*a) + except Exception as e: + raise PySparkRuntimeError( + error_class="UDTF_EXEC_ERROR", + message_parameters={"method_name": f.__name__, "error": str(e)}, + ) + + return evaluate + class VectorizedUDTF: def __init__(self) -> None: self.func = cls() @@ -157,17 +172,26 @@ def _vectorize_udtf(cls: Type) -> Type: def eval(self, *args: pd.Series) -> Iterator[pd.DataFrame]: if len(args) == 0: - yield pd.DataFrame(self.func.eval()) + yield pd.DataFrame(wrap_func(self.func.eval)()) else: # Create tuples from the input pandas Series, each tuple # represents a row across all Series. row_tuples = zip(*args) for row in row_tuples: - yield pd.DataFrame(self.func.eval(*row)) - - def terminate(self) -> Iterator[pd.DataFrame]: - if hasattr(self.func, "terminate"): - yield pd.DataFrame(self.func.terminate()) + res = wrap_func(self.func.eval)(*row) + if res is not None and not isinstance(res, Iterable): + raise PySparkRuntimeError( + error_class="UDTF_RETURN_NOT_ITERABLE", + message_parameters={ + "type": type(res).__name__, + }, + ) + yield pd.DataFrame(res) + + if hasattr(cls, "terminate"): + + def terminate(self) -> Iterator[pd.DataFrame]: + yield pd.DataFrame(wrap_func(self.func.terminate)()) vectorized_udtf = VectorizedUDTF vectorized_udtf.__name__ = cls.__name__ diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 20e856c9add..3acfa58b6fb 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -23,7 +23,7 @@ import sys import time from inspect import currentframe, getframeinfo, getfullargspec import json -from typing import Iterator +from typing import Iterable, Iterator # 'resource' is a Unix specific module. has_resource_module = True @@ -591,6 +591,7 @@ def read_udtf(pickleSer, infile, eval_type): def wrap_arrow_udtf(f, return_type): arrow_return_type = to_arrow_type(return_type) + return_type_size = len(return_type) def verify_result(result): import pandas as pd @@ -599,7 +600,7 @@ def read_udtf(pickleSer, infile, eval_type): raise PySparkTypeError( error_class="INVALID_ARROW_UDTF_RETURN_TYPE", message_parameters={ - "type_name": type(result).__name_, + "type_name": type(result).__name__, "value": str(result), }, ) @@ -609,11 +610,11 @@ def read_udtf(pickleSer, infile, eval_type): # result dataframe may contain an empty row. For example, when a UDTF is # defined as follows: def eval(self): yield tuple(). if len(result) > 0 or len(result.columns) > 0: - if len(result.columns) != len(return_type): + if len(result.columns) != return_type_size: raise PySparkRuntimeError( error_class="UDTF_RETURN_SCHEMA_MISMATCH", message_parameters={ - "expected": str(len(return_type)), + "expected": str(return_type_size), "actual": str(len(result.columns)), }, ) @@ -641,13 +642,7 @@ def read_udtf(pickleSer, infile, eval_type): yield from eval(*[a[o] for o in arg_offsets]) finally: if terminate is not None: - try: - yield from terminate() - except BaseException as e: - raise PySparkRuntimeError( - error_class="UDTF_EXEC_ERROR", - message_parameters={"method_name": "terminate", "error": str(e)}, - ) + yield from terminate() return mapper, None, ser, ser @@ -656,15 +651,16 @@ def read_udtf(pickleSer, infile, eval_type): def wrap_udtf(f, return_type): assert return_type.needConversion() toInternal = return_type.toInternal + return_type_size = len(return_type) def verify_and_convert_result(result): # TODO(SPARK-44005): support returning non-tuple values if result is not None and hasattr(result, "__len__"): - if len(result) != len(return_type): + if len(result) != return_type_size: raise PySparkRuntimeError( error_class="UDTF_RETURN_SCHEMA_MISMATCH", message_parameters={ - "expected": str(len(return_type)), + "expected": str(return_type_size), "actual": str(len(result)), }, ) @@ -672,16 +668,29 @@ def read_udtf(pickleSer, infile, eval_type): # Evaluate the function and return a tuple back to the executor. def evaluate(*a) -> tuple: - res = f(*a) + try: + res = f(*a) + except Exception as e: + raise PySparkRuntimeError( + error_class="UDTF_EXEC_ERROR", + message_parameters={"method_name": f.__name__, "error": str(e)}, + ) + if res is None: # If the function returns None or does not have an explicit return statement, # an empty tuple is returned to the executor. # This is because directly constructing tuple(None) results in an exception. return tuple() - else: - # If the function returns a result, we map it to the internal representation and - # returns the results as a tuple. - return tuple(map(verify_and_convert_result, res)) + + if not isinstance(res, Iterable): + raise PySparkRuntimeError( + error_class="UDTF_RETURN_NOT_ITERABLE", + message_parameters={"type": type(res).__name__}, + ) + + # If the function returns a result, we map it to the internal representation and + # returns the results as a tuple. + return tuple(map(verify_and_convert_result, res)) return evaluate @@ -699,13 +708,7 @@ def read_udtf(pickleSer, infile, eval_type): yield eval(*[a[o] for o in arg_offsets]) finally: if terminate is not None: - try: - yield terminate() - except BaseException as e: - raise PySparkRuntimeError( - error_class="UDTF_EXEC_ERROR", - message_parameters={"method_name": "terminate", "error": str(e)}, - ) + yield terminate() return mapper, None, ser, ser --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org