This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new e98b94e feat: Introduce Experimental `tvm_ffi.dataclasses.c_class`
(#8)
e98b94e is described below
commit e98b94e118dfa5ac4bcf3764a8b1695afee3d596
Author: Junru Shao <[email protected]>
AuthorDate: Sun Sep 21 17:22:49 2025 -0700
feat: Introduce Experimental `tvm_ffi.dataclasses.c_class` (#8)
Depends on #32.
This PR introduces an experimental `@c_class` decorator that enables
Python dataclass-like syntax for exposing C++ objects through TVM FFI.
The decorator automatically handles field reflection, inheritance, and
constructor generation for FFI-backed classes.
## Example
On C++ side, register types using `tvm::ffi::reflection::ObjectDef<>`:
```C++
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<MyClass>()
.def_static("__init__", [](int64_t v_i64, int32_t v_i32, double
v_f64, float v_f32) -> Any {
return ObjectRef(ffi::make_object<MyClass>(v_i64, v_i32, v_f64,
v_f32));
})
.def_rw("v_i64", &MyClass::v_i64)
.def_rw("v_i32", &MyClass::v_i32)
.def_rw("v_f64", &MyClass::v_f64)
.def_rw("v_f32", &MyClass::v_f32);
}
```
Mirror the same structure in Python using dataclass-style annotations:
```python
from tvm_ffi.dataclasses import c_class, field
@c_class("example.MyClass")
class MyClass:
v_i64: int
v_i32: int
v_f64: float = field(default=0.0)
v_f32: float = field(default_factory=lambda: 1.0)
obj = MyClass(v_i64=4, v_i32=8)
obj.v_f64 = 3.14 # transparently forwards to the underlying C++ object
```
## Future work
Supporting as many dataclass features as possible, including: `repr`,
`init`, `order`, etc.
---
python/tvm_ffi/core.pyi | 10 ++
python/tvm_ffi/cython/object.pxi | 10 ++
python/tvm_ffi/cython/type_info.pxi | 1 +
python/tvm_ffi/dataclasses/__init__.py | 24 ++++
python/tvm_ffi/dataclasses/_utils.py | 209 +++++++++++++++++++++++++++++++
python/tvm_ffi/dataclasses/c_class.py | 171 +++++++++++++++++++++++++
python/tvm_ffi/dataclasses/field.py | 95 ++++++++++++++
python/tvm_ffi/registry.py | 2 +
python/tvm_ffi/testing.py | 28 ++++-
src/ffi/extra/testing.cc | 51 ++++++++
tests/python/test_dataclasses_c_class.py | 66 ++++++++++
11 files changed, 665 insertions(+), 2 deletions(-)
diff --git a/python/tvm_ffi/core.pyi b/python/tvm_ffi/core.pyi
index afadcd1..9ca9a18 100644
--- a/python/tvm_ffi/core.pyi
+++ b/python/tvm_ffi/core.pyi
@@ -45,6 +45,15 @@ class Object:
def __ne__(self, other: Any) -> bool: ...
def __hash__(self) -> int: ...
def __init_handle_by_constructor__(self, fconstructor: Function, *args:
Any) -> None: ...
+ def __ffi_init__(self, *args: Any) -> None:
+ """Initialize the instance using the ` __init__` method registered on
C++ side.
+
+ Parameters
+ ----------
+ args: list of objects
+ The arguments to the constructor
+
+ """
def same_as(self, other: Any) -> bool: ...
def _move(self) -> ObjectRValueRef: ...
def __move_handle_from__(self, other: Object) -> None: ...
@@ -240,6 +249,7 @@ class TypeField:
frozen: bool
getter: Any
setter: Any
+ dataclass_field: Any | None
def as_property(self, cls: type) -> property: ...
diff --git a/python/tvm_ffi/cython/object.pxi b/python/tvm_ffi/cython/object.pxi
index 3d0e33e..326a98b 100644
--- a/python/tvm_ffi/cython/object.pxi
+++ b/python/tvm_ffi/cython/object.pxi
@@ -138,6 +138,16 @@ cdef class Object:
(<Object>fconstructor).chandle, <PyObject*>args, &chandle, NULL)
self.chandle = chandle
+ def __ffi_init__(self, *args) -> None:
+ """Initialize the instance using the ` __init__` method registered on
C++ side.
+
+ Parameters
+ ----------
+ args: list of objects
+ The arguments to the constructor
+ """
+ self.__init_handle_by_constructor__(type(self).__c_ffi_init__, *args)
+
def same_as(self, other):
"""Check object identity.
diff --git a/python/tvm_ffi/cython/type_info.pxi
b/python/tvm_ffi/cython/type_info.pxi
index bde25be..a50a95f 100644
--- a/python/tvm_ffi/cython/type_info.pxi
+++ b/python/tvm_ffi/cython/type_info.pxi
@@ -68,6 +68,7 @@ class TypeField:
frozen: bool
getter: FieldGetter
setter: FieldSetter
+ dataclass_field: object | None = None
def __post_init__(self):
assert self.setter is not None
diff --git a/python/tvm_ffi/dataclasses/__init__.py
b/python/tvm_ffi/dataclasses/__init__.py
new file mode 100644
index 0000000..3185413
--- /dev/null
+++ b/python/tvm_ffi/dataclasses/__init__.py
@@ -0,0 +1,24 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Experimental FFI interface that exposes C++ classes to Python in dataclass
syntax."""
+
+from dataclasses import MISSING
+
+from .c_class import c_class
+from .field import Field, field
+
+__all__ = ["MISSING", "Field", "c_class", "field"]
diff --git a/python/tvm_ffi/dataclasses/_utils.py
b/python/tvm_ffi/dataclasses/_utils.py
new file mode 100644
index 0000000..ef7c7e4
--- /dev/null
+++ b/python/tvm_ffi/dataclasses/_utils.py
@@ -0,0 +1,209 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities for constructing Python proxies of FFI types."""
+
+from __future__ import annotations
+
+import functools
+import inspect
+from dataclasses import MISSING
+from typing import Any, Callable, NamedTuple, TypeVar
+
+from ..core import (
+ Object,
+ TypeField,
+ TypeInfo,
+ _lookup_type_info_from_type_key,
+)
+
+_InputClsType = TypeVar("_InputClsType")
+
+
+def get_parent_type_info(type_cls: type) -> TypeInfo:
+ """Find the nearest ancestor with registered ``__tvm_ffi_type_info__``.
+
+ If none are found, return the base ``ffi.Object`` type info.
+ """
+ for base in type_cls.__bases__:
+ if (info := getattr(base, "__tvm_ffi_type_info__", None)) is not None:
+ return info
+ return _lookup_type_info_from_type_key("ffi.Object")
+
+
+def type_info_to_cls(
+ type_info: TypeInfo,
+ cls: type[_InputClsType],
+ methods: dict[str, Callable[..., Any] | None],
+) -> type[_InputClsType]:
+ assert type_info.type_cls is None, "Type class is already created"
+ # Step 1. Determine the base classes
+ cls_bases = cls.__bases__
+ if cls_bases == (object,):
+ # If the class inherits from `object`, we need to set the base class
to `Object`
+ cls_bases = (Object,)
+
+ # Step 2. Define the new class attributes
+ attrs = dict(cls.__dict__)
+ attrs.pop("__dict__", None)
+ attrs.pop("__weakref__", None)
+ attrs["__slots__"] = ()
+ attrs["__tvm_ffi_type_info__"] = type_info
+
+ # Step 2. Add fields
+ for field in type_info.fields:
+ attrs[field.name] = field.as_property(cls)
+
+ # Step 3. Add methods
+ def _add_method(name: str, func: Callable) -> None:
+ if name == "__ffi_init__":
+ name = "__c_ffi_init__"
+ if name in attrs: # already defined
+ return
+ func.__module__ = cls.__module__
+ func.__name__ = name
+ func.__qualname__ = f"{cls.__qualname__}.{name}"
+ func.__doc__ = f"Method `{name}` of class `{cls.__qualname__}`"
+ attrs[name] = func
+ setattr(cls, name, func)
+
+ for name, method in methods.items():
+ if method is not None:
+ _add_method(name, method)
+ for method in type_info.methods:
+ _add_method(method.name, method.func)
+
+ # Step 4. Create the new class
+ new_cls = type(cls.__name__, cls_bases, attrs)
+ new_cls.__module__ = cls.__module__
+ new_cls = functools.wraps(cls, updated=())(new_cls) # type: ignore
+ return new_cls
+
+
+def fill_dataclass_field(type_cls: type, type_field: TypeField) -> None:
+ from .field import Field, field # noqa: PLC0415
+
+ field_name = type_field.name
+ rhs: Any = getattr(type_cls, field_name, MISSING)
+ if rhs is MISSING:
+ rhs = field()
+ elif isinstance(rhs, Field):
+ pass
+ elif isinstance(rhs, (int, float, str, bool, type(None))):
+ rhs = field(default=rhs)
+ else:
+ raise ValueError(f"Cannot recognize field: {type_field.name}: {rhs}")
+ assert isinstance(rhs, Field)
+ rhs.name = type_field.name
+ type_field.dataclass_field = rhs
+
+
+def method_init(type_cls: type, type_info: TypeInfo) -> Callable[..., None]:
# noqa: PLR0915
+ """Generate an ``__init__`` that forwards to the FFI constructor.
+
+ The generated initializer has a proper Python signature built from the
+ reflected field list, supporting default values and ``__post_init__``.
+ """
+
+ class DefaultFactory(NamedTuple):
+ """Wrapper that marks a parameter as having a default factory."""
+
+ fn: Callable[[], Any]
+
+ fields: list[TypeInfo] = []
+ cur_type_info = type_info
+ while True:
+ fields.extend(reversed(cur_type_info.fields))
+ cur_type_info = cur_type_info.parent_type_info
+ if cur_type_info is None:
+ break
+ fields.reverse()
+ del cur_type_info
+
+ annotations: dict[str, Any] = {"return": None}
+ # Step 1. Split the parameters into two groups to ensure that
+ # those without defaults appear first in the signature.
+ params_without_defaults: list[inspect.Parameter] = []
+ params_with_defaults: list[inspect.Parameter] = []
+ ordering = [0] * len(fields)
+ for i, field in enumerate(fields):
+ assert field.name is not None
+ name: str = field.name
+ annotations[name] = Any # NOTE: We might be able to handle
annotations better
+ assert field.dataclass_field is not None
+ default_factory = field.dataclass_field.default_factory
+ if default_factory is MISSING:
+ ordering[i] = len(params_without_defaults)
+ params_without_defaults.append(
+ inspect.Parameter(name=name,
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD)
+ )
+ else:
+ ordering[i] = -len(params_with_defaults) - 1
+ params_with_defaults.append(
+ inspect.Parameter(
+ name=name,
+ kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
+ default=DefaultFactory(fn=default_factory),
+ )
+ )
+ for i, order in enumerate(ordering):
+ if order < 0:
+ ordering[i] = len(params_without_defaults) - order - 1
+ # Step 2. Create the signature object
+ sig = inspect.Signature(parameters=[*params_without_defaults,
*params_with_defaults])
+ signature_str = (
+ f"{type_cls.__module__}.{type_cls.__qualname__}.__init__("
+ + ", ".join(p.name for p in sig.parameters.values())
+ + ")"
+ )
+
+ # Step 3. Create the `binding` method that reorders parameters
+ def touch_arg(x: Any) -> Any:
+ return x.fn() if isinstance(x, DefaultFactory) else x
+
+ def bind_args(*args: Any, **kwargs: Any) -> tuple[Any, ...]:
+ bound = sig.bind(*args, **kwargs)
+ bound.apply_defaults()
+ args = bound.args
+ args = tuple(touch_arg(args[i]) for i in ordering)
+ return args
+
+ for type_method in type_info.methods:
+ if type_method.name == "__ffi_init__":
+ break
+ else:
+ raise ValueError(f"Cannot find constructor method:
`{type_info.type_key}.__ffi_init__`")
+
+ def __init__(self: type, *args: Any, **kwargs: Any) -> None:
+ e = None
+ try:
+ args = bind_args(*args, **kwargs)
+ del kwargs
+ self.__ffi_init__(*args)
+ except Exception as _e:
+ e = TypeError(f"Error in `{signature_str}`:
{_e}").with_traceback(_e.__traceback__)
+ if e is not None:
+ raise e
+ try:
+ fn_post_init = self.__post_init__ # type: ignore[attr-defined]
+ except AttributeError:
+ pass
+ else:
+ fn_post_init()
+
+ __init__.__signature__ = sig # type: ignore[attr-defined]
+ __init__.__annotations__ = annotations
+ return __init__
diff --git a/python/tvm_ffi/dataclasses/c_class.py
b/python/tvm_ffi/dataclasses/c_class.py
new file mode 100644
index 0000000..7507b76
--- /dev/null
+++ b/python/tvm_ffi/dataclasses/c_class.py
@@ -0,0 +1,171 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Helpers for mirroring registered C++ FFI types with Python dataclass syntax.
+
+The :func:`c_class` decorator is the primary entry point. It inspects the
+reflection metadata that the C++ runtime exposes via the TVM FFI registry and
+turns it into Python ``dataclass``-style descriptors: annotated attributes
become
+properties that forward to the underlying C++ object, while an ``__init__``
+method is synthesized to call the FFI constructor when requested.
+"""
+
+from collections.abc import Callable
+from dataclasses import InitVar
+from typing import ClassVar, TypeVar, get_origin, get_type_hints
+
+from ..core import TypeField, TypeInfo
+from . import _utils, field
+
+try:
+ from typing import dataclass_transform
+except ImportError:
+ from typing_extensions import dataclass_transform
+
+
+_InputClsType = TypeVar("_InputClsType")
+
+
+@dataclass_transform(field_specifiers=(field.field, field.Field))
+def c_class(
+ type_key: str, init: bool = True
+) -> Callable[[type[_InputClsType]], type[_InputClsType]]:
+ """(Experimental) Create a dataclass-like proxy for a C++ class registered
with TVM FFI.
+
+ The decorator reads the reflection metadata that was registered on the C++
+ side using ``tvm::ffi::reflection::ObjectDef`` and binds it to the
annotated
+ attributes in the decorated Python class. Each field defined in C++ becomes
+ a property on the Python class, and optional default values can be provided
+ with :func:`tvm_ffi.dataclasses.field` in the same way as Python's native
+ ``dataclasses.field``.
+
+ The intent is to offer a familiar dataclass authoring experience while
still
+ exposing the underlying C++ object. The ``type_key`` of the C++ class must
+ match the string passed to :func:`c_class`, and inheritance relationships
are
+ preserved—subclasses registered in C++ can subclass the Python proxy
defined
+ for their parent.
+
+ Parameters
+ ----------
+ type_key : str
+ The reflection key that identifies the C++ type in the FFI registry,
+ e.g. ``"testing.MyClass"`` as registered in
+ ``src/ffi/extra/testing.cc``.
+
+ init : bool, default True
+ If ``True`` and the Python class does not define ``__init__``, an
+ initializer is auto-generated that mirrors the reflected constructor
+ signature. The generated initializer calls the C++ ``__init__``
+ function registered with ``ObjectDef`` and invokes ``__post_init__`` if
+ it exists on the Python class.
+
+ Returns
+ -------
+ Callable[[type], type]
+ A class decorator that materializes the final proxy class.
+
+ Examples
+ --------
+ Register the C++ type and its fields with TVM FFI:
+
+ .. code-block:: c++
+
+ TVM_FFI_STATIC_INIT_BLOCK() {
+ namespace refl = tvm::ffi::reflection;
+ refl::ObjectDef<MyClass>()
+ .def_static("__init__", [](int64_t v_i64, int32_t v_i32,
+ double v_f64, float v_f32) -> Any {
+ return ObjectRef(ffi::make_object<MyClass>(
+ v_i64, v_i32, v_f64, v_f32));
+ })
+ .def_rw("v_i64", &MyClass::v_i64)
+ .def_rw("v_i32", &MyClass::v_i32)
+ .def_rw("v_f64", &MyClass::v_f64)
+ .def_rw("v_f32", &MyClass::v_f32);
+ }
+
+ Mirror the same structure in Python using dataclass-style annotations:
+
+ .. code-block:: python
+
+ from tvm_ffi.dataclasses import c_class, field
+
+ @c_class("example.MyClass")
+ class MyClass:
+ v_i64: int
+ v_i32: int
+ v_f64: float = field(default=0.0)
+ v_f32: float = field(default_factory=lambda: 1.0)
+
+ obj = MyClass(v_i64=4, v_i32=8)
+ obj.v_f64 = 3.14 # transparently forwards to the underlying C++ object
+
+ """
+
+ def decorator(super_type_cls: type[_InputClsType]) -> type[_InputClsType]:
+ nonlocal init
+ init = init and "__init__" not in super_type_cls.__dict__
+ # Step 1. Retrieve `type_info` from registry
+ type_info: TypeInfo = _utils._lookup_type_info_from_type_key(type_key)
+ assert type_info.parent_type_info is None, f"Already registered type:
{type_key}"
+ type_info.parent_type_info =
_utils.get_parent_type_info(super_type_cls)
+ # Step 2. Reflect all the fields of the type
+ type_info.fields = _inspect_c_class_fields(super_type_cls, type_info)
+ for type_field in type_info.fields:
+ _utils.fill_dataclass_field(super_type_cls, type_field)
+ # Step 3. Create the proxy class with the fields as properties
+ fn_init = _utils.method_init(super_type_cls, type_info) if init else
None
+ type_cls: type[_InputClsType] = _utils.type_info_to_cls(
+ type_info=type_info,
+ cls=super_type_cls,
+ methods={"__init__": fn_init},
+ )
+ type_info.type_cls = type_cls
+ return type_cls
+
+ return decorator
+
+
+def _inspect_c_class_fields(type_cls: type, type_info: TypeInfo) ->
list[TypeField]:
+ type_hints_resolved = get_type_hints(type_cls, include_extras=True)
+ type_hints_py = {
+ name: type_hints_resolved[name]
+ for name in getattr(type_cls, "__annotations__", {}).keys()
+ if get_origin(type_hints_resolved[name])
+ not in [ # ignore non-field annotations
+ ClassVar,
+ InitVar,
+ ]
+ }
+ del type_hints_resolved
+
+ type_fields_cxx: dict[str, TypeField] = {f.name: f for f in
type_info.fields}
+ type_fields: list[TypeField] = []
+ for field_name, _field_ty_py in type_hints_py.items():
+ if field_name.startswith("__tvm_ffi"): # TVM's private fields - skip
+ continue
+ type_field: TypeField = type_fields_cxx.pop(field_name, None)
+ if type_field is None:
+ raise ValueError(
+ f"Extraneous field `{type_cls}.{field_name}`. Defined in
Python but not in C++"
+ )
+ type_fields.append(type_field)
+ if type_fields_cxx:
+ extra_fields = ", ".join(f"`{f.name}`" for f in
type_fields_cxx.values())
+ raise ValueError(
+ f"Missing fields in `{type_cls}`: {extra_fields}. Defined in C++
but not in Python"
+ )
+ return type_fields
diff --git a/python/tvm_ffi/dataclasses/field.py
b/python/tvm_ffi/dataclasses/field.py
new file mode 100644
index 0000000..00170e5
--- /dev/null
+++ b/python/tvm_ffi/dataclasses/field.py
@@ -0,0 +1,95 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Public helpers for describing dataclass-style defaults on FFI proxies."""
+
+from __future__ import annotations
+
+from dataclasses import MISSING, dataclass
+from typing import Any, Callable
+
+
+@dataclass(kw_only=True)
+class Field:
+ """(Experimental) Descriptor placeholder returned by
:func:`tvm_ffi.dataclasses.field`.
+
+ A ``Field`` mirrors the object returned by :func:`dataclasses.field`, but
it
+ is understood by :func:`tvm_ffi.dataclasses.c_class`. The decorator
inspects
+ the ``Field`` instances, records the ``default_factory`` and later replaces
+ the field with a property that forwards to the underlying C++ attribute.
+
+ Users should not instantiate ``Field`` directly—use :func:`field` instead,
+ which guarantees that ``name`` and ``default_factory`` are populated in a
+ way the decorator understands.
+ """
+
+ name: str | None = None
+ default_factory: Callable[[], Any]
+
+
+def field(*, default: Any = MISSING, default_factory: Any = MISSING) -> Field:
+ """(Experimental) Declare a dataclass-style field on a :func:`c_class`
proxy.
+
+ Use this helper exactly like :func:`dataclasses.field` when defining the
+ Python side of a C++ class. When :func:`c_class` processes the class body
it
+ replaces the placeholder with a property and arranges for ``default`` or
+ ``default_factory`` to be respected by the synthesized ``__init__``.
+
+ Parameters
+ ----------
+ default : Any, optional
+ A literal default value that should populate the field when no argument
+ is given. The value is copied into a closure because TVM FFI does not
+ mutate the Python placeholder instance.
+ default_factory : Callable[[], Any], optional
+ A zero-argument callable that produces the default. This matches the
+ semantics of :func:`dataclasses.field` and is useful for mutable
+ defaults such as ``list`` or ``dict``.
+
+ Returns
+ -------
+ Field
+ A placeholder object that :func:`c_class` will consume during class
+ registration.
+
+ Examples
+ --------
+ ``field`` integrates with :func:`c_class` to express defaults the same way
a
+ Python ``dataclass`` would::
+
+ @c_class("testing.TestCxxClassBase")
+ class PyBase:
+ v_i64: int
+ v_i32: int = field(default=16)
+
+ obj = PyBase(v_i64=4)
+ obj.v_i32 # -> 16
+
+ """
+ if default is not MISSING and default_factory is not MISSING:
+ raise ValueError("Cannot specify both `default` and `default_factory`")
+ if default is not MISSING:
+ default_factory = _make_default_factory(default)
+ return Field(default_factory=default_factory)
+
+
+def _make_default_factory(value: Any) -> Callable[[], Any]:
+ """Make a default factory that returns the given value."""
+
+ def factory() -> Any:
+ return value
+
+ return factory
diff --git a/python/tvm_ffi/registry.py b/python/tvm_ffi/registry.py
index 5a540fb..6bf08f6 100644
--- a/python/tvm_ffi/registry.py
+++ b/python/tvm_ffi/registry.py
@@ -248,6 +248,8 @@ def _add_class_attrs(type_cls: type, type_info: TypeInfo)
-> type:
setattr(type_cls, name, property(getter, setter, doc=doc))
for method in type_info.methods:
name = method.name
+ if name == "__ffi_init__":
+ name = "__c_ffi_init__"
doc = method.doc if method.doc else None
method_func = method.func
if method.is_static:
diff --git a/python/tvm_ffi/testing.py b/python/tvm_ffi/testing.py
index 3215d8a..825f9cf 100644
--- a/python/tvm_ffi/testing.py
+++ b/python/tvm_ffi/testing.py
@@ -16,10 +16,11 @@
# under the License.
"""Testing utilities."""
-from typing import Any
+from typing import Any, ClassVar
from . import _ffi_api
from .core import Object
+from .dataclasses import c_class, field
from .registry import register_object
@@ -34,7 +35,7 @@ class TestIntPair(Object):
def __init__(self, a: int, b: int) -> None:
"""Construct the object."""
- self.__init_handle_by_constructor__(TestIntPair.__ffi_init__, a, b)
+ self.__ffi_init__(a, b)
@register_object("testing.TestObjectDerived")
@@ -68,3 +69,26 @@ def create_object(type_key: str, **kwargs: Any) -> Object:
args.append(k)
args.append(v)
return _ffi_api.MakeObjectFromPackedArgs(*args)
+
+
+@c_class("testing.TestCxxClassBase")
+class _TestCxxClassBase:
+ v_i64: int
+ v_i32: int
+ not_field_1 = 1
+ not_field_2: ClassVar[int] = 2
+
+ def __init__(self, v_i64: int, v_i32: int) -> None:
+ self.__ffi_init__(v_i64 + 1, v_i32 + 2)
+
+
+@c_class("testing.TestCxxClassDerived")
+class _TestCxxClassDerived(_TestCxxClassBase):
+ v_f64: float
+ v_f32: float = 8
+
+
+@c_class("testing.TestCxxClassDerivedDerived")
+class _TestCxxClassDerivedDerived(_TestCxxClassDerived):
+ v_str: str = field(default_factory=lambda: "default")
+ v_bool: bool
diff --git a/src/ffi/extra/testing.cc b/src/ffi/extra/testing.cc
index 9c3a019..370a1c6 100644
--- a/src/ffi/extra/testing.cc
+++ b/src/ffi/extra/testing.cc
@@ -86,6 +86,41 @@ class TestObjectDerived : public TestObjectBase {
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("testing.TestObjectDerived",
TestObjectDerived, TestObjectBase);
};
+class TestCxxClassBase : public Object {
+ public:
+ int64_t v_i64;
+ int32_t v_i32;
+
+ TestCxxClassBase(int64_t v_i64, int32_t v_i32) : v_i64(v_i64), v_i32(v_i32)
{}
+
+ static constexpr bool _type_mutable = true;
+ TVM_FFI_DECLARE_OBJECT_INFO("testing.TestCxxClassBase", TestCxxClassBase,
Object);
+};
+
+class TestCxxClassDerived : public TestCxxClassBase {
+ public:
+ double v_f64;
+ float v_f32;
+
+ TestCxxClassDerived(int64_t v_i64, int32_t v_i32, double v_f64, float v_f32)
+ : TestCxxClassBase(v_i64, v_i32), v_f64(v_f64), v_f32(v_f32) {}
+
+ TVM_FFI_DECLARE_OBJECT_INFO("testing.TestCxxClassDerived",
TestCxxClassDerived, TestCxxClassBase);
+};
+
+class TestCxxClassDerivedDerived : public TestCxxClassDerived {
+ public:
+ String v_str;
+ bool v_bool;
+
+ TestCxxClassDerivedDerived(int64_t v_i64, int32_t v_i32, double v_f64, float
v_f32, String v_str,
+ bool v_bool)
+ : TestCxxClassDerived(v_i64, v_i32, v_f64, v_f32), v_str(v_str),
v_bool(v_bool) {}
+
+ TVM_FFI_DECLARE_OBJECT_INFO("testing.TestCxxClassDerivedDerived",
TestCxxClassDerivedDerived,
+ TestCxxClassDerived);
+};
+
TVM_FFI_NO_INLINE void TestRaiseError(String kind, String msg) {
// keep name and no liner for testing traceback
throw ffi::Error(kind, msg, TVMFFITraceback(__FILE__, __LINE__,
TVM_FFI_FUNC_SIG, 0));
@@ -110,6 +145,22 @@ TVM_FFI_STATIC_INIT_BLOCK() {
.def_ro("v_map", &TestObjectDerived::v_map)
.def_ro("v_array", &TestObjectDerived::v_array);
+ refl::ObjectDef<TestCxxClassBase>()
+ .def_static("__ffi_init__", refl::init<TestCxxClassBase, int64_t,
int32_t>)
+ .def_rw("v_i64", &TestCxxClassBase::v_i64)
+ .def_rw("v_i32", &TestCxxClassBase::v_i32);
+
+ refl::ObjectDef<TestCxxClassDerived>()
+ .def_static("__ffi_init__", refl::init<TestCxxClassDerived, int64_t,
int32_t, double, float>)
+ .def_rw("v_f64", &TestCxxClassDerived::v_f64)
+ .def_rw("v_f32", &TestCxxClassDerived::v_f32);
+ refl::ObjectDef<TestCxxClassDerivedDerived>()
+ .def_static(
+ "__ffi_init__",
+ refl::init<TestCxxClassDerivedDerived, int64_t, int32_t, double,
float, String, bool>)
+ .def_rw("v_str", &TestCxxClassDerivedDerived::v_str)
+ .def_rw("v_bool", &TestCxxClassDerivedDerived::v_bool);
+
refl::GlobalDef()
.def("testing.test_raise_error", TestRaiseError)
.def_packed("testing.nop", [](PackedArgs args, Any* ret) {})
diff --git a/tests/python/test_dataclasses_c_class.py
b/tests/python/test_dataclasses_c_class.py
new file mode 100644
index 0000000..a2fa80e
--- /dev/null
+++ b/tests/python/test_dataclasses_c_class.py
@@ -0,0 +1,66 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from tvm_ffi.testing import _TestCxxClassBase, _TestCxxClassDerived,
_TestCxxClassDerivedDerived
+
+
+def test_cxx_class_base() -> None:
+ obj = _TestCxxClassBase(v_i64=123, v_i32=456)
+ assert obj.v_i64 == 123 + 1
+ assert obj.v_i32 == 456 + 2
+
+
+def test_cxx_class_derived() -> None:
+ obj = _TestCxxClassDerived(v_i64=123, v_i32=456, v_f64=4.00, v_f32=8.00)
+ assert obj.v_i64 == 123
+ assert obj.v_i32 == 456
+ assert obj.v_f64 == 4.00
+ assert obj.v_f32 == 8.00
+
+
+def test_cxx_class_derived_default() -> None:
+ obj = _TestCxxClassDerived(v_i64=123, v_i32=456, v_f64=4.00)
+ assert obj.v_i64 == 123
+ assert obj.v_i32 == 456
+ assert obj.v_f64 == 4.00
+ assert isinstance(obj.v_f32, float) and obj.v_f32 == 8.00 # default value
+
+
+def test_cxx_class_derived_derived() -> None:
+ obj = _TestCxxClassDerivedDerived(
+ v_i64=123,
+ v_i32=456,
+ v_f64=4.00,
+ v_f32=8.00,
+ v_str="hello",
+ v_bool=True,
+ )
+ assert obj.v_i64 == 123
+ assert obj.v_i32 == 456
+ assert obj.v_f64 == 4.00
+ assert obj.v_f32 == 8.00
+ assert obj.v_str == "hello"
+ assert obj.v_bool is True
+
+
+def test_cxx_class_derived_derived_default() -> None:
+ obj = _TestCxxClassDerivedDerived(123, 456, 4, True)
+ assert obj.v_i64 == 123
+ assert obj.v_i32 == 456
+ assert isinstance(obj.v_f64, float) and obj.v_f64 == 4
+ assert isinstance(obj.v_f32, float) and obj.v_f32 == 8
+ assert obj.v_str == "default"
+ assert isinstance(obj.v_bool, bool) and obj.v_bool is True