This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new b97ff1ae refactor(dataclasses)!: remove Python-side field descriptor
infrastructure (#478)
b97ff1ae is described below
commit b97ff1ae2abd21f5b8a368d5e04f34b53e3985bf
Author: Junru Shao <[email protected]>
AuthorDate: Fri Feb 27 03:43:55 2026 -0800
refactor(dataclasses)!: remove Python-side field descriptor infrastructure
(#478)
## Summary
- Remove the Python-side dataclass field descriptor system (`_utils.py`,
`field.py`, complex `c_class.py`) that duplicated C++ reflection
metadata
- Simplify `c_class` decorator to a thin pass-through to
`register_object`
- Fix `_add_class_attrs` to always override `__c_ffi_init__` per type,
preventing inherited base-class constructors from masking derived-class
constructors
## Architecture
The previous `c_class` decorator maintained a parallel Python-side field
descriptor system (`Field`, `KW_ONLY`, `default_factory`, codegen'd
`__init__`) that mirrored what C++ reflection already provides. This PR
removes that duplication:
- **Deleted**: `_utils.py` (210 lines — `type_info_to_cls`,
`fill_dataclass_field`, `method_init` codegen), `field.py` (169 lines —
`Field` class, `KW_ONLY` sentinel, default factory wiring)
- **Simplified**: `c_class.py` reduced from 190 to 36 lines — now
delegates directly to `register_object`
- **Fixed**: `_add_class_attrs` in `registry.py` now always overrides
`__c_ffi_init__` on each type (matching the existing
`__ffi_shallow_copy__` override pattern), preventing a derived class
from inheriting a base class constructor with the wrong field count
## Breaking Changes
- `field()`, `Field`, `KW_ONLY`, and `MISSING` are no longer exported
from `tvm_ffi.dataclasses`
- `c_class`-decorated classes must now explicitly inherit from `Object`
(previously `_utils.type_info_to_cls` injected it)
- `__init__` on decorated types uses the C++ FFI constructor directly
(positional args in field order) instead of the Python codegen'd init
with keyword-only and default factory support
- The `init` and `kw_only` parameters of `c_class()` are removed
## Test plan
- [x] `tests/python/test_repr.py` updated — derived constructors use
positional args
- [x] `tests/python/test_dataclasses_c_class.py` deleted — tested
removed infrastructure
- [x] Full CI pass on all platforms
---
python/tvm_ffi/dataclasses/__init__.py | 7 +-
python/tvm_ffi/dataclasses/_utils.py | 210 -------------------------------
python/tvm_ffi/dataclasses/c_class.py | 190 +++-------------------------
python/tvm_ffi/dataclasses/field.py | 169 -------------------------
python/tvm_ffi/registry.py | 3 +
python/tvm_ffi/testing/testing.py | 24 ++--
tests/python/test_dataclasses_c_class.py | 151 ----------------------
tests/python/test_repr.py | 6 +-
8 files changed, 33 insertions(+), 727 deletions(-)
diff --git a/python/tvm_ffi/dataclasses/__init__.py
b/python/tvm_ffi/dataclasses/__init__.py
index bfb44049..912d6bd1 100644
--- a/python/tvm_ffi/dataclasses/__init__.py
+++ b/python/tvm_ffi/dataclasses/__init__.py
@@ -14,11 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Experimental FFI interface that exposes C++ classes to Python in dataclass
syntax."""
-
-from dataclasses import MISSING
+"""C++ FFI classes registered via ``c_class`` decorator."""
from .c_class import c_class
-from .field import KW_ONLY, Field, field
-__all__ = ["KW_ONLY", "MISSING", "Field", "c_class", "field"]
+__all__ = ["c_class"]
diff --git a/python/tvm_ffi/dataclasses/_utils.py
b/python/tvm_ffi/dataclasses/_utils.py
deleted file mode 100644
index 80c39874..00000000
--- a/python/tvm_ffi/dataclasses/_utils.py
+++ /dev/null
@@ -1,210 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Utilities for constructing Python proxies of FFI types."""
-
-from __future__ import annotations
-
-import functools
-from dataclasses import MISSING
-from typing import Any, Callable, Type, TypeVar, cast
-
-from ..core import (
- Object,
- TypeField,
- TypeInfo,
-)
-
-_InputClsType = TypeVar("_InputClsType")
-
-
-def type_info_to_cls(
- type_info: TypeInfo,
- cls: Type[_InputClsType], # noqa: UP006
- methods: dict[str, Callable[..., Any] | None],
-) -> Type[_InputClsType]: # noqa: UP006
- assert type_info.type_cls is None, "Type class is already created"
- # Step 1. Determine the base classes
- cls_bases = cls.__bases__
- if cls_bases == (object,):
- # If the class inherits from `object`, we need to set the base class
to `Object`
- cls_bases = (Object,)
-
- # Step 2. Define the new class attributes
- attrs = dict(cls.__dict__)
- attrs.pop("__dict__", None)
- attrs.pop("__weakref__", None)
- attrs["__slots__"] = ()
- attrs["__tvm_ffi_type_info__"] = type_info
-
- # Step 2. Add fields
- for field in type_info.fields:
- attrs[field.name] = field.as_property(cls)
-
- # Step 3. Add methods
- for name, method_impl in methods.items():
- if method_impl is not None:
- method_impl.__module__ = cls.__module__
- method_impl.__name__ = name # ty: ignore[unresolved-attribute]
- method_impl.__qualname__ = f"{cls.__qualname__}.{name}" # ty:
ignore[unresolved-attribute]
- method_impl.__doc__ = f"Method `{name}` of class
`{cls.__qualname__}`"
- attrs[name] = method_impl
- for method in type_info.methods:
- name = method.name
- if name == "__ffi_init__":
- name = "__c_ffi_init__"
- # as_callable wraps instance methods so `self` is passed to the C++
function,
- # and wraps static methods with staticmethod(); it also sets
__module__,
- # __name__, __qualname__, and __doc__ so we insert directly into attrs.
- func = method.as_callable(cls)
- if name != method.name:
- # Rename was applied (e.g. __ffi_init__ -> __c_ffi_init__)
- inner = func.__func__ if isinstance(func, staticmethod) else func
- inner.__name__ = name # ty: ignore[invalid-assignment]
- inner.__qualname__ = f"{cls.__qualname__}.{name}" # ty:
ignore[invalid-assignment]
- attrs[name] = func
-
- # Step 4. Create the new class
- new_cls = type(cls.__name__, cls_bases, attrs)
- new_cls.__module__ = cls.__module__
- new_cls = functools.wraps(cls, updated=())(new_cls)
- return cast(Type[_InputClsType], new_cls)
-
-
-def fill_dataclass_field(
- type_cls: type,
- type_field: TypeField,
- *,
- class_kw_only: bool = False,
- kw_only_from_sentinel: bool = False,
-) -> None:
- from .field import Field, field # noqa: PLC0415
-
- field_name = type_field.name
- rhs: Any = getattr(type_cls, field_name, MISSING)
- if rhs is MISSING:
- rhs = field()
- elif isinstance(rhs, Field):
- pass
- elif isinstance(rhs, (int, float, str, bool, type(None))):
- rhs = field(default=rhs)
- else:
- raise ValueError(f"Cannot recognize field: {type_field.name}: {rhs}")
- assert isinstance(rhs, Field)
- rhs.name = type_field.name
-
- # Resolve kw_only: field-level > KW_ONLY sentinel > class-level
- if rhs.kw_only is MISSING:
- if kw_only_from_sentinel:
- rhs.kw_only = True
- else:
- rhs.kw_only = class_kw_only
-
- type_field.dataclass_field = rhs
-
-
-def _get_all_fields(type_info: TypeInfo) -> list[TypeField]:
- """Collect all fields from the type hierarchy, from parents to children."""
- fields: list[TypeField] = []
- cur_type_info: TypeInfo | None = type_info
- while cur_type_info is not None:
- fields.extend(reversed(cur_type_info.fields))
- cur_type_info = cur_type_info.parent_type_info
- fields.reverse()
- return fields
-
-
-def method_init(_type_cls: type, type_info: TypeInfo) -> Callable[..., None]:
- """Generate an ``__init__`` that forwards to the FFI constructor.
-
- The generated initializer has a proper Python signature built from the
- reflected field list, supporting default values, keyword-only args, and
``__post_init__``.
- """
- # Step 0. Collect all fields from the type hierarchy
- fields = _get_all_fields(type_info)
- # sanity check
- if not any(m.name == "__ffi_init__" for m in type_info.methods):
- raise ValueError(f"Cannot find constructor method:
`{type_info.type_key}.__ffi_init__`")
- # Step 1. Split args into sections and register default factories
- pos_no_defaults: list[str] = []
- pos_with_defaults: list[str] = []
- kw_no_defaults: list[str] = []
- kw_with_defaults: list[str] = []
- fields_with_defaults: list[tuple[str, bool]] = []
- ffi_arg_order: list[str] = []
- exec_globals: dict[str, Any] = {"MISSING": MISSING}
-
- for field in fields:
- assert field.name is not None
- assert field.dataclass_field is not None
- dataclass_field = field.dataclass_field
- has_default = (default_factory := dataclass_field.default_factory) is
not MISSING
- is_kw_only = dataclass_field.kw_only is True
-
- if dataclass_field.init:
- ffi_arg_order.append(field.name)
- if has_default:
- (kw_with_defaults if is_kw_only else
pos_with_defaults).append(field.name)
- fields_with_defaults.append((field.name, True))
- exec_globals[f"_default_factory_{field.name}"] =
default_factory
- else:
- (kw_no_defaults if is_kw_only else
pos_no_defaults).append(field.name)
- elif has_default:
- ffi_arg_order.append(field.name)
- fields_with_defaults.append((field.name, False))
- exec_globals[f"_default_factory_{field.name}"] = default_factory
-
- # Step 2. Build signature
- args: list[str] = ["self"]
- args.extend(pos_no_defaults)
- args.extend(f"{name}=MISSING" for name in pos_with_defaults)
- if kw_no_defaults or kw_with_defaults:
- args.append("*")
- args.extend(kw_no_defaults)
- args.extend(f"{name}=MISSING" for name in kw_with_defaults)
-
- # Step 3. Build body
- body_lines: list[str] = []
- for field_name, is_init in fields_with_defaults:
- if is_init:
- body_lines.append(
- f"if {field_name} is MISSING: {field_name} =
_default_factory_{field_name}()"
- )
- else:
- body_lines.append(f"{field_name} =
_default_factory_{field_name}()")
- body_lines.append(f"self.__ffi_init__({', '.join(ffi_arg_order)})")
- body_lines.extend(
- [
- "try:",
- " fn_post_init = self.__post_init__",
- "except AttributeError:",
- " pass",
- "else:",
- " fn_post_init()",
- ]
- )
-
- source_lines = [f"def __init__({', '.join(args)}):"]
- source_lines.extend(f" {line}" for line in body_lines)
- source_lines.append(" ...")
- source = "\n".join(source_lines)
- # Note: Code generation in this case is guaranteed to be safe,
- # because the generated code does not contain any untrusted input.
- # This is also a common practice used by `dataclasses` and `pydantic`.
- namespace: dict[str, Any] = {}
- exec(source, exec_globals, namespace)
- __init__ = namespace["__init__"]
- return __init__
diff --git a/python/tvm_ffi/dataclasses/c_class.py
b/python/tvm_ffi/dataclasses/c_class.py
index 295918a9..d836f1a9 100644
--- a/python/tvm_ffi/dataclasses/c_class.py
+++ b/python/tvm_ffi/dataclasses/c_class.py
@@ -14,198 +14,36 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Helpers for mirroring registered C++ FFI types with Python dataclass syntax.
-
-The :func:`c_class` decorator is the primary entry point. It inspects the
-reflection metadata that the C++ runtime exposes via the TVM FFI registry and
-turns it into Python ``dataclass``-style descriptors: annotated attributes
become
-properties that forward to the underlying C++ object, while an ``__init__``
-method is synthesized to call the FFI constructor when requested.
-"""
+"""The ``c_class`` decorator: pass-through to ``register_object``."""
from __future__ import annotations
-import sys
from collections.abc import Callable
-from dataclasses import InitVar
-from typing import ClassVar, Type, TypeVar, get_origin, get_type_hints
-
-from typing_extensions import dataclass_transform
-
-from ..core import TypeField, TypeInfo,
_lookup_or_register_type_info_from_type_key, _set_type_cls
-from . import _utils
-from .field import KW_ONLY, field
+from typing import Any, TypeVar
-_InputClsType = TypeVar("_InputClsType")
+_T = TypeVar("_T", bound=type)
-@dataclass_transform(field_specifiers=(field,), kw_only_default=False)
-def c_class(
- type_key: str, init: bool = True, kw_only: bool = False
-) -> Callable[[Type[_InputClsType]], Type[_InputClsType]]: # noqa: UP006
- """(Experimental) Create a dataclass-like proxy for a C++ class registered
with TVM FFI.
+def c_class(type_key: str, **kwargs: Any) -> Callable[[_T], _T]:
+ """Register a C++ FFI class by type key.
- The decorator reads the reflection metadata that was registered on the C++
- side using ``tvm::ffi::reflection::ObjectDef`` and binds it to the
annotated
- attributes in the decorated Python class. Each field defined in C++ becomes
- a property on the Python class, and optional default values can be provided
- with :func:`tvm_ffi.dataclasses.field` in the same way as Python's native
- ``dataclasses.field``.
-
- The intent is to offer a familiar dataclass authoring experience while
still
- exposing the underlying C++ object. The ``type_key`` of the C++ class must
- match the string passed to :func:`c_class`, and inheritance relationships
are
- preserved-subclasses registered in C++ can subclass the Python proxy
defined
- for their parent.
+ This is a thin wrapper around :func:`~tvm_ffi.register_object` that
+ accepts (and currently ignores) additional keyword arguments for
+ forward compatibility.
Parameters
----------
type_key
- The reflection key that identifies the C++ type in the FFI registry,
- e.g. ``"testing.MyClass"`` as registered in
- ``src/ffi/extra/testing.cc``.
-
- init
- If ``True`` and the Python class does not define ``__init__``, an
- initializer is auto-generated that mirrors the reflected constructor
- signature. The generated initializer calls the C++ ``__init__``
- function registered with ``ObjectDef`` and invokes ``__post_init__`` if
- it exists on the Python class.
-
- kw_only
- If ``True``, all fields become keyword-only parameters in the generated
- ``__init__``. Individual fields can override this by setting
- ``kw_only=False`` in :func:`field`. Additionally, a ``KW_ONLY``
sentinel
- annotation can be used to mark all subsequent fields as keyword-only.
+ The reflection key that identifies the C++ type in the FFI registry.
+ kwargs
+ Reserved for future use.
Returns
-------
Callable[[type], type]
- A class decorator that materializes the final proxy class.
-
- Examples
- --------
- Register the C++ type and its fields with TVM FFI:
-
- .. code-block:: c++
-
- TVM_FFI_STATIC_INIT_BLOCK() {
- namespace refl = tvm::ffi::reflection;
- refl::ObjectDef<MyClass>()
- .def_static("__init__", [](int64_t v_i64, int32_t v_i32,
- double v_f64, float v_f32) -> Any {
- return ObjectRef(ffi::make_object<MyClass>(
- v_i64, v_i32, v_f64, v_f32));
- })
- .def_rw("v_i64", &MyClass::v_i64)
- .def_rw("v_i32", &MyClass::v_i32)
- .def_rw("v_f64", &MyClass::v_f64)
- .def_rw("v_f32", &MyClass::v_f32);
- }
-
- Mirror the same structure in Python using dataclass-style annotations:
-
- .. code-block:: python
-
- from tvm_ffi.dataclasses import c_class, field
-
-
- @c_class("example.MyClass")
- class MyClass:
- v_i64: int
- v_i32: int
- v_f64: float = field(default=0.0)
- v_f32: float = field(default_factory=lambda: 1.0)
-
-
- obj = MyClass(v_i64=4, v_i32=8)
- obj.v_f64 = 3.14 # transparently forwards to the underlying C++ object
+ A class decorator.
"""
+ from ..registry import register_object # noqa: PLC0415
- def decorator(super_type_cls: Type[_InputClsType]) -> Type[_InputClsType]:
# noqa: UP006
- nonlocal init
- init = init and "__init__" not in super_type_cls.__dict__
- # Step 1. Retrieve `type_info` from registry
- type_info: TypeInfo =
_lookup_or_register_type_info_from_type_key(type_key)
- assert type_info.parent_type_info is not None
- # Step 2. Reflect all the fields of the type
- type_info.fields, kw_only_start_idx =
_inspect_c_class_fields(super_type_cls, type_info)
- for idx, type_field in enumerate(type_info.fields):
- kw_only_from_sentinel = kw_only_start_idx is not None and idx >=
kw_only_start_idx
- _utils.fill_dataclass_field(
- super_type_cls,
- type_field,
- class_kw_only=kw_only,
- kw_only_from_sentinel=kw_only_from_sentinel,
- )
- # Step 3. Create the proxy class with the fields as properties
- fn_init = _utils.method_init(super_type_cls, type_info) if init else
None
- type_cls: Type[_InputClsType] = _utils.type_info_to_cls( # noqa: UP006
- type_info=type_info,
- cls=super_type_cls,
- methods={"__init__": fn_init},
- )
- _set_type_cls(type_info, type_cls)
- # Step 4. Set up __copy__, __deepcopy__, __replace__
- from ..registry import _setup_copy_methods # noqa: PLC0415
-
- has_shallow_copy = any(m.name == "__ffi_shallow_copy__" for m in
type_info.methods)
- _setup_copy_methods(type_cls, has_shallow_copy)
- return type_cls
-
- return decorator
-
-
-def _inspect_c_class_fields(
- type_cls: type, type_info: TypeInfo
-) -> tuple[list[TypeField], int | None]:
- if sys.version_info >= (3, 9):
- type_hints_resolved = get_type_hints(type_cls, include_extras=True)
- else:
- type_hints_resolved = get_type_hints(type_cls)
- type_hints_py = {
- name: type_hints_resolved[name]
- for name in getattr(type_cls, "__annotations__", {}).keys()
- if get_origin(type_hints_resolved[name])
- not in [ # ignore non-field annotations
- ClassVar,
- InitVar,
- ]
- and type_hints_resolved[name] is not KW_ONLY
- }
-
- # Detect KW_ONLY sentinel position
- kw_only_start_idx: int | None = None
- field_count = 0
- for name in getattr(type_cls, "__annotations__", {}).keys():
- resolved_type = type_hints_resolved.get(name)
- if resolved_type is None:
- continue
- if get_origin(resolved_type) in [ClassVar, InitVar]:
- continue
- if resolved_type is KW_ONLY:
- if kw_only_start_idx is not None:
- raise ValueError(f"KW_ONLY may only be used once per class:
{type_cls}")
- kw_only_start_idx = field_count
- continue
- field_count += 1
- del type_hints_resolved
-
- type_fields_cxx: dict[str, TypeField] = {f.name: f for f in
type_info.fields}
- type_fields: list[TypeField] = []
- for field_name, _field_ty_py in type_hints_py.items():
- if field_name.startswith("__tvm_ffi"): # TVM's private fields - skip
- continue
- type_field = type_fields_cxx.pop(field_name, None)
- if type_field is None:
- raise ValueError(
- f"Extraneous field `{type_cls}.{field_name}`. Defined in
Python but not in C++"
- )
- type_fields.append(type_field)
- if type_fields_cxx:
- extra_fields = ", ".join(f"`{f.name}`" for f in
type_fields_cxx.values())
- raise ValueError(
- f"Missing fields in `{type_cls}`: {extra_fields}. Defined in C++
but not in Python"
- )
- return type_fields, kw_only_start_idx
+ return register_object(type_key)
diff --git a/python/tvm_ffi/dataclasses/field.py
b/python/tvm_ffi/dataclasses/field.py
deleted file mode 100644
index 2378da87..00000000
--- a/python/tvm_ffi/dataclasses/field.py
+++ /dev/null
@@ -1,169 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Public helpers for describing dataclass-style defaults on FFI proxies."""
-
-from __future__ import annotations
-
-from dataclasses import _MISSING_TYPE, MISSING
-from typing import Any, Callable, TypeVar, cast
-
-try:
- from dataclasses import KW_ONLY # ty: ignore[unresolved-import]
-except ImportError:
- # Python < 3.10: define our own KW_ONLY sentinel
- class _KW_ONLY_Sentinel:
- __slots__ = ()
-
- KW_ONLY = _KW_ONLY_Sentinel()
-
-_FieldValue = TypeVar("_FieldValue")
-_KW_ONLY_TYPE = type(KW_ONLY)
-
-
-class Field:
- """(Experimental) Descriptor placeholder returned by
:func:`tvm_ffi.dataclasses.field`.
-
- A ``Field`` mirrors the object returned by :func:`dataclasses.field`, but
it
- is understood by :func:`tvm_ffi.dataclasses.c_class`. The decorator
inspects
- the ``Field`` instances, records the ``default_factory`` and later replaces
- the field with a property that forwards to the underlying C++ attribute.
-
- Users should not instantiate ``Field`` directly - use :func:`field`
instead,
- which guarantees that ``name`` and ``default_factory`` are populated in a
- way the decorator understands.
- """
-
- __slots__ = ("default_factory", "init", "kw_only", "name")
-
- def __init__(
- self,
- *,
- name: str | None = None,
- default_factory: Callable[[], _FieldValue] | _MISSING_TYPE = MISSING,
- init: bool = True,
- kw_only: bool | _MISSING_TYPE = MISSING,
- ) -> None:
- """Do not call directly; use :func:`field` instead."""
- self.name = name
- self.default_factory = default_factory
- self.init = init
- self.kw_only = kw_only
-
-
-def field(
- *,
- default: _FieldValue | _MISSING_TYPE = MISSING,
- default_factory: Callable[[], _FieldValue] | _MISSING_TYPE = MISSING,
- init: bool = True,
- kw_only: bool | _MISSING_TYPE = MISSING,
-) -> _FieldValue:
- """(Experimental) Declare a dataclass-style field on a :func:`c_class`
proxy.
-
- Use this helper exactly like :func:`dataclasses.field` when defining the
- Python side of a C++ class. When :func:`c_class` processes the class body
it
- replaces the placeholder with a property and arranges for ``default`` or
- ``default_factory`` to be respected by the synthesized ``__init__``.
-
- Parameters
- ----------
- default
- A literal default value that populates the field when no argument
- is given. At most one of ``default`` or ``default_factory`` may be
- given.
- default_factory
- A zero-argument callable that produces the default. This matches the
- semantics of :func:`dataclasses.field` and is useful for mutable
- defaults such as ``list`` or ``dict``.
- init
- If ``True`` the field is included in the generated ``__init__``.
- If ``False`` the field is omitted from input arguments of ``__init__``.
- kw_only
- If ``True``, the field is a keyword-only argument in ``__init__``.
- If ``MISSING``, inherits from the class-level ``kw_only`` setting or
- from a preceding ``KW_ONLY`` sentinel annotation.
-
- Note
- ----
- The decision to forward a field to the C++ ``__ffi_init__`` constructor
- depends on its configuration:
-
- * If ``init=True``, the field's value (from user input or defaults)
- is forwarded.
-
- * If ``init=False``:
-
- - With a ``default`` or ``default_factory``, its computed value is
- forwarded. The user cannot provide this value via Python
``__init__``.
-
- - Without a ``default`` or ``default_factory``, the field is *not*
- forwarded to C++ ``__ffi_init__`` and must be initialized by the
- C++ constructor.
-
- Returns
- -------
- Field
- A placeholder object that :func:`c_class` will consume during class
- registration.
-
- Examples
- --------
- ``field`` integrates with :func:`c_class` to express defaults the same way
a
- Python ``dataclass`` would:
-
- .. code-block:: python
-
- @c_class("testing.TestCxxClassBase")
- class PyBase:
- v_i64: int
- v_i32: int = field(default=16)
-
-
- obj = PyBase(v_i64=4)
- obj.v_i32 # -> 16
-
- Use ``kw_only=True`` to make a field keyword-only:
-
- .. code-block:: python
-
- @c_class("testing.TestCxxClassBase")
- class PyBase:
- v_i64: int
- v_i32: int = field(kw_only=True)
-
-
- obj = PyBase(4, v_i32=8) # v_i32 must be keyword
-
- """
- if default is not MISSING and default_factory is not MISSING:
- raise ValueError("Cannot specify both `default` and `default_factory`")
- if not isinstance(init, bool):
- raise TypeError("`init` must be a bool")
- if kw_only is not MISSING and not isinstance(kw_only, bool):
- raise TypeError(f"`kw_only` must be a bool, got
{type(kw_only).__name__!r}")
- if default is not MISSING:
- default_factory = _make_default_factory(default)
- ret = Field(default_factory=default_factory, init=init, kw_only=kw_only)
- return cast(_FieldValue, ret)
-
-
-def _make_default_factory(value: Any) -> Callable[[], Any]:
- """Make a default factory that returns the given value."""
-
- def factory() -> Any:
- return value
-
- return factory
diff --git a/python/tvm_ffi/registry.py b/python/tvm_ffi/registry.py
index e39cc44d..46074c28 100644
--- a/python/tvm_ffi/registry.py
+++ b/python/tvm_ffi/registry.py
@@ -346,6 +346,9 @@ def _add_class_attrs(type_cls: type, type_info: TypeInfo)
-> type:
has_shallow_copy = True
# Always override: shallow copy is type-specific and must not be
inherited
setattr(type_cls, name, method.as_callable(type_cls))
+ elif name == "__c_ffi_init__":
+ # Always override: each type has its own constructor signature
+ setattr(type_cls, name, method.as_callable(type_cls))
elif not hasattr(type_cls, name):
setattr(type_cls, name, method.as_callable(type_cls))
if "__init__" not in type_cls.__dict__:
diff --git a/python/tvm_ffi/testing/testing.py
b/python/tvm_ffi/testing/testing.py
index c5578910..4d9502da 100644
--- a/python/tvm_ffi/testing/testing.py
+++ b/python/tvm_ffi/testing/testing.py
@@ -34,7 +34,7 @@ from typing import ClassVar
from .. import _ffi_api
from ..core import Object
-from ..dataclasses import c_class, field
+from ..dataclasses import c_class
from ..registry import get_global_func, register_object
@@ -163,38 +163,38 @@ def add_one(x: int) -> int:
@c_class("testing.TestCxxClassBase")
-class _TestCxxClassBase:
+class _TestCxxClassBase(Object):
v_i64: int
v_i32: int
not_field_1 = 1
not_field_2: ClassVar[int] = 2
def __init__(self, v_i64: int, v_i32: int) -> None:
- self.__ffi_init__(v_i64 + 1, v_i32 + 2) # ty:
ignore[unresolved-attribute]
+ self.__ffi_init__(v_i64 + 1, v_i32 + 2)
@c_class("testing.TestCxxClassDerived")
class _TestCxxClassDerived(_TestCxxClassBase):
v_f64: float
- v_f32: float = 8
+ v_f32: float
@c_class("testing.TestCxxClassDerivedDerived")
class _TestCxxClassDerivedDerived(_TestCxxClassDerived):
- v_str: str = field(default_factory=lambda: "default")
- v_bool: bool # ty: ignore[dataclass-field-order] # Required field after
fields with defaults
+ v_str: str
+ v_bool: bool
@c_class("testing.TestCxxInitSubset")
-class _TestCxxInitSubset:
+class _TestCxxInitSubset(Object):
required_field: int
- optional_field: int = field(init=False)
- note: str = field(default_factory=lambda: "py-default", init=False)
+ optional_field: int
+ note: str
-@c_class("testing.TestCxxKwOnly", kw_only=True)
-class _TestCxxKwOnly:
+@c_class("testing.TestCxxKwOnly")
+class _TestCxxKwOnly(Object):
x: int
y: int
z: int
- w: int = 100
+ w: int
diff --git a/tests/python/test_dataclasses_c_class.py
b/tests/python/test_dataclasses_c_class.py
deleted file mode 100644
index ff2df3c9..00000000
--- a/tests/python/test_dataclasses_c_class.py
+++ /dev/null
@@ -1,151 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-import inspect
-from dataclasses import MISSING
-
-import pytest
-from tvm_ffi.dataclasses import KW_ONLY, field
-from tvm_ffi.dataclasses.field import _KW_ONLY_TYPE, Field
-from tvm_ffi.testing import (
- _TestCxxClassBase,
- _TestCxxClassDerived,
- _TestCxxClassDerivedDerived,
- _TestCxxInitSubset,
- _TestCxxKwOnly,
-)
-
-
-def test_cxx_class_base() -> None:
- obj = _TestCxxClassBase(v_i64=123, v_i32=456)
- assert obj.v_i64 == 123 + 1
- assert obj.v_i32 == 456 + 2
-
-
-def test_cxx_class_derived() -> None:
- obj = _TestCxxClassDerived(v_i64=123, v_i32=456, v_f64=4.00, v_f32=8.00)
- assert obj.v_i64 == 123
- assert obj.v_i32 == 456
- assert obj.v_f64 == 4.00
- assert obj.v_f32 == 8.00
-
-
-def test_cxx_class_derived_default() -> None:
- obj = _TestCxxClassDerived(v_i64=123, v_i32=456, v_f64=4.00)
- assert obj.v_i64 == 123
- assert obj.v_i32 == 456
- assert obj.v_f64 == 4.00
- assert isinstance(obj.v_f32, float) and obj.v_f32 == 8.00 # default value
-
-
-def test_cxx_class_derived_derived() -> None:
- obj = _TestCxxClassDerivedDerived(
- v_i64=123,
- v_i32=456,
- v_f64=4.00,
- v_f32=8.00,
- v_str="hello",
- v_bool=True,
- )
- assert obj.v_i64 == 123
- assert obj.v_i32 == 456
- assert obj.v_f64 == 4.00
- assert obj.v_f32 == 8.00
- assert obj.v_str == "hello"
- assert obj.v_bool is True
-
-
-def test_cxx_class_derived_derived_default() -> None:
- obj = _TestCxxClassDerivedDerived(123, 456, 4, True) # ty:
ignore[missing-argument]
- assert obj.v_i64 == 123
- assert obj.v_i32 == 456
- assert isinstance(obj.v_f64, float) and obj.v_f64 == 4
- assert isinstance(obj.v_f32, float) and obj.v_f32 == 8
- assert obj.v_str == "default"
- assert isinstance(obj.v_bool, bool) and obj.v_bool is True
-
-
-def test_cxx_class_init_subset_signature() -> None:
- sig = inspect.signature(_TestCxxInitSubset.__init__)
- params = tuple(sig.parameters)
- assert "required_field" in params
- assert "optional_field" not in params
- assert "note" not in params
-
-
-def test_cxx_class_init_subset_defaults() -> None:
- obj = _TestCxxInitSubset(required_field=42)
- assert obj.required_field == 42
- assert obj.optional_field == -1
- assert obj.note == "py-default"
-
-
-def test_cxx_class_init_subset_positional() -> None:
- obj = _TestCxxInitSubset(7)
- assert obj.required_field == 7
- assert obj.optional_field == -1
- obj.optional_field = 11
- assert obj.optional_field == 11
-
-
-def test_kw_only_class_level_signature() -> None:
- sig = inspect.signature(_TestCxxKwOnly.__init__)
- params = sig.parameters
- assert params["x"].kind == inspect.Parameter.KEYWORD_ONLY
- assert params["y"].kind == inspect.Parameter.KEYWORD_ONLY
- assert params["z"].kind == inspect.Parameter.KEYWORD_ONLY
- assert params["w"].kind == inspect.Parameter.KEYWORD_ONLY
-
-
-def test_kw_only_class_level_call() -> None:
- obj = _TestCxxKwOnly(x=1, y=2, z=3, w=4)
- assert obj.x == 1
- assert obj.y == 2
- assert obj.z == 3
- assert obj.w == 4
-
-
-def test_kw_only_class_level_with_default() -> None:
- obj = _TestCxxKwOnly(x=1, y=2, z=3)
- assert obj.w == 100
-
-
-def test_kw_only_class_level_rejects_positional() -> None:
- with pytest.raises(TypeError, match="positional"):
- _TestCxxKwOnly(1, 2, 3, 4) # ty: ignore[missing-argument,
too-many-positional-arguments]
-
-
-def test_field_kw_only_parameter() -> None:
- f1: Field = field(kw_only=True)
- assert isinstance(f1, Field)
- assert f1.kw_only is True
-
- f2: Field = field(kw_only=False)
- assert f2.kw_only is False
-
- f3: Field = field()
- assert f3.kw_only is MISSING
-
-
-def test_field_kw_only_with_default() -> None:
- f = field(default=42, kw_only=True)
- assert isinstance(f, Field)
- assert f.kw_only is True
- assert f.default_factory() == 42
-
-
-def test_kw_only_sentinel_exists() -> None:
- assert isinstance(KW_ONLY, _KW_ONLY_TYPE)
diff --git a/tests/python/test_repr.py b/tests/python/test_repr.py
index 312c517b..768e6df1 100644
--- a/tests/python/test_repr.py
+++ b/tests/python/test_repr.py
@@ -191,7 +191,7 @@ def test_repr_user_object_all_fields() -> None:
def test_repr_user_object_repr_off() -> None:
"""Test repr of object with Repr(false) fields excluded."""
- obj = tvm_ffi.testing._TestCxxClassDerived(v_i64=1, v_i32=2, v_f64=3.5,
v_f32=4.5)
+ obj = tvm_ffi.testing._TestCxxClassDerived(1, 2, 3.5, 4.5) # ty:
ignore[too-many-positional-arguments]
assert ReprPrint(obj) == "testing.TestCxxClassDerived(v_f64=3.5,
v_f32=4.5)"
@@ -390,9 +390,7 @@ def test_repr_map_with_object_values() -> None:
def test_repr_derived_derived_shows_all_own_fields() -> None:
"""TestCxxClassDerivedDerived should show v_f64, v_f32, v_str, v_bool (not
v_i64, v_i32)."""
- obj = tvm_ffi.testing._TestCxxClassDerivedDerived(
- v_i64=1, v_i32=2, v_f64=3.0, v_f32=4.0, v_str="test", v_bool=True
- )
+ obj = tvm_ffi.testing._TestCxxClassDerivedDerived(1, 2, 3.0, 4.0, "test",
True) # ty: ignore[too-many-positional-arguments]
assert (
ReprPrint(obj)
== 'testing.TestCxxClassDerivedDerived(v_f64=3, v_f32=4, v_str="test",
v_bool=True)'