This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 3a5bf5e feat: add kw_only support for dataclass init generation (#384)
3a5bf5e is described below
commit 3a5bf5e68ad1b4108045ef6b336a13efcd2037d9
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Sun Jan 18 15:13:10 2026 +0800
feat: add kw_only support for dataclass init generation (#384)
## Related Issue
#356
## Why
Python's standard dataclasses support kw_only parameter to make fields
keyword-only in __init__. This feature was missing from @c_class
decorator.
## How
- Add `KW_ONLY` sentinel class for marking keyword-only fields
- Add `kw_only` parameter to field() function and @c_class decorator
- Update `method_init()` to generate proper signature
- Add tests
---
python/tvm_ffi/dataclasses/__init__.py | 4 +--
python/tvm_ffi/dataclasses/_utils.py | 58 ++++++++++++++++++++++----------
python/tvm_ffi/dataclasses/c_class.py | 47 +++++++++++++++++++++-----
python/tvm_ffi/dataclasses/field.py | 35 +++++++++++++++++--
python/tvm_ffi/testing/__init__.py | 1 +
python/tvm_ffi/testing/testing.py | 8 +++++
src/ffi/testing/testing.cc | 20 +++++++++++
tests/python/test_dataclasses_c_class.py | 55 ++++++++++++++++++++++++++++++
8 files changed, 199 insertions(+), 29 deletions(-)
diff --git a/python/tvm_ffi/dataclasses/__init__.py
b/python/tvm_ffi/dataclasses/__init__.py
index 3185413..bfb4404 100644
--- a/python/tvm_ffi/dataclasses/__init__.py
+++ b/python/tvm_ffi/dataclasses/__init__.py
@@ -19,6 +19,6 @@
from dataclasses import MISSING
from .c_class import c_class
-from .field import Field, field
+from .field import KW_ONLY, Field, field
-__all__ = ["MISSING", "Field", "c_class", "field"]
+__all__ = ["KW_ONLY", "MISSING", "Field", "c_class", "field"]
diff --git a/python/tvm_ffi/dataclasses/_utils.py
b/python/tvm_ffi/dataclasses/_utils.py
index 5ed4e96..7c0afb4 100644
--- a/python/tvm_ffi/dataclasses/_utils.py
+++ b/python/tvm_ffi/dataclasses/_utils.py
@@ -79,7 +79,13 @@ def type_info_to_cls(
return cast(Type[_InputClsType], new_cls)
-def fill_dataclass_field(type_cls: type, type_field: TypeField) -> None:
+def fill_dataclass_field(
+ type_cls: type,
+ type_field: TypeField,
+ *,
+ class_kw_only: bool = False,
+ kw_only_from_sentinel: bool = False,
+) -> None:
from .field import Field, field # noqa: PLC0415
field_name = type_field.name
@@ -94,6 +100,14 @@ def fill_dataclass_field(type_cls: type, type_field:
TypeField) -> None:
raise ValueError(f"Cannot recognize field: {type_field.name}: {rhs}")
assert isinstance(rhs, Field)
rhs.name = type_field.name
+
+ # Resolve kw_only: field-level > KW_ONLY sentinel > class-level
+ if rhs.kw_only is MISSING:
+ if kw_only_from_sentinel:
+ rhs.kw_only = True
+ else:
+ rhs.kw_only = class_kw_only
+
type_field.dataclass_field = rhs
@@ -148,47 +162,56 @@ def method_repr(type_cls: type, type_info: TypeInfo) ->
Callable[..., str]:
return __repr__
-def method_init(type_cls: type, type_info: TypeInfo) -> Callable[..., None]:
+def method_init(_type_cls: type, type_info: TypeInfo) -> Callable[..., None]:
"""Generate an ``__init__`` that forwards to the FFI constructor.
The generated initializer has a proper Python signature built from the
- reflected field list, supporting default values and ``__post_init__``.
+ reflected field list, supporting default values, keyword-only args, and
``__post_init__``.
"""
# Step 0. Collect all fields from the type hierarchy
fields = _get_all_fields(type_info)
# sanity check
- for type_method in type_info.methods:
- if type_method.name == "__ffi_init__":
- break
- else:
+ if not any(m.name == "__ffi_init__" for m in type_info.methods):
raise ValueError(f"Cannot find constructor method:
`{type_info.type_key}.__ffi_init__`")
# Step 1. Split args into sections and register default factories
- args_no_defaults: list[str] = []
- args_with_defaults: list[str] = []
+ pos_no_defaults: list[str] = []
+ pos_with_defaults: list[str] = []
+ kw_no_defaults: list[str] = []
+ kw_with_defaults: list[str] = []
fields_with_defaults: list[tuple[str, bool]] = []
ffi_arg_order: list[str] = []
- exec_globals = {"MISSING": MISSING}
+ exec_globals: dict[str, Any] = {"MISSING": MISSING}
+
for field in fields:
assert field.name is not None
assert field.dataclass_field is not None
dataclass_field = field.dataclass_field
- has_default_factory = (default_factory :=
dataclass_field.default_factory) is not MISSING
+ has_default = (default_factory := dataclass_field.default_factory) is
not MISSING
+ is_kw_only = dataclass_field.kw_only is True
+
if dataclass_field.init:
ffi_arg_order.append(field.name)
- if has_default_factory:
- args_with_defaults.append(field.name)
+ if has_default:
+ (kw_with_defaults if is_kw_only else
pos_with_defaults).append(field.name)
fields_with_defaults.append((field.name, True))
exec_globals[f"_default_factory_{field.name}"] =
default_factory
else:
- args_no_defaults.append(field.name)
- elif has_default_factory:
+ (kw_no_defaults if is_kw_only else
pos_no_defaults).append(field.name)
+ elif has_default:
ffi_arg_order.append(field.name)
fields_with_defaults.append((field.name, False))
exec_globals[f"_default_factory_{field.name}"] = default_factory
+ # Step 2. Build signature
args: list[str] = ["self"]
- args.extend(args_no_defaults)
- args.extend(f"{name}=MISSING" for name in args_with_defaults)
+ args.extend(pos_no_defaults)
+ args.extend(f"{name}=MISSING" for name in pos_with_defaults)
+ if kw_no_defaults or kw_with_defaults:
+ args.append("*")
+ args.extend(kw_no_defaults)
+ args.extend(f"{name}=MISSING" for name in kw_with_defaults)
+
+ # Step 3. Build body
body_lines: list[str] = []
for field_name, is_init in fields_with_defaults:
if is_init:
@@ -208,6 +231,7 @@ def method_init(type_cls: type, type_info: TypeInfo) ->
Callable[..., None]:
" fn_post_init()",
]
)
+
source_lines = [f"def __init__({', '.join(args)}):"]
source_lines.extend(f" {line}" for line in body_lines)
source_lines.append(" ...")
diff --git a/python/tvm_ffi/dataclasses/c_class.py
b/python/tvm_ffi/dataclasses/c_class.py
index 8171b1b..8dd5e5a 100644
--- a/python/tvm_ffi/dataclasses/c_class.py
+++ b/python/tvm_ffi/dataclasses/c_class.py
@@ -34,14 +34,14 @@ from typing_extensions import dataclass_transform
from ..core import TypeField, TypeInfo,
_lookup_or_register_type_info_from_type_key, _set_type_cls
from . import _utils
-from .field import field
+from .field import KW_ONLY, field
_InputClsType = TypeVar("_InputClsType")
-@dataclass_transform(field_specifiers=(field,))
+@dataclass_transform(field_specifiers=(field,), kw_only_default=False)
def c_class(
- type_key: str, init: bool = True, repr: bool = True
+ type_key: str, init: bool = True, kw_only: bool = False, repr: bool = True
) -> Callable[[Type[_InputClsType]], Type[_InputClsType]]: # noqa: UP006
"""(Experimental) Create a dataclass-like proxy for a C++ class registered
with TVM FFI.
@@ -71,6 +71,12 @@ def c_class(
signature. The generated initializer calls the C++ ``__init__``
function registered with ``ObjectDef`` and invokes ``__post_init__`` if
it exists on the Python class.
+
+ kw_only
+ If ``True``, all fields become keyword-only parameters in the generated
+ ``__init__``. Individual fields can override this by setting
+ ``kw_only=False`` in :func:`field`. Additionally, a ``KW_ONLY``
sentinel
+ annotation can be used to mark all subsequent fields as keyword-only.
repr
If ``True`` and the Python class does not define ``__repr__``, a
representation method is auto-generated that includes all fields with
@@ -129,9 +135,15 @@ def c_class(
type_info: TypeInfo =
_lookup_or_register_type_info_from_type_key(type_key)
assert type_info.parent_type_info is not None
# Step 2. Reflect all the fields of the type
- type_info.fields = _inspect_c_class_fields(super_type_cls, type_info)
- for type_field in type_info.fields:
- _utils.fill_dataclass_field(super_type_cls, type_field)
+ type_info.fields, kw_only_start_idx =
_inspect_c_class_fields(super_type_cls, type_info)
+ for idx, type_field in enumerate(type_info.fields):
+ kw_only_from_sentinel = kw_only_start_idx is not None and idx >=
kw_only_start_idx
+ _utils.fill_dataclass_field(
+ super_type_cls,
+ type_field,
+ class_kw_only=kw_only,
+ kw_only_from_sentinel=kw_only_from_sentinel,
+ )
# Step 3. Create the proxy class with the fields as properties
fn_init = _utils.method_init(super_type_cls, type_info) if init else
None
fn_repr = _utils.method_repr(super_type_cls, type_info) if repr else
None
@@ -146,7 +158,9 @@ def c_class(
return decorator
-def _inspect_c_class_fields(type_cls: type, type_info: TypeInfo) ->
list[TypeField]:
+def _inspect_c_class_fields(
+ type_cls: type, type_info: TypeInfo
+) -> tuple[list[TypeField], int | None]:
if sys.version_info >= (3, 9):
type_hints_resolved = get_type_hints(type_cls, include_extras=True)
else:
@@ -159,7 +173,24 @@ def _inspect_c_class_fields(type_cls: type, type_info:
TypeInfo) -> list[TypeFie
ClassVar,
InitVar,
]
+ and type_hints_resolved[name] is not KW_ONLY
}
+
+ # Detect KW_ONLY sentinel position
+ kw_only_start_idx: int | None = None
+ field_count = 0
+ for name in getattr(type_cls, "__annotations__", {}).keys():
+ resolved_type = type_hints_resolved.get(name)
+ if resolved_type is None:
+ continue
+ if get_origin(resolved_type) in [ClassVar, InitVar]:
+ continue
+ if resolved_type is KW_ONLY:
+ if kw_only_start_idx is not None:
+ raise ValueError(f"KW_ONLY may only be used once per class:
{type_cls}")
+ kw_only_start_idx = field_count
+ continue
+ field_count += 1
del type_hints_resolved
type_fields_cxx: dict[str, TypeField] = {f.name: f for f in
type_info.fields}
@@ -178,4 +209,4 @@ def _inspect_c_class_fields(type_cls: type, type_info:
TypeInfo) -> list[TypeFie
raise ValueError(
f"Missing fields in `{type_cls}`: {extra_fields}. Defined in C++
but not in Python"
)
- return type_fields
+ return type_fields, kw_only_start_idx
diff --git a/python/tvm_ffi/dataclasses/field.py
b/python/tvm_ffi/dataclasses/field.py
index d0e27b1..a395e50 100644
--- a/python/tvm_ffi/dataclasses/field.py
+++ b/python/tvm_ffi/dataclasses/field.py
@@ -21,7 +21,17 @@ from __future__ import annotations
from dataclasses import _MISSING_TYPE, MISSING
from typing import Any, Callable, TypeVar, cast
+try:
+ from dataclasses import KW_ONLY # type: ignore[attr-defined]
+except ImportError:
+ # Python < 3.10: define our own KW_ONLY sentinel
+ class _KW_ONLY_Sentinel:
+ __slots__ = ()
+
+ KW_ONLY = _KW_ONLY_Sentinel()
+
_FieldValue = TypeVar("_FieldValue")
+_KW_ONLY_TYPE = type(KW_ONLY)
class Field:
@@ -37,7 +47,7 @@ class Field:
way the decorator understands.
"""
- __slots__ = ("default_factory", "init", "name", "repr")
+ __slots__ = ("default_factory", "init", "kw_only", "name", "repr")
def __init__(
self,
@@ -46,12 +56,14 @@ class Field:
default_factory: Callable[[], _FieldValue] | _MISSING_TYPE = MISSING,
init: bool = True,
repr: bool = True,
+ kw_only: bool | _MISSING_TYPE = MISSING,
) -> None:
"""Do not call directly; use :func:`field` instead."""
self.name = name
self.default_factory = default_factory
self.init = init
self.repr = repr
+ self.kw_only = kw_only
def field(
@@ -60,6 +72,7 @@ def field(
default_factory: Callable[[], _FieldValue] | _MISSING_TYPE = MISSING, #
type: ignore[assignment]
init: bool = True,
repr: bool = True,
+ kw_only: bool | _MISSING_TYPE = MISSING, # type: ignore[assignment]
) -> _FieldValue:
"""(Experimental) Declare a dataclass-style field on a :func:`c_class`
proxy.
@@ -84,6 +97,10 @@ def field(
repr
If ``True`` the field is included in the generated ``__repr__``.
If ``False`` the field is omitted from the ``__repr__`` output.
+ kw_only
+ If ``True``, the field is a keyword-only argument in ``__init__``.
+ If ``MISSING``, inherits from the class-level ``kw_only`` setting or
+ from a preceding ``KW_ONLY`` sentinel annotation.
Note
----
@@ -124,6 +141,18 @@ def field(
obj = PyBase(v_i64=4)
obj.v_i32 # -> 16
+ Use ``kw_only=True`` to make a field keyword-only:
+
+ .. code-block:: python
+
+ @c_class("testing.TestCxxClassBase")
+ class PyBase:
+ v_i64: int
+ v_i32: int = field(kw_only=True)
+
+
+ obj = PyBase(4, v_i32=8) # v_i32 must be keyword
+
"""
if default is not MISSING and default_factory is not MISSING:
raise ValueError("Cannot specify both `default` and `default_factory`")
@@ -131,9 +160,11 @@ def field(
raise TypeError("`init` must be a bool")
if not isinstance(repr, bool):
raise TypeError("`repr` must be a bool")
+ if kw_only is not MISSING and not isinstance(kw_only, bool):
+ raise TypeError(f"`kw_only` must be a bool, got
{type(kw_only).__name__!r}")
if default is not MISSING:
default_factory = _make_default_factory(default)
- ret = Field(default_factory=default_factory, init=init, repr=repr)
+ ret = Field(default_factory=default_factory, init=init, repr=repr,
kw_only=kw_only)
return cast(_FieldValue, ret)
diff --git a/python/tvm_ffi/testing/__init__.py
b/python/tvm_ffi/testing/__init__.py
index cd35736..af22210 100644
--- a/python/tvm_ffi/testing/__init__.py
+++ b/python/tvm_ffi/testing/__init__.py
@@ -25,6 +25,7 @@ from .testing import (
_TestCxxClassDerived,
_TestCxxClassDerivedDerived,
_TestCxxInitSubset,
+ _TestCxxKwOnly,
add_one,
create_object,
make_unregistered_object,
diff --git a/python/tvm_ffi/testing/testing.py
b/python/tvm_ffi/testing/testing.py
index b905b5b..0ffeb49 100644
--- a/python/tvm_ffi/testing/testing.py
+++ b/python/tvm_ffi/testing/testing.py
@@ -178,3 +178,11 @@ class _TestCxxInitSubset:
required_field: int
optional_field: int = field(init=False)
note: str = field(default_factory=lambda: "py-default", init=False)
+
+
+@c_class("testing.TestCxxKwOnly", kw_only=True)
+class _TestCxxKwOnly:
+ x: int
+ y: int
+ z: int
+ w: int = 100
diff --git a/src/ffi/testing/testing.cc b/src/ffi/testing/testing.cc
index 7ee6ffd..0df7f1e 100644
--- a/src/ffi/testing/testing.cc
+++ b/src/ffi/testing/testing.cc
@@ -147,6 +147,19 @@ class TestCxxInitSubsetObj : public Object {
TVM_FFI_DECLARE_OBJECT_INFO("testing.TestCxxInitSubset",
TestCxxInitSubsetObj, Object);
};
+class TestCxxKwOnly : public Object {
+ public:
+ int64_t x;
+ int64_t y;
+ int64_t z;
+ int64_t w;
+
+ TestCxxKwOnly(int64_t x, int64_t y, int64_t z, int64_t w) : x(x), y(y),
z(z), w(w) {}
+
+ static constexpr bool _type_mutable = true;
+ TVM_FFI_DECLARE_OBJECT_INFO("testing.TestCxxKwOnly", TestCxxKwOnly, Object);
+};
+
class TestUnregisteredBaseObject : public Object {
public:
int64_t v1;
@@ -229,6 +242,13 @@ TVM_FFI_STATIC_INIT_BLOCK() {
.def_rw("optional_field", &TestCxxInitSubsetObj::optional_field)
.def_rw("note", &TestCxxInitSubsetObj::note);
+ refl::ObjectDef<TestCxxKwOnly>()
+ .def(refl::init<int64_t, int64_t, int64_t, int64_t>())
+ .def_rw("x", &TestCxxKwOnly::x)
+ .def_rw("y", &TestCxxKwOnly::y)
+ .def_rw("z", &TestCxxKwOnly::z)
+ .def_rw("w", &TestCxxKwOnly::w);
+
refl::ObjectDef<TestUnregisteredBaseObject>()
.def(refl::init<int64_t>(), "Constructor of TestUnregisteredBaseObject")
.def_ro("v1", &TestUnregisteredBaseObject::v1)
diff --git a/tests/python/test_dataclasses_c_class.py
b/tests/python/test_dataclasses_c_class.py
index 5361cb6..3a757d0 100644
--- a/tests/python/test_dataclasses_c_class.py
+++ b/tests/python/test_dataclasses_c_class.py
@@ -15,12 +15,17 @@
# specific language governing permissions and limitations
# under the License.
import inspect
+from dataclasses import MISSING
+import pytest
+from tvm_ffi.dataclasses import KW_ONLY, field
+from tvm_ffi.dataclasses.field import _KW_ONLY_TYPE, Field
from tvm_ffi.testing import (
_TestCxxClassBase,
_TestCxxClassDerived,
_TestCxxClassDerivedDerived,
_TestCxxInitSubset,
+ _TestCxxKwOnly,
)
@@ -129,3 +134,53 @@ def test_cxx_class_repr_derived_derived() -> None:
assert "v_i32=456" in repr_str
assert "v_str='hello'" in repr_str or 'v_str="hello"' in repr_str
assert "v_bool=True" in repr_str
+
+
+def test_kw_only_class_level_signature() -> None:
+ sig = inspect.signature(_TestCxxKwOnly.__init__)
+ params = sig.parameters
+ assert params["x"].kind == inspect.Parameter.KEYWORD_ONLY
+ assert params["y"].kind == inspect.Parameter.KEYWORD_ONLY
+ assert params["z"].kind == inspect.Parameter.KEYWORD_ONLY
+ assert params["w"].kind == inspect.Parameter.KEYWORD_ONLY
+
+
+def test_kw_only_class_level_call() -> None:
+ obj = _TestCxxKwOnly(x=1, y=2, z=3, w=4)
+ assert obj.x == 1
+ assert obj.y == 2
+ assert obj.z == 3
+ assert obj.w == 4
+
+
+def test_kw_only_class_level_with_default() -> None:
+ obj = _TestCxxKwOnly(x=1, y=2, z=3)
+ assert obj.w == 100
+
+
+def test_kw_only_class_level_rejects_positional() -> None:
+ with pytest.raises(TypeError, match="positional"):
+ _TestCxxKwOnly(1, 2, 3, 4) # type: ignore[misc]
+
+
+def test_field_kw_only_parameter() -> None:
+ f1: Field = field(kw_only=True)
+ assert isinstance(f1, Field)
+ assert f1.kw_only is True
+
+ f2: Field = field(kw_only=False)
+ assert f2.kw_only is False
+
+ f3: Field = field()
+ assert f3.kw_only is MISSING
+
+
+def test_field_kw_only_with_default() -> None:
+ f = field(default=42, kw_only=True)
+ assert isinstance(f, Field)
+ assert f.kw_only is True
+ assert f.default_factory() == 42
+
+
+def test_kw_only_sentinel_exists() -> None:
+ assert isinstance(KW_ONLY, _KW_ONLY_TYPE)