This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new daeb235  feat: Support `field(init=...)` in `@c_class` fields (#52)
daeb235 is described below

commit daeb235a29c576d8702d447fa5f4773170bb1e8f
Author: Junru Shao <[email protected]>
AuthorDate: Wed Sep 24 12:56:41 2025 -0700

    feat: Support `field(init=...)` in `@c_class` fields (#52)
    
    This PR enhances the `@c_class` decorator by integrating `field(...,
    init: bool=True)`.
    
    This feature gives finer control over the generated Python `__init__`
    method, making it possible to exclude specific fields from the
    constructor's signature, and thereby increase the flexibility and
    alignment of tvm-ffi's dataclass system with Python's native dataclass
    capabilities.
    
    ## Example
    
    ```cpp
    // C++ side
    class TestCxxInitSubsetObj : public Object {
     public:
      int64_t required_field;
      int64_t optional_field;
      String note;
    
      explicit TestCxxInitSubsetObj(int64_t value, String note)
          : required_field(value), optional_field(-1), note(note) {}
                                  ^^^^^^^^^^^^^^^^^^^^
                                  this field is not present in constructor
    
      TVM_FFI_DECLARE_OBJECT_INFO("testing.TestCxxInitSubset", 
TestCxxInitSubsetObj, Object);
    };
    ```
    
    ```python
    # Python side
    @c_class("testing.TestCxxInitSubset")
    class _TestCxxInitSubset:
        required_field: int
        optional_field: int = field(init=False)
                                    ^^^^^^^^^^
                                doesn't appear in `__init__`;
                                won't be forwarded to C++
        note: str = field(default_factory=lambda: "py-default", init=False)
                                                                ^^^^^^^^^^
                                              doesn't appear in `__init__`
                                              the default value will be 
forwarded to C++
    
    obj = _TestCxxInitSubset(required_field=42)
    assert obj.optional_field == -1
    assert obj.note == "py-default"
    ```
---
 docs/reference/python/cpp/index.rst      |  27 ------
 docs/reference/python/index.rst          |  18 +++-
 python/tvm_ffi/__init__.py               |   2 +
 python/tvm_ffi/dataclasses/_utils.py     | 136 ++++++++++++++-----------------
 python/tvm_ffi/dataclasses/field.py      |  39 +++++++--
 python/tvm_ffi/testing.py                |   7 ++
 src/ffi/extra/testing.cc                 |  19 +++++
 tests/python/test_dataclasses_c_class.py |  32 +++++++-
 8 files changed, 166 insertions(+), 114 deletions(-)

diff --git a/docs/reference/python/cpp/index.rst 
b/docs/reference/python/cpp/index.rst
deleted file mode 100644
index c104335..0000000
--- a/docs/reference/python/cpp/index.rst
+++ /dev/null
@@ -1,27 +0,0 @@
-..  Licensed to the Apache Software Foundation (ASF) under one
-    or more contributor license agreements.  See the NOTICE file
-    distributed with this work for additional information
-    regarding copyright ownership.  The ASF licenses this file
-    to you under the Apache License, Version 2.0 (the
-    "License"); you may not use this file except in compliance
-    with the License.  You may obtain a copy of the License at
-
-..    http://www.apache.org/licenses/LICENSE-2.0
-
-..  Unless required by applicable law or agreed to in writing,
-    software distributed under the License is distributed on an
-    "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-    KIND, either express or implied.  See the License for the
-    specific language governing permissions and limitations
-    under the License.
-
-tvm_ffi.cpp
------------
-
-.. automodule:: tvm_ffi.cpp
-  :no-members:
-
-.. autosummary::
-  :toctree: generated/
-
-  load_inline
diff --git a/docs/reference/python/index.rst b/docs/reference/python/index.rst
index 756af4c..8b478e8 100644
--- a/docs/reference/python/index.rst
+++ b/docs/reference/python/index.rst
@@ -81,7 +81,19 @@ Stream Context
 Utility
 -------
 
-.. toctree::
-  :maxdepth: 1
+C++ integration helpers for building and loading inline modules.
 
-  cpp/index.rst
+.. autosummary::
+  :toctree: cpp/generated/
+
+  cpp.load_inline
+
+
+.. (Experimental) Dataclasses
+.. --------------------------
+
+.. .. autosummary::
+..   :toctree: generated/
+
+..   dataclasses.c_class
+..   dataclasses.field
diff --git a/python/tvm_ffi/__init__.py b/python/tvm_ffi/__init__.py
index 2381363..c742ee2 100644
--- a/python/tvm_ffi/__init__.py
+++ b/python/tvm_ffi/__init__.py
@@ -42,6 +42,7 @@ from .stream import StreamContext, use_raw_stream, 
use_torch_stream
 from . import serialization
 from . import access_path
 from . import testing
+from . import dataclasses
 
 # optional module to speedup dlpack conversion
 from . import _optional_torch_c_dlpack
@@ -61,6 +62,7 @@ __all__ = [
     "Tensor",
     "access_path",
     "convert",
+    "dataclasses",
     "device",
     "dtype",
     "from_dlpack",
diff --git a/python/tvm_ffi/dataclasses/_utils.py 
b/python/tvm_ffi/dataclasses/_utils.py
index 7a28e01..b6bcdac 100644
--- a/python/tvm_ffi/dataclasses/_utils.py
+++ b/python/tvm_ffi/dataclasses/_utils.py
@@ -19,9 +19,8 @@
 from __future__ import annotations
 
 import functools
-import inspect
 from dataclasses import MISSING
-from typing import Any, Callable, NamedTuple, TypeVar, cast
+from typing import Any, Callable, TypeVar, cast
 
 from ..core import (
     Object,
@@ -111,96 +110,79 @@ def fill_dataclass_field(type_cls: type, type_field: 
TypeField) -> None:
     type_field.dataclass_field = rhs
 
 
-def method_init(type_cls: type, type_info: TypeInfo) -> Callable[..., None]:  
# noqa: PLR0915
+def method_init(type_cls: type, type_info: TypeInfo) -> Callable[..., None]:
     """Generate an ``__init__`` that forwards to the FFI constructor.
 
     The generated initializer has a proper Python signature built from the
     reflected field list, supporting default values and ``__post_init__``.
     """
-
-    class DefaultFactory(NamedTuple):
-        """Wrapper that marks a parameter as having a default factory."""
-
-        fn: Callable[[], Any]
-
+    # Step 0. Collect all fields from the type hierarchy
     fields: list[TypeField] = []
     cur_type_info: TypeInfo | None = type_info
     while cur_type_info is not None:
         fields.extend(reversed(cur_type_info.fields))
         cur_type_info = cur_type_info.parent_type_info
     fields.reverse()
-
-    annotations: dict[str, Any] = {"return": None}
-    # Step 1. Split the parameters into two groups to ensure that
-    # those without defaults appear first in the signature.
-    params_without_defaults: list[inspect.Parameter] = []
-    params_with_defaults: list[inspect.Parameter] = []
-    ordering = [0] * len(fields)
-    for i, field in enumerate(fields):
-        assert field.name is not None
-        name: str = field.name
-        annotations[name] = Any  # NOTE: We might be able to handle 
annotations better
-        assert field.dataclass_field is not None
-        default_factory = field.dataclass_field.default_factory
-        if default_factory is MISSING:
-            ordering[i] = len(params_without_defaults)
-            params_without_defaults.append(
-                inspect.Parameter(name=name, 
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD)
-            )
-        else:
-            ordering[i] = -len(params_with_defaults) - 1
-            params_with_defaults.append(
-                inspect.Parameter(
-                    name=name,
-                    kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
-                    default=DefaultFactory(fn=default_factory),
-                )
-            )
-    for i, order in enumerate(ordering):
-        if order < 0:
-            ordering[i] = len(params_without_defaults) - order - 1
-    # Step 2. Create the signature object
-    sig = inspect.Signature(parameters=[*params_without_defaults, 
*params_with_defaults])
-    signature_str = (
-        f"{type_cls.__module__}.{type_cls.__qualname__}.__init__("
-        + ", ".join(p.name for p in sig.parameters.values())
-        + ")"
-    )
-
-    # Step 3. Create the `binding` method that reorders parameters
-    def touch_arg(x: Any) -> Any:
-        return x.fn() if isinstance(x, DefaultFactory) else x
-
-    def bind_args(*args: Any, **kwargs: Any) -> tuple[Any, ...]:
-        bound = sig.bind(*args, **kwargs)
-        bound.apply_defaults()
-        args = bound.args
-        args = tuple(touch_arg(args[i]) for i in ordering)
-        return args
-
+    # sanity check
     for type_method in type_info.methods:
         if type_method.name == "__ffi_init__":
             break
     else:
         raise ValueError(f"Cannot find constructor method: 
`{type_info.type_key}.__ffi_init__`")
-
-    def __init__(self: Any, *args: Any, **kwargs: Any) -> None:
-        e = None
-        try:
-            args = bind_args(*args, **kwargs)
-            del kwargs
-            self.__ffi_init__(*args)
-        except Exception as _e:
-            e = TypeError(f"Error in `{signature_str}`: 
{_e}").with_traceback(_e.__traceback__)
-        if e is not None:
-            raise e
-        try:
-            fn_post_init = self.__post_init__  # type: ignore[attr-defined]
-        except AttributeError:
-            pass
+    # Step 1. Split args into sections and register default factories
+    args_no_defaults: list[str] = []
+    args_with_defaults: list[str] = []
+    fields_with_defaults: list[tuple[str, bool]] = []
+    ffi_arg_order: list[str] = []
+    exec_globals = {"MISSING": MISSING}
+    for field in fields:
+        assert field.name is not None
+        assert field.dataclass_field is not None
+        dataclass_field = field.dataclass_field
+        has_default_factory = (default_factory := 
dataclass_field.default_factory) is not MISSING
+        if dataclass_field.init:
+            ffi_arg_order.append(field.name)
+            if has_default_factory:
+                args_with_defaults.append(field.name)
+                fields_with_defaults.append((field.name, True))
+                exec_globals[f"_default_factory_{field.name}"] = 
default_factory
+            else:
+                args_no_defaults.append(field.name)
+        elif has_default_factory:
+            ffi_arg_order.append(field.name)
+            fields_with_defaults.append((field.name, False))
+            exec_globals[f"_default_factory_{field.name}"] = default_factory
+
+    args: list[str] = ["self"]
+    args.extend(args_no_defaults)
+    args.extend(f"{name}=MISSING" for name in args_with_defaults)
+    body_lines: list[str] = []
+    for field_name, is_init in fields_with_defaults:
+        if is_init:
+            body_lines.append(
+                f"if {field_name} is MISSING: {field_name} = 
_default_factory_{field_name}()"
+            )
         else:
-            fn_post_init()
-
-    __init__.__signature__ = sig  # type: ignore[attr-defined]
-    __init__.__annotations__ = annotations
+            body_lines.append(f"{field_name} = 
_default_factory_{field_name}()")
+    body_lines.append(f"self.__ffi_init__({', '.join(ffi_arg_order)})")
+    body_lines.extend(
+        [
+            "try:",
+            "    fn_post_init = self.__post_init__",
+            "except AttributeError:",
+            "    pass",
+            "else:",
+            "    fn_post_init()",
+        ]
+    )
+    source_lines = [f"def __init__({', '.join(args)}):"]
+    source_lines.extend(f"    {line}" for line in body_lines)
+    source_lines.append("    ...")
+    source = "\n".join(source_lines)
+    # Note: Code generation in this case is guaranteed to be safe,
+    # because the generated code does not contain any untrusted input.
+    # This is also a common practice used by `dataclasses` and `pydantic`.
+    namespace: dict[str, Any] = {}
+    exec(source, exec_globals, namespace)
+    __init__ = namespace["__init__"]
     return __init__
diff --git a/python/tvm_ffi/dataclasses/field.py 
b/python/tvm_ffi/dataclasses/field.py
index 2a62d66..92a28cb 100644
--- a/python/tvm_ffi/dataclasses/field.py
+++ b/python/tvm_ffi/dataclasses/field.py
@@ -37,23 +37,26 @@ class Field:
     way the decorator understands.
     """
 
-    __slots__ = ("default_factory", "name")
+    __slots__ = ("default_factory", "init", "name")
 
     def __init__(
         self,
         *,
         name: str | None = None,
         default_factory: Callable[[], _FieldValue] | _MISSING_TYPE = MISSING,
+        init: bool = True,
     ) -> None:
         """Do not call directly; use :func:`field` instead."""
         self.name = name
         self.default_factory = default_factory
+        self.init = init
 
 
 def field(
     *,
     default: _FieldValue | _MISSING_TYPE = MISSING,  # type: ignore[assignment]
     default_factory: Callable[[], _FieldValue] | _MISSING_TYPE = MISSING,  # 
type: ignore[assignment]
+    init: bool = True,
 ) -> _FieldValue:
     """(Experimental) Declare a dataclass-style field on a :func:`c_class` 
proxy.
 
@@ -65,13 +68,33 @@ def field(
     Parameters
     ----------
     default : Any, optional
-        A literal default value that should populate the field when no argument
-        is given.  The value is copied into a closure because TVM FFI does not
-        mutate the Python placeholder instance.
+        A literal default value that populates the field when no argument
+        is given. At most one of ``default`` or ``default_factory`` may be
+        given.
     default_factory : Callable[[], Any], optional
         A zero-argument callable that produces the default.  This matches the
         semantics of :func:`dataclasses.field` and is useful for mutable
         defaults such as ``list`` or ``dict``.
+    init : bool, default True
+        If ``True`` the field is included in the generated ``__init__``.
+        If ``False`` the field is omitted from input arguments of ``__init__``.
+
+    Note
+    ----
+    The decision to forward a field to the C++ ``__ffi_init__`` constructor
+    depends on its configuration:
+
+    *   If ``init=True``, the field's value (from user input or defaults)
+        is forwarded.
+
+    *   If ``init=False``:
+
+        -   With a ``default`` or ``default_factory``, its computed value is
+            forwarded. The user cannot provide this value via Python 
``__init__``.
+
+        -   Without a ``default`` or ``default_factory``, the field is *not*
+            forwarded to C++ ``__ffi_init__`` and must be initialized by the
+            C++ constructor.
 
     Returns
     -------
@@ -82,7 +105,9 @@ def field(
     Examples
     --------
     ``field`` integrates with :func:`c_class` to express defaults the same way 
a
-    Python ``dataclass`` would::
+    Python ``dataclass`` would:
+
+    .. code-block:: python
 
         @c_class("testing.TestCxxClassBase")
         class PyBase:
@@ -95,9 +120,11 @@ def field(
     """
     if default is not MISSING and default_factory is not MISSING:
         raise ValueError("Cannot specify both `default` and `default_factory`")
+    if not isinstance(init, bool):
+        raise TypeError("`init` must be a bool")
     if default is not MISSING:
         default_factory = _make_default_factory(default)
-    ret = Field(default_factory=default_factory)
+    ret = Field(default_factory=default_factory, init=init)
     return cast(_FieldValue, ret)
 
 
diff --git a/python/tvm_ffi/testing.py b/python/tvm_ffi/testing.py
index 0053564..155030b 100644
--- a/python/tvm_ffi/testing.py
+++ b/python/tvm_ffi/testing.py
@@ -110,3 +110,10 @@ class _TestCxxClassDerived(_TestCxxClassBase):
 class _TestCxxClassDerivedDerived(_TestCxxClassDerived):
     v_str: str = field(default_factory=lambda: "default")
     v_bool: bool  # type: ignore[misc]  # Suppress: Attributes without a 
default cannot follow attributes with one
+
+
+@c_class("testing.TestCxxInitSubset")
+class _TestCxxInitSubset:
+    required_field: int
+    optional_field: int = field(init=False)
+    note: str = field(default_factory=lambda: "py-default", init=False)
diff --git a/src/ffi/extra/testing.cc b/src/ffi/extra/testing.cc
index afd952b..cf55161 100644
--- a/src/ffi/extra/testing.cc
+++ b/src/ffi/extra/testing.cc
@@ -121,6 +121,19 @@ class TestCxxClassDerivedDerived : public 
TestCxxClassDerived {
                               TestCxxClassDerived);
 };
 
+class TestCxxInitSubsetObj : public Object {
+ public:
+  int64_t required_field;
+  int64_t optional_field;
+  String note;
+
+  explicit TestCxxInitSubsetObj(int64_t value, String note)
+      : required_field(value), optional_field(-1), note(note) {}
+
+  static constexpr bool _type_mutable = true;
+  TVM_FFI_DECLARE_OBJECT_INFO("testing.TestCxxInitSubset", 
TestCxxInitSubsetObj, Object);
+};
+
 class TestUnregisteredObject : public Object {
  public:
   int64_t value;
@@ -170,6 +183,12 @@ TVM_FFI_STATIC_INIT_BLOCK() {
       .def_rw("v_str", &TestCxxClassDerivedDerived::v_str)
       .def_rw("v_bool", &TestCxxClassDerivedDerived::v_bool);
 
+  refl::ObjectDef<TestCxxInitSubsetObj>()
+      .def_static("__ffi_init__", refl::init<TestCxxInitSubsetObj, int64_t, 
String>)
+      .def_rw("required_field", &TestCxxInitSubsetObj::required_field)
+      .def_rw("optional_field", &TestCxxInitSubsetObj::optional_field)
+      .def_rw("note", &TestCxxInitSubsetObj::note);
+
   refl::GlobalDef()
       .def("testing.test_raise_error", TestRaiseError)
       .def_packed("testing.nop", [](PackedArgs args, Any* ret) {})
diff --git a/tests/python/test_dataclasses_c_class.py 
b/tests/python/test_dataclasses_c_class.py
index 0050e6c..676bbf5 100644
--- a/tests/python/test_dataclasses_c_class.py
+++ b/tests/python/test_dataclasses_c_class.py
@@ -14,7 +14,14 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-from tvm_ffi.testing import _TestCxxClassBase, _TestCxxClassDerived, 
_TestCxxClassDerivedDerived
+import inspect
+
+from tvm_ffi.testing import (
+    _TestCxxClassBase,
+    _TestCxxClassDerived,
+    _TestCxxClassDerivedDerived,
+    _TestCxxInitSubset,
+)
 
 
 def test_cxx_class_base() -> None:
@@ -64,3 +71,26 @@ def test_cxx_class_derived_derived_default() -> None:
     assert isinstance(obj.v_f32, float) and obj.v_f32 == 8
     assert obj.v_str == "default"
     assert isinstance(obj.v_bool, bool) and obj.v_bool is True
+
+
+def test_cxx_class_init_subset_signature() -> None:
+    sig = inspect.signature(_TestCxxInitSubset.__init__)
+    params = tuple(sig.parameters)
+    assert "required_field" in params
+    assert "optional_field" not in params
+    assert "note" not in params
+
+
+def test_cxx_class_init_subset_defaults() -> None:
+    obj = _TestCxxInitSubset(required_field=42)
+    assert obj.required_field == 42
+    assert obj.optional_field == -1
+    assert obj.note == "py-default"
+
+
+def test_cxx_class_init_subset_positional() -> None:
+    obj = _TestCxxInitSubset(7)  # type: ignore[call-arg]
+    assert obj.required_field == 7
+    assert obj.optional_field == -1
+    obj.optional_field = 11
+    assert obj.optional_field == 11

Reply via email to