This is an automated email from the ASF dual-hosted git repository.

kosiew pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-python.git


The following commit(s) were added to refs/heads/main by this push:
     new c4e74864 Add PyCapsule Type Support and Type Hint Enhancements for 
AggregateUDF in DataFusion Python Bindings (#1277)
c4e74864 is described below

commit c4e748641ad17c1c04c5f4411a32fb4bd3010904
Author: kosiew <[email protected]>
AuthorDate: Sat Nov 8 18:45:12 2025 +0800

    Add PyCapsule Type Support and Type Hint Enhancements for AggregateUDF in 
DataFusion Python Bindings (#1277)
    
    Added TypeGuard function _is_pycapsule() for lightweight PyCapsule type 
validation.
    Introduced _PyCapsule proxy class for static typing compatibility in 
non-type-checking contexts.
    Extended overloads in AggregateUDF.__init__ and AggregateUDF.udaf() to 
include AggregateUDFExportable | _PyCapsule argument types.
    Added stricter constructor argument validation for callable accumulators.
    Updated AggregateUDF.from_pycapsule() to support direct PyCapsule 
initialization.
    Refactored Rust PyAggregateUDF::from_pycapsule() logic to delegate 
PyCapsule validation to a new helper function aggregate_udf_from_capsule() for 
cleaner handling.
---
 python/datafusion/user_defined.py   | 65 ++++++++++++++++++++++++++++++++++---
 python/tests/test_pyclass_frozen.py |  3 +-
 src/udaf.rs                         | 33 ++++++++++++-------
 3 files changed, 82 insertions(+), 19 deletions(-)

diff --git a/python/datafusion/user_defined.py 
b/python/datafusion/user_defined.py
index 21b2de63..43a72c80 100644
--- a/python/datafusion/user_defined.py
+++ b/python/datafusion/user_defined.py
@@ -22,7 +22,7 @@ from __future__ import annotations
 import functools
 from abc import ABCMeta, abstractmethod
 from enum import Enum
-from typing import TYPE_CHECKING, Any, Protocol, TypeVar, overload
+from typing import TYPE_CHECKING, Any, Protocol, TypeGuard, TypeVar, cast, 
overload
 
 import pyarrow as pa
 
@@ -30,6 +30,8 @@ import datafusion._internal as df_internal
 from datafusion.expr import Expr
 
 if TYPE_CHECKING:
+    from _typeshed import CapsuleType as _PyCapsule
+
     _R = TypeVar("_R", bound=pa.DataType)
     from collections.abc import Callable
 
@@ -84,6 +86,11 @@ class ScalarUDFExportable(Protocol):
     def __datafusion_scalar_udf__(self) -> object: ...  # noqa: D105
 
 
+def _is_pycapsule(value: object) -> TypeGuard[_PyCapsule]:
+    """Return ``True`` when ``value`` is a CPython ``PyCapsule``."""
+    return value.__class__.__name__ == "PyCapsule"
+
+
 class ScalarUDF:
     """Class for performing scalar user-defined functions (UDF).
 
@@ -291,6 +298,7 @@ class AggregateUDF:
     also :py:class:`ScalarUDF` for operating on a row by row basis.
     """
 
+    @overload
     def __init__(
         self,
         name: str,
@@ -299,6 +307,27 @@ class AggregateUDF:
         return_type: pa.DataType,
         state_type: list[pa.DataType],
         volatility: Volatility | str,
+    ) -> None: ...
+
+    @overload
+    def __init__(
+        self,
+        name: str,
+        accumulator: AggregateUDFExportable,
+        input_types: None = ...,
+        return_type: None = ...,
+        state_type: None = ...,
+        volatility: None = ...,
+    ) -> None: ...
+
+    def __init__(
+        self,
+        name: str,
+        accumulator: Callable[[], Accumulator] | AggregateUDFExportable,
+        input_types: list[pa.DataType] | None,
+        return_type: pa.DataType | None,
+        state_type: list[pa.DataType] | None,
+        volatility: Volatility | str | None,
     ) -> None:
         """Instantiate a user-defined aggregate function (UDAF).
 
@@ -308,6 +337,18 @@ class AggregateUDF:
         if hasattr(accumulator, "__datafusion_aggregate_udf__"):
             self._udaf = df_internal.AggregateUDF.from_pycapsule(accumulator)
             return
+        if (
+            input_types is None
+            or return_type is None
+            or state_type is None
+            or volatility is None
+        ):
+            msg = (
+                "`input_types`, `return_type`, `state_type`, and `volatility` "
+                "must be provided when `accumulator` is callable."
+            )
+            raise TypeError(msg)
+
         self._udaf = df_internal.AggregateUDF(
             name,
             accumulator,
@@ -351,6 +392,14 @@ class AggregateUDF:
         name: str | None = None,
     ) -> AggregateUDF: ...
 
+    @overload
+    @staticmethod
+    def udaf(accum: AggregateUDFExportable) -> AggregateUDF: ...
+
+    @overload
+    @staticmethod
+    def udaf(accum: _PyCapsule) -> AggregateUDF: ...
+
     @staticmethod
     def udaf(*args: Any, **kwargs: Any):  # noqa: D417, C901
         """Create a new User-Defined Aggregate Function (UDAF).
@@ -471,7 +520,7 @@ class AggregateUDF:
 
             return decorator
 
-        if hasattr(args[0], "__datafusion_aggregate_udf__"):
+        if hasattr(args[0], "__datafusion_aggregate_udf__") or 
_is_pycapsule(args[0]):
             return AggregateUDF.from_pycapsule(args[0])
 
         if args and callable(args[0]):
@@ -481,16 +530,22 @@ class AggregateUDF:
         return _decorator(*args, **kwargs)
 
     @staticmethod
-    def from_pycapsule(func: AggregateUDFExportable) -> AggregateUDF:
+    def from_pycapsule(func: AggregateUDFExportable | _PyCapsule) -> 
AggregateUDF:
         """Create an Aggregate UDF from AggregateUDF PyCapsule object.
 
         This function will instantiate a Aggregate UDF that uses a DataFusion
         AggregateUDF that is exported via the FFI bindings.
         """
-        name = str(func.__class__)
+        if _is_pycapsule(func):
+            aggregate = cast(AggregateUDF, object.__new__(AggregateUDF))
+            aggregate._udaf = df_internal.AggregateUDF.from_pycapsule(func)
+            return aggregate
+
+        capsule = cast(AggregateUDFExportable, func)
+        name = str(capsule.__class__)
         return AggregateUDF(
             name=name,
-            accumulator=func,
+            accumulator=capsule,
             input_types=None,
             return_type=None,
             state_type=None,
diff --git a/python/tests/test_pyclass_frozen.py 
b/python/tests/test_pyclass_frozen.py
index 3500c5e3..33338bf5 100644
--- a/python/tests/test_pyclass_frozen.py
+++ b/python/tests/test_pyclass_frozen.py
@@ -35,8 +35,7 @@ ARG_STRING_RE = re.compile(
     r"(?P<key>[A-Za-z_][A-Za-z0-9_]*)\s*=\s*\"(?P<value>[^\"]+)\"",
 )
 STRUCT_NAME_RE = re.compile(
-    r"\b(?:pub\s+)?(?:struct|enum)\s+"
-    r"(?P<name>[A-Za-z_][A-Za-z0-9_]*)",
+    r"\b(?:pub\s+)?(?:struct|enum)\s+" r"(?P<name>[A-Za-z_][A-Za-z0-9_]*)",
 )
 
 
diff --git a/src/udaf.rs b/src/udaf.rs
index eab4581d..e48e35f8 100644
--- a/src/udaf.rs
+++ b/src/udaf.rs
@@ -154,6 +154,15 @@ pub fn to_rust_accumulator(accum: PyObject) -> 
AccumulatorFactoryFunction {
     })
 }
 
+fn aggregate_udf_from_capsule(capsule: &Bound<'_, PyCapsule>) -> 
PyDataFusionResult<AggregateUDF> {
+    validate_pycapsule(capsule, "datafusion_aggregate_udf")?;
+
+    let udaf = unsafe { capsule.reference::<FFI_AggregateUDF>() };
+    let udaf: ForeignAggregateUDF = udaf.try_into()?;
+
+    Ok(udaf.into())
+}
+
 /// Represents an AggregateUDF
 #[pyclass(frozen, name = "AggregateUDF", module = "datafusion", subclass)]
 #[derive(Debug, Clone)]
@@ -186,22 +195,22 @@ impl PyAggregateUDF {
 
     #[staticmethod]
     pub fn from_pycapsule(func: Bound<'_, PyAny>) -> PyDataFusionResult<Self> {
+        if func.is_instance_of::<PyCapsule>() {
+            let capsule = 
func.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
+            let function = aggregate_udf_from_capsule(capsule)?;
+            return Ok(Self { function });
+        }
+
         if func.hasattr("__datafusion_aggregate_udf__")? {
             let capsule = 
func.getattr("__datafusion_aggregate_udf__")?.call0()?;
             let capsule = 
capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
-            validate_pycapsule(capsule, "datafusion_aggregate_udf")?;
-
-            let udaf = unsafe { capsule.reference::<FFI_AggregateUDF>() };
-            let udaf: ForeignAggregateUDF = udaf.try_into()?;
-
-            Ok(Self {
-                function: udaf.into(),
-            })
-        } else {
-            Err(crate::errors::PyDataFusionError::Common(
-                "__datafusion_aggregate_udf__ does not exist on AggregateUDF 
object.".to_string(),
-            ))
+            let function = aggregate_udf_from_capsule(capsule)?;
+            return Ok(Self { function });
         }
+
+        Err(crate::errors::PyDataFusionError::Common(
+            "__datafusion_aggregate_udf__ does not exist on AggregateUDF 
object.".to_string(),
+        ))
     }
 
     /// creates a new PyExpr with the call of the udf


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to