This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new e98b94e  feat: Introduce Experimental `tvm_ffi.dataclasses.c_class` 
(#8)
e98b94e is described below

commit e98b94e118dfa5ac4bcf3764a8b1695afee3d596
Author: Junru Shao <[email protected]>
AuthorDate: Sun Sep 21 17:22:49 2025 -0700

    feat: Introduce Experimental `tvm_ffi.dataclasses.c_class` (#8)
    
    Depends on #32.
    
    This PR introduces an experimental `@c_class` decorator that enables
    Python dataclass-like syntax for exposing C++ objects through TVM FFI.
    The decorator automatically handles field reflection, inheritance, and
    constructor generation for FFI-backed classes.
    
    ## Example
    
    On C++ side, register types using `tvm::ffi::reflection::ObjectDef<>`:
    
    ```C++
    TVM_FFI_STATIC_INIT_BLOCK() {
        namespace refl = tvm::ffi::reflection;
        refl::ObjectDef<MyClass>()
            .def_static("__init__", [](int64_t v_i64, int32_t v_i32, double 
v_f64, float v_f32) -> Any {
                return ObjectRef(ffi::make_object<MyClass>(v_i64, v_i32, v_f64, 
v_f32));
            })
            .def_rw("v_i64", &MyClass::v_i64)
            .def_rw("v_i32", &MyClass::v_i32)
            .def_rw("v_f64", &MyClass::v_f64)
            .def_rw("v_f32", &MyClass::v_f32);
    }
    ```
    
    Mirror the same structure in Python using dataclass-style annotations:
    
    ```python
    from tvm_ffi.dataclasses import c_class, field
    
    @c_class("example.MyClass")
    class MyClass:
        v_i64: int
        v_i32: int
        v_f64: float = field(default=0.0)
        v_f32: float = field(default_factory=lambda: 1.0)
    
    obj = MyClass(v_i64=4, v_i32=8)
    obj.v_f64 = 3.14  # transparently forwards to the underlying C++ object
    ```
    
    ## Future work
    
    Supporting as many dataclass features as possible, including: `repr`,
    `init`, `order`, etc.
---
 python/tvm_ffi/core.pyi                  |  10 ++
 python/tvm_ffi/cython/object.pxi         |  10 ++
 python/tvm_ffi/cython/type_info.pxi      |   1 +
 python/tvm_ffi/dataclasses/__init__.py   |  24 ++++
 python/tvm_ffi/dataclasses/_utils.py     | 209 +++++++++++++++++++++++++++++++
 python/tvm_ffi/dataclasses/c_class.py    | 171 +++++++++++++++++++++++++
 python/tvm_ffi/dataclasses/field.py      |  95 ++++++++++++++
 python/tvm_ffi/registry.py               |   2 +
 python/tvm_ffi/testing.py                |  28 ++++-
 src/ffi/extra/testing.cc                 |  51 ++++++++
 tests/python/test_dataclasses_c_class.py |  66 ++++++++++
 11 files changed, 665 insertions(+), 2 deletions(-)

diff --git a/python/tvm_ffi/core.pyi b/python/tvm_ffi/core.pyi
index afadcd1..9ca9a18 100644
--- a/python/tvm_ffi/core.pyi
+++ b/python/tvm_ffi/core.pyi
@@ -45,6 +45,15 @@ class Object:
     def __ne__(self, other: Any) -> bool: ...
     def __hash__(self) -> int: ...
     def __init_handle_by_constructor__(self, fconstructor: Function, *args: 
Any) -> None: ...
+    def __ffi_init__(self, *args: Any) -> None:
+        """Initialize the instance using the ` __init__` method registered on 
C++ side.
+
+        Parameters
+        ----------
+        args: list of objects
+            The arguments to the constructor
+
+        """
     def same_as(self, other: Any) -> bool: ...
     def _move(self) -> ObjectRValueRef: ...
     def __move_handle_from__(self, other: Object) -> None: ...
@@ -240,6 +249,7 @@ class TypeField:
     frozen: bool
     getter: Any
     setter: Any
+    dataclass_field: Any | None
 
     def as_property(self, cls: type) -> property: ...
 
diff --git a/python/tvm_ffi/cython/object.pxi b/python/tvm_ffi/cython/object.pxi
index 3d0e33e..326a98b 100644
--- a/python/tvm_ffi/cython/object.pxi
+++ b/python/tvm_ffi/cython/object.pxi
@@ -138,6 +138,16 @@ cdef class Object:
             (<Object>fconstructor).chandle, <PyObject*>args, &chandle, NULL)
         self.chandle = chandle
 
+    def __ffi_init__(self, *args) -> None:
+        """Initialize the instance using the ` __init__` method registered on 
C++ side.
+
+        Parameters
+        ----------
+        args: list of objects
+            The arguments to the constructor
+        """
+        self.__init_handle_by_constructor__(type(self).__c_ffi_init__, *args)
+
     def same_as(self, other):
         """Check object identity.
 
diff --git a/python/tvm_ffi/cython/type_info.pxi 
b/python/tvm_ffi/cython/type_info.pxi
index bde25be..a50a95f 100644
--- a/python/tvm_ffi/cython/type_info.pxi
+++ b/python/tvm_ffi/cython/type_info.pxi
@@ -68,6 +68,7 @@ class TypeField:
     frozen: bool
     getter: FieldGetter
     setter: FieldSetter
+    dataclass_field: object | None = None
 
     def __post_init__(self):
         assert self.setter is not None
diff --git a/python/tvm_ffi/dataclasses/__init__.py 
b/python/tvm_ffi/dataclasses/__init__.py
new file mode 100644
index 0000000..3185413
--- /dev/null
+++ b/python/tvm_ffi/dataclasses/__init__.py
@@ -0,0 +1,24 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Experimental FFI interface that exposes C++ classes to Python in dataclass 
syntax."""
+
+from dataclasses import MISSING
+
+from .c_class import c_class
+from .field import Field, field
+
+__all__ = ["MISSING", "Field", "c_class", "field"]
diff --git a/python/tvm_ffi/dataclasses/_utils.py 
b/python/tvm_ffi/dataclasses/_utils.py
new file mode 100644
index 0000000..ef7c7e4
--- /dev/null
+++ b/python/tvm_ffi/dataclasses/_utils.py
@@ -0,0 +1,209 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities for constructing Python proxies of FFI types."""
+
+from __future__ import annotations
+
+import functools
+import inspect
+from dataclasses import MISSING
+from typing import Any, Callable, NamedTuple, TypeVar
+
+from ..core import (
+    Object,
+    TypeField,
+    TypeInfo,
+    _lookup_type_info_from_type_key,
+)
+
+_InputClsType = TypeVar("_InputClsType")
+
+
+def get_parent_type_info(type_cls: type) -> TypeInfo:
+    """Find the nearest ancestor with registered ``__tvm_ffi_type_info__``.
+
+    If none are found, return the base ``ffi.Object`` type info.
+    """
+    for base in type_cls.__bases__:
+        if (info := getattr(base, "__tvm_ffi_type_info__", None)) is not None:
+            return info
+    return _lookup_type_info_from_type_key("ffi.Object")
+
+
+def type_info_to_cls(
+    type_info: TypeInfo,
+    cls: type[_InputClsType],
+    methods: dict[str, Callable[..., Any] | None],
+) -> type[_InputClsType]:
+    assert type_info.type_cls is None, "Type class is already created"
+    # Step 1. Determine the base classes
+    cls_bases = cls.__bases__
+    if cls_bases == (object,):
+        # If the class inherits from `object`, we need to set the base class 
to `Object`
+        cls_bases = (Object,)
+
+    # Step 2. Define the new class attributes
+    attrs = dict(cls.__dict__)
+    attrs.pop("__dict__", None)
+    attrs.pop("__weakref__", None)
+    attrs["__slots__"] = ()
+    attrs["__tvm_ffi_type_info__"] = type_info
+
+    # Step 2. Add fields
+    for field in type_info.fields:
+        attrs[field.name] = field.as_property(cls)
+
+    # Step 3. Add methods
+    def _add_method(name: str, func: Callable) -> None:
+        if name == "__ffi_init__":
+            name = "__c_ffi_init__"
+        if name in attrs:  # already defined
+            return
+        func.__module__ = cls.__module__
+        func.__name__ = name
+        func.__qualname__ = f"{cls.__qualname__}.{name}"
+        func.__doc__ = f"Method `{name}` of class `{cls.__qualname__}`"
+        attrs[name] = func
+        setattr(cls, name, func)
+
+    for name, method in methods.items():
+        if method is not None:
+            _add_method(name, method)
+    for method in type_info.methods:
+        _add_method(method.name, method.func)
+
+    # Step 4. Create the new class
+    new_cls = type(cls.__name__, cls_bases, attrs)
+    new_cls.__module__ = cls.__module__
+    new_cls = functools.wraps(cls, updated=())(new_cls)  # type: ignore
+    return new_cls
+
+
+def fill_dataclass_field(type_cls: type, type_field: TypeField) -> None:
+    from .field import Field, field  # noqa: PLC0415
+
+    field_name = type_field.name
+    rhs: Any = getattr(type_cls, field_name, MISSING)
+    if rhs is MISSING:
+        rhs = field()
+    elif isinstance(rhs, Field):
+        pass
+    elif isinstance(rhs, (int, float, str, bool, type(None))):
+        rhs = field(default=rhs)
+    else:
+        raise ValueError(f"Cannot recognize field: {type_field.name}: {rhs}")
+    assert isinstance(rhs, Field)
+    rhs.name = type_field.name
+    type_field.dataclass_field = rhs
+
+
+def method_init(type_cls: type, type_info: TypeInfo) -> Callable[..., None]:  
# noqa: PLR0915
+    """Generate an ``__init__`` that forwards to the FFI constructor.
+
+    The generated initializer has a proper Python signature built from the
+    reflected field list, supporting default values and ``__post_init__``.
+    """
+
+    class DefaultFactory(NamedTuple):
+        """Wrapper that marks a parameter as having a default factory."""
+
+        fn: Callable[[], Any]
+
+    fields: list[TypeInfo] = []
+    cur_type_info = type_info
+    while True:
+        fields.extend(reversed(cur_type_info.fields))
+        cur_type_info = cur_type_info.parent_type_info
+        if cur_type_info is None:
+            break
+    fields.reverse()
+    del cur_type_info
+
+    annotations: dict[str, Any] = {"return": None}
+    # Step 1. Split the parameters into two groups to ensure that
+    # those without defaults appear first in the signature.
+    params_without_defaults: list[inspect.Parameter] = []
+    params_with_defaults: list[inspect.Parameter] = []
+    ordering = [0] * len(fields)
+    for i, field in enumerate(fields):
+        assert field.name is not None
+        name: str = field.name
+        annotations[name] = Any  # NOTE: We might be able to handle 
annotations better
+        assert field.dataclass_field is not None
+        default_factory = field.dataclass_field.default_factory
+        if default_factory is MISSING:
+            ordering[i] = len(params_without_defaults)
+            params_without_defaults.append(
+                inspect.Parameter(name=name, 
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD)
+            )
+        else:
+            ordering[i] = -len(params_with_defaults) - 1
+            params_with_defaults.append(
+                inspect.Parameter(
+                    name=name,
+                    kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
+                    default=DefaultFactory(fn=default_factory),
+                )
+            )
+    for i, order in enumerate(ordering):
+        if order < 0:
+            ordering[i] = len(params_without_defaults) - order - 1
+    # Step 2. Create the signature object
+    sig = inspect.Signature(parameters=[*params_without_defaults, 
*params_with_defaults])
+    signature_str = (
+        f"{type_cls.__module__}.{type_cls.__qualname__}.__init__("
+        + ", ".join(p.name for p in sig.parameters.values())
+        + ")"
+    )
+
+    # Step 3. Create the `binding` method that reorders parameters
+    def touch_arg(x: Any) -> Any:
+        return x.fn() if isinstance(x, DefaultFactory) else x
+
+    def bind_args(*args: Any, **kwargs: Any) -> tuple[Any, ...]:
+        bound = sig.bind(*args, **kwargs)
+        bound.apply_defaults()
+        args = bound.args
+        args = tuple(touch_arg(args[i]) for i in ordering)
+        return args
+
+    for type_method in type_info.methods:
+        if type_method.name == "__ffi_init__":
+            break
+    else:
+        raise ValueError(f"Cannot find constructor method: 
`{type_info.type_key}.__ffi_init__`")
+
+    def __init__(self: type, *args: Any, **kwargs: Any) -> None:
+        e = None
+        try:
+            args = bind_args(*args, **kwargs)
+            del kwargs
+            self.__ffi_init__(*args)
+        except Exception as _e:
+            e = TypeError(f"Error in `{signature_str}`: 
{_e}").with_traceback(_e.__traceback__)
+        if e is not None:
+            raise e
+        try:
+            fn_post_init = self.__post_init__  # type: ignore[attr-defined]
+        except AttributeError:
+            pass
+        else:
+            fn_post_init()
+
+    __init__.__signature__ = sig  # type: ignore[attr-defined]
+    __init__.__annotations__ = annotations
+    return __init__
diff --git a/python/tvm_ffi/dataclasses/c_class.py 
b/python/tvm_ffi/dataclasses/c_class.py
new file mode 100644
index 0000000..7507b76
--- /dev/null
+++ b/python/tvm_ffi/dataclasses/c_class.py
@@ -0,0 +1,171 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Helpers for mirroring registered C++ FFI types with Python dataclass syntax.
+
+The :func:`c_class` decorator is the primary entry point.  It inspects the
+reflection metadata that the C++ runtime exposes via the TVM FFI registry and
+turns it into Python ``dataclass``-style descriptors: annotated attributes 
become
+properties that forward to the underlying C++ object, while an ``__init__``
+method is synthesized to call the FFI constructor when requested.
+"""
+
+from collections.abc import Callable
+from dataclasses import InitVar
+from typing import ClassVar, TypeVar, get_origin, get_type_hints
+
+from ..core import TypeField, TypeInfo
+from . import _utils, field
+
+try:
+    from typing import dataclass_transform
+except ImportError:
+    from typing_extensions import dataclass_transform
+
+
+_InputClsType = TypeVar("_InputClsType")
+
+
+@dataclass_transform(field_specifiers=(field.field, field.Field))
+def c_class(
+    type_key: str, init: bool = True
+) -> Callable[[type[_InputClsType]], type[_InputClsType]]:
+    """(Experimental) Create a dataclass-like proxy for a C++ class registered 
with TVM FFI.
+
+    The decorator reads the reflection metadata that was registered on the C++
+    side using ``tvm::ffi::reflection::ObjectDef`` and binds it to the 
annotated
+    attributes in the decorated Python class. Each field defined in C++ becomes
+    a property on the Python class, and optional default values can be provided
+    with :func:`tvm_ffi.dataclasses.field` in the same way as Python's native
+    ``dataclasses.field``.
+
+    The intent is to offer a familiar dataclass authoring experience while 
still
+    exposing the underlying C++ object.  The ``type_key`` of the C++ class must
+    match the string passed to :func:`c_class`, and inheritance relationships 
are
+    preserved—subclasses registered in C++ can subclass the Python proxy 
defined
+    for their parent.
+
+    Parameters
+    ----------
+    type_key : str
+        The reflection key that identifies the C++ type in the FFI registry,
+        e.g. ``"testing.MyClass"`` as registered in
+        ``src/ffi/extra/testing.cc``.
+
+    init : bool, default True
+        If ``True`` and the Python class does not define ``__init__``, an
+        initializer is auto-generated that mirrors the reflected constructor
+        signature.  The generated initializer calls the C++ ``__init__``
+        function registered with ``ObjectDef`` and invokes ``__post_init__`` if
+        it exists on the Python class.
+
+    Returns
+    -------
+    Callable[[type], type]
+        A class decorator that materializes the final proxy class.
+
+    Examples
+    --------
+    Register the C++ type and its fields with TVM FFI:
+
+    .. code-block:: c++
+
+        TVM_FFI_STATIC_INIT_BLOCK() {
+          namespace refl = tvm::ffi::reflection;
+          refl::ObjectDef<MyClass>()
+              .def_static("__init__", [](int64_t v_i64, int32_t v_i32,
+                                         double v_f64, float v_f32) -> Any {
+                   return ObjectRef(ffi::make_object<MyClass>(
+                       v_i64, v_i32, v_f64, v_f32));
+               })
+              .def_rw("v_i64", &MyClass::v_i64)
+              .def_rw("v_i32", &MyClass::v_i32)
+              .def_rw("v_f64", &MyClass::v_f64)
+              .def_rw("v_f32", &MyClass::v_f32);
+        }
+
+    Mirror the same structure in Python using dataclass-style annotations:
+
+    .. code-block:: python
+
+        from tvm_ffi.dataclasses import c_class, field
+
+        @c_class("example.MyClass")
+        class MyClass:
+            v_i64: int
+            v_i32: int
+            v_f64: float = field(default=0.0)
+            v_f32: float = field(default_factory=lambda: 1.0)
+
+        obj = MyClass(v_i64=4, v_i32=8)
+        obj.v_f64 = 3.14  # transparently forwards to the underlying C++ object
+
+    """
+
+    def decorator(super_type_cls: type[_InputClsType]) -> type[_InputClsType]:
+        nonlocal init
+        init = init and "__init__" not in super_type_cls.__dict__
+        # Step 1. Retrieve `type_info` from registry
+        type_info: TypeInfo = _utils._lookup_type_info_from_type_key(type_key)
+        assert type_info.parent_type_info is None, f"Already registered type: 
{type_key}"
+        type_info.parent_type_info = 
_utils.get_parent_type_info(super_type_cls)
+        # Step 2. Reflect all the fields of the type
+        type_info.fields = _inspect_c_class_fields(super_type_cls, type_info)
+        for type_field in type_info.fields:
+            _utils.fill_dataclass_field(super_type_cls, type_field)
+        # Step 3. Create the proxy class with the fields as properties
+        fn_init = _utils.method_init(super_type_cls, type_info) if init else 
None
+        type_cls: type[_InputClsType] = _utils.type_info_to_cls(
+            type_info=type_info,
+            cls=super_type_cls,
+            methods={"__init__": fn_init},
+        )
+        type_info.type_cls = type_cls
+        return type_cls
+
+    return decorator
+
+
+def _inspect_c_class_fields(type_cls: type, type_info: TypeInfo) -> 
list[TypeField]:
+    type_hints_resolved = get_type_hints(type_cls, include_extras=True)
+    type_hints_py = {
+        name: type_hints_resolved[name]
+        for name in getattr(type_cls, "__annotations__", {}).keys()
+        if get_origin(type_hints_resolved[name])
+        not in [  # ignore non-field annotations
+            ClassVar,
+            InitVar,
+        ]
+    }
+    del type_hints_resolved
+
+    type_fields_cxx: dict[str, TypeField] = {f.name: f for f in 
type_info.fields}
+    type_fields: list[TypeField] = []
+    for field_name, _field_ty_py in type_hints_py.items():
+        if field_name.startswith("__tvm_ffi"):  # TVM's private fields - skip
+            continue
+        type_field: TypeField = type_fields_cxx.pop(field_name, None)
+        if type_field is None:
+            raise ValueError(
+                f"Extraneous field `{type_cls}.{field_name}`. Defined in 
Python but not in C++"
+            )
+        type_fields.append(type_field)
+    if type_fields_cxx:
+        extra_fields = ", ".join(f"`{f.name}`" for f in 
type_fields_cxx.values())
+        raise ValueError(
+            f"Missing fields in `{type_cls}`: {extra_fields}. Defined in C++ 
but not in Python"
+        )
+    return type_fields
diff --git a/python/tvm_ffi/dataclasses/field.py 
b/python/tvm_ffi/dataclasses/field.py
new file mode 100644
index 0000000..00170e5
--- /dev/null
+++ b/python/tvm_ffi/dataclasses/field.py
@@ -0,0 +1,95 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Public helpers for describing dataclass-style defaults on FFI proxies."""
+
+from __future__ import annotations
+
+from dataclasses import MISSING, dataclass
+from typing import Any, Callable
+
+
+@dataclass(kw_only=True)
+class Field:
+    """(Experimental) Descriptor placeholder returned by 
:func:`tvm_ffi.dataclasses.field`.
+
+    A ``Field`` mirrors the object returned by :func:`dataclasses.field`, but 
it
+    is understood by :func:`tvm_ffi.dataclasses.c_class`.  The decorator 
inspects
+    the ``Field`` instances, records the ``default_factory`` and later replaces
+    the field with a property that forwards to the underlying C++ attribute.
+
+    Users should not instantiate ``Field`` directly—use :func:`field` instead,
+    which guarantees that ``name`` and ``default_factory`` are populated in a
+    way the decorator understands.
+    """
+
+    name: str | None = None
+    default_factory: Callable[[], Any]
+
+
+def field(*, default: Any = MISSING, default_factory: Any = MISSING) -> Field:
+    """(Experimental) Declare a dataclass-style field on a :func:`c_class` 
proxy.
+
+    Use this helper exactly like :func:`dataclasses.field` when defining the
+    Python side of a C++ class.  When :func:`c_class` processes the class body 
it
+    replaces the placeholder with a property and arranges for ``default`` or
+    ``default_factory`` to be respected by the synthesized ``__init__``.
+
+    Parameters
+    ----------
+    default : Any, optional
+        A literal default value that should populate the field when no argument
+        is given.  The value is copied into a closure because TVM FFI does not
+        mutate the Python placeholder instance.
+    default_factory : Callable[[], Any], optional
+        A zero-argument callable that produces the default.  This matches the
+        semantics of :func:`dataclasses.field` and is useful for mutable
+        defaults such as ``list`` or ``dict``.
+
+    Returns
+    -------
+    Field
+        A placeholder object that :func:`c_class` will consume during class
+        registration.
+
+    Examples
+    --------
+    ``field`` integrates with :func:`c_class` to express defaults the same way 
a
+    Python ``dataclass`` would::
+
+        @c_class("testing.TestCxxClassBase")
+        class PyBase:
+            v_i64: int
+            v_i32: int = field(default=16)
+
+        obj = PyBase(v_i64=4)
+        obj.v_i32  # -> 16
+
+    """
+    if default is not MISSING and default_factory is not MISSING:
+        raise ValueError("Cannot specify both `default` and `default_factory`")
+    if default is not MISSING:
+        default_factory = _make_default_factory(default)
+    return Field(default_factory=default_factory)
+
+
+def _make_default_factory(value: Any) -> Callable[[], Any]:
+    """Make a default factory that returns the given value."""
+
+    def factory() -> Any:
+        return value
+
+    return factory
diff --git a/python/tvm_ffi/registry.py b/python/tvm_ffi/registry.py
index 5a540fb..6bf08f6 100644
--- a/python/tvm_ffi/registry.py
+++ b/python/tvm_ffi/registry.py
@@ -248,6 +248,8 @@ def _add_class_attrs(type_cls: type, type_info: TypeInfo) 
-> type:
         setattr(type_cls, name, property(getter, setter, doc=doc))
     for method in type_info.methods:
         name = method.name
+        if name == "__ffi_init__":
+            name = "__c_ffi_init__"
         doc = method.doc if method.doc else None
         method_func = method.func
         if method.is_static:
diff --git a/python/tvm_ffi/testing.py b/python/tvm_ffi/testing.py
index 3215d8a..825f9cf 100644
--- a/python/tvm_ffi/testing.py
+++ b/python/tvm_ffi/testing.py
@@ -16,10 +16,11 @@
 # under the License.
 """Testing utilities."""
 
-from typing import Any
+from typing import Any, ClassVar
 
 from . import _ffi_api
 from .core import Object
+from .dataclasses import c_class, field
 from .registry import register_object
 
 
@@ -34,7 +35,7 @@ class TestIntPair(Object):
 
     def __init__(self, a: int, b: int) -> None:
         """Construct the object."""
-        self.__init_handle_by_constructor__(TestIntPair.__ffi_init__, a, b)
+        self.__ffi_init__(a, b)
 
 
 @register_object("testing.TestObjectDerived")
@@ -68,3 +69,26 @@ def create_object(type_key: str, **kwargs: Any) -> Object:
         args.append(k)
         args.append(v)
     return _ffi_api.MakeObjectFromPackedArgs(*args)
+
+
+@c_class("testing.TestCxxClassBase")
+class _TestCxxClassBase:
+    v_i64: int
+    v_i32: int
+    not_field_1 = 1
+    not_field_2: ClassVar[int] = 2
+
+    def __init__(self, v_i64: int, v_i32: int) -> None:
+        self.__ffi_init__(v_i64 + 1, v_i32 + 2)
+
+
+@c_class("testing.TestCxxClassDerived")
+class _TestCxxClassDerived(_TestCxxClassBase):
+    v_f64: float
+    v_f32: float = 8
+
+
+@c_class("testing.TestCxxClassDerivedDerived")
+class _TestCxxClassDerivedDerived(_TestCxxClassDerived):
+    v_str: str = field(default_factory=lambda: "default")
+    v_bool: bool
diff --git a/src/ffi/extra/testing.cc b/src/ffi/extra/testing.cc
index 9c3a019..370a1c6 100644
--- a/src/ffi/extra/testing.cc
+++ b/src/ffi/extra/testing.cc
@@ -86,6 +86,41 @@ class TestObjectDerived : public TestObjectBase {
   TVM_FFI_DECLARE_OBJECT_INFO_FINAL("testing.TestObjectDerived", 
TestObjectDerived, TestObjectBase);
 };
 
+class TestCxxClassBase : public Object {
+ public:
+  int64_t v_i64;
+  int32_t v_i32;
+
+  TestCxxClassBase(int64_t v_i64, int32_t v_i32) : v_i64(v_i64), v_i32(v_i32) 
{}
+
+  static constexpr bool _type_mutable = true;
+  TVM_FFI_DECLARE_OBJECT_INFO("testing.TestCxxClassBase", TestCxxClassBase, 
Object);
+};
+
+class TestCxxClassDerived : public TestCxxClassBase {
+ public:
+  double v_f64;
+  float v_f32;
+
+  TestCxxClassDerived(int64_t v_i64, int32_t v_i32, double v_f64, float v_f32)
+      : TestCxxClassBase(v_i64, v_i32), v_f64(v_f64), v_f32(v_f32) {}
+
+  TVM_FFI_DECLARE_OBJECT_INFO("testing.TestCxxClassDerived", 
TestCxxClassDerived, TestCxxClassBase);
+};
+
+class TestCxxClassDerivedDerived : public TestCxxClassDerived {
+ public:
+  String v_str;
+  bool v_bool;
+
+  TestCxxClassDerivedDerived(int64_t v_i64, int32_t v_i32, double v_f64, float 
v_f32, String v_str,
+                             bool v_bool)
+      : TestCxxClassDerived(v_i64, v_i32, v_f64, v_f32), v_str(v_str), 
v_bool(v_bool) {}
+
+  TVM_FFI_DECLARE_OBJECT_INFO("testing.TestCxxClassDerivedDerived", 
TestCxxClassDerivedDerived,
+                              TestCxxClassDerived);
+};
+
 TVM_FFI_NO_INLINE void TestRaiseError(String kind, String msg) {
   // keep name and no liner for testing traceback
   throw ffi::Error(kind, msg, TVMFFITraceback(__FILE__, __LINE__, 
TVM_FFI_FUNC_SIG, 0));
@@ -110,6 +145,22 @@ TVM_FFI_STATIC_INIT_BLOCK() {
       .def_ro("v_map", &TestObjectDerived::v_map)
       .def_ro("v_array", &TestObjectDerived::v_array);
 
+  refl::ObjectDef<TestCxxClassBase>()
+      .def_static("__ffi_init__", refl::init<TestCxxClassBase, int64_t, 
int32_t>)
+      .def_rw("v_i64", &TestCxxClassBase::v_i64)
+      .def_rw("v_i32", &TestCxxClassBase::v_i32);
+
+  refl::ObjectDef<TestCxxClassDerived>()
+      .def_static("__ffi_init__", refl::init<TestCxxClassDerived, int64_t, 
int32_t, double, float>)
+      .def_rw("v_f64", &TestCxxClassDerived::v_f64)
+      .def_rw("v_f32", &TestCxxClassDerived::v_f32);
+  refl::ObjectDef<TestCxxClassDerivedDerived>()
+      .def_static(
+          "__ffi_init__",
+          refl::init<TestCxxClassDerivedDerived, int64_t, int32_t, double, 
float, String, bool>)
+      .def_rw("v_str", &TestCxxClassDerivedDerived::v_str)
+      .def_rw("v_bool", &TestCxxClassDerivedDerived::v_bool);
+
   refl::GlobalDef()
       .def("testing.test_raise_error", TestRaiseError)
       .def_packed("testing.nop", [](PackedArgs args, Any* ret) {})
diff --git a/tests/python/test_dataclasses_c_class.py 
b/tests/python/test_dataclasses_c_class.py
new file mode 100644
index 0000000..a2fa80e
--- /dev/null
+++ b/tests/python/test_dataclasses_c_class.py
@@ -0,0 +1,66 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from tvm_ffi.testing import _TestCxxClassBase, _TestCxxClassDerived, 
_TestCxxClassDerivedDerived
+
+
+def test_cxx_class_base() -> None:
+    obj = _TestCxxClassBase(v_i64=123, v_i32=456)
+    assert obj.v_i64 == 123 + 1
+    assert obj.v_i32 == 456 + 2
+
+
+def test_cxx_class_derived() -> None:
+    obj = _TestCxxClassDerived(v_i64=123, v_i32=456, v_f64=4.00, v_f32=8.00)
+    assert obj.v_i64 == 123
+    assert obj.v_i32 == 456
+    assert obj.v_f64 == 4.00
+    assert obj.v_f32 == 8.00
+
+
+def test_cxx_class_derived_default() -> None:
+    obj = _TestCxxClassDerived(v_i64=123, v_i32=456, v_f64=4.00)
+    assert obj.v_i64 == 123
+    assert obj.v_i32 == 456
+    assert obj.v_f64 == 4.00
+    assert isinstance(obj.v_f32, float) and obj.v_f32 == 8.00  # default value
+
+
+def test_cxx_class_derived_derived() -> None:
+    obj = _TestCxxClassDerivedDerived(
+        v_i64=123,
+        v_i32=456,
+        v_f64=4.00,
+        v_f32=8.00,
+        v_str="hello",
+        v_bool=True,
+    )
+    assert obj.v_i64 == 123
+    assert obj.v_i32 == 456
+    assert obj.v_f64 == 4.00
+    assert obj.v_f32 == 8.00
+    assert obj.v_str == "hello"
+    assert obj.v_bool is True
+
+
+def test_cxx_class_derived_derived_default() -> None:
+    obj = _TestCxxClassDerivedDerived(123, 456, 4, True)
+    assert obj.v_i64 == 123
+    assert obj.v_i32 == 456
+    assert isinstance(obj.v_f64, float) and obj.v_f64 == 4
+    assert isinstance(obj.v_f32, float) and obj.v_f32 == 8
+    assert obj.v_str == "default"
+    assert isinstance(obj.v_bool, bool) and obj.v_bool is True

Reply via email to