This is an automated email from the ASF dual-hosted git repository. ueshin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 3e22c8653d7 [SPARK-44640][PYTHON][FOLLOW-UP] Update UDTF error messages to include method name 3e22c8653d7 is described below commit 3e22c8653d728a6b8523051faddcca437accfc22 Author: allisonwang-db <allison.w...@databricks.com> AuthorDate: Sat Sep 2 16:07:09 2023 -0700 [SPARK-44640][PYTHON][FOLLOW-UP] Update UDTF error messages to include method name ### What changes were proposed in this pull request? This PR is a follow-up for SPARK-44640 to make the error message of a few UDTF errors more informative by including the method name in the error message (`eval` or `terminate`). ### Why are the changes needed? To improve error messages. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No Closes #42726 from allisonwang-db/SPARK-44640-follow-up. Authored-by: allisonwang-db <allison.w...@databricks.com> Signed-off-by: Takuya UESHIN <ues...@databricks.com> --- python/pyspark/errors/error_classes.py | 8 ++++---- python/pyspark/sql/tests/test_udtf.py | 21 +++++++++++++++++++ python/pyspark/worker.py | 37 +++++++++++++++++++++++++--------- 3 files changed, 52 insertions(+), 14 deletions(-) diff --git a/python/pyspark/errors/error_classes.py b/python/pyspark/errors/error_classes.py index ca448a169e8..74f52c416e9 100644 --- a/python/pyspark/errors/error_classes.py +++ b/python/pyspark/errors/error_classes.py @@ -244,7 +244,7 @@ ERROR_CLASSES_JSON = """ }, "INVALID_ARROW_UDTF_RETURN_TYPE" : { "message" : [ - "The return type of the arrow-optimized Python UDTF should be of type 'pandas.DataFrame', but the function returned a value of type <type_name> with value: <value>." + "The return type of the arrow-optimized Python UDTF should be of type 'pandas.DataFrame', but the '<func>' method returned a value of type <type_name> with value: <value>." ] }, "INVALID_BROADCAST_OPERATION": { @@ -745,17 +745,17 @@ ERROR_CLASSES_JSON = """ }, "UDTF_INVALID_OUTPUT_ROW_TYPE" : { "message" : [ - "The type of an individual output row in the UDTF is invalid. Each row should be a tuple, list, or dict, but got '<type>'. Please make sure that the output rows are of the correct type." + "The type of an individual output row in the '<func>' method of the UDTF is invalid. Each row should be a tuple, list, or dict, but got '<type>'. Please make sure that the output rows are of the correct type." ] }, "UDTF_RETURN_NOT_ITERABLE" : { "message" : [ - "The return value of the UDTF is invalid. It should be an iterable (e.g., generator or list), but got '<type>'. Please make sure that the UDTF returns one of these types." + "The return value of the '<func>' method of the UDTF is invalid. It should be an iterable (e.g., generator or list), but got '<type>'. Please make sure that the UDTF returns one of these types." ] }, "UDTF_RETURN_SCHEMA_MISMATCH" : { "message" : [ - "The number of columns in the result does not match the specified schema. Expected column count: <expected>, Actual column count: <actual>. Please make sure the values returned by the function have the same number of columns as specified in the output schema." + "The number of columns in the result does not match the specified schema. Expected column count: <expected>, Actual column count: <actual>. Please make sure the values returned by the '<func>' method have the same number of columns as specified in the output schema." ] }, "UDTF_RETURN_TYPE_MISMATCH" : { diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index c5f8b7693c2..97d5190a506 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -190,6 +190,27 @@ class BaseUDTFTestsMixin: with self.assertRaisesRegex(PythonException, "UDTF_RETURN_NOT_ITERABLE"): TestUDTF(lit(1)).collect() + def test_udtf_with_zero_arg_and_invalid_return_value(self): + @udtf(returnType="x: int") + class TestUDTF: + def eval(self): + return 1 + + with self.assertRaisesRegex(PythonException, "UDTF_RETURN_NOT_ITERABLE"): + TestUDTF().collect() + + def test_udtf_with_invalid_return_value_in_terminate(self): + @udtf(returnType="x: int") + class TestUDTF: + def eval(self, a): + ... + + def terminate(self): + return 1 + + with self.assertRaisesRegex(PythonException, "UDTF_RETURN_NOT_ITERABLE"): + TestUDTF(lit(1)).collect() + def test_udtf_eval_with_no_return(self): @udtf(returnType="a: int") class TestUDTF: diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index d95a5c4672f..fff99f1de3d 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -773,6 +773,7 @@ def read_udtf(pickleSer, infile, eval_type): message_parameters={ "type_name": type(result).__name__, "value": str(result), + "func": f.__name__, }, ) @@ -787,6 +788,7 @@ def read_udtf(pickleSer, infile, eval_type): message_parameters={ "expected": str(return_type_size), "actual": str(len(result.columns)), + "func": f.__name__, }, ) @@ -806,9 +808,23 @@ def read_udtf(pickleSer, infile, eval_type): message_parameters={"method_name": f.__name__, "error": str(e)}, ) + def check_return_value(res): + # Check whether the result of an arrow UDTF is iterable before + # using it to construct a pandas DataFrame. + if res is not None and not isinstance(res, Iterable): + raise PySparkRuntimeError( + error_class="UDTF_RETURN_NOT_ITERABLE", + message_parameters={ + "type": type(res).__name__, + "func": f.__name__, + }, + ) + def evaluate(*args: pd.Series, **kwargs: pd.Series): if len(args) == 0 and len(kwargs) == 0: - yield verify_result(pd.DataFrame(func())), arrow_return_type + res = func() + check_return_value(res) + yield verify_result(pd.DataFrame(res)), arrow_return_type else: # Create tuples from the input pandas Series, each tuple # represents a row across all Series. @@ -820,13 +836,7 @@ def read_udtf(pickleSer, infile, eval_type): *row[:len_args], **{key: row[len_args + i] for i, key in enumerate(keys)}, ) - if res is not None and not isinstance(res, Iterable): - raise PySparkRuntimeError( - error_class="UDTF_RETURN_NOT_ITERABLE", - message_parameters={ - "type": type(res).__name__, - }, - ) + check_return_value(res) yield verify_result(pd.DataFrame(res)), arrow_return_type return evaluate @@ -868,13 +878,17 @@ def read_udtf(pickleSer, infile, eval_type): message_parameters={ "expected": str(return_type_size), "actual": str(len(result)), + "func": f.__name__, }, ) if not (isinstance(result, (list, dict, tuple)) or hasattr(result, "__dict__")): raise PySparkRuntimeError( error_class="UDTF_INVALID_OUTPUT_ROW_TYPE", - message_parameters={"type": type(result).__name__}, + message_parameters={ + "type": type(result).__name__, + "func": f.__name__, + }, ) return toInternal(result) @@ -898,7 +912,10 @@ def read_udtf(pickleSer, infile, eval_type): if not isinstance(res, Iterable): raise PySparkRuntimeError( error_class="UDTF_RETURN_NOT_ITERABLE", - message_parameters={"type": type(res).__name__}, + message_parameters={ + "type": type(res).__name__, + "func": f.__name__, + }, ) # If the function returns a result, we map it to the internal representation and --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org