This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new 7059b69e67d [SPARK-43968][PYTHON][3.5] Improve error messages for 
Python UDTFs with wrong number of outputs
7059b69e67d is described below

commit 7059b69e67db8126dafc3d4b1f3b39e947c4c3ca
Author: allisonwang-db <allison.w...@databricks.com>
AuthorDate: Fri Jul 28 15:30:17 2023 +0900

    [SPARK-43968][PYTHON][3.5] Improve error messages for Python UDTFs with 
wrong number of outputs
    
    ### What changes were proposed in this pull request?
    
    This PR cherry-picks 7194ce9263fe1683c039a1aaf9462657b1672a99. It improves 
the error messages for Python UDTFs when the number of outputs mismatches the 
number of outputs specified in the return type of the UDTFs.
    
    ### Why are the changes needed?
    
    To make Python UDTFs more user-friendly.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. This PR improves the error messages.
    Before this change, the error thrown by Spark will be a java 
IllegalStateException:
    ```
    java.lang.IllegalStateException: Input row doesn't have expected number of 
values required by the schema
    ```
    After this PR, it will throw a clearer error message with an error class:
    ```
    [UDTF_RETURN_SCHEMA_MISMATCH] The number of columns in the result does not 
match the specified schema
    ```
    
    ### How was this patch tested?
    
    Existing tests and new unit tests.
    
    Closes #42192 from allisonwang-db/spark-43968-3.5.
    
    Authored-by: allisonwang-db <allison.w...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/errors/error_classes.py             |   5 +
 python/pyspark/sql/connect/udtf.py                 |   4 +-
 .../pyspark/sql/tests/connect/test_parity_udtf.py  |  50 --------
 python/pyspark/sql/tests/test_udtf.py              | 133 +++++++++++----------
 python/pyspark/sql/udtf.py                         |   9 +-
 python/pyspark/worker.py                           |  22 +++-
 6 files changed, 99 insertions(+), 124 deletions(-)

diff --git a/python/pyspark/errors/error_classes.py 
b/python/pyspark/errors/error_classes.py
index b1bf6b47af9..f6411fac1da 100644
--- a/python/pyspark/errors/error_classes.py
+++ b/python/pyspark/errors/error_classes.py
@@ -320,6 +320,11 @@ ERROR_CLASSES_JSON = """
       "The eval type for the UDTF '<name>' is invalid. It must be one of 
<eval_type>."
     ]
   },
+  "INVALID_UDTF_HANDLER_TYPE" : {
+    "message" : [
+      "The UDTF is invalid. The function handler must be a class, but got 
'<type>'. Please provide a class as the function handler."
+    ]
+  },
   "INVALID_UDTF_NO_EVAL" : {
     "message" : [
       "The UDTF '<name>' is invalid. It does not implement the required 'eval' 
method. Please implement the 'eval' method in '<name>' and try again."
diff --git a/python/pyspark/sql/connect/udtf.py 
b/python/pyspark/sql/connect/udtf.py
index 1fe8e1024ee..3747e37459e 100644
--- a/python/pyspark/sql/connect/udtf.py
+++ b/python/pyspark/sql/connect/udtf.py
@@ -124,6 +124,8 @@ class UserDefinedTableFunction:
         evalType: int = PythonEvalType.SQL_TABLE_UDF,
         deterministic: bool = True,
     ) -> None:
+        _validate_udtf_handler(func)
+
         self.func = func
         self.returnType: DataType = (
             UnparsedDataType(returnType) if isinstance(returnType, str) else 
returnType
@@ -132,8 +134,6 @@ class UserDefinedTableFunction:
         self.evalType = evalType
         self.deterministic = deterministic
 
-        _validate_udtf_handler(func)
-
     def _build_common_inline_user_defined_table_function(
         self, *cols: "ColumnOrName"
     ) -> CommonInlineUserDefinedTableFunction:
diff --git a/python/pyspark/sql/tests/connect/test_parity_udtf.py 
b/python/pyspark/sql/tests/connect/test_parity_udtf.py
index e18e116e003..355f5288d2c 100644
--- a/python/pyspark/sql/tests/connect/test_parity_udtf.py
+++ b/python/pyspark/sql/tests/connect/test_parity_udtf.py
@@ -54,56 +54,6 @@ class UDTFParityTests(BaseUDTFTestsMixin, 
ReusedConnectTestCase):
         ):
             TestUDTF(lit(1)).collect()
 
-    def test_udtf_with_wrong_num_output(self):
-        err_msg = (
-            "java.lang.IllegalStateException: Input row doesn't have expected 
number of "
-            + "values required by the schema."
-        )
-
-        @udtf(returnType="a: int, b: int")
-        class TestUDTF:
-            def eval(self, a: int):
-                yield a,
-
-        with self.assertRaisesRegex(SparkConnectGrpcException, err_msg):
-            TestUDTF(lit(1)).collect()
-
-        @udtf(returnType="a: int")
-        class TestUDTF:
-            def eval(self, a: int):
-                yield a, a + 1
-
-        with self.assertRaisesRegex(SparkConnectGrpcException, err_msg):
-            TestUDTF(lit(1)).collect()
-
-    def test_udtf_terminate_with_wrong_num_output(self):
-        err_msg = (
-            "java.lang.IllegalStateException: Input row doesn't have expected 
number of "
-            "values required by the schema."
-        )
-
-        @udtf(returnType="a: int, b: int")
-        class TestUDTF:
-            def eval(self, a: int):
-                yield a, a + 1
-
-            def terminate(self):
-                yield 1, 2, 3
-
-        with self.assertRaisesRegex(SparkConnectGrpcException, err_msg):
-            TestUDTF(lit(1)).show()
-
-        @udtf(returnType="a: int, b: int")
-        class TestUDTF:
-            def eval(self, a: int):
-                yield a, a + 1
-
-            def terminate(self):
-                yield 1,
-
-        with self.assertRaisesRegex(SparkConnectGrpcException, err_msg):
-            TestUDTF(lit(1)).show()
-
 
 class ArrowUDTFParityTests(UDTFArrowTestsMixin, UDTFParityTests):
     @classmethod
diff --git a/python/pyspark/sql/tests/test_udtf.py 
b/python/pyspark/sql/tests/test_udtf.py
index 0aa769e506d..b3e832b8b97 100644
--- a/python/pyspark/sql/tests/test_udtf.py
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -19,8 +19,6 @@ import unittest
 
 from typing import Iterator
 
-from py4j.protocol import Py4JJavaError
-
 from pyspark.errors import (
     PySparkAttributeError,
     PythonException,
@@ -234,29 +232,76 @@ class BaseUDTFTestsMixin:
         ):
             TestUDTF(lit(1), lit(2)).collect()
 
+    def test_udtf_init_with_additional_args(self):
+        @udtf(returnType="x int")
+        class TestUDTF:
+            def __init__(self, a: int):
+                ...
+
+            def eval(self, a: int):
+                yield a,
+
+        with self.assertRaisesRegex(
+            PythonException, r"__init__\(\) missing 1 required positional 
argument: 'a'"
+        ):
+            TestUDTF(lit(1)).show()
+
+    def test_udtf_terminate_with_additional_args(self):
+        @udtf(returnType="x int")
+        class TestUDTF:
+            def eval(self, a: int):
+                yield a,
+
+            def terminate(self, a: int):
+                ...
+
+        with self.assertRaisesRegex(
+            PythonException, r"terminate\(\) missing 1 required positional 
argument: 'a'"
+        ):
+            TestUDTF(lit(1)).show()
+
     def test_udtf_with_wrong_num_output(self):
-        # TODO(SPARK-43968): check this during compile time instead of runtime
         err_msg = (
-            "java.lang.IllegalStateException: Input row doesn't have expected 
number of "
-            + "values required by the schema."
+            r"\[UDTF_RETURN_SCHEMA_MISMATCH\] The number of columns in the "
+            "result does not match the specified schema."
         )
 
+        # Output less columns than specified return schema
         @udtf(returnType="a: int, b: int")
         class TestUDTF:
             def eval(self, a: int):
                 yield a,
 
-        with self.assertRaisesRegex(Py4JJavaError, err_msg):
+        with self.assertRaisesRegex(PythonException, err_msg):
             TestUDTF(lit(1)).collect()
 
+        # Output more columns than specified return schema
         @udtf(returnType="a: int")
         class TestUDTF:
             def eval(self, a: int):
                 yield a, a + 1
 
-        with self.assertRaisesRegex(Py4JJavaError, err_msg):
+        with self.assertRaisesRegex(PythonException, err_msg):
             TestUDTF(lit(1)).collect()
 
+    def test_udtf_with_empty_output_schema_and_non_empty_output(self):
+        @udtf(returnType=StructType())
+        class TestUDTF:
+            def eval(self):
+                yield 1,
+
+        with self.assertRaisesRegex(PythonException, 
"UDTF_RETURN_SCHEMA_MISMATCH"):
+            TestUDTF().collect()
+
+    def test_udtf_with_non_empty_output_schema_and_empty_output(self):
+        @udtf(returnType="a: int")
+        class TestUDTF:
+            def eval(self):
+                yield tuple()
+
+        with self.assertRaisesRegex(PythonException, 
"UDTF_RETURN_SCHEMA_MISMATCH"):
+            TestUDTF().collect()
+
     def test_udtf_init(self):
         @udtf(returnType="a: int, b: int, c: string")
         class TestUDTF:
@@ -323,8 +368,8 @@ class BaseUDTFTestsMixin:
 
     def test_udtf_terminate_with_wrong_num_output(self):
         err_msg = (
-            "java.lang.IllegalStateException: Input row doesn't have expected 
number of "
-            "values required by the schema."
+            r"\[UDTF_RETURN_SCHEMA_MISMATCH\] The number of columns in the 
result "
+            "does not match the specified schema."
         )
 
         @udtf(returnType="a: int, b: int")
@@ -335,7 +380,7 @@ class BaseUDTFTestsMixin:
             def terminate(self):
                 yield 1, 2, 3
 
-        with self.assertRaisesRegex(Py4JJavaError, err_msg):
+        with self.assertRaisesRegex(PythonException, err_msg):
             TestUDTF(lit(1)).show()
 
         @udtf(returnType="a: int, b: int")
@@ -346,7 +391,7 @@ class BaseUDTFTestsMixin:
             def terminate(self):
                 yield 1,
 
-        with self.assertRaisesRegex(Py4JJavaError, err_msg):
+        with self.assertRaisesRegex(PythonException, err_msg):
             TestUDTF(lit(1)).show()
 
     def test_nondeterministic_udtf(self):
@@ -523,18 +568,26 @@ class BaseUDTFTestsMixin:
         )
 
     def test_udtf_with_no_handler_class(self):
-        err_msg = "the function handler must be a class"
-        with self.assertRaisesRegex(TypeError, err_msg):
+        with self.assertRaises(PySparkTypeError) as e:
 
             @udtf(returnType="a: int")
             def test_udtf(a: int):
                 yield a,
 
-        def test_udtf(a: int):
-            yield a
+        self.check_error(
+            exception=e.exception,
+            error_class="INVALID_UDTF_HANDLER_TYPE",
+            message_parameters={"type": "function"},
+        )
+
+        with self.assertRaises(PySparkTypeError) as e:
+            udtf(1, returnType="a: int")
 
-        with self.assertRaisesRegex(TypeError, err_msg):
-            udtf(test_udtf, returnType="a: int")
+        self.check_error(
+            exception=e.exception,
+            error_class="INVALID_UDTF_HANDLER_TYPE",
+            message_parameters={"type": "int"},
+        )
 
     def test_udtf_with_table_argument_query(self):
         class TestUDTF:
@@ -804,52 +857,6 @@ class UDTFArrowTestsMixin(BaseUDTFTestsMixin):
         func = udtf(TestUDTF, returnType="a: int")
         self.assertEqual(func(lit(1)).collect(), [Row(a=1)])
 
-    def test_udtf_terminate_with_wrong_num_output(self):
-        # The error message for arrow-optimized UDTF is different from regular 
UDTF.
-        err_msg = "The number of columns in the result does not match the 
specified schema."
-
-        @udtf(returnType="a: int, b: int")
-        class TestUDTF:
-            def eval(self, a: int):
-                yield a, a + 1
-
-            def terminate(self):
-                yield 1, 2, 3
-
-        with self.assertRaisesRegex(PythonException, err_msg):
-            TestUDTF(lit(1)).show()
-
-        @udtf(returnType="a: int, b: int")
-        class TestUDTF:
-            def eval(self, a: int):
-                yield a, a + 1
-
-            def terminate(self):
-                yield 1,
-
-        with self.assertRaisesRegex(PythonException, err_msg):
-            TestUDTF(lit(1)).show()
-
-    def test_udtf_with_wrong_num_output(self):
-        # The error message for arrow-optimized UDTF is different.
-        err_msg = "The number of columns in the result does not match the 
specified schema."
-
-        @udtf(returnType="a: int, b: int")
-        class TestUDTF:
-            def eval(self, a: int):
-                yield a,
-
-        with self.assertRaisesRegex(PythonException, err_msg):
-            TestUDTF(lit(1)).collect()
-
-        @udtf(returnType="a: int")
-        class TestUDTF:
-            def eval(self, a: int):
-                yield a, a + 1
-
-        with self.assertRaisesRegex(PythonException, err_msg):
-            TestUDTF(lit(1)).collect()
-
 
 class UDTFArrowTests(UDTFArrowTestsMixin, ReusedSQLTestCase):
     @classmethod
diff --git a/python/pyspark/sql/udtf.py b/python/pyspark/sql/udtf.py
index 3ab74193093..50bba56880c 100644
--- a/python/pyspark/sql/udtf.py
+++ b/python/pyspark/sql/udtf.py
@@ -139,13 +139,10 @@ def _vectorize_udtf(cls: Type) -> Type:
 
 def _validate_udtf_handler(cls: Any) -> None:
     """Validate the handler class of a UDTF."""
-    # TODO(SPARK-43968): add more compile time checks for UDTFs
 
     if not isinstance(cls, type):
         raise PySparkTypeError(
-            f"Invalid user defined table function: the function handler "
-            f"must be a class, but got {type(cls).__name__}. Please provide "
-            "a class as the handler."
+            error_class="INVALID_UDTF_HANDLER_TYPE", 
message_parameters={"type": type(cls).__name__}
         )
 
     if not hasattr(cls, "eval"):
@@ -176,6 +173,8 @@ class UserDefinedTableFunction:
         evalType: int = PythonEvalType.SQL_TABLE_UDF,
         deterministic: bool = True,
     ):
+        _validate_udtf_handler(func)
+
         self.func = func
         self._returnType = returnType
         self._returnType_placeholder: Optional[StructType] = None
@@ -185,8 +184,6 @@ class UserDefinedTableFunction:
         self.evalType = evalType
         self.deterministic = deterministic
 
-        _validate_udtf_handler(func)
-
     @property
     def returnType(self) -> StructType:
         # `_parse_datatype_string` accesses to JVM for parsing a DDL formatted 
string.
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 2445b46970c..cbc9faad47c 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -532,8 +532,11 @@ def read_udtf(pickleSer, infile, eval_type):
                         },
                     )
 
-                # Check when the dataframe has both rows and columns.
-                if not result.empty or len(result.columns) != 0:
+                # Validate the output schema when the result dataframe has 
either output
+                # rows or columns. Note that we avoid using `df.empty` here 
because the
+                # result dataframe may contain an empty row. For example, when 
a UDTF is
+                # defined as follows: def eval(self): yield tuple().
+                if len(result) > 0 or len(result.columns) > 0:
                     if len(result.columns) != len(return_type):
                         raise PySparkRuntimeError(
                             error_class="UDTF_RETURN_SCHEMA_MISMATCH",
@@ -580,6 +583,19 @@ def read_udtf(pickleSer, infile, eval_type):
             assert return_type.needConversion()
             toInternal = return_type.toInternal
 
+            def verify_and_convert_result(result):
+                # TODO(SPARK-44005): support returning non-tuple values
+                if result is not None and hasattr(result, "__len__"):
+                    if len(result) != len(return_type):
+                        raise PySparkRuntimeError(
+                            error_class="UDTF_RETURN_SCHEMA_MISMATCH",
+                            message_parameters={
+                                "expected": str(len(return_type)),
+                                "actual": str(len(result)),
+                            },
+                        )
+                return toInternal(result)
+
             # Evaluate the function and return a tuple back to the executor.
             def evaluate(*a) -> tuple:
                 res = f(*a)
@@ -591,7 +607,7 @@ def read_udtf(pickleSer, infile, eval_type):
                 else:
                     # If the function returns a result, we map it to the 
internal representation and
                     # returns the results as a tuple.
-                    return tuple(map(toInternal, res))
+                    return tuple(map(verify_and_convert_result, res))
 
             return evaluate
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to