This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new daeb235 feat: Support `field(init=...)` in `@c_class` fields (#52)
daeb235 is described below
commit daeb235a29c576d8702d447fa5f4773170bb1e8f
Author: Junru Shao <[email protected]>
AuthorDate: Wed Sep 24 12:56:41 2025 -0700
feat: Support `field(init=...)` in `@c_class` fields (#52)
This PR enhances the `@c_class` decorator by integrating `field(...,
init: bool=True)`.
This feature gives finer control over the generated Python `__init__`
method, making it possible to exclude specific fields from the
constructor's signature, and thereby increase the flexibility and
alignment of tvm-ffi's dataclass system with Python's native dataclass
capabilities.
## Example
```cpp
// C++ side
class TestCxxInitSubsetObj : public Object {
public:
int64_t required_field;
int64_t optional_field;
String note;
explicit TestCxxInitSubsetObj(int64_t value, String note)
: required_field(value), optional_field(-1), note(note) {}
^^^^^^^^^^^^^^^^^^^^
this field is not present in constructor
TVM_FFI_DECLARE_OBJECT_INFO("testing.TestCxxInitSubset",
TestCxxInitSubsetObj, Object);
};
```
```python
# Python side
@c_class("testing.TestCxxInitSubset")
class _TestCxxInitSubset:
required_field: int
optional_field: int = field(init=False)
^^^^^^^^^^
doesn't appear in `__init__`;
won't be forwarded to C++
note: str = field(default_factory=lambda: "py-default", init=False)
^^^^^^^^^^
doesn't appear in `__init__`
the default value will be
forwarded to C++
obj = _TestCxxInitSubset(required_field=42)
assert obj.optional_field == -1
assert obj.note == "py-default"
```
---
docs/reference/python/cpp/index.rst | 27 ------
docs/reference/python/index.rst | 18 +++-
python/tvm_ffi/__init__.py | 2 +
python/tvm_ffi/dataclasses/_utils.py | 136 ++++++++++++++-----------------
python/tvm_ffi/dataclasses/field.py | 39 +++++++--
python/tvm_ffi/testing.py | 7 ++
src/ffi/extra/testing.cc | 19 +++++
tests/python/test_dataclasses_c_class.py | 32 +++++++-
8 files changed, 166 insertions(+), 114 deletions(-)
diff --git a/docs/reference/python/cpp/index.rst
b/docs/reference/python/cpp/index.rst
deleted file mode 100644
index c104335..0000000
--- a/docs/reference/python/cpp/index.rst
+++ /dev/null
@@ -1,27 +0,0 @@
-.. Licensed to the Apache Software Foundation (ASF) under one
- or more contributor license agreements. See the NOTICE file
- distributed with this work for additional information
- regarding copyright ownership. The ASF licenses this file
- to you under the Apache License, Version 2.0 (the
- "License"); you may not use this file except in compliance
- with the License. You may obtain a copy of the License at
-
-.. http://www.apache.org/licenses/LICENSE-2.0
-
-.. Unless required by applicable law or agreed to in writing,
- software distributed under the License is distributed on an
- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- KIND, either express or implied. See the License for the
- specific language governing permissions and limitations
- under the License.
-
-tvm_ffi.cpp
------------
-
-.. automodule:: tvm_ffi.cpp
- :no-members:
-
-.. autosummary::
- :toctree: generated/
-
- load_inline
diff --git a/docs/reference/python/index.rst b/docs/reference/python/index.rst
index 756af4c..8b478e8 100644
--- a/docs/reference/python/index.rst
+++ b/docs/reference/python/index.rst
@@ -81,7 +81,19 @@ Stream Context
Utility
-------
-.. toctree::
- :maxdepth: 1
+C++ integration helpers for building and loading inline modules.
- cpp/index.rst
+.. autosummary::
+ :toctree: cpp/generated/
+
+ cpp.load_inline
+
+
+.. (Experimental) Dataclasses
+.. --------------------------
+
+.. .. autosummary::
+.. :toctree: generated/
+
+.. dataclasses.c_class
+.. dataclasses.field
diff --git a/python/tvm_ffi/__init__.py b/python/tvm_ffi/__init__.py
index 2381363..c742ee2 100644
--- a/python/tvm_ffi/__init__.py
+++ b/python/tvm_ffi/__init__.py
@@ -42,6 +42,7 @@ from .stream import StreamContext, use_raw_stream,
use_torch_stream
from . import serialization
from . import access_path
from . import testing
+from . import dataclasses
# optional module to speedup dlpack conversion
from . import _optional_torch_c_dlpack
@@ -61,6 +62,7 @@ __all__ = [
"Tensor",
"access_path",
"convert",
+ "dataclasses",
"device",
"dtype",
"from_dlpack",
diff --git a/python/tvm_ffi/dataclasses/_utils.py
b/python/tvm_ffi/dataclasses/_utils.py
index 7a28e01..b6bcdac 100644
--- a/python/tvm_ffi/dataclasses/_utils.py
+++ b/python/tvm_ffi/dataclasses/_utils.py
@@ -19,9 +19,8 @@
from __future__ import annotations
import functools
-import inspect
from dataclasses import MISSING
-from typing import Any, Callable, NamedTuple, TypeVar, cast
+from typing import Any, Callable, TypeVar, cast
from ..core import (
Object,
@@ -111,96 +110,79 @@ def fill_dataclass_field(type_cls: type, type_field:
TypeField) -> None:
type_field.dataclass_field = rhs
-def method_init(type_cls: type, type_info: TypeInfo) -> Callable[..., None]:
# noqa: PLR0915
+def method_init(type_cls: type, type_info: TypeInfo) -> Callable[..., None]:
"""Generate an ``__init__`` that forwards to the FFI constructor.
The generated initializer has a proper Python signature built from the
reflected field list, supporting default values and ``__post_init__``.
"""
-
- class DefaultFactory(NamedTuple):
- """Wrapper that marks a parameter as having a default factory."""
-
- fn: Callable[[], Any]
-
+ # Step 0. Collect all fields from the type hierarchy
fields: list[TypeField] = []
cur_type_info: TypeInfo | None = type_info
while cur_type_info is not None:
fields.extend(reversed(cur_type_info.fields))
cur_type_info = cur_type_info.parent_type_info
fields.reverse()
-
- annotations: dict[str, Any] = {"return": None}
- # Step 1. Split the parameters into two groups to ensure that
- # those without defaults appear first in the signature.
- params_without_defaults: list[inspect.Parameter] = []
- params_with_defaults: list[inspect.Parameter] = []
- ordering = [0] * len(fields)
- for i, field in enumerate(fields):
- assert field.name is not None
- name: str = field.name
- annotations[name] = Any # NOTE: We might be able to handle
annotations better
- assert field.dataclass_field is not None
- default_factory = field.dataclass_field.default_factory
- if default_factory is MISSING:
- ordering[i] = len(params_without_defaults)
- params_without_defaults.append(
- inspect.Parameter(name=name,
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD)
- )
- else:
- ordering[i] = -len(params_with_defaults) - 1
- params_with_defaults.append(
- inspect.Parameter(
- name=name,
- kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
- default=DefaultFactory(fn=default_factory),
- )
- )
- for i, order in enumerate(ordering):
- if order < 0:
- ordering[i] = len(params_without_defaults) - order - 1
- # Step 2. Create the signature object
- sig = inspect.Signature(parameters=[*params_without_defaults,
*params_with_defaults])
- signature_str = (
- f"{type_cls.__module__}.{type_cls.__qualname__}.__init__("
- + ", ".join(p.name for p in sig.parameters.values())
- + ")"
- )
-
- # Step 3. Create the `binding` method that reorders parameters
- def touch_arg(x: Any) -> Any:
- return x.fn() if isinstance(x, DefaultFactory) else x
-
- def bind_args(*args: Any, **kwargs: Any) -> tuple[Any, ...]:
- bound = sig.bind(*args, **kwargs)
- bound.apply_defaults()
- args = bound.args
- args = tuple(touch_arg(args[i]) for i in ordering)
- return args
-
+ # sanity check
for type_method in type_info.methods:
if type_method.name == "__ffi_init__":
break
else:
raise ValueError(f"Cannot find constructor method:
`{type_info.type_key}.__ffi_init__`")
-
- def __init__(self: Any, *args: Any, **kwargs: Any) -> None:
- e = None
- try:
- args = bind_args(*args, **kwargs)
- del kwargs
- self.__ffi_init__(*args)
- except Exception as _e:
- e = TypeError(f"Error in `{signature_str}`:
{_e}").with_traceback(_e.__traceback__)
- if e is not None:
- raise e
- try:
- fn_post_init = self.__post_init__ # type: ignore[attr-defined]
- except AttributeError:
- pass
+ # Step 1. Split args into sections and register default factories
+ args_no_defaults: list[str] = []
+ args_with_defaults: list[str] = []
+ fields_with_defaults: list[tuple[str, bool]] = []
+ ffi_arg_order: list[str] = []
+ exec_globals = {"MISSING": MISSING}
+ for field in fields:
+ assert field.name is not None
+ assert field.dataclass_field is not None
+ dataclass_field = field.dataclass_field
+ has_default_factory = (default_factory :=
dataclass_field.default_factory) is not MISSING
+ if dataclass_field.init:
+ ffi_arg_order.append(field.name)
+ if has_default_factory:
+ args_with_defaults.append(field.name)
+ fields_with_defaults.append((field.name, True))
+ exec_globals[f"_default_factory_{field.name}"] =
default_factory
+ else:
+ args_no_defaults.append(field.name)
+ elif has_default_factory:
+ ffi_arg_order.append(field.name)
+ fields_with_defaults.append((field.name, False))
+ exec_globals[f"_default_factory_{field.name}"] = default_factory
+
+ args: list[str] = ["self"]
+ args.extend(args_no_defaults)
+ args.extend(f"{name}=MISSING" for name in args_with_defaults)
+ body_lines: list[str] = []
+ for field_name, is_init in fields_with_defaults:
+ if is_init:
+ body_lines.append(
+ f"if {field_name} is MISSING: {field_name} =
_default_factory_{field_name}()"
+ )
else:
- fn_post_init()
-
- __init__.__signature__ = sig # type: ignore[attr-defined]
- __init__.__annotations__ = annotations
+ body_lines.append(f"{field_name} =
_default_factory_{field_name}()")
+ body_lines.append(f"self.__ffi_init__({', '.join(ffi_arg_order)})")
+ body_lines.extend(
+ [
+ "try:",
+ " fn_post_init = self.__post_init__",
+ "except AttributeError:",
+ " pass",
+ "else:",
+ " fn_post_init()",
+ ]
+ )
+ source_lines = [f"def __init__({', '.join(args)}):"]
+ source_lines.extend(f" {line}" for line in body_lines)
+ source_lines.append(" ...")
+ source = "\n".join(source_lines)
+ # Note: Code generation in this case is guaranteed to be safe,
+ # because the generated code does not contain any untrusted input.
+ # This is also a common practice used by `dataclasses` and `pydantic`.
+ namespace: dict[str, Any] = {}
+ exec(source, exec_globals, namespace)
+ __init__ = namespace["__init__"]
return __init__
diff --git a/python/tvm_ffi/dataclasses/field.py
b/python/tvm_ffi/dataclasses/field.py
index 2a62d66..92a28cb 100644
--- a/python/tvm_ffi/dataclasses/field.py
+++ b/python/tvm_ffi/dataclasses/field.py
@@ -37,23 +37,26 @@ class Field:
way the decorator understands.
"""
- __slots__ = ("default_factory", "name")
+ __slots__ = ("default_factory", "init", "name")
def __init__(
self,
*,
name: str | None = None,
default_factory: Callable[[], _FieldValue] | _MISSING_TYPE = MISSING,
+ init: bool = True,
) -> None:
"""Do not call directly; use :func:`field` instead."""
self.name = name
self.default_factory = default_factory
+ self.init = init
def field(
*,
default: _FieldValue | _MISSING_TYPE = MISSING, # type: ignore[assignment]
default_factory: Callable[[], _FieldValue] | _MISSING_TYPE = MISSING, #
type: ignore[assignment]
+ init: bool = True,
) -> _FieldValue:
"""(Experimental) Declare a dataclass-style field on a :func:`c_class`
proxy.
@@ -65,13 +68,33 @@ def field(
Parameters
----------
default : Any, optional
- A literal default value that should populate the field when no argument
- is given. The value is copied into a closure because TVM FFI does not
- mutate the Python placeholder instance.
+ A literal default value that populates the field when no argument
+ is given. At most one of ``default`` or ``default_factory`` may be
+ given.
default_factory : Callable[[], Any], optional
A zero-argument callable that produces the default. This matches the
semantics of :func:`dataclasses.field` and is useful for mutable
defaults such as ``list`` or ``dict``.
+ init : bool, default True
+ If ``True`` the field is included in the generated ``__init__``.
+ If ``False`` the field is omitted from input arguments of ``__init__``.
+
+ Note
+ ----
+ The decision to forward a field to the C++ ``__ffi_init__`` constructor
+ depends on its configuration:
+
+ * If ``init=True``, the field's value (from user input or defaults)
+ is forwarded.
+
+ * If ``init=False``:
+
+ - With a ``default`` or ``default_factory``, its computed value is
+ forwarded. The user cannot provide this value via Python
``__init__``.
+
+ - Without a ``default`` or ``default_factory``, the field is *not*
+ forwarded to C++ ``__ffi_init__`` and must be initialized by the
+ C++ constructor.
Returns
-------
@@ -82,7 +105,9 @@ def field(
Examples
--------
``field`` integrates with :func:`c_class` to express defaults the same way
a
- Python ``dataclass`` would::
+ Python ``dataclass`` would:
+
+ .. code-block:: python
@c_class("testing.TestCxxClassBase")
class PyBase:
@@ -95,9 +120,11 @@ def field(
"""
if default is not MISSING and default_factory is not MISSING:
raise ValueError("Cannot specify both `default` and `default_factory`")
+ if not isinstance(init, bool):
+ raise TypeError("`init` must be a bool")
if default is not MISSING:
default_factory = _make_default_factory(default)
- ret = Field(default_factory=default_factory)
+ ret = Field(default_factory=default_factory, init=init)
return cast(_FieldValue, ret)
diff --git a/python/tvm_ffi/testing.py b/python/tvm_ffi/testing.py
index 0053564..155030b 100644
--- a/python/tvm_ffi/testing.py
+++ b/python/tvm_ffi/testing.py
@@ -110,3 +110,10 @@ class _TestCxxClassDerived(_TestCxxClassBase):
class _TestCxxClassDerivedDerived(_TestCxxClassDerived):
v_str: str = field(default_factory=lambda: "default")
v_bool: bool # type: ignore[misc] # Suppress: Attributes without a
default cannot follow attributes with one
+
+
+@c_class("testing.TestCxxInitSubset")
+class _TestCxxInitSubset:
+ required_field: int
+ optional_field: int = field(init=False)
+ note: str = field(default_factory=lambda: "py-default", init=False)
diff --git a/src/ffi/extra/testing.cc b/src/ffi/extra/testing.cc
index afd952b..cf55161 100644
--- a/src/ffi/extra/testing.cc
+++ b/src/ffi/extra/testing.cc
@@ -121,6 +121,19 @@ class TestCxxClassDerivedDerived : public
TestCxxClassDerived {
TestCxxClassDerived);
};
+class TestCxxInitSubsetObj : public Object {
+ public:
+ int64_t required_field;
+ int64_t optional_field;
+ String note;
+
+ explicit TestCxxInitSubsetObj(int64_t value, String note)
+ : required_field(value), optional_field(-1), note(note) {}
+
+ static constexpr bool _type_mutable = true;
+ TVM_FFI_DECLARE_OBJECT_INFO("testing.TestCxxInitSubset",
TestCxxInitSubsetObj, Object);
+};
+
class TestUnregisteredObject : public Object {
public:
int64_t value;
@@ -170,6 +183,12 @@ TVM_FFI_STATIC_INIT_BLOCK() {
.def_rw("v_str", &TestCxxClassDerivedDerived::v_str)
.def_rw("v_bool", &TestCxxClassDerivedDerived::v_bool);
+ refl::ObjectDef<TestCxxInitSubsetObj>()
+ .def_static("__ffi_init__", refl::init<TestCxxInitSubsetObj, int64_t,
String>)
+ .def_rw("required_field", &TestCxxInitSubsetObj::required_field)
+ .def_rw("optional_field", &TestCxxInitSubsetObj::optional_field)
+ .def_rw("note", &TestCxxInitSubsetObj::note);
+
refl::GlobalDef()
.def("testing.test_raise_error", TestRaiseError)
.def_packed("testing.nop", [](PackedArgs args, Any* ret) {})
diff --git a/tests/python/test_dataclasses_c_class.py
b/tests/python/test_dataclasses_c_class.py
index 0050e6c..676bbf5 100644
--- a/tests/python/test_dataclasses_c_class.py
+++ b/tests/python/test_dataclasses_c_class.py
@@ -14,7 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from tvm_ffi.testing import _TestCxxClassBase, _TestCxxClassDerived,
_TestCxxClassDerivedDerived
+import inspect
+
+from tvm_ffi.testing import (
+ _TestCxxClassBase,
+ _TestCxxClassDerived,
+ _TestCxxClassDerivedDerived,
+ _TestCxxInitSubset,
+)
def test_cxx_class_base() -> None:
@@ -64,3 +71,26 @@ def test_cxx_class_derived_derived_default() -> None:
assert isinstance(obj.v_f32, float) and obj.v_f32 == 8
assert obj.v_str == "default"
assert isinstance(obj.v_bool, bool) and obj.v_bool is True
+
+
+def test_cxx_class_init_subset_signature() -> None:
+ sig = inspect.signature(_TestCxxInitSubset.__init__)
+ params = tuple(sig.parameters)
+ assert "required_field" in params
+ assert "optional_field" not in params
+ assert "note" not in params
+
+
+def test_cxx_class_init_subset_defaults() -> None:
+ obj = _TestCxxInitSubset(required_field=42)
+ assert obj.required_field == 42
+ assert obj.optional_field == -1
+ assert obj.note == "py-default"
+
+
+def test_cxx_class_init_subset_positional() -> None:
+ obj = _TestCxxInitSubset(7) # type: ignore[call-arg]
+ assert obj.required_field == 7
+ assert obj.optional_field == -1
+ obj.optional_field = 11
+ assert obj.optional_field == 11