This is an automated email from the ASF dual-hosted git repository.
timsaucer pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-python.git
The following commit(s) were added to refs/heads/main by this push:
new 3481904f Fix Python UDAF list-of-timestamps return by enforcing
list-valued scalars and caching PyArrow types (#1347)
3481904f is described below
commit 3481904fe9c770ffc218cc043d47d35e860148a6
Author: kosiew <[email protected]>
AuthorDate: Thu Feb 19 00:15:06 2026 +0800
Fix Python UDAF list-of-timestamps return by enforcing list-valued scalars
and caching PyArrow types (#1347)
* Implement UDAF improvements for list type handling
Store UDAF return type in Rust accumulator and wrap
pyarrow Array/ChunkedArray returns into list scalars
for list-like return types. Add a UDAF test to return
a list of timestamps via a pyarrow array, validating
the aggregate output for correctness.
* Document UDAF list-valued scalar returns
Add documented list-valued scalar returns for UDAF
accumulators, including an example with pa.scalar and a note
about unsupported pyarrow.Array returns from evaluate().
Also, introduce a UDAF FAQ entry detailing list-returning
patterns and required return_type/state_type definitions.
* Fix pyarrow calls and improve type handling in RustAccumulator
* Refactor RustAccumulator to support pyarrow array types and improve type
checking for list types
* Fixed PyO3 type mismatch by cloning Array/ChunkedArray types before
unbinding and binding fresh copies when checking array-likeness, eliminating
the Bound reference error
* Add timezone information to datetime objects in
test_udaf_list_timestamp_return
* clippy fix
* Refactor RustAccumulator and utility functions for improved type handling
and conversion from Python objects to Arrow types
* Enhance PyArrow integration by refining type handling and conversion in
RustAccumulator and utility functions
* Fix array data binding in py_obj_to_scalar_value function
* Implement single point for scalar conversion from python objects
* Add unit tests and simplify python wrapper for literal
* Add nanoarrow and arro3-core to dev dependencies. Sort the dependencies
alphabetically.
* Refactor common code into helper function so we do not duplicate it.
* Update import path to access Scalar type
* Add test for generic python objects that support the C interface
* Update unit test to pass back either pyarrow array or array wrapped as
scalar
* Update tests to pass back raw python values or pyarrow scalar
* Expand on user documentation for how to return list arrays
* More user documentation
---------
Co-authored-by: Tim Saucer <[email protected]>
---
.../user-guide/common-operations/udf-and-udfa.rst | 33 ++++--
pyproject.toml | 20 ++--
python/datafusion/expr.py | 3 -
python/datafusion/user_defined.py | 22 +++-
python/tests/test_expr.py | 45 ++++++++
python/tests/test_udaf.py | 89 ++++++++++++++--
src/common/data_type.rs | 3 +
src/config.rs | 6 +-
src/dataframe.rs | 9 +-
src/pyarrow_util.rs | 118 +++++++++++++++++++--
src/udaf.rs | 34 +++---
src/udwf.rs | 1 -
src/utils.rs | 18 ----
13 files changed, 320 insertions(+), 81 deletions(-)
diff --git a/docs/source/user-guide/common-operations/udf-and-udfa.rst
b/docs/source/user-guide/common-operations/udf-and-udfa.rst
index d554e1e2..f669721a 100644
--- a/docs/source/user-guide/common-operations/udf-and-udfa.rst
+++ b/docs/source/user-guide/common-operations/udf-and-udfa.rst
@@ -123,7 +123,7 @@ also see how the inputs to ``update`` and ``merge`` differ.
.. code-block:: python
- import pyarrow
+ import pyarrow as pa
import pyarrow.compute
import datafusion
from datafusion import col, udaf, Accumulator
@@ -136,16 +136,16 @@ also see how the inputs to ``update`` and ``merge``
differ.
def __init__(self):
self._sum = 0.0
- def update(self, values_a: pyarrow.Array, values_b: pyarrow.Array) ->
None:
+ def update(self, values_a: pa.Array, values_b: pa.Array) -> None:
self._sum = self._sum + pyarrow.compute.sum(values_a).as_py() -
pyarrow.compute.sum(values_b).as_py()
- def merge(self, states: List[pyarrow.Array]) -> None:
+ def merge(self, states: list[pa.Array]) -> None:
self._sum = self._sum + pyarrow.compute.sum(states[0]).as_py()
- def state(self) -> pyarrow.Array:
- return pyarrow.array([self._sum])
+ def state(self) -> list[pa.Scalar]:
+ return [pyarrow.scalar(self._sum)]
- def evaluate(self) -> pyarrow.Scalar:
+ def evaluate(self) -> pa.Scalar:
return pyarrow.scalar(self._sum)
ctx = datafusion.SessionContext()
@@ -156,10 +156,29 @@ also see how the inputs to ``update`` and ``merge``
differ.
}
)
- my_udaf = udaf(MyAccumulator, [pyarrow.float64(), pyarrow.float64()],
pyarrow.float64(), [pyarrow.float64()], 'stable')
+ my_udaf = udaf(MyAccumulator, [pa.float64(), pa.float64()], pa.float64(),
[pa.float64()], 'stable')
df.aggregate([], [my_udaf(col("a"), col("b")).alias("col_diff")])
+FAQ
+^^^
+
+**How do I return a list from a UDAF?**
+
+Both the ``evaluate`` and the ``state`` functions expect to return scalar
values.
+If you wish to return a list array as a scalar value, the best practice is to
+wrap the values in a ``pyarrow.Scalar`` object. For example, you can return a
+timestamp list with ``pa.scalar([...], type=pa.list_(pa.timestamp("ms")))`` and
+register the appropriate return or state types as
+``return_type=pa.list_(pa.timestamp("ms"))`` and
+``state_type=[pa.list_(pa.timestamp("ms"))]``, respectively.
+
+As of DataFusion 52.0.0 , you can pass return any Python object, including a
+PyArrow array, as the return value(s) for these functions and DataFusion will
+attempt to create a scalar type from the value. DataFusion has been tested to
+convert PyArrow, nanoarrow, and arro3 objects as well as primitive data types
+like integers, strings, and so on.
+
Window Functions
----------------
diff --git a/pyproject.toml b/pyproject.toml
index d315dbe1..5a5128a2 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -173,27 +173,29 @@ ignore-words-list = ["ans", "IST"]
[dependency-groups]
dev = [
+ "arro3-core==0.6.5",
+ "codespell==2.4.1",
"maturin>=1.8.1",
+ "nanoarrow==0.8.0",
"numpy>1.25.0;python_version<'3.14'",
"numpy>=2.3.2;python_version>='3.14'",
- "pyarrow>=19.0.0",
"pre-commit>=4.3.0",
- "pyyaml>=6.0.3",
+ "pyarrow>=19.0.0",
+ "pygithub==2.5.0",
"pytest>=7.4.4",
"pytest-asyncio>=0.23.3",
+ "pyyaml>=6.0.3",
"ruff>=0.9.1",
"toml>=0.10.2",
- "pygithub==2.5.0",
- "codespell==2.4.1",
]
docs = [
- "sphinx>=7.1.2",
- "pydata-sphinx-theme==0.8.0",
- "myst-parser>=3.0.1",
- "jinja2>=3.1.5",
"ipython>=8.12.3",
+ "jinja2>=3.1.5",
+ "myst-parser>=3.0.1",
"pandas>=2.0.3",
"pickleshare>=0.7.5",
- "sphinx-autoapi>=3.4.0",
+ "pydata-sphinx-theme==0.8.0",
"setuptools>=75.3.0",
+ "sphinx>=7.1.2",
+ "sphinx-autoapi>=3.4.0",
]
diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py
index 695fe7c4..9df58f52 100644
--- a/python/datafusion/expr.py
+++ b/python/datafusion/expr.py
@@ -562,8 +562,6 @@ class Expr:
"""
if isinstance(value, str):
value = pa.scalar(value, type=pa.string_view())
- if not isinstance(value, pa.Scalar):
- value = pa.scalar(value)
return Expr(expr_internal.RawExpr.literal(value))
@staticmethod
@@ -576,7 +574,6 @@ class Expr:
"""
if isinstance(value, str):
value = pa.scalar(value, type=pa.string_view())
- value = value if isinstance(value, pa.Scalar) else pa.scalar(value)
return Expr(expr_internal.RawExpr.literal_with_metadata(value,
metadata))
diff --git a/python/datafusion/user_defined.py
b/python/datafusion/user_defined.py
index 5dd62656..d4e5302b 100644
--- a/python/datafusion/user_defined.py
+++ b/python/datafusion/user_defined.py
@@ -298,7 +298,16 @@ class Accumulator(metaclass=ABCMeta):
@abstractmethod
def state(self) -> list[pa.Scalar]:
- """Return the current state."""
+ """Return the current state.
+
+ While this function template expects PyArrow Scalar values return type,
+ you can return any value that can be converted into a Scalar. This
+ includes basic Python data types such as integers and strings. In
+ addition to primitive types, we currently support PyArrow, nanoarrow,
+ and arro3 objects in addition to primitive data types. Other objects
+ that support the Arrow FFI standard will be given a "best attempt" at
+ conversion to scalar objects.
+ """
@abstractmethod
def update(self, *values: pa.Array) -> None:
@@ -310,7 +319,16 @@ class Accumulator(metaclass=ABCMeta):
@abstractmethod
def evaluate(self) -> pa.Scalar:
- """Return the resultant value."""
+ """Return the resultant value.
+
+ While this function template expects a PyArrow Scalar value return
type,
+ you can return any value that can be converted into a Scalar. This
+ includes basic Python data types such as integers and strings. In
+ addition to primitive types, we currently support PyArrow, nanoarrow,
+ and arro3 objects in addition to primitive data types. Other objects
+ that support the Arrow FFI standard will be given a "best attempt" at
+ conversion to scalar objects.
+ """
class AggregateUDFExportable(Protocol):
diff --git a/python/tests/test_expr.py b/python/tests/test_expr.py
index 39e48f7c..92251827 100644
--- a/python/tests/test_expr.py
+++ b/python/tests/test_expr.py
@@ -20,6 +20,8 @@ from concurrent.futures import ThreadPoolExecutor
from datetime import date, datetime, time, timezone
from decimal import Decimal
+import arro3.core
+import nanoarrow
import pyarrow as pa
import pytest
from datafusion import (
@@ -980,6 +982,49 @@ def test_literal_metadata(ctx):
assert expected_field.metadata == actual_field.metadata
+def test_scalar_conversion() -> None:
+ class WrappedPyArrow:
+ """Wrapper class for testing __arrow_c_array__."""
+
+ def __init__(self, val: pa.Array) -> None:
+ self.val = val
+
+ def __arrow_c_array__(self, requested_schema=None):
+ return
self.val.__arrow_c_array__(requested_schema=requested_schema)
+
+ expected_value = lit(1)
+ assert str(expected_value) == "Expr(Int64(1))"
+
+ # Test pyarrow imports
+ assert expected_value == lit(pa.scalar(1))
+ assert expected_value == lit(pa.scalar(1, type=pa.int32()))
+
+ # Test nanoarrow
+ na_scalar = nanoarrow.Array([1], nanoarrow.int32())[0]
+ assert expected_value == lit(na_scalar)
+
+ # Test pyo3
+ arro3_scalar = arro3.core.Scalar(1, type=arro3.core.DataType.int32())
+ assert expected_value == lit(arro3_scalar)
+
+ generic_scalar = WrappedPyArrow(pa.array([1]))
+ assert expected_value == lit(generic_scalar)
+
+ expected_value = lit([1, 2, 3])
+ assert str(expected_value) == "Expr(List([1, 2, 3]))"
+
+ assert expected_value == lit(pa.scalar([1, 2, 3]))
+
+ na_array = nanoarrow.Array([1, 2, 3], nanoarrow.int32())
+ assert expected_value == lit(na_array)
+
+ arro3_array = arro3.core.Array([1, 2, 3], type=arro3.core.DataType.int32())
+ assert expected_value == lit(arro3_array)
+
+ generic_array = WrappedPyArrow(pa.array([1, 2, 3]))
+ assert expected_value == lit(generic_array)
+
+
def test_ensure_expr():
e = col("a")
assert ensure_expr(e) is e.expr
diff --git a/python/tests/test_udaf.py b/python/tests/test_udaf.py
index 453ff6f4..8cd480e3 100644
--- a/python/tests/test_udaf.py
+++ b/python/tests/test_udaf.py
@@ -17,6 +17,8 @@
from __future__ import annotations
+from datetime import datetime, timezone
+
import pyarrow as pa
import pyarrow.compute as pc
import pytest
@@ -26,23 +28,28 @@ from datafusion import Accumulator, column, udaf
class Summarize(Accumulator):
"""Interface of a user-defined accumulation."""
- def __init__(self, initial_value: float = 0.0):
- self._sum = pa.scalar(initial_value)
+ def __init__(self, initial_value: float = 0.0, as_scalar: bool = False):
+ self._sum = initial_value
+ self.as_scalar = as_scalar
def state(self) -> list[pa.Scalar]:
+ if self.as_scalar:
+ return [pa.scalar(self._sum)]
return [self._sum]
def update(self, values: pa.Array) -> None:
# Not nice since pyarrow scalars can't be summed yet.
# This breaks on `None`
- self._sum = pa.scalar(self._sum.as_py() + pc.sum(values).as_py())
+ self._sum = self._sum + pc.sum(values).as_py()
def merge(self, states: list[pa.Array]) -> None:
# Not nice since pyarrow scalars can't be summed yet.
# This breaks on `None`
- self._sum = pa.scalar(self._sum.as_py() + pc.sum(states[0]).as_py())
+ self._sum = self._sum + pc.sum(states[0]).as_py()
def evaluate(self) -> pa.Scalar:
+ if self.as_scalar:
+ return pa.scalar(self._sum)
return self._sum
@@ -58,6 +65,30 @@ class MissingMethods(Accumulator):
return [self._sum]
+class CollectTimestamps(Accumulator):
+ def __init__(self, wrap_in_scalar: bool):
+ self._values: list[datetime] = []
+ self.wrap_in_scalar = wrap_in_scalar
+
+ def state(self) -> list[pa.Scalar]:
+ if self.wrap_in_scalar:
+ return [pa.scalar(self._values, type=pa.list_(pa.timestamp("ns")))]
+ return [pa.array(self._values, type=pa.timestamp("ns"))]
+
+ def update(self, values: pa.Array) -> None:
+ self._values.extend(values.to_pylist())
+
+ def merge(self, states: list[pa.Array]) -> None:
+ for state in states[0].to_pylist():
+ if state is not None:
+ self._values.extend(state)
+
+ def evaluate(self) -> pa.Scalar:
+ if self.wrap_in_scalar:
+ return pa.scalar(self._values, type=pa.list_(pa.timestamp("ns")))
+ return pa.array(self._values, type=pa.timestamp("ns"))
+
+
@pytest.fixture
def df(ctx):
# create a RecordBatch and a new DataFrame from it
@@ -137,11 +168,12 @@ def test_udaf_decorator_aggregate(df):
assert result.column(0) == pa.array([1.0 + 2.0 + 3.0])
-def test_udaf_aggregate_with_arguments(df):
[email protected]("as_scalar", [True, False])
+def test_udaf_aggregate_with_arguments(df, as_scalar):
bias = 10.0
summarize = udaf(
- lambda: Summarize(bias),
+ lambda: Summarize(initial_value=bias, as_scalar=as_scalar),
pa.float64(),
pa.float64(),
[pa.float64()],
@@ -217,3 +249,48 @@ def test_register_udaf(ctx, df) -> None:
df_result = ctx.sql("select summarize(b) from test_table")
assert df_result.collect()[0][0][0].as_py() == 14.0
+
+
[email protected]("wrap_in_scalar", [True, False])
+def test_udaf_list_timestamp_return(ctx, wrap_in_scalar) -> None:
+ timestamps1 = [
+ datetime(2024, 1, 1, tzinfo=timezone.utc),
+ datetime(2024, 1, 2, tzinfo=timezone.utc),
+ ]
+ timestamps2 = [
+ datetime(2024, 1, 3, tzinfo=timezone.utc),
+ datetime(2024, 1, 4, tzinfo=timezone.utc),
+ ]
+ batch1 = pa.RecordBatch.from_arrays(
+ [pa.array(timestamps1, type=pa.timestamp("ns"))],
+ names=["ts"],
+ )
+ batch2 = pa.RecordBatch.from_arrays(
+ [pa.array(timestamps2, type=pa.timestamp("ns"))],
+ names=["ts"],
+ )
+ df = ctx.create_dataframe([[batch1], [batch2]], name="timestamp_table")
+
+ list_type = pa.list_(
+ pa.field("item", type=pa.timestamp("ns"), nullable=wrap_in_scalar)
+ )
+
+ collect = udaf(
+ lambda: CollectTimestamps(wrap_in_scalar),
+ pa.timestamp("ns"),
+ list_type,
+ [list_type],
+ volatility="immutable",
+ )
+
+ result = df.aggregate([], [collect(column("ts"))]).collect()[0]
+
+ # There is no guarantee about the ordering of the batches, so perform a
sort
+ # to get consistent results. Alternatively we could sort on evaluate().
+ assert (
+ result.column(0).values.sort()
+ == pa.array(
+ [[*timestamps1, *timestamps2]],
+ type=list_type,
+ ).values
+ )
diff --git a/src/common/data_type.rs b/src/common/data_type.rs
index 55848da5..1ff332eb 100644
--- a/src/common/data_type.rs
+++ b/src/common/data_type.rs
@@ -22,6 +22,9 @@ use datafusion::logical_expr::expr::NullTreatment as
DFNullTreatment;
use pyo3::exceptions::{PyNotImplementedError, PyValueError};
use pyo3::prelude::*;
+/// A [`ScalarValue`] wrapped in a Python object. This struct allows for
conversion
+/// from a variety of Python objects into a [`ScalarValue`]. See
+/// ``FromPyArrow::from_pyarrow_bound`` conversion details.
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd)]
pub struct PyScalarValue(pub ScalarValue);
diff --git a/src/config.rs b/src/config.rs
index 583dea7e..38936e6c 100644
--- a/src/config.rs
+++ b/src/config.rs
@@ -22,8 +22,8 @@ use parking_lot::RwLock;
use pyo3::prelude::*;
use pyo3::types::*;
+use crate::common::data_type::PyScalarValue;
use crate::errors::PyDataFusionResult;
-use crate::utils::py_obj_to_scalar_value;
#[pyclass(name = "Config", module = "datafusion", subclass, frozen)]
#[derive(Clone)]
pub(crate) struct PyConfig {
@@ -65,9 +65,9 @@ impl PyConfig {
/// Set a configuration option
pub fn set(&self, key: &str, value: Py<PyAny>, py: Python) ->
PyDataFusionResult<()> {
- let scalar_value = py_obj_to_scalar_value(py, value)?;
+ let scalar_value: PyScalarValue = value.extract(py)?;
let mut options = self.config.write();
- options.set(key, scalar_value.to_string().as_str())?;
+ options.set(key, scalar_value.0.to_string().as_str())?;
Ok(())
}
diff --git a/src/dataframe.rs b/src/dataframe.rs
index fe039593..53fab58c 100644
--- a/src/dataframe.rs
+++ b/src/dataframe.rs
@@ -48,6 +48,7 @@ use pyo3::prelude::*;
use pyo3::pybacked::PyBackedStr;
use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods};
+use crate::common::data_type::PyScalarValue;
use crate::errors::{PyDataFusionError, PyDataFusionResult, py_datafusion_err};
use crate::expr::PyExpr;
use crate::expr::sort_expr::{PySortExpr, to_sort_expressions};
@@ -55,9 +56,7 @@ use crate::physical_plan::PyExecutionPlan;
use crate::record_batch::{PyRecordBatchStream, poll_next_batch};
use crate::sql::logical::PyLogicalPlan;
use crate::table::{PyTable, TempViewTable};
-use crate::utils::{
- is_ipython_env, py_obj_to_scalar_value, spawn_future, validate_pycapsule,
wait_for_future,
-};
+use crate::utils::{is_ipython_env, spawn_future, validate_pycapsule,
wait_for_future};
/// File-level static CStr for the Arrow array stream capsule name.
static ARROW_ARRAY_STREAM_NAME: &CStr = cstr!("arrow_array_stream");
@@ -1191,14 +1190,14 @@ impl PyDataFrame {
columns: Option<Vec<PyBackedStr>>,
py: Python,
) -> PyDataFusionResult<Self> {
- let scalar_value = py_obj_to_scalar_value(py, value)?;
+ let scalar_value: PyScalarValue = value.extract(py)?;
let cols = match columns {
Some(col_names) => col_names.iter().map(|c|
c.to_string()).collect(),
None => Vec::new(), // Empty vector means fill null for all columns
};
- let df = self.df.as_ref().clone().fill_null(scalar_value, cols)?;
+ let df = self.df.as_ref().clone().fill_null(scalar_value.0, cols)?;
Ok(Self::new(df))
}
}
diff --git a/src/pyarrow_util.rs b/src/pyarrow_util.rs
index 264cfd34..2a119274 100644
--- a/src/pyarrow_util.rs
+++ b/src/pyarrow_util.rs
@@ -17,8 +17,13 @@
//! Conversions between PyArrow and DataFusion types
-use arrow::array::{Array, ArrayData};
+use std::sync::Arc;
+
+use arrow::array::{Array, ArrayData, ArrayRef, ListArray, make_array};
+use arrow::buffer::OffsetBuffer;
+use arrow::datatypes::Field;
use arrow::pyarrow::{FromPyArrow, ToPyArrow};
+use datafusion::common::exec_err;
use datafusion::scalar::ScalarValue;
use pyo3::types::{PyAnyMethods, PyList};
use pyo3::{Bound, FromPyObject, PyAny, PyResult, Python};
@@ -26,21 +31,114 @@ use pyo3::{Bound, FromPyObject, PyAny, PyResult, Python};
use crate::common::data_type::PyScalarValue;
use crate::errors::PyDataFusionError;
+/// Helper function to turn an Array into a ScalarValue. If ``as_list_array``
is true,
+/// the array will be turned into a ``ListArray``. Otherwise, we extract the
first value
+/// from the array.
+fn array_to_scalar_value(array: ArrayRef, as_list_array: bool) ->
PyResult<PyScalarValue> {
+ if as_list_array {
+ let field = Arc::new(Field::new_list_field(
+ array.data_type().clone(),
+ array.nulls().is_some(),
+ ));
+ let offsets = OffsetBuffer::from_lengths(vec![array.len()]);
+ let list_array = ListArray::new(field, offsets, array, None);
+ Ok(PyScalarValue(ScalarValue::List(Arc::new(list_array))))
+ } else {
+ let scalar = ScalarValue::try_from_array(&array,
0).map_err(PyDataFusionError::from)?;
+ Ok(PyScalarValue(scalar))
+ }
+}
+
+/// Helper function to take any Python object that contains an Arrow PyCapsule
+/// interface and attempt to extract a scalar value from it. If `as_list_array`
+/// is true, the array will be turned into a ``ListArray``. Otherwise, we
extract
+/// the first value from the array.
+fn pyobj_extract_scalar_via_capsule(
+ value: &Bound<'_, PyAny>,
+ as_list_array: bool,
+) -> PyResult<PyScalarValue> {
+ let array_data = ArrayData::from_pyarrow_bound(value)?;
+ let array = make_array(array_data);
+
+ array_to_scalar_value(array, as_list_array)
+}
+
impl FromPyArrow for PyScalarValue {
fn from_pyarrow_bound(value: &Bound<'_, PyAny>) -> PyResult<Self> {
let py = value.py();
- let typ = value.getattr("type")?;
+ let pyarrow_mod = py.import("pyarrow");
- // construct pyarrow array from the python value and pyarrow type
- let factory = py.import("pyarrow")?.getattr("array")?;
- let args = PyList::new(py, [value])?;
- let array = factory.call1((args, typ))?;
+ // Is it a PyArrow object?
+ if let Ok(pa) = pyarrow_mod.as_ref() {
+ let scalar_type = pa.getattr("Scalar")?;
+ if value.is_instance(&scalar_type)? {
+ let typ = value.getattr("type")?;
- // convert the pyarrow array to rust array using C data interface
- let array =
arrow::array::make_array(ArrayData::from_pyarrow_bound(&array)?);
- let scalar = ScalarValue::try_from_array(&array,
0).map_err(PyDataFusionError::from)?;
+ // construct pyarrow array from the python value and pyarrow
type
+ let factory = py.import("pyarrow")?.getattr("array")?;
+ let args = PyList::new(py, [value])?;
+ let array = factory.call1((args, typ))?;
- Ok(PyScalarValue(scalar))
+ return pyobj_extract_scalar_via_capsule(&array, false);
+ }
+
+ let array_type = pa.getattr("Array")?;
+ if value.is_instance(&array_type)? {
+ return pyobj_extract_scalar_via_capsule(value, true);
+ }
+ }
+
+ // Is it a NanoArrow scalar?
+ if let Ok(na) = py.import("nanoarrow") {
+ let scalar_type = py.import("nanoarrow.array")?.getattr("Scalar")?;
+ if value.is_instance(&scalar_type)? {
+ return pyobj_extract_scalar_via_capsule(value, false);
+ }
+ let array_type = na.getattr("Array")?;
+ if value.is_instance(&array_type)? {
+ return pyobj_extract_scalar_via_capsule(value, true);
+ }
+ }
+
+ // Is it a arro3 scalar?
+ if let Ok(arro3) = py.import("arro3").and_then(|arro3|
arro3.getattr("core")) {
+ let scalar_type = arro3.getattr("Scalar")?;
+ if value.is_instance(&scalar_type)? {
+ return pyobj_extract_scalar_via_capsule(value, false);
+ }
+ let array_type = arro3.getattr("Array")?;
+ if value.is_instance(&array_type)? {
+ return pyobj_extract_scalar_via_capsule(value, true);
+ }
+ }
+
+ // Does it have a PyCapsule interface but isn't one of our known
libraries?
+ // If so do our "best guess". Try checking type name, and if that fails
+ // return a single value if the length is 1 and return a List value
otherwise
+ if value.hasattr("__arrow_c_array__")? {
+ let type_name = value.get_type().repr()?;
+ if type_name.contains("Scalar")? {
+ return pyobj_extract_scalar_via_capsule(value, false);
+ }
+ if type_name.contains("Array")? {
+ return pyobj_extract_scalar_via_capsule(value, true);
+ }
+
+ let array_data = ArrayData::from_pyarrow_bound(value)?;
+ let array = make_array(array_data);
+
+ let as_array_list = array.len() != 1;
+ return array_to_scalar_value(array, as_array_list);
+ }
+
+ // Last attempt - try to create a PyArrow scalar from a plain Python
object
+ if let Ok(pa) = pyarrow_mod.as_ref() {
+ let scalar = pa.call_method1("scalar", (value,))?;
+
+ PyScalarValue::from_pyarrow_bound(&scalar)
+ } else {
+ exec_err!("Unable to import scalar
value").map_err(PyDataFusionError::from)?
+ }
}
}
diff --git a/src/udaf.rs b/src/udaf.rs
index 298a59b0..cc166035 100644
--- a/src/udaf.rs
+++ b/src/udaf.rs
@@ -17,7 +17,7 @@
use std::sync::Arc;
-use datafusion::arrow::array::{Array, ArrayRef};
+use datafusion::arrow::array::ArrayRef;
use datafusion::arrow::datatypes::DataType;
use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
use datafusion::common::ScalarValue;
@@ -47,24 +47,24 @@ impl RustAccumulator {
impl Accumulator for RustAccumulator {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
- Python::attach(|py| {
- self.accum
- .bind(py)
- .call_method0("state")?
- .extract::<Vec<PyScalarValue>>()
+ Python::attach(|py| -> PyResult<Vec<ScalarValue>> {
+ let values = self.accum.bind(py).call_method0("state")?;
+ let mut scalars = Vec::new();
+ for item in values.try_iter()? {
+ let item: Bound<'_, PyAny> = item?;
+ let scalar = item.extract::<PyScalarValue>()?.0;
+ scalars.push(scalar);
+ }
+ Ok(scalars)
})
- .map(|v| v.into_iter().map(|x| x.0).collect())
.map_err(|e| DataFusionError::Execution(format!("{e}")))
}
fn evaluate(&mut self) -> Result<ScalarValue> {
- Python::attach(|py| {
- self.accum
- .bind(py)
- .call_method0("evaluate")?
- .extract::<PyScalarValue>()
+ Python::attach(|py| -> PyResult<ScalarValue> {
+ let value = self.accum.bind(py).call_method0("evaluate")?;
+ value.extract::<PyScalarValue>().map(|v| v.0)
})
- .map(|v| v.0)
.map_err(|e| DataFusionError::Execution(format!("{e}")))
}
@@ -73,7 +73,7 @@ impl Accumulator for RustAccumulator {
// 1. cast args to Pyarrow array
let py_args = values
.iter()
- .map(|arg| arg.into_data().to_pyarrow(py).unwrap())
+ .map(|arg| arg.to_data().to_pyarrow(py).unwrap())
.collect::<Vec<_>>();
let py_args = PyTuple::new(py,
py_args).map_err(to_datafusion_err)?;
@@ -94,7 +94,7 @@ impl Accumulator for RustAccumulator {
.iter()
.map(|state| {
state
- .into_data()
+ .to_data()
.to_pyarrow(py)
.map_err(|e|
DataFusionError::Execution(format!("{e}")))
})
@@ -119,7 +119,7 @@ impl Accumulator for RustAccumulator {
// 1. cast args to Pyarrow array
let py_args = values
.iter()
- .map(|arg| arg.into_data().to_pyarrow(py).unwrap())
+ .map(|arg| arg.to_data().to_pyarrow(py).unwrap())
.collect::<Vec<_>>();
let py_args = PyTuple::new(py,
py_args).map_err(to_datafusion_err)?;
@@ -144,7 +144,7 @@ impl Accumulator for RustAccumulator {
}
pub fn to_rust_accumulator(accum: Py<PyAny>) -> AccumulatorFactoryFunction {
- Arc::new(move |_| -> Result<Box<dyn Accumulator>> {
+ Arc::new(move |_args| -> Result<Box<dyn Accumulator>> {
let accum = Python::attach(|py| {
accum
.call0(py)
diff --git a/src/udwf.rs b/src/udwf.rs
index b5b795d2..4bf55a85 100644
--- a/src/udwf.rs
+++ b/src/udwf.rs
@@ -94,7 +94,6 @@ impl PartitionEvaluator for RustPartitionEvaluator {
}
fn evaluate_all(&mut self, values: &[ArrayRef], num_rows: usize) ->
Result<ArrayRef> {
- println!("evaluate all called with number of values {}", values.len());
Python::attach(|py| {
let py_values = PyList::new(
py,
diff --git a/src/utils.rs b/src/utils.rs
index 311f8fc8..28b58ba0 100644
--- a/src/utils.rs
+++ b/src/utils.rs
@@ -19,7 +19,6 @@ use std::future::Future;
use std::sync::{Arc, OnceLock};
use std::time::Duration;
-use datafusion::common::ScalarValue;
use datafusion::datasource::TableProvider;
use datafusion::execution::context::SessionContext;
use datafusion::logical_expr::Volatility;
@@ -34,7 +33,6 @@ use tokio::task::JoinHandle;
use tokio::time::sleep;
use crate::TokioRuntime;
-use crate::common::data_type::PyScalarValue;
use crate::context::PySessionContext;
use crate::errors::{PyDataFusionError, PyDataFusionResult, py_datafusion_err,
to_datafusion_err};
@@ -199,22 +197,6 @@ pub(crate) fn table_provider_from_pycapsule<'py>(
}
}
-pub(crate) fn py_obj_to_scalar_value(py: Python, obj: Py<PyAny>) ->
PyResult<ScalarValue> {
- // convert Python object to PyScalarValue to ScalarValue
-
- let pa = py.import("pyarrow")?;
-
- // Convert Python object to PyArrow scalar
- let scalar = pa.call_method1("scalar", (obj,))?;
-
- // Convert PyArrow scalar to PyScalarValue
- let py_scalar = PyScalarValue::extract_bound(scalar.as_ref())
- .map_err(|e| PyValueError::new_err(format!("Failed to extract
PyScalarValue: {e}")))?;
-
- // Convert PyScalarValue to ScalarValue
- Ok(py_scalar.into())
-}
-
pub(crate) fn extract_logical_extension_codec(
py: Python,
obj: Option<Bound<PyAny>>,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]