This is an automated email from the ASF dual-hosted git repository.
kosiew pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-python.git
The following commit(s) were added to refs/heads/main by this push:
new c4e74864 Add PyCapsule Type Support and Type Hint Enhancements for
AggregateUDF in DataFusion Python Bindings (#1277)
c4e74864 is described below
commit c4e748641ad17c1c04c5f4411a32fb4bd3010904
Author: kosiew <[email protected]>
AuthorDate: Sat Nov 8 18:45:12 2025 +0800
Add PyCapsule Type Support and Type Hint Enhancements for AggregateUDF in
DataFusion Python Bindings (#1277)
Added TypeGuard function _is_pycapsule() for lightweight PyCapsule type
validation.
Introduced _PyCapsule proxy class for static typing compatibility in
non-type-checking contexts.
Extended overloads in AggregateUDF.__init__ and AggregateUDF.udaf() to
include AggregateUDFExportable | _PyCapsule argument types.
Added stricter constructor argument validation for callable accumulators.
Updated AggregateUDF.from_pycapsule() to support direct PyCapsule
initialization.
Refactored Rust PyAggregateUDF::from_pycapsule() logic to delegate
PyCapsule validation to a new helper function aggregate_udf_from_capsule() for
cleaner handling.
---
python/datafusion/user_defined.py | 65 ++++++++++++++++++++++++++++++++++---
python/tests/test_pyclass_frozen.py | 3 +-
src/udaf.rs | 33 ++++++++++++-------
3 files changed, 82 insertions(+), 19 deletions(-)
diff --git a/python/datafusion/user_defined.py
b/python/datafusion/user_defined.py
index 21b2de63..43a72c80 100644
--- a/python/datafusion/user_defined.py
+++ b/python/datafusion/user_defined.py
@@ -22,7 +22,7 @@ from __future__ import annotations
import functools
from abc import ABCMeta, abstractmethod
from enum import Enum
-from typing import TYPE_CHECKING, Any, Protocol, TypeVar, overload
+from typing import TYPE_CHECKING, Any, Protocol, TypeGuard, TypeVar, cast,
overload
import pyarrow as pa
@@ -30,6 +30,8 @@ import datafusion._internal as df_internal
from datafusion.expr import Expr
if TYPE_CHECKING:
+ from _typeshed import CapsuleType as _PyCapsule
+
_R = TypeVar("_R", bound=pa.DataType)
from collections.abc import Callable
@@ -84,6 +86,11 @@ class ScalarUDFExportable(Protocol):
def __datafusion_scalar_udf__(self) -> object: ... # noqa: D105
+def _is_pycapsule(value: object) -> TypeGuard[_PyCapsule]:
+ """Return ``True`` when ``value`` is a CPython ``PyCapsule``."""
+ return value.__class__.__name__ == "PyCapsule"
+
+
class ScalarUDF:
"""Class for performing scalar user-defined functions (UDF).
@@ -291,6 +298,7 @@ class AggregateUDF:
also :py:class:`ScalarUDF` for operating on a row by row basis.
"""
+ @overload
def __init__(
self,
name: str,
@@ -299,6 +307,27 @@ class AggregateUDF:
return_type: pa.DataType,
state_type: list[pa.DataType],
volatility: Volatility | str,
+ ) -> None: ...
+
+ @overload
+ def __init__(
+ self,
+ name: str,
+ accumulator: AggregateUDFExportable,
+ input_types: None = ...,
+ return_type: None = ...,
+ state_type: None = ...,
+ volatility: None = ...,
+ ) -> None: ...
+
+ def __init__(
+ self,
+ name: str,
+ accumulator: Callable[[], Accumulator] | AggregateUDFExportable,
+ input_types: list[pa.DataType] | None,
+ return_type: pa.DataType | None,
+ state_type: list[pa.DataType] | None,
+ volatility: Volatility | str | None,
) -> None:
"""Instantiate a user-defined aggregate function (UDAF).
@@ -308,6 +337,18 @@ class AggregateUDF:
if hasattr(accumulator, "__datafusion_aggregate_udf__"):
self._udaf = df_internal.AggregateUDF.from_pycapsule(accumulator)
return
+ if (
+ input_types is None
+ or return_type is None
+ or state_type is None
+ or volatility is None
+ ):
+ msg = (
+ "`input_types`, `return_type`, `state_type`, and `volatility` "
+ "must be provided when `accumulator` is callable."
+ )
+ raise TypeError(msg)
+
self._udaf = df_internal.AggregateUDF(
name,
accumulator,
@@ -351,6 +392,14 @@ class AggregateUDF:
name: str | None = None,
) -> AggregateUDF: ...
+ @overload
+ @staticmethod
+ def udaf(accum: AggregateUDFExportable) -> AggregateUDF: ...
+
+ @overload
+ @staticmethod
+ def udaf(accum: _PyCapsule) -> AggregateUDF: ...
+
@staticmethod
def udaf(*args: Any, **kwargs: Any): # noqa: D417, C901
"""Create a new User-Defined Aggregate Function (UDAF).
@@ -471,7 +520,7 @@ class AggregateUDF:
return decorator
- if hasattr(args[0], "__datafusion_aggregate_udf__"):
+ if hasattr(args[0], "__datafusion_aggregate_udf__") or
_is_pycapsule(args[0]):
return AggregateUDF.from_pycapsule(args[0])
if args and callable(args[0]):
@@ -481,16 +530,22 @@ class AggregateUDF:
return _decorator(*args, **kwargs)
@staticmethod
- def from_pycapsule(func: AggregateUDFExportable) -> AggregateUDF:
+ def from_pycapsule(func: AggregateUDFExportable | _PyCapsule) ->
AggregateUDF:
"""Create an Aggregate UDF from AggregateUDF PyCapsule object.
This function will instantiate a Aggregate UDF that uses a DataFusion
AggregateUDF that is exported via the FFI bindings.
"""
- name = str(func.__class__)
+ if _is_pycapsule(func):
+ aggregate = cast(AggregateUDF, object.__new__(AggregateUDF))
+ aggregate._udaf = df_internal.AggregateUDF.from_pycapsule(func)
+ return aggregate
+
+ capsule = cast(AggregateUDFExportable, func)
+ name = str(capsule.__class__)
return AggregateUDF(
name=name,
- accumulator=func,
+ accumulator=capsule,
input_types=None,
return_type=None,
state_type=None,
diff --git a/python/tests/test_pyclass_frozen.py
b/python/tests/test_pyclass_frozen.py
index 3500c5e3..33338bf5 100644
--- a/python/tests/test_pyclass_frozen.py
+++ b/python/tests/test_pyclass_frozen.py
@@ -35,8 +35,7 @@ ARG_STRING_RE = re.compile(
r"(?P<key>[A-Za-z_][A-Za-z0-9_]*)\s*=\s*\"(?P<value>[^\"]+)\"",
)
STRUCT_NAME_RE = re.compile(
- r"\b(?:pub\s+)?(?:struct|enum)\s+"
- r"(?P<name>[A-Za-z_][A-Za-z0-9_]*)",
+ r"\b(?:pub\s+)?(?:struct|enum)\s+" r"(?P<name>[A-Za-z_][A-Za-z0-9_]*)",
)
diff --git a/src/udaf.rs b/src/udaf.rs
index eab4581d..e48e35f8 100644
--- a/src/udaf.rs
+++ b/src/udaf.rs
@@ -154,6 +154,15 @@ pub fn to_rust_accumulator(accum: PyObject) ->
AccumulatorFactoryFunction {
})
}
+fn aggregate_udf_from_capsule(capsule: &Bound<'_, PyCapsule>) ->
PyDataFusionResult<AggregateUDF> {
+ validate_pycapsule(capsule, "datafusion_aggregate_udf")?;
+
+ let udaf = unsafe { capsule.reference::<FFI_AggregateUDF>() };
+ let udaf: ForeignAggregateUDF = udaf.try_into()?;
+
+ Ok(udaf.into())
+}
+
/// Represents an AggregateUDF
#[pyclass(frozen, name = "AggregateUDF", module = "datafusion", subclass)]
#[derive(Debug, Clone)]
@@ -186,22 +195,22 @@ impl PyAggregateUDF {
#[staticmethod]
pub fn from_pycapsule(func: Bound<'_, PyAny>) -> PyDataFusionResult<Self> {
+ if func.is_instance_of::<PyCapsule>() {
+ let capsule =
func.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
+ let function = aggregate_udf_from_capsule(capsule)?;
+ return Ok(Self { function });
+ }
+
if func.hasattr("__datafusion_aggregate_udf__")? {
let capsule =
func.getattr("__datafusion_aggregate_udf__")?.call0()?;
let capsule =
capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
- validate_pycapsule(capsule, "datafusion_aggregate_udf")?;
-
- let udaf = unsafe { capsule.reference::<FFI_AggregateUDF>() };
- let udaf: ForeignAggregateUDF = udaf.try_into()?;
-
- Ok(Self {
- function: udaf.into(),
- })
- } else {
- Err(crate::errors::PyDataFusionError::Common(
- "__datafusion_aggregate_udf__ does not exist on AggregateUDF
object.".to_string(),
- ))
+ let function = aggregate_udf_from_capsule(capsule)?;
+ return Ok(Self { function });
}
+
+ Err(crate::errors::PyDataFusionError::Common(
+ "__datafusion_aggregate_udf__ does not exist on AggregateUDF
object.".to_string(),
+ ))
}
/// creates a new PyExpr with the call of the udf
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]