This is an automated email from the ASF dual-hosted git repository. ueshin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 7194ce9263f [SPARK-43968][PYTHON] Improve error messages for Python UDTFs with wrong number of outputs 7194ce9263f is described below commit 7194ce9263fe1683c039a1aaf9462657b1672a99 Author: allisonwang-db <allison.w...@databricks.com> AuthorDate: Thu Jul 27 13:18:39 2023 -0700 [SPARK-43968][PYTHON] Improve error messages for Python UDTFs with wrong number of outputs ### What changes were proposed in this pull request? This PR improves the error messages for Python UDTFs when the number of outputs mismatches the number of outputs specified in the return type of the UDTFs. ### Why are the changes needed? To make Python UDTFs more user-friendly. ### Does this PR introduce _any_ user-facing change? Yes. This PR improves the error messages. Before this change, the error thrown by Spark will be a java IllegalStateException: ``` java.lang.IllegalStateException: Input row doesn't have expected number of values required by the schema ``` After this PR, it will throw a clearer error message with an error class: ``` [UDTF_RETURN_SCHEMA_MISMATCH] The number of columns in the result does not match the specified schema ``` ### How was this patch tested? Existing tests and new unit tests. Closes #42157 from allisonwang-db/spark-43968-py-udtf-checks. Authored-by: allisonwang-db <allison.w...@databricks.com> Signed-off-by: Takuya UESHIN <ues...@databricks.com> --- python/pyspark/errors/error_classes.py | 5 + python/pyspark/sql/connect/udtf.py | 4 +- .../pyspark/sql/tests/connect/test_parity_udtf.py | 50 -------- python/pyspark/sql/tests/test_udtf.py | 133 +++++++++++---------- python/pyspark/sql/udtf.py | 9 +- python/pyspark/worker.py | 22 +++- 6 files changed, 99 insertions(+), 124 deletions(-) diff --git a/python/pyspark/errors/error_classes.py b/python/pyspark/errors/error_classes.py index e0d1c30b604..f4b643f1d32 100644 --- a/python/pyspark/errors/error_classes.py +++ b/python/pyspark/errors/error_classes.py @@ -283,6 +283,11 @@ ERROR_CLASSES_JSON = """ "The eval type for the UDTF '<name>' is invalid. It must be one of <eval_type>." ] }, + "INVALID_UDTF_HANDLER_TYPE" : { + "message" : [ + "The UDTF is invalid. The function handler must be a class, but got '<type>'. Please provide a class as the function handler." + ] + }, "INVALID_UDTF_NO_EVAL" : { "message" : [ "The UDTF '<name>' is invalid. It does not implement the required 'eval' method. Please implement the 'eval' method in '<name>' and try again." diff --git a/python/pyspark/sql/connect/udtf.py b/python/pyspark/sql/connect/udtf.py index 74c55cc42cd..919994401c8 100644 --- a/python/pyspark/sql/connect/udtf.py +++ b/python/pyspark/sql/connect/udtf.py @@ -124,6 +124,8 @@ class UserDefinedTableFunction: evalType: int = PythonEvalType.SQL_TABLE_UDF, deterministic: bool = True, ) -> None: + _validate_udtf_handler(func, returnType) + self.func = func self.returnType: Optional[DataType] = ( None @@ -136,8 +138,6 @@ class UserDefinedTableFunction: self.evalType = evalType self.deterministic = deterministic - _validate_udtf_handler(func, returnType) - def _build_common_inline_user_defined_table_function( self, *cols: "ColumnOrName" ) -> CommonInlineUserDefinedTableFunction: diff --git a/python/pyspark/sql/tests/connect/test_parity_udtf.py b/python/pyspark/sql/tests/connect/test_parity_udtf.py index 1aff1bd0686..748b611e667 100644 --- a/python/pyspark/sql/tests/connect/test_parity_udtf.py +++ b/python/pyspark/sql/tests/connect/test_parity_udtf.py @@ -56,56 +56,6 @@ class UDTFParityTests(BaseUDTFTestsMixin, ReusedConnectTestCase): ): TestUDTF(lit(1)).collect() - def test_udtf_with_wrong_num_output(self): - err_msg = ( - "java.lang.IllegalStateException: Input row doesn't have expected number of " - + "values required by the schema." - ) - - @udtf(returnType="a: int, b: int") - class TestUDTF: - def eval(self, a: int): - yield a, - - with self.assertRaisesRegex(SparkConnectGrpcException, err_msg): - TestUDTF(lit(1)).collect() - - @udtf(returnType="a: int") - class TestUDTF: - def eval(self, a: int): - yield a, a + 1 - - with self.assertRaisesRegex(SparkConnectGrpcException, err_msg): - TestUDTF(lit(1)).collect() - - def test_udtf_terminate_with_wrong_num_output(self): - err_msg = ( - "java.lang.IllegalStateException: Input row doesn't have expected number of " - "values required by the schema." - ) - - @udtf(returnType="a: int, b: int") - class TestUDTF: - def eval(self, a: int): - yield a, a + 1 - - def terminate(self): - yield 1, 2, 3 - - with self.assertRaisesRegex(SparkConnectGrpcException, err_msg): - TestUDTF(lit(1)).show() - - @udtf(returnType="a: int, b: int") - class TestUDTF: - def eval(self, a: int): - yield a, a + 1 - - def terminate(self): - yield 1, - - with self.assertRaisesRegex(SparkConnectGrpcException, err_msg): - TestUDTF(lit(1)).show() - @unittest.skip("Spark Connect does not support broadcast but the test depends on it.") def test_udtf_with_analyze_using_broadcast(self): super().test_udtf_with_analyze_using_broadcast() diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index e5b29b36034..120d46491c9 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -21,8 +21,6 @@ import unittest from typing import Iterator -from py4j.protocol import Py4JJavaError - from pyspark.errors import ( PySparkAttributeError, PythonException, @@ -259,29 +257,76 @@ class BaseUDTFTestsMixin: ): TestUDTF(lit(1), lit(2)).collect() + def test_udtf_init_with_additional_args(self): + @udtf(returnType="x int") + class TestUDTF: + def __init__(self, a: int): + ... + + def eval(self, a: int): + yield a, + + with self.assertRaisesRegex( + PythonException, r"__init__\(\) missing 1 required positional argument: 'a'" + ): + TestUDTF(lit(1)).show() + + def test_udtf_terminate_with_additional_args(self): + @udtf(returnType="x int") + class TestUDTF: + def eval(self, a: int): + yield a, + + def terminate(self, a: int): + ... + + with self.assertRaisesRegex( + PythonException, r"terminate\(\) missing 1 required positional argument: 'a'" + ): + TestUDTF(lit(1)).show() + def test_udtf_with_wrong_num_output(self): - # TODO(SPARK-43968): check this during compile time instead of runtime err_msg = ( - "java.lang.IllegalStateException: Input row doesn't have expected number of " - + "values required by the schema." + r"\[UDTF_RETURN_SCHEMA_MISMATCH\] The number of columns in the " + "result does not match the specified schema." ) + # Output less columns than specified return schema @udtf(returnType="a: int, b: int") class TestUDTF: def eval(self, a: int): yield a, - with self.assertRaisesRegex(Py4JJavaError, err_msg): + with self.assertRaisesRegex(PythonException, err_msg): TestUDTF(lit(1)).collect() + # Output more columns than specified return schema @udtf(returnType="a: int") class TestUDTF: def eval(self, a: int): yield a, a + 1 - with self.assertRaisesRegex(Py4JJavaError, err_msg): + with self.assertRaisesRegex(PythonException, err_msg): TestUDTF(lit(1)).collect() + def test_udtf_with_empty_output_schema_and_non_empty_output(self): + @udtf(returnType=StructType()) + class TestUDTF: + def eval(self): + yield 1, + + with self.assertRaisesRegex(PythonException, "UDTF_RETURN_SCHEMA_MISMATCH"): + TestUDTF().collect() + + def test_udtf_with_non_empty_output_schema_and_empty_output(self): + @udtf(returnType="a: int") + class TestUDTF: + def eval(self): + yield tuple() + + with self.assertRaisesRegex(PythonException, "UDTF_RETURN_SCHEMA_MISMATCH"): + TestUDTF().collect() + def test_udtf_init(self): @udtf(returnType="a: int, b: int, c: string") class TestUDTF: @@ -348,8 +393,8 @@ class BaseUDTFTestsMixin: def test_udtf_terminate_with_wrong_num_output(self): err_msg = ( - "java.lang.IllegalStateException: Input row doesn't have expected number of " - "values required by the schema." + r"\[UDTF_RETURN_SCHEMA_MISMATCH\] The number of columns in the result " + "does not match the specified schema." ) @udtf(returnType="a: int, b: int") @@ -360,7 +405,7 @@ class BaseUDTFTestsMixin: def terminate(self): yield 1, 2, 3 - with self.assertRaisesRegex(Py4JJavaError, err_msg): + with self.assertRaisesRegex(PythonException, err_msg): TestUDTF(lit(1)).show() @udtf(returnType="a: int, b: int") @@ -371,7 +416,7 @@ class BaseUDTFTestsMixin: def terminate(self): yield 1, - with self.assertRaisesRegex(Py4JJavaError, err_msg): + with self.assertRaisesRegex(PythonException, err_msg): TestUDTF(lit(1)).show() def test_nondeterministic_udtf(self): @@ -548,18 +593,26 @@ class BaseUDTFTestsMixin: ) def test_udtf_with_no_handler_class(self): - err_msg = "the function handler must be a class" - with self.assertRaisesRegex(TypeError, err_msg): + with self.assertRaises(PySparkTypeError) as e: @udtf(returnType="a: int") def test_udtf(a: int): yield a, - def test_udtf(a: int): - yield a + self.check_error( + exception=e.exception, + error_class="INVALID_UDTF_HANDLER_TYPE", + message_parameters={"type": "function"}, + ) + + with self.assertRaises(PySparkTypeError) as e: + udtf(1, returnType="a: int") - with self.assertRaisesRegex(TypeError, err_msg): - udtf(test_udtf, returnType="a: int") + self.check_error( + exception=e.exception, + error_class="INVALID_UDTF_HANDLER_TYPE", + message_parameters={"type": "int"}, + ) def test_udtf_with_table_argument_query(self): class TestUDTF: @@ -1498,52 +1551,6 @@ class UDTFArrowTestsMixin(BaseUDTFTestsMixin): func = udtf(TestUDTF, returnType="a: int") self.assertEqual(func(lit(1)).collect(), [Row(a=1)]) - def test_udtf_terminate_with_wrong_num_output(self): - # The error message for arrow-optimized UDTF is different from regular UDTF. - err_msg = "The number of columns in the result does not match the specified schema." - - @udtf(returnType="a: int, b: int") - class TestUDTF: - def eval(self, a: int): - yield a, a + 1 - - def terminate(self): - yield 1, 2, 3 - - with self.assertRaisesRegex(PythonException, err_msg): - TestUDTF(lit(1)).show() - - @udtf(returnType="a: int, b: int") - class TestUDTF: - def eval(self, a: int): - yield a, a + 1 - - def terminate(self): - yield 1, - - with self.assertRaisesRegex(PythonException, err_msg): - TestUDTF(lit(1)).show() - - def test_udtf_with_wrong_num_output(self): - # The error message for arrow-optimized UDTF is different. - err_msg = "The number of columns in the result does not match the specified schema." - - @udtf(returnType="a: int, b: int") - class TestUDTF: - def eval(self, a: int): - yield a, - - with self.assertRaisesRegex(PythonException, err_msg): - TestUDTF(lit(1)).collect() - - @udtf(returnType="a: int") - class TestUDTF: - def eval(self, a: int): - yield a, a + 1 - - with self.assertRaisesRegex(PythonException, err_msg): - TestUDTF(lit(1)).collect() - class UDTFArrowTests(UDTFArrowTestsMixin, ReusedSQLTestCase): @classmethod diff --git a/python/pyspark/sql/udtf.py b/python/pyspark/sql/udtf.py index 67d4ef33777..d14a263f839 100644 --- a/python/pyspark/sql/udtf.py +++ b/python/pyspark/sql/udtf.py @@ -186,13 +186,10 @@ def _vectorize_udtf(cls: Type) -> Type: def _validate_udtf_handler(cls: Any, returnType: Optional[Union[StructType, str]]) -> None: """Validate the handler class of a UDTF.""" - # TODO(SPARK-43968): add more compile time checks for UDTFs if not isinstance(cls, type): raise PySparkTypeError( - f"Invalid user defined table function: the function handler " - f"must be a class, but got {type(cls).__name__}. Please provide " - "a class as the handler." + error_class="INVALID_UDTF_HANDLER_TYPE", message_parameters={"type": type(cls).__name__} ) if not hasattr(cls, "eval"): @@ -237,6 +234,8 @@ class UserDefinedTableFunction: evalType: int = PythonEvalType.SQL_TABLE_UDF, deterministic: bool = True, ): + _validate_udtf_handler(func, returnType) + self.func = func self._returnType = returnType self._returnType_placeholder: Optional[StructType] = None @@ -246,8 +245,6 @@ class UserDefinedTableFunction: self.evalType = evalType self.deterministic = deterministic - _validate_udtf_handler(func, returnType) - @property def returnType(self) -> Optional[StructType]: if self._returnType is None: diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index bfe788faf6d..20e856c9add 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -604,8 +604,11 @@ def read_udtf(pickleSer, infile, eval_type): }, ) - # Check when the dataframe has both rows and columns. - if not result.empty or len(result.columns) != 0: + # Validate the output schema when the result dataframe has either output + # rows or columns. Note that we avoid using `df.empty` here because the + # result dataframe may contain an empty row. For example, when a UDTF is + # defined as follows: def eval(self): yield tuple(). + if len(result) > 0 or len(result.columns) > 0: if len(result.columns) != len(return_type): raise PySparkRuntimeError( error_class="UDTF_RETURN_SCHEMA_MISMATCH", @@ -654,6 +657,19 @@ def read_udtf(pickleSer, infile, eval_type): assert return_type.needConversion() toInternal = return_type.toInternal + def verify_and_convert_result(result): + # TODO(SPARK-44005): support returning non-tuple values + if result is not None and hasattr(result, "__len__"): + if len(result) != len(return_type): + raise PySparkRuntimeError( + error_class="UDTF_RETURN_SCHEMA_MISMATCH", + message_parameters={ + "expected": str(len(return_type)), + "actual": str(len(result)), + }, + ) + return toInternal(result) + # Evaluate the function and return a tuple back to the executor. def evaluate(*a) -> tuple: res = f(*a) @@ -665,7 +681,7 @@ def read_udtf(pickleSer, infile, eval_type): else: # If the function returns a result, we map it to the internal representation and # returns the results as a tuple. - return tuple(map(toInternal, res)) + return tuple(map(verify_and_convert_result, res)) return evaluate --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org