This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 916b0d3de97 [SPARK-43817][SPARK-43702][PYTHON] Support UserDefinedType 
in createDataFrame from pandas DataFrame and toPandas
916b0d3de97 is described below

commit 916b0d3de973b8b30a8ede3d56b9f8a711110512
Author: Takuya UESHIN <ues...@databricks.com>
AuthorDate: Sun May 28 08:47:35 2023 +0800

    [SPARK-43817][SPARK-43702][PYTHON] Support UserDefinedType in 
createDataFrame from pandas DataFrame and toPandas
    
    ### What changes were proposed in this pull request?
    
    Support `UserDefinedType` in `createDataFrame` from pandas DataFrame and 
`toPandas`.
    
    For the following schema and pandas DataFrame:
    
    ```py
    schema = (
        StructType()
        .add("point", ExamplePointUDT())
        .add("struct", StructType().add("point", ExamplePointUDT()))
        .add("array", ArrayType(ExamplePointUDT()))
        .add("map", MapType(StringType(), ExamplePointUDT()))
    )
    data = [
        Row(
            ExamplePoint(1.0, 2.0),
            Row(ExamplePoint(3.0, 4.0)),
            [ExamplePoint(5.0, 6.0)],
            dict(point=ExamplePoint(7.0, 8.0)),
        )
    ]
    
    df = spark.createDataFrame(data, schema)
    
    pdf = pd.DataFrame.from_records(data, columns=schema.names)
    ```
    
    ##### `spark.createDataFrame()`
    
    For all, return the same results:
    
    ```py
    >>> spark.createDataFrame(pdf, schema).show(truncate=False)
    +----------+------------+------------+---------------------+
    |point     |struct      |array       |map                  |
    +----------+------------+------------+---------------------+
    |(1.0, 2.0)|{(3.0, 4.0)}|[(5.0, 6.0)]|{point -> (7.0, 8.0)}|
    +----------+------------+------------+---------------------+
    ```
    
    ##### `df.toPandas()`
    
    ```py
    >>> spark.conf.set('spark.sql.execution.pandas.structHandlingMode', 'row')
    >>> df.toPandas()
           point        struct        array                   map
    0  (1.0,2.0)  ((3.0,4.0),)  [(5.0,6.0)]  {'point': (7.0,8.0)}
    ```
    
    ### Why are the changes needed?
    
    Currently `UserDefinedType` in `spark.createDataFrame()` with pandas 
DataFrame and `df.toPandas()` is not supported with Arrow enabled or in Spark 
Connect.
    
    ##### `spark.createDataFrame()`
    
    Works without Arrow:
    
    ```py
    >>> spark.createDataFrame(pdf, schema).show(truncate=False)
    +----------+------------+------------+---------------------+
    |point     |struct      |array       |map                  |
    +----------+------------+------------+---------------------+
    |(1.0, 2.0)|{(3.0, 4.0)}|[(5.0, 6.0)]|{point -> (7.0, 8.0)}|
    +----------+------------+------------+---------------------+
    ```
    
    , whereas:
    
    - With Arrow:
    
    Works with fallback:
    
    ```py
    >>> spark.createDataFrame(pdf, schema).show(truncate=False)
    /.../python/pyspark/sql/pandas/conversion.py:351: UserWarning: 
createDataFrame attempted Arrow optimization because 
'spark.sql.execution.arrow.pyspark.enabled' is set to true; however, failed by 
the reason below:
      [UNSUPPORTED_DATA_TYPE_FOR_ARROW_CONVERSION] ExamplePointUDT() is not 
supported in conversion to Arrow.
    Attempting non-optimization as 
'spark.sql.execution.arrow.pyspark.fallback.enabled' is set to true.
      warn(msg)
    +----------+------------+------------+---------------------+
    |point     |struct      |array       |map                  |
    +----------+------------+------------+---------------------+
    |(1.0, 2.0)|{(3.0, 4.0)}|[(5.0, 6.0)]|{point -> (7.0, 8.0)}|
    +----------+------------+------------+---------------------+
    ```
    
    - Spark Connect
    
    ```py
    >>> spark.createDataFrame(pdf, schema).show(truncate=False)
    Traceback (most recent call last):
    ...
    pyspark.errors.exceptions.base.PySparkTypeError: 
[UNSUPPORTED_DATA_TYPE_FOR_ARROW_CONVERSION] ExamplePointUDT() is not supported 
in conversion to Arrow.
    ```
    
    ##### `df.toPandas()`
    
    Works without Arrow:
    
    ```py
    >>> spark.conf.set('spark.sql.execution.pandas.structHandlingMode', 'row')
    >>> df.toPandas()
           point        struct        array                   map
    0  (1.0,2.0)  ((3.0,4.0),)  [(5.0,6.0)]  {'point': (7.0,8.0)}
    ```
    
    , whereas:
    
    - With Arrow
    
    Works with fallback:
    
    ```py
    >>> spark.conf.set('spark.sql.execution.pandas.structHandlingMode', 'row')
    >>> df.toPandas()
    /.../python/pyspark/sql/pandas/conversion.py:111: UserWarning: toPandas 
attempted Arrow optimization because 
'spark.sql.execution.arrow.pyspark.enabled' is set to true; however, failed by 
the reason below:
      [UNSUPPORTED_DATA_TYPE_FOR_ARROW_CONVERSION] ExamplePointUDT() is not 
supported in conversion to Arrow.
    Attempting non-optimization as 
'spark.sql.execution.arrow.pyspark.fallback.enabled' is set to true.
      warn(msg)
           point        struct        array                   map
    0  (1.0,2.0)  ((3.0,4.0),)  [(5.0,6.0)]  {'point': (7.0,8.0)}
    ```
    
    - Spark Connect
    
    Results with the internal type:
    
    ```py
    >>> spark.conf.set('spark.sql.execution.pandas.structHandlingMode', 'row')
    >>> df.toPandas()
            point         struct         array                    map
    0  [1.0, 2.0]  ([3.0, 4.0],)  [[5.0, 6.0]]  {'point': [7.0, 8.0]}
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    Users will be able to use `UserDefinedType`.
    
    ### How was this patch tested?
    
    Added the related tests.
    
    Closes #41333 from ueshin/issues/SPARK-43817/udt.
    
    Authored-by: Takuya UESHIN <ues...@databricks.com>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 .../connect/data_type_ops/test_parity_udt_ops.py   |  48 +------
 python/pyspark/sql/connect/client/core.py          |   4 +-
 python/pyspark/sql/connect/conversion.py           |   7 +-
 python/pyspark/sql/connect/dataframe.py            |   2 +-
 python/pyspark/sql/connect/session.py              |  22 +++-
 python/pyspark/sql/connect/types.py                | 146 ---------------------
 python/pyspark/sql/pandas/conversion.py            |  17 ++-
 python/pyspark/sql/pandas/serializers.py           |  28 ++--
 python/pyspark/sql/pandas/types.py                 |  58 ++++++--
 .../pyspark/sql/tests/connect/test_parity_arrow.py |   6 +
 python/pyspark/sql/tests/test_arrow.py             |  67 +++++++++-
 11 files changed, 171 insertions(+), 234 deletions(-)

diff --git 
a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_udt_ops.py 
b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_udt_ops.py
index 81511829c06..70a79e4cd3f 100644
--- a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_udt_ops.py
+++ b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_udt_ops.py
@@ -25,53 +25,7 @@ from pyspark.testing.connectutils import 
ReusedConnectTestCase
 class UDTOpsParityTests(
     UDTOpsTestsMixin, PandasOnSparkTestUtils, OpsTestBase, 
ReusedConnectTestCase
 ):
-    @unittest.skip(
-        "TODO(SPARK-43702): Fix pyspark.sql.pandas.types.to_arrow_type to work 
with Spark Connect."
-    )
-    def test_eq(self):
-        super().test_eq()
-
-    @unittest.skip(
-        "TODO(SPARK-43702): Fix pyspark.sql.pandas.types.to_arrow_type to work 
with Spark Connect."
-    )
-    def test_from_to_pandas(self):
-        super().test_from_to_pandas()
-
-    @unittest.skip(
-        "TODO(SPARK-43702): Fix pyspark.sql.pandas.types.to_arrow_type to work 
with Spark Connect."
-    )
-    def test_ge(self):
-        super().test_ge()
-
-    @unittest.skip(
-        "TODO(SPARK-43702): Fix pyspark.sql.pandas.types.to_arrow_type to work 
with Spark Connect."
-    )
-    def test_gt(self):
-        super().test_gt()
-
-    @unittest.skip(
-        "TODO(SPARK-43702): Fix pyspark.sql.pandas.types.to_arrow_type to work 
with Spark Connect."
-    )
-    def test_isnull(self):
-        super().test_isnull()
-
-    @unittest.skip(
-        "TODO(SPARK-43702): Fix pyspark.sql.pandas.types.to_arrow_type to work 
with Spark Connect."
-    )
-    def test_le(self):
-        super().test_le()
-
-    @unittest.skip(
-        "TODO(SPARK-43702): Fix pyspark.sql.pandas.types.to_arrow_type to work 
with Spark Connect."
-    )
-    def test_lt(self):
-        super().test_lt()
-
-    @unittest.skip(
-        "TODO(SPARK-43702): Fix pyspark.sql.pandas.types.to_arrow_type to work 
with Spark Connect."
-    )
-    def test_ne(self):
-        super().test_ne()
+    pass
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/sql/connect/client/core.py 
b/python/pyspark/sql/connect/client/core.py
index 544ed5d4183..a0f790b2992 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -75,7 +75,7 @@ from pyspark.sql.connect.expressions import (
     CommonInlineUserDefinedFunction,
     JavaUDF,
 )
-from pyspark.sql.pandas.types import _create_converter_to_pandas
+from pyspark.sql.pandas.types import _create_converter_to_pandas, 
from_arrow_schema
 from pyspark.sql.types import DataType, StructType, TimestampType, _has_type
 from pyspark.rdd import PythonEvalType
 from pyspark.storagelevel import StorageLevel
@@ -717,7 +717,7 @@ class SparkConnectClient(object):
         table, schema, metrics, observed_metrics, _ = 
self._execute_and_fetch(req)
         assert table is not None
 
-        schema = schema or types.from_arrow_schema(table.schema, 
prefer_timestamp_ntz=True)
+        schema = schema or from_arrow_schema(table.schema, 
prefer_timestamp_ntz=True)
         assert schema is not None and isinstance(schema, StructType)
 
         # Rename columns to avoid duplicated column names.
diff --git a/python/pyspark/sql/connect/conversion.py 
b/python/pyspark/sql/connect/conversion.py
index 3cc301c38ea..cdbc3a1e39c 100644
--- a/python/pyspark/sql/connect/conversion.py
+++ b/python/pyspark/sql/connect/conversion.py
@@ -42,9 +42,8 @@ from pyspark.sql.types import (
 )
 
 from pyspark.storagelevel import StorageLevel
-from pyspark.sql.connect.types import to_arrow_schema
 import pyspark.sql.connect.proto as pb2
-from pyspark.sql.pandas.types import _dedup_names, _deduplicate_field_names
+from pyspark.sql.pandas.types import to_arrow_schema, _dedup_names, 
_deduplicate_field_names
 
 from typing import (
     Any,
@@ -246,7 +245,7 @@ class LocalDataToArrowConversion:
         elif isinstance(dataType, UserDefinedType):
             udt: UserDefinedType = dataType
 
-            conv = 
LocalDataToArrowConversion._create_converter(dataType.sqlType())
+            conv = LocalDataToArrowConversion._create_converter(udt.sqlType())
 
             def convert_udt(value: Any) -> Any:
                 if value is None:
@@ -428,7 +427,7 @@ class ArrowTableToRowsConversion:
         elif isinstance(dataType, UserDefinedType):
             udt: UserDefinedType = dataType
 
-            conv = 
ArrowTableToRowsConversion._create_converter(dataType.sqlType())
+            conv = ArrowTableToRowsConversion._create_converter(udt.sqlType())
 
             def convert_udt(value: Any) -> Any:
                 if value is None:
diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index 70aa53ed73e..46218bb4dc0 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -76,7 +76,7 @@ from pyspark.sql.connect.functions import (
     lit,
     expr as sql_expression,
 )
-from pyspark.sql.connect.types import from_arrow_schema
+from pyspark.sql.pandas.types import from_arrow_schema
 
 
 if TYPE_CHECKING:
diff --git a/python/pyspark/sql/connect/session.py 
b/python/pyspark/sql/connect/session.py
index 7932ab54081..2d58ce1daf0 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -331,8 +331,12 @@ class SparkSession:
 
             # Determine arrow types to coerce data when creating batches
             arrow_schema: Optional[pa.Schema] = None
+            spark_types: List[Optional[DataType]]
+            arrow_types: List[Optional[pa.DataType]]
             if isinstance(schema, StructType):
-                arrow_schema = to_arrow_schema(cast(StructType, 
_deduplicate_field_names(schema)))
+                deduped_schema = cast(StructType, 
_deduplicate_field_names(schema))
+                spark_types = [field.dataType for field in 
deduped_schema.fields]
+                arrow_schema = to_arrow_schema(deduped_schema)
                 arrow_types = [field.type for field in arrow_schema]
                 _cols = [str(x) if not isinstance(x, str) else x for x in 
schema.fieldNames()]
             elif isinstance(schema, DataType):
@@ -342,14 +346,15 @@ class SparkSession:
                 )
             else:
                 # Any timestamps must be coerced to be compatible with Spark
-                arrow_types = [
-                    to_arrow_type(TimestampType())
+                spark_types = [
+                    TimestampType()
                     if is_datetime64_dtype(t) or is_datetime64tz_dtype(t)
-                    else to_arrow_type(DayTimeIntervalType())
+                    else DayTimeIntervalType()
                     if is_timedelta64_dtype(t)
                     else None
                     for t in data.dtypes
                 ]
+                arrow_types = [to_arrow_type(dt) if dt is not None else None 
for dt in spark_types]
 
             timezone, safecheck = self._client.get_configs(
                 "spark.sql.session.timeZone", 
"spark.sql.execution.pandas.convertToArrowArraySafely"
@@ -358,7 +363,14 @@ class SparkSession:
             ser = ArrowStreamPandasSerializer(cast(str, timezone), safecheck 
== "true")
 
             _table = pa.Table.from_batches(
-                [ser._create_batch([(c, t) for (_, c), t in zip(data.items(), 
arrow_types)])]
+                [
+                    ser._create_batch(
+                        [
+                            (c, at, st)
+                            for (_, c), at, st in zip(data.items(), 
arrow_types, spark_types)
+                        ]
+                    )
+                ]
             )
 
             if isinstance(schema, StructType):
diff --git a/python/pyspark/sql/connect/types.py 
b/python/pyspark/sql/connect/types.py
index fa8f9f5f8ff..2a21cdf0675 100644
--- a/python/pyspark/sql/connect/types.py
+++ b/python/pyspark/sql/connect/types.py
@@ -20,8 +20,6 @@ check_dependencies(__name__)
 
 import json
 
-import pyarrow as pa
-
 from typing import Any, Dict, Optional
 
 from pyspark.sql.types import (
@@ -299,147 +297,3 @@ def proto_schema_to_pyspark_data_type(schema: 
pb2.DataType) -> DataType:
         return UserDefinedType.fromJson(json_value)
     else:
         raise Exception(f"Unsupported data type {schema}")
-
-
-def to_arrow_type(dt: DataType) -> "pa.DataType":
-    """
-    Convert Spark data type to pyarrow type.
-
-    This function refers to 'pyspark.sql.pandas.types.to_arrow_type' but relax 
the restriction,
-    e.g. it supports nested StructType.
-    """
-    if type(dt) == BooleanType:
-        arrow_type = pa.bool_()
-    elif type(dt) == ByteType:
-        arrow_type = pa.int8()
-    elif type(dt) == ShortType:
-        arrow_type = pa.int16()
-    elif type(dt) == IntegerType:
-        arrow_type = pa.int32()
-    elif type(dt) == LongType:
-        arrow_type = pa.int64()
-    elif type(dt) == FloatType:
-        arrow_type = pa.float32()
-    elif type(dt) == DoubleType:
-        arrow_type = pa.float64()
-    elif type(dt) == DecimalType:
-        arrow_type = pa.decimal128(dt.precision, dt.scale)
-    elif type(dt) == StringType:
-        arrow_type = pa.string()
-    elif type(dt) == BinaryType:
-        arrow_type = pa.binary()
-    elif type(dt) == DateType:
-        arrow_type = pa.date32()
-    elif type(dt) == TimestampType:
-        # Timestamps should be in UTC, JVM Arrow timestamps require a timezone 
to be read
-        arrow_type = pa.timestamp("us", tz="UTC")
-    elif type(dt) == TimestampNTZType:
-        arrow_type = pa.timestamp("us", tz=None)
-    elif type(dt) == DayTimeIntervalType:
-        arrow_type = pa.duration("us")
-    elif type(dt) == ArrayType:
-        field = pa.field("element", to_arrow_type(dt.elementType), 
nullable=dt.containsNull)
-        arrow_type = pa.list_(field)
-    elif type(dt) == MapType:
-        key_field = pa.field("key", to_arrow_type(dt.keyType), nullable=False)
-        value_field = pa.field("value", to_arrow_type(dt.valueType), 
nullable=dt.valueContainsNull)
-        arrow_type = pa.map_(key_field, value_field)
-    elif type(dt) == StructType:
-        fields = [
-            pa.field(field.name, to_arrow_type(field.dataType), 
nullable=field.nullable)
-            for field in dt
-        ]
-        arrow_type = pa.struct(fields)
-    elif type(dt) == NullType:
-        arrow_type = pa.null()
-    elif isinstance(dt, UserDefinedType):
-        arrow_type = to_arrow_type(dt.sqlType())
-    else:
-        raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
-    return arrow_type
-
-
-def to_arrow_schema(schema: StructType) -> "pa.Schema":
-    """Convert a schema from Spark to Arrow"""
-    fields = [
-        pa.field(field.name, to_arrow_type(field.dataType), 
nullable=field.nullable)
-        for field in schema
-    ]
-    return pa.schema(fields)
-
-
-def from_arrow_type(at: "pa.DataType", prefer_timestamp_ntz: bool = False) -> 
DataType:
-    """Convert pyarrow type to Spark data type.
-
-    This function refers to 'pyspark.sql.pandas.types.from_arrow_type' but 
relax the restriction,
-    e.g. it supports nested StructType, Array of TimestampType. However, Arrow 
DictionaryType is
-    not allowed.
-    """
-    import pyarrow.types as types
-
-    spark_type: DataType
-    if types.is_boolean(at):
-        spark_type = BooleanType()
-    elif types.is_int8(at):
-        spark_type = ByteType()
-    elif types.is_int16(at):
-        spark_type = ShortType()
-    elif types.is_int32(at):
-        spark_type = IntegerType()
-    elif types.is_int64(at):
-        spark_type = LongType()
-    elif types.is_float32(at):
-        spark_type = FloatType()
-    elif types.is_float64(at):
-        spark_type = DoubleType()
-    elif types.is_decimal(at):
-        spark_type = DecimalType(precision=at.precision, scale=at.scale)
-    elif types.is_string(at):
-        spark_type = StringType()
-    elif types.is_binary(at):
-        spark_type = BinaryType()
-    elif types.is_date32(at):
-        spark_type = DateType()
-    elif types.is_timestamp(at) and prefer_timestamp_ntz and at.tz is None:
-        spark_type = TimestampNTZType()
-    elif types.is_timestamp(at):
-        spark_type = TimestampType()
-    elif types.is_duration(at):
-        spark_type = DayTimeIntervalType()
-    elif types.is_list(at):
-        spark_type = ArrayType(from_arrow_type(at.value_type, 
prefer_timestamp_ntz))
-    elif types.is_map(at):
-        spark_type = MapType(
-            from_arrow_type(at.key_type, prefer_timestamp_ntz),
-            from_arrow_type(at.item_type, prefer_timestamp_ntz),
-        )
-    elif types.is_struct(at):
-        return StructType(
-            [
-                StructField(
-                    field.name,
-                    from_arrow_type(field.type, prefer_timestamp_ntz),
-                    nullable=field.nullable,
-                )
-                for field in at
-            ]
-        )
-    elif types.is_null(at):
-        spark_type = NullType()
-    else:
-        raise TypeError("Unsupported type in conversion from Arrow: " + 
str(at))
-    return spark_type
-
-
-def from_arrow_schema(arrow_schema: "pa.Schema", prefer_timestamp_ntz: bool = 
False) -> StructType:
-    """Convert schema from Arrow to Spark."""
-    return StructType(
-        [
-            StructField(
-                field.name,
-                from_arrow_type(field.type, prefer_timestamp_ntz),
-                nullable=field.nullable,
-            )
-            for field in arrow_schema
-        ]
-    )
diff --git a/python/pyspark/sql/pandas/conversion.py 
b/python/pyspark/sql/pandas/conversion.py
index 5e147aff48d..8664c4df73e 100644
--- a/python/pyspark/sql/pandas/conversion.py
+++ b/python/pyspark/sql/pandas/conversion.py
@@ -598,9 +598,7 @@ class SparkConversionMixin:
 
         # Determine arrow types to coerce data when creating batches
         if isinstance(schema, StructType):
-            arrow_types = [
-                to_arrow_type(_deduplicate_field_names(f.dataType)) for f in 
schema.fields
-            ]
+            spark_types = [_deduplicate_field_names(f.dataType) for f in 
schema.fields]
         elif isinstance(schema, DataType):
             raise PySparkTypeError(
                 error_class="UNSUPPORTED_DATA_TYPE_FOR_ARROW",
@@ -608,10 +606,8 @@ class SparkConversionMixin:
             )
         else:
             # Any timestamps must be coerced to be compatible with Spark
-            arrow_types = [
-                to_arrow_type(TimestampType())
-                if is_datetime64_dtype(t) or is_datetime64tz_dtype(t)
-                else None
+            spark_types = [
+                TimestampType() if is_datetime64_dtype(t) or 
is_datetime64tz_dtype(t) else None
                 for t in pdf.dtypes
             ]
 
@@ -619,9 +615,12 @@ class SparkConversionMixin:
         step = self._jconf.arrowMaxRecordsPerBatch()
         pdf_slices = (pdf.iloc[start : start + step] for start in range(0, 
len(pdf), step))
 
-        # Create list of Arrow (columns, type) for serializer dump_stream
+        # Create list of Arrow (columns, arrow_type, spark_type) for 
serializer dump_stream
         arrow_data = [
-            [(c, t) for (_, c), t in zip(pdf_slice.items(), arrow_types)]
+            [
+                (c, to_arrow_type(t) if t is not None else None, t)
+                for (_, c), t in zip(pdf_slice.items(), spark_types)
+            ]
             for pdf_slice in pdf_slices
         ]
 
diff --git a/python/pyspark/sql/pandas/serializers.py 
b/python/pyspark/sql/pandas/serializers.py
index e81d90fc23e..84471143367 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -27,7 +27,7 @@ from pyspark.sql.pandas.types import (
     _create_converter_from_pandas,
     _create_converter_to_pandas,
 )
-from pyspark.sql.types import StringType, StructType, BinaryType, StructField, 
LongType
+from pyspark.sql.types import DataType, StringType, StructType, BinaryType, 
StructField, LongType
 
 
 class SpecialLengths:
@@ -189,7 +189,7 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
         )
         return converter(s)
 
-    def _create_array(self, series, arrow_type):
+    def _create_array(self, series, arrow_type, spark_type=None):
         """
         Create an Arrow Array from the given pandas.Series and optional type.
 
@@ -199,6 +199,8 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
             A single series
         arrow_type : pyarrow.DataType, optional
             If None, pyarrow's inferred type will be used
+        spark_type : DataType, optional
+            If None, spark type converted from arrow_type will be used
 
         Returns
         -------
@@ -211,10 +213,10 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
             series = series.astype(series.dtypes.categories.dtype)
 
         if arrow_type is not None:
-            spark_type = from_arrow_type(arrow_type, prefer_timestamp_ntz=True)
+            dt = spark_type or from_arrow_type(arrow_type, 
prefer_timestamp_ntz=True)
             # TODO(SPARK-43579): cache the converter for reuse
             conv = _create_converter_from_pandas(
-                spark_type, timezone=self._timezone, 
error_on_duplicated_field_names=False
+                dt, timezone=self._timezone, 
error_on_duplicated_field_names=False
             )
             series = conv(series)
 
@@ -261,14 +263,24 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
         """
         import pyarrow as pa
 
-        # Make input conform to [(series1, type1), (series2, type2), ...]
-        if not isinstance(series, (list, tuple)) or (
-            len(series) == 2 and isinstance(series[1], pa.DataType)
+        # Make input conform to
+        # [(series1, arrow_type1, spark_type1), (series2, arrow_type2, 
spark_type2), ...]
+        if (
+            not isinstance(series, (list, tuple))
+            or (len(series) == 2 and isinstance(series[1], pa.DataType))
+            or (
+                len(series) == 3
+                and isinstance(series[1], pa.DataType)
+                and isinstance(series[2], DataType)
+            )
         ):
             series = [series]
         series = ((s, None) if not isinstance(s, (list, tuple)) else s for s 
in series)
+        series = ((s[0], s[1], None) if len(s) == 2 else s for s in series)
 
-        arrs = [self._create_array(s, t) for s, t in series]
+        arrs = [
+            self._create_array(s, arrow_type, spark_type) for s, arrow_type, 
spark_type in series
+        ]
         return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in 
range(len(arrs))])
 
     def dump_stream(self, iterator, stream):
diff --git a/python/pyspark/sql/pandas/types.py 
b/python/pyspark/sql/pandas/types.py
index adf497bbd73..ae7c25e0828 100644
--- a/python/pyspark/sql/pandas/types.py
+++ b/python/pyspark/sql/pandas/types.py
@@ -46,6 +46,7 @@ from pyspark.sql.types import (
     StructField,
     NullType,
     DataType,
+    UserDefinedType,
     Row,
     _create_row,
 )
@@ -119,6 +120,8 @@ def to_arrow_type(dt: DataType) -> "pa.DataType":
         arrow_type = pa.struct(fields)
     elif type(dt) == NullType:
         arrow_type = pa.null()
+    elif isinstance(dt, UserDefinedType):
+        arrow_type = to_arrow_type(dt.sqlType())
     else:
         raise PySparkTypeError(
             error_class="UNSUPPORTED_DATA_TYPE_FOR_ARROW_CONVERSION",
@@ -561,10 +564,12 @@ def _create_converter_to_pandas(
 
         return correct_dtype
 
-    def _converter(dt: DataType) -> Optional[Callable[[Any], Any]]:
+    def _converter(
+        dt: DataType, _struct_in_pandas: Optional[str]
+    ) -> Optional[Callable[[Any], Any]]:
 
         if isinstance(dt, ArrayType):
-            _element_conv = _converter(dt.elementType)
+            _element_conv = _converter(dt.elementType, _struct_in_pandas)
             if _element_conv is None:
                 return None
 
@@ -582,8 +587,8 @@ def _create_converter_to_pandas(
             return convert_array
 
         elif isinstance(dt, MapType):
-            _key_conv = _converter(dt.keyType) or (lambda x: x)
-            _value_conv = _converter(dt.valueType) or (lambda x: x)
+            _key_conv = _converter(dt.keyType, _struct_in_pandas) or (lambda 
x: x)
+            _value_conv = _converter(dt.valueType, _struct_in_pandas) or 
(lambda x: x)
 
             def convert_map(value: Any) -> Any:
                 if value is None:
@@ -599,7 +604,7 @@ def _create_converter_to_pandas(
             return convert_map
 
         elif isinstance(dt, StructType):
-            assert struct_in_pandas is not None
+            assert _struct_in_pandas is not None
 
             field_names = dt.names
 
@@ -611,9 +616,11 @@ def _create_converter_to_pandas(
 
             dedup_field_names = _dedup_names(field_names)
 
-            field_convs = [_converter(f.dataType) or (lambda x: x) for f in 
dt.fields]
+            field_convs = [
+                _converter(f.dataType, _struct_in_pandas) or (lambda x: x) for 
f in dt.fields
+            ]
 
-            if struct_in_pandas == "row":
+            if _struct_in_pandas == "row":
 
                 def convert_struct_as_row(value: Any) -> Any:
                     if value is None:
@@ -633,7 +640,7 @@ def _create_converter_to_pandas(
 
                 return convert_struct_as_row
 
-            elif struct_in_pandas == "dict":
+            elif _struct_in_pandas == "dict":
 
                 def convert_struct_as_dict(value: Any) -> Any:
                     if value is None:
@@ -654,7 +661,7 @@ def _create_converter_to_pandas(
                 return convert_struct_as_dict
 
             else:
-                raise ValueError(f"Unknown value for `struct_in_pandas`: 
{struct_in_pandas}")
+                raise ValueError(f"Unknown value for `struct_in_pandas`: 
{_struct_in_pandas}")
 
         elif isinstance(dt, TimestampType):
             assert timezone is not None
@@ -685,10 +692,26 @@ def _create_converter_to_pandas(
 
             return convert_timestamp_ntz
 
+        elif isinstance(dt, UserDefinedType):
+            udt: UserDefinedType = dt
+
+            conv = _converter(udt.sqlType(), _struct_in_pandas="row") or 
(lambda x: x)
+
+            def convert_udt(value: Any) -> Any:
+                if value is None:
+                    return None
+                elif hasattr(value, "__UDT__"):
+                    assert isinstance(value.__UDT__, type(udt))
+                    return value
+                else:
+                    return udt.deserialize(conv(value))
+
+            return convert_udt
+
         else:
             return None
 
-    conv = _converter(data_type)
+    conv = _converter(data_type, struct_in_pandas)
     if conv is not None:
         return lambda pser: pser.apply(conv)  # type: ignore[return-value]
     else:
@@ -779,7 +802,7 @@ def _create_converter_from_pandas(
                         for i, key in enumerate(field_names)
                     }
                 else:
-                    assert isinstance(value, Row)
+                    assert isinstance(value, tuple)
                     return {dedup_field_names[i]: field_convs[i](v) for i, v 
in enumerate(value)}
 
             return convert_struct
@@ -799,6 +822,19 @@ def _create_converter_from_pandas(
 
             return convert_timestamp
 
+        elif isinstance(dt, UserDefinedType):
+            udt: UserDefinedType = dt
+
+            conv = _converter(udt.sqlType()) or (lambda x: x)
+
+            def convert_udt(value: Any) -> Any:
+                if value is None:
+                    return None
+                else:
+                    return conv(udt.serialize(value))
+
+            return convert_udt
+
         return None
 
     conv = _converter(data_type)
diff --git a/python/pyspark/sql/tests/connect/test_parity_arrow.py 
b/python/pyspark/sql/tests/connect/test_parity_arrow.py
index d1c8a1a55a0..60f1ef257c5 100644
--- a/python/pyspark/sql/tests/connect/test_parity_arrow.py
+++ b/python/pyspark/sql/tests/connect/test_parity_arrow.py
@@ -121,6 +121,12 @@ class ArrowParityTests(ArrowTestsMixin, 
ReusedConnectTestCase):
     def test_toPandas_nested_timestamp(self):
         self.check_toPandas_nested_timestamp(True)
 
+    def test_createDataFrame_udt(self):
+        self.check_createDataFrame_udt(True)
+
+    def test_toPandas_udt(self):
+        self.check_toPandas_udt(True)
+
 
 if __name__ == "__main__":
     from pyspark.sql.tests.connect.test_parity_arrow import *  # noqa: F401
diff --git a/python/pyspark/sql/tests/test_arrow.py 
b/python/pyspark/sql/tests/test_arrow.py
index dfde747c265..e26aabbea27 100644
--- a/python/pyspark/sql/tests/test_arrow.py
+++ b/python/pyspark/sql/tests/test_arrow.py
@@ -53,6 +53,8 @@ from pyspark.testing.sqlutils import (
     have_pyarrow,
     pandas_requirement_message,
     pyarrow_requirement_message,
+    ExamplePoint,
+    ExamplePointUDT,
 )
 from pyspark.testing.utils import QuietTest
 from pyspark.errors import ArithmeticException, PySparkTypeError, 
UnsupportedOperationException
@@ -1022,7 +1024,7 @@ class ArrowTestsMixin:
         df = self.spark.range(2).select([])
 
         with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": 
arrow_enabled}):
-            assert_frame_equal(df.toPandas(), pd.DataFrame(index=range(2)))
+            assert_frame_equal(df.toPandas(), pd.DataFrame(columns=[], 
index=range(2)))
 
     def test_createDataFrame_nested_timestamp(self):
         for arrow_enabled in [True, False]:
@@ -1143,6 +1145,69 @@ class ArrowTestsMixin:
 
         assert_frame_equal(pdf, expected)
 
+    def test_createDataFrame_udt(self):
+        for arrow_enabled in [True, False]:
+            with self.subTest(arrow_enabled=arrow_enabled):
+                self.check_createDataFrame_udt(arrow_enabled)
+
+    def check_createDataFrame_udt(self, arrow_enabled):
+        schema = (
+            StructType()
+            .add("point", ExamplePointUDT())
+            .add("struct", StructType().add("point", ExamplePointUDT()))
+            .add("array", ArrayType(ExamplePointUDT()))
+            .add("map", MapType(StringType(), ExamplePointUDT()))
+        )
+        data = [
+            Row(
+                ExamplePoint(1.0, 2.0),
+                Row(ExamplePoint(3.0, 4.0)),
+                [ExamplePoint(5.0, 6.0)],
+                dict(point=ExamplePoint(7.0, 8.0)),
+            )
+        ]
+        pdf = pd.DataFrame.from_records(data, columns=schema.names)
+
+        with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": 
arrow_enabled}):
+            df = self.spark.createDataFrame(pdf, schema)
+
+        self.assertEqual(df.collect(), data)
+
+    def test_toPandas_udt(self):
+        for arrow_enabled in [True, False]:
+            with self.subTest(arrow_enabled=arrow_enabled):
+                self.check_toPandas_udt(arrow_enabled)
+
+    def check_toPandas_udt(self, arrow_enabled):
+        schema = (
+            StructType()
+            .add("point", ExamplePointUDT())
+            .add("struct", StructType().add("point", ExamplePointUDT()))
+            .add("array", ArrayType(ExamplePointUDT()))
+            .add("map", MapType(StringType(), ExamplePointUDT()))
+        )
+        data = [
+            Row(
+                ExamplePoint(1.0, 2.0),
+                Row(ExamplePoint(3.0, 4.0)),
+                [ExamplePoint(5.0, 6.0)],
+                dict(point=ExamplePoint(7.0, 8.0)),
+            )
+        ]
+        df = self.spark.createDataFrame(data, schema)
+
+        with self.sql_conf(
+            {
+                "spark.sql.execution.arrow.pyspark.enabled": arrow_enabled,
+                "spark.sql.execution.pandas.structHandlingMode": "row",
+            }
+        ):
+            pdf = df.toPandas()
+
+        expected = pd.DataFrame.from_records(data, columns=schema.names)
+
+        assert_frame_equal(pdf, expected)
+
 
 @unittest.skipIf(
     not have_pandas or not have_pyarrow,


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to