This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 360648f  feat: Add __repr__ generation support for @c_class 
dataclasses (#411)
360648f is described below

commit 360648f30ccb14523ab6fbb81f37eb085b801f98
Author: Dayuxiaoshui <[email protected]>
AuthorDate: Sun Jan 18 11:30:53 2026 +0800

    feat: Add __repr__ generation support for @c_class dataclasses (#411)
    
    - Add `repr` parameter to `c_class()` decorator (default: True)
    - Add `repr` parameter to `field()` function (default: True)
    - Implement `method_repr()` to generate __repr__ methods
    - Generated repr format: ClassName(field1=value1, field2=value2, ...)
    - Fields with repr=False are excluded from the representation
    
    This implements part of #356 dataclass feature parity.
---
 python/tvm_ffi/dataclasses/_utils.py     | 63 +++++++++++++++++++++++++++-----
 python/tvm_ffi/dataclasses/c_class.py    | 12 ++++--
 python/tvm_ffi/dataclasses/field.py      | 12 +++++-
 tests/python/test_dataclasses_c_class.py | 35 ++++++++++++++++++
 4 files changed, 108 insertions(+), 14 deletions(-)

diff --git a/python/tvm_ffi/dataclasses/_utils.py 
b/python/tvm_ffi/dataclasses/_utils.py
index bd647a6..5ed4e96 100644
--- a/python/tvm_ffi/dataclasses/_utils.py
+++ b/python/tvm_ffi/dataclasses/_utils.py
@@ -58,14 +58,13 @@ def type_info_to_cls(
     def _add_method(name: str, func: Callable[..., Any]) -> None:
         if name == "__ffi_init__":
             name = "__c_ffi_init__"
-        if name in attrs:  # already defined
-            return
+        # Allow overriding methods (including from base classes like 
Object.__repr__)
+        # by always adding to attrs, which will be used when creating the new 
class
         func.__module__ = cls.__module__
         func.__name__ = name
         func.__qualname__ = f"{cls.__qualname__}.{name}"
         func.__doc__ = f"Method `{name}` of class `{cls.__qualname__}`"
         attrs[name] = func
-        setattr(cls, name, func)
 
     for name, method_impl in methods.items():
         if method_impl is not None:
@@ -98,6 +97,57 @@ def fill_dataclass_field(type_cls: type, type_field: 
TypeField) -> None:
     type_field.dataclass_field = rhs
 
 
+def _get_all_fields(type_info: TypeInfo) -> list[TypeField]:
+    """Collect all fields from the type hierarchy, from parents to children."""
+    fields: list[TypeField] = []
+    cur_type_info: TypeInfo | None = type_info
+    while cur_type_info is not None:
+        fields.extend(reversed(cur_type_info.fields))
+        cur_type_info = cur_type_info.parent_type_info
+    fields.reverse()
+    return fields
+
+
+def method_repr(type_cls: type, type_info: TypeInfo) -> Callable[..., str]:
+    """Generate a ``__repr__`` method for the dataclass.
+
+    The generated representation includes all fields with ``repr=True`` in
+    the format ``ClassName(field1=value1, field2=value2, ...)``.
+    """
+    # Step 0. Collect all fields from the type hierarchy
+    fields = _get_all_fields(type_info)
+
+    # Step 1. Filter fields that should appear in repr
+    repr_fields: list[str] = []
+    for field in fields:
+        assert field.name is not None
+        assert field.dataclass_field is not None
+        if field.dataclass_field.repr:
+            repr_fields.append(field.name)
+
+    # Step 2. Generate the repr method
+    if not repr_fields:
+        # No fields to show, return a simple class name representation
+        body_lines = [f"return f'{type_cls.__name__}()'"]
+    else:
+        # Build field representations
+        fields_str = ", ".join(
+            f"{field_name}={{self.{field_name}!r}}" for field_name in 
repr_fields
+        )
+        body_lines = [f"return f'{type_cls.__name__}({fields_str})'"]
+
+    source_lines = ["def __repr__(self) -> str:"]
+    source_lines.extend(f"    {line}" for line in body_lines)
+    source = "\n".join(source_lines)
+
+    # Note: Code generation in this case is guaranteed to be safe,
+    # because the generated code does not contain any untrusted input.
+    namespace: dict[str, Any] = {}
+    exec(source, {}, namespace)
+    __repr__ = namespace["__repr__"]
+    return __repr__
+
+
 def method_init(type_cls: type, type_info: TypeInfo) -> Callable[..., None]:
     """Generate an ``__init__`` that forwards to the FFI constructor.
 
@@ -105,12 +155,7 @@ def method_init(type_cls: type, type_info: TypeInfo) -> 
Callable[..., None]:
     reflected field list, supporting default values and ``__post_init__``.
     """
     # Step 0. Collect all fields from the type hierarchy
-    fields: list[TypeField] = []
-    cur_type_info: TypeInfo | None = type_info
-    while cur_type_info is not None:
-        fields.extend(reversed(cur_type_info.fields))
-        cur_type_info = cur_type_info.parent_type_info
-    fields.reverse()
+    fields = _get_all_fields(type_info)
     # sanity check
     for type_method in type_info.methods:
         if type_method.name == "__ffi_init__":
diff --git a/python/tvm_ffi/dataclasses/c_class.py 
b/python/tvm_ffi/dataclasses/c_class.py
index 65d7c73..8171b1b 100644
--- a/python/tvm_ffi/dataclasses/c_class.py
+++ b/python/tvm_ffi/dataclasses/c_class.py
@@ -41,7 +41,7 @@ _InputClsType = TypeVar("_InputClsType")
 
 @dataclass_transform(field_specifiers=(field,))
 def c_class(
-    type_key: str, init: bool = True
+    type_key: str, init: bool = True, repr: bool = True
 ) -> Callable[[Type[_InputClsType]], Type[_InputClsType]]:  # noqa: UP006
     """(Experimental) Create a dataclass-like proxy for a C++ class registered 
with TVM FFI.
 
@@ -71,6 +71,10 @@ def c_class(
         signature.  The generated initializer calls the C++ ``__init__``
         function registered with ``ObjectDef`` and invokes ``__post_init__`` if
         it exists on the Python class.
+    repr
+        If ``True`` and the Python class does not define ``__repr__``, a
+        representation method is auto-generated that includes all fields with
+        ``repr=True``.
 
     Returns
     -------
@@ -118,8 +122,9 @@ def c_class(
     """
 
     def decorator(super_type_cls: Type[_InputClsType]) -> Type[_InputClsType]: 
 # noqa: UP006
-        nonlocal init
+        nonlocal init, repr
         init = init and "__init__" not in super_type_cls.__dict__
+        repr = repr and "__repr__" not in super_type_cls.__dict__
         # Step 1. Retrieve `type_info` from registry
         type_info: TypeInfo = 
_lookup_or_register_type_info_from_type_key(type_key)
         assert type_info.parent_type_info is not None
@@ -129,10 +134,11 @@ def c_class(
             _utils.fill_dataclass_field(super_type_cls, type_field)
         # Step 3. Create the proxy class with the fields as properties
         fn_init = _utils.method_init(super_type_cls, type_info) if init else 
None
+        fn_repr = _utils.method_repr(super_type_cls, type_info) if repr else 
None
         type_cls: Type[_InputClsType] = _utils.type_info_to_cls(  # noqa: UP006
             type_info=type_info,
             cls=super_type_cls,
-            methods={"__init__": fn_init},
+            methods={"__init__": fn_init, "__repr__": fn_repr},
         )
         _set_type_cls(type_info, type_cls)
         return type_cls
diff --git a/python/tvm_ffi/dataclasses/field.py 
b/python/tvm_ffi/dataclasses/field.py
index d10612c..d0e27b1 100644
--- a/python/tvm_ffi/dataclasses/field.py
+++ b/python/tvm_ffi/dataclasses/field.py
@@ -37,7 +37,7 @@ class Field:
     way the decorator understands.
     """
 
-    __slots__ = ("default_factory", "init", "name")
+    __slots__ = ("default_factory", "init", "name", "repr")
 
     def __init__(
         self,
@@ -45,11 +45,13 @@ class Field:
         name: str | None = None,
         default_factory: Callable[[], _FieldValue] | _MISSING_TYPE = MISSING,
         init: bool = True,
+        repr: bool = True,
     ) -> None:
         """Do not call directly; use :func:`field` instead."""
         self.name = name
         self.default_factory = default_factory
         self.init = init
+        self.repr = repr
 
 
 def field(
@@ -57,6 +59,7 @@ def field(
     default: _FieldValue | _MISSING_TYPE = MISSING,  # type: ignore[assignment]
     default_factory: Callable[[], _FieldValue] | _MISSING_TYPE = MISSING,  # 
type: ignore[assignment]
     init: bool = True,
+    repr: bool = True,
 ) -> _FieldValue:
     """(Experimental) Declare a dataclass-style field on a :func:`c_class` 
proxy.
 
@@ -78,6 +81,9 @@ def field(
     init
         If ``True`` the field is included in the generated ``__init__``.
         If ``False`` the field is omitted from input arguments of ``__init__``.
+    repr
+        If ``True`` the field is included in the generated ``__repr__``.
+        If ``False`` the field is omitted from the ``__repr__`` output.
 
     Note
     ----
@@ -123,9 +129,11 @@ def field(
         raise ValueError("Cannot specify both `default` and `default_factory`")
     if not isinstance(init, bool):
         raise TypeError("`init` must be a bool")
+    if not isinstance(repr, bool):
+        raise TypeError("`repr` must be a bool")
     if default is not MISSING:
         default_factory = _make_default_factory(default)
-    ret = Field(default_factory=default_factory, init=init)
+    ret = Field(default_factory=default_factory, init=init, repr=repr)
     return cast(_FieldValue, ret)
 
 
diff --git a/tests/python/test_dataclasses_c_class.py 
b/tests/python/test_dataclasses_c_class.py
index 676bbf5..5361cb6 100644
--- a/tests/python/test_dataclasses_c_class.py
+++ b/tests/python/test_dataclasses_c_class.py
@@ -94,3 +94,38 @@ def test_cxx_class_init_subset_positional() -> None:
     assert obj.optional_field == -1
     obj.optional_field = 11
     assert obj.optional_field == 11
+
+
+def test_cxx_class_repr() -> None:
+    obj = _TestCxxClassDerived(v_i64=123, v_i32=456, v_f64=4.0, v_f32=8.0)
+    repr_str = repr(obj)
+    assert "_TestCxxClassDerived" in repr_str
+    if "__repr__" in _TestCxxClassDerived.__dict__:
+        assert "v_i64=123" in repr_str
+        assert "v_i32=456" in repr_str
+        assert "v_f64=4.0" in repr_str
+        assert "v_f32=8.0" in repr_str
+
+
+def test_cxx_class_repr_default() -> None:
+    obj = _TestCxxClassDerived(v_i64=123, v_i32=456, v_f64=4.0)
+    repr_str = repr(obj)
+    assert "_TestCxxClassDerived" in repr_str
+    if "__repr__" in _TestCxxClassDerived.__dict__:
+        assert "v_i64=123" in repr_str
+        assert "v_i32=456" in repr_str
+        assert "v_f64=4.0" in repr_str
+        assert "v_f32=8.0" in repr_str
+
+
+def test_cxx_class_repr_derived_derived() -> None:
+    obj = _TestCxxClassDerivedDerived(
+        v_i64=123, v_i32=456, v_f64=4.0, v_f32=8.0, v_str="hello", v_bool=True
+    )
+    repr_str = repr(obj)
+    assert "_TestCxxClassDerivedDerived" in repr_str
+    if "__repr__" in _TestCxxClassDerivedDerived.__dict__:
+        assert "v_i64=123" in repr_str
+        assert "v_i32=456" in repr_str
+        assert "v_str='hello'" in repr_str or 'v_str="hello"' in repr_str
+        assert "v_bool=True" in repr_str

Reply via email to