This is an automated email from the ASF dual-hosted git repository.

wjones127 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new aff86e704d Implement Arrow PyCapsule Interface (#5070)
aff86e704d is described below

commit aff86e704dabecbf99edd1e0ad62c216819dbc15
Author: Kyle Barron <[email protected]>
AuthorDate: Wed Nov 15 13:18:45 2023 -0500

    Implement Arrow PyCapsule Interface (#5070)
    
    * arrow ffi array copy
    
    * remove copy_ffi_array
    
    * docstring
    
    * wip: pycapsule support
    
    * return
    
    * Update arrow/src/pyarrow.rs
    
    Co-authored-by: Raphael Taylor-Davies 
<[email protected]>
    
    * remove sync impl
    
    * Update arrow/src/pyarrow.rs
    
    Co-authored-by: Will Jones <[email protected]>
    
    * Remove copy()
    
    * Need &mut FFI_ArrowArray for std::mem::replace
    
    * Use std::ptr::replace
    
    * update comments
    
    * Minimize unsafe block
    
    * revert pub release functions
    
    * Add RecordBatch and Stream conversion
    
    * fix returns
    
    * Fix return type
    
    * Fix name
    
    * fix ci
    
    * Add tests
    
    * Add table test
    
    * skip if pre pyarrow 14
    
    * bump python version in CI to use pyarrow 14
    
    * Add record batch test
    
    * Update arrow/src/pyarrow.rs
    
    Co-authored-by: Raphael Taylor-Davies 
<[email protected]>
    
    * run on pyarrow 13 and 14
    
    * Update .github/workflows/integration.yml
    
    Co-authored-by: Will Jones <[email protected]>
    
    ---------
    
    Co-authored-by: Raphael Taylor-Davies 
<[email protected]>
    Co-authored-by: Will Jones <[email protected]>
---
 .github/workflows/integration.yml                  |   6 +-
 arrow-pyarrow-integration-testing/README.md        |   2 +
 .../tests/test_sql.py                              | 138 ++++++++++++++++++++-
 arrow-schema/src/ffi.rs                            |   2 +
 arrow/src/pyarrow.rs                               | 134 +++++++++++++++++++-
 5 files changed, 274 insertions(+), 8 deletions(-)

diff --git a/.github/workflows/integration.yml 
b/.github/workflows/integration.yml
index 6e2b442040..f939a6a13b 100644
--- a/.github/workflows/integration.yml
+++ b/.github/workflows/integration.yml
@@ -106,6 +106,8 @@ jobs:
     strategy:
       matrix:
         rust: [ stable ]
+        # PyArrow 13 was the last version prior to introduction to Arrow 
PyCapsules
+        pyarrow: [ "13", "14" ]
     steps:
       - uses: actions/checkout@v4
         with:
@@ -128,14 +130,14 @@ jobs:
           key: ${{ runner.os }}-${{ matrix.arch }}-target-maturin-cache-${{ 
matrix.rust }}-
       - uses: actions/setup-python@v4
         with:
-          python-version: '3.7'
+          python-version: '3.8'
       - name: Upgrade pip and setuptools
         run: pip install --upgrade pip setuptools wheel virtualenv
       - name: Create virtualenv and install dependencies
         run: |
           virtualenv venv
           source venv/bin/activate
-          pip install maturin toml pytest pytz pyarrow>=5.0
+          pip install maturin toml pytest pytz pyarrow==${{ matrix.pyarrow }}
       - name: Run Rust tests
         run: |
           source venv/bin/activate
diff --git a/arrow-pyarrow-integration-testing/README.md 
b/arrow-pyarrow-integration-testing/README.md
index e63953ad79..5ca2ea76b8 100644
--- a/arrow-pyarrow-integration-testing/README.md
+++ b/arrow-pyarrow-integration-testing/README.md
@@ -25,6 +25,7 @@ Note that this crate uses two languages and an external ABI:
 * `Rust`
 * `Python`
 * C ABI privately exposed by `Pyarrow`.
+* PyCapsule ABI publicly exposed by `pyarrow`
 
 ## Basic idea
 
@@ -36,6 +37,7 @@ we can use pyarrow's interface to move pointers from and to 
Rust.
 ## Relevant literature
 
 * [Arrow's 
CDataInterface](https://arrow.apache.org/docs/format/CDataInterface.html)
+* [Arrow PyCapsule 
Interface](https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html)
 * [Rust's FFI](https://doc.rust-lang.org/nomicon/ffi.html)
 * [Pyarrow private 
binds](https://github.com/apache/arrow/blob/ae1d24efcc3f1ac2a876d8d9f544a34eb04ae874/python/pyarrow/array.pxi#L1226)
 * [PyO3](https://docs.rs/pyo3/0.12.1/pyo3/index.html)
diff --git a/arrow-pyarrow-integration-testing/tests/test_sql.py 
b/arrow-pyarrow-integration-testing/tests/test_sql.py
index 1748fd3ffb..16d4e0f12f 100644
--- a/arrow-pyarrow-integration-testing/tests/test_sql.py
+++ b/arrow-pyarrow-integration-testing/tests/test_sql.py
@@ -27,6 +27,8 @@ import pytz
 
 import arrow_pyarrow_integration_testing as rust
 
+PYARROW_PRE_14 = int(pa.__version__.split('.')[0]) < 14
+
 
 @contextlib.contextmanager
 def no_pyarrow_leak():
@@ -113,6 +115,34 @@ _supported_pyarrow_types = [
 _unsupported_pyarrow_types = [
 ]
 
+# As of pyarrow 14, pyarrow implements the Arrow PyCapsule interface
+# 
(https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html).
+# This defines that Arrow consumers should allow any object that has specific 
"dunder"
+# methods, `__arrow_c_*_`. These wrapper classes ensure that arrow-rs is able 
to handle
+# _any_ class, without pyarrow-specific handling.
+class SchemaWrapper:
+    def __init__(self, schema):
+        self.schema = schema
+
+    def __arrow_c_schema__(self):
+        return self.schema.__arrow_c_schema__()
+
+
+class ArrayWrapper:
+    def __init__(self, array):
+        self.array = array
+
+    def __arrow_c_array__(self):
+        return self.array.__arrow_c_array__()
+
+
+class StreamWrapper:
+    def __init__(self, stream):
+        self.stream = stream
+
+    def __arrow_c_stream__(self):
+        return self.stream.__arrow_c_stream__()
+
 
 @pytest.mark.parametrize("pyarrow_type", _supported_pyarrow_types, ids=str)
 def test_type_roundtrip(pyarrow_type):
@@ -120,6 +150,14 @@ def test_type_roundtrip(pyarrow_type):
     assert restored == pyarrow_type
     assert restored is not pyarrow_type
 
[email protected](PYARROW_PRE_14, reason="requires pyarrow 14")
[email protected]("pyarrow_type", _supported_pyarrow_types, ids=str)
+def test_type_roundtrip_pycapsule(pyarrow_type):
+    wrapped = SchemaWrapper(pyarrow_type)
+    restored = rust.round_trip_type(wrapped)
+    assert restored == pyarrow_type
+    assert restored is not pyarrow_type
+
 
 @pytest.mark.parametrize("pyarrow_type", _unsupported_pyarrow_types, ids=str)
 def test_type_roundtrip_raises(pyarrow_type):
@@ -138,6 +176,20 @@ def test_field_roundtrip(pyarrow_type):
         field = rust.round_trip_field(pyarrow_field)
         assert field == pyarrow_field
 
[email protected](PYARROW_PRE_14, reason="requires pyarrow 14")
[email protected]('pyarrow_type', _supported_pyarrow_types, ids=str)
+def test_field_roundtrip_pycapsule(pyarrow_type):
+    pyarrow_field = pa.field("test", pyarrow_type, nullable=True)
+    wrapped = SchemaWrapper(pyarrow_field)
+    field = rust.round_trip_field(wrapped)
+    assert field == wrapped.schema
+
+    if pyarrow_type != pa.null():
+        # A null type field may not be non-nullable
+        pyarrow_field = pa.field("test", pyarrow_type, nullable=False)
+        field = rust.round_trip_field(wrapped)
+        assert field == wrapped.schema
+
 def test_field_metadata_roundtrip():
     metadata = {"hello": "World! 😊", "x": "2"}
     pyarrow_field = pa.field("test", pa.int32(), metadata=metadata)
@@ -163,6 +215,17 @@ def test_primitive_python():
     del b
 
 
[email protected](PYARROW_PRE_14, reason="requires pyarrow 14")
+def test_primitive_python_pycapsule():
+    """
+    Python -> Rust -> Python
+    """
+    a = pa.array([1, 2, 3])
+    wrapped = ArrayWrapper(a)
+    b = rust.double(wrapped)
+    assert b == pa.array([2, 4, 6])
+
+
 def test_primitive_rust():
     """
     Rust -> Python -> Rust
@@ -433,6 +496,33 @@ def test_record_batch_reader():
     got_batches = list(b)
     assert got_batches == batches
 
[email protected](PYARROW_PRE_14, reason="requires pyarrow 14")
+def test_record_batch_reader_pycapsule():
+    """
+    Python -> Rust -> Python
+    """
+    schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': 
b'value1'})
+    batches = [
+        pa.record_batch([[[1], [2, 42]]], schema),
+        pa.record_batch([[None, [], [5, 6]]], schema),
+    ]
+    a = pa.RecordBatchReader.from_batches(schema, batches)
+    wrapped = StreamWrapper(a)
+    b = rust.round_trip_record_batch_reader(wrapped)
+
+    assert b.schema == schema
+    got_batches = list(b)
+    assert got_batches == batches
+
+    # Also try the boxed reader variant
+    a = pa.RecordBatchReader.from_batches(schema, batches)
+    wrapped = StreamWrapper(a)
+    b = rust.boxed_reader_roundtrip(wrapped)
+    assert b.schema == schema
+    got_batches = list(b)
+    assert got_batches == batches
+
+
 def test_record_batch_reader_error():
     schema = pa.schema([('ints', pa.list_(pa.int32()))])
 
@@ -453,24 +543,64 @@ def test_record_batch_reader_error():
     with pytest.raises(ValueError, match="invalid utf-8"):
         rust.round_trip_record_batch_reader(reader)
 
+
[email protected](PYARROW_PRE_14, reason="requires pyarrow 14")
+def test_record_batch_pycapsule():
+    """
+    Python -> Rust -> Python
+    """
+    schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': 
b'value1'})
+    batch = pa.record_batch([[[1], [2, 42]]], schema)
+    wrapped = StreamWrapper(batch)
+    b = rust.round_trip_record_batch_reader(wrapped)
+    new_table = b.read_all()
+    new_batches = new_table.to_batches()
+
+    assert len(new_batches) == 1
+    new_batch = new_batches[0]
+
+    assert batch == new_batch
+    assert batch.schema == new_batch.schema
+
+
[email protected](PYARROW_PRE_14, reason="requires pyarrow 14")
+def test_table_pycapsule():
+    """
+    Python -> Rust -> Python
+    """
+    schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': 
b'value1'})
+    batches = [
+        pa.record_batch([[[1], [2, 42]]], schema),
+        pa.record_batch([[None, [], [5, 6]]], schema),
+    ]
+    table = pa.Table.from_batches(batches)
+    wrapped = StreamWrapper(table)
+    b = rust.round_trip_record_batch_reader(wrapped)
+    new_table = b.read_all()
+
+    assert table.schema == new_table.schema
+    assert table == new_table
+    assert len(table.to_batches()) == len(new_table.to_batches())
+
+
 def test_reject_other_classes():
     # Arbitrary type that is not a PyArrow type
     not_pyarrow = ["hello"]
 
     with pytest.raises(TypeError, match="Expected instance of 
pyarrow.lib.Array, got builtins.list"):
         rust.round_trip_array(not_pyarrow)
-    
+
     with pytest.raises(TypeError, match="Expected instance of 
pyarrow.lib.Schema, got builtins.list"):
         rust.round_trip_schema(not_pyarrow)
-    
+
     with pytest.raises(TypeError, match="Expected instance of 
pyarrow.lib.Field, got builtins.list"):
         rust.round_trip_field(not_pyarrow)
-    
+
     with pytest.raises(TypeError, match="Expected instance of 
pyarrow.lib.DataType, got builtins.list"):
         rust.round_trip_type(not_pyarrow)
 
     with pytest.raises(TypeError, match="Expected instance of 
pyarrow.lib.RecordBatch, got builtins.list"):
         rust.round_trip_record_batch(not_pyarrow)
-    
+
     with pytest.raises(TypeError, match="Expected instance of 
pyarrow.lib.RecordBatchReader, got builtins.list"):
         rust.round_trip_record_batch_reader(not_pyarrow)
diff --git a/arrow-schema/src/ffi.rs b/arrow-schema/src/ffi.rs
index 7e33a78fec..640a7de798 100644
--- a/arrow-schema/src/ffi.rs
+++ b/arrow-schema/src/ffi.rs
@@ -351,6 +351,8 @@ impl Drop for FFI_ArrowSchema {
     }
 }
 
+unsafe impl Send for FFI_ArrowSchema {}
+
 impl TryFrom<&FFI_ArrowSchema> for DataType {
     type Error = ArrowError;
 
diff --git a/arrow/src/pyarrow.rs b/arrow/src/pyarrow.rs
index 517c333add..4d262b0d10 100644
--- a/arrow/src/pyarrow.rs
+++ b/arrow/src/pyarrow.rs
@@ -59,12 +59,12 @@ use std::convert::{From, TryFrom};
 use std::ptr::{addr_of, addr_of_mut};
 use std::sync::Arc;
 
-use arrow_array::{RecordBatchIterator, RecordBatchReader};
+use arrow_array::{RecordBatchIterator, RecordBatchReader, StructArray};
 use pyo3::exceptions::{PyTypeError, PyValueError};
 use pyo3::ffi::Py_uintptr_t;
 use pyo3::import_exception;
 use pyo3::prelude::*;
-use pyo3::types::{PyList, PyTuple};
+use pyo3::types::{PyCapsule, PyList, PyTuple};
 
 use crate::array::{make_array, ArrayData};
 use crate::datatypes::{DataType, Field, Schema};
@@ -118,8 +118,40 @@ fn validate_class(expected: &str, value: &PyAny) -> 
PyResult<()> {
     Ok(())
 }
 
+fn validate_pycapsule(capsule: &PyCapsule, name: &str) -> PyResult<()> {
+    let capsule_name = capsule.name()?;
+    if capsule_name.is_none() {
+        return Err(PyValueError::new_err(
+            "Expected schema PyCapsule to have name set.",
+        ));
+    }
+
+    let capsule_name = capsule_name.unwrap().to_str()?;
+    if capsule_name != name {
+        return Err(PyValueError::new_err(format!(
+            "Expected name '{}' in PyCapsule, instead got '{}'",
+            name, capsule_name
+        )));
+    }
+
+    Ok(())
+}
+
 impl FromPyArrow for DataType {
     fn from_pyarrow(value: &PyAny) -> PyResult<Self> {
+        // Newer versions of PyArrow as well as other libraries with Arrow 
data implement this
+        // method, so prefer it over _export_to_c.
+        // See 
https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
+        if value.hasattr("__arrow_c_schema__")? {
+            let capsule: &PyCapsule =
+                
PyTryInto::try_into(value.getattr("__arrow_c_schema__")?.call0()?)?;
+            validate_pycapsule(capsule, "arrow_schema")?;
+
+            let schema_ptr = unsafe { capsule.reference::<FFI_ArrowSchema>() };
+            let dtype = DataType::try_from(schema_ptr).map_err(to_py_err)?;
+            return Ok(dtype);
+        }
+
         validate_class("DataType", value)?;
 
         let c_schema = FFI_ArrowSchema::empty();
@@ -143,6 +175,19 @@ impl ToPyArrow for DataType {
 
 impl FromPyArrow for Field {
     fn from_pyarrow(value: &PyAny) -> PyResult<Self> {
+        // Newer versions of PyArrow as well as other libraries with Arrow 
data implement this
+        // method, so prefer it over _export_to_c.
+        // See 
https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
+        if value.hasattr("__arrow_c_schema__")? {
+            let capsule: &PyCapsule =
+                
PyTryInto::try_into(value.getattr("__arrow_c_schema__")?.call0()?)?;
+            validate_pycapsule(capsule, "arrow_schema")?;
+
+            let schema_ptr = unsafe { capsule.reference::<FFI_ArrowSchema>() };
+            let field = Field::try_from(schema_ptr).map_err(to_py_err)?;
+            return Ok(field);
+        }
+
         validate_class("Field", value)?;
 
         let c_schema = FFI_ArrowSchema::empty();
@@ -166,6 +211,19 @@ impl ToPyArrow for Field {
 
 impl FromPyArrow for Schema {
     fn from_pyarrow(value: &PyAny) -> PyResult<Self> {
+        // Newer versions of PyArrow as well as other libraries with Arrow 
data implement this
+        // method, so prefer it over _export_to_c.
+        // See 
https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
+        if value.hasattr("__arrow_c_schema__")? {
+            let capsule: &PyCapsule =
+                
PyTryInto::try_into(value.getattr("__arrow_c_schema__")?.call0()?)?;
+            validate_pycapsule(capsule, "arrow_schema")?;
+
+            let schema_ptr = unsafe { capsule.reference::<FFI_ArrowSchema>() };
+            let schema = Schema::try_from(schema_ptr).map_err(to_py_err)?;
+            return Ok(schema);
+        }
+
         validate_class("Schema", value)?;
 
         let c_schema = FFI_ArrowSchema::empty();
@@ -189,6 +247,30 @@ impl ToPyArrow for Schema {
 
 impl FromPyArrow for ArrayData {
     fn from_pyarrow(value: &PyAny) -> PyResult<Self> {
+        // Newer versions of PyArrow as well as other libraries with Arrow 
data implement this
+        // method, so prefer it over _export_to_c.
+        // See 
https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
+        if value.hasattr("__arrow_c_array__")? {
+            let tuple = value.getattr("__arrow_c_array__")?.call0()?;
+
+            if !tuple.is_instance_of::<PyTuple>() {
+                return Err(PyTypeError::new_err(
+                    "Expected __arrow_c_array__ to return a tuple.",
+                ));
+            }
+
+            let schema_capsule: &PyCapsule = 
PyTryInto::try_into(tuple.get_item(0)?)?;
+            let array_capsule: &PyCapsule = 
PyTryInto::try_into(tuple.get_item(1)?)?;
+
+            validate_pycapsule(schema_capsule, "arrow_schema")?;
+            validate_pycapsule(array_capsule, "arrow_array")?;
+
+            let schema_ptr = unsafe { 
schema_capsule.reference::<FFI_ArrowSchema>() };
+            let array_ptr = array_capsule.pointer() as *mut FFI_ArrowArray;
+            let array = unsafe { std::ptr::replace(array_ptr, 
FFI_ArrowArray::empty()) };
+            return ffi::from_ffi(array, schema_ptr).map_err(to_py_err);
+        }
+
         validate_class("Array", value)?;
 
         // prepare a pointer to receive the Array struct
@@ -247,6 +329,37 @@ impl<T: ToPyArrow> ToPyArrow for Vec<T> {
 
 impl FromPyArrow for RecordBatch {
     fn from_pyarrow(value: &PyAny) -> PyResult<Self> {
+        // Newer versions of PyArrow as well as other libraries with Arrow 
data implement this
+        // method, so prefer it over _export_to_c.
+        // See 
https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
+        if value.hasattr("__arrow_c_array__")? {
+            let tuple = value.getattr("__arrow_c_array__")?.call0()?;
+
+            if !tuple.is_instance_of::<PyTuple>() {
+                return Err(PyTypeError::new_err(
+                    "Expected __arrow_c_array__ to return a tuple.",
+                ));
+            }
+
+            let schema_capsule: &PyCapsule = 
PyTryInto::try_into(tuple.get_item(0)?)?;
+            let array_capsule: &PyCapsule = 
PyTryInto::try_into(tuple.get_item(1)?)?;
+
+            validate_pycapsule(schema_capsule, "arrow_schema")?;
+            validate_pycapsule(array_capsule, "arrow_array")?;
+
+            let schema_ptr = unsafe { 
schema_capsule.reference::<FFI_ArrowSchema>() };
+            let array_ptr = array_capsule.pointer() as *mut FFI_ArrowArray;
+            let ffi_array = unsafe { std::ptr::replace(array_ptr, 
FFI_ArrowArray::empty()) };
+            let array_data = ffi::from_ffi(ffi_array, 
schema_ptr).map_err(to_py_err)?;
+            if !matches!(array_data.data_type(), DataType::Struct(_)) {
+                return Err(PyTypeError::new_err(
+                    "Expected Struct type from __arrow_c_array.",
+                ));
+            }
+            let array = StructArray::from(array_data);
+            return Ok(array.into());
+        }
+
         validate_class("RecordBatch", value)?;
         // TODO(kszucs): implement the FFI conversions in arrow-rs for 
RecordBatches
         let schema = value.getattr("schema")?;
@@ -276,6 +389,23 @@ impl ToPyArrow for RecordBatch {
 /// Supports conversion from `pyarrow.RecordBatchReader` to 
[ArrowArrayStreamReader].
 impl FromPyArrow for ArrowArrayStreamReader {
     fn from_pyarrow(value: &PyAny) -> PyResult<Self> {
+        // Newer versions of PyArrow as well as other libraries with Arrow 
data implement this
+        // method, so prefer it over _export_to_c.
+        // See 
https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html
+        if value.hasattr("__arrow_c_stream__")? {
+            let capsule: &PyCapsule =
+                
PyTryInto::try_into(value.getattr("__arrow_c_stream__")?.call0()?)?;
+            validate_pycapsule(capsule, "arrow_array_stream")?;
+
+            let stream_ptr = capsule.pointer() as *mut FFI_ArrowArrayStream;
+            let stream = unsafe { std::ptr::replace(stream_ptr, 
FFI_ArrowArrayStream::empty()) };
+
+            let stream_reader = ArrowArrayStreamReader::try_new(stream)
+                .map_err(|err| PyValueError::new_err(err.to_string()))?;
+
+            return Ok(stream_reader);
+        }
+
         validate_class("RecordBatchReader", value)?;
 
         // prepare a pointer to receive the stream struct

Reply via email to