pitrou commented on code in PR #37797:
URL: https://github.com/apache/arrow/pull/37797#discussion_r1346079581


##########
python/pyarrow/tests/test_cffi.py:
##########
@@ -411,3 +415,120 @@ def test_imported_batch_reader_error():
                        match="Expected to be able to read 16 bytes "
                              "for message body, got 8"):
         reader_new.read_all()
+
+
[email protected]('obj', [pa.int32(), pa.field('foo', pa.int32()),
+                                 pa.schema({'foo': pa.int32()})],
+                         ids=['type', 'field', 'schema'])
+def test_roundtrip_schema_capsule(obj):
+    gc.collect()  # Make sure no Arrow data dangles in a ref cycle
+    old_allocated = pa.total_allocated_bytes()
+
+    capsule = obj.__arrow_c_schema__()
+    assert PyCapsule_IsValid(capsule, b"arrow_schema") == 1
+    obj_out = type(obj)._import_from_c_capsule(capsule)
+    assert obj_out == obj
+
+    assert pa.total_allocated_bytes() == old_allocated
+
+    capsule = obj.__arrow_c_schema__()
+
+    assert pa.total_allocated_bytes() > old_allocated
+    del capsule
+    assert pa.total_allocated_bytes() == old_allocated
+
+
[email protected]('arr,schema_accessor,bad_type,good_type', [
+    (pa.array(['a', 'b', 'c']), lambda x: x.type, pa.int32(), pa.string()),
+    (
+        pa.record_batch([pa.array(['a', 'b', 'c'])], names=['x']),
+        lambda x: x.schema,
+        pa.schema({'x': pa.int32()}),
+        pa.schema({'x': pa.string()})
+    ),
+], ids=['array', 'record_batch'])
+def test_roundtrip_array_capsule(arr, schema_accessor, bad_type, good_type):
+    gc.collect()  # Make sure no Arrow data dangles in a ref cycle
+    old_allocated = pa.total_allocated_bytes()
+
+    import_array = type(arr)._import_from_c_capsule
+
+    schema_capsule, capsule = arr.__arrow_c_array__()
+    assert PyCapsule_IsValid(schema_capsule, b"arrow_schema") == 1
+    assert PyCapsule_IsValid(capsule, b"arrow_array") == 1
+    arr_out = import_array(schema_capsule, capsule)
+    assert arr_out.equals(arr)
+
+    assert pa.total_allocated_bytes() > old_allocated
+    del arr_out
+
+    assert pa.total_allocated_bytes() == old_allocated
+
+    capsule = arr.__arrow_c_array__()
+
+    assert pa.total_allocated_bytes() > old_allocated
+    del capsule
+    assert pa.total_allocated_bytes() == old_allocated
+
+    with pytest.raises(ValueError,

Review Comment:
   Ah, it's a bit more sophisticated: we are catching the `ArrowInvalid` to 
replace it with our own error message.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to