This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new ecf179c3485b [SPARK-54337][PS] Add support for PyCapsule to Pyspark
ecf179c3485b is described below

commit ecf179c3485ba8bac72afd9105892d9798d23f8f
Author: Devin Petersohn <[email protected]>
AuthorDate: Mon Jan 12 09:07:51 2026 +0800

    [SPARK-54337][PS] Add support for PyCapsule to Pyspark
    
    ### What changes were proposed in this pull request?
    
    Add support for Pycapsule protocol to facilitate interchange between Spark 
and other Python libraries. Here is a demo of what this enables with Polars and 
DuckDB:
    
    ```
    Welcome to
          ____              __
         / __/__  ___ _____/ /__
        _\ \/ _ \/ _ `/ __/  '_/
       /__ / .__/\_,_/_/ /_/\_\   version 4.2.0-SNAPSHOT
          /_/
    
    Using Python version 3.11.5 (main, Sep 11 2023 08:31:25)
    Spark context Web UI available at http://192.168.86.83:4040
    Spark context available as 'sc' (master = local[*], app id = 
local-1765227291836).
    SparkSession available as 'spark'.
    
    In [1]: import pyspark.pandas as ps
       ...: import pandas as pd
       ...: import numpy as np
       ...: import polars as pl
       ...:
       ...: pdf = pd.DataFrame(
       ...:     {"A": [True, False], "B": [1, np.nan], "C": [True, None], "D": 
[None, np.nan]}
       ...: )
       ...: psdf = ps.from_pandas(pdf)
       ...: polars_df = pl.DataFrame(psdf)
    
/Users/dpetersohn/software_sources/spark/python/pyspark/pandas/__init__.py:43: 
UserWarning: 'PYARROW_IGNORE_TIMEZONE' environment variable was not set. It is 
required to set this environment variable to '1' in both driver and executor 
sides if you use pyarrow>=2.0.0. pandas-on-Spark will set it for you but it 
does not work if there is a Spark context already launched.
      warnings.warn(
    [Stage 0:>                                                          (0 + 1) 
/ 1]
    In [2]: polars_df
    Out[2]:
    shape: (2, 5)
    ┌───────────────────┬───────┬──────┬──────┬──────┐
    │ __index_level_0__ ┆ A     ┆ B    ┆ C    ┆ D    │
    │ ---               ┆ ---   ┆ ---  ┆ ---  ┆ ---  │
    │ i64               ┆ bool  ┆ f64  ┆ bool ┆ f64  │
    ╞═══════════════════╪═══════╪══════╪══════╪══════╡
    │ 0                 ┆ true  ┆ 1.0  ┆ true ┆ null │
    │ 1                 ┆ false ┆ null ┆ null ┆ null │
    └───────────────────┴───────┴──────┴──────┴──────┘
    
    In [3]: import duckdb
    
    In [4]: import pyarrow as pa
    
    In [5]: stream = pa.RecordBatchReader.from_stream(psdf)
    
    In [6]: duckdb.sql("SELECT count(*) AS total, avg(B) FROM stream WHERE B IS 
NOT NULL").fetchall()
    Out[6]: [(1, 1.0)]
    ```
    
    Polars will now be able to consume a full Pyspark dataframe (or 
`pyspark.pandas`), and DuckDB can consume a stream built from the Pyspark 
dataframe. Importantly, the `stream = pa.RecordBatchReader.from_stream(psdf)` 
line does not trigger any computation, it simply creates a stream object which 
is incrementally consumed by DuckDB when the `fetchall` call is executed.
    
    ### Why are the changes needed?
    
    Currently, Pyspark (and to a lesser degree Pyspark pandas) does not 
integrate well with the broader Python ecosystem. Currently, the best practice 
is to go through pandas with `toPandas`, but that materializes all data on the 
driver all at once. This new API and protocol allows data to stream, one Arrow 
Batch at a time, enabling libraries like DuckDB and Polars to consume the data 
as a stream.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, new user-level API.
    
    ### How was this patch tested?
    
    Locally
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #53391 from devin-petersohn/devin/pycapsule.
    
    Lead-authored-by: Devin Petersohn <[email protected]>
    Co-authored-by: Devin Petersohn <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 dev/sparktestsupport/modules.py                    |  1 +
 python/pyspark/pandas/frame.py                     | 19 +++++
 python/pyspark/sql/dataframe.py                    | 17 +++++
 python/pyspark/sql/interchange.py                  | 89 ++++++++++++++++++++++
 .../pyspark/sql/tests/arrow/test_arrow_c_stream.py | 64 ++++++++++++++++
 5 files changed, 190 insertions(+)

diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 4e956314c3d8..0ff9d6634377 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -551,6 +551,7 @@ pyspark_sql = Module(
         "pyspark.sql.tests.test_job_cancellation",
         "pyspark.sql.tests.arrow.test_arrow",
         "pyspark.sql.tests.arrow.test_arrow_map",
+        "pyspark.sql.tests.arrow.test_arrow_c_stream",
         "pyspark.sql.tests.arrow.test_arrow_cogrouped_map",
         "pyspark.sql.tests.arrow.test_arrow_grouped_map",
         "pyspark.sql.tests.arrow.test_arrow_python_udf",
diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py
index df68e31d4f33..23ac31c8ebfb 100644
--- a/python/pyspark/pandas/frame.py
+++ b/python/pyspark/pandas/frame.py
@@ -13823,6 +13823,25 @@ defaultdict(<class 'list'>, {'col..., 'col...})]
         # we always wraps the given type hints by a tuple to mimic the 
variadic generic.
         return create_tuple_for_frame_type(params)
 
+    def __arrow_c_stream__(self, requested_schema: Optional[object] = None) -> 
object:
+        """
+        Export to a C PyCapsule stream object.
+
+        Parameters
+        ----------
+        requested_schema : PyCapsule, optional
+            The schema to attempt to use for the output stream. This is a best 
effort request,
+
+        Returns
+        -------
+        A C PyCapsule stream object.
+        """
+        from pyspark.sql.interchange import SparkArrowCStreamer
+
+        return 
SparkArrowCStreamer(self._internal.to_internal_spark_frame).__arrow_c_stream__(
+            requested_schema
+        )
+
 
 def _reduce_spark_multi(sdf: PySparkDataFrame, aggs: List[PySparkColumn]) -> 
Any:
     """
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 7b7547b68ff2..2ddfdda762d7 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -6982,6 +6982,23 @@ class DataFrame:
         """
         ...
 
+    def __arrow_c_stream__(self, requested_schema: Optional[object] = None) -> 
object:
+        """
+        Export to a C PyCapsule stream object.
+
+        Parameters
+        ----------
+        requested_schema : PyCapsule, optional
+            The schema to attempt to use for the output stream. This is a best 
effort request,
+
+        Returns
+        -------
+        A C PyCapsule stream object.
+        """
+        from pyspark.sql.interchange import SparkArrowCStreamer
+
+        return SparkArrowCStreamer(self).__arrow_c_stream__(requested_schema)
+
 
 class DataFrameNaFunctions:
     """Functionality for working with missing data in :class:`DataFrame`.
diff --git a/python/pyspark/sql/interchange.py 
b/python/pyspark/sql/interchange.py
new file mode 100644
index 000000000000..141d9f37148e
--- /dev/null
+++ b/python/pyspark/sql/interchange.py
@@ -0,0 +1,89 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from typing import Iterator, Optional
+
+import pyarrow as pa
+
+import pyspark.sql
+from pyspark.sql.types import StructType, StructField, BinaryType
+from pyspark.sql.pandas.types import to_arrow_schema
+
+
+def _get_arrow_array_partition_stream(df: pyspark.sql.DataFrame) -> 
Iterator[pa.RecordBatch]:
+    """Return all the partitions as Arrow arrays in an Iterator."""
+    # We will be using mapInArrow to convert each partition to Arrow 
RecordBatches.
+    # The return type of the function will be a single binary column containing
+    # the serialized RecordBatch in Arrow IPC format.
+    binary_schema = StructType([StructField("arrow_ipc_bytes", BinaryType(), 
nullable=False)])
+
+    def batch_to_bytes_iter(batch_iter: Iterator[pa.RecordBatch]) -> 
Iterator[pa.RecordBatch]:
+        """
+        A generator function that converts RecordBatches to serialized Arrow 
IPC format.
+
+        Spark sends each partition as an iterator of RecordBatches. In order 
to return
+        the entire partition as a stream of Arrow RecordBatches, we need to 
serialize
+        each RecordBatch to Arrow IPC format and yield it as a single binary 
blob.
+        """
+        # The size of the batch can be controlled by the Spark config
+        # `spark.sql.execution.arrow.maxRecordsPerBatch`.
+        for arrow_batch in batch_iter:
+            # We create an in-memory byte stream to hold the serialized batch
+            sink = pa.BufferOutputStream()
+            # Write the batch to the stream using Arrow IPC format
+            with pa.ipc.new_stream(sink, arrow_batch.schema) as writer:
+                writer.write_batch(arrow_batch)
+            buf = sink.getvalue()
+            # The second buffer contains the offsets we are manually creating.
+            offset_buf = pa.array([0, len(buf)], type=pa.int32()).buffers()[1]
+            null_bitmap = None
+            # Wrap the bytes in a new 1-row, 1-column RecordBatch to satisfy 
mapInArrow return
+            # signature. This serializes the whole batch into a single pyarrow 
serialized cell.
+            storage_arr = pa.Array.from_buffers(
+                type=pa.binary(), length=1, buffers=[null_bitmap, offset_buf, 
buf]
+            )
+            yield pa.RecordBatch.from_arrays([storage_arr], 
names=["arrow_ipc_bytes"])
+
+    # Convert all partitions to Arrow RecordBatches and map to binary blobs.
+    byte_df = df.mapInArrow(batch_to_bytes_iter, binary_schema)
+    # A row is actually a batch of data in Arrow IPC format. Fetch the batches 
one by one.
+    for row in byte_df.toLocalIterator():
+        with pa.ipc.open_stream(row.arrow_ipc_bytes) as reader:
+            for batch in reader:
+                # Each batch corresponds to a chunk of data in the partition.
+                yield batch
+
+
+class SparkArrowCStreamer:
+    """
+    A class that implements that __arrow_c_stream__ protocol for Spark 
partitions.
+
+    This class is implemented in a way that allows consumers to consume each 
partition
+    one at a time without materializing all partitions at once on the driver 
side.
+    """
+
+    def __init__(self, df: pyspark.sql.DataFrame):
+        self._df = df
+        self._schema = to_arrow_schema(df.schema)
+
+    def __arrow_c_stream__(self, requested_schema: Optional[object] = None) -> 
object:
+        """
+        Return the Arrow C stream for the dataframe partitions.
+        """
+        reader: pa.RecordBatchReader = pa.RecordBatchReader.from_batches(
+            self._schema, _get_arrow_array_partition_stream(self._df)
+        )
+        return reader.__arrow_c_stream__(requested_schema=requested_schema)
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_c_stream.py 
b/python/pyspark/sql/tests/arrow/test_arrow_c_stream.py
new file mode 100644
index 000000000000..9534db71bae6
--- /dev/null
+++ b/python/pyspark/sql/tests/arrow/test_arrow_c_stream.py
@@ -0,0 +1,64 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import ctypes
+import unittest
+import pyarrow as pa
+import pandas as pd
+import pyspark.pandas as ps
+
+
+class TestSparkArrowCStreamer(unittest.TestCase):
+    def test_spark_arrow_c_streamer_arrow_consumer(self):
+        pdf = pd.DataFrame([[1, "a"], [2, "b"], [3, "c"], [4, "d"]], 
columns=["id", "value"])
+        psdf = ps.from_pandas(pdf)
+
+        capsule = psdf.__arrow_c_stream__()
+        assert (
+            ctypes.pythonapi.PyCapsule_IsValid(ctypes.py_object(capsule), 
b"arrow_array_stream")
+            == 1
+        )
+
+        stream = pa.RecordBatchReader.from_stream(psdf)
+        assert isinstance(stream, pa.RecordBatchReader)
+        result = pa.Table.from_batches(stream)
+        schema = pa.schema(
+            [
+                ("__index_level_0__", pa.int64(), False),
+                ("id", pa.int64(), False),
+                ("value", pa.string(), False),
+            ]
+        )
+        expected = pa.Table.from_pandas(
+            pd.DataFrame(
+                [[0, 1, "a"], [1, 2, "b"], [2, 3, "c"], [3, 4, "d"]],
+                columns=["__index_level_0__", "id", "value"],
+            ),
+            schema=schema,
+        )
+        self.assertEqual(result, expected)
+
+
+if __name__ == "__main__":
+    from pyspark.sql.tests.arrow.test_arrow_c_stream import *  # noqa: F401
+
+    try:
+        import xmlrunner  # type: ignore
+
+        test_runner = xmlrunner.XMLTestRunner(output="target/test-reports", 
verbosity=2)
+    except ImportError:
+        test_runner = None
+    unittest.main(testRunner=test_runner, verbosity=2)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to