This is an automated email from the ASF dual-hosted git repository.

ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 7194ce9263f [SPARK-43968][PYTHON] Improve error messages for Python 
UDTFs with wrong number of outputs
7194ce9263f is described below

commit 7194ce9263fe1683c039a1aaf9462657b1672a99
Author: allisonwang-db <allison.w...@databricks.com>
AuthorDate: Thu Jul 27 13:18:39 2023 -0700

    [SPARK-43968][PYTHON] Improve error messages for Python UDTFs with wrong 
number of outputs
    
    ### What changes were proposed in this pull request?
    
    This PR improves the error messages for Python UDTFs when the number of 
outputs mismatches the number of outputs specified in the return type of the 
UDTFs.
    
    ### Why are the changes needed?
    
    To make Python UDTFs more user-friendly.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. This PR improves the error messages.
    Before this change, the error thrown by Spark will be a java 
IllegalStateException:
    ```
    java.lang.IllegalStateException: Input row doesn't have expected number of 
values required by the schema
    ```
    After this PR, it will throw a clearer error message with an error class:
    ```
    [UDTF_RETURN_SCHEMA_MISMATCH] The number of columns in the result does not 
match the specified schema
    ```
    
    ### How was this patch tested?
    
    Existing tests and new unit tests.
    
    Closes #42157 from allisonwang-db/spark-43968-py-udtf-checks.
    
    Authored-by: allisonwang-db <allison.w...@databricks.com>
    Signed-off-by: Takuya UESHIN <ues...@databricks.com>
---
 python/pyspark/errors/error_classes.py             |   5 +
 python/pyspark/sql/connect/udtf.py                 |   4 +-
 .../pyspark/sql/tests/connect/test_parity_udtf.py  |  50 --------
 python/pyspark/sql/tests/test_udtf.py              | 133 +++++++++++----------
 python/pyspark/sql/udtf.py                         |   9 +-
 python/pyspark/worker.py                           |  22 +++-
 6 files changed, 99 insertions(+), 124 deletions(-)

diff --git a/python/pyspark/errors/error_classes.py 
b/python/pyspark/errors/error_classes.py
index e0d1c30b604..f4b643f1d32 100644
--- a/python/pyspark/errors/error_classes.py
+++ b/python/pyspark/errors/error_classes.py
@@ -283,6 +283,11 @@ ERROR_CLASSES_JSON = """
       "The eval type for the UDTF '<name>' is invalid. It must be one of 
<eval_type>."
     ]
   },
+  "INVALID_UDTF_HANDLER_TYPE" : {
+    "message" : [
+      "The UDTF is invalid. The function handler must be a class, but got 
'<type>'. Please provide a class as the function handler."
+    ]
+  },
   "INVALID_UDTF_NO_EVAL" : {
     "message" : [
       "The UDTF '<name>' is invalid. It does not implement the required 'eval' 
method. Please implement the 'eval' method in '<name>' and try again."
diff --git a/python/pyspark/sql/connect/udtf.py 
b/python/pyspark/sql/connect/udtf.py
index 74c55cc42cd..919994401c8 100644
--- a/python/pyspark/sql/connect/udtf.py
+++ b/python/pyspark/sql/connect/udtf.py
@@ -124,6 +124,8 @@ class UserDefinedTableFunction:
         evalType: int = PythonEvalType.SQL_TABLE_UDF,
         deterministic: bool = True,
     ) -> None:
+        _validate_udtf_handler(func, returnType)
+
         self.func = func
         self.returnType: Optional[DataType] = (
             None
@@ -136,8 +138,6 @@ class UserDefinedTableFunction:
         self.evalType = evalType
         self.deterministic = deterministic
 
-        _validate_udtf_handler(func, returnType)
-
     def _build_common_inline_user_defined_table_function(
         self, *cols: "ColumnOrName"
     ) -> CommonInlineUserDefinedTableFunction:
diff --git a/python/pyspark/sql/tests/connect/test_parity_udtf.py 
b/python/pyspark/sql/tests/connect/test_parity_udtf.py
index 1aff1bd0686..748b611e667 100644
--- a/python/pyspark/sql/tests/connect/test_parity_udtf.py
+++ b/python/pyspark/sql/tests/connect/test_parity_udtf.py
@@ -56,56 +56,6 @@ class UDTFParityTests(BaseUDTFTestsMixin, 
ReusedConnectTestCase):
         ):
             TestUDTF(lit(1)).collect()
 
-    def test_udtf_with_wrong_num_output(self):
-        err_msg = (
-            "java.lang.IllegalStateException: Input row doesn't have expected 
number of "
-            + "values required by the schema."
-        )
-
-        @udtf(returnType="a: int, b: int")
-        class TestUDTF:
-            def eval(self, a: int):
-                yield a,
-
-        with self.assertRaisesRegex(SparkConnectGrpcException, err_msg):
-            TestUDTF(lit(1)).collect()
-
-        @udtf(returnType="a: int")
-        class TestUDTF:
-            def eval(self, a: int):
-                yield a, a + 1
-
-        with self.assertRaisesRegex(SparkConnectGrpcException, err_msg):
-            TestUDTF(lit(1)).collect()
-
-    def test_udtf_terminate_with_wrong_num_output(self):
-        err_msg = (
-            "java.lang.IllegalStateException: Input row doesn't have expected 
number of "
-            "values required by the schema."
-        )
-
-        @udtf(returnType="a: int, b: int")
-        class TestUDTF:
-            def eval(self, a: int):
-                yield a, a + 1
-
-            def terminate(self):
-                yield 1, 2, 3
-
-        with self.assertRaisesRegex(SparkConnectGrpcException, err_msg):
-            TestUDTF(lit(1)).show()
-
-        @udtf(returnType="a: int, b: int")
-        class TestUDTF:
-            def eval(self, a: int):
-                yield a, a + 1
-
-            def terminate(self):
-                yield 1,
-
-        with self.assertRaisesRegex(SparkConnectGrpcException, err_msg):
-            TestUDTF(lit(1)).show()
-
     @unittest.skip("Spark Connect does not support broadcast but the test 
depends on it.")
     def test_udtf_with_analyze_using_broadcast(self):
         super().test_udtf_with_analyze_using_broadcast()
diff --git a/python/pyspark/sql/tests/test_udtf.py 
b/python/pyspark/sql/tests/test_udtf.py
index e5b29b36034..120d46491c9 100644
--- a/python/pyspark/sql/tests/test_udtf.py
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -21,8 +21,6 @@ import unittest
 
 from typing import Iterator
 
-from py4j.protocol import Py4JJavaError
-
 from pyspark.errors import (
     PySparkAttributeError,
     PythonException,
@@ -259,29 +257,76 @@ class BaseUDTFTestsMixin:
         ):
             TestUDTF(lit(1), lit(2)).collect()
 
+    def test_udtf_init_with_additional_args(self):
+        @udtf(returnType="x int")
+        class TestUDTF:
+            def __init__(self, a: int):
+                ...
+
+            def eval(self, a: int):
+                yield a,
+
+        with self.assertRaisesRegex(
+            PythonException, r"__init__\(\) missing 1 required positional 
argument: 'a'"
+        ):
+            TestUDTF(lit(1)).show()
+
+    def test_udtf_terminate_with_additional_args(self):
+        @udtf(returnType="x int")
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a,
+
+            def terminate(self, a: int):
+                ...
+
+        with self.assertRaisesRegex(
+            PythonException, r"terminate\(\) missing 1 required positional 
argument: 'a'"
+        ):
+            TestUDTF(lit(1)).show()
+
     def test_udtf_with_wrong_num_output(self):
-        # TODO(SPARK-43968): check this during compile time instead of runtime
         err_msg = (
-            "java.lang.IllegalStateException: Input row doesn't have expected 
number of "
-            + "values required by the schema."
+            r"\[UDTF_RETURN_SCHEMA_MISMATCH\] The number of columns in the "
+            "result does not match the specified schema."
         )
 
+        # Output less columns than specified return schema
         @udtf(returnType="a: int, b: int")
         class TestUDTF:
             def eval(self, a: int):
                 yield a,
 
-        with self.assertRaisesRegex(Py4JJavaError, err_msg):
+        with self.assertRaisesRegex(PythonException, err_msg):
             TestUDTF(lit(1)).collect()
 
+        # Output more columns than specified return schema
         @udtf(returnType="a: int")
         class TestUDTF:
             def eval(self, a: int):
                 yield a, a + 1
 
-        with self.assertRaisesRegex(Py4JJavaError, err_msg):
+        with self.assertRaisesRegex(PythonException, err_msg):
             TestUDTF(lit(1)).collect()
 
+    def test_udtf_with_empty_output_schema_and_non_empty_output(self):
+        @udtf(returnType=StructType())
+        class TestUDTF:
+            def eval(self):
+                yield 1,
+
+        with self.assertRaisesRegex(PythonException, 
"UDTF_RETURN_SCHEMA_MISMATCH"):
+            TestUDTF().collect()
+
+    def test_udtf_with_non_empty_output_schema_and_empty_output(self):
+        @udtf(returnType="a: int")
+        class TestUDTF:
+            def eval(self):
+                yield tuple()
+
+        with self.assertRaisesRegex(PythonException, 
"UDTF_RETURN_SCHEMA_MISMATCH"):
+            TestUDTF().collect()
+
     def test_udtf_init(self):
         @udtf(returnType="a: int, b: int, c: string")
         class TestUDTF:
@@ -348,8 +393,8 @@ class BaseUDTFTestsMixin:
 
     def test_udtf_terminate_with_wrong_num_output(self):
         err_msg = (
-            "java.lang.IllegalStateException: Input row doesn't have expected 
number of "
-            "values required by the schema."
+            r"\[UDTF_RETURN_SCHEMA_MISMATCH\] The number of columns in the 
result "
+            "does not match the specified schema."
         )
 
         @udtf(returnType="a: int, b: int")
@@ -360,7 +405,7 @@ class BaseUDTFTestsMixin:
             def terminate(self):
                 yield 1, 2, 3
 
-        with self.assertRaisesRegex(Py4JJavaError, err_msg):
+        with self.assertRaisesRegex(PythonException, err_msg):
             TestUDTF(lit(1)).show()
 
         @udtf(returnType="a: int, b: int")
@@ -371,7 +416,7 @@ class BaseUDTFTestsMixin:
             def terminate(self):
                 yield 1,
 
-        with self.assertRaisesRegex(Py4JJavaError, err_msg):
+        with self.assertRaisesRegex(PythonException, err_msg):
             TestUDTF(lit(1)).show()
 
     def test_nondeterministic_udtf(self):
@@ -548,18 +593,26 @@ class BaseUDTFTestsMixin:
         )
 
     def test_udtf_with_no_handler_class(self):
-        err_msg = "the function handler must be a class"
-        with self.assertRaisesRegex(TypeError, err_msg):
+        with self.assertRaises(PySparkTypeError) as e:
 
             @udtf(returnType="a: int")
             def test_udtf(a: int):
                 yield a,
 
-        def test_udtf(a: int):
-            yield a
+        self.check_error(
+            exception=e.exception,
+            error_class="INVALID_UDTF_HANDLER_TYPE",
+            message_parameters={"type": "function"},
+        )
+
+        with self.assertRaises(PySparkTypeError) as e:
+            udtf(1, returnType="a: int")
 
-        with self.assertRaisesRegex(TypeError, err_msg):
-            udtf(test_udtf, returnType="a: int")
+        self.check_error(
+            exception=e.exception,
+            error_class="INVALID_UDTF_HANDLER_TYPE",
+            message_parameters={"type": "int"},
+        )
 
     def test_udtf_with_table_argument_query(self):
         class TestUDTF:
@@ -1498,52 +1551,6 @@ class UDTFArrowTestsMixin(BaseUDTFTestsMixin):
         func = udtf(TestUDTF, returnType="a: int")
         self.assertEqual(func(lit(1)).collect(), [Row(a=1)])
 
-    def test_udtf_terminate_with_wrong_num_output(self):
-        # The error message for arrow-optimized UDTF is different from regular 
UDTF.
-        err_msg = "The number of columns in the result does not match the 
specified schema."
-
-        @udtf(returnType="a: int, b: int")
-        class TestUDTF:
-            def eval(self, a: int):
-                yield a, a + 1
-
-            def terminate(self):
-                yield 1, 2, 3
-
-        with self.assertRaisesRegex(PythonException, err_msg):
-            TestUDTF(lit(1)).show()
-
-        @udtf(returnType="a: int, b: int")
-        class TestUDTF:
-            def eval(self, a: int):
-                yield a, a + 1
-
-            def terminate(self):
-                yield 1,
-
-        with self.assertRaisesRegex(PythonException, err_msg):
-            TestUDTF(lit(1)).show()
-
-    def test_udtf_with_wrong_num_output(self):
-        # The error message for arrow-optimized UDTF is different.
-        err_msg = "The number of columns in the result does not match the 
specified schema."
-
-        @udtf(returnType="a: int, b: int")
-        class TestUDTF:
-            def eval(self, a: int):
-                yield a,
-
-        with self.assertRaisesRegex(PythonException, err_msg):
-            TestUDTF(lit(1)).collect()
-
-        @udtf(returnType="a: int")
-        class TestUDTF:
-            def eval(self, a: int):
-                yield a, a + 1
-
-        with self.assertRaisesRegex(PythonException, err_msg):
-            TestUDTF(lit(1)).collect()
-
 
 class UDTFArrowTests(UDTFArrowTestsMixin, ReusedSQLTestCase):
     @classmethod
diff --git a/python/pyspark/sql/udtf.py b/python/pyspark/sql/udtf.py
index 67d4ef33777..d14a263f839 100644
--- a/python/pyspark/sql/udtf.py
+++ b/python/pyspark/sql/udtf.py
@@ -186,13 +186,10 @@ def _vectorize_udtf(cls: Type) -> Type:
 
 def _validate_udtf_handler(cls: Any, returnType: Optional[Union[StructType, 
str]]) -> None:
     """Validate the handler class of a UDTF."""
-    # TODO(SPARK-43968): add more compile time checks for UDTFs
 
     if not isinstance(cls, type):
         raise PySparkTypeError(
-            f"Invalid user defined table function: the function handler "
-            f"must be a class, but got {type(cls).__name__}. Please provide "
-            "a class as the handler."
+            error_class="INVALID_UDTF_HANDLER_TYPE", 
message_parameters={"type": type(cls).__name__}
         )
 
     if not hasattr(cls, "eval"):
@@ -237,6 +234,8 @@ class UserDefinedTableFunction:
         evalType: int = PythonEvalType.SQL_TABLE_UDF,
         deterministic: bool = True,
     ):
+        _validate_udtf_handler(func, returnType)
+
         self.func = func
         self._returnType = returnType
         self._returnType_placeholder: Optional[StructType] = None
@@ -246,8 +245,6 @@ class UserDefinedTableFunction:
         self.evalType = evalType
         self.deterministic = deterministic
 
-        _validate_udtf_handler(func, returnType)
-
     @property
     def returnType(self) -> Optional[StructType]:
         if self._returnType is None:
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index bfe788faf6d..20e856c9add 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -604,8 +604,11 @@ def read_udtf(pickleSer, infile, eval_type):
                         },
                     )
 
-                # Check when the dataframe has both rows and columns.
-                if not result.empty or len(result.columns) != 0:
+                # Validate the output schema when the result dataframe has 
either output
+                # rows or columns. Note that we avoid using `df.empty` here 
because the
+                # result dataframe may contain an empty row. For example, when 
a UDTF is
+                # defined as follows: def eval(self): yield tuple().
+                if len(result) > 0 or len(result.columns) > 0:
                     if len(result.columns) != len(return_type):
                         raise PySparkRuntimeError(
                             error_class="UDTF_RETURN_SCHEMA_MISMATCH",
@@ -654,6 +657,19 @@ def read_udtf(pickleSer, infile, eval_type):
             assert return_type.needConversion()
             toInternal = return_type.toInternal
 
+            def verify_and_convert_result(result):
+                # TODO(SPARK-44005): support returning non-tuple values
+                if result is not None and hasattr(result, "__len__"):
+                    if len(result) != len(return_type):
+                        raise PySparkRuntimeError(
+                            error_class="UDTF_RETURN_SCHEMA_MISMATCH",
+                            message_parameters={
+                                "expected": str(len(return_type)),
+                                "actual": str(len(result)),
+                            },
+                        )
+                return toInternal(result)
+
             # Evaluate the function and return a tuple back to the executor.
             def evaluate(*a) -> tuple:
                 res = f(*a)
@@ -665,7 +681,7 @@ def read_udtf(pickleSer, infile, eval_type):
                 else:
                     # If the function returns a result, we map it to the 
internal representation and
                     # returns the results as a tuple.
-                    return tuple(map(toInternal, res))
+                    return tuple(map(verify_and_convert_result, res))
 
             return evaluate
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to