This is an automated email from the ASF dual-hosted git repository.

timsaucer pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 3481904f Fix Python UDAF list-of-timestamps return by enforcing 
list-valued scalars and caching PyArrow types (#1347)
3481904f is described below

commit 3481904fe9c770ffc218cc043d47d35e860148a6
Author: kosiew <[email protected]>
AuthorDate: Thu Feb 19 00:15:06 2026 +0800

    Fix Python UDAF list-of-timestamps return by enforcing list-valued scalars 
and caching PyArrow types (#1347)
    
    * Implement UDAF improvements for list type handling
    
    Store UDAF return type in Rust accumulator and wrap
    pyarrow Array/ChunkedArray returns into list scalars
    for list-like return types. Add a UDAF test to return
    a list of timestamps via a pyarrow array, validating
    the aggregate output for correctness.
    
    * Document UDAF list-valued scalar returns
    
    Add documented list-valued scalar returns for UDAF
    accumulators, including an example with pa.scalar and a note
    about unsupported pyarrow.Array returns from evaluate().
    Also, introduce a UDAF FAQ entry detailing list-returning
    patterns and required return_type/state_type definitions.
    
    * Fix pyarrow calls and improve type handling in RustAccumulator
    
    * Refactor RustAccumulator to support pyarrow array types and improve type 
checking for list types
    
    * Fixed PyO3 type mismatch by cloning Array/ChunkedArray types before 
unbinding and binding fresh copies when checking array-likeness, eliminating 
the Bound reference error
    
    * Add timezone information to datetime objects in 
test_udaf_list_timestamp_return
    
    * clippy fix
    
    * Refactor RustAccumulator and utility functions for improved type handling 
and conversion from Python objects to Arrow types
    
    * Enhance PyArrow integration by refining type handling and conversion in 
RustAccumulator and utility functions
    
    * Fix array data binding in py_obj_to_scalar_value function
    
    * Implement single point for scalar conversion from python objects
    
    * Add unit tests and simplify python wrapper for literal
    
    * Add nanoarrow and arro3-core to dev dependencies. Sort the dependencies 
alphabetically.
    
    * Refactor common code into helper function so we do not duplicate it.
    
    * Update import path to access Scalar type
    
    * Add test for generic python objects that support the C interface
    
    * Update unit test to pass back either pyarrow array or array wrapped as 
scalar
    
    * Update tests to pass back raw python values or pyarrow scalar
    
    * Expand on user documentation for how to return list arrays
    
    * More user documentation
    
    ---------
    
    Co-authored-by: Tim Saucer <[email protected]>
---
 .../user-guide/common-operations/udf-and-udfa.rst  |  33 ++++--
 pyproject.toml                                     |  20 ++--
 python/datafusion/expr.py                          |   3 -
 python/datafusion/user_defined.py                  |  22 +++-
 python/tests/test_expr.py                          |  45 ++++++++
 python/tests/test_udaf.py                          |  89 ++++++++++++++--
 src/common/data_type.rs                            |   3 +
 src/config.rs                                      |   6 +-
 src/dataframe.rs                                   |   9 +-
 src/pyarrow_util.rs                                | 118 +++++++++++++++++++--
 src/udaf.rs                                        |  34 +++---
 src/udwf.rs                                        |   1 -
 src/utils.rs                                       |  18 ----
 13 files changed, 320 insertions(+), 81 deletions(-)

diff --git a/docs/source/user-guide/common-operations/udf-and-udfa.rst 
b/docs/source/user-guide/common-operations/udf-and-udfa.rst
index d554e1e2..f669721a 100644
--- a/docs/source/user-guide/common-operations/udf-and-udfa.rst
+++ b/docs/source/user-guide/common-operations/udf-and-udfa.rst
@@ -123,7 +123,7 @@ also see how the inputs to ``update`` and ``merge`` differ.
 
 .. code-block:: python
 
-    import pyarrow
+    import pyarrow as pa
     import pyarrow.compute
     import datafusion
     from datafusion import col, udaf, Accumulator
@@ -136,16 +136,16 @@ also see how the inputs to ``update`` and ``merge`` 
differ.
         def __init__(self):
             self._sum = 0.0
 
-        def update(self, values_a: pyarrow.Array, values_b: pyarrow.Array) -> 
None:
+        def update(self, values_a: pa.Array, values_b: pa.Array) -> None:
             self._sum = self._sum + pyarrow.compute.sum(values_a).as_py() - 
pyarrow.compute.sum(values_b).as_py()
 
-        def merge(self, states: List[pyarrow.Array]) -> None:
+        def merge(self, states: list[pa.Array]) -> None:
             self._sum = self._sum + pyarrow.compute.sum(states[0]).as_py()
 
-        def state(self) -> pyarrow.Array:
-            return pyarrow.array([self._sum])
+        def state(self) -> list[pa.Scalar]:
+            return [pyarrow.scalar(self._sum)]
 
-        def evaluate(self) -> pyarrow.Scalar:
+        def evaluate(self) -> pa.Scalar:
             return pyarrow.scalar(self._sum)
 
     ctx = datafusion.SessionContext()
@@ -156,10 +156,29 @@ also see how the inputs to ``update`` and ``merge`` 
differ.
         }
     )
 
-    my_udaf = udaf(MyAccumulator, [pyarrow.float64(), pyarrow.float64()], 
pyarrow.float64(), [pyarrow.float64()], 'stable')
+    my_udaf = udaf(MyAccumulator, [pa.float64(), pa.float64()], pa.float64(), 
[pa.float64()], 'stable')
 
     df.aggregate([], [my_udaf(col("a"), col("b")).alias("col_diff")])
 
+FAQ
+^^^
+
+**How do I return a list from a UDAF?**
+
+Both the ``evaluate`` and the ``state`` functions expect to return scalar 
values.
+If you wish to return a list array as a scalar value, the best practice is to
+wrap the values in a ``pyarrow.Scalar`` object. For example, you can return a
+timestamp list with ``pa.scalar([...], type=pa.list_(pa.timestamp("ms")))`` and
+register the appropriate return or state types as
+``return_type=pa.list_(pa.timestamp("ms"))`` and
+``state_type=[pa.list_(pa.timestamp("ms"))]``, respectively.
+
+As of DataFusion 52.0.0 , you can pass return any Python object, including a
+PyArrow array, as the return value(s) for these functions and DataFusion will
+attempt to create a scalar type from the value. DataFusion has been tested to
+convert PyArrow, nanoarrow, and arro3 objects as well as primitive data types
+like integers, strings, and so on.
+
 Window Functions
 ----------------
 
diff --git a/pyproject.toml b/pyproject.toml
index d315dbe1..5a5128a2 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -173,27 +173,29 @@ ignore-words-list = ["ans", "IST"]
 
 [dependency-groups]
 dev = [
+  "arro3-core==0.6.5",
+  "codespell==2.4.1",
   "maturin>=1.8.1",
+  "nanoarrow==0.8.0",
   "numpy>1.25.0;python_version<'3.14'",
   "numpy>=2.3.2;python_version>='3.14'",
-  "pyarrow>=19.0.0",
   "pre-commit>=4.3.0",
-  "pyyaml>=6.0.3",
+  "pyarrow>=19.0.0",
+  "pygithub==2.5.0",
   "pytest>=7.4.4",
   "pytest-asyncio>=0.23.3",
+  "pyyaml>=6.0.3",
   "ruff>=0.9.1",
   "toml>=0.10.2",
-  "pygithub==2.5.0",
-  "codespell==2.4.1",
 ]
 docs = [
-  "sphinx>=7.1.2",
-  "pydata-sphinx-theme==0.8.0",
-  "myst-parser>=3.0.1",
-  "jinja2>=3.1.5",
   "ipython>=8.12.3",
+  "jinja2>=3.1.5",
+  "myst-parser>=3.0.1",
   "pandas>=2.0.3",
   "pickleshare>=0.7.5",
-  "sphinx-autoapi>=3.4.0",
+  "pydata-sphinx-theme==0.8.0",
   "setuptools>=75.3.0",
+  "sphinx>=7.1.2",
+  "sphinx-autoapi>=3.4.0",
 ]
diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py
index 695fe7c4..9df58f52 100644
--- a/python/datafusion/expr.py
+++ b/python/datafusion/expr.py
@@ -562,8 +562,6 @@ class Expr:
         """
         if isinstance(value, str):
             value = pa.scalar(value, type=pa.string_view())
-        if not isinstance(value, pa.Scalar):
-            value = pa.scalar(value)
         return Expr(expr_internal.RawExpr.literal(value))
 
     @staticmethod
@@ -576,7 +574,6 @@ class Expr:
         """
         if isinstance(value, str):
             value = pa.scalar(value, type=pa.string_view())
-        value = value if isinstance(value, pa.Scalar) else pa.scalar(value)
 
         return Expr(expr_internal.RawExpr.literal_with_metadata(value, 
metadata))
 
diff --git a/python/datafusion/user_defined.py 
b/python/datafusion/user_defined.py
index 5dd62656..d4e5302b 100644
--- a/python/datafusion/user_defined.py
+++ b/python/datafusion/user_defined.py
@@ -298,7 +298,16 @@ class Accumulator(metaclass=ABCMeta):
 
     @abstractmethod
     def state(self) -> list[pa.Scalar]:
-        """Return the current state."""
+        """Return the current state.
+
+        While this function template expects PyArrow Scalar values return type,
+        you can return any value that can be converted into a Scalar. This
+        includes basic Python data types such as integers and strings. In
+        addition to primitive types, we currently support PyArrow, nanoarrow,
+        and arro3 objects in addition to primitive data types. Other objects
+        that support the Arrow FFI standard will be given a "best attempt" at
+        conversion to scalar objects.
+        """
 
     @abstractmethod
     def update(self, *values: pa.Array) -> None:
@@ -310,7 +319,16 @@ class Accumulator(metaclass=ABCMeta):
 
     @abstractmethod
     def evaluate(self) -> pa.Scalar:
-        """Return the resultant value."""
+        """Return the resultant value.
+
+        While this function template expects a PyArrow Scalar value return 
type,
+        you can return any value that can be converted into a Scalar. This
+        includes basic Python data types such as integers and strings. In
+        addition to primitive types, we currently support PyArrow, nanoarrow,
+        and arro3 objects in addition to primitive data types. Other objects
+        that support the Arrow FFI standard will be given a "best attempt" at
+        conversion to scalar objects.
+        """
 
 
 class AggregateUDFExportable(Protocol):
diff --git a/python/tests/test_expr.py b/python/tests/test_expr.py
index 39e48f7c..92251827 100644
--- a/python/tests/test_expr.py
+++ b/python/tests/test_expr.py
@@ -20,6 +20,8 @@ from concurrent.futures import ThreadPoolExecutor
 from datetime import date, datetime, time, timezone
 from decimal import Decimal
 
+import arro3.core
+import nanoarrow
 import pyarrow as pa
 import pytest
 from datafusion import (
@@ -980,6 +982,49 @@ def test_literal_metadata(ctx):
         assert expected_field.metadata == actual_field.metadata
 
 
+def test_scalar_conversion() -> None:
+    class WrappedPyArrow:
+        """Wrapper class for testing __arrow_c_array__."""
+
+        def __init__(self, val: pa.Array) -> None:
+            self.val = val
+
+        def __arrow_c_array__(self, requested_schema=None):
+            return 
self.val.__arrow_c_array__(requested_schema=requested_schema)
+
+    expected_value = lit(1)
+    assert str(expected_value) == "Expr(Int64(1))"
+
+    # Test pyarrow imports
+    assert expected_value == lit(pa.scalar(1))
+    assert expected_value == lit(pa.scalar(1, type=pa.int32()))
+
+    # Test nanoarrow
+    na_scalar = nanoarrow.Array([1], nanoarrow.int32())[0]
+    assert expected_value == lit(na_scalar)
+
+    # Test pyo3
+    arro3_scalar = arro3.core.Scalar(1, type=arro3.core.DataType.int32())
+    assert expected_value == lit(arro3_scalar)
+
+    generic_scalar = WrappedPyArrow(pa.array([1]))
+    assert expected_value == lit(generic_scalar)
+
+    expected_value = lit([1, 2, 3])
+    assert str(expected_value) == "Expr(List([1, 2, 3]))"
+
+    assert expected_value == lit(pa.scalar([1, 2, 3]))
+
+    na_array = nanoarrow.Array([1, 2, 3], nanoarrow.int32())
+    assert expected_value == lit(na_array)
+
+    arro3_array = arro3.core.Array([1, 2, 3], type=arro3.core.DataType.int32())
+    assert expected_value == lit(arro3_array)
+
+    generic_array = WrappedPyArrow(pa.array([1, 2, 3]))
+    assert expected_value == lit(generic_array)
+
+
 def test_ensure_expr():
     e = col("a")
     assert ensure_expr(e) is e.expr
diff --git a/python/tests/test_udaf.py b/python/tests/test_udaf.py
index 453ff6f4..8cd480e3 100644
--- a/python/tests/test_udaf.py
+++ b/python/tests/test_udaf.py
@@ -17,6 +17,8 @@
 
 from __future__ import annotations
 
+from datetime import datetime, timezone
+
 import pyarrow as pa
 import pyarrow.compute as pc
 import pytest
@@ -26,23 +28,28 @@ from datafusion import Accumulator, column, udaf
 class Summarize(Accumulator):
     """Interface of a user-defined accumulation."""
 
-    def __init__(self, initial_value: float = 0.0):
-        self._sum = pa.scalar(initial_value)
+    def __init__(self, initial_value: float = 0.0, as_scalar: bool = False):
+        self._sum = initial_value
+        self.as_scalar = as_scalar
 
     def state(self) -> list[pa.Scalar]:
+        if self.as_scalar:
+            return [pa.scalar(self._sum)]
         return [self._sum]
 
     def update(self, values: pa.Array) -> None:
         # Not nice since pyarrow scalars can't be summed yet.
         # This breaks on `None`
-        self._sum = pa.scalar(self._sum.as_py() + pc.sum(values).as_py())
+        self._sum = self._sum + pc.sum(values).as_py()
 
     def merge(self, states: list[pa.Array]) -> None:
         # Not nice since pyarrow scalars can't be summed yet.
         # This breaks on `None`
-        self._sum = pa.scalar(self._sum.as_py() + pc.sum(states[0]).as_py())
+        self._sum = self._sum + pc.sum(states[0]).as_py()
 
     def evaluate(self) -> pa.Scalar:
+        if self.as_scalar:
+            return pa.scalar(self._sum)
         return self._sum
 
 
@@ -58,6 +65,30 @@ class MissingMethods(Accumulator):
         return [self._sum]
 
 
+class CollectTimestamps(Accumulator):
+    def __init__(self, wrap_in_scalar: bool):
+        self._values: list[datetime] = []
+        self.wrap_in_scalar = wrap_in_scalar
+
+    def state(self) -> list[pa.Scalar]:
+        if self.wrap_in_scalar:
+            return [pa.scalar(self._values, type=pa.list_(pa.timestamp("ns")))]
+        return [pa.array(self._values, type=pa.timestamp("ns"))]
+
+    def update(self, values: pa.Array) -> None:
+        self._values.extend(values.to_pylist())
+
+    def merge(self, states: list[pa.Array]) -> None:
+        for state in states[0].to_pylist():
+            if state is not None:
+                self._values.extend(state)
+
+    def evaluate(self) -> pa.Scalar:
+        if self.wrap_in_scalar:
+            return pa.scalar(self._values, type=pa.list_(pa.timestamp("ns")))
+        return pa.array(self._values, type=pa.timestamp("ns"))
+
+
 @pytest.fixture
 def df(ctx):
     # create a RecordBatch and a new DataFrame from it
@@ -137,11 +168,12 @@ def test_udaf_decorator_aggregate(df):
     assert result.column(0) == pa.array([1.0 + 2.0 + 3.0])
 
 
-def test_udaf_aggregate_with_arguments(df):
[email protected]("as_scalar", [True, False])
+def test_udaf_aggregate_with_arguments(df, as_scalar):
     bias = 10.0
 
     summarize = udaf(
-        lambda: Summarize(bias),
+        lambda: Summarize(initial_value=bias, as_scalar=as_scalar),
         pa.float64(),
         pa.float64(),
         [pa.float64()],
@@ -217,3 +249,48 @@ def test_register_udaf(ctx, df) -> None:
     df_result = ctx.sql("select summarize(b) from test_table")
 
     assert df_result.collect()[0][0][0].as_py() == 14.0
+
+
[email protected]("wrap_in_scalar", [True, False])
+def test_udaf_list_timestamp_return(ctx, wrap_in_scalar) -> None:
+    timestamps1 = [
+        datetime(2024, 1, 1, tzinfo=timezone.utc),
+        datetime(2024, 1, 2, tzinfo=timezone.utc),
+    ]
+    timestamps2 = [
+        datetime(2024, 1, 3, tzinfo=timezone.utc),
+        datetime(2024, 1, 4, tzinfo=timezone.utc),
+    ]
+    batch1 = pa.RecordBatch.from_arrays(
+        [pa.array(timestamps1, type=pa.timestamp("ns"))],
+        names=["ts"],
+    )
+    batch2 = pa.RecordBatch.from_arrays(
+        [pa.array(timestamps2, type=pa.timestamp("ns"))],
+        names=["ts"],
+    )
+    df = ctx.create_dataframe([[batch1], [batch2]], name="timestamp_table")
+
+    list_type = pa.list_(
+        pa.field("item", type=pa.timestamp("ns"), nullable=wrap_in_scalar)
+    )
+
+    collect = udaf(
+        lambda: CollectTimestamps(wrap_in_scalar),
+        pa.timestamp("ns"),
+        list_type,
+        [list_type],
+        volatility="immutable",
+    )
+
+    result = df.aggregate([], [collect(column("ts"))]).collect()[0]
+
+    # There is no guarantee about the ordering of the batches, so perform a 
sort
+    # to get consistent results. Alternatively we could sort on evaluate().
+    assert (
+        result.column(0).values.sort()
+        == pa.array(
+            [[*timestamps1, *timestamps2]],
+            type=list_type,
+        ).values
+    )
diff --git a/src/common/data_type.rs b/src/common/data_type.rs
index 55848da5..1ff332eb 100644
--- a/src/common/data_type.rs
+++ b/src/common/data_type.rs
@@ -22,6 +22,9 @@ use datafusion::logical_expr::expr::NullTreatment as 
DFNullTreatment;
 use pyo3::exceptions::{PyNotImplementedError, PyValueError};
 use pyo3::prelude::*;
 
+/// A [`ScalarValue`] wrapped in a Python object. This struct allows for 
conversion
+/// from a variety of Python objects into a [`ScalarValue`]. See
+/// ``FromPyArrow::from_pyarrow_bound`` conversion details.
 #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd)]
 pub struct PyScalarValue(pub ScalarValue);
 
diff --git a/src/config.rs b/src/config.rs
index 583dea7e..38936e6c 100644
--- a/src/config.rs
+++ b/src/config.rs
@@ -22,8 +22,8 @@ use parking_lot::RwLock;
 use pyo3::prelude::*;
 use pyo3::types::*;
 
+use crate::common::data_type::PyScalarValue;
 use crate::errors::PyDataFusionResult;
-use crate::utils::py_obj_to_scalar_value;
 #[pyclass(name = "Config", module = "datafusion", subclass, frozen)]
 #[derive(Clone)]
 pub(crate) struct PyConfig {
@@ -65,9 +65,9 @@ impl PyConfig {
 
     /// Set a configuration option
     pub fn set(&self, key: &str, value: Py<PyAny>, py: Python) -> 
PyDataFusionResult<()> {
-        let scalar_value = py_obj_to_scalar_value(py, value)?;
+        let scalar_value: PyScalarValue = value.extract(py)?;
         let mut options = self.config.write();
-        options.set(key, scalar_value.to_string().as_str())?;
+        options.set(key, scalar_value.0.to_string().as_str())?;
         Ok(())
     }
 
diff --git a/src/dataframe.rs b/src/dataframe.rs
index fe039593..53fab58c 100644
--- a/src/dataframe.rs
+++ b/src/dataframe.rs
@@ -48,6 +48,7 @@ use pyo3::prelude::*;
 use pyo3::pybacked::PyBackedStr;
 use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods};
 
+use crate::common::data_type::PyScalarValue;
 use crate::errors::{PyDataFusionError, PyDataFusionResult, py_datafusion_err};
 use crate::expr::PyExpr;
 use crate::expr::sort_expr::{PySortExpr, to_sort_expressions};
@@ -55,9 +56,7 @@ use crate::physical_plan::PyExecutionPlan;
 use crate::record_batch::{PyRecordBatchStream, poll_next_batch};
 use crate::sql::logical::PyLogicalPlan;
 use crate::table::{PyTable, TempViewTable};
-use crate::utils::{
-    is_ipython_env, py_obj_to_scalar_value, spawn_future, validate_pycapsule, 
wait_for_future,
-};
+use crate::utils::{is_ipython_env, spawn_future, validate_pycapsule, 
wait_for_future};
 
 /// File-level static CStr for the Arrow array stream capsule name.
 static ARROW_ARRAY_STREAM_NAME: &CStr = cstr!("arrow_array_stream");
@@ -1191,14 +1190,14 @@ impl PyDataFrame {
         columns: Option<Vec<PyBackedStr>>,
         py: Python,
     ) -> PyDataFusionResult<Self> {
-        let scalar_value = py_obj_to_scalar_value(py, value)?;
+        let scalar_value: PyScalarValue = value.extract(py)?;
 
         let cols = match columns {
             Some(col_names) => col_names.iter().map(|c| 
c.to_string()).collect(),
             None => Vec::new(), // Empty vector means fill null for all columns
         };
 
-        let df = self.df.as_ref().clone().fill_null(scalar_value, cols)?;
+        let df = self.df.as_ref().clone().fill_null(scalar_value.0, cols)?;
         Ok(Self::new(df))
     }
 }
diff --git a/src/pyarrow_util.rs b/src/pyarrow_util.rs
index 264cfd34..2a119274 100644
--- a/src/pyarrow_util.rs
+++ b/src/pyarrow_util.rs
@@ -17,8 +17,13 @@
 
 //! Conversions between PyArrow and DataFusion types
 
-use arrow::array::{Array, ArrayData};
+use std::sync::Arc;
+
+use arrow::array::{Array, ArrayData, ArrayRef, ListArray, make_array};
+use arrow::buffer::OffsetBuffer;
+use arrow::datatypes::Field;
 use arrow::pyarrow::{FromPyArrow, ToPyArrow};
+use datafusion::common::exec_err;
 use datafusion::scalar::ScalarValue;
 use pyo3::types::{PyAnyMethods, PyList};
 use pyo3::{Bound, FromPyObject, PyAny, PyResult, Python};
@@ -26,21 +31,114 @@ use pyo3::{Bound, FromPyObject, PyAny, PyResult, Python};
 use crate::common::data_type::PyScalarValue;
 use crate::errors::PyDataFusionError;
 
+/// Helper function to turn an Array into a ScalarValue. If ``as_list_array`` 
is true,
+/// the array will be turned into a ``ListArray``. Otherwise, we extract the 
first value
+/// from the array.
+fn array_to_scalar_value(array: ArrayRef, as_list_array: bool) -> 
PyResult<PyScalarValue> {
+    if as_list_array {
+        let field = Arc::new(Field::new_list_field(
+            array.data_type().clone(),
+            array.nulls().is_some(),
+        ));
+        let offsets = OffsetBuffer::from_lengths(vec![array.len()]);
+        let list_array = ListArray::new(field, offsets, array, None);
+        Ok(PyScalarValue(ScalarValue::List(Arc::new(list_array))))
+    } else {
+        let scalar = ScalarValue::try_from_array(&array, 
0).map_err(PyDataFusionError::from)?;
+        Ok(PyScalarValue(scalar))
+    }
+}
+
+/// Helper function to take any Python object that contains an Arrow PyCapsule
+/// interface and attempt to extract a scalar value from it. If `as_list_array`
+/// is true, the array will be turned into a ``ListArray``. Otherwise, we 
extract
+/// the first value from the array.
+fn pyobj_extract_scalar_via_capsule(
+    value: &Bound<'_, PyAny>,
+    as_list_array: bool,
+) -> PyResult<PyScalarValue> {
+    let array_data = ArrayData::from_pyarrow_bound(value)?;
+    let array = make_array(array_data);
+
+    array_to_scalar_value(array, as_list_array)
+}
+
 impl FromPyArrow for PyScalarValue {
     fn from_pyarrow_bound(value: &Bound<'_, PyAny>) -> PyResult<Self> {
         let py = value.py();
-        let typ = value.getattr("type")?;
+        let pyarrow_mod = py.import("pyarrow");
 
-        // construct pyarrow array from the python value and pyarrow type
-        let factory = py.import("pyarrow")?.getattr("array")?;
-        let args = PyList::new(py, [value])?;
-        let array = factory.call1((args, typ))?;
+        // Is it a PyArrow object?
+        if let Ok(pa) = pyarrow_mod.as_ref() {
+            let scalar_type = pa.getattr("Scalar")?;
+            if value.is_instance(&scalar_type)? {
+                let typ = value.getattr("type")?;
 
-        // convert the pyarrow array to rust array using C data interface
-        let array = 
arrow::array::make_array(ArrayData::from_pyarrow_bound(&array)?);
-        let scalar = ScalarValue::try_from_array(&array, 
0).map_err(PyDataFusionError::from)?;
+                // construct pyarrow array from the python value and pyarrow 
type
+                let factory = py.import("pyarrow")?.getattr("array")?;
+                let args = PyList::new(py, [value])?;
+                let array = factory.call1((args, typ))?;
 
-        Ok(PyScalarValue(scalar))
+                return pyobj_extract_scalar_via_capsule(&array, false);
+            }
+
+            let array_type = pa.getattr("Array")?;
+            if value.is_instance(&array_type)? {
+                return pyobj_extract_scalar_via_capsule(value, true);
+            }
+        }
+
+        // Is it a NanoArrow scalar?
+        if let Ok(na) = py.import("nanoarrow") {
+            let scalar_type = py.import("nanoarrow.array")?.getattr("Scalar")?;
+            if value.is_instance(&scalar_type)? {
+                return pyobj_extract_scalar_via_capsule(value, false);
+            }
+            let array_type = na.getattr("Array")?;
+            if value.is_instance(&array_type)? {
+                return pyobj_extract_scalar_via_capsule(value, true);
+            }
+        }
+
+        // Is it a arro3 scalar?
+        if let Ok(arro3) = py.import("arro3").and_then(|arro3| 
arro3.getattr("core")) {
+            let scalar_type = arro3.getattr("Scalar")?;
+            if value.is_instance(&scalar_type)? {
+                return pyobj_extract_scalar_via_capsule(value, false);
+            }
+            let array_type = arro3.getattr("Array")?;
+            if value.is_instance(&array_type)? {
+                return pyobj_extract_scalar_via_capsule(value, true);
+            }
+        }
+
+        // Does it have a PyCapsule interface but isn't one of our known 
libraries?
+        // If so do our "best guess". Try checking type name, and if that fails
+        // return a single value if the length is 1 and return a List value 
otherwise
+        if value.hasattr("__arrow_c_array__")? {
+            let type_name = value.get_type().repr()?;
+            if type_name.contains("Scalar")? {
+                return pyobj_extract_scalar_via_capsule(value, false);
+            }
+            if type_name.contains("Array")? {
+                return pyobj_extract_scalar_via_capsule(value, true);
+            }
+
+            let array_data = ArrayData::from_pyarrow_bound(value)?;
+            let array = make_array(array_data);
+
+            let as_array_list = array.len() != 1;
+            return array_to_scalar_value(array, as_array_list);
+        }
+
+        // Last attempt - try to create a PyArrow scalar from a plain Python 
object
+        if let Ok(pa) = pyarrow_mod.as_ref() {
+            let scalar = pa.call_method1("scalar", (value,))?;
+
+            PyScalarValue::from_pyarrow_bound(&scalar)
+        } else {
+            exec_err!("Unable to import scalar 
value").map_err(PyDataFusionError::from)?
+        }
     }
 }
 
diff --git a/src/udaf.rs b/src/udaf.rs
index 298a59b0..cc166035 100644
--- a/src/udaf.rs
+++ b/src/udaf.rs
@@ -17,7 +17,7 @@
 
 use std::sync::Arc;
 
-use datafusion::arrow::array::{Array, ArrayRef};
+use datafusion::arrow::array::ArrayRef;
 use datafusion::arrow::datatypes::DataType;
 use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
 use datafusion::common::ScalarValue;
@@ -47,24 +47,24 @@ impl RustAccumulator {
 
 impl Accumulator for RustAccumulator {
     fn state(&mut self) -> Result<Vec<ScalarValue>> {
-        Python::attach(|py| {
-            self.accum
-                .bind(py)
-                .call_method0("state")?
-                .extract::<Vec<PyScalarValue>>()
+        Python::attach(|py| -> PyResult<Vec<ScalarValue>> {
+            let values = self.accum.bind(py).call_method0("state")?;
+            let mut scalars = Vec::new();
+            for item in values.try_iter()? {
+                let item: Bound<'_, PyAny> = item?;
+                let scalar = item.extract::<PyScalarValue>()?.0;
+                scalars.push(scalar);
+            }
+            Ok(scalars)
         })
-        .map(|v| v.into_iter().map(|x| x.0).collect())
         .map_err(|e| DataFusionError::Execution(format!("{e}")))
     }
 
     fn evaluate(&mut self) -> Result<ScalarValue> {
-        Python::attach(|py| {
-            self.accum
-                .bind(py)
-                .call_method0("evaluate")?
-                .extract::<PyScalarValue>()
+        Python::attach(|py| -> PyResult<ScalarValue> {
+            let value = self.accum.bind(py).call_method0("evaluate")?;
+            value.extract::<PyScalarValue>().map(|v| v.0)
         })
-        .map(|v| v.0)
         .map_err(|e| DataFusionError::Execution(format!("{e}")))
     }
 
@@ -73,7 +73,7 @@ impl Accumulator for RustAccumulator {
             // 1. cast args to Pyarrow array
             let py_args = values
                 .iter()
-                .map(|arg| arg.into_data().to_pyarrow(py).unwrap())
+                .map(|arg| arg.to_data().to_pyarrow(py).unwrap())
                 .collect::<Vec<_>>();
             let py_args = PyTuple::new(py, 
py_args).map_err(to_datafusion_err)?;
 
@@ -94,7 +94,7 @@ impl Accumulator for RustAccumulator {
                 .iter()
                 .map(|state| {
                     state
-                        .into_data()
+                        .to_data()
                         .to_pyarrow(py)
                         .map_err(|e| 
DataFusionError::Execution(format!("{e}")))
                 })
@@ -119,7 +119,7 @@ impl Accumulator for RustAccumulator {
             // 1. cast args to Pyarrow array
             let py_args = values
                 .iter()
-                .map(|arg| arg.into_data().to_pyarrow(py).unwrap())
+                .map(|arg| arg.to_data().to_pyarrow(py).unwrap())
                 .collect::<Vec<_>>();
             let py_args = PyTuple::new(py, 
py_args).map_err(to_datafusion_err)?;
 
@@ -144,7 +144,7 @@ impl Accumulator for RustAccumulator {
 }
 
 pub fn to_rust_accumulator(accum: Py<PyAny>) -> AccumulatorFactoryFunction {
-    Arc::new(move |_| -> Result<Box<dyn Accumulator>> {
+    Arc::new(move |_args| -> Result<Box<dyn Accumulator>> {
         let accum = Python::attach(|py| {
             accum
                 .call0(py)
diff --git a/src/udwf.rs b/src/udwf.rs
index b5b795d2..4bf55a85 100644
--- a/src/udwf.rs
+++ b/src/udwf.rs
@@ -94,7 +94,6 @@ impl PartitionEvaluator for RustPartitionEvaluator {
     }
 
     fn evaluate_all(&mut self, values: &[ArrayRef], num_rows: usize) -> 
Result<ArrayRef> {
-        println!("evaluate all called with number of values {}", values.len());
         Python::attach(|py| {
             let py_values = PyList::new(
                 py,
diff --git a/src/utils.rs b/src/utils.rs
index 311f8fc8..28b58ba0 100644
--- a/src/utils.rs
+++ b/src/utils.rs
@@ -19,7 +19,6 @@ use std::future::Future;
 use std::sync::{Arc, OnceLock};
 use std::time::Duration;
 
-use datafusion::common::ScalarValue;
 use datafusion::datasource::TableProvider;
 use datafusion::execution::context::SessionContext;
 use datafusion::logical_expr::Volatility;
@@ -34,7 +33,6 @@ use tokio::task::JoinHandle;
 use tokio::time::sleep;
 
 use crate::TokioRuntime;
-use crate::common::data_type::PyScalarValue;
 use crate::context::PySessionContext;
 use crate::errors::{PyDataFusionError, PyDataFusionResult, py_datafusion_err, 
to_datafusion_err};
 
@@ -199,22 +197,6 @@ pub(crate) fn table_provider_from_pycapsule<'py>(
     }
 }
 
-pub(crate) fn py_obj_to_scalar_value(py: Python, obj: Py<PyAny>) -> 
PyResult<ScalarValue> {
-    // convert Python object to PyScalarValue to ScalarValue
-
-    let pa = py.import("pyarrow")?;
-
-    // Convert Python object to PyArrow scalar
-    let scalar = pa.call_method1("scalar", (obj,))?;
-
-    // Convert PyArrow scalar to PyScalarValue
-    let py_scalar = PyScalarValue::extract_bound(scalar.as_ref())
-        .map_err(|e| PyValueError::new_err(format!("Failed to extract 
PyScalarValue: {e}")))?;
-
-    // Convert PyScalarValue to ScalarValue
-    Ok(py_scalar.into())
-}
-
 pub(crate) fn extract_logical_extension_codec(
     py: Python,
     obj: Option<Bound<PyAny>>,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to