This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 8ae9c512a55 [SPARK-43178][CONNECT][PYTHON] Migrate UDF errors into 
PySpark error framework
8ae9c512a55 is described below

commit 8ae9c512a55df1651508dc0de468fd6826955344
Author: itholic <haejoon....@databricks.com>
AuthorDate: Mon Apr 24 18:17:52 2023 +0800

    [SPARK-43178][CONNECT][PYTHON] Migrate UDF errors into PySpark error 
framework
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to migrate UDF errors into PySpark error framework.
    
    ### Why are the changes needed?
    
    To leverage the PySpark error framework so that we can provide more 
actionable and consistent errors for users.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    The existing CI should pass.
    
    Closes #40866 from itholic/udf_errors.
    
    Authored-by: itholic <haejoon....@databricks.com>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 python/pyspark/errors/error_classes.py             | 15 ++++++
 python/pyspark/sql/connect/udf.py                  | 33 +++++++-----
 .../sql/tests/pandas/test_pandas_grouped_map.py    | 13 +++--
 python/pyspark/sql/tests/pandas/test_pandas_udf.py | 22 ++++++--
 python/pyspark/sql/tests/test_udf.py               | 11 +++-
 python/pyspark/sql/udf.py                          | 62 +++++++++++++---------
 6 files changed, 108 insertions(+), 48 deletions(-)

diff --git a/python/pyspark/errors/error_classes.py 
b/python/pyspark/errors/error_classes.py
index 2b41f54def9..e3742441fe4 100644
--- a/python/pyspark/errors/error_classes.py
+++ b/python/pyspark/errors/error_classes.py
@@ -74,6 +74,11 @@ ERROR_CLASSES_JSON = """
       "<arg_list> should not be set together."
     ]
   },
+  "CANNOT_SPECIFY_RETURN_TYPE_FOR_UDF": {
+    "message": [
+      "returnType can not be specified when `<arg_name>` is a user-defined 
function, but got <return_type>."
+    ]
+  },
   "COLUMN_IN_LIST": {
     "message": [
       "`<func_name>` does not allow a Column in a list."
@@ -99,11 +104,21 @@ ERROR_CLASSES_JSON = """
       "All items in `<arg_name>` should be in <allowed_types>, got 
<item_type>."
     ]
   },
+  "INVALID_RETURN_TYPE_FOR_PANDAS_UDF": {
+    "message": [
+      "Pandas UDF should return StructType for <eval_type>, got <return_type>."
+    ]
+  },
   "INVALID_TIMEOUT_TIMESTAMP" : {
     "message" : [
       "Timeout timestamp (<timestamp>) cannot be earlier than the current 
watermark (<watermark>)."
     ]
   },
+  "INVALID_UDF_EVAL_TYPE" : {
+    "message" : [
+      "Eval type for UDF must be SQL_BATCHED_UDF, SQL_SCALAR_PANDAS_UDF, 
SQL_SCALAR_PANDAS_ITER_UDF or SQL_GROUPED_AGG_PANDAS_UDF."
+    ]
+  },
   "INVALID_WHEN_USAGE": {
     "message": [
       "when() can only be applied on a Column previously generated by when() 
function, and cannot be applied once otherwise() is applied."
diff --git a/python/pyspark/sql/connect/udf.py 
b/python/pyspark/sql/connect/udf.py
index aab7bb3c0d3..89d126e43d1 100644
--- a/python/pyspark/sql/connect/udf.py
+++ b/python/pyspark/sql/connect/udf.py
@@ -37,6 +37,7 @@ from pyspark.sql.connect.column import Column
 from pyspark.sql.connect.types import UnparsedDataType
 from pyspark.sql.types import ArrayType, DataType, MapType, StringType, 
StructType
 from pyspark.sql.udf import UDFRegistration as PySparkUDFRegistration
+from pyspark.errors import PySparkTypeError
 
 
 if TYPE_CHECKING:
@@ -125,20 +126,24 @@ class UserDefinedFunction:
         deterministic: bool = True,
     ):
         if not callable(func):
-            raise TypeError(
-                "Invalid function: not a function or callable (__call__ is not 
defined): "
-                "{0}".format(type(func))
+            raise PySparkTypeError(
+                error_class="NOT_CALLABLE",
+                message_parameters={"arg_name": "func", "arg_type": 
type(func).__name__},
             )
 
         if not isinstance(returnType, (DataType, str)):
-            raise TypeError(
-                "Invalid return type: returnType should be DataType or str "
-                "but is {}".format(returnType)
+            raise PySparkTypeError(
+                error_class="NOT_DATATYPE_OR_STR",
+                message_parameters={
+                    "arg_name": "returnType",
+                    "arg_type": type(returnType).__name__,
+                },
             )
 
         if not isinstance(evalType, int):
-            raise TypeError(
-                "Invalid evaluation type: evalType should be an int but is 
{}".format(evalType)
+            raise PySparkTypeError(
+                error_class="NOT_INT",
+                message_parameters={"arg_name": "evalType", "arg_type": 
type(evalType).__name__},
             )
 
         self.func = func
@@ -241,9 +246,9 @@ class UDFRegistration:
         # Python function.
         if hasattr(f, "asNondeterministic"):
             if returnType is not None:
-                raise TypeError(
-                    "Invalid return type: data type can not be specified when 
f is"
-                    "a user-defined function, but got %s." % returnType
+                raise PySparkTypeError(
+                    error_class="CANNOT_SPECIFY_RETURN_TYPE_FOR_UDF",
+                    message_parameters={"arg_name": "f", "return_type": 
str(returnType)},
                 )
             f = cast("UserDefinedFunctionLike", f)
             if f.evalType not in [
@@ -252,9 +257,9 @@ class UDFRegistration:
                 PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
                 PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
             ]:
-                raise ValueError(
-                    "Invalid f: f must be SQL_BATCHED_UDF, 
SQL_SCALAR_PANDAS_UDF, "
-                    "SQL_SCALAR_PANDAS_ITER_UDF or SQL_GROUPED_AGG_PANDAS_UDF."
+                raise PySparkTypeError(
+                    error_class="INVALID_UDF_EVAL_TYPE",
+                    message_parameters={},
                 )
             return_udf = f
             self.sparkSession._client.register_udf(
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py 
b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
index 36bdae02944..f5051c412ff 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
@@ -51,7 +51,7 @@ from pyspark.sql.types import (
     NullType,
     TimestampType,
 )
-from pyspark.errors import PythonException
+from pyspark.errors import PythonException, PySparkTypeError
 from pyspark.testing.sqlutils import (
     ReusedSQLTestCase,
     have_pandas,
@@ -212,12 +212,15 @@ class GroupedApplyInPandasTestsMixin:
     def test_register_grouped_map_udf(self):
         foo_udf = pandas_udf(lambda x: x, "id long", PandasUDFType.GROUPED_MAP)
         with QuietTest(self.sc):
-            with self.assertRaisesRegex(
-                ValueError,
-                
"f.*SQL_BATCHED_UDF.*SQL_SCALAR_PANDAS_UDF.*SQL_GROUPED_AGG_PANDAS_UDF.*",
-            ):
+            with self.assertRaises(PySparkTypeError) as pe:
                 self.spark.catalog.registerFunction("foo_udf", foo_udf)
 
+            self.check_error(
+                exception=pe.exception,
+                error_class="INVALID_UDF_EVAL_TYPE",
+                message_parameters={},
+            )
+
     def test_decorator(self):
         df = self.data
 
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf.py 
b/python/pyspark/sql/tests/pandas/test_pandas_udf.py
index 4e1eec38a0c..8278c03ea04 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf.py
@@ -21,7 +21,7 @@ from typing import cast
 
 from pyspark.sql.functions import udf, pandas_udf, PandasUDFType, assert_true, 
lit
 from pyspark.sql.types import DoubleType, StructType, StructField, LongType, 
DayTimeIntervalType
-from pyspark.errors import ParseException, PythonException
+from pyspark.errors import ParseException, PythonException, PySparkTypeError
 from pyspark.rdd import PythonEvalType
 from pyspark.testing.sqlutils import (
     ReusedSQLTestCase,
@@ -153,18 +153,34 @@ class PandasUDFTestsMixin:
                 def zero_with_type():
                     return 1
 
-            with self.assertRaisesRegex(TypeError, "Invalid return type"):
+            with self.assertRaises(PySparkTypeError) as pe:
 
                 @pandas_udf(returnType=PandasUDFType.GROUPED_MAP)
                 def foo(df):
                     return df
 
-            with self.assertRaisesRegex(TypeError, "Invalid return type"):
+            self.check_error(
+                exception=pe.exception,
+                error_class="NOT_DATATYPE_OR_STR",
+                message_parameters={"arg_name": "returnType", "arg_type": 
"int"},
+            )
+
+            with self.assertRaises(PySparkTypeError) as pe:
 
                 @pandas_udf(returnType="double", 
functionType=PandasUDFType.GROUPED_MAP)
                 def foo(df):
                     return df
 
+            self.check_error(
+                exception=pe.exception,
+                error_class="INVALID_RETURN_TYPE_FOR_PANDAS_UDF",
+                message_parameters={
+                    "eval_type": "SQL_GROUPED_MAP_PANDAS_UDF "
+                    "or SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE",
+                    "return_type": "DoubleType()",
+                },
+            )
+
             with self.assertRaisesRegex(ValueError, "Invalid function"):
 
                 @pandas_udf(returnType="k int, v double", 
functionType=PandasUDFType.GROUPED_MAP)
diff --git a/python/pyspark/sql/tests/test_udf.py 
b/python/pyspark/sql/tests/test_udf.py
index d8a464b006f..2e6a7813524 100644
--- a/python/pyspark/sql/tests/test_udf.py
+++ b/python/pyspark/sql/tests/test_udf.py
@@ -38,7 +38,7 @@ from pyspark.sql.types import (
     TimestampNTZType,
     DayTimeIntervalType,
 )
-from pyspark.errors import AnalysisException
+from pyspark.errors import AnalysisException, PySparkTypeError
 from pyspark.testing.sqlutils import ReusedSQLTestCase, test_compiled, 
test_not_compiled_message
 from pyspark.testing.utils import QuietTest
 
@@ -109,11 +109,18 @@ class BaseUDFTestsMixin(object):
             self.check_udf_registration_return_type_not_none()
 
     def check_udf_registration_return_type_not_none(self):
-        with self.assertRaisesRegex(TypeError, "Invalid return type"):
+        # negative test for incorrect type
+        with self.assertRaises(PySparkTypeError) as pe:
             self.spark.catalog.registerFunction(
                 "f", UserDefinedFunction(lambda x, y: len(x) + y, 
StringType()), StringType()
             )
 
+        self.check_error(
+            exception=pe.exception,
+            error_class="CANNOT_SPECIFY_RETURN_TYPE_FOR_UDF",
+            message_parameters={"arg_name": "f", "return_type": 
"StringType()"},
+        )
+
     def test_nondeterministic_udf(self):
         # Test that nondeterministic UDFs are evaluated only once in chained 
UDF evaluations
         import random
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index c486d869cba..58be9ed17ff 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -43,6 +43,7 @@ from pyspark.sql.types import (
 from pyspark.sql.utils import get_active_spark_context
 from pyspark.sql.pandas.types import to_arrow_type
 from pyspark.sql.pandas.utils import require_minimum_pandas_version, 
require_minimum_pyarrow_version
+from pyspark.errors import PySparkTypeError
 
 if TYPE_CHECKING:
     from pyspark.sql._typing import DataTypeOrString, ColumnOrName, 
UserDefinedFunctionLike
@@ -218,20 +219,24 @@ class UserDefinedFunction:
         deterministic: bool = True,
     ):
         if not callable(func):
-            raise TypeError(
-                "Invalid function: not a function or callable (__call__ is not 
defined): "
-                "{0}".format(type(func))
+            raise PySparkTypeError(
+                error_class="NOT_CALLABLE",
+                message_parameters={"arg_name": "func", "arg_type": 
type(func).__name__},
             )
 
         if not isinstance(returnType, (DataType, str)):
-            raise TypeError(
-                "Invalid return type: returnType should be DataType or str "
-                "but is {}".format(returnType)
+            raise PySparkTypeError(
+                error_class="NOT_DATATYPE_OR_STR",
+                message_parameters={
+                    "arg_name": "returnType",
+                    "arg_type": type(returnType).__name__,
+                },
             )
 
         if not isinstance(evalType, int):
-            raise TypeError(
-                "Invalid evaluation type: evalType should be an int but is 
{}".format(evalType)
+            raise PySparkTypeError(
+                error_class="NOT_INT",
+                message_parameters={"arg_name": "evalType", "arg_type": 
type(evalType).__name__},
             )
 
         self.func = func
@@ -280,10 +285,13 @@ class UserDefinedFunction:
                         % str(self._returnType_placeholder)
                     )
             else:
-                raise TypeError(
-                    "Invalid return type for grouped map Pandas "
-                    "UDFs or at groupby.applyInPandas(WithState): return type 
must be a "
-                    "StructType."
+                raise PySparkTypeError(
+                    error_class="INVALID_RETURN_TYPE_FOR_PANDAS_UDF",
+                    message_parameters={
+                        "eval_type": "SQL_GROUPED_MAP_PANDAS_UDF or "
+                        "SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE",
+                        "return_type": str(self._returnType_placeholder),
+                    },
                 )
         elif (
             self.evalType == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
@@ -298,9 +306,12 @@ class UserDefinedFunction:
                         "%s is not supported" % 
str(self._returnType_placeholder)
                     )
             else:
-                raise TypeError(
-                    "Invalid return type in mapInPandas/mapInArrow: "
-                    "return type must be a StructType."
+                raise PySparkTypeError(
+                    error_class="INVALID_RETURN_TYPE_FOR_PANDAS_UDF",
+                    message_parameters={
+                        "eval_type": "SQL_MAP_PANDAS_ITER_UDF or 
SQL_MAP_ARROW_ITER_UDF",
+                        "return_type": str(self._returnType_placeholder),
+                    },
                 )
         elif self.evalType == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
             if isinstance(self._returnType_placeholder, StructType):
@@ -312,9 +323,12 @@ class UserDefinedFunction:
                         "%s is not supported" % 
str(self._returnType_placeholder)
                     )
             else:
-                raise TypeError(
-                    "Invalid return type in cogroup.applyInPandas: "
-                    "return type must be a StructType."
+                raise PySparkTypeError(
+                    error_class="INVALID_RETURN_TYPE_FOR_PANDAS_UDF",
+                    message_parameters={
+                        "eval_type": "SQL_COGROUPED_MAP_PANDAS_UDF",
+                        "return_type": str(self._returnType_placeholder),
+                    },
                 )
         elif self.evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
             try:
@@ -591,9 +605,9 @@ class UDFRegistration:
         # Python function.
         if hasattr(f, "asNondeterministic"):
             if returnType is not None:
-                raise TypeError(
-                    "Invalid return type: data type can not be specified when 
f is"
-                    "a user-defined function, but got %s." % returnType
+                raise PySparkTypeError(
+                    error_class="CANNOT_SPECIFY_RETURN_TYPE_FOR_UDF",
+                    message_parameters={"arg_name": "f", "return_type": 
str(returnType)},
                 )
             f = cast("UserDefinedFunctionLike", f)
             if f.evalType not in [
@@ -602,9 +616,9 @@ class UDFRegistration:
                 PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
                 PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
             ]:
-                raise ValueError(
-                    "Invalid f: f must be SQL_BATCHED_UDF, 
SQL_SCALAR_PANDAS_UDF, "
-                    "SQL_SCALAR_PANDAS_ITER_UDF or SQL_GROUPED_AGG_PANDAS_UDF."
+                raise PySparkTypeError(
+                    error_class="INVALID_UDF_EVAL_TYPE",
+                    message_parameters={},
                 )
             register_udf = _create_udf(
                 f.func,


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to