This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 04816474bfc [SPARK-43261][PYTHON] Migrate `TypeError` from Spark SQL 
types into error class
04816474bfc is described below

commit 04816474bfcc05c7d90f7b7e8d35184d95c78cbd
Author: itholic <haejoon....@databricks.com>
AuthorDate: Thu Apr 27 16:55:52 2023 +0800

    [SPARK-43261][PYTHON] Migrate `TypeError` from Spark SQL types into error 
class
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to migrate `TypeError` from Spark SQL types into error 
class.
    
    ### Why are the changes needed?
    
    To improve PySpark error
    
    ### Does this PR introduce _any_ user-facing change?
    
    No API change, only error improvement.
    
    ### How was this patch tested?
    
    The existing CI should pass
    
    Closes #40926 from itholic/error_sql_types.
    
    Authored-by: itholic <haejoon....@databricks.com>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 python/pyspark/errors/error_classes.py     | 35 +++++++++++++
 python/pyspark/sql/tests/test_dataframe.py |  2 +-
 python/pyspark/sql/tests/test_functions.py | 12 +++--
 python/pyspark/sql/tests/test_types.py     | 61 +++++++++++++++++++----
 python/pyspark/sql/types.py                | 80 +++++++++++++++++++++++-------
 5 files changed, 158 insertions(+), 32 deletions(-)

diff --git a/python/pyspark/errors/error_classes.py 
b/python/pyspark/errors/error_classes.py
index 34efd471707..f35971c4a94 100644
--- a/python/pyspark/errors/error_classes.py
+++ b/python/pyspark/errors/error_classes.py
@@ -44,6 +44,11 @@ ERROR_CLASSES_JSON = """
       "Not supported to call `<func_name>` before initialize <object>."
     ]
   },
+  "CANNOT_ACCEPT_OBJECT_IN_TYPE": {
+    "message": [
+      "`<data_type>` can not accept object `<obj_name>` in type `<obj_type>`."
+    ]
+  },
   "CANNOT_ACCESS_TO_DUNDER": {
     "message": [
       "Dunder(double underscore) attribute is for internal use only."
@@ -69,11 +74,31 @@ ERROR_CLASSES_JSON = """
       "Cannot convert column into bool: please use '&' for 'and', '|' for 
'or', '~' for 'not' when building DataFrame boolean expressions."
     ]
   },
+  "CANNOT_CONVERT_TYPE": {
+    "message": [
+      "Cannot convert <from_type> into <to_type>."
+    ]
+  },
   "CANNOT_INFER_ARRAY_TYPE": {
     "message": [
       "Can not infer Array Type from an list with None as the first element."
     ]
   },
+  "CANNOT_INFER_SCHEMA_FOR_TYPE": {
+    "message": [
+      "Can not infer schema for type: `<data_type>`."
+    ]
+  },
+  "CANNOT_INFER_TYPE_FOR_FIELD": {
+    "message": [
+      "Unable to infer the type of the field `<field_name>`."
+    ]
+  },
+  "CANNOT_MERGE_TYPE": {
+    "message": [
+      "Can not merge type `<data_type1>` and `<data_type2>`."
+    ]
+  },
   "CANNOT_OPEN_SOCKET": {
     "message": [
       "Can not open socket: <errors>."
@@ -155,6 +180,11 @@ ERROR_CLASSES_JSON = """
       "Timeout timestamp (<timestamp>) cannot be earlier than the current 
watermark (<watermark>)."
     ]
   },
+  "INVALID_TYPENAME_CALL" : {
+    "message" : [
+      "StructField does not have typeName. Use typeName on its type explicitly 
instead."
+    ]
+  },
   "INVALID_UDF_EVAL_TYPE" : {
     "message" : [
       "Eval type for UDF must be <eval_type>."
@@ -335,6 +365,11 @@ ERROR_CLASSES_JSON = """
       "Argument `<arg_name>` should be an int, got <arg_type>."
     ]
   },
+  "NOT_INT_OR_SLICE_OR_STR" : {
+    "message" : [
+      "Argument `<arg_name>` should be an int, slice or str, got <arg_type>."
+    ]
+  },
   "NOT_IN_BARRIER_STAGE" : {
     "message" : [
       "It is not in a barrier stage."
diff --git a/python/pyspark/sql/tests/test_dataframe.py 
b/python/pyspark/sql/tests/test_dataframe.py
index 96b31dfee7b..27e12568b28 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -1010,7 +1010,7 @@ class DataFrameTestsMixin:
         # field types mismatch will cause exception at runtime.
         self.assertRaisesRegex(
             Exception,
-            "FloatType\\(\\) can not accept",
+            "CANNOT_ACCEPT_OBJECT_IN_TYPE",
             lambda: rdd.toDF("key: float, value: string").collect(),
         )
 
diff --git a/python/pyspark/sql/tests/test_functions.py 
b/python/pyspark/sql/tests/test_functions.py
index 38de87b0e72..9067de34633 100644
--- a/python/pyspark/sql/tests/test_functions.py
+++ b/python/pyspark/sql/tests/test_functions.py
@@ -1136,11 +1136,17 @@ class FunctionsTestsMixin:
                 expected_spark_dtypes, 
self.spark.range(1).select(F.lit(arr).alias("b")).dtypes
             )
         arr = np.array([1, 2]).astype(np.uint)
-        with self.assertRaisesRegex(
-            TypeError, "The type of array scalar '%s' is not supported" % 
arr.dtype
-        ):
+        with self.assertRaises(PySparkTypeError) as pe:
             self.spark.range(1).select(F.lit(arr).alias("b"))
 
+        self.check_error(
+            exception=pe.exception,
+            error_class="UNSUPPORTED_NUMPY_ARRAY_SCALAR",
+            message_parameters={
+                "dtype": "uint64",
+            },
+        )
+
     def test_binary_math_function(self):
         funcs, expected = zip(
             *[(F.atan2, 0.13664), (F.hypot, 8.07527), (F.pow, 2.14359), 
(F.pmod, 1.1)]
diff --git a/python/pyspark/sql/tests/test_types.py 
b/python/pyspark/sql/tests/test_types.py
index cd1ae1f2964..49952c2c135 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -26,7 +26,7 @@ import unittest
 
 from pyspark.sql import Row
 from pyspark.sql import functions as F
-from pyspark.errors import AnalysisException
+from pyspark.errors import AnalysisException, PySparkTypeError
 from pyspark.sql.types import (
     ByteType,
     ShortType,
@@ -66,6 +66,7 @@ from pyspark.testing.sqlutils import (
     PythonOnlyPoint,
     MyObject,
 )
+from pyspark.testing.utils import PySparkErrorTestUtils
 
 
 class TypesTestsMixin:
@@ -906,8 +907,13 @@ class TypesTestsMixin:
         self.assertEqual(
             _merge_type(ArrayType(LongType()), ArrayType(LongType())), 
ArrayType(LongType())
         )
-        with self.assertRaisesRegex(TypeError, "element in array"):
+        with self.assertRaises(PySparkTypeError) as pe:
             _merge_type(ArrayType(LongType()), ArrayType(DoubleType()))
+        self.check_error(
+            exception=pe.exception,
+            error_class="CANNOT_MERGE_TYPE",
+            message_parameters={"data_type1": "LongType", "data_type2": 
"DoubleType"},
+        )
 
         self.assertEqual(
             _merge_type(MapType(StringType(), LongType()), 
MapType(StringType(), LongType())),
@@ -919,8 +925,13 @@ class TypesTestsMixin:
             MapType(StringType(), LongType()),
         )
 
-        with self.assertRaisesRegex(TypeError, "value of map"):
+        with self.assertRaises(PySparkTypeError) as pe:
             _merge_type(MapType(StringType(), LongType()), 
MapType(StringType(), DoubleType()))
+        self.check_error(
+            exception=pe.exception,
+            error_class="CANNOT_MERGE_TYPE",
+            message_parameters={"data_type1": "LongType", "data_type2": 
"DoubleType"},
+        )
 
         self.assertEqual(
             _merge_type(
@@ -929,11 +940,16 @@ class TypesTestsMixin:
             ),
             StructType([StructField("f1", LongType()), StructField("f2", 
StringType())]),
         )
-        with self.assertRaisesRegex(TypeError, "field f1"):
+        with self.assertRaises(PySparkTypeError) as pe:
             _merge_type(
                 StructType([StructField("f1", LongType()), StructField("f2", 
StringType())]),
                 StructType([StructField("f1", DoubleType()), StructField("f2", 
StringType())]),
             )
+        self.check_error(
+            exception=pe.exception,
+            error_class="CANNOT_MERGE_TYPE",
+            message_parameters={"data_type1": "LongType", "data_type2": 
"DoubleType"},
+        )
 
         self.assertEqual(
             _merge_type(
@@ -961,7 +977,7 @@ class TypesTestsMixin:
             ),
             StructType([StructField("f1", ArrayType(LongType())), 
StructField("f2", StringType())]),
         )
-        with self.assertRaisesRegex(TypeError, "element in array field f1"):
+        with self.assertRaises(PySparkTypeError) as pe:
             _merge_type(
                 StructType(
                     [StructField("f1", ArrayType(LongType())), 
StructField("f2", StringType())]
@@ -970,6 +986,11 @@ class TypesTestsMixin:
                     [StructField("f1", ArrayType(DoubleType())), 
StructField("f2", StringType())]
                 ),
             )
+        self.check_error(
+            exception=pe.exception,
+            error_class="CANNOT_MERGE_TYPE",
+            message_parameters={"data_type1": "LongType", "data_type2": 
"DoubleType"},
+        )
 
         self.assertEqual(
             _merge_type(
@@ -993,7 +1014,7 @@ class TypesTestsMixin:
                 ]
             ),
         )
-        with self.assertRaisesRegex(TypeError, "value of map field f1"):
+        with self.assertRaises(PySparkTypeError) as pe:
             _merge_type(
                 StructType(
                     [
@@ -1008,6 +1029,11 @@ class TypesTestsMixin:
                     ]
                 ),
             )
+        self.check_error(
+            exception=pe.exception,
+            error_class="CANNOT_MERGE_TYPE",
+            message_parameters={"data_type1": "LongType", "data_type2": 
"DoubleType"},
+        )
 
         self.assertEqual(
             _merge_type(
@@ -1110,10 +1136,16 @@ class TypesTestsMixin:
         unsupported_types = all_types - set(supported_types)
         # test unsupported types
         for t in unsupported_types:
-            with self.assertRaisesRegex(TypeError, "infer the type of the 
field myarray"):
+            with self.assertRaises(PySparkTypeError) as pe:
                 a = array.array(t)
                 self.spark.createDataFrame([Row(myarray=a)]).collect()
 
+            self.check_error(
+                exception=pe.exception,
+                error_class="CANNOT_INFER_TYPE_FOR_FIELD",
+                message_parameters={"field_name": "myarray"},
+            )
+
     def test_repr(self):
         instances = [
             NullType(),
@@ -1304,7 +1336,7 @@ class DataTypeTests(unittest.TestCase):
         self.assertRaises(ValueError, lambda: row_class(1, 2, 3))
 
 
-class DataTypeVerificationTests(unittest.TestCase):
+class DataTypeVerificationTests(unittest.TestCase, PySparkErrorTestUtils):
     def test_verify_type_exception_msg(self):
         self.assertRaisesRegex(
             ValueError,
@@ -1313,8 +1345,17 @@ class DataTypeVerificationTests(unittest.TestCase):
         )
 
         schema = StructType([StructField("a", StructType([StructField("b", 
IntegerType())]))])
-        self.assertRaisesRegex(
-            TypeError, "field b in field a", lambda: 
_make_type_verifier(schema)([["data"]])
+        with self.assertRaises(PySparkTypeError) as pe:
+            _make_type_verifier(schema)([["data"]])
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="CANNOT_ACCEPT_OBJECT_IN_TYPE",
+            message_parameters={
+                "data_type": "IntegerType()",
+                "obj_name": "data",
+                "obj_type": "str",
+            },
         )
 
     def test_verify_type_ok_nullable(self):
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 721be76e8ba..5876d55e426 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -50,6 +50,7 @@ from py4j.java_gateway import GatewayClient, JavaClass, 
JavaGateway, JavaObject,
 
 from pyspark.serializers import CloudPickleSerializer
 from pyspark.sql.utils import has_numpy, get_active_spark_context
+from pyspark.errors import PySparkTypeError
 
 if has_numpy:
     import numpy as np
@@ -718,8 +719,9 @@ class StructField(DataType):
         return self.dataType.fromInternal(obj)
 
     def typeName(self) -> str:  # type: ignore[override]
-        raise TypeError(
-            "StructField does not have typeName. " "Use typeName on its type 
explicitly instead."
+        raise PySparkTypeError(
+            error_class="INVALID_TYPENAME_CALL",
+            message_parameters={},
         )
 
 
@@ -898,7 +900,10 @@ class StructType(DataType):
         elif isinstance(key, slice):
             return StructType(self.fields[key])
         else:
-            raise TypeError("StructType keys should be strings, integers or 
slices")
+            raise PySparkTypeError(
+                error_class="NOT_INT_OR_SLICE_OR_STR",
+                message_parameters={"arg_name": "key", "arg_type": 
type(key).__name__},
+            )
 
     def simpleString(self) -> str:
         return "struct<%s>" % (",".join(f.simpleString() for f in self))
@@ -1584,7 +1589,10 @@ def _infer_type(
         if obj.typecode in _array_type_mappings:
             return ArrayType(_array_type_mappings[obj.typecode](), False)
         else:
-            raise TypeError("not supported type: array(%s)" % obj.typecode)
+            raise PySparkTypeError(
+                error_class="UNSUPPORTED_DATA_TYPE",
+                message_parameters={"data_type": f"array({obj.typecode})"},
+            )
     else:
         try:
             return _infer_schema(
@@ -1593,7 +1601,10 @@ def _infer_type(
                 infer_array_from_first_element=infer_array_from_first_element,
             )
         except TypeError:
-            raise TypeError("not supported type: %s" % type(obj))
+            raise PySparkTypeError(
+                error_class="UNSUPPORTED_DATA_TYPE",
+                message_parameters={"data_type": type(obj).__name__},
+            )
 
 
 def _infer_schema(
@@ -1624,7 +1635,10 @@ def _infer_schema(
         items = sorted(row.__dict__.items())
 
     else:
-        raise TypeError("Can not infer schema for type: %s" % type(row))
+        raise PySparkTypeError(
+            error_class="CANNOT_INFER_SCHEMA_FOR_TYPE",
+            message_parameters={"data_type": type(row).__name__},
+        )
 
     fields = []
     for k, v in items:
@@ -1641,8 +1655,11 @@ def _infer_schema(
                     True,
                 )
             )
-        except TypeError as e:
-            raise TypeError("Unable to infer the type of the field 
{}.".format(k)) from e
+        except TypeError:
+            raise PySparkTypeError(
+                error_class="CANNOT_INFER_TYPE_FOR_FIELD",
+                message_parameters={"field_name": k},
+            )
     return StructType(fields)
 
 
@@ -1713,7 +1730,10 @@ def _merge_type(
         return a
     elif type(a) is not type(b):
         # TODO: type cast (such as int -> long)
-        raise TypeError(new_msg("Can not merge type %s and %s" % (type(a), 
type(b))))
+        raise PySparkTypeError(
+            error_class="CANNOT_MERGE_TYPE",
+            message_parameters={"data_type1": type(a).__name__, "data_type2": 
type(b).__name__},
+        )
 
     # same type
     if isinstance(a, StructType):
@@ -1801,7 +1821,10 @@ def _create_converter(dataType: DataType) -> Callable:
         elif hasattr(obj, "__dict__"):  # object
             d = obj.__dict__
         else:
-            raise TypeError("Unexpected obj type: %s" % type(obj))
+            raise PySparkTypeError(
+                error_class="UNSUPPORTED_DATA_TYPE",
+                message_parameters={"data_type": type(obj).__name__},
+            )
 
         if convert_fields:
             return tuple([conv(d.get(name)) for name, conv in zip(names, 
converters)])
@@ -1860,7 +1883,7 @@ def _make_type_verifier(
     >>> _make_type_verifier(ArrayType(StringType()))(set()) # doctest: 
+IGNORE_EXCEPTION_DETAIL
     Traceback (most recent call last):
         ...
-    TypeError:...
+    pyspark.errors.exceptions.base.PySparkTypeError:...
     >>> _make_type_verifier(MapType(StringType(), IntegerType()))({})
     >>> _make_type_verifier(StructType([]))(())
     >>> _make_type_verifier(StructType([]))([])
@@ -1883,7 +1906,9 @@ def _make_type_verifier(
     Traceback (most recent call last):
         ...
     ValueError:...
-    >>> _make_type_verifier(MapType(StringType(), IntegerType()))({None: 1})
+    >>> _make_type_verifier(  # doctest: +IGNORE_EXCEPTION_DETAIL
+    ...     MapType(StringType(), IntegerType())
+    ...     )({None: 1})
     Traceback (most recent call last):
         ...
     ValueError:...
@@ -1929,8 +1954,13 @@ def _make_type_verifier(
     def verify_acceptable_types(obj: Any) -> None:
         # subclass of them can not be fromInternal in JVM
         if type(obj) not in _acceptable_types[_type]:
-            raise TypeError(
-                new_msg("%s can not accept object %r in type %s" % (dataType, 
obj, type(obj)))
+            raise PySparkTypeError(
+                error_class="CANNOT_ACCEPT_OBJECT_IN_TYPE",
+                message_parameters={
+                    "data_type": str(dataType),
+                    "obj_name": str(obj),
+                    "obj_type": type(obj).__name__,
+                },
             )
 
     if isinstance(dataType, (StringType, CharType, VarcharType)):
@@ -2043,8 +2073,13 @@ def _make_type_verifier(
                 for f, verifier in verifiers:
                     verifier(d.get(f))
             else:
-                raise TypeError(
-                    new_msg("StructType can not accept object %r in type %s" % 
(obj, type(obj)))
+                raise PySparkTypeError(
+                    error_class="CANNOT_ACCEPT_OBJECT_IN_TYPE",
+                    message_parameters={
+                        "data_type": "StructType",
+                        "obj_name": str(obj),
+                        "obj_type": type(obj).__name__,
+                    },
                 )
 
         verify_value = verify_struct
@@ -2183,7 +2218,13 @@ class Row(tuple):
         True
         """
         if not hasattr(self, "__fields__"):
-            raise TypeError("Cannot convert a Row class into dict")
+            raise PySparkTypeError(
+                error_class="CANNOT_CONVERT_TYPE",
+                message_parameters={
+                    "from_type": "Row",
+                    "to_type": "dict",
+                },
+            )
 
         if recursive:
 
@@ -2368,7 +2409,10 @@ class NumpyArrayConverter:
         else:
             jtpe = self._from_numpy_type_to_java_type(obj.dtype, gateway)
             if jtpe is None:
-                raise TypeError("The type of array scalar '%s' is not 
supported" % (obj.dtype))
+                raise PySparkTypeError(
+                    error_class="UNSUPPORTED_NUMPY_ARRAY_SCALAR",
+                    message_parameters={"dtype": str(obj.dtype)},
+                )
         jarr = gateway.new_array(jtpe, len(obj))
         for i in range(len(plist)):
             jarr[i] = plist[i]


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to