This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new fd376718f75c [SPARK-55365][PYTHON] Generalize the utils for arrow 
array conversion
fd376718f75c is described below

commit fd376718f75c1dde29eff0f0e33499e7784cc0e5
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Thu Feb 5 16:49:51 2026 +0800

    [SPARK-55365][PYTHON] Generalize the utils for arrow array conversion
    
    ### What changes were proposed in this pull request?
    Generalize the utils for arrow array conversion
    
    ### Why are the changes needed?
    we have a `localize_tz` to drop timezones from `TimestampType`, when I want 
to implement more time-related conversions in `convert_numpy`, I found it has 
two drawbacks:
    1, the new function for `coerce_temporal_nanoseconds` will be pretty 
similar, and I have to copy-paste the code for nested type handling;
    2, it doesn't support `ChunkedArray` for now, thus the timezone handling is 
not yet enabled in the new arrow->pandas conversion.
    
    So I think we need to make the utils more flexible for adding new 
conversion functions.
    
    ### Does this PR introduce _any_ user-facing change?
    no, internal changes
    
    ### How was this patch tested?
    ci
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #54152 from zhengruifeng/arrow_convert.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 python/pyspark/sql/conversion.py            | 173 +++++++++++++++++++---------
 python/pyspark/sql/tests/test_conversion.py |   6 +-
 2 files changed, 122 insertions(+), 57 deletions(-)

diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py
index f65e6cb814bf..e15e678a2dc9 100644
--- a/python/pyspark/sql/conversion.py
+++ b/python/pyspark/sql/conversion.py
@@ -903,29 +903,121 @@ class ArrowTableToRowsConversion:
                 return [_create_row(fields, tuple())] * table.num_rows
 
 
-class ArrowTimestampConversion:
+class ArrowArrayConversion:
     @classmethod
-    def _need_localization(cls, at: "pa.DataType") -> bool:
+    def check_conversion(
+        cls,
+        pa_type: "pa.DataType",
+        check_type: Callable[["pa.DataType"], bool],
+    ) -> bool:
         import pyarrow.types as types
 
-        if types.is_timestamp(at) and at.tz is not None:
+        if check_type(pa_type):
             return True
         elif (
-            types.is_list(at)
-            or types.is_large_list(at)
-            or types.is_fixed_size_list(at)
-            or types.is_dictionary(at)
+            types.is_list(pa_type)
+            or types.is_large_list(pa_type)
+            or types.is_fixed_size_list(pa_type)
+            or types.is_dictionary(pa_type)
         ):
-            return cls._need_localization(at.value_type)
-        elif types.is_map(at):
-            return any(cls._need_localization(dt) for dt in [at.key_type, 
at.item_type])
-        elif types.is_struct(at):
-            return any(cls._need_localization(field.type) for field in at)
+            return cls.check_conversion(pa_type.value_type, check_type)
+        elif types.is_map(pa_type):
+            return any(
+                cls.check_conversion(at, check_type)
+                for at in [
+                    pa_type.key_type,
+                    pa_type.item_type,
+                ]
+            )
+        elif types.is_struct(pa_type):
+            return any(cls.check_conversion(field.type, check_type) for field 
in pa_type)
         else:
             return False
 
-    @staticmethod
-    def localize_tz(arr: "pa.Array") -> "pa.Array":
+    @classmethod
+    def convert_array(
+        cls,
+        arr: "pa.Array",
+        check_type: Callable[["pa.DataType"], bool],
+        convert: Callable[["pa.Array"], "pa.Array"],
+    ) -> "pa.Array":
+        import pyarrow as pa
+        import pyarrow.types as types
+
+        assert isinstance(arr, pa.Array)
+
+        pa_type = arr.type
+        # fastpath
+        if not cls.check_conversion(pa_type, check_type):
+            return arr
+
+        if check_type(pa_type):
+            converted = convert(arr)
+            assert len(converted) == len(arr), f"array length changed: {arr} 
-> {converted}"
+            return converted
+        elif types.is_list(pa_type):
+            return pa.ListArray.from_arrays(
+                offsets=arr.offsets,
+                values=cls.convert_array(arr.values, check_type, convert),
+            )
+        elif types.is_large_list(pa_type):
+            return pa.LargeListType.from_arrays(
+                offsets=arr.offsets,
+                values=cls.convert_array(arr.values, check_type, convert),
+            )
+        elif types.is_fixed_size_list(pa_type):
+            return pa.FixedSizeListArray.from_arrays(
+                values=cls.convert_array(arr.values, check_type, convert),
+            )
+        elif types.is_dictionary(pa_type):
+            return pa.DictionaryArray.from_arrays(
+                indices=arr.indices,
+                dictionary=cls.convert_array(arr.dictionary, check_type, 
convert),
+            )
+        elif types.is_map(pa_type):
+            return pa.MapArray.from_arrays(
+                offsets=arr.offsets,
+                keys=cls.convert_array(arr.keys, check_type, convert),
+                items=cls.convert_array(arr.items, check_type, convert),
+            )
+        elif types.is_struct(pa_type):
+            return pa.StructArray.from_arrays(
+                arrays=[
+                    cls.convert_array(arr.field(i), check_type, convert)
+                    for i in range(len(arr.type))
+                ],
+                names=arr.type.names,
+            )
+        else:  # pragma: no cover
+            assert False, f"Need converter for {pa_type} but failed to find 
one."
+
+    @classmethod
+    def convert(
+        cls,
+        arr: Union["pa.Array", "pa.ChunkedArray"],
+        check_type: Callable[["pa.DataType"], bool],
+        convert: Callable[["pa.Array"], "pa.Array"],
+    ) -> Union["pa.Array", "pa.ChunkedArray"]:
+        import pyarrow as pa
+
+        assert isinstance(arr, (pa.Array, pa.ChunkedArray))
+
+        # fastpath
+        if not cls.check_conversion(arr.type, check_type):
+            return arr
+
+        if isinstance(arr, pa.Array):
+            return cls.convert_array(arr, check_type, convert)
+        else:
+            return pa.chunked_array(
+                (cls.convert_array(a, check_type, convert) for a in 
arr.iterchunks())
+            )
+
+    @classmethod
+    def localize_tz(
+        cls,
+        arr: Union["pa.Array", "pa.ChunkedArray"],
+    ) -> Union["pa.Array", "pa.ChunkedArray"]:
         """
         Convert Arrow timezone-aware timestamps to timezone-naive in the 
specified timezone.
         This function works on Arrow Arrays, and it recurses to convert nested 
types.
@@ -960,11 +1052,13 @@ class ArrowTimestampConversion:
         import pyarrow.types as types
         import pyarrow.compute as pc
 
-        pa_type = arr.type
-        if not ArrowTimestampConversion._need_localization(pa_type):
-            return arr
+        def check_type_func(pa_type: pa.DataType) -> bool:
+            # match timezone-aware TimestampType
+            return types.is_timestamp(pa_type) and pa_type.tz is not None
+
+        def convert_func(arr: pa.Array) -> pa.Array:
+            assert isinstance(arr, pa.TimestampArray)
 
-        if types.is_timestamp(pa_type) and pa_type.tz is not None:
             # import datetime
             # from zoneinfo import ZoneInfo
             # ts = datetime.datetime(2022, 1, 5, 15, 0, 1, 
tzinfo=ZoneInfo('Asia/Singapore'))
@@ -974,42 +1068,13 @@ class ArrowTimestampConversion:
             # arr = pc.local_timestamp(arr)
             # arr[0]
             # <pyarrow.TimestampScalar: '2022-01-05T15:00:01.000000'>
-
             return pc.local_timestamp(arr)
-        elif types.is_list(pa_type):
-            return pa.ListArray.from_arrays(
-                offsets=arr.offsets,
-                values=ArrowTimestampConversion.localize_tz(arr.values),
-            )
-        elif types.is_large_list(pa_type):
-            return pa.LargeListType.from_arrays(
-                offsets=arr.offsets,
-                values=ArrowTimestampConversion.localize_tz(arr.values),
-            )
-        elif types.is_fixed_size_list(pa_type):
-            return pa.FixedSizeListArray.from_arrays(
-                values=ArrowTimestampConversion.localize_tz(arr.values),
-            )
-        elif types.is_dictionary(pa_type):
-            return pa.DictionaryArray.from_arrays(
-                indices=arr.indices,
-                
dictionary=ArrowTimestampConversion.localize_tz(arr.dictionary),
-            )
-        elif types.is_map(pa_type):
-            return pa.MapArray.from_arrays(
-                offsets=arr.offsets,
-                keys=ArrowTimestampConversion.localize_tz(arr.keys),
-                items=ArrowTimestampConversion.localize_tz(arr.items),
-            )
-        elif types.is_struct(pa_type):
-            return pa.StructArray.from_arrays(
-                arrays=[
-                    ArrowTimestampConversion.localize_tz(arr.field(i)) for i 
in range(len(arr.type))
-                ],
-                names=arr.type.names,
-            )
-        else:  # pragma: no cover
-            assert False, f"Need converter for {pa_type} but failed to find 
one."
+
+        return cls.convert(
+            arr,
+            check_type=check_type_func,
+            convert=convert_func,
+        )
 
 
 class ArrowArrayToPandasConversion:
@@ -1237,7 +1302,7 @@ class ArrowArrayToPandasConversion:
             pdf.columns = spark_type.names  # type: ignore[assignment]
             return pdf
 
-        arr = ArrowTimestampConversion.localize_tz(arr)
+        arr = ArrowArrayConversion.localize_tz(arr)
 
         # TODO(SPARK-55332): Create benchmark for pa.array -> pd.series 
integer conversion
         # 1, benchmark a nullable integral array
diff --git a/python/pyspark/sql/tests/test_conversion.py 
b/python/pyspark/sql/tests/test_conversion.py
index c3fa1fd19304..adee81a158f8 100644
--- a/python/pyspark/sql/tests/test_conversion.py
+++ b/python/pyspark/sql/tests/test_conversion.py
@@ -22,7 +22,7 @@ from pyspark.errors import PySparkValueError
 from pyspark.sql.conversion import (
     ArrowTableToRowsConversion,
     LocalDataToArrowConversion,
-    ArrowTimestampConversion,
+    ArrowArrayConversion,
     ArrowBatchTransformer,
 )
 from pyspark.sql.types import (
@@ -304,7 +304,7 @@ class ConversionTests(unittest.TestCase):
             pa.StructArray.from_arrays([pa.array([1, 2]), pa.array(["x", 
"y"])], names=["a", "b"]),
             pa.array([{1: None, 2: "x"}], type=pa.map_(pa.int32(), 
pa.string())),
         ]:
-            output = ArrowTimestampConversion.localize_tz(arr)
+            output = ArrowArrayConversion.localize_tz(arr)
             self.assertTrue(output is arr, f"MUST not generate a new array 
{output.tolist()}")
 
         # timestampe types
@@ -372,7 +372,7 @@ class ConversionTests(unittest.TestCase):
                 ),
             ),  # map<int, array<ts-ltz>>
         ]:
-            output = ArrowTimestampConversion.localize_tz(arr)
+            output = ArrowArrayConversion.localize_tz(arr)
             self.assertEqual(output, expected, f"{output.tolist()} != 
{expected.tolist()}")
 
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to