This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new f66611d5e17 [SPARK-42982][CONNECT][PYTHON] Fix createDataFrame to respect the given schema ddl f66611d5e17 is described below commit f66611d5e1788745c907c6a54fe8d941a67b55b4 Author: Takuya UESHIN <ues...@databricks.com> AuthorDate: Thu Apr 13 18:31:14 2023 +0900 [SPARK-42982][CONNECT][PYTHON] Fix createDataFrame to respect the given schema ddl ### What changes were proposed in this pull request? Fixes `createDataFrame` to respect the given schema ddl. ### Why are the changes needed? Currently even if the schema is provided as a DDL string, it's not taken into account and causes the schema mismatch in the server side. For example: ```py >>> import pandas as pd >>> map_data = [{"a": 1}, {"b": 2, "c": 3}, {}, None, {"d": None}] >>> pdf = pd.DataFrame({"id": [0, 1, 2, 3, 4], "m": map_data}) >>> schema = "id long, m map<string, long>" >>> >>> spark.createDataFrame(pdf, schema=schema) Traceback (most recent call last): ... pyspark.errors.exceptions.connect.AnalysisException: [INVALID_COLUMN_OR_FIELD_DATA_TYPE] Column or field `col_1` is of type "STRUCT<col_0: BIGINT, col_1: BIGINT, col_2: BIGINT, col_3: VOID>" while it's required to be "MAP<STRING, BIGINT>". ``` ### Does this PR introduce _any_ user-facing change? The schema DDL string will be taken into account. ### How was this patch tested? Enabled/modified the related tests. Closes #40760 from ueshin/issues/SPARK-42982/schema_str. Authored-by: Takuya UESHIN <ues...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- python/pyspark/sql/connect/conversion.py | 36 ++++--- python/pyspark/sql/connect/session.py | 58 ++++------- .../sql/tests/connect/test_connect_basic.py | 111 +++++++++++++-------- .../pyspark/sql/tests/connect/test_parity_arrow.py | 6 -- python/pyspark/sql/tests/test_arrow.py | 76 ++++++++------ 5 files changed, 155 insertions(+), 132 deletions(-) diff --git a/python/pyspark/sql/connect/conversion.py b/python/pyspark/sql/connect/conversion.py index 310f26654df..5a31d1df67e 100644 --- a/python/pyspark/sql/connect/conversion.py +++ b/python/pyspark/sql/connect/conversion.py @@ -119,13 +119,17 @@ class LocalDataToArrowConversion: _dict = {} if not isinstance(value, Row) and hasattr(value, "__dict__"): value = value.__dict__ - for i, field in enumerate(field_names): - if isinstance(value, dict): - v = value.get(field) - else: - v = value[i] - - _dict[f"col_{i}"] = field_convs[i](v) + if isinstance(value, dict): + for i, field in enumerate(field_names): + _dict[f"col_{i}"] = field_convs[i](value.get(field)) + else: + if len(value) != len(field_names): + raise ValueError( + f"Length mismatch: Expected axis has {len(field_names)} elements, " + f"new values have {len(value)} elements" + ) + for i in range(len(field_names)): + _dict[f"col_{i}"] = field_convs[i](value[i]) return _dict @@ -272,13 +276,17 @@ class LocalDataToArrowConversion: for item in data: if not isinstance(item, Row) and hasattr(item, "__dict__"): item = item.__dict__ - for i, col in enumerate(column_names): - if isinstance(item, dict): - value = item.get(col) - else: - value = item[i] - - pylist[i].append(column_convs[i](value)) + if isinstance(item, dict): + for i, col in enumerate(column_names): + pylist[i].append(column_convs[i](item.get(col))) + else: + if len(item) != len(column_names): + raise ValueError( + f"Length mismatch: Expected axis has {len(column_names)} elements, " + f"new values have {len(item)} elements" + ) + for i in range(len(column_names)): + pylist[i].append(column_convs[i](item[i])) def normalize(dt: DataType) -> DataType: if isinstance(dt, StructType): diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index e7b2ec6d2a6..3d9b641658a 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -220,10 +220,14 @@ class SparkSession: raise TypeError("data is already a DataFrame") _schema: Optional[Union[AtomicType, StructType]] = None - _schema_str: Optional[str] = None _cols: Optional[List[str]] = None _num_cols: Optional[int] = None + if isinstance(schema, str): + schema = self.client._analyze( # type: ignore[assignment] + method="ddl_parse", ddl_string=schema + ).parsed + if isinstance(schema, (AtomicType, StructType)): _schema = schema if isinstance(schema, StructType): @@ -231,9 +235,6 @@ class SparkSession: else: _num_cols = 1 - elif isinstance(schema, str): - _schema_str = schema - elif isinstance(schema, (list, tuple)): # Must re-encode any unicode strings to be consistent with StructField names _cols = [x.encode("utf-8") if not isinstance(x, str) else x for x in schema] @@ -244,13 +245,10 @@ class SparkSession: elif isinstance(data, Sized) and len(data) == 0: if _schema is not None: return DataFrame.withPlan(LocalRelation(table=None, schema=_schema.json()), self) - elif _schema_str is not None: - return DataFrame.withPlan(LocalRelation(table=None, schema=_schema_str), self) else: raise ValueError("can not infer schema from empty dataset") _table: Optional[pa.Table] = None - _inferred_schema: Optional[StructType] = None if isinstance(data, pd.DataFrame): # Logic was borrowed from `_create_from_pandas_with_arrow` in @@ -309,16 +307,16 @@ class SparkSession: if data.ndim == 1: if 1 != len(_cols): raise ValueError( - f"Length mismatch: Expected axis has 1 element, " - f"new values have {len(_cols)} elements" + f"Length mismatch: Expected axis has {len(_cols)} element, " + "new values have 1 elements" ) _table = pa.Table.from_arrays([pa.array(data)], _cols) else: if data.shape[1] != len(_cols): raise ValueError( - f"Length mismatch: Expected axis has {data.shape[1]} elements, " - f"new values have {len(_cols)} elements" + f"Length mismatch: Expected axis has {len(_cols)} elements, " + f"new values have {data.shape[1]} elements" ) _table = pa.Table.from_arrays( @@ -334,7 +332,7 @@ class SparkSession: if isinstance(_data[0], dict): # Sort the data to respect inferred schema. # For dictionaries, we sort the schema in alphabetical order. - _data = [dict(sorted(d.items())) for d in _data] + _data = [dict(sorted(d.items())) if d is not None else None for d in _data] elif not isinstance(_data[0], (Row, tuple, list, dict)) and not hasattr( _data[0], "__dict__" @@ -344,44 +342,28 @@ class SparkSession: _data = [[d] for d in _data] if _schema is not None: - if isinstance(_schema, StructType): - _inferred_schema = _schema - else: - _inferred_schema = StructType().add("value", _schema) + if not isinstance(_schema, StructType): + _schema = StructType().add("value", _schema) else: - _inferred_schema = self._inferSchemaFromList(_data, _cols) + _schema = self._inferSchemaFromList(_data, _cols) if _cols is not None and cast(int, _num_cols) < len(_cols): _num_cols = len(_cols) - if _has_nulltype(_inferred_schema): + if _has_nulltype(_schema): # For cases like createDataFrame([("Alice", None, 80.1)], schema) # we can not infer the schema from the data itself. - warnings.warn("failed to infer the schema from data") - if _schema_str is not None: - _parsed = self.client._analyze( - method="ddl_parse", ddl_string=_schema_str - ).parsed - if isinstance(_parsed, StructType): - _inferred_schema = _parsed - elif isinstance(_parsed, DataType): - _inferred_schema = StructType().add("value", _parsed) - _schema_str = None - if _has_nulltype(_inferred_schema): - raise ValueError( - "Some of types cannot be determined after inferring, " - "a StructType Schema is required in this case" - ) - - if _schema_str is None: - _schema = _inferred_schema + raise ValueError( + "Some of types cannot be determined after inferring, " + "a StructType Schema is required in this case" + ) from pyspark.sql.connect.conversion import LocalDataToArrowConversion # Spark Connect will try its best to build the Arrow table with the # inferred schema in the client side, and then rename the columns and # cast the datatypes in the server side. - _table = LocalDataToArrowConversion.convert(_data, _inferred_schema) + _table = LocalDataToArrowConversion.convert(_data, _schema) # TODO: Beside the validation on number of columns, we should also check # whether the Arrow Schema is compatible with the user provided Schema. @@ -393,8 +375,6 @@ class SparkSession: if _schema is not None: df = DataFrame.withPlan(LocalRelation(_table, schema=_schema.json()), self) - elif _schema_str is not None: - df = DataFrame.withPlan(LocalRelation(_table, schema=_schema_str), self) else: df = DataFrame.withPlan(LocalRelation(_table), self) diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 2658ad79ab0..166a8b23bd7 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -555,17 +555,18 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): with self.assertRaisesRegex( ValueError, - "Length mismatch: Expected axis has 4 elements, new values have 5 elements", + "Length mismatch: Expected axis has 5 elements, new values have 4 elements", ): self.connect.createDataFrame(data, ["a", "b", "c", "d", "e"]) with self.assertRaises(ParseException): - self.connect.createDataFrame( - data, "col1 magic_type, col2 int, col3 int, col4 int" - ).show() + self.connect.createDataFrame(data, "col1 magic_type, col2 int, col3 int, col4 int") - with self.assertRaises(SparkConnectException): - self.connect.createDataFrame(data, "col1 int, col2 int, col3 int").show() + with self.assertRaisesRegex( + ValueError, + "Length mismatch: Expected axis has 3 elements, new values have 4 elements", + ): + self.connect.createDataFrame(data, "col1 int, col2 int, col3 int") # test 1 dim ndarray data = np.array([1.0, 2.0, np.nan, 3.0, 4.0, float("NaN"), 5.0]) @@ -606,12 +607,13 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): self.connect.createDataFrame(data, ["a", "b", "c", "d", "e"]) with self.assertRaises(ParseException): - self.connect.createDataFrame( - data, "col1 magic_type, col2 int, col3 int, col4 int" - ).show() + self.connect.createDataFrame(data, "col1 magic_type, col2 int, col3 int, col4 int") - with self.assertRaises(SparkConnectException): - self.connect.createDataFrame(data, "col1 int, col2 int, col3 int").show() + with self.assertRaisesRegex( + ValueError, + "Length mismatch: Expected axis has 3 elements, new values have 4 elements", + ): + self.connect.createDataFrame(data, "col1 int, col2 int, col3 int") def test_with_local_rows(self): # SPARK-41789, SPARK-41810: Test creating a dataframe with list of rows and dictionaries @@ -2982,79 +2984,106 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): self.assertEqual(cdf4.collect(), sdf4.collect()) def test_array_has_nullable(self): - for schema, data in [ + for schemas, data in [ ( - StructType().add("arr", ArrayType(IntegerType(), False), True), + [StructType().add("arr", ArrayType(IntegerType(), False), True)], [Row([1, 2]), Row([3]), Row(None)], ), ( - StructType().add("arr", ArrayType(IntegerType(), True), True), + [ + StructType().add("arr", ArrayType(IntegerType(), True), True), + "arr array<integer>", + ], [Row([1, None]), Row([3]), Row(None)], ), ( - StructType().add("arr", ArrayType(IntegerType(), False), False), + [StructType().add("arr", ArrayType(IntegerType(), False), False)], [Row([1, 2]), Row([3])], ), ( - StructType().add("arr", ArrayType(IntegerType(), True), False), + [ + StructType().add("arr", ArrayType(IntegerType(), True), False), + "arr array<integer> not null", + ], [Row([1, None]), Row([3])], ), ]: - with self.subTest(schema=schema): - cdf = self.connect.createDataFrame(data, schema=schema) - sdf = self.spark.createDataFrame(data, schema=schema) - self.assertEqual(cdf.schema, sdf.schema) - self.assertEqual(cdf.collect(), sdf.collect()) + for schema in schemas: + with self.subTest(schema=schema): + cdf = self.connect.createDataFrame(data, schema=schema) + sdf = self.spark.createDataFrame(data, schema=schema) + self.assertEqual(cdf.schema, sdf.schema) + self.assertEqual(cdf.collect(), sdf.collect()) def test_map_has_nullable(self): - for schema, data in [ + for schemas, data in [ ( - StructType().add("map", MapType(StringType(), IntegerType(), False), True), + [StructType().add("map", MapType(StringType(), IntegerType(), False), True)], [Row({"a": 1, "b": 2}), Row({"a": 3}), Row(None)], ), ( - StructType().add("map", MapType(StringType(), IntegerType(), True), True), + [ + StructType().add("map", MapType(StringType(), IntegerType(), True), True), + "map map<string, integer>", + ], [Row({"a": 1, "b": None}), Row({"a": 3}), Row(None)], ), ( - StructType().add("map", MapType(StringType(), IntegerType(), False), False), + [StructType().add("map", MapType(StringType(), IntegerType(), False), False)], [Row({"a": 1, "b": 2}), Row({"a": 3})], ), ( - StructType().add("map", MapType(StringType(), IntegerType(), True), False), + [ + StructType().add("map", MapType(StringType(), IntegerType(), True), False), + "map map<string, integer> not null", + ], [Row({"a": 1, "b": None}), Row({"a": 3})], ), ]: - with self.subTest(schema=schema): - cdf = self.connect.createDataFrame(data, schema=schema) - sdf = self.spark.createDataFrame(data, schema=schema) - self.assertEqual(cdf.schema, sdf.schema) - self.assertEqual(cdf.collect(), sdf.collect()) + for schema in schemas: + with self.subTest(schema=schema): + cdf = self.connect.createDataFrame(data, schema=schema) + sdf = self.spark.createDataFrame(data, schema=schema) + self.assertEqual(cdf.schema, sdf.schema) + self.assertEqual(cdf.collect(), sdf.collect()) def test_struct_has_nullable(self): - for schema, data in [ + for schemas, data in [ ( - StructType().add("struct", StructType().add("i", IntegerType(), False), True), + [ + StructType().add("struct", StructType().add("i", IntegerType(), False), True), + "struct struct<i: integer not null>", + ], [Row(Row(1)), Row(Row(2)), Row(None)], ), ( - StructType().add("struct", StructType().add("i", IntegerType(), True), True), + [ + StructType().add("struct", StructType().add("i", IntegerType(), True), True), + "struct struct<i: integer>", + ], [Row(Row(1)), Row(Row(2)), Row(Row(None)), Row(None)], ), ( - StructType().add("struct", StructType().add("i", IntegerType(), False), False), + [ + StructType().add("struct", StructType().add("i", IntegerType(), False), False), + "struct struct<i: integer not null> not null", + ], [Row(Row(1)), Row(Row(2))], ), ( - StructType().add("struct", StructType().add("i", IntegerType(), True), False), + [ + StructType().add("struct", StructType().add("i", IntegerType(), True), False), + "struct struct<i: integer> not null", + ], [Row(Row(1)), Row(Row(2)), Row(Row(None))], ), ]: - with self.subTest(schema=schema): - cdf = self.connect.createDataFrame(data, schema=schema) - sdf = self.spark.createDataFrame(data, schema=schema) - self.assertEqual(cdf.schema, sdf.schema) - self.assertEqual(cdf.collect(), sdf.collect()) + for schema in schemas: + with self.subTest(schema=schema): + cdf = self.connect.createDataFrame(data, schema=schema) + sdf = self.spark.createDataFrame(data, schema=schema) + self.assertEqual(cdf.schema, sdf.schema) + self.assertEqual(cdf.collect(), sdf.collect()) def test_large_client_data(self): # SPARK-42816 support more than 4MB message size. diff --git a/python/pyspark/sql/tests/connect/test_parity_arrow.py b/python/pyspark/sql/tests/connect/test_parity_arrow.py index 0ed8642383b..ec33bb22a4b 100644 --- a/python/pyspark/sql/tests/connect/test_parity_arrow.py +++ b/python/pyspark/sql/tests/connect/test_parity_arrow.py @@ -37,8 +37,6 @@ class ArrowParityTests(ArrowTestsMixin, ReusedConnectTestCase): def test_createDataFrame_with_incorrect_schema(self): self.check_createDataFrame_with_incorrect_schema() - # TODO(SPARK-42982): INVALID_COLUMN_OR_FIELD_DATA_TYPE - @unittest.skip("Fails in Spark Connect, should enable.") def test_createDataFrame_with_map_type(self): self.check_createDataFrame_with_map_type(True) @@ -92,13 +90,9 @@ class ArrowParityTests(ArrowTestsMixin, ReusedConnectTestCase): def test_toPandas_fallback_enabled(self): super().test_toPandas_fallback_enabled() - # TODO(SPARK-42982): INVALID_COLUMN_OR_FIELD_DATA_TYPE - @unittest.skip("Fails in Spark Connect, should enable.") def test_toPandas_with_map_type(self): self.check_toPandas_with_map_type(True) - # TODO(SPARK-42982): INVALID_COLUMN_OR_FIELD_DATA_TYPE - @unittest.skip("Fails in Spark Connect, should enable.") def test_toPandas_with_map_type_nulls(self): self.check_toPandas_with_map_type_nulls(True) diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index dfc65d02b55..04aaa2b1c32 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -43,6 +43,7 @@ from pyspark.sql.types import ( BinaryType, StructField, ArrayType, + MapType, NullType, DayTimeIntervalType, ) @@ -620,20 +621,23 @@ class ArrowTestsMixin: map_data = [{"a": 1}, {"b": 2, "c": 3}, {}, None, {"d": None}] pdf = pd.DataFrame({"id": [0, 1, 2, 3, 4], "m": map_data}) - schema = "id long, m map<string, long>" - - with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": arrow_enabled}): - if arrow_enabled and LooseVersion(pa.__version__) < LooseVersion("2.0.0"): - with self.assertRaisesRegex(Exception, "MapType.*only.*pyarrow 2.0.0"): - self.spark.createDataFrame(pdf, schema=schema).collect() - else: - df = self.spark.createDataFrame(pdf, schema=schema) - - result = df.collect() - - for row in result: - i, m = row - self.assertEqual(m, map_data[i]) + for schema in ( + "id long, m map<string, long>", + StructType().add("id", LongType()).add("m", MapType(StringType(), LongType())), + ): + with self.subTest(schema=schema): + with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": arrow_enabled}): + if arrow_enabled and LooseVersion(pa.__version__) < LooseVersion("2.0.0"): + with self.assertRaisesRegex(Exception, "MapType.*only.*pyarrow 2.0.0"): + self.spark.createDataFrame(pdf, schema=schema).collect() + else: + df = self.spark.createDataFrame(pdf, schema=schema) + + result = df.collect() + + for row in result: + i, m = row + self.assertEqual(m, map_data[i]) def test_createDataFrame_with_string_dtype(self): # SPARK-34521: spark.createDataFrame does not support Pandas StringDtype extension type @@ -667,16 +671,20 @@ class ArrowTestsMixin: {"id": [0, 1, 2, 3], "m": [{}, {"a": 1}, {"a": 1, "b": 2}, {"a": 1, "b": 2, "c": 3}]} ) - with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}): - df = self.spark.createDataFrame(origin, schema="id long, m map<string, long>") + for schema in [ + "id long, m map<string, long>", + StructType().add("id", LongType()).add("m", MapType(StringType(), LongType())), + ]: + with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}): + df = self.spark.createDataFrame(origin, schema=schema) - with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": arrow_enabled}): - if arrow_enabled and LooseVersion(pa.__version__) < LooseVersion("2.0.0"): - with self.assertRaisesRegex(Exception, "MapType.*only.*pyarrow 2.0.0"): - df.toPandas() - else: - pdf = df.toPandas() - assert_frame_equal(origin, pdf) + with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": arrow_enabled}): + if arrow_enabled and LooseVersion(pa.__version__) < LooseVersion("2.0.0"): + with self.assertRaisesRegex(Exception, "MapType.*only.*pyarrow 2.0.0"): + df.toPandas() + else: + pdf = df.toPandas() + assert_frame_equal(origin, pdf) def test_toPandas_with_map_type_nulls(self): with QuietTest(self.sc): @@ -689,16 +697,20 @@ class ArrowTestsMixin: {"id": [0, 1, 2, 3, 4], "m": [{"a": 1}, {"b": 2, "c": 3}, {}, None, {"d": None}]} ) - with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}): - df = self.spark.createDataFrame(origin, schema="id long, m map<string, long>") + for schema in [ + "id long, m map<string, long>", + StructType().add("id", LongType()).add("m", MapType(StringType(), LongType())), + ]: + with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}): + df = self.spark.createDataFrame(origin, schema=schema) - with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": arrow_enabled}): - if arrow_enabled and LooseVersion(pa.__version__) < LooseVersion("2.0.0"): - with self.assertRaisesRegex(Exception, "MapType.*only.*pyarrow 2.0.0"): - df.toPandas() - else: - pdf = df.toPandas() - assert_frame_equal(origin, pdf) + with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": arrow_enabled}): + if arrow_enabled and LooseVersion(pa.__version__) < LooseVersion("2.0.0"): + with self.assertRaisesRegex(Exception, "MapType.*only.*pyarrow 2.0.0"): + df.toPandas() + else: + pdf = df.toPandas() + assert_frame_equal(origin, pdf) def test_createDataFrame_with_int_col_names(self): for arrow_enabled in [True, False]: --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org