This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/main by this push:
     new 14f1eb97fb pyarrow: Cache the imported classes to avoid importing them 
each time (#9439)
14f1eb97fb is described below

commit 14f1eb97fbf017dbd0faef749f62f6cd9389a451
Author: Thomas Tanon <[email protected]>
AuthorDate: Thu Mar 19 18:42:07 2026 +0100

    pyarrow: Cache the imported classes to avoid importing them each time 
(#9439)
    
    # Which issue does this PR close?
    
    - Closes #9438.
    
    # Rationale for this change
    
    Speed up conversion by only importing `pyarrow` once.
    
    # What changes are included in this PR?
    
    - Use `PyOnceLock::import` to import the types.
    - Remove some not useful `.extract::<PyBackedStr>()?` (the `Display`
    implementation already does something similar)
    
    # Are these changes tested?
    
    Covered by existing tests. It would be nice to add benchmark but it
    might require to:
    - either add a dependency to a python benchmark runner
    - write some hacky code to import `pyarrow` from criterion tests (likely
    by running `pip`/`uv` from the Rust benchmark code)
    
    # Are there any user-facing changes?
    
    No
---
 arrow-pyarrow/src/lib.rs | 98 +++++++++++++++++++++++++++++-------------------
 1 file changed, 59 insertions(+), 39 deletions(-)

diff --git a/arrow-pyarrow/src/lib.rs b/arrow-pyarrow/src/lib.rs
index 15951f8dcf..e396711f87 100644
--- a/arrow-pyarrow/src/lib.rs
+++ b/arrow-pyarrow/src/lib.rs
@@ -75,10 +75,10 @@ use arrow_data::ArrayData;
 use arrow_schema::{ArrowError, DataType, Field, Schema, SchemaRef};
 use pyo3::exceptions::{PyTypeError, PyValueError};
 use pyo3::ffi::Py_uintptr_t;
+use pyo3::import_exception;
 use pyo3::prelude::*;
-use pyo3::pybacked::PyBackedStr;
-use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple};
-use pyo3::{import_exception, intern};
+use pyo3::sync::PyOnceLock;
+use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType};
 
 import_exception!(pyarrow, ArrowException);
 /// Represents an exception raised by PyArrow.
@@ -118,17 +118,13 @@ impl<T: ToPyArrow> IntoPyArrow for T {
     }
 }
 
-fn validate_class(expected: &str, value: &Bound<PyAny>) -> PyResult<()> {
-    let pyarrow = PyModule::import(value.py(), "pyarrow")?;
-    let class = pyarrow.getattr(expected)?;
-    if !value.is_instance(&class)? {
-        let expected_module = 
class.getattr("__module__")?.extract::<PyBackedStr>()?;
-        let expected_name = 
class.getattr("__name__")?.extract::<PyBackedStr>()?;
+fn validate_class(expected: &Bound<PyType>, value: &Bound<PyAny>) -> 
PyResult<()> {
+    if !value.is_instance(expected)? {
+        let expected_module = expected.getattr("__module__")?;
+        let expected_name = expected.getattr("__name__")?;
         let found_class = value.get_type();
-        let found_module = found_class
-            .getattr("__module__")?
-            .extract::<PyBackedStr>()?;
-        let found_name = 
found_class.getattr("__name__")?.extract::<PyBackedStr>()?;
+        let found_module = found_class.getattr("__module__")?;
+        let found_name = found_class.getattr("__name__")?;
         return Err(PyTypeError::new_err(format!(
             "Expected instance of {expected_module}.{expected_name}, got 
{found_module}.{found_name}",
         )));
@@ -173,7 +169,7 @@ impl FromPyArrow for DataType {
             }
         }
 
-        validate_class("DataType", value)?;
+        validate_class(data_type_class(value.py())?, value)?;
 
         let c_schema = FFI_ArrowSchema::empty();
         let c_schema_ptr = &c_schema as *const FFI_ArrowSchema;
@@ -187,9 +183,8 @@ impl ToPyArrow for DataType {
     fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
         let c_schema = FFI_ArrowSchema::try_from(self).map_err(to_py_err)?;
         let c_schema_ptr = &c_schema as *const FFI_ArrowSchema;
-        let module = py.import("pyarrow")?;
-        let class = module.getattr("DataType")?;
-        let dtype = class.call_method1("_import_from_c", (c_schema_ptr as 
Py_uintptr_t,))?;
+        let dtype =
+            data_type_class(py)?.call_method1("_import_from_c", (c_schema_ptr 
as Py_uintptr_t,))?;
         Ok(dtype)
     }
 }
@@ -213,7 +208,7 @@ impl FromPyArrow for Field {
             }
         }
 
-        validate_class("Field", value)?;
+        validate_class(field_class(value.py())?, value)?;
 
         let c_schema = FFI_ArrowSchema::empty();
         let c_schema_ptr = &c_schema as *const FFI_ArrowSchema;
@@ -227,9 +222,8 @@ impl ToPyArrow for Field {
     fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
         let c_schema = FFI_ArrowSchema::try_from(self).map_err(to_py_err)?;
         let c_schema_ptr = &c_schema as *const FFI_ArrowSchema;
-        let module = py.import("pyarrow")?;
-        let class = module.getattr("Field")?;
-        let dtype = class.call_method1("_import_from_c", (c_schema_ptr as 
Py_uintptr_t,))?;
+        let dtype =
+            field_class(py)?.call_method1("_import_from_c", (c_schema_ptr as 
Py_uintptr_t,))?;
         Ok(dtype)
     }
 }
@@ -253,7 +247,7 @@ impl FromPyArrow for Schema {
             }
         }
 
-        validate_class("Schema", value)?;
+        validate_class(schema_class(value.py())?, value)?;
 
         let c_schema = FFI_ArrowSchema::empty();
         let c_schema_ptr = &c_schema as *const FFI_ArrowSchema;
@@ -267,9 +261,8 @@ impl ToPyArrow for Schema {
     fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
         let c_schema = FFI_ArrowSchema::try_from(self).map_err(to_py_err)?;
         let c_schema_ptr = &c_schema as *const FFI_ArrowSchema;
-        let module = py.import("pyarrow")?;
-        let class = module.getattr("Schema")?;
-        let schema = class.call_method1("_import_from_c", (c_schema_ptr as 
Py_uintptr_t,))?;
+        let schema =
+            schema_class(py)?.call_method1("_import_from_c", (c_schema_ptr as 
Py_uintptr_t,))?;
         Ok(schema)
     }
 }
@@ -310,7 +303,7 @@ impl FromPyArrow for ArrayData {
             return unsafe { ffi::from_ffi(array, schema_ptr.as_ref()) 
}.map_err(to_py_err);
         }
 
-        validate_class("Array", value)?;
+        validate_class(array_class(value.py())?, value)?;
 
         // prepare a pointer to receive the Array struct
         let mut array = FFI_ArrowArray::empty();
@@ -336,9 +329,7 @@ impl ToPyArrow for ArrayData {
         let array = FFI_ArrowArray::new(self);
         let schema = 
FFI_ArrowSchema::try_from(self.data_type()).map_err(to_py_err)?;
 
-        let module = py.import("pyarrow")?;
-        let class = module.getattr("Array")?;
-        let array = class.call_method1(
+        let array = array_class(py)?.call_method1(
             "_import_from_c",
             (
                 addr_of!(array) as Py_uintptr_t,
@@ -423,7 +414,7 @@ impl FromPyArrow for RecordBatch {
             return RecordBatch::try_new_with_options(schema, columns, 
&options).map_err(to_py_err);
         }
 
-        validate_class("RecordBatch", value)?;
+        validate_class(record_batch_class(value.py())?, value)?;
         // TODO(kszucs): implement the FFI conversions in arrow-rs for 
RecordBatches
         let schema = value.getattr("schema")?;
         let schema = Arc::new(Schema::from_pyarrow_bound(&schema)?);
@@ -483,7 +474,7 @@ impl FromPyArrow for ArrowArrayStreamReader {
             return Ok(stream_reader);
         }
 
-        validate_class("RecordBatchReader", value)?;
+        validate_class(record_batch_reader_class(value.py())?, value)?;
 
         // prepare a pointer to receive the stream struct
         let mut stream = FFI_ArrowArrayStream::empty();
@@ -510,10 +501,8 @@ impl IntoPyArrow for Box<dyn RecordBatchReader + Send> {
         let mut stream = FFI_ArrowArrayStream::new(self);
 
         let stream_ptr = (&mut stream) as *mut FFI_ArrowArrayStream;
-        let module = py.import("pyarrow")?;
-        let class = module.getattr("RecordBatchReader")?;
-        let args = PyTuple::new(py, [stream_ptr as Py_uintptr_t])?;
-        let reader = class.call_method1("_import_from_c", args)?;
+        let reader = record_batch_reader_class(py)?
+            .call_method1("_import_from_c", (stream_ptr as Py_uintptr_t,))?;
 
         Ok(reader)
     }
@@ -606,21 +595,52 @@ impl FromPyArrow for Table {
 /// Convert a [`Table`] into `pyarrow.Table`.
 impl IntoPyArrow for Table {
     fn into_pyarrow(self, py: Python) -> PyResult<Bound<PyAny>> {
-        let module = py.import(intern!(py, "pyarrow"))?;
-        let class = module.getattr(intern!(py, "Table"))?;
-
         let py_batches = PyList::new(py, 
self.record_batches.into_iter().map(PyArrowType))?;
         let py_schema = PyArrowType(Arc::unwrap_or_clone(self.schema));
 
         let kwargs = PyDict::new(py);
         kwargs.set_item("schema", py_schema)?;
 
-        let reader = class.call_method("from_batches", (py_batches,), 
Some(&kwargs))?;
+        let reader = table_class(py)?.call_method("from_batches", 
(py_batches,), Some(&kwargs))?;
 
         Ok(reader)
     }
 }
 
+fn array_class(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
+    static TYPE: PyOnceLock<Py<PyType>> = PyOnceLock::new();
+    TYPE.import(py, "pyarrow", "Array")
+}
+
+fn record_batch_class(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
+    static TYPE: PyOnceLock<Py<PyType>> = PyOnceLock::new();
+    TYPE.import(py, "pyarrow", "RecordBatch")
+}
+
+fn record_batch_reader_class(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
+    static TYPE: PyOnceLock<Py<PyType>> = PyOnceLock::new();
+    TYPE.import(py, "pyarrow", "RecordBatchReader")
+}
+fn data_type_class(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
+    static TYPE: PyOnceLock<Py<PyType>> = PyOnceLock::new();
+    TYPE.import(py, "pyarrow", "DataType")
+}
+
+fn field_class(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
+    static TYPE: PyOnceLock<Py<PyType>> = PyOnceLock::new();
+    TYPE.import(py, "pyarrow", "Field")
+}
+
+fn schema_class(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
+    static TYPE: PyOnceLock<Py<PyType>> = PyOnceLock::new();
+    TYPE.import(py, "pyarrow", "Schema")
+}
+
+fn table_class(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
+    static TYPE: PyOnceLock<Py<PyType>> = PyOnceLock::new();
+    TYPE.import(py, "pyarrow", "Table")
+}
+
 /// A newtype wrapper for types implementing [`FromPyArrow`] or 
[`IntoPyArrow`].
 ///
 /// When wrapped around a type `T: FromPyArrow`, it

Reply via email to