This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new 7059b69e67d [SPARK-43968][PYTHON][3.5] Improve error messages for Python UDTFs with wrong number of outputs 7059b69e67d is described below commit 7059b69e67db8126dafc3d4b1f3b39e947c4c3ca Author: allisonwang-db <allison.w...@databricks.com> AuthorDate: Fri Jul 28 15:30:17 2023 +0900 [SPARK-43968][PYTHON][3.5] Improve error messages for Python UDTFs with wrong number of outputs ### What changes were proposed in this pull request? This PR cherry-picks 7194ce9263fe1683c039a1aaf9462657b1672a99. It improves the error messages for Python UDTFs when the number of outputs mismatches the number of outputs specified in the return type of the UDTFs. ### Why are the changes needed? To make Python UDTFs more user-friendly. ### Does this PR introduce _any_ user-facing change? Yes. This PR improves the error messages. Before this change, the error thrown by Spark will be a java IllegalStateException: ``` java.lang.IllegalStateException: Input row doesn't have expected number of values required by the schema ``` After this PR, it will throw a clearer error message with an error class: ``` [UDTF_RETURN_SCHEMA_MISMATCH] The number of columns in the result does not match the specified schema ``` ### How was this patch tested? Existing tests and new unit tests. Closes #42192 from allisonwang-db/spark-43968-3.5. Authored-by: allisonwang-db <allison.w...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/errors/error_classes.py | 5 + python/pyspark/sql/connect/udtf.py | 4 +- .../pyspark/sql/tests/connect/test_parity_udtf.py | 50 -------- python/pyspark/sql/tests/test_udtf.py | 133 +++++++++++---------- python/pyspark/sql/udtf.py | 9 +- python/pyspark/worker.py | 22 +++- 6 files changed, 99 insertions(+), 124 deletions(-) diff --git a/python/pyspark/errors/error_classes.py b/python/pyspark/errors/error_classes.py index b1bf6b47af9..f6411fac1da 100644 --- a/python/pyspark/errors/error_classes.py +++ b/python/pyspark/errors/error_classes.py @@ -320,6 +320,11 @@ ERROR_CLASSES_JSON = """ "The eval type for the UDTF '<name>' is invalid. It must be one of <eval_type>." ] }, + "INVALID_UDTF_HANDLER_TYPE" : { + "message" : [ + "The UDTF is invalid. The function handler must be a class, but got '<type>'. Please provide a class as the function handler." + ] + }, "INVALID_UDTF_NO_EVAL" : { "message" : [ "The UDTF '<name>' is invalid. It does not implement the required 'eval' method. Please implement the 'eval' method in '<name>' and try again." diff --git a/python/pyspark/sql/connect/udtf.py b/python/pyspark/sql/connect/udtf.py index 1fe8e1024ee..3747e37459e 100644 --- a/python/pyspark/sql/connect/udtf.py +++ b/python/pyspark/sql/connect/udtf.py @@ -124,6 +124,8 @@ class UserDefinedTableFunction: evalType: int = PythonEvalType.SQL_TABLE_UDF, deterministic: bool = True, ) -> None: + _validate_udtf_handler(func) + self.func = func self.returnType: DataType = ( UnparsedDataType(returnType) if isinstance(returnType, str) else returnType @@ -132,8 +134,6 @@ class UserDefinedTableFunction: self.evalType = evalType self.deterministic = deterministic - _validate_udtf_handler(func) - def _build_common_inline_user_defined_table_function( self, *cols: "ColumnOrName" ) -> CommonInlineUserDefinedTableFunction: diff --git a/python/pyspark/sql/tests/connect/test_parity_udtf.py b/python/pyspark/sql/tests/connect/test_parity_udtf.py index e18e116e003..355f5288d2c 100644 --- a/python/pyspark/sql/tests/connect/test_parity_udtf.py +++ b/python/pyspark/sql/tests/connect/test_parity_udtf.py @@ -54,56 +54,6 @@ class UDTFParityTests(BaseUDTFTestsMixin, ReusedConnectTestCase): ): TestUDTF(lit(1)).collect() - def test_udtf_with_wrong_num_output(self): - err_msg = ( - "java.lang.IllegalStateException: Input row doesn't have expected number of " - + "values required by the schema." - ) - - @udtf(returnType="a: int, b: int") - class TestUDTF: - def eval(self, a: int): - yield a, - - with self.assertRaisesRegex(SparkConnectGrpcException, err_msg): - TestUDTF(lit(1)).collect() - - @udtf(returnType="a: int") - class TestUDTF: - def eval(self, a: int): - yield a, a + 1 - - with self.assertRaisesRegex(SparkConnectGrpcException, err_msg): - TestUDTF(lit(1)).collect() - - def test_udtf_terminate_with_wrong_num_output(self): - err_msg = ( - "java.lang.IllegalStateException: Input row doesn't have expected number of " - "values required by the schema." - ) - - @udtf(returnType="a: int, b: int") - class TestUDTF: - def eval(self, a: int): - yield a, a + 1 - - def terminate(self): - yield 1, 2, 3 - - with self.assertRaisesRegex(SparkConnectGrpcException, err_msg): - TestUDTF(lit(1)).show() - - @udtf(returnType="a: int, b: int") - class TestUDTF: - def eval(self, a: int): - yield a, a + 1 - - def terminate(self): - yield 1, - - with self.assertRaisesRegex(SparkConnectGrpcException, err_msg): - TestUDTF(lit(1)).show() - class ArrowUDTFParityTests(UDTFArrowTestsMixin, UDTFParityTests): @classmethod diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index 0aa769e506d..b3e832b8b97 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -19,8 +19,6 @@ import unittest from typing import Iterator -from py4j.protocol import Py4JJavaError - from pyspark.errors import ( PySparkAttributeError, PythonException, @@ -234,29 +232,76 @@ class BaseUDTFTestsMixin: ): TestUDTF(lit(1), lit(2)).collect() + def test_udtf_init_with_additional_args(self): + @udtf(returnType="x int") + class TestUDTF: + def __init__(self, a: int): + ... + + def eval(self, a: int): + yield a, + + with self.assertRaisesRegex( + PythonException, r"__init__\(\) missing 1 required positional argument: 'a'" + ): + TestUDTF(lit(1)).show() + + def test_udtf_terminate_with_additional_args(self): + @udtf(returnType="x int") + class TestUDTF: + def eval(self, a: int): + yield a, + + def terminate(self, a: int): + ... + + with self.assertRaisesRegex( + PythonException, r"terminate\(\) missing 1 required positional argument: 'a'" + ): + TestUDTF(lit(1)).show() + def test_udtf_with_wrong_num_output(self): - # TODO(SPARK-43968): check this during compile time instead of runtime err_msg = ( - "java.lang.IllegalStateException: Input row doesn't have expected number of " - + "values required by the schema." + r"\[UDTF_RETURN_SCHEMA_MISMATCH\] The number of columns in the " + "result does not match the specified schema." ) + # Output less columns than specified return schema @udtf(returnType="a: int, b: int") class TestUDTF: def eval(self, a: int): yield a, - with self.assertRaisesRegex(Py4JJavaError, err_msg): + with self.assertRaisesRegex(PythonException, err_msg): TestUDTF(lit(1)).collect() + # Output more columns than specified return schema @udtf(returnType="a: int") class TestUDTF: def eval(self, a: int): yield a, a + 1 - with self.assertRaisesRegex(Py4JJavaError, err_msg): + with self.assertRaisesRegex(PythonException, err_msg): TestUDTF(lit(1)).collect() + def test_udtf_with_empty_output_schema_and_non_empty_output(self): + @udtf(returnType=StructType()) + class TestUDTF: + def eval(self): + yield 1, + + with self.assertRaisesRegex(PythonException, "UDTF_RETURN_SCHEMA_MISMATCH"): + TestUDTF().collect() + + def test_udtf_with_non_empty_output_schema_and_empty_output(self): + @udtf(returnType="a: int") + class TestUDTF: + def eval(self): + yield tuple() + + with self.assertRaisesRegex(PythonException, "UDTF_RETURN_SCHEMA_MISMATCH"): + TestUDTF().collect() + def test_udtf_init(self): @udtf(returnType="a: int, b: int, c: string") class TestUDTF: @@ -323,8 +368,8 @@ class BaseUDTFTestsMixin: def test_udtf_terminate_with_wrong_num_output(self): err_msg = ( - "java.lang.IllegalStateException: Input row doesn't have expected number of " - "values required by the schema." + r"\[UDTF_RETURN_SCHEMA_MISMATCH\] The number of columns in the result " + "does not match the specified schema." ) @udtf(returnType="a: int, b: int") @@ -335,7 +380,7 @@ class BaseUDTFTestsMixin: def terminate(self): yield 1, 2, 3 - with self.assertRaisesRegex(Py4JJavaError, err_msg): + with self.assertRaisesRegex(PythonException, err_msg): TestUDTF(lit(1)).show() @udtf(returnType="a: int, b: int") @@ -346,7 +391,7 @@ class BaseUDTFTestsMixin: def terminate(self): yield 1, - with self.assertRaisesRegex(Py4JJavaError, err_msg): + with self.assertRaisesRegex(PythonException, err_msg): TestUDTF(lit(1)).show() def test_nondeterministic_udtf(self): @@ -523,18 +568,26 @@ class BaseUDTFTestsMixin: ) def test_udtf_with_no_handler_class(self): - err_msg = "the function handler must be a class" - with self.assertRaisesRegex(TypeError, err_msg): + with self.assertRaises(PySparkTypeError) as e: @udtf(returnType="a: int") def test_udtf(a: int): yield a, - def test_udtf(a: int): - yield a + self.check_error( + exception=e.exception, + error_class="INVALID_UDTF_HANDLER_TYPE", + message_parameters={"type": "function"}, + ) + + with self.assertRaises(PySparkTypeError) as e: + udtf(1, returnType="a: int") - with self.assertRaisesRegex(TypeError, err_msg): - udtf(test_udtf, returnType="a: int") + self.check_error( + exception=e.exception, + error_class="INVALID_UDTF_HANDLER_TYPE", + message_parameters={"type": "int"}, + ) def test_udtf_with_table_argument_query(self): class TestUDTF: @@ -804,52 +857,6 @@ class UDTFArrowTestsMixin(BaseUDTFTestsMixin): func = udtf(TestUDTF, returnType="a: int") self.assertEqual(func(lit(1)).collect(), [Row(a=1)]) - def test_udtf_terminate_with_wrong_num_output(self): - # The error message for arrow-optimized UDTF is different from regular UDTF. - err_msg = "The number of columns in the result does not match the specified schema." - - @udtf(returnType="a: int, b: int") - class TestUDTF: - def eval(self, a: int): - yield a, a + 1 - - def terminate(self): - yield 1, 2, 3 - - with self.assertRaisesRegex(PythonException, err_msg): - TestUDTF(lit(1)).show() - - @udtf(returnType="a: int, b: int") - class TestUDTF: - def eval(self, a: int): - yield a, a + 1 - - def terminate(self): - yield 1, - - with self.assertRaisesRegex(PythonException, err_msg): - TestUDTF(lit(1)).show() - - def test_udtf_with_wrong_num_output(self): - # The error message for arrow-optimized UDTF is different. - err_msg = "The number of columns in the result does not match the specified schema." - - @udtf(returnType="a: int, b: int") - class TestUDTF: - def eval(self, a: int): - yield a, - - with self.assertRaisesRegex(PythonException, err_msg): - TestUDTF(lit(1)).collect() - - @udtf(returnType="a: int") - class TestUDTF: - def eval(self, a: int): - yield a, a + 1 - - with self.assertRaisesRegex(PythonException, err_msg): - TestUDTF(lit(1)).collect() - class UDTFArrowTests(UDTFArrowTestsMixin, ReusedSQLTestCase): @classmethod diff --git a/python/pyspark/sql/udtf.py b/python/pyspark/sql/udtf.py index 3ab74193093..50bba56880c 100644 --- a/python/pyspark/sql/udtf.py +++ b/python/pyspark/sql/udtf.py @@ -139,13 +139,10 @@ def _vectorize_udtf(cls: Type) -> Type: def _validate_udtf_handler(cls: Any) -> None: """Validate the handler class of a UDTF.""" - # TODO(SPARK-43968): add more compile time checks for UDTFs if not isinstance(cls, type): raise PySparkTypeError( - f"Invalid user defined table function: the function handler " - f"must be a class, but got {type(cls).__name__}. Please provide " - "a class as the handler." + error_class="INVALID_UDTF_HANDLER_TYPE", message_parameters={"type": type(cls).__name__} ) if not hasattr(cls, "eval"): @@ -176,6 +173,8 @@ class UserDefinedTableFunction: evalType: int = PythonEvalType.SQL_TABLE_UDF, deterministic: bool = True, ): + _validate_udtf_handler(func) + self.func = func self._returnType = returnType self._returnType_placeholder: Optional[StructType] = None @@ -185,8 +184,6 @@ class UserDefinedTableFunction: self.evalType = evalType self.deterministic = deterministic - _validate_udtf_handler(func) - @property def returnType(self) -> StructType: # `_parse_datatype_string` accesses to JVM for parsing a DDL formatted string. diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 2445b46970c..cbc9faad47c 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -532,8 +532,11 @@ def read_udtf(pickleSer, infile, eval_type): }, ) - # Check when the dataframe has both rows and columns. - if not result.empty or len(result.columns) != 0: + # Validate the output schema when the result dataframe has either output + # rows or columns. Note that we avoid using `df.empty` here because the + # result dataframe may contain an empty row. For example, when a UDTF is + # defined as follows: def eval(self): yield tuple(). + if len(result) > 0 or len(result.columns) > 0: if len(result.columns) != len(return_type): raise PySparkRuntimeError( error_class="UDTF_RETURN_SCHEMA_MISMATCH", @@ -580,6 +583,19 @@ def read_udtf(pickleSer, infile, eval_type): assert return_type.needConversion() toInternal = return_type.toInternal + def verify_and_convert_result(result): + # TODO(SPARK-44005): support returning non-tuple values + if result is not None and hasattr(result, "__len__"): + if len(result) != len(return_type): + raise PySparkRuntimeError( + error_class="UDTF_RETURN_SCHEMA_MISMATCH", + message_parameters={ + "expected": str(len(return_type)), + "actual": str(len(result)), + }, + ) + return toInternal(result) + # Evaluate the function and return a tuple back to the executor. def evaluate(*a) -> tuple: res = f(*a) @@ -591,7 +607,7 @@ def read_udtf(pickleSer, infile, eval_type): else: # If the function returns a result, we map it to the internal representation and # returns the results as a tuple. - return tuple(map(toInternal, res)) + return tuple(map(verify_and_convert_result, res)) return evaluate --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org