This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 916b0d3de97 [SPARK-43817][SPARK-43702][PYTHON] Support UserDefinedType in createDataFrame from pandas DataFrame and toPandas 916b0d3de97 is described below commit 916b0d3de973b8b30a8ede3d56b9f8a711110512 Author: Takuya UESHIN <ues...@databricks.com> AuthorDate: Sun May 28 08:47:35 2023 +0800 [SPARK-43817][SPARK-43702][PYTHON] Support UserDefinedType in createDataFrame from pandas DataFrame and toPandas ### What changes were proposed in this pull request? Support `UserDefinedType` in `createDataFrame` from pandas DataFrame and `toPandas`. For the following schema and pandas DataFrame: ```py schema = ( StructType() .add("point", ExamplePointUDT()) .add("struct", StructType().add("point", ExamplePointUDT())) .add("array", ArrayType(ExamplePointUDT())) .add("map", MapType(StringType(), ExamplePointUDT())) ) data = [ Row( ExamplePoint(1.0, 2.0), Row(ExamplePoint(3.0, 4.0)), [ExamplePoint(5.0, 6.0)], dict(point=ExamplePoint(7.0, 8.0)), ) ] df = spark.createDataFrame(data, schema) pdf = pd.DataFrame.from_records(data, columns=schema.names) ``` ##### `spark.createDataFrame()` For all, return the same results: ```py >>> spark.createDataFrame(pdf, schema).show(truncate=False) +----------+------------+------------+---------------------+ |point |struct |array |map | +----------+------------+------------+---------------------+ |(1.0, 2.0)|{(3.0, 4.0)}|[(5.0, 6.0)]|{point -> (7.0, 8.0)}| +----------+------------+------------+---------------------+ ``` ##### `df.toPandas()` ```py >>> spark.conf.set('spark.sql.execution.pandas.structHandlingMode', 'row') >>> df.toPandas() point struct array map 0 (1.0,2.0) ((3.0,4.0),) [(5.0,6.0)] {'point': (7.0,8.0)} ``` ### Why are the changes needed? Currently `UserDefinedType` in `spark.createDataFrame()` with pandas DataFrame and `df.toPandas()` is not supported with Arrow enabled or in Spark Connect. ##### `spark.createDataFrame()` Works without Arrow: ```py >>> spark.createDataFrame(pdf, schema).show(truncate=False) +----------+------------+------------+---------------------+ |point |struct |array |map | +----------+------------+------------+---------------------+ |(1.0, 2.0)|{(3.0, 4.0)}|[(5.0, 6.0)]|{point -> (7.0, 8.0)}| +----------+------------+------------+---------------------+ ``` , whereas: - With Arrow: Works with fallback: ```py >>> spark.createDataFrame(pdf, schema).show(truncate=False) /.../python/pyspark/sql/pandas/conversion.py:351: UserWarning: createDataFrame attempted Arrow optimization because 'spark.sql.execution.arrow.pyspark.enabled' is set to true; however, failed by the reason below: [UNSUPPORTED_DATA_TYPE_FOR_ARROW_CONVERSION] ExamplePointUDT() is not supported in conversion to Arrow. Attempting non-optimization as 'spark.sql.execution.arrow.pyspark.fallback.enabled' is set to true. warn(msg) +----------+------------+------------+---------------------+ |point |struct |array |map | +----------+------------+------------+---------------------+ |(1.0, 2.0)|{(3.0, 4.0)}|[(5.0, 6.0)]|{point -> (7.0, 8.0)}| +----------+------------+------------+---------------------+ ``` - Spark Connect ```py >>> spark.createDataFrame(pdf, schema).show(truncate=False) Traceback (most recent call last): ... pyspark.errors.exceptions.base.PySparkTypeError: [UNSUPPORTED_DATA_TYPE_FOR_ARROW_CONVERSION] ExamplePointUDT() is not supported in conversion to Arrow. ``` ##### `df.toPandas()` Works without Arrow: ```py >>> spark.conf.set('spark.sql.execution.pandas.structHandlingMode', 'row') >>> df.toPandas() point struct array map 0 (1.0,2.0) ((3.0,4.0),) [(5.0,6.0)] {'point': (7.0,8.0)} ``` , whereas: - With Arrow Works with fallback: ```py >>> spark.conf.set('spark.sql.execution.pandas.structHandlingMode', 'row') >>> df.toPandas() /.../python/pyspark/sql/pandas/conversion.py:111: UserWarning: toPandas attempted Arrow optimization because 'spark.sql.execution.arrow.pyspark.enabled' is set to true; however, failed by the reason below: [UNSUPPORTED_DATA_TYPE_FOR_ARROW_CONVERSION] ExamplePointUDT() is not supported in conversion to Arrow. Attempting non-optimization as 'spark.sql.execution.arrow.pyspark.fallback.enabled' is set to true. warn(msg) point struct array map 0 (1.0,2.0) ((3.0,4.0),) [(5.0,6.0)] {'point': (7.0,8.0)} ``` - Spark Connect Results with the internal type: ```py >>> spark.conf.set('spark.sql.execution.pandas.structHandlingMode', 'row') >>> df.toPandas() point struct array map 0 [1.0, 2.0] ([3.0, 4.0],) [[5.0, 6.0]] {'point': [7.0, 8.0]} ``` ### Does this PR introduce _any_ user-facing change? Users will be able to use `UserDefinedType`. ### How was this patch tested? Added the related tests. Closes #41333 from ueshin/issues/SPARK-43817/udt. Authored-by: Takuya UESHIN <ues...@databricks.com> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../connect/data_type_ops/test_parity_udt_ops.py | 48 +------ python/pyspark/sql/connect/client/core.py | 4 +- python/pyspark/sql/connect/conversion.py | 7 +- python/pyspark/sql/connect/dataframe.py | 2 +- python/pyspark/sql/connect/session.py | 22 +++- python/pyspark/sql/connect/types.py | 146 --------------------- python/pyspark/sql/pandas/conversion.py | 17 ++- python/pyspark/sql/pandas/serializers.py | 28 ++-- python/pyspark/sql/pandas/types.py | 58 ++++++-- .../pyspark/sql/tests/connect/test_parity_arrow.py | 6 + python/pyspark/sql/tests/test_arrow.py | 67 +++++++++- 11 files changed, 171 insertions(+), 234 deletions(-) diff --git a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_udt_ops.py b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_udt_ops.py index 81511829c06..70a79e4cd3f 100644 --- a/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_udt_ops.py +++ b/python/pyspark/pandas/tests/connect/data_type_ops/test_parity_udt_ops.py @@ -25,53 +25,7 @@ from pyspark.testing.connectutils import ReusedConnectTestCase class UDTOpsParityTests( UDTOpsTestsMixin, PandasOnSparkTestUtils, OpsTestBase, ReusedConnectTestCase ): - @unittest.skip( - "TODO(SPARK-43702): Fix pyspark.sql.pandas.types.to_arrow_type to work with Spark Connect." - ) - def test_eq(self): - super().test_eq() - - @unittest.skip( - "TODO(SPARK-43702): Fix pyspark.sql.pandas.types.to_arrow_type to work with Spark Connect." - ) - def test_from_to_pandas(self): - super().test_from_to_pandas() - - @unittest.skip( - "TODO(SPARK-43702): Fix pyspark.sql.pandas.types.to_arrow_type to work with Spark Connect." - ) - def test_ge(self): - super().test_ge() - - @unittest.skip( - "TODO(SPARK-43702): Fix pyspark.sql.pandas.types.to_arrow_type to work with Spark Connect." - ) - def test_gt(self): - super().test_gt() - - @unittest.skip( - "TODO(SPARK-43702): Fix pyspark.sql.pandas.types.to_arrow_type to work with Spark Connect." - ) - def test_isnull(self): - super().test_isnull() - - @unittest.skip( - "TODO(SPARK-43702): Fix pyspark.sql.pandas.types.to_arrow_type to work with Spark Connect." - ) - def test_le(self): - super().test_le() - - @unittest.skip( - "TODO(SPARK-43702): Fix pyspark.sql.pandas.types.to_arrow_type to work with Spark Connect." - ) - def test_lt(self): - super().test_lt() - - @unittest.skip( - "TODO(SPARK-43702): Fix pyspark.sql.pandas.types.to_arrow_type to work with Spark Connect." - ) - def test_ne(self): - super().test_ne() + pass if __name__ == "__main__": diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 544ed5d4183..a0f790b2992 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -75,7 +75,7 @@ from pyspark.sql.connect.expressions import ( CommonInlineUserDefinedFunction, JavaUDF, ) -from pyspark.sql.pandas.types import _create_converter_to_pandas +from pyspark.sql.pandas.types import _create_converter_to_pandas, from_arrow_schema from pyspark.sql.types import DataType, StructType, TimestampType, _has_type from pyspark.rdd import PythonEvalType from pyspark.storagelevel import StorageLevel @@ -717,7 +717,7 @@ class SparkConnectClient(object): table, schema, metrics, observed_metrics, _ = self._execute_and_fetch(req) assert table is not None - schema = schema or types.from_arrow_schema(table.schema, prefer_timestamp_ntz=True) + schema = schema or from_arrow_schema(table.schema, prefer_timestamp_ntz=True) assert schema is not None and isinstance(schema, StructType) # Rename columns to avoid duplicated column names. diff --git a/python/pyspark/sql/connect/conversion.py b/python/pyspark/sql/connect/conversion.py index 3cc301c38ea..cdbc3a1e39c 100644 --- a/python/pyspark/sql/connect/conversion.py +++ b/python/pyspark/sql/connect/conversion.py @@ -42,9 +42,8 @@ from pyspark.sql.types import ( ) from pyspark.storagelevel import StorageLevel -from pyspark.sql.connect.types import to_arrow_schema import pyspark.sql.connect.proto as pb2 -from pyspark.sql.pandas.types import _dedup_names, _deduplicate_field_names +from pyspark.sql.pandas.types import to_arrow_schema, _dedup_names, _deduplicate_field_names from typing import ( Any, @@ -246,7 +245,7 @@ class LocalDataToArrowConversion: elif isinstance(dataType, UserDefinedType): udt: UserDefinedType = dataType - conv = LocalDataToArrowConversion._create_converter(dataType.sqlType()) + conv = LocalDataToArrowConversion._create_converter(udt.sqlType()) def convert_udt(value: Any) -> Any: if value is None: @@ -428,7 +427,7 @@ class ArrowTableToRowsConversion: elif isinstance(dataType, UserDefinedType): udt: UserDefinedType = dataType - conv = ArrowTableToRowsConversion._create_converter(dataType.sqlType()) + conv = ArrowTableToRowsConversion._create_converter(udt.sqlType()) def convert_udt(value: Any) -> Any: if value is None: diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 70aa53ed73e..46218bb4dc0 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -76,7 +76,7 @@ from pyspark.sql.connect.functions import ( lit, expr as sql_expression, ) -from pyspark.sql.connect.types import from_arrow_schema +from pyspark.sql.pandas.types import from_arrow_schema if TYPE_CHECKING: diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 7932ab54081..2d58ce1daf0 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -331,8 +331,12 @@ class SparkSession: # Determine arrow types to coerce data when creating batches arrow_schema: Optional[pa.Schema] = None + spark_types: List[Optional[DataType]] + arrow_types: List[Optional[pa.DataType]] if isinstance(schema, StructType): - arrow_schema = to_arrow_schema(cast(StructType, _deduplicate_field_names(schema))) + deduped_schema = cast(StructType, _deduplicate_field_names(schema)) + spark_types = [field.dataType for field in deduped_schema.fields] + arrow_schema = to_arrow_schema(deduped_schema) arrow_types = [field.type for field in arrow_schema] _cols = [str(x) if not isinstance(x, str) else x for x in schema.fieldNames()] elif isinstance(schema, DataType): @@ -342,14 +346,15 @@ class SparkSession: ) else: # Any timestamps must be coerced to be compatible with Spark - arrow_types = [ - to_arrow_type(TimestampType()) + spark_types = [ + TimestampType() if is_datetime64_dtype(t) or is_datetime64tz_dtype(t) - else to_arrow_type(DayTimeIntervalType()) + else DayTimeIntervalType() if is_timedelta64_dtype(t) else None for t in data.dtypes ] + arrow_types = [to_arrow_type(dt) if dt is not None else None for dt in spark_types] timezone, safecheck = self._client.get_configs( "spark.sql.session.timeZone", "spark.sql.execution.pandas.convertToArrowArraySafely" @@ -358,7 +363,14 @@ class SparkSession: ser = ArrowStreamPandasSerializer(cast(str, timezone), safecheck == "true") _table = pa.Table.from_batches( - [ser._create_batch([(c, t) for (_, c), t in zip(data.items(), arrow_types)])] + [ + ser._create_batch( + [ + (c, at, st) + for (_, c), at, st in zip(data.items(), arrow_types, spark_types) + ] + ) + ] ) if isinstance(schema, StructType): diff --git a/python/pyspark/sql/connect/types.py b/python/pyspark/sql/connect/types.py index fa8f9f5f8ff..2a21cdf0675 100644 --- a/python/pyspark/sql/connect/types.py +++ b/python/pyspark/sql/connect/types.py @@ -20,8 +20,6 @@ check_dependencies(__name__) import json -import pyarrow as pa - from typing import Any, Dict, Optional from pyspark.sql.types import ( @@ -299,147 +297,3 @@ def proto_schema_to_pyspark_data_type(schema: pb2.DataType) -> DataType: return UserDefinedType.fromJson(json_value) else: raise Exception(f"Unsupported data type {schema}") - - -def to_arrow_type(dt: DataType) -> "pa.DataType": - """ - Convert Spark data type to pyarrow type. - - This function refers to 'pyspark.sql.pandas.types.to_arrow_type' but relax the restriction, - e.g. it supports nested StructType. - """ - if type(dt) == BooleanType: - arrow_type = pa.bool_() - elif type(dt) == ByteType: - arrow_type = pa.int8() - elif type(dt) == ShortType: - arrow_type = pa.int16() - elif type(dt) == IntegerType: - arrow_type = pa.int32() - elif type(dt) == LongType: - arrow_type = pa.int64() - elif type(dt) == FloatType: - arrow_type = pa.float32() - elif type(dt) == DoubleType: - arrow_type = pa.float64() - elif type(dt) == DecimalType: - arrow_type = pa.decimal128(dt.precision, dt.scale) - elif type(dt) == StringType: - arrow_type = pa.string() - elif type(dt) == BinaryType: - arrow_type = pa.binary() - elif type(dt) == DateType: - arrow_type = pa.date32() - elif type(dt) == TimestampType: - # Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read - arrow_type = pa.timestamp("us", tz="UTC") - elif type(dt) == TimestampNTZType: - arrow_type = pa.timestamp("us", tz=None) - elif type(dt) == DayTimeIntervalType: - arrow_type = pa.duration("us") - elif type(dt) == ArrayType: - field = pa.field("element", to_arrow_type(dt.elementType), nullable=dt.containsNull) - arrow_type = pa.list_(field) - elif type(dt) == MapType: - key_field = pa.field("key", to_arrow_type(dt.keyType), nullable=False) - value_field = pa.field("value", to_arrow_type(dt.valueType), nullable=dt.valueContainsNull) - arrow_type = pa.map_(key_field, value_field) - elif type(dt) == StructType: - fields = [ - pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable) - for field in dt - ] - arrow_type = pa.struct(fields) - elif type(dt) == NullType: - arrow_type = pa.null() - elif isinstance(dt, UserDefinedType): - arrow_type = to_arrow_type(dt.sqlType()) - else: - raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) - return arrow_type - - -def to_arrow_schema(schema: StructType) -> "pa.Schema": - """Convert a schema from Spark to Arrow""" - fields = [ - pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable) - for field in schema - ] - return pa.schema(fields) - - -def from_arrow_type(at: "pa.DataType", prefer_timestamp_ntz: bool = False) -> DataType: - """Convert pyarrow type to Spark data type. - - This function refers to 'pyspark.sql.pandas.types.from_arrow_type' but relax the restriction, - e.g. it supports nested StructType, Array of TimestampType. However, Arrow DictionaryType is - not allowed. - """ - import pyarrow.types as types - - spark_type: DataType - if types.is_boolean(at): - spark_type = BooleanType() - elif types.is_int8(at): - spark_type = ByteType() - elif types.is_int16(at): - spark_type = ShortType() - elif types.is_int32(at): - spark_type = IntegerType() - elif types.is_int64(at): - spark_type = LongType() - elif types.is_float32(at): - spark_type = FloatType() - elif types.is_float64(at): - spark_type = DoubleType() - elif types.is_decimal(at): - spark_type = DecimalType(precision=at.precision, scale=at.scale) - elif types.is_string(at): - spark_type = StringType() - elif types.is_binary(at): - spark_type = BinaryType() - elif types.is_date32(at): - spark_type = DateType() - elif types.is_timestamp(at) and prefer_timestamp_ntz and at.tz is None: - spark_type = TimestampNTZType() - elif types.is_timestamp(at): - spark_type = TimestampType() - elif types.is_duration(at): - spark_type = DayTimeIntervalType() - elif types.is_list(at): - spark_type = ArrayType(from_arrow_type(at.value_type, prefer_timestamp_ntz)) - elif types.is_map(at): - spark_type = MapType( - from_arrow_type(at.key_type, prefer_timestamp_ntz), - from_arrow_type(at.item_type, prefer_timestamp_ntz), - ) - elif types.is_struct(at): - return StructType( - [ - StructField( - field.name, - from_arrow_type(field.type, prefer_timestamp_ntz), - nullable=field.nullable, - ) - for field in at - ] - ) - elif types.is_null(at): - spark_type = NullType() - else: - raise TypeError("Unsupported type in conversion from Arrow: " + str(at)) - return spark_type - - -def from_arrow_schema(arrow_schema: "pa.Schema", prefer_timestamp_ntz: bool = False) -> StructType: - """Convert schema from Arrow to Spark.""" - return StructType( - [ - StructField( - field.name, - from_arrow_type(field.type, prefer_timestamp_ntz), - nullable=field.nullable, - ) - for field in arrow_schema - ] - ) diff --git a/python/pyspark/sql/pandas/conversion.py b/python/pyspark/sql/pandas/conversion.py index 5e147aff48d..8664c4df73e 100644 --- a/python/pyspark/sql/pandas/conversion.py +++ b/python/pyspark/sql/pandas/conversion.py @@ -598,9 +598,7 @@ class SparkConversionMixin: # Determine arrow types to coerce data when creating batches if isinstance(schema, StructType): - arrow_types = [ - to_arrow_type(_deduplicate_field_names(f.dataType)) for f in schema.fields - ] + spark_types = [_deduplicate_field_names(f.dataType) for f in schema.fields] elif isinstance(schema, DataType): raise PySparkTypeError( error_class="UNSUPPORTED_DATA_TYPE_FOR_ARROW", @@ -608,10 +606,8 @@ class SparkConversionMixin: ) else: # Any timestamps must be coerced to be compatible with Spark - arrow_types = [ - to_arrow_type(TimestampType()) - if is_datetime64_dtype(t) or is_datetime64tz_dtype(t) - else None + spark_types = [ + TimestampType() if is_datetime64_dtype(t) or is_datetime64tz_dtype(t) else None for t in pdf.dtypes ] @@ -619,9 +615,12 @@ class SparkConversionMixin: step = self._jconf.arrowMaxRecordsPerBatch() pdf_slices = (pdf.iloc[start : start + step] for start in range(0, len(pdf), step)) - # Create list of Arrow (columns, type) for serializer dump_stream + # Create list of Arrow (columns, arrow_type, spark_type) for serializer dump_stream arrow_data = [ - [(c, t) for (_, c), t in zip(pdf_slice.items(), arrow_types)] + [ + (c, to_arrow_type(t) if t is not None else None, t) + for (_, c), t in zip(pdf_slice.items(), spark_types) + ] for pdf_slice in pdf_slices ] diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index e81d90fc23e..84471143367 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -27,7 +27,7 @@ from pyspark.sql.pandas.types import ( _create_converter_from_pandas, _create_converter_to_pandas, ) -from pyspark.sql.types import StringType, StructType, BinaryType, StructField, LongType +from pyspark.sql.types import DataType, StringType, StructType, BinaryType, StructField, LongType class SpecialLengths: @@ -189,7 +189,7 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer): ) return converter(s) - def _create_array(self, series, arrow_type): + def _create_array(self, series, arrow_type, spark_type=None): """ Create an Arrow Array from the given pandas.Series and optional type. @@ -199,6 +199,8 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer): A single series arrow_type : pyarrow.DataType, optional If None, pyarrow's inferred type will be used + spark_type : DataType, optional + If None, spark type converted from arrow_type will be used Returns ------- @@ -211,10 +213,10 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer): series = series.astype(series.dtypes.categories.dtype) if arrow_type is not None: - spark_type = from_arrow_type(arrow_type, prefer_timestamp_ntz=True) + dt = spark_type or from_arrow_type(arrow_type, prefer_timestamp_ntz=True) # TODO(SPARK-43579): cache the converter for reuse conv = _create_converter_from_pandas( - spark_type, timezone=self._timezone, error_on_duplicated_field_names=False + dt, timezone=self._timezone, error_on_duplicated_field_names=False ) series = conv(series) @@ -261,14 +263,24 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer): """ import pyarrow as pa - # Make input conform to [(series1, type1), (series2, type2), ...] - if not isinstance(series, (list, tuple)) or ( - len(series) == 2 and isinstance(series[1], pa.DataType) + # Make input conform to + # [(series1, arrow_type1, spark_type1), (series2, arrow_type2, spark_type2), ...] + if ( + not isinstance(series, (list, tuple)) + or (len(series) == 2 and isinstance(series[1], pa.DataType)) + or ( + len(series) == 3 + and isinstance(series[1], pa.DataType) + and isinstance(series[2], DataType) + ) ): series = [series] series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series) + series = ((s[0], s[1], None) if len(s) == 2 else s for s in series) - arrs = [self._create_array(s, t) for s, t in series] + arrs = [ + self._create_array(s, arrow_type, spark_type) for s, arrow_type, spark_type in series + ] return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in range(len(arrs))]) def dump_stream(self, iterator, stream): diff --git a/python/pyspark/sql/pandas/types.py b/python/pyspark/sql/pandas/types.py index adf497bbd73..ae7c25e0828 100644 --- a/python/pyspark/sql/pandas/types.py +++ b/python/pyspark/sql/pandas/types.py @@ -46,6 +46,7 @@ from pyspark.sql.types import ( StructField, NullType, DataType, + UserDefinedType, Row, _create_row, ) @@ -119,6 +120,8 @@ def to_arrow_type(dt: DataType) -> "pa.DataType": arrow_type = pa.struct(fields) elif type(dt) == NullType: arrow_type = pa.null() + elif isinstance(dt, UserDefinedType): + arrow_type = to_arrow_type(dt.sqlType()) else: raise PySparkTypeError( error_class="UNSUPPORTED_DATA_TYPE_FOR_ARROW_CONVERSION", @@ -561,10 +564,12 @@ def _create_converter_to_pandas( return correct_dtype - def _converter(dt: DataType) -> Optional[Callable[[Any], Any]]: + def _converter( + dt: DataType, _struct_in_pandas: Optional[str] + ) -> Optional[Callable[[Any], Any]]: if isinstance(dt, ArrayType): - _element_conv = _converter(dt.elementType) + _element_conv = _converter(dt.elementType, _struct_in_pandas) if _element_conv is None: return None @@ -582,8 +587,8 @@ def _create_converter_to_pandas( return convert_array elif isinstance(dt, MapType): - _key_conv = _converter(dt.keyType) or (lambda x: x) - _value_conv = _converter(dt.valueType) or (lambda x: x) + _key_conv = _converter(dt.keyType, _struct_in_pandas) or (lambda x: x) + _value_conv = _converter(dt.valueType, _struct_in_pandas) or (lambda x: x) def convert_map(value: Any) -> Any: if value is None: @@ -599,7 +604,7 @@ def _create_converter_to_pandas( return convert_map elif isinstance(dt, StructType): - assert struct_in_pandas is not None + assert _struct_in_pandas is not None field_names = dt.names @@ -611,9 +616,11 @@ def _create_converter_to_pandas( dedup_field_names = _dedup_names(field_names) - field_convs = [_converter(f.dataType) or (lambda x: x) for f in dt.fields] + field_convs = [ + _converter(f.dataType, _struct_in_pandas) or (lambda x: x) for f in dt.fields + ] - if struct_in_pandas == "row": + if _struct_in_pandas == "row": def convert_struct_as_row(value: Any) -> Any: if value is None: @@ -633,7 +640,7 @@ def _create_converter_to_pandas( return convert_struct_as_row - elif struct_in_pandas == "dict": + elif _struct_in_pandas == "dict": def convert_struct_as_dict(value: Any) -> Any: if value is None: @@ -654,7 +661,7 @@ def _create_converter_to_pandas( return convert_struct_as_dict else: - raise ValueError(f"Unknown value for `struct_in_pandas`: {struct_in_pandas}") + raise ValueError(f"Unknown value for `struct_in_pandas`: {_struct_in_pandas}") elif isinstance(dt, TimestampType): assert timezone is not None @@ -685,10 +692,26 @@ def _create_converter_to_pandas( return convert_timestamp_ntz + elif isinstance(dt, UserDefinedType): + udt: UserDefinedType = dt + + conv = _converter(udt.sqlType(), _struct_in_pandas="row") or (lambda x: x) + + def convert_udt(value: Any) -> Any: + if value is None: + return None + elif hasattr(value, "__UDT__"): + assert isinstance(value.__UDT__, type(udt)) + return value + else: + return udt.deserialize(conv(value)) + + return convert_udt + else: return None - conv = _converter(data_type) + conv = _converter(data_type, struct_in_pandas) if conv is not None: return lambda pser: pser.apply(conv) # type: ignore[return-value] else: @@ -779,7 +802,7 @@ def _create_converter_from_pandas( for i, key in enumerate(field_names) } else: - assert isinstance(value, Row) + assert isinstance(value, tuple) return {dedup_field_names[i]: field_convs[i](v) for i, v in enumerate(value)} return convert_struct @@ -799,6 +822,19 @@ def _create_converter_from_pandas( return convert_timestamp + elif isinstance(dt, UserDefinedType): + udt: UserDefinedType = dt + + conv = _converter(udt.sqlType()) or (lambda x: x) + + def convert_udt(value: Any) -> Any: + if value is None: + return None + else: + return conv(udt.serialize(value)) + + return convert_udt + return None conv = _converter(data_type) diff --git a/python/pyspark/sql/tests/connect/test_parity_arrow.py b/python/pyspark/sql/tests/connect/test_parity_arrow.py index d1c8a1a55a0..60f1ef257c5 100644 --- a/python/pyspark/sql/tests/connect/test_parity_arrow.py +++ b/python/pyspark/sql/tests/connect/test_parity_arrow.py @@ -121,6 +121,12 @@ class ArrowParityTests(ArrowTestsMixin, ReusedConnectTestCase): def test_toPandas_nested_timestamp(self): self.check_toPandas_nested_timestamp(True) + def test_createDataFrame_udt(self): + self.check_createDataFrame_udt(True) + + def test_toPandas_udt(self): + self.check_toPandas_udt(True) + if __name__ == "__main__": from pyspark.sql.tests.connect.test_parity_arrow import * # noqa: F401 diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index dfde747c265..e26aabbea27 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -53,6 +53,8 @@ from pyspark.testing.sqlutils import ( have_pyarrow, pandas_requirement_message, pyarrow_requirement_message, + ExamplePoint, + ExamplePointUDT, ) from pyspark.testing.utils import QuietTest from pyspark.errors import ArithmeticException, PySparkTypeError, UnsupportedOperationException @@ -1022,7 +1024,7 @@ class ArrowTestsMixin: df = self.spark.range(2).select([]) with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": arrow_enabled}): - assert_frame_equal(df.toPandas(), pd.DataFrame(index=range(2))) + assert_frame_equal(df.toPandas(), pd.DataFrame(columns=[], index=range(2))) def test_createDataFrame_nested_timestamp(self): for arrow_enabled in [True, False]: @@ -1143,6 +1145,69 @@ class ArrowTestsMixin: assert_frame_equal(pdf, expected) + def test_createDataFrame_udt(self): + for arrow_enabled in [True, False]: + with self.subTest(arrow_enabled=arrow_enabled): + self.check_createDataFrame_udt(arrow_enabled) + + def check_createDataFrame_udt(self, arrow_enabled): + schema = ( + StructType() + .add("point", ExamplePointUDT()) + .add("struct", StructType().add("point", ExamplePointUDT())) + .add("array", ArrayType(ExamplePointUDT())) + .add("map", MapType(StringType(), ExamplePointUDT())) + ) + data = [ + Row( + ExamplePoint(1.0, 2.0), + Row(ExamplePoint(3.0, 4.0)), + [ExamplePoint(5.0, 6.0)], + dict(point=ExamplePoint(7.0, 8.0)), + ) + ] + pdf = pd.DataFrame.from_records(data, columns=schema.names) + + with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": arrow_enabled}): + df = self.spark.createDataFrame(pdf, schema) + + self.assertEqual(df.collect(), data) + + def test_toPandas_udt(self): + for arrow_enabled in [True, False]: + with self.subTest(arrow_enabled=arrow_enabled): + self.check_toPandas_udt(arrow_enabled) + + def check_toPandas_udt(self, arrow_enabled): + schema = ( + StructType() + .add("point", ExamplePointUDT()) + .add("struct", StructType().add("point", ExamplePointUDT())) + .add("array", ArrayType(ExamplePointUDT())) + .add("map", MapType(StringType(), ExamplePointUDT())) + ) + data = [ + Row( + ExamplePoint(1.0, 2.0), + Row(ExamplePoint(3.0, 4.0)), + [ExamplePoint(5.0, 6.0)], + dict(point=ExamplePoint(7.0, 8.0)), + ) + ] + df = self.spark.createDataFrame(data, schema) + + with self.sql_conf( + { + "spark.sql.execution.arrow.pyspark.enabled": arrow_enabled, + "spark.sql.execution.pandas.structHandlingMode": "row", + } + ): + pdf = df.toPandas() + + expected = pd.DataFrame.from_records(data, columns=schema.names) + + assert_frame_equal(pdf, expected) + @unittest.skipIf( not have_pandas or not have_pyarrow, --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org