This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new fd376718f75c [SPARK-55365][PYTHON] Generalize the utils for arrow
array conversion
fd376718f75c is described below
commit fd376718f75c1dde29eff0f0e33499e7784cc0e5
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Thu Feb 5 16:49:51 2026 +0800
[SPARK-55365][PYTHON] Generalize the utils for arrow array conversion
### What changes were proposed in this pull request?
Generalize the utils for arrow array conversion
### Why are the changes needed?
we have a `localize_tz` to drop timezones from `TimestampType`, when I want
to implement more time-related conversions in `convert_numpy`, I found it has
two drawbacks:
1, the new function for `coerce_temporal_nanoseconds` will be pretty
similar, and I have to copy-paste the code for nested type handling;
2, it doesn't support `ChunkedArray` for now, thus the timezone handling is
not yet enabled in the new arrow->pandas conversion.
So I think we need to make the utils more flexible for adding new
conversion functions.
### Does this PR introduce _any_ user-facing change?
no, internal changes
### How was this patch tested?
ci
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #54152 from zhengruifeng/arrow_convert.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/sql/conversion.py | 173 +++++++++++++++++++---------
python/pyspark/sql/tests/test_conversion.py | 6 +-
2 files changed, 122 insertions(+), 57 deletions(-)
diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py
index f65e6cb814bf..e15e678a2dc9 100644
--- a/python/pyspark/sql/conversion.py
+++ b/python/pyspark/sql/conversion.py
@@ -903,29 +903,121 @@ class ArrowTableToRowsConversion:
return [_create_row(fields, tuple())] * table.num_rows
-class ArrowTimestampConversion:
+class ArrowArrayConversion:
@classmethod
- def _need_localization(cls, at: "pa.DataType") -> bool:
+ def check_conversion(
+ cls,
+ pa_type: "pa.DataType",
+ check_type: Callable[["pa.DataType"], bool],
+ ) -> bool:
import pyarrow.types as types
- if types.is_timestamp(at) and at.tz is not None:
+ if check_type(pa_type):
return True
elif (
- types.is_list(at)
- or types.is_large_list(at)
- or types.is_fixed_size_list(at)
- or types.is_dictionary(at)
+ types.is_list(pa_type)
+ or types.is_large_list(pa_type)
+ or types.is_fixed_size_list(pa_type)
+ or types.is_dictionary(pa_type)
):
- return cls._need_localization(at.value_type)
- elif types.is_map(at):
- return any(cls._need_localization(dt) for dt in [at.key_type,
at.item_type])
- elif types.is_struct(at):
- return any(cls._need_localization(field.type) for field in at)
+ return cls.check_conversion(pa_type.value_type, check_type)
+ elif types.is_map(pa_type):
+ return any(
+ cls.check_conversion(at, check_type)
+ for at in [
+ pa_type.key_type,
+ pa_type.item_type,
+ ]
+ )
+ elif types.is_struct(pa_type):
+ return any(cls.check_conversion(field.type, check_type) for field
in pa_type)
else:
return False
- @staticmethod
- def localize_tz(arr: "pa.Array") -> "pa.Array":
+ @classmethod
+ def convert_array(
+ cls,
+ arr: "pa.Array",
+ check_type: Callable[["pa.DataType"], bool],
+ convert: Callable[["pa.Array"], "pa.Array"],
+ ) -> "pa.Array":
+ import pyarrow as pa
+ import pyarrow.types as types
+
+ assert isinstance(arr, pa.Array)
+
+ pa_type = arr.type
+ # fastpath
+ if not cls.check_conversion(pa_type, check_type):
+ return arr
+
+ if check_type(pa_type):
+ converted = convert(arr)
+ assert len(converted) == len(arr), f"array length changed: {arr}
-> {converted}"
+ return converted
+ elif types.is_list(pa_type):
+ return pa.ListArray.from_arrays(
+ offsets=arr.offsets,
+ values=cls.convert_array(arr.values, check_type, convert),
+ )
+ elif types.is_large_list(pa_type):
+ return pa.LargeListType.from_arrays(
+ offsets=arr.offsets,
+ values=cls.convert_array(arr.values, check_type, convert),
+ )
+ elif types.is_fixed_size_list(pa_type):
+ return pa.FixedSizeListArray.from_arrays(
+ values=cls.convert_array(arr.values, check_type, convert),
+ )
+ elif types.is_dictionary(pa_type):
+ return pa.DictionaryArray.from_arrays(
+ indices=arr.indices,
+ dictionary=cls.convert_array(arr.dictionary, check_type,
convert),
+ )
+ elif types.is_map(pa_type):
+ return pa.MapArray.from_arrays(
+ offsets=arr.offsets,
+ keys=cls.convert_array(arr.keys, check_type, convert),
+ items=cls.convert_array(arr.items, check_type, convert),
+ )
+ elif types.is_struct(pa_type):
+ return pa.StructArray.from_arrays(
+ arrays=[
+ cls.convert_array(arr.field(i), check_type, convert)
+ for i in range(len(arr.type))
+ ],
+ names=arr.type.names,
+ )
+ else: # pragma: no cover
+ assert False, f"Need converter for {pa_type} but failed to find
one."
+
+ @classmethod
+ def convert(
+ cls,
+ arr: Union["pa.Array", "pa.ChunkedArray"],
+ check_type: Callable[["pa.DataType"], bool],
+ convert: Callable[["pa.Array"], "pa.Array"],
+ ) -> Union["pa.Array", "pa.ChunkedArray"]:
+ import pyarrow as pa
+
+ assert isinstance(arr, (pa.Array, pa.ChunkedArray))
+
+ # fastpath
+ if not cls.check_conversion(arr.type, check_type):
+ return arr
+
+ if isinstance(arr, pa.Array):
+ return cls.convert_array(arr, check_type, convert)
+ else:
+ return pa.chunked_array(
+ (cls.convert_array(a, check_type, convert) for a in
arr.iterchunks())
+ )
+
+ @classmethod
+ def localize_tz(
+ cls,
+ arr: Union["pa.Array", "pa.ChunkedArray"],
+ ) -> Union["pa.Array", "pa.ChunkedArray"]:
"""
Convert Arrow timezone-aware timestamps to timezone-naive in the
specified timezone.
This function works on Arrow Arrays, and it recurses to convert nested
types.
@@ -960,11 +1052,13 @@ class ArrowTimestampConversion:
import pyarrow.types as types
import pyarrow.compute as pc
- pa_type = arr.type
- if not ArrowTimestampConversion._need_localization(pa_type):
- return arr
+ def check_type_func(pa_type: pa.DataType) -> bool:
+ # match timezone-aware TimestampType
+ return types.is_timestamp(pa_type) and pa_type.tz is not None
+
+ def convert_func(arr: pa.Array) -> pa.Array:
+ assert isinstance(arr, pa.TimestampArray)
- if types.is_timestamp(pa_type) and pa_type.tz is not None:
# import datetime
# from zoneinfo import ZoneInfo
# ts = datetime.datetime(2022, 1, 5, 15, 0, 1,
tzinfo=ZoneInfo('Asia/Singapore'))
@@ -974,42 +1068,13 @@ class ArrowTimestampConversion:
# arr = pc.local_timestamp(arr)
# arr[0]
# <pyarrow.TimestampScalar: '2022-01-05T15:00:01.000000'>
-
return pc.local_timestamp(arr)
- elif types.is_list(pa_type):
- return pa.ListArray.from_arrays(
- offsets=arr.offsets,
- values=ArrowTimestampConversion.localize_tz(arr.values),
- )
- elif types.is_large_list(pa_type):
- return pa.LargeListType.from_arrays(
- offsets=arr.offsets,
- values=ArrowTimestampConversion.localize_tz(arr.values),
- )
- elif types.is_fixed_size_list(pa_type):
- return pa.FixedSizeListArray.from_arrays(
- values=ArrowTimestampConversion.localize_tz(arr.values),
- )
- elif types.is_dictionary(pa_type):
- return pa.DictionaryArray.from_arrays(
- indices=arr.indices,
-
dictionary=ArrowTimestampConversion.localize_tz(arr.dictionary),
- )
- elif types.is_map(pa_type):
- return pa.MapArray.from_arrays(
- offsets=arr.offsets,
- keys=ArrowTimestampConversion.localize_tz(arr.keys),
- items=ArrowTimestampConversion.localize_tz(arr.items),
- )
- elif types.is_struct(pa_type):
- return pa.StructArray.from_arrays(
- arrays=[
- ArrowTimestampConversion.localize_tz(arr.field(i)) for i
in range(len(arr.type))
- ],
- names=arr.type.names,
- )
- else: # pragma: no cover
- assert False, f"Need converter for {pa_type} but failed to find
one."
+
+ return cls.convert(
+ arr,
+ check_type=check_type_func,
+ convert=convert_func,
+ )
class ArrowArrayToPandasConversion:
@@ -1237,7 +1302,7 @@ class ArrowArrayToPandasConversion:
pdf.columns = spark_type.names # type: ignore[assignment]
return pdf
- arr = ArrowTimestampConversion.localize_tz(arr)
+ arr = ArrowArrayConversion.localize_tz(arr)
# TODO(SPARK-55332): Create benchmark for pa.array -> pd.series
integer conversion
# 1, benchmark a nullable integral array
diff --git a/python/pyspark/sql/tests/test_conversion.py
b/python/pyspark/sql/tests/test_conversion.py
index c3fa1fd19304..adee81a158f8 100644
--- a/python/pyspark/sql/tests/test_conversion.py
+++ b/python/pyspark/sql/tests/test_conversion.py
@@ -22,7 +22,7 @@ from pyspark.errors import PySparkValueError
from pyspark.sql.conversion import (
ArrowTableToRowsConversion,
LocalDataToArrowConversion,
- ArrowTimestampConversion,
+ ArrowArrayConversion,
ArrowBatchTransformer,
)
from pyspark.sql.types import (
@@ -304,7 +304,7 @@ class ConversionTests(unittest.TestCase):
pa.StructArray.from_arrays([pa.array([1, 2]), pa.array(["x",
"y"])], names=["a", "b"]),
pa.array([{1: None, 2: "x"}], type=pa.map_(pa.int32(),
pa.string())),
]:
- output = ArrowTimestampConversion.localize_tz(arr)
+ output = ArrowArrayConversion.localize_tz(arr)
self.assertTrue(output is arr, f"MUST not generate a new array
{output.tolist()}")
# timestampe types
@@ -372,7 +372,7 @@ class ConversionTests(unittest.TestCase):
),
), # map<int, array<ts-ltz>>
]:
- output = ArrowTimestampConversion.localize_tz(arr)
+ output = ArrowArrayConversion.localize_tz(arr)
self.assertEqual(output, expected, f"{output.tolist()} !=
{expected.tolist()}")
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]