This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-python.git
The following commit(s) were added to refs/heads/main by this push:
new c7ea90d Pyo3 `Bound<'py, T>` api (#734)
c7ea90d is described below
commit c7ea90d3cc17d72e26ef67c12efb4a2bd7bef942
Author: Michael J Ward <[email protected]>
AuthorDate: Tue Jun 18 09:37:22 2024 -0500
Pyo3 `Bound<'py, T>` api (#734)
* remove gil-refs feature from pyo3
* migrate module instantiation to Bound api
* migrate utils.rs to Bound api
* migrate config.rs to Bound api
* migrate context.rs to Bound api
* migrate udaf.rs to Bound api
* migrate pyarrow_filter_expression to Bound api
* migrate dataframe.rs to Bound api
* migrade dataset and dataset_exec to Bound api
* migrate substrait.rs to Bound api
---
Cargo.toml | 3 +-
src/common.rs | 2 +-
src/config.rs | 2 +-
src/context.rs | 23 ++++++++------
src/dataframe.rs | 67 ++++++++++++++++++++++++----------------
src/dataset.rs | 13 ++++----
src/dataset_exec.rs | 36 ++++++++++-----------
src/expr.rs | 2 +-
src/functions.rs | 2 +-
src/lib.rs | 36 ++++++++++-----------
src/pyarrow_filter_expression.rs | 23 +++++++-------
src/store.rs | 2 +-
src/substrait.rs | 8 ++---
src/udaf.rs | 28 ++++++++---------
src/utils.rs | 4 +--
15 files changed, 136 insertions(+), 115 deletions(-)
diff --git a/Cargo.toml b/Cargo.toml
index 7285cf3..78b46e4 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -36,7 +36,7 @@ substrait = ["dep:datafusion-substrait"]
[dependencies]
tokio = { version = "1.35", features = ["macros", "rt", "rt-multi-thread",
"sync"] }
rand = "0.8"
-pyo3 = { version = "0.21", features = ["extension-module", "abi3",
"abi3-py38", "gil-refs"] }
+pyo3 = { version = "0.21", features = ["extension-module", "abi3",
"abi3-py38"] }
arrow = { version = "52", feature = ["pyarrow"] }
datafusion = { version = "39.0.0", features = ["pyarrow", "avro",
"unicode_expressions"] }
datafusion-common = { version = "39.0.0", features = ["pyarrow"] }
@@ -67,3 +67,4 @@ crate-type = ["cdylib", "rlib"]
[profile.release]
lto = true
codegen-units = 1
+
\ No newline at end of file
diff --git a/src/common.rs b/src/common.rs
index 682639a..44c557c 100644
--- a/src/common.rs
+++ b/src/common.rs
@@ -23,7 +23,7 @@ pub mod function;
pub mod schema;
/// Initializes the `common` module to match the pattern of
`datafusion-common`
https://docs.rs/datafusion-common/18.0.0/datafusion_common/index.html
-pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
+pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<df_schema::PyDFSchema>()?;
m.add_class::<data_type::PyDataType>()?;
m.add_class::<data_type::DataTypeMap>()?;
diff --git a/src/config.rs b/src/config.rs
index 228f95a..82a4f93 100644
--- a/src/config.rs
+++ b/src/config.rs
@@ -65,7 +65,7 @@ impl PyConfig {
/// Get all configuration options
pub fn get_all(&mut self, py: Python) -> PyResult<PyObject> {
- let dict = PyDict::new(py);
+ let dict = PyDict::new_bound(py);
let options = self.config.to_owned();
for entry in options.entries() {
dict.set_item(entry.key, entry.value.clone().into_py(py))?;
diff --git a/src/context.rs b/src/context.rs
index 9462d0b..ec63adb 100644
--- a/src/context.rs
+++ b/src/context.rs
@@ -291,11 +291,11 @@ impl PySessionContext {
pub fn register_object_store(
&mut self,
scheme: &str,
- store: &PyAny,
+ store: &Bound<'_, PyAny>,
host: Option<&str>,
) -> PyResult<()> {
let res: Result<(Arc<dyn ObjectStore>, String), PyErr> =
- match StorageContexts::extract(store) {
+ match StorageContexts::extract_bound(store) {
Ok(store) => match store {
StorageContexts::AmazonS3(s3) => Ok((s3.inner,
s3.bucket_name)),
StorageContexts::GoogleCloudStorage(gcs) => Ok((gcs.inner,
gcs.bucket_name)),
@@ -443,8 +443,8 @@ impl PySessionContext {
) -> PyResult<PyDataFrame> {
Python::with_gil(|py| {
// Instantiate pyarrow Table object & convert to Arrow Table
- let table_class = py.import("pyarrow")?.getattr("Table")?;
- let args = PyTuple::new(py, &[data]);
+ let table_class = py.import_bound("pyarrow")?.getattr("Table")?;
+ let args = PyTuple::new_bound(py, &[data]);
let table = table_class.call_method1("from_pylist", args)?.into();
// Convert Arrow Table to datafusion DataFrame
@@ -463,8 +463,8 @@ impl PySessionContext {
) -> PyResult<PyDataFrame> {
Python::with_gil(|py| {
// Instantiate pyarrow Table object & convert to Arrow Table
- let table_class = py.import("pyarrow")?.getattr("Table")?;
- let args = PyTuple::new(py, &[data]);
+ let table_class = py.import_bound("pyarrow")?.getattr("Table")?;
+ let args = PyTuple::new_bound(py, &[data]);
let table = table_class.call_method1("from_pydict", args)?.into();
// Convert Arrow Table to datafusion DataFrame
@@ -507,8 +507,8 @@ impl PySessionContext {
) -> PyResult<PyDataFrame> {
Python::with_gil(|py| {
// Instantiate pyarrow Table object & convert to Arrow Table
- let table_class = py.import("pyarrow")?.getattr("Table")?;
- let args = PyTuple::new(py, &[data]);
+ let table_class = py.import_bound("pyarrow")?.getattr("Table")?;
+ let args = PyTuple::new_bound(py, &[data]);
let table = table_class.call_method1("from_pandas", args)?.into();
// Convert Arrow Table to datafusion DataFrame
@@ -710,7 +710,12 @@ impl PySessionContext {
}
// Registers a PyArrow.Dataset
- pub fn register_dataset(&self, name: &str, dataset: &PyAny, py: Python) ->
PyResult<()> {
+ pub fn register_dataset(
+ &self,
+ name: &str,
+ dataset: &Bound<'_, PyAny>,
+ py: Python,
+ ) -> PyResult<()> {
let table: Arc<dyn TableProvider> = Arc::new(Dataset::new(dataset,
py)?);
self.ctx
diff --git a/src/dataframe.rs b/src/dataframe.rs
index 8f45143..1b91067 100644
--- a/src/dataframe.rs
+++ b/src/dataframe.rs
@@ -28,6 +28,7 @@ use datafusion::prelude::*;
use datafusion_common::UnnestOptions;
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::prelude::*;
+use pyo3::pybacked::PyBackedStr;
use pyo3::types::PyTuple;
use tokio::task::JoinHandle;
@@ -56,23 +57,25 @@ impl PyDataFrame {
#[pymethods]
impl PyDataFrame {
- fn __getitem__(&self, key: PyObject) -> PyResult<Self> {
- Python::with_gil(|py| {
- if let Ok(key) = key.extract::<&str>(py) {
- self.select_columns(vec![key])
- } else if let Ok(tuple) = key.extract::<&PyTuple>(py) {
- let keys = tuple
- .iter()
- .map(|item| item.extract::<&str>())
- .collect::<PyResult<Vec<&str>>>()?;
- self.select_columns(keys)
- } else if let Ok(keys) = key.extract::<Vec<&str>>(py) {
- self.select_columns(keys)
- } else {
- let message = "DataFrame can only be indexed by string index
or indices";
- Err(PyTypeError::new_err(message))
- }
- })
+ /// Enable selection for `df[col]`, `df[col1, col2, col3]`, and `df[[col1,
col2, col3]]`
+ fn __getitem__(&self, key: Bound<'_, PyAny>) -> PyResult<Self> {
+ if let Ok(key) = key.extract::<PyBackedStr>() {
+ // df[col]
+ self.select_columns(vec![key])
+ } else if let Ok(tuple) = key.extract::<&PyTuple>() {
+ // df[col1, col2, col3]
+ let keys = tuple
+ .iter()
+ .map(|item| item.extract::<PyBackedStr>())
+ .collect::<PyResult<Vec<PyBackedStr>>>()?;
+ self.select_columns(keys)
+ } else if let Ok(keys) = key.extract::<Vec<PyBackedStr>>() {
+ // df[[col1, col2, col3]]
+ self.select_columns(keys)
+ } else {
+ let message = "DataFrame can only be indexed by string index or
indices";
+ Err(PyTypeError::new_err(message))
+ }
}
fn __repr__(&self, py: Python) -> PyResult<String> {
@@ -98,7 +101,8 @@ impl PyDataFrame {
}
#[pyo3(signature = (*args))]
- fn select_columns(&self, args: Vec<&str>) -> PyResult<Self> {
+ fn select_columns(&self, args: Vec<PyBackedStr>) -> PyResult<Self> {
+ let args = args.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
let df = self.df.as_ref().clone().select_columns(&args)?;
Ok(Self::new(df))
}
@@ -194,7 +198,7 @@ impl PyDataFrame {
fn join(
&self,
right: PyDataFrame,
- join_keys: (Vec<&str>, Vec<&str>),
+ join_keys: (Vec<PyBackedStr>, Vec<PyBackedStr>),
how: &str,
) -> PyResult<Self> {
let join_type = match how {
@@ -212,11 +216,22 @@ impl PyDataFrame {
}
};
+ let left_keys = join_keys
+ .0
+ .iter()
+ .map(|s| s.as_ref())
+ .collect::<Vec<&str>>();
+ let right_keys = join_keys
+ .1
+ .iter()
+ .map(|s| s.as_ref())
+ .collect::<Vec<&str>>();
+
let df = self.df.as_ref().clone().join(
right.df.as_ref().clone(),
join_type,
- &join_keys.0,
- &join_keys.1,
+ &left_keys,
+ &right_keys,
None,
)?;
Ok(Self::new(df))
@@ -414,8 +429,8 @@ impl PyDataFrame {
Python::with_gil(|py| {
// Instantiate pyarrow Table object and use its from_batches method
- let table_class = py.import("pyarrow")?.getattr("Table")?;
- let args = PyTuple::new(py, &[batches, schema]);
+ let table_class = py.import_bound("pyarrow")?.getattr("Table")?;
+ let args = PyTuple::new_bound(py, &[batches, schema]);
let table: PyObject = table_class.call_method1("from_batches",
args)?.into();
Ok(table)
})
@@ -489,8 +504,8 @@ impl PyDataFrame {
let table = self.to_arrow_table(py)?;
Python::with_gil(|py| {
- let dataframe = py.import("polars")?.getattr("DataFrame")?;
- let args = PyTuple::new(py, &[table]);
+ let dataframe = py.import_bound("polars")?.getattr("DataFrame")?;
+ let args = PyTuple::new_bound(py, &[table]);
let result: PyObject = dataframe.call1(args)?.into();
Ok(result)
})
@@ -514,7 +529,7 @@ fn print_dataframe(py: Python, df: DataFrame) ->
PyResult<()> {
// Import the Python 'builtins' module to access the print function
// Note that println! does not print to the Python debug console and is
not visible in notebooks for instance
- let print = py.import("builtins")?.getattr("print")?;
+ let print = py.import_bound("builtins")?.getattr("print")?;
print.call1((result,))?;
Ok(())
}
diff --git a/src/dataset.rs b/src/dataset.rs
index fcbb503..724b4af 100644
--- a/src/dataset.rs
+++ b/src/dataset.rs
@@ -46,13 +46,14 @@ pub(crate) struct Dataset {
impl Dataset {
// Creates a Python PyArrow.Dataset
- pub fn new(dataset: &PyAny, py: Python) -> PyResult<Self> {
+ pub fn new(dataset: &Bound<'_, PyAny>, py: Python) -> PyResult<Self> {
// Ensure that we were passed an instance of pyarrow.dataset.Dataset
- let ds = PyModule::import(py, "pyarrow.dataset")?;
- let ds_type: &PyType = ds.getattr("Dataset")?.downcast()?;
+ let ds = PyModule::import_bound(py, "pyarrow.dataset")?;
+ let ds_attr = ds.getattr("Dataset")?;
+ let ds_type = ds_attr.downcast::<PyType>()?;
if dataset.is_instance(ds_type)? {
Ok(Dataset {
- dataset: dataset.into(),
+ dataset: dataset.clone().unbind(),
})
} else {
Err(PyValueError::new_err(
@@ -73,7 +74,7 @@ impl TableProvider for Dataset {
/// Get a reference to the schema for this table
fn schema(&self) -> SchemaRef {
Python::with_gil(|py| {
- let dataset = self.dataset.as_ref(py);
+ let dataset = self.dataset.bind(py);
// This can panic but since we checked that self.dataset is a
pyarrow.dataset.Dataset it should never
Arc::new(
dataset
@@ -108,7 +109,7 @@ impl TableProvider for Dataset {
) -> DFResult<Arc<dyn ExecutionPlan>> {
Python::with_gil(|py| {
let plan: Arc<dyn ExecutionPlan> = Arc::new(
- DatasetExec::new(py, self.dataset.as_ref(py),
projection.cloned(), filters)
+ DatasetExec::new(py, self.dataset.bind(py),
projection.cloned(), filters)
.map_err(|err| DataFusionError::External(Box::new(err)))?,
);
Ok(plan)
diff --git a/src/dataset_exec.rs b/src/dataset_exec.rs
index 8ef3a56..240c864 100644
--- a/src/dataset_exec.rs
+++ b/src/dataset_exec.rs
@@ -53,7 +53,7 @@ impl Iterator for PyArrowBatchesAdapter {
fn next(&mut self) -> Option<Self::Item> {
Python::with_gil(|py| {
- let mut batches: &PyIterator = self.batches.as_ref(py);
+ let mut batches = self.batches.clone().into_bound(py);
Some(
batches
.next()?
@@ -79,7 +79,7 @@ pub(crate) struct DatasetExec {
impl DatasetExec {
pub fn new(
py: Python,
- dataset: &PyAny,
+ dataset: &Bound<'_, PyAny>,
projection: Option<Vec<usize>>,
filters: &[Expr],
) -> Result<Self, DataFusionError> {
@@ -103,7 +103,7 @@ impl DatasetExec {
})
.transpose()?;
- let kwargs = PyDict::new(py);
+ let kwargs = PyDict::new_bound(py);
kwargs.set_item("columns", columns.clone())?;
kwargs.set_item(
@@ -111,7 +111,7 @@ impl DatasetExec {
filter_expr.as_ref().map(|expr| expr.clone_ref(py)),
)?;
- let scanner = dataset.call_method("scanner", (), Some(kwargs))?;
+ let scanner = dataset.call_method("scanner", (), Some(&kwargs))?;
let schema = Arc::new(
scanner
@@ -120,19 +120,17 @@ impl DatasetExec {
.0,
);
- let builtins = Python::import(py, "builtins")?;
+ let builtins = Python::import_bound(py, "builtins")?;
let pylist = builtins.getattr("list")?;
// Get the fragments or partitions of the dataset
- let fragments_iterator: &PyAny = dataset.call_method1(
+ let fragments_iterator: Bound<'_, PyAny> = dataset.call_method1(
"get_fragments",
(filter_expr.as_ref().map(|expr| expr.clone_ref(py)),),
)?;
- let fragments: &PyList = pylist
- .call1((fragments_iterator,))?
- .downcast()
- .map_err(PyErr::from)?;
+ let fragments_iter = pylist.call1((fragments_iterator,))?;
+ let fragments =
fragments_iter.downcast::<PyList>().map_err(PyErr::from)?;
let projected_statistics = Statistics::new_unknown(&schema);
let plan_properties = datafusion::physical_plan::PlanProperties::new(
@@ -142,9 +140,9 @@ impl DatasetExec {
);
Ok(DatasetExec {
- dataset: dataset.into(),
+ dataset: dataset.clone().unbind(),
schema,
- fragments: fragments.into(),
+ fragments: fragments.clone().unbind(),
columns,
filter_expr,
projected_statistics,
@@ -183,8 +181,8 @@ impl ExecutionPlan for DatasetExec {
) -> DFResult<SendableRecordBatchStream> {
let batch_size = context.session_config().batch_size();
Python::with_gil(|py| {
- let dataset = self.dataset.as_ref(py);
- let fragments = self.fragments.as_ref(py);
+ let dataset = self.dataset.bind(py);
+ let fragments = self.fragments.bind(py);
let fragment = fragments
.get_item(partition)
.map_err(|err| InnerDataFusionError::External(Box::new(err)))?;
@@ -193,7 +191,7 @@ impl ExecutionPlan for DatasetExec {
let dataset_schema = dataset
.getattr("schema")
.map_err(|err| InnerDataFusionError::External(Box::new(err)))?;
- let kwargs = PyDict::new(py);
+ let kwargs = PyDict::new_bound(py);
kwargs
.set_item("columns", self.columns.clone())
.map_err(|err| InnerDataFusionError::External(Box::new(err)))?;
@@ -207,7 +205,7 @@ impl ExecutionPlan for DatasetExec {
.set_item("batch_size", batch_size)
.map_err(|err| InnerDataFusionError::External(Box::new(err)))?;
let scanner = fragment
- .call_method("scanner", (dataset_schema,), Some(kwargs))
+ .call_method("scanner", (dataset_schema,), Some(&kwargs))
.map_err(|err| InnerDataFusionError::External(Box::new(err)))?;
let schema: SchemaRef = Arc::new(
scanner
@@ -215,7 +213,7 @@ impl ExecutionPlan for DatasetExec {
.and_then(|schema|
Ok(schema.extract::<PyArrowType<_>>()?.0))
.map_err(|err|
InnerDataFusionError::External(Box::new(err)))?,
);
- let record_batches: &PyIterator = scanner
+ let record_batches: Bound<'_, PyIterator> = scanner
.call_method0("to_batches")
.map_err(|err| InnerDataFusionError::External(Box::new(err)))?
.iter()
@@ -264,7 +262,7 @@ impl ExecutionPlanProperties for DatasetExec {
impl DisplayAs for DatasetExec {
fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) ->
std::fmt::Result {
Python::with_gil(|py| {
- let number_of_fragments = self.fragments.as_ref(py).len();
+ let number_of_fragments = self.fragments.bind(py).len();
match t {
DisplayFormatType::Default | DisplayFormatType::Verbose => {
let projected_columns: Vec<String> = self
@@ -274,7 +272,7 @@ impl DisplayAs for DatasetExec {
.map(|x| x.name().to_owned())
.collect();
if let Some(filter_expr) = &self.filter_expr {
- let filter_expr =
filter_expr.as_ref(py).str().or(Err(std::fmt::Error))?;
+ let filter_expr =
filter_expr.bind(py).str().or(Err(std::fmt::Error))?;
write!(
f,
"DatasetExec: number_of_fragments={},
filter_expr={}, projection=[{}]",
diff --git a/src/expr.rs b/src/expr.rs
index 09a773c..dc1de66 100644
--- a/src/expr.rs
+++ b/src/expr.rs
@@ -553,7 +553,7 @@ impl PyExpr {
}
/// Initializes the `expr` module to match the pattern of `datafusion-expr`
https://docs.rs/datafusion-expr/latest/datafusion_expr/
-pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
+pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyExpr>()?;
m.add_class::<PyColumn>()?;
m.add_class::<PyLiteral>()?;
diff --git a/src/functions.rs b/src/functions.rs
index 09cdee6..8e395ae 100644
--- a/src/functions.rs
+++ b/src/functions.rs
@@ -670,7 +670,7 @@ aggregate_function!(bit_xor, BitXor);
aggregate_function!(bool_and, BoolAnd);
aggregate_function!(bool_or, BoolOr);
-pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
+pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(abs))?;
m.add_wrapped(wrap_pyfunction!(acos))?;
m.add_wrapped(wrap_pyfunction!(acosh))?;
diff --git a/src/lib.rs b/src/lib.rs
index a696ebf..71c27e1 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -72,7 +72,7 @@ pub(crate) struct TokioRuntime(tokio::runtime::Runtime);
/// The higher-level public API is defined in pure python files under the
/// datafusion directory.
#[pymodule]
-fn _internal(py: Python, m: &PyModule) -> PyResult<()> {
+fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> {
// Register the Tokio Runtime as a module attribute so we can reuse it
m.add(
"runtime",
@@ -94,35 +94,35 @@ fn _internal(py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<physical_plan::PyExecutionPlan>()?;
// Register `common` as a submodule. Matching `datafusion-common`
https://docs.rs/datafusion-common/latest/datafusion_common/
- let common = PyModule::new(py, "common")?;
- common::init_module(common)?;
- m.add_submodule(common)?;
+ let common = PyModule::new_bound(py, "common")?;
+ common::init_module(&common)?;
+ m.add_submodule(&common)?;
// Register `expr` as a submodule. Matching `datafusion-expr`
https://docs.rs/datafusion-expr/latest/datafusion_expr/
- let expr = PyModule::new(py, "expr")?;
- expr::init_module(expr)?;
- m.add_submodule(expr)?;
+ let expr = PyModule::new_bound(py, "expr")?;
+ expr::init_module(&expr)?;
+ m.add_submodule(&expr)?;
// Register the functions as a submodule
- let funcs = PyModule::new(py, "functions")?;
- functions::init_module(funcs)?;
- m.add_submodule(funcs)?;
+ let funcs = PyModule::new_bound(py, "functions")?;
+ functions::init_module(&funcs)?;
+ m.add_submodule(&funcs)?;
- let store = PyModule::new(py, "object_store")?;
- store::init_module(store)?;
- m.add_submodule(store)?;
+ let store = PyModule::new_bound(py, "object_store")?;
+ store::init_module(&store)?;
+ m.add_submodule(&store)?;
// Register substrait as a submodule
#[cfg(feature = "substrait")]
- setup_substrait_module(py, m)?;
+ setup_substrait_module(py, &m)?;
Ok(())
}
#[cfg(feature = "substrait")]
-fn setup_substrait_module(py: Python, m: &PyModule) -> PyResult<()> {
- let substrait = PyModule::new(py, "substrait")?;
- substrait::init_module(substrait)?;
- m.add_submodule(substrait)?;
+fn setup_substrait_module(py: Python, m: &Bound<'_, PyModule>) -> PyResult<()>
{
+ let substrait = PyModule::new_bound(py, "substrait")?;
+ substrait::init_module(&substrait)?;
+ m.add_submodule(&substrait)?;
Ok(())
}
diff --git a/src/pyarrow_filter_expression.rs b/src/pyarrow_filter_expression.rs
index 64124fb..fca8851 100644
--- a/src/pyarrow_filter_expression.rs
+++ b/src/pyarrow_filter_expression.rs
@@ -32,9 +32,9 @@ pub(crate) struct PyArrowFilterExpression(PyObject);
fn operator_to_py<'py>(
operator: &Operator,
- op: &'py PyModule,
-) -> Result<&'py PyAny, DataFusionError> {
- let py_op: &PyAny = match operator {
+ op: &Bound<'py, PyModule>,
+) -> Result<Bound<'py, PyAny>, DataFusionError> {
+ let py_op: Bound<'_, PyAny> = match operator {
Operator::Eq => op.getattr("eq")?,
Operator::NotEq => op.getattr("ne")?,
Operator::Lt => op.getattr("lt")?,
@@ -96,9 +96,9 @@ impl TryFrom<&Expr> for PyArrowFilterExpression {
//
https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Expression.html#pyarrow-dataset-expression
fn try_from(expr: &Expr) -> Result<Self, Self::Error> {
Python::with_gil(|py| {
- let pc = Python::import(py, "pyarrow.compute")?;
- let op_module = Python::import(py, "operator")?;
- let pc_expr: Result<&PyAny, DataFusionError> = match expr {
+ let pc = Python::import_bound(py, "pyarrow.compute")?;
+ let op_module = Python::import_bound(py, "operator")?;
+ let pc_expr: Result<Bound<'_, PyAny>, DataFusionError> = match
expr {
Expr::Column(Column { name, .. }) =>
Ok(pc.getattr("field")?.call1((name,))?),
Expr::Literal(v) => match v {
ScalarValue::Boolean(Some(b)) =>
Ok(pc.getattr("scalar")?.call1((*b,))?),
@@ -118,7 +118,7 @@ impl TryFrom<&Expr> for PyArrowFilterExpression {
))),
},
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
- let operator = operator_to_py(op, op_module)?;
+ let operator = operator_to_py(op, &op_module)?;
let left =
PyArrowFilterExpression::try_from(left.as_ref())?.0;
let right =
PyArrowFilterExpression::try_from(right.as_ref())?.0;
Ok(operator.call1((left, right))?)
@@ -131,14 +131,15 @@ impl TryFrom<&Expr> for PyArrowFilterExpression {
Expr::IsNotNull(expr) => {
let py_expr =
PyArrowFilterExpression::try_from(expr.as_ref())?
.0
- .into_ref(py);
+ .into_bound(py);
Ok(py_expr.call_method0("is_valid")?)
}
Expr::IsNull(expr) => {
let expr =
PyArrowFilterExpression::try_from(expr.as_ref())?
.0
- .into_ref(py);
- Ok(expr.call_method1("is_null", (expr,))?)
+ .into_bound(py);
+ // TODO: this expression does not seems like it should be
`call_method0`
+ Ok(expr.clone().call_method1("is_null", (expr,))?)
}
Expr::Between(Between {
expr,
@@ -168,7 +169,7 @@ impl TryFrom<&Expr> for PyArrowFilterExpression {
}) => {
let expr =
PyArrowFilterExpression::try_from(expr.as_ref())?
.0
- .into_ref(py);
+ .into_bound(py);
let scalars = extract_scalar_list(list, py)?;
let ret = expr.call_method1("isin", (scalars,))?;
let invert = op_module.getattr("invert")?;
diff --git a/src/store.rs b/src/store.rs
index 542cfa9..846d96a 100644
--- a/src/store.rs
+++ b/src/store.rs
@@ -219,7 +219,7 @@ impl PyAmazonS3Context {
}
}
-pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
+pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyAmazonS3Context>()?;
m.add_class::<PyMicrosoftAzureContext>()?;
m.add_class::<PyGoogleCloudContext>()?;
diff --git a/src/substrait.rs b/src/substrait.rs
index ff83f6f..1e9e16c 100644
--- a/src/substrait.rs
+++ b/src/substrait.rs
@@ -40,7 +40,7 @@ impl PyPlan {
self.plan
.encode(&mut proto_bytes)
.map_err(DataFusionError::EncodeError)?;
- Ok(PyBytes::new(py, &proto_bytes).into())
+ Ok(PyBytes::new_bound(py, &proto_bytes).unbind().into())
}
}
@@ -76,7 +76,7 @@ impl PySubstraitSerializer {
pub fn serialize_to_plan(sql: &str, ctx: PySessionContext, py: Python) ->
PyResult<PyPlan> {
match PySubstraitSerializer::serialize_bytes(sql, ctx, py) {
Ok(proto_bytes) => {
- let proto_bytes: &PyBytes =
proto_bytes.as_ref(py).downcast().unwrap();
+ let proto_bytes =
proto_bytes.bind(py).downcast::<PyBytes>().unwrap();
PySubstraitSerializer::deserialize_bytes(proto_bytes.as_bytes().to_vec(), py)
}
Err(e) => Err(py_datafusion_err(e)),
@@ -87,7 +87,7 @@ impl PySubstraitSerializer {
pub fn serialize_bytes(sql: &str, ctx: PySessionContext, py: Python) ->
PyResult<PyObject> {
let proto_bytes: Vec<u8> = wait_for_future(py,
serializer::serialize_bytes(sql, &ctx.ctx))
.map_err(DataFusionError::from)?;
- Ok(PyBytes::new(py, &proto_bytes).into())
+ Ok(PyBytes::new_bound(py, &proto_bytes).unbind().into())
}
#[staticmethod]
@@ -140,7 +140,7 @@ impl PySubstraitConsumer {
}
}
-pub fn init_module(m: &PyModule) -> PyResult<()> {
+pub fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyPlan>()?;
m.add_class::<PySubstraitConsumer>()?;
m.add_class::<PySubstraitProducer>()?;
diff --git a/src/udaf.rs b/src/udaf.rs
index 9aea761..7b5e036 100644
--- a/src/udaf.rs
+++ b/src/udaf.rs
@@ -17,7 +17,7 @@
use std::sync::Arc;
-use pyo3::{prelude::*, types::PyBool, types::PyTuple};
+use pyo3::{prelude::*, types::PyTuple};
use datafusion::arrow::array::{Array, ArrayRef};
use datafusion::arrow::datatypes::DataType;
@@ -42,12 +42,12 @@ impl RustAccumulator {
impl Accumulator for RustAccumulator {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
- Python::with_gil(|py|
self.accum.as_ref(py).call_method0("state")?.extract())
+ Python::with_gil(|py|
self.accum.bind(py).call_method0("state")?.extract())
.map_err(|e| DataFusionError::Execution(format!("{e}")))
}
fn evaluate(&mut self) -> Result<ScalarValue> {
- Python::with_gil(|py|
self.accum.as_ref(py).call_method0("evaluate")?.extract())
+ Python::with_gil(|py|
self.accum.bind(py).call_method0("evaluate")?.extract())
.map_err(|e| DataFusionError::Execution(format!("{e}")))
}
@@ -58,11 +58,11 @@ impl Accumulator for RustAccumulator {
.iter()
.map(|arg| arg.into_data().to_pyarrow(py).unwrap())
.collect::<Vec<_>>();
- let py_args = PyTuple::new(py, py_args);
+ let py_args = PyTuple::new_bound(py, py_args);
// 2. call function
self.accum
- .as_ref(py)
+ .bind(py)
.call_method1("update", py_args)
.map_err(|e| DataFusionError::Execution(format!("{e}")))?;
@@ -82,7 +82,7 @@ impl Accumulator for RustAccumulator {
// 2. call merge
self.accum
- .as_ref(py)
+ .bind(py)
.call_method1("merge", (state,))
.map_err(|e| DataFusionError::Execution(format!("{e}")))?;
@@ -101,11 +101,11 @@ impl Accumulator for RustAccumulator {
.iter()
.map(|arg| arg.into_data().to_pyarrow(py).unwrap())
.collect::<Vec<_>>();
- let py_args = PyTuple::new(py, py_args);
+ let py_args = PyTuple::new_bound(py, py_args);
// 2. call function
self.accum
- .as_ref(py)
+ .bind(py)
.call_method1("retract_batch", py_args)
.map_err(|e| DataFusionError::Execution(format!("{e}")))?;
@@ -114,12 +114,12 @@ impl Accumulator for RustAccumulator {
}
fn supports_retract_batch(&self) -> bool {
- Python::with_gil(|py| {
- let x: Result<&PyAny, PyErr> =
- self.accum.as_ref(py).call_method0("supports_retract_batch");
- let x: &PyAny = x.unwrap_or(PyBool::new(py, false));
- x.extract().unwrap_or(false)
- })
+ Python::with_gil(
+ |py| match
self.accum.bind(py).call_method0("supports_retract_batch") {
+ Ok(x) => x.extract().unwrap_or(false),
+ Err(_) => false,
+ },
+ )
}
}
diff --git a/src/utils.rs b/src/utils.rs
index 62cf07d..4334f86 100644
--- a/src/utils.rs
+++ b/src/utils.rs
@@ -24,13 +24,13 @@ use tokio::runtime::Runtime;
/// Utility to get the Tokio Runtime from Python
pub(crate) fn get_tokio_runtime(py: Python) -> PyRef<TokioRuntime> {
- let datafusion = py.import("datafusion._internal").unwrap();
+ let datafusion = py.import_bound("datafusion._internal").unwrap();
let tmp = datafusion.getattr("runtime").unwrap();
match tmp.extract::<PyRef<TokioRuntime>>() {
Ok(runtime) => runtime,
Err(_e) => {
let rt = TokioRuntime(tokio::runtime::Runtime::new().unwrap());
- let obj: &PyAny = Py::new(py, rt).unwrap().into_ref(py);
+ let obj: Bound<'_, TokioRuntime> = Py::new(py,
rt).unwrap().into_bound(py);
obj.extract().unwrap()
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]