This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 062ac987063 [SPARK-43296][CONNECT][PYTHON] Migrate Spark Connect 
session errors into error class
062ac987063 is described below

commit 062ac987063f8814c8d92925ddc6d2c72df2d208
Author: itholic <haejoon....@databricks.com>
AuthorDate: Tue May 9 10:48:34 2023 +0800

    [SPARK-43296][CONNECT][PYTHON] Migrate Spark Connect session errors into 
error class
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to migrate Spark Connect session errors into error class
    
    ### Why are the changes needed?
    
    To improve PySpark error usability.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No API changes.
    
    ### How was this patch tested?
    
    The existing CI should pass.
    
    Closes #40964 from itholic/error_connect_session.
    
    Authored-by: itholic <haejoon....@databricks.com>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 python/pyspark/errors/error_classes.py             | 30 +++++++++
 python/pyspark/sql/connect/session.py              | 77 +++++++++++++++-------
 python/pyspark/sql/pandas/conversion.py            |  6 +-
 .../sql/tests/connect/test_connect_basic.py        | 44 ++++++++-----
 python/pyspark/sql/tests/test_arrow.py             |  8 ++-
 5 files changed, 123 insertions(+), 42 deletions(-)

diff --git a/python/pyspark/errors/error_classes.py 
b/python/pyspark/errors/error_classes.py
index 6af8d5bc6ff..c7b00e0736d 100644
--- a/python/pyspark/errors/error_classes.py
+++ b/python/pyspark/errors/error_classes.py
@@ -39,6 +39,11 @@ ERROR_CLASSES_JSON = """
       "Attribute `<attr_name>` is not supported."
     ]
   },
+  "AXIS_LENGTH_MISMATCH" : {
+    "message" : [
+      "Length mismatch: Expected axis has <expected_length> element, new 
values have <actual_length> elements."
+    ]
+  },
   "BROADCAST_VARIABLE_NOT_LOADED": {
     "message": [
       "Broadcast variable `<variable>` not loaded."
@@ -94,6 +99,11 @@ ERROR_CLASSES_JSON = """
       "Can not infer Array Type from an list with None as the first element."
     ]
   },
+  "CANNOT_INFER_EMPTY_SCHEMA": {
+    "message": [
+      "Can not infer schema from empty dataset."
+    ]
+  },
   "CANNOT_INFER_SCHEMA_FOR_TYPE": {
     "message": [
       "Can not infer schema for type: `<data_type>`."
@@ -195,6 +205,11 @@ ERROR_CLASSES_JSON = """
       "All items in `<arg_name>` should be in <allowed_types>, got 
<item_type>."
     ]
   },
+  "INVALID_NDARRAY_DIMENSION": {
+    "message": [
+      "NumPy array input should be of <dimensions> dimensions."
+    ]
+  },
   "INVALID_PANDAS_UDF" : {
     "message" : [
       "Invalid function: <detail>"
@@ -215,6 +230,11 @@ ERROR_CLASSES_JSON = """
       "Timeout timestamp (<timestamp>) cannot be earlier than the current 
watermark (<watermark>)."
     ]
   },
+  "INVALID_TYPE" : {
+    "message" : [
+      "Argument `<arg_name>` should not be a <data_type>."
+    ]
+  },
   "INVALID_TYPENAME_CALL" : {
     "message" : [
       "StructField does not have typeName. Use typeName on its type explicitly 
instead."
@@ -556,6 +576,11 @@ ERROR_CLASSES_JSON = """
       "Result vector from pandas_udf was not the required length: expected 
<expected>, got <actual>."
     ]
   },
+  "SESSION_OR_CONTEXT_EXISTS" : {
+    "message" : [
+      "There should not be an existing Spark Session or Spark Context."
+    ]
+  },
   "SLICE_WITH_STEP" : {
     "message" : [
       "Slice with step is not supported."
@@ -611,6 +636,11 @@ ERROR_CLASSES_JSON = """
       "Unsupported DataType `<data_type>`."
     ]
   },
+  "UNSUPPORTED_DATA_TYPE_FOR_ARROW" : {
+    "message" : [
+      "Single data type <data_type> is not supported with Arrow."
+    ]
+  },
   "UNSUPPORTED_DATA_TYPE_FOR_ARROW_CONVERSION" : {
     "message" : [
       "<data_type> is not supported in conversion to Arrow."
diff --git a/python/pyspark/sql/connect/session.py 
b/python/pyspark/sql/connect/session.py
index 4f8fa419119..c23b6c5d11a 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -68,7 +68,13 @@ from pyspark.sql.types import (
     TimestampType,
 )
 from pyspark.sql.utils import to_str
-from pyspark.errors import PySparkAttributeError, PySparkNotImplementedError
+from pyspark.errors import (
+    PySparkAttributeError,
+    PySparkNotImplementedError,
+    PySparkRuntimeError,
+    PySparkValueError,
+    PySparkTypeError,
+)
 
 if TYPE_CHECKING:
     from pyspark.sql.connect._typing import OptionalPrimitiveType
@@ -153,8 +159,7 @@ class SparkSession:
 
         def enableHiveSupport(self) -> "SparkSession.Builder":
             raise PySparkNotImplementedError(
-                error_class="NOT_IMPLEMENTED",
-                message_parameters={"feature": "enableHiveSupport"},
+                error_class="NOT_IMPLEMENTED", message_parameters={"feature": 
"enableHiveSupport"}
             )
 
         def getOrCreate(self) -> "SparkSession":
@@ -233,7 +238,10 @@ class SparkSession:
         Infer schema from list of Row, dict, or tuple.
         """
         if not data:
-            raise ValueError("can not infer schema from empty dataset")
+            raise PySparkValueError(
+                error_class="CANNOT_INFER_EMPTY_SCHEMA",
+                message_parameters={},
+            )
 
         (
             infer_dict_as_struct,
@@ -265,7 +273,10 @@ class SparkSession:
     ) -> "DataFrame":
         assert data is not None
         if isinstance(data, DataFrame):
-            raise TypeError("data is already a DataFrame")
+            raise PySparkTypeError(
+                error_class="INVALID_TYPE",
+                message_parameters={"arg_name": "data", "data_type": 
"DataFrame"},
+            )
 
         _schema: Optional[Union[AtomicType, StructType]] = None
         _cols: Optional[List[str]] = None
@@ -289,12 +300,18 @@ class SparkSession:
             _num_cols = len(_cols)
 
         if isinstance(data, np.ndarray) and data.ndim not in [1, 2]:
-            raise ValueError("NumPy array input should be of 1 or 2 
dimensions.")
+            raise PySparkValueError(
+                error_class="INVALID_NDARRAY_DIMENSION",
+                message_parameters={"dimensions": "1 or 2"},
+            )
         elif isinstance(data, Sized) and len(data) == 0:
             if _schema is not None:
                 return DataFrame.withPlan(LocalRelation(table=None, 
schema=_schema.json()), self)
             else:
-                raise ValueError("can not infer schema from empty dataset")
+                raise PySparkValueError(
+                    error_class="CANNOT_INFER_EMPTY_SCHEMA",
+                    message_parameters={},
+                )
 
         _table: Optional[pa.Table] = None
 
@@ -317,7 +334,10 @@ class SparkSession:
                 arrow_types = [field.type for field in arrow_schema]
                 _cols = [str(x) if not isinstance(x, str) else x for x in 
schema.fieldNames()]
             elif isinstance(schema, DataType):
-                raise ValueError("Single data type %s is not supported with 
Arrow" % str(schema))
+                raise PySparkTypeError(
+                    error_class="UNSUPPORTED_DATA_TYPE_FOR_ARROW",
+                    message_parameters={"data_type": str(schema)},
+                )
             else:
                 # Any timestamps must be coerced to be compatible with Spark
                 arrow_types = [
@@ -354,17 +374,23 @@ class SparkSession:
 
             if data.ndim == 1:
                 if 1 != len(_cols):
-                    raise ValueError(
-                        f"Length mismatch: Expected axis has {len(_cols)} 
element, "
-                        "new values have 1 elements"
+                    raise PySparkValueError(
+                        error_class="AXIS_LENGTH_MISMATCH",
+                        message_parameters={
+                            "expected_length": str(len(_cols)),
+                            "actual_length": "1",
+                        },
                     )
 
                 _table = pa.Table.from_arrays([pa.array(data)], _cols)
             else:
                 if data.shape[1] != len(_cols):
-                    raise ValueError(
-                        f"Length mismatch: Expected axis has {len(_cols)} 
elements, "
-                        f"new values have {data.shape[1]} elements"
+                    raise PySparkValueError(
+                        error_class="AXIS_LENGTH_MISMATCH",
+                        message_parameters={
+                            "expected_length": str(len(_cols)),
+                            "actual_length": str(data.shape[1]),
+                        },
                     )
 
                 _table = pa.Table.from_arrays(
@@ -416,9 +442,12 @@ class SparkSession:
         # TODO: Beside the validation on number of columns, we should also 
check
         # whether the Arrow Schema is compatible with the user provided Schema.
         if _num_cols is not None and _num_cols != _table.shape[1]:
-            raise ValueError(
-                f"Length mismatch: Expected axis has {_num_cols} elements, "
-                f"new values have {_table.shape[1]} elements"
+            raise PySparkValueError(
+                error_class="AXIS_LENGTH_MISMATCH",
+                message_parameters={
+                    "expected_length": str(_num_cols),
+                    "actual_length": str(_table.shape[1]),
+                },
             )
 
         if _schema is not None:
@@ -517,14 +546,12 @@ class SparkSession:
     @classmethod
     def getActiveSession(cls) -> Any:
         raise PySparkNotImplementedError(
-            error_class="NOT_IMPLEMENTED",
-            message_parameters={"feature": "getActiveSession()"},
+            error_class="NOT_IMPLEMENTED", message_parameters={"feature": 
"getActiveSession()"}
         )
 
     def newSession(self) -> Any:
         raise PySparkNotImplementedError(
-            error_class="NOT_IMPLEMENTED",
-            message_parameters={"feature": "newSession()"},
+            error_class="NOT_IMPLEMENTED", message_parameters={"feature": 
"newSession()"}
         )
 
     @property
@@ -534,8 +561,7 @@ class SparkSession:
     @property
     def sparkContext(self) -> Any:
         raise PySparkNotImplementedError(
-            error_class="NOT_IMPLEMENTED",
-            message_parameters={"feature": "sparkContext()"},
+            error_class="NOT_IMPLEMENTED", message_parameters={"feature": 
"sparkContext()"}
         )
 
     @property
@@ -705,7 +731,10 @@ class SparkSession:
                 if origin_remote is not None:
                     os.environ["SPARK_REMOTE"] = origin_remote
         else:
-            raise RuntimeError("There should not be an existing Spark Session 
or Spark Context.")
+            raise PySparkRuntimeError(
+                error_class="SESSION_OR_CONTEXT_EXISTS",
+                message_parameters={},
+            )
 
     @property
     def session_id(self) -> str:
diff --git a/python/pyspark/sql/pandas/conversion.py 
b/python/pyspark/sql/pandas/conversion.py
index 0c29dcceed0..a4503661cad 100644
--- a/python/pyspark/sql/pandas/conversion.py
+++ b/python/pyspark/sql/pandas/conversion.py
@@ -31,6 +31,7 @@ from pyspark.sql.pandas.serializers import 
ArrowCollectSerializer
 from pyspark.sql.types import TimestampType, StructType, DataType
 from pyspark.sql.utils import is_timestamp_ntz_preferred
 from pyspark.traceback_utils import SCCallSiteSync
+from pyspark.errors import PySparkTypeError
 
 if TYPE_CHECKING:
     import numpy as np
@@ -488,7 +489,10 @@ class SparkConversionMixin:
         if isinstance(schema, StructType):
             arrow_types = [to_arrow_type(f.dataType) for f in schema.fields]
         elif isinstance(schema, DataType):
-            raise ValueError("Single data type %s is not supported with Arrow" 
% str(schema))
+            raise PySparkTypeError(
+                error_class="UNSUPPORTED_DATA_TYPE_FOR_ARROW",
+                message_parameters={"data_type": str(schema)},
+            )
         else:
             # Any timestamps must be coerced to be compatible with Spark
             arrow_types = [
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 45dbe182f12..b0bc2cba78e 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -552,21 +552,27 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
                 self.assertEqual(sdf.schema, cdf.schema)
                 self.assert_eq(sdf.toPandas(), cdf.toPandas())
 
-        with self.assertRaisesRegex(
-            ValueError,
-            "Length mismatch: Expected axis has 5 elements, new values have 4 
elements",
-        ):
+        with self.assertRaises(PySparkValueError) as pe:
             self.connect.createDataFrame(data, ["a", "b", "c", "d", "e"])
 
+        self.check_error(
+            exception=pe.exception,
+            error_class="AXIS_LENGTH_MISMATCH",
+            message_parameters={"expected_length": "5", "actual_length": "4"},
+        )
+
         with self.assertRaises(ParseException):
             self.connect.createDataFrame(data, "col1 magic_type, col2 int, 
col3 int, col4 int")
 
-        with self.assertRaisesRegex(
-            ValueError,
-            "Length mismatch: Expected axis has 3 elements, new values have 4 
elements",
-        ):
+        with self.assertRaises(PySparkValueError) as pe:
             self.connect.createDataFrame(data, "col1 int, col2 int, col3 int")
 
+        self.check_error(
+            exception=pe.exception,
+            error_class="AXIS_LENGTH_MISMATCH",
+            message_parameters={"expected_length": "3", "actual_length": "4"},
+        )
+
         # test 1 dim ndarray
         data = np.array([1.0, 2.0, np.nan, 3.0, 4.0, float("NaN"), 5.0])
         self.assertEqual(data.ndim, 1)
@@ -599,12 +605,15 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
             self.assertEqual(sdf.schema, cdf.schema)
             self.assert_eq(sdf.toPandas(), cdf.toPandas())
 
-        with self.assertRaisesRegex(
-            ValueError,
-            "Length mismatch: Expected axis has 5 elements, new values have 4 
elements",
-        ):
+        with self.assertRaises(PySparkValueError) as pe:
             self.connect.createDataFrame(data, ["a", "b", "c", "d", "e"])
 
+        self.check_error(
+            exception=pe.exception,
+            error_class="AXIS_LENGTH_MISMATCH",
+            message_parameters={"expected_length": "5", "actual_length": "4"},
+        )
+
         with self.assertRaises(ParseException):
             self.connect.createDataFrame(data, "col1 magic_type, col2 int, 
col3 int, col4 int")
 
@@ -765,12 +774,15 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
             self.assert_eq(cdf.toPandas(), sdf.toPandas())
 
         # check error
-        with self.assertRaisesRegex(
-            ValueError,
-            "can not infer schema from empty dataset",
-        ):
+        with self.assertRaises(PySparkValueError) as pe:
             self.connect.createDataFrame(data=[])
 
+        self.check_error(
+            exception=pe.exception,
+            error_class="CANNOT_INFER_EMPTY_SCHEMA",
+            message_parameters={},
+        )
+
     def test_create_dataframe_from_arrays(self):
         # SPARK-42021: createDataFrame support array.array
         data1 = [Row(a=1, b=array.array("i", [1, 2, 3]), c=array.array("d", 
[4, 5, 6]))]
diff --git a/python/pyspark/sql/tests/test_arrow.py 
b/python/pyspark/sql/tests/test_arrow.py
index 52e13782199..91fc6969185 100644
--- a/python/pyspark/sql/tests/test_arrow.py
+++ b/python/pyspark/sql/tests/test_arrow.py
@@ -542,9 +542,15 @@ class ArrowTestsMixin:
     def check_createDataFrame_with_single_data_type(self):
         for schema in ["int", IntegerType()]:
             with self.subTest(schema=schema):
-                with self.assertRaisesRegex(ValueError, ".*IntegerType.*not 
supported.*"):
+                with self.assertRaises(PySparkTypeError) as pe:
                     self.spark.createDataFrame(pd.DataFrame({"a": [1]}), 
schema=schema).collect()
 
+                self.check_error(
+                    exception=pe.exception,
+                    error_class="UNSUPPORTED_DATA_TYPE_FOR_ARROW",
+                    message_parameters={"data_type": "IntegerType()"},
+                )
+
     def test_createDataFrame_does_not_modify_input(self):
         # Some series get converted for Spark to consume, this makes sure 
input is unchanged
         pdf = self.create_pandas_data_frame()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to