This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 062ac987063 [SPARK-43296][CONNECT][PYTHON] Migrate Spark Connect session errors into error class 062ac987063 is described below commit 062ac987063f8814c8d92925ddc6d2c72df2d208 Author: itholic <haejoon....@databricks.com> AuthorDate: Tue May 9 10:48:34 2023 +0800 [SPARK-43296][CONNECT][PYTHON] Migrate Spark Connect session errors into error class ### What changes were proposed in this pull request? This PR proposes to migrate Spark Connect session errors into error class ### Why are the changes needed? To improve PySpark error usability. ### Does this PR introduce _any_ user-facing change? No API changes. ### How was this patch tested? The existing CI should pass. Closes #40964 from itholic/error_connect_session. Authored-by: itholic <haejoon....@databricks.com> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- python/pyspark/errors/error_classes.py | 30 +++++++++ python/pyspark/sql/connect/session.py | 77 +++++++++++++++------- python/pyspark/sql/pandas/conversion.py | 6 +- .../sql/tests/connect/test_connect_basic.py | 44 ++++++++----- python/pyspark/sql/tests/test_arrow.py | 8 ++- 5 files changed, 123 insertions(+), 42 deletions(-) diff --git a/python/pyspark/errors/error_classes.py b/python/pyspark/errors/error_classes.py index 6af8d5bc6ff..c7b00e0736d 100644 --- a/python/pyspark/errors/error_classes.py +++ b/python/pyspark/errors/error_classes.py @@ -39,6 +39,11 @@ ERROR_CLASSES_JSON = """ "Attribute `<attr_name>` is not supported." ] }, + "AXIS_LENGTH_MISMATCH" : { + "message" : [ + "Length mismatch: Expected axis has <expected_length> element, new values have <actual_length> elements." + ] + }, "BROADCAST_VARIABLE_NOT_LOADED": { "message": [ "Broadcast variable `<variable>` not loaded." @@ -94,6 +99,11 @@ ERROR_CLASSES_JSON = """ "Can not infer Array Type from an list with None as the first element." ] }, + "CANNOT_INFER_EMPTY_SCHEMA": { + "message": [ + "Can not infer schema from empty dataset." + ] + }, "CANNOT_INFER_SCHEMA_FOR_TYPE": { "message": [ "Can not infer schema for type: `<data_type>`." @@ -195,6 +205,11 @@ ERROR_CLASSES_JSON = """ "All items in `<arg_name>` should be in <allowed_types>, got <item_type>." ] }, + "INVALID_NDARRAY_DIMENSION": { + "message": [ + "NumPy array input should be of <dimensions> dimensions." + ] + }, "INVALID_PANDAS_UDF" : { "message" : [ "Invalid function: <detail>" @@ -215,6 +230,11 @@ ERROR_CLASSES_JSON = """ "Timeout timestamp (<timestamp>) cannot be earlier than the current watermark (<watermark>)." ] }, + "INVALID_TYPE" : { + "message" : [ + "Argument `<arg_name>` should not be a <data_type>." + ] + }, "INVALID_TYPENAME_CALL" : { "message" : [ "StructField does not have typeName. Use typeName on its type explicitly instead." @@ -556,6 +576,11 @@ ERROR_CLASSES_JSON = """ "Result vector from pandas_udf was not the required length: expected <expected>, got <actual>." ] }, + "SESSION_OR_CONTEXT_EXISTS" : { + "message" : [ + "There should not be an existing Spark Session or Spark Context." + ] + }, "SLICE_WITH_STEP" : { "message" : [ "Slice with step is not supported." @@ -611,6 +636,11 @@ ERROR_CLASSES_JSON = """ "Unsupported DataType `<data_type>`." ] }, + "UNSUPPORTED_DATA_TYPE_FOR_ARROW" : { + "message" : [ + "Single data type <data_type> is not supported with Arrow." + ] + }, "UNSUPPORTED_DATA_TYPE_FOR_ARROW_CONVERSION" : { "message" : [ "<data_type> is not supported in conversion to Arrow." diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 4f8fa419119..c23b6c5d11a 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -68,7 +68,13 @@ from pyspark.sql.types import ( TimestampType, ) from pyspark.sql.utils import to_str -from pyspark.errors import PySparkAttributeError, PySparkNotImplementedError +from pyspark.errors import ( + PySparkAttributeError, + PySparkNotImplementedError, + PySparkRuntimeError, + PySparkValueError, + PySparkTypeError, +) if TYPE_CHECKING: from pyspark.sql.connect._typing import OptionalPrimitiveType @@ -153,8 +159,7 @@ class SparkSession: def enableHiveSupport(self) -> "SparkSession.Builder": raise PySparkNotImplementedError( - error_class="NOT_IMPLEMENTED", - message_parameters={"feature": "enableHiveSupport"}, + error_class="NOT_IMPLEMENTED", message_parameters={"feature": "enableHiveSupport"} ) def getOrCreate(self) -> "SparkSession": @@ -233,7 +238,10 @@ class SparkSession: Infer schema from list of Row, dict, or tuple. """ if not data: - raise ValueError("can not infer schema from empty dataset") + raise PySparkValueError( + error_class="CANNOT_INFER_EMPTY_SCHEMA", + message_parameters={}, + ) ( infer_dict_as_struct, @@ -265,7 +273,10 @@ class SparkSession: ) -> "DataFrame": assert data is not None if isinstance(data, DataFrame): - raise TypeError("data is already a DataFrame") + raise PySparkTypeError( + error_class="INVALID_TYPE", + message_parameters={"arg_name": "data", "data_type": "DataFrame"}, + ) _schema: Optional[Union[AtomicType, StructType]] = None _cols: Optional[List[str]] = None @@ -289,12 +300,18 @@ class SparkSession: _num_cols = len(_cols) if isinstance(data, np.ndarray) and data.ndim not in [1, 2]: - raise ValueError("NumPy array input should be of 1 or 2 dimensions.") + raise PySparkValueError( + error_class="INVALID_NDARRAY_DIMENSION", + message_parameters={"dimensions": "1 or 2"}, + ) elif isinstance(data, Sized) and len(data) == 0: if _schema is not None: return DataFrame.withPlan(LocalRelation(table=None, schema=_schema.json()), self) else: - raise ValueError("can not infer schema from empty dataset") + raise PySparkValueError( + error_class="CANNOT_INFER_EMPTY_SCHEMA", + message_parameters={}, + ) _table: Optional[pa.Table] = None @@ -317,7 +334,10 @@ class SparkSession: arrow_types = [field.type for field in arrow_schema] _cols = [str(x) if not isinstance(x, str) else x for x in schema.fieldNames()] elif isinstance(schema, DataType): - raise ValueError("Single data type %s is not supported with Arrow" % str(schema)) + raise PySparkTypeError( + error_class="UNSUPPORTED_DATA_TYPE_FOR_ARROW", + message_parameters={"data_type": str(schema)}, + ) else: # Any timestamps must be coerced to be compatible with Spark arrow_types = [ @@ -354,17 +374,23 @@ class SparkSession: if data.ndim == 1: if 1 != len(_cols): - raise ValueError( - f"Length mismatch: Expected axis has {len(_cols)} element, " - "new values have 1 elements" + raise PySparkValueError( + error_class="AXIS_LENGTH_MISMATCH", + message_parameters={ + "expected_length": str(len(_cols)), + "actual_length": "1", + }, ) _table = pa.Table.from_arrays([pa.array(data)], _cols) else: if data.shape[1] != len(_cols): - raise ValueError( - f"Length mismatch: Expected axis has {len(_cols)} elements, " - f"new values have {data.shape[1]} elements" + raise PySparkValueError( + error_class="AXIS_LENGTH_MISMATCH", + message_parameters={ + "expected_length": str(len(_cols)), + "actual_length": str(data.shape[1]), + }, ) _table = pa.Table.from_arrays( @@ -416,9 +442,12 @@ class SparkSession: # TODO: Beside the validation on number of columns, we should also check # whether the Arrow Schema is compatible with the user provided Schema. if _num_cols is not None and _num_cols != _table.shape[1]: - raise ValueError( - f"Length mismatch: Expected axis has {_num_cols} elements, " - f"new values have {_table.shape[1]} elements" + raise PySparkValueError( + error_class="AXIS_LENGTH_MISMATCH", + message_parameters={ + "expected_length": str(_num_cols), + "actual_length": str(_table.shape[1]), + }, ) if _schema is not None: @@ -517,14 +546,12 @@ class SparkSession: @classmethod def getActiveSession(cls) -> Any: raise PySparkNotImplementedError( - error_class="NOT_IMPLEMENTED", - message_parameters={"feature": "getActiveSession()"}, + error_class="NOT_IMPLEMENTED", message_parameters={"feature": "getActiveSession()"} ) def newSession(self) -> Any: raise PySparkNotImplementedError( - error_class="NOT_IMPLEMENTED", - message_parameters={"feature": "newSession()"}, + error_class="NOT_IMPLEMENTED", message_parameters={"feature": "newSession()"} ) @property @@ -534,8 +561,7 @@ class SparkSession: @property def sparkContext(self) -> Any: raise PySparkNotImplementedError( - error_class="NOT_IMPLEMENTED", - message_parameters={"feature": "sparkContext()"}, + error_class="NOT_IMPLEMENTED", message_parameters={"feature": "sparkContext()"} ) @property @@ -705,7 +731,10 @@ class SparkSession: if origin_remote is not None: os.environ["SPARK_REMOTE"] = origin_remote else: - raise RuntimeError("There should not be an existing Spark Session or Spark Context.") + raise PySparkRuntimeError( + error_class="SESSION_OR_CONTEXT_EXISTS", + message_parameters={}, + ) @property def session_id(self) -> str: diff --git a/python/pyspark/sql/pandas/conversion.py b/python/pyspark/sql/pandas/conversion.py index 0c29dcceed0..a4503661cad 100644 --- a/python/pyspark/sql/pandas/conversion.py +++ b/python/pyspark/sql/pandas/conversion.py @@ -31,6 +31,7 @@ from pyspark.sql.pandas.serializers import ArrowCollectSerializer from pyspark.sql.types import TimestampType, StructType, DataType from pyspark.sql.utils import is_timestamp_ntz_preferred from pyspark.traceback_utils import SCCallSiteSync +from pyspark.errors import PySparkTypeError if TYPE_CHECKING: import numpy as np @@ -488,7 +489,10 @@ class SparkConversionMixin: if isinstance(schema, StructType): arrow_types = [to_arrow_type(f.dataType) for f in schema.fields] elif isinstance(schema, DataType): - raise ValueError("Single data type %s is not supported with Arrow" % str(schema)) + raise PySparkTypeError( + error_class="UNSUPPORTED_DATA_TYPE_FOR_ARROW", + message_parameters={"data_type": str(schema)}, + ) else: # Any timestamps must be coerced to be compatible with Spark arrow_types = [ diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 45dbe182f12..b0bc2cba78e 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -552,21 +552,27 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): self.assertEqual(sdf.schema, cdf.schema) self.assert_eq(sdf.toPandas(), cdf.toPandas()) - with self.assertRaisesRegex( - ValueError, - "Length mismatch: Expected axis has 5 elements, new values have 4 elements", - ): + with self.assertRaises(PySparkValueError) as pe: self.connect.createDataFrame(data, ["a", "b", "c", "d", "e"]) + self.check_error( + exception=pe.exception, + error_class="AXIS_LENGTH_MISMATCH", + message_parameters={"expected_length": "5", "actual_length": "4"}, + ) + with self.assertRaises(ParseException): self.connect.createDataFrame(data, "col1 magic_type, col2 int, col3 int, col4 int") - with self.assertRaisesRegex( - ValueError, - "Length mismatch: Expected axis has 3 elements, new values have 4 elements", - ): + with self.assertRaises(PySparkValueError) as pe: self.connect.createDataFrame(data, "col1 int, col2 int, col3 int") + self.check_error( + exception=pe.exception, + error_class="AXIS_LENGTH_MISMATCH", + message_parameters={"expected_length": "3", "actual_length": "4"}, + ) + # test 1 dim ndarray data = np.array([1.0, 2.0, np.nan, 3.0, 4.0, float("NaN"), 5.0]) self.assertEqual(data.ndim, 1) @@ -599,12 +605,15 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): self.assertEqual(sdf.schema, cdf.schema) self.assert_eq(sdf.toPandas(), cdf.toPandas()) - with self.assertRaisesRegex( - ValueError, - "Length mismatch: Expected axis has 5 elements, new values have 4 elements", - ): + with self.assertRaises(PySparkValueError) as pe: self.connect.createDataFrame(data, ["a", "b", "c", "d", "e"]) + self.check_error( + exception=pe.exception, + error_class="AXIS_LENGTH_MISMATCH", + message_parameters={"expected_length": "5", "actual_length": "4"}, + ) + with self.assertRaises(ParseException): self.connect.createDataFrame(data, "col1 magic_type, col2 int, col3 int, col4 int") @@ -765,12 +774,15 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): self.assert_eq(cdf.toPandas(), sdf.toPandas()) # check error - with self.assertRaisesRegex( - ValueError, - "can not infer schema from empty dataset", - ): + with self.assertRaises(PySparkValueError) as pe: self.connect.createDataFrame(data=[]) + self.check_error( + exception=pe.exception, + error_class="CANNOT_INFER_EMPTY_SCHEMA", + message_parameters={}, + ) + def test_create_dataframe_from_arrays(self): # SPARK-42021: createDataFrame support array.array data1 = [Row(a=1, b=array.array("i", [1, 2, 3]), c=array.array("d", [4, 5, 6]))] diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index 52e13782199..91fc6969185 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -542,9 +542,15 @@ class ArrowTestsMixin: def check_createDataFrame_with_single_data_type(self): for schema in ["int", IntegerType()]: with self.subTest(schema=schema): - with self.assertRaisesRegex(ValueError, ".*IntegerType.*not supported.*"): + with self.assertRaises(PySparkTypeError) as pe: self.spark.createDataFrame(pd.DataFrame({"a": [1]}), schema=schema).collect() + self.check_error( + exception=pe.exception, + error_class="UNSUPPORTED_DATA_TYPE_FOR_ARROW", + message_parameters={"data_type": "IntegerType()"}, + ) + def test_createDataFrame_does_not_modify_input(self): # Some series get converted for Spark to consume, this makes sure input is unchanged pdf = self.create_pandas_data_frame() --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org