This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new f66611d5e17 [SPARK-42982][CONNECT][PYTHON] Fix createDataFrame to 
respect the given schema ddl
f66611d5e17 is described below

commit f66611d5e1788745c907c6a54fe8d941a67b55b4
Author: Takuya UESHIN <ues...@databricks.com>
AuthorDate: Thu Apr 13 18:31:14 2023 +0900

    [SPARK-42982][CONNECT][PYTHON] Fix createDataFrame to respect the given 
schema ddl
    
    ### What changes were proposed in this pull request?
    
    Fixes `createDataFrame` to respect the given schema ddl.
    
    ### Why are the changes needed?
    
    Currently even if the schema is provided as a DDL string, it's not taken 
into account and causes the schema mismatch in the server side.
    
    For example:
    
    ```py
    >>> import pandas as pd
    >>> map_data = [{"a": 1}, {"b": 2, "c": 3}, {}, None, {"d": None}]
    >>> pdf = pd.DataFrame({"id": [0, 1, 2, 3, 4], "m": map_data})
    >>> schema = "id long, m map<string, long>"
    >>>
    >>> spark.createDataFrame(pdf, schema=schema)
    Traceback (most recent call last):
    ...
    pyspark.errors.exceptions.connect.AnalysisException: 
[INVALID_COLUMN_OR_FIELD_DATA_TYPE] Column or field `col_1` is of type 
"STRUCT<col_0: BIGINT, col_1: BIGINT, col_2: BIGINT, col_3: VOID>" while it's 
required to be "MAP<STRING, BIGINT>".
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    The schema DDL string will be taken into account.
    
    ### How was this patch tested?
    
    Enabled/modified the related tests.
    
    Closes #40760 from ueshin/issues/SPARK-42982/schema_str.
    
    Authored-by: Takuya UESHIN <ues...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/sql/connect/conversion.py           |  36 ++++---
 python/pyspark/sql/connect/session.py              |  58 ++++-------
 .../sql/tests/connect/test_connect_basic.py        | 111 +++++++++++++--------
 .../pyspark/sql/tests/connect/test_parity_arrow.py |   6 --
 python/pyspark/sql/tests/test_arrow.py             |  76 ++++++++------
 5 files changed, 155 insertions(+), 132 deletions(-)

diff --git a/python/pyspark/sql/connect/conversion.py 
b/python/pyspark/sql/connect/conversion.py
index 310f26654df..5a31d1df67e 100644
--- a/python/pyspark/sql/connect/conversion.py
+++ b/python/pyspark/sql/connect/conversion.py
@@ -119,13 +119,17 @@ class LocalDataToArrowConversion:
                     _dict = {}
                     if not isinstance(value, Row) and hasattr(value, 
"__dict__"):
                         value = value.__dict__
-                    for i, field in enumerate(field_names):
-                        if isinstance(value, dict):
-                            v = value.get(field)
-                        else:
-                            v = value[i]
-
-                        _dict[f"col_{i}"] = field_convs[i](v)
+                    if isinstance(value, dict):
+                        for i, field in enumerate(field_names):
+                            _dict[f"col_{i}"] = 
field_convs[i](value.get(field))
+                    else:
+                        if len(value) != len(field_names):
+                            raise ValueError(
+                                f"Length mismatch: Expected axis has 
{len(field_names)} elements, "
+                                f"new values have {len(value)} elements"
+                            )
+                        for i in range(len(field_names)):
+                            _dict[f"col_{i}"] = field_convs[i](value[i])
 
                     return _dict
 
@@ -272,13 +276,17 @@ class LocalDataToArrowConversion:
         for item in data:
             if not isinstance(item, Row) and hasattr(item, "__dict__"):
                 item = item.__dict__
-            for i, col in enumerate(column_names):
-                if isinstance(item, dict):
-                    value = item.get(col)
-                else:
-                    value = item[i]
-
-                pylist[i].append(column_convs[i](value))
+            if isinstance(item, dict):
+                for i, col in enumerate(column_names):
+                    pylist[i].append(column_convs[i](item.get(col)))
+            else:
+                if len(item) != len(column_names):
+                    raise ValueError(
+                        f"Length mismatch: Expected axis has 
{len(column_names)} elements, "
+                        f"new values have {len(item)} elements"
+                    )
+                for i in range(len(column_names)):
+                    pylist[i].append(column_convs[i](item[i]))
 
         def normalize(dt: DataType) -> DataType:
             if isinstance(dt, StructType):
diff --git a/python/pyspark/sql/connect/session.py 
b/python/pyspark/sql/connect/session.py
index e7b2ec6d2a6..3d9b641658a 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -220,10 +220,14 @@ class SparkSession:
             raise TypeError("data is already a DataFrame")
 
         _schema: Optional[Union[AtomicType, StructType]] = None
-        _schema_str: Optional[str] = None
         _cols: Optional[List[str]] = None
         _num_cols: Optional[int] = None
 
+        if isinstance(schema, str):
+            schema = self.client._analyze(  # type: ignore[assignment]
+                method="ddl_parse", ddl_string=schema
+            ).parsed
+
         if isinstance(schema, (AtomicType, StructType)):
             _schema = schema
             if isinstance(schema, StructType):
@@ -231,9 +235,6 @@ class SparkSession:
             else:
                 _num_cols = 1
 
-        elif isinstance(schema, str):
-            _schema_str = schema
-
         elif isinstance(schema, (list, tuple)):
             # Must re-encode any unicode strings to be consistent with 
StructField names
             _cols = [x.encode("utf-8") if not isinstance(x, str) else x for x 
in schema]
@@ -244,13 +245,10 @@ class SparkSession:
         elif isinstance(data, Sized) and len(data) == 0:
             if _schema is not None:
                 return DataFrame.withPlan(LocalRelation(table=None, 
schema=_schema.json()), self)
-            elif _schema_str is not None:
-                return DataFrame.withPlan(LocalRelation(table=None, 
schema=_schema_str), self)
             else:
                 raise ValueError("can not infer schema from empty dataset")
 
         _table: Optional[pa.Table] = None
-        _inferred_schema: Optional[StructType] = None
 
         if isinstance(data, pd.DataFrame):
             # Logic was borrowed from `_create_from_pandas_with_arrow` in
@@ -309,16 +307,16 @@ class SparkSession:
             if data.ndim == 1:
                 if 1 != len(_cols):
                     raise ValueError(
-                        f"Length mismatch: Expected axis has 1 element, "
-                        f"new values have {len(_cols)} elements"
+                        f"Length mismatch: Expected axis has {len(_cols)} 
element, "
+                        "new values have 1 elements"
                     )
 
                 _table = pa.Table.from_arrays([pa.array(data)], _cols)
             else:
                 if data.shape[1] != len(_cols):
                     raise ValueError(
-                        f"Length mismatch: Expected axis has {data.shape[1]} 
elements, "
-                        f"new values have {len(_cols)} elements"
+                        f"Length mismatch: Expected axis has {len(_cols)} 
elements, "
+                        f"new values have {data.shape[1]} elements"
                     )
 
                 _table = pa.Table.from_arrays(
@@ -334,7 +332,7 @@ class SparkSession:
             if isinstance(_data[0], dict):
                 # Sort the data to respect inferred schema.
                 # For dictionaries, we sort the schema in alphabetical order.
-                _data = [dict(sorted(d.items())) for d in _data]
+                _data = [dict(sorted(d.items())) if d is not None else None 
for d in _data]
 
             elif not isinstance(_data[0], (Row, tuple, list, dict)) and not 
hasattr(
                 _data[0], "__dict__"
@@ -344,44 +342,28 @@ class SparkSession:
                 _data = [[d] for d in _data]
 
             if _schema is not None:
-                if isinstance(_schema, StructType):
-                    _inferred_schema = _schema
-                else:
-                    _inferred_schema = StructType().add("value", _schema)
+                if not isinstance(_schema, StructType):
+                    _schema = StructType().add("value", _schema)
             else:
-                _inferred_schema = self._inferSchemaFromList(_data, _cols)
+                _schema = self._inferSchemaFromList(_data, _cols)
 
                 if _cols is not None and cast(int, _num_cols) < len(_cols):
                     _num_cols = len(_cols)
 
-                if _has_nulltype(_inferred_schema):
+                if _has_nulltype(_schema):
                     # For cases like createDataFrame([("Alice", None, 80.1)], 
schema)
                     # we can not infer the schema from the data itself.
-                    warnings.warn("failed to infer the schema from data")
-                    if _schema_str is not None:
-                        _parsed = self.client._analyze(
-                            method="ddl_parse", ddl_string=_schema_str
-                        ).parsed
-                        if isinstance(_parsed, StructType):
-                            _inferred_schema = _parsed
-                        elif isinstance(_parsed, DataType):
-                            _inferred_schema = StructType().add("value", 
_parsed)
-                        _schema_str = None
-                    if _has_nulltype(_inferred_schema):
-                        raise ValueError(
-                            "Some of types cannot be determined after 
inferring, "
-                            "a StructType Schema is required in this case"
-                        )
-
-                if _schema_str is None:
-                    _schema = _inferred_schema
+                    raise ValueError(
+                        "Some of types cannot be determined after inferring, "
+                        "a StructType Schema is required in this case"
+                    )
 
             from pyspark.sql.connect.conversion import 
LocalDataToArrowConversion
 
             # Spark Connect will try its best to build the Arrow table with the
             # inferred schema in the client side, and then rename the columns 
and
             # cast the datatypes in the server side.
-            _table = LocalDataToArrowConversion.convert(_data, 
_inferred_schema)
+            _table = LocalDataToArrowConversion.convert(_data, _schema)
 
         # TODO: Beside the validation on number of columns, we should also 
check
         # whether the Arrow Schema is compatible with the user provided Schema.
@@ -393,8 +375,6 @@ class SparkSession:
 
         if _schema is not None:
             df = DataFrame.withPlan(LocalRelation(_table, 
schema=_schema.json()), self)
-        elif _schema_str is not None:
-            df = DataFrame.withPlan(LocalRelation(_table, schema=_schema_str), 
self)
         else:
             df = DataFrame.withPlan(LocalRelation(_table), self)
 
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 2658ad79ab0..166a8b23bd7 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -555,17 +555,18 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
 
         with self.assertRaisesRegex(
             ValueError,
-            "Length mismatch: Expected axis has 4 elements, new values have 5 
elements",
+            "Length mismatch: Expected axis has 5 elements, new values have 4 
elements",
         ):
             self.connect.createDataFrame(data, ["a", "b", "c", "d", "e"])
 
         with self.assertRaises(ParseException):
-            self.connect.createDataFrame(
-                data, "col1 magic_type, col2 int, col3 int, col4 int"
-            ).show()
+            self.connect.createDataFrame(data, "col1 magic_type, col2 int, 
col3 int, col4 int")
 
-        with self.assertRaises(SparkConnectException):
-            self.connect.createDataFrame(data, "col1 int, col2 int, col3 
int").show()
+        with self.assertRaisesRegex(
+            ValueError,
+            "Length mismatch: Expected axis has 3 elements, new values have 4 
elements",
+        ):
+            self.connect.createDataFrame(data, "col1 int, col2 int, col3 int")
 
         # test 1 dim ndarray
         data = np.array([1.0, 2.0, np.nan, 3.0, 4.0, float("NaN"), 5.0])
@@ -606,12 +607,13 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
             self.connect.createDataFrame(data, ["a", "b", "c", "d", "e"])
 
         with self.assertRaises(ParseException):
-            self.connect.createDataFrame(
-                data, "col1 magic_type, col2 int, col3 int, col4 int"
-            ).show()
+            self.connect.createDataFrame(data, "col1 magic_type, col2 int, 
col3 int, col4 int")
 
-        with self.assertRaises(SparkConnectException):
-            self.connect.createDataFrame(data, "col1 int, col2 int, col3 
int").show()
+        with self.assertRaisesRegex(
+            ValueError,
+            "Length mismatch: Expected axis has 3 elements, new values have 4 
elements",
+        ):
+            self.connect.createDataFrame(data, "col1 int, col2 int, col3 int")
 
     def test_with_local_rows(self):
         # SPARK-41789, SPARK-41810: Test creating a dataframe with list of 
rows and dictionaries
@@ -2982,79 +2984,106 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
         self.assertEqual(cdf4.collect(), sdf4.collect())
 
     def test_array_has_nullable(self):
-        for schema, data in [
+        for schemas, data in [
             (
-                StructType().add("arr", ArrayType(IntegerType(), False), True),
+                [StructType().add("arr", ArrayType(IntegerType(), False), 
True)],
                 [Row([1, 2]), Row([3]), Row(None)],
             ),
             (
-                StructType().add("arr", ArrayType(IntegerType(), True), True),
+                [
+                    StructType().add("arr", ArrayType(IntegerType(), True), 
True),
+                    "arr array<integer>",
+                ],
                 [Row([1, None]), Row([3]), Row(None)],
             ),
             (
-                StructType().add("arr", ArrayType(IntegerType(), False), 
False),
+                [StructType().add("arr", ArrayType(IntegerType(), False), 
False)],
                 [Row([1, 2]), Row([3])],
             ),
             (
-                StructType().add("arr", ArrayType(IntegerType(), True), False),
+                [
+                    StructType().add("arr", ArrayType(IntegerType(), True), 
False),
+                    "arr array<integer> not null",
+                ],
                 [Row([1, None]), Row([3])],
             ),
         ]:
-            with self.subTest(schema=schema):
-                cdf = self.connect.createDataFrame(data, schema=schema)
-                sdf = self.spark.createDataFrame(data, schema=schema)
-                self.assertEqual(cdf.schema, sdf.schema)
-                self.assertEqual(cdf.collect(), sdf.collect())
+            for schema in schemas:
+                with self.subTest(schema=schema):
+                    cdf = self.connect.createDataFrame(data, schema=schema)
+                    sdf = self.spark.createDataFrame(data, schema=schema)
+                    self.assertEqual(cdf.schema, sdf.schema)
+                    self.assertEqual(cdf.collect(), sdf.collect())
 
     def test_map_has_nullable(self):
-        for schema, data in [
+        for schemas, data in [
             (
-                StructType().add("map", MapType(StringType(), IntegerType(), 
False), True),
+                [StructType().add("map", MapType(StringType(), IntegerType(), 
False), True)],
                 [Row({"a": 1, "b": 2}), Row({"a": 3}), Row(None)],
             ),
             (
-                StructType().add("map", MapType(StringType(), IntegerType(), 
True), True),
+                [
+                    StructType().add("map", MapType(StringType(), 
IntegerType(), True), True),
+                    "map map<string, integer>",
+                ],
                 [Row({"a": 1, "b": None}), Row({"a": 3}), Row(None)],
             ),
             (
-                StructType().add("map", MapType(StringType(), IntegerType(), 
False), False),
+                [StructType().add("map", MapType(StringType(), IntegerType(), 
False), False)],
                 [Row({"a": 1, "b": 2}), Row({"a": 3})],
             ),
             (
-                StructType().add("map", MapType(StringType(), IntegerType(), 
True), False),
+                [
+                    StructType().add("map", MapType(StringType(), 
IntegerType(), True), False),
+                    "map map<string, integer> not null",
+                ],
                 [Row({"a": 1, "b": None}), Row({"a": 3})],
             ),
         ]:
-            with self.subTest(schema=schema):
-                cdf = self.connect.createDataFrame(data, schema=schema)
-                sdf = self.spark.createDataFrame(data, schema=schema)
-                self.assertEqual(cdf.schema, sdf.schema)
-                self.assertEqual(cdf.collect(), sdf.collect())
+            for schema in schemas:
+                with self.subTest(schema=schema):
+                    cdf = self.connect.createDataFrame(data, schema=schema)
+                    sdf = self.spark.createDataFrame(data, schema=schema)
+                    self.assertEqual(cdf.schema, sdf.schema)
+                    self.assertEqual(cdf.collect(), sdf.collect())
 
     def test_struct_has_nullable(self):
-        for schema, data in [
+        for schemas, data in [
             (
-                StructType().add("struct", StructType().add("i", 
IntegerType(), False), True),
+                [
+                    StructType().add("struct", StructType().add("i", 
IntegerType(), False), True),
+                    "struct struct<i: integer not null>",
+                ],
                 [Row(Row(1)), Row(Row(2)), Row(None)],
             ),
             (
-                StructType().add("struct", StructType().add("i", 
IntegerType(), True), True),
+                [
+                    StructType().add("struct", StructType().add("i", 
IntegerType(), True), True),
+                    "struct struct<i: integer>",
+                ],
                 [Row(Row(1)), Row(Row(2)), Row(Row(None)), Row(None)],
             ),
             (
-                StructType().add("struct", StructType().add("i", 
IntegerType(), False), False),
+                [
+                    StructType().add("struct", StructType().add("i", 
IntegerType(), False), False),
+                    "struct struct<i: integer not null> not null",
+                ],
                 [Row(Row(1)), Row(Row(2))],
             ),
             (
-                StructType().add("struct", StructType().add("i", 
IntegerType(), True), False),
+                [
+                    StructType().add("struct", StructType().add("i", 
IntegerType(), True), False),
+                    "struct struct<i: integer> not null",
+                ],
                 [Row(Row(1)), Row(Row(2)), Row(Row(None))],
             ),
         ]:
-            with self.subTest(schema=schema):
-                cdf = self.connect.createDataFrame(data, schema=schema)
-                sdf = self.spark.createDataFrame(data, schema=schema)
-                self.assertEqual(cdf.schema, sdf.schema)
-                self.assertEqual(cdf.collect(), sdf.collect())
+            for schema in schemas:
+                with self.subTest(schema=schema):
+                    cdf = self.connect.createDataFrame(data, schema=schema)
+                    sdf = self.spark.createDataFrame(data, schema=schema)
+                    self.assertEqual(cdf.schema, sdf.schema)
+                    self.assertEqual(cdf.collect(), sdf.collect())
 
     def test_large_client_data(self):
         # SPARK-42816 support more than 4MB message size.
diff --git a/python/pyspark/sql/tests/connect/test_parity_arrow.py 
b/python/pyspark/sql/tests/connect/test_parity_arrow.py
index 0ed8642383b..ec33bb22a4b 100644
--- a/python/pyspark/sql/tests/connect/test_parity_arrow.py
+++ b/python/pyspark/sql/tests/connect/test_parity_arrow.py
@@ -37,8 +37,6 @@ class ArrowParityTests(ArrowTestsMixin, 
ReusedConnectTestCase):
     def test_createDataFrame_with_incorrect_schema(self):
         self.check_createDataFrame_with_incorrect_schema()
 
-    # TODO(SPARK-42982): INVALID_COLUMN_OR_FIELD_DATA_TYPE
-    @unittest.skip("Fails in Spark Connect, should enable.")
     def test_createDataFrame_with_map_type(self):
         self.check_createDataFrame_with_map_type(True)
 
@@ -92,13 +90,9 @@ class ArrowParityTests(ArrowTestsMixin, 
ReusedConnectTestCase):
     def test_toPandas_fallback_enabled(self):
         super().test_toPandas_fallback_enabled()
 
-    # TODO(SPARK-42982): INVALID_COLUMN_OR_FIELD_DATA_TYPE
-    @unittest.skip("Fails in Spark Connect, should enable.")
     def test_toPandas_with_map_type(self):
         self.check_toPandas_with_map_type(True)
 
-    # TODO(SPARK-42982): INVALID_COLUMN_OR_FIELD_DATA_TYPE
-    @unittest.skip("Fails in Spark Connect, should enable.")
     def test_toPandas_with_map_type_nulls(self):
         self.check_toPandas_with_map_type_nulls(True)
 
diff --git a/python/pyspark/sql/tests/test_arrow.py 
b/python/pyspark/sql/tests/test_arrow.py
index dfc65d02b55..04aaa2b1c32 100644
--- a/python/pyspark/sql/tests/test_arrow.py
+++ b/python/pyspark/sql/tests/test_arrow.py
@@ -43,6 +43,7 @@ from pyspark.sql.types import (
     BinaryType,
     StructField,
     ArrayType,
+    MapType,
     NullType,
     DayTimeIntervalType,
 )
@@ -620,20 +621,23 @@ class ArrowTestsMixin:
         map_data = [{"a": 1}, {"b": 2, "c": 3}, {}, None, {"d": None}]
 
         pdf = pd.DataFrame({"id": [0, 1, 2, 3, 4], "m": map_data})
-        schema = "id long, m map<string, long>"
-
-        with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": 
arrow_enabled}):
-            if arrow_enabled and LooseVersion(pa.__version__) < 
LooseVersion("2.0.0"):
-                with self.assertRaisesRegex(Exception, "MapType.*only.*pyarrow 
2.0.0"):
-                    self.spark.createDataFrame(pdf, schema=schema).collect()
-            else:
-                df = self.spark.createDataFrame(pdf, schema=schema)
-
-                result = df.collect()
-
-                for row in result:
-                    i, m = row
-                    self.assertEqual(m, map_data[i])
+        for schema in (
+            "id long, m map<string, long>",
+            StructType().add("id", LongType()).add("m", MapType(StringType(), 
LongType())),
+        ):
+            with self.subTest(schema=schema):
+                with 
self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": arrow_enabled}):
+                    if arrow_enabled and LooseVersion(pa.__version__) < 
LooseVersion("2.0.0"):
+                        with self.assertRaisesRegex(Exception, 
"MapType.*only.*pyarrow 2.0.0"):
+                            self.spark.createDataFrame(pdf, 
schema=schema).collect()
+                    else:
+                        df = self.spark.createDataFrame(pdf, schema=schema)
+
+                        result = df.collect()
+
+                        for row in result:
+                            i, m = row
+                            self.assertEqual(m, map_data[i])
 
     def test_createDataFrame_with_string_dtype(self):
         # SPARK-34521: spark.createDataFrame does not support Pandas 
StringDtype extension type
@@ -667,16 +671,20 @@ class ArrowTestsMixin:
             {"id": [0, 1, 2, 3], "m": [{}, {"a": 1}, {"a": 1, "b": 2}, {"a": 
1, "b": 2, "c": 3}]}
         )
 
-        with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": 
False}):
-            df = self.spark.createDataFrame(origin, schema="id long, m 
map<string, long>")
+        for schema in [
+            "id long, m map<string, long>",
+            StructType().add("id", LongType()).add("m", MapType(StringType(), 
LongType())),
+        ]:
+            with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": 
False}):
+                df = self.spark.createDataFrame(origin, schema=schema)
 
-        with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": 
arrow_enabled}):
-            if arrow_enabled and LooseVersion(pa.__version__) < 
LooseVersion("2.0.0"):
-                with self.assertRaisesRegex(Exception, "MapType.*only.*pyarrow 
2.0.0"):
-                    df.toPandas()
-            else:
-                pdf = df.toPandas()
-                assert_frame_equal(origin, pdf)
+            with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": 
arrow_enabled}):
+                if arrow_enabled and LooseVersion(pa.__version__) < 
LooseVersion("2.0.0"):
+                    with self.assertRaisesRegex(Exception, 
"MapType.*only.*pyarrow 2.0.0"):
+                        df.toPandas()
+                else:
+                    pdf = df.toPandas()
+                    assert_frame_equal(origin, pdf)
 
     def test_toPandas_with_map_type_nulls(self):
         with QuietTest(self.sc):
@@ -689,16 +697,20 @@ class ArrowTestsMixin:
             {"id": [0, 1, 2, 3, 4], "m": [{"a": 1}, {"b": 2, "c": 3}, {}, 
None, {"d": None}]}
         )
 
-        with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": 
False}):
-            df = self.spark.createDataFrame(origin, schema="id long, m 
map<string, long>")
+        for schema in [
+            "id long, m map<string, long>",
+            StructType().add("id", LongType()).add("m", MapType(StringType(), 
LongType())),
+        ]:
+            with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": 
False}):
+                df = self.spark.createDataFrame(origin, schema=schema)
 
-        with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": 
arrow_enabled}):
-            if arrow_enabled and LooseVersion(pa.__version__) < 
LooseVersion("2.0.0"):
-                with self.assertRaisesRegex(Exception, "MapType.*only.*pyarrow 
2.0.0"):
-                    df.toPandas()
-            else:
-                pdf = df.toPandas()
-                assert_frame_equal(origin, pdf)
+            with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": 
arrow_enabled}):
+                if arrow_enabled and LooseVersion(pa.__version__) < 
LooseVersion("2.0.0"):
+                    with self.assertRaisesRegex(Exception, 
"MapType.*only.*pyarrow 2.0.0"):
+                        df.toPandas()
+                else:
+                    pdf = df.toPandas()
+                    assert_frame_equal(origin, pdf)
 
     def test_createDataFrame_with_int_col_names(self):
         for arrow_enabled in [True, False]:


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to