This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new e5f3af7b feat(python): reimplement c_class as register_object + 
structural dunders (#488)
e5f3af7b is described below

commit e5f3af7bb83e6461c45d08117e4eaabe51add3b1
Author: Junru Shao <[email protected]>
AuthorDate: Sat Feb 28 12:03:37 2026 -0800

    feat(python): reimplement c_class as register_object + structural dunders 
(#488)
    
    ## Summary
    
    Rewrite the `@c_class` decorator from a thin `register_object`
    pass-through into
    a `dataclass`-style decorator that combines FFI type registration with
    structural
    dunder methods derived from C++ reflection metadata.
    
    - **`@c_class` now installs structural dunders** — `__init__`,
    `__repr__`, `__eq__`/`__ne__`,
    `__hash__`, and ordering operators (`__lt__`, `__le__`, `__gt__`,
    `__ge__`) — all
    delegating to the corresponding C++ recursive operations (`RecursiveEq`,
    `RecursiveHash`,
      `RecursiveLt`, etc.).
    - **`@dataclass_transform` decorator** added for IDE/type-checker
    support (pyright, mypy).
    - **Migrated all test objects** in `tvm_ffi.testing` from
    `@register_object` to `@c_class`.
    
    ## Architecture
    
    - `c_class.py`: decorator accepts `init`, `repr`, `eq`, `order`,
    `unsafe_hash`
    parameters. Delegates to `register_object` +
    `_install_dataclass_dunders`.
    - `registry.py`: `_install_dataclass_dunders` installs dunders on class;
      `_install_init` synthesizes reflection-based `__init__` or guard;
    `_make_init` / `_make_init_signature` builds `inspect.Signature` from
    C++ field
      metadata (respecting `kw_only`, `has_default`, `c_init` traits).
      `_is_comparable` centralises the bidirectional isinstance guard.
    - Each installed dunder checks `cls.__dict__` before setting, preserving
    user-defined
      overrides.
    - `__eq__`/`__ne__`/ordering return `NotImplemented` for unrelated
    types, following
      Python data model conventions.
    
    ## Public Interfaces
    
    - `@c_class(type_key, *, init, repr, eq, order, unsafe_hash)` — new
    keyword
    arguments; old usage `@c_class("key")` continues to work with sensible
    defaults
      (`init=True`, `repr=True`, others off).
    - No breaking changes — `eq`, `order`, `unsafe_hash` default to `False`.
    
    ## Test Plan
    
    - [x] New `test_dataclass_c_class.py` (26 tests): custom init
    preservation,
    auto-generated init with defaults, structural equality (reflexive,
    symmetric),
    hash (dict key, set dedup), ordering (reflexive, antisymmetric),
    different-type
    returns `NotImplemented`, subclass equality, `kw_only` from C++
    reflection,
      `init_subset`, derived-derived defaults.
    - [x] Renamed `test_copy.py` → `test_dataclass_copy.py` with additional
    cycle/Shape
      coverage and `deep_copy.cc` branch-coverage tests.
    - [x] `uv run pytest -vvs tests/python` — 960 passed, 23 skipped, 1
    xfailed.
    - [x] `sphinx-build -W --keep-going -b html docs docs/_build/html` —
    build succeeded.
    
    🤖 Generated with [Claude Code](https://claude.com/claude-code)
---
 python/tvm_ffi/cython/object.pxi                   |   7 +-
 python/tvm_ffi/dataclasses/c_class.py              |  89 +++++-
 python/tvm_ffi/registry.py                         | 130 ++++++++-
 python/tvm_ffi/testing/testing.py                  |  60 ++--
 tests/python/test_dataclass_c_class.py             | 319 +++++++++++++++++++++
 .../{test_copy.py => test_dataclass_copy.py}       | 202 +++++++++++++
 tests/python/test_dataclass_init.py                |   2 +-
 7 files changed, 772 insertions(+), 37 deletions(-)

diff --git a/python/tvm_ffi/cython/object.pxi b/python/tvm_ffi/cython/object.pxi
index 97536fec..ba0a4906 100644
--- a/python/tvm_ffi/cython/object.pxi
+++ b/python/tvm_ffi/cython/object.pxi
@@ -203,9 +203,10 @@ class Object(CObject, metaclass=_ObjectSlotsMeta):
       identity unless an overridden implementation is provided on the
       concrete type. Use :py:meth:`same_as` to check whether two
       references point to the same underlying object.
-    - Subclasses that omit ``__slots__`` are treated as ``__slots__ = ()``.
-      Subclasses that need per-instance dynamic attributes can opt in with
-      ``__slots__ = ("__dict__",)``.
+    - Subclasses that omit ``__slots__`` get ``__slots__ = ()`` injected
+      automatically by the metaclass.  Pass ``slots=False`` in the class
+      header (e.g. ``class Foo(Object, slots=False)``) to suppress this
+      and allow a per-instance ``__dict__``.
     - Most users interact with subclasses (e.g. :class:`Tensor`,
       :class:`Function`) rather than :py:class:`Object` directly.
 
diff --git a/python/tvm_ffi/dataclasses/c_class.py 
b/python/tvm_ffi/dataclasses/c_class.py
index d836f1a9..ee6d4329 100644
--- a/python/tvm_ffi/dataclasses/c_class.py
+++ b/python/tvm_ffi/dataclasses/c_class.py
@@ -14,36 +14,103 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""The ``c_class`` decorator: pass-through to ``register_object``."""
+"""The ``c_class`` decorator: register_object + structural dunders."""
 
 from __future__ import annotations
 
 from collections.abc import Callable
-from typing import Any, TypeVar
+from typing import TypeVar
+
+from typing_extensions import dataclass_transform
 
 _T = TypeVar("_T", bound=type)
 
 
-def c_class(type_key: str, **kwargs: Any) -> Callable[[_T], _T]:
-    """Register a C++ FFI class by type key.
+@dataclass_transform(eq_default=False, order_default=False)
+def c_class(
+    type_key: str,
+    *,
+    init: bool = True,
+    repr: bool = True,
+    eq: bool = False,
+    order: bool = False,
+    unsafe_hash: bool = False,
+) -> Callable[[_T], _T]:
+    """Register a C++ FFI class and install structural dunder methods.
 
-    This is a thin wrapper around :func:`~tvm_ffi.register_object` that
-    accepts (and currently ignores) additional keyword arguments for
-    forward compatibility.
+    Combines :func:`~tvm_ffi.register_object` with structural comparison,
+    hashing, and ordering derived from the C++ reflection metadata.
+    User-defined dunders in the class body are never overwritten.
 
     Parameters
     ----------
     type_key
         The reflection key that identifies the C++ type in the FFI registry.
-    kwargs
-        Reserved for future use.
+        Must match a key already registered on the C++ side via
+        ``TVM_FFI_DECLARE_OBJECT_INFO``.
+    init
+        If True (default), install ``__init__`` from C++ reflection metadata.
+        The generated ``__init__`` respects ``Init()``, ``KwOnly()``, and
+        ``Default()`` traits declared on each C++ field.  If the class body
+        already defines ``__init__``, it is kept.
+    repr
+        If True (default), install ``__repr__`` using
+        :func:`~tvm_ffi.core.object_repr`, which formats the object via
+        the C++ ``ReprPrint`` visitor.  Skipped if the class body already
+        defines ``__repr__``.
+    eq
+        If True, install ``__eq__`` and ``__ne__`` using the C++ recursive
+        structural comparison (``RecursiveEq``).  Returns ``NotImplemented``
+        for unrelated types.  Defaults to False.
+    order
+        If True, install ``__lt__``, ``__le__``, ``__gt__``, ``__ge__``
+        using the C++ recursive comparators.  Returns ``NotImplemented``
+        for unrelated types.  Defaults to False.
+    unsafe_hash
+        If True, install ``__hash__`` using ``RecursiveHash``.  Called
+        *unsafe* because mutable fields contribute to the hash, so mutating
+        an object while it is in a set or dict key will break invariants.
+        Defaults to False.
 
     Returns
     -------
     Callable[[type], type]
         A class decorator.
 
+    Examples
+    --------
+    Basic usage with default settings (``init`` and ``repr`` enabled):
+
+    .. code-block:: python
+
+        @c_class("my.Point")
+        class Point(Object):
+            x: float
+            y: float
+
+    Enable structural equality, hashing, and ordering:
+
+    .. code-block:: python
+
+        @c_class("my.Point", eq=True, unsafe_hash=True, order=True)
+        class Point(Object):
+            x: float
+            y: float
+
+    See Also
+    --------
+    :func:`tvm_ffi.register_object`
+        Lower-level decorator that only registers the type without
+        installing structural dunders.
+
     """
-    from ..registry import register_object  # noqa: PLC0415
+    from ..registry import _install_dataclass_dunders, register_object  # 
noqa: PLC0415
+
+    def decorator(cls: _T) -> _T:
+        cls = register_object(type_key)(cls)
+        _install_dataclass_dunders(
+            cls, init=init, repr=repr, eq=eq, order=order, 
unsafe_hash=unsafe_hash
+        )
+        return cls
 
-    return register_object(type_key)
+    return decorator
diff --git a/python/tvm_ffi/registry.py b/python/tvm_ffi/registry.py
index f2801b1f..b45b03e0 100644
--- a/python/tvm_ffi/registry.py
+++ b/python/tvm_ffi/registry.py
@@ -63,8 +63,8 @@ def register_object(type_key: str | None = None) -> 
Callable[[_T], _T]:
                 return cls
             raise ValueError(f"Cannot find object type index for 
{object_name}")
         info = core._register_object_by_index(type_index, cls)
-        setattr(cls, "__tvm_ffi_type_info__", info)
         _add_class_attrs(type_cls=cls, type_info=info)
+        setattr(cls, "__tvm_ffi_type_info__", info)
         return cls
 
     if isinstance(type_key, str):
@@ -418,7 +418,6 @@ def _add_class_attrs(type_cls: type, type_info: TypeInfo) 
-> type:
             setattr(type_cls, name, method.as_callable(type_cls))
         elif not hasattr(type_cls, name):
             setattr(type_cls, name, method.as_callable(type_cls))
-    _install_init(type_cls, enabled=True)
     is_container = type_info.type_key in (
         "ffi.Array",
         "ffi.Map",
@@ -456,7 +455,19 @@ def _setup_copy_methods(
 
 
 def _install_init(cls: type, *, enabled: bool) -> None:
-    """Install ``__init__`` from C++ reflection metadata, or a guard."""
+    """Install ``__init__`` from C++ reflection metadata, or a guard.
+
+    When *enabled* is True, looks for a ``__ffi_init__`` method in the
+    type's C++ reflection metadata.  If the method has ``auto_init=True``
+    metadata (set by ``refl::init()`` in C++), a Python ``__init__`` is
+    synthesized with an ``inspect.Signature`` derived from the field
+    metadata (respecting ``Init()``, ``KwOnly()``, ``Default()`` traits).
+    Otherwise the raw ``__ffi_init__`` is exposed as ``__init__`` directly.
+
+    When *enabled* is False, installs a guard that raises ``TypeError``
+    on construction.  Skipped entirely if the class body already defines
+    ``__init__``.
+    """
     if "__init__" in cls.__dict__:
         return
     type_info: TypeInfo | None = getattr(cls, "__tvm_ffi_type_info__", None)
@@ -531,6 +542,119 @@ def _replace_unsupported(self: Any, **kwargs: Any) -> Any:
     )
 
 
+def _install_dataclass_dunders(
+    cls: type,
+    *,
+    init: bool,
+    repr: bool,
+    eq: bool,
+    order: bool,
+    unsafe_hash: bool,
+) -> None:
+    """Install structural dunder methods on *cls*.
+
+    Each dunder delegates to the corresponding C++ recursive structural
+    operation (``RecursiveEq``, ``RecursiveHash``, ``RecursiveLt``, etc.).
+    If the user already defined a dunder in the class body
+    (i.e. it exists in ``cls.__dict__``), it is left untouched.
+
+    Parameters
+    ----------
+    cls
+        The class to install dunders on.  Must have been processed by
+        :func:`register_object` first (so ``__tvm_ffi_type_info__`` exists).
+    init
+        If True, install ``__init__`` from C++ reflection metadata via
+        :func:`_install_init`.
+    repr
+        If True, install :func:`~tvm_ffi.core.object_repr` as ``__repr__``.
+    eq
+        If True, install ``__eq__`` and ``__ne__`` using ``RecursiveEq``.
+        Returns ``NotImplemented`` for unrelated types so Python can
+        fall back to identity comparison.
+    order
+        If True, install ``__lt__``, ``__le__``, ``__gt__``, ``__ge__``
+        using ``RecursiveLt``/``Le``/``Gt``/``Ge``.  Returns
+        ``NotImplemented`` for unrelated types.
+    unsafe_hash
+        If True, install ``__hash__`` using ``RecursiveHash``.
+
+    """
+    _install_init(cls, enabled=init)
+
+    if repr and "__repr__" not in cls.__dict__:
+        from .core import object_repr  # noqa: PLC0415
+
+        cls.__repr__ = object_repr  # type: ignore[attr-defined]
+
+    from . import _ffi_api  # noqa: PLC0415
+
+    def _is_comparable(self: Any, other: Any) -> bool:
+        """Return True if *self* and *other* share a type hierarchy."""
+        return isinstance(other, type(self)) or isinstance(self, type(other))
+
+    dunders: dict[str, Any] = {}
+
+    if eq:
+        recursive_eq = _ffi_api.RecursiveEq
+
+        def __eq__(self: Any, other: Any) -> bool:
+            if not _is_comparable(self, other):
+                return NotImplemented
+            return recursive_eq(self, other)
+
+        def __ne__(self: Any, other: Any) -> bool:
+            if not _is_comparable(self, other):
+                return NotImplemented
+            return not recursive_eq(self, other)
+
+        dunders["__eq__"] = __eq__
+        dunders["__ne__"] = __ne__
+
+    if unsafe_hash:
+        recursive_hash = _ffi_api.RecursiveHash
+
+        def __hash__(self: Any) -> int:
+            return recursive_hash(self)
+
+        dunders["__hash__"] = __hash__
+
+    if order:
+        recursive_lt = _ffi_api.RecursiveLt
+        recursive_le = _ffi_api.RecursiveLe
+        recursive_gt = _ffi_api.RecursiveGt
+        recursive_ge = _ffi_api.RecursiveGe
+
+        def __lt__(self: Any, other: Any) -> bool:
+            if not _is_comparable(self, other):
+                return NotImplemented
+            return recursive_lt(self, other)
+
+        def __le__(self: Any, other: Any) -> bool:
+            if not _is_comparable(self, other):
+                return NotImplemented
+            return recursive_le(self, other)
+
+        def __gt__(self: Any, other: Any) -> bool:
+            if not _is_comparable(self, other):
+                return NotImplemented
+            return recursive_gt(self, other)
+
+        def __ge__(self: Any, other: Any) -> bool:
+            if not _is_comparable(self, other):
+                return NotImplemented
+            return recursive_ge(self, other)
+
+        dunders["__lt__"] = __lt__
+        dunders["__le__"] = __le__
+        dunders["__gt__"] = __gt__
+        dunders["__ge__"] = __ge__
+
+    for name, impl in dunders.items():
+        if name not in cls.__dict__:
+            setattr(cls, name, impl)
+
+
 def get_registered_type_keys() -> Sequence[str]:
     """Get the list of valid type keys registered to TVM-FFI.
 
diff --git a/python/tvm_ffi/testing/testing.py 
b/python/tvm_ffi/testing/testing.py
index d98374d9..59f1e50d 100644
--- a/python/tvm_ffi/testing/testing.py
+++ b/python/tvm_ffi/testing/testing.py
@@ -35,10 +35,10 @@ from typing import ClassVar
 from .. import _ffi_api
 from ..core import Object
 from ..dataclasses import c_class
-from ..registry import get_global_func, register_object
+from ..registry import get_global_func
 
 
-@register_object("testing.TestObjectBase")
+@c_class("testing.TestObjectBase")
 class TestObjectBase(Object):
     """Test object base class."""
 
@@ -54,10 +54,12 @@ class TestObjectBase(Object):
     # tvm-ffi-stubgen(end)
 
 
-@register_object("testing.TestIntPair")
+@c_class("testing.TestIntPair")
 class TestIntPair(Object):
     """Test Int Pair."""
 
+    __test__ = False
+
     # tvm-ffi-stubgen(begin): object/testing.TestIntPair
     # fmt: off
     a: int
@@ -71,7 +73,7 @@ class TestIntPair(Object):
     # tvm-ffi-stubgen(end)
 
 
-@register_object("testing.TestObjectDerived")
+@c_class("testing.TestObjectDerived")
 class TestObjectDerived(TestObjectBase):
     """Test object derived class."""
 
@@ -85,14 +87,14 @@ class TestObjectDerived(TestObjectBase):
     # tvm-ffi-stubgen(end)
 
 
-@register_object("testing.TestNonCopyable")
+@c_class("testing.TestNonCopyable")
 class TestNonCopyable(Object):
     """Test object with deleted copy constructor."""
 
     value: int
 
 
-@register_object("testing.TestCompare")
+@c_class("testing.TestCompare")
 class TestCompare(Object):
     """Test object with Compare(false) on ignored_field."""
 
@@ -111,7 +113,7 @@ class TestCompare(Object):
     # tvm-ffi-stubgen(end)
 
 
-@register_object("testing.TestCustomCompare")
+@c_class("testing.TestCustomCompare")
 class TestCustomCompare(Object):
     """Test object with custom __ffi_eq__/__ffi_compare__ hooks (compares only 
key)."""
 
@@ -129,7 +131,7 @@ class TestCustomCompare(Object):
     # tvm-ffi-stubgen(end)
 
 
-@register_object("testing.TestEqWithoutHash")
+@c_class("testing.TestEqWithoutHash")
 class TestEqWithoutHash(Object):
     """Test object with __ffi_eq__ but no __ffi_hash__ (exercises hash 
guard)."""
 
@@ -147,7 +149,7 @@ class TestEqWithoutHash(Object):
     # tvm-ffi-stubgen(end)
 
 
-@register_object("testing.TestHash")
+@c_class("testing.TestHash")
 class TestHash(Object):
     """Test object with Hash(false) on hash_ignored."""
 
@@ -166,7 +168,7 @@ class TestHash(Object):
     # tvm-ffi-stubgen(end)
 
 
-@register_object("testing.TestCustomHash")
+@c_class("testing.TestCustomHash")
 class TestCustomHash(Object):
     """Test object with custom __ffi_hash__ hook (hashes only key)."""
 
@@ -184,7 +186,7 @@ class TestCustomHash(Object):
     # tvm-ffi-stubgen(end)
 
 
-@register_object("testing.SchemaAllTypes")
+@c_class("testing.SchemaAllTypes")
 class _SchemaAllTypes:
     # tvm-ffi-stubgen(ty-map): testing.SchemaAllTypes -> 
testing._SchemaAllTypes
     # tvm-ffi-stubgen(begin): object/testing.SchemaAllTypes
@@ -265,16 +267,30 @@ class _TestCxxClassBase(Object):
         self.__ffi_init__(v_i64 + 1, v_i32 + 2)
 
 
-@c_class("testing.TestCxxClassDerived")
+@c_class("testing.TestCxxClassDerived", eq=True, order=True, unsafe_hash=True)
 class _TestCxxClassDerived(_TestCxxClassBase):
     v_f64: float
     v_f32: float
+    if TYPE_CHECKING:
+
+        def __init__(self, v_i64: int, v_i32: int, v_f64: float, v_f32: float 
= ...) -> None: ...
 
 
 @c_class("testing.TestCxxClassDerivedDerived")
 class _TestCxxClassDerivedDerived(_TestCxxClassDerived):
     v_str: str
     v_bool: bool
+    if TYPE_CHECKING:
+
+        def __init__(
+            self,
+            v_i64: int,
+            v_i32: int,
+            v_f64: float,
+            v_bool: bool,
+            v_f32: float = ...,
+            v_str: str = ...,
+        ) -> None: ...
 
 
 @c_class("testing.TestCxxInitSubset")
@@ -282,6 +298,9 @@ class _TestCxxInitSubset(Object):
     required_field: int
     optional_field: int
     note: str
+    if TYPE_CHECKING:
+
+        def __init__(self, required_field: int) -> None: ...
 
 
 @c_class("testing.TestCxxKwOnly")
@@ -290,9 +309,12 @@ class _TestCxxKwOnly(Object):
     y: int
     z: int
     w: int
+    if TYPE_CHECKING:
+
+        def __init__(self, *, x: int, y: int, z: int, w: int = ...) -> None: 
...
 
 
-@register_object("testing.TestCxxAutoInit")
+@c_class("testing.TestCxxAutoInit")
 class _TestCxxAutoInit(Object):
     """Test object with init(false) on b and KwOnly(true) on c."""
 
@@ -307,7 +329,7 @@ class _TestCxxAutoInit(Object):
         def __init__(self, a: int, d: int = ..., *, c: int) -> None: ...
 
 
-@register_object("testing.TestCxxAutoInitSimple")
+@c_class("testing.TestCxxAutoInitSimple")
 class _TestCxxAutoInitSimple(Object):
     """Test object with all fields positional (no init/KwOnly traits)."""
 
@@ -320,7 +342,7 @@ class _TestCxxAutoInitSimple(Object):
         def __init__(self, x: int, y: int) -> None: ...
 
 
-@register_object("testing.TestCxxAutoInitAllInitOff")
+@c_class("testing.TestCxxAutoInitAllInitOff")
 class _TestCxxAutoInitAllInitOff(Object):
     """Test object with all fields excluded from auto-init (init(false))."""
 
@@ -334,7 +356,7 @@ class _TestCxxAutoInitAllInitOff(Object):
         def __init__(self) -> None: ...
 
 
-@register_object("testing.TestCxxAutoInitKwOnlyDefaults")
+@c_class("testing.TestCxxAutoInitKwOnlyDefaults")
 class _TestCxxAutoInitKwOnlyDefaults(Object):
     """Test object with mixed positional/kw-only/default/init=False fields."""
 
@@ -352,7 +374,7 @@ class _TestCxxAutoInitKwOnlyDefaults(Object):
         ) -> None: ...
 
 
-@register_object("testing.TestCxxNoAutoInit")
+@c_class("testing.TestCxxNoAutoInit", init=False)
 class _TestCxxNoAutoInit(Object):
     """Test object with init(false) at class level — no __ffi_init__ 
generated."""
 
@@ -362,7 +384,7 @@ class _TestCxxNoAutoInit(Object):
     y: int
 
 
-@register_object("testing.TestCxxAutoInitParent")
+@c_class("testing.TestCxxAutoInitParent")
 class _TestCxxAutoInitParent(Object):
     """Parent object for inheritance auto-init tests."""
 
@@ -375,7 +397,7 @@ class _TestCxxAutoInitParent(Object):
         def __init__(self, parent_required: int, parent_default: int = ...) -> 
None: ...
 
 
-@register_object("testing.TestCxxAutoInitChild")
+@c_class("testing.TestCxxAutoInitChild")
 class _TestCxxAutoInitChild(_TestCxxAutoInitParent):
     """Child object for inheritance auto-init tests."""
 
diff --git a/tests/python/test_dataclass_c_class.py 
b/tests/python/test_dataclass_c_class.py
new file mode 100644
index 00000000..52bf38bb
--- /dev/null
+++ b/tests/python/test_dataclass_c_class.py
@@ -0,0 +1,319 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Tests for the c_class decorator (register_object + structural dunders)."""
+
+from __future__ import annotations
+
+import inspect
+
+import pytest
+from tvm_ffi.testing import (
+    _TestCxxClassBase,
+    _TestCxxClassDerived,
+    _TestCxxClassDerivedDerived,
+    _TestCxxInitSubset,
+    _TestCxxKwOnly,
+)
+
+# ---------------------------------------------------------------------------
+# 1. Custom __init__ preservation
+# ---------------------------------------------------------------------------
+
+
+def test_c_class_custom_init() -> None:
+    """c_class preserves user-defined __init__."""
+    obj = _TestCxxClassBase(v_i64=10, v_i32=20)
+    assert obj.v_i64 == 11  # +1 from custom __init__
+    assert obj.v_i32 == 22  # +2 from custom __init__
+
+
+# ---------------------------------------------------------------------------
+# 2. Auto-generated __init__ with defaults
+# ---------------------------------------------------------------------------
+
+
+def test_c_class_auto_init_defaults() -> None:
+    """Derived classes use auto-generated __init__ with C++ defaults."""
+    obj = _TestCxxClassDerived(v_i64=1, v_i32=2, v_f64=3.0)
+    assert obj.v_i64 == 1
+    assert obj.v_i32 == 2
+    assert obj.v_f64 == 3.0
+    assert obj.v_f32 == 8.0  # default from C++
+
+
+def test_c_class_auto_init_all_explicit() -> None:
+    """Auto-generated __init__ accepts all fields explicitly."""
+    obj = _TestCxxClassDerived(v_i64=123, v_i32=456, v_f64=4.0, v_f32=9.0)
+    assert obj.v_i64 == 123
+    assert obj.v_i32 == 456
+    assert obj.v_f64 == 4.0
+    assert obj.v_f32 == 9.0
+
+
+# ---------------------------------------------------------------------------
+# 3. Structural equality (__eq__)
+# ---------------------------------------------------------------------------
+
+
+def test_c_class_eq() -> None:
+    """c_class installs __eq__ using RecursiveEq."""
+    a = _TestCxxClassDerived(1, 2, 3.0, 4.0)
+    b = _TestCxxClassDerived(1, 2, 3.0, 4.0)
+    assert a == b
+    assert a is not b  # different objects
+    c = _TestCxxClassDerived(1, 2, 3.0, 5.0)
+    assert a != c
+
+
+def test_c_class_eq_reflexive() -> None:
+    """Equality is reflexive: an object equals itself."""
+    a = _TestCxxClassDerived(1, 2, 3.0, 4.0)
+    b = a  # alias, same object
+    assert a == b
+
+
+def test_c_class_eq_symmetric() -> None:
+    """Equality is symmetric: a == b implies b == a."""
+    a = _TestCxxClassDerived(1, 2, 3.0, 4.0)
+    b = _TestCxxClassDerived(1, 2, 3.0, 4.0)
+    assert a == b
+    assert b == a
+
+
+# ---------------------------------------------------------------------------
+# 4. Structural hash (__hash__)
+# ---------------------------------------------------------------------------
+
+
+def test_c_class_hash() -> None:
+    """c_class installs __hash__ using RecursiveHash."""
+    a = _TestCxxClassDerived(1, 2, 3.0, 4.0)
+    b = _TestCxxClassDerived(1, 2, 3.0, 4.0)
+    assert hash(a) == hash(b)
+
+
+def test_c_class_hash_as_dict_key() -> None:
+    """Equal objects can be used interchangeably as dict keys."""
+    a = _TestCxxClassDerived(1, 2, 3.0, 4.0)
+    b = _TestCxxClassDerived(1, 2, 3.0, 4.0)
+    d = {a: "value"}
+    assert d[b] == "value"
+
+
+# ---------------------------------------------------------------------------
+# 5. Ordering (__lt__, __le__, __gt__, __ge__)
+# ---------------------------------------------------------------------------
+
+
+def test_c_class_ordering() -> None:
+    """c_class installs ordering operators."""
+    small = _TestCxxClassDerived(0, 0, 0.0, 0.0)
+    big = _TestCxxClassDerived(100, 100, 100.0, 100.0)
+    assert small < big  # ty: ignore[unsupported-operator]
+    assert small <= big  # ty: ignore[unsupported-operator]
+    assert big > small  # ty: ignore[unsupported-operator]
+    assert big >= small  # ty: ignore[unsupported-operator]
+    assert not (big < small)  # ty: ignore[unsupported-operator]
+    assert not (small > big)  # ty: ignore[unsupported-operator]
+
+
+def test_c_class_ordering_reflexive() -> None:
+    """<= and >= are reflexive."""
+    a = _TestCxxClassDerived(1, 2, 3.0, 4.0)
+    b = a  # alias, same object
+    assert a <= b  # ty: ignore[unsupported-operator]
+    assert a >= b  # ty: ignore[unsupported-operator]
+
+
+def test_c_class_ordering_antisymmetric() -> None:
+    """If a < b then not b < a."""
+    a = _TestCxxClassDerived(0, 0, 0.0, 0.0)
+    b = _TestCxxClassDerived(100, 100, 100.0, 100.0)
+    if a < b:  # ty: ignore[unsupported-operator]
+        assert not (b < a)  # ty: ignore[unsupported-operator]
+    else:
+        assert not (a < b)  # ty: ignore[unsupported-operator]
+
+
+# ---------------------------------------------------------------------------
+# 6. Equality with different types returns NotImplemented
+# ---------------------------------------------------------------------------
+
+
+def test_c_class_eq_different_type() -> None:
+    """__eq__ returns NotImplemented for unrelated types."""
+    a = _TestCxxClassDerived(1, 2, 3.0, 4.0)
+    assert a != "hello"
+    assert a != 42
+    assert a != 3.14
+    assert a is not None
+
+
+def test_c_class_ordering_different_type() -> None:
+    """Ordering against unrelated types raises TypeError."""
+    a = _TestCxxClassDerived(1, 2, 3.0, 4.0)
+    with pytest.raises(TypeError):
+        a < "hello"  # ty: ignore[unsupported-operator]
+    with pytest.raises(TypeError):
+        a <= 42  # ty: ignore[unsupported-operator]
+    with pytest.raises(TypeError):
+        a > 3.14  # ty: ignore[unsupported-operator]
+    with pytest.raises(TypeError):
+        a >= None  # ty: ignore[unsupported-operator]
+
+
+# ---------------------------------------------------------------------------
+# 7. Subclass equality
+# ---------------------------------------------------------------------------
+
+
+def test_c_class_subclass_eq() -> None:
+    """Subclass instances can be compared to parent instances without 
crashing."""
+    derived = _TestCxxClassDerived(1, 2, 3.0, 4.0)
+    derived_derived = _TestCxxClassDerivedDerived(
+        v_i64=1, v_i32=2, v_f64=3.0, v_f32=4.0, v_str="hello", v_bool=True
+    )
+    # These are different types in the same hierarchy; comparison should
+    # return a bool (the result depends on C++ behavior).
+    result = derived == derived_derived
+    assert isinstance(result, bool)
+
+
+# ---------------------------------------------------------------------------
+# 8. KwOnly from C++ reflection
+# ---------------------------------------------------------------------------
+
+
+def test_c_class_kw_only_signature() -> None:
+    """kw_only trait comes from C++ reflection, not Python decorator."""
+    sig = inspect.signature(_TestCxxKwOnly.__init__)
+    params = sig.parameters
+    for name in ("x", "y", "z", "w"):
+        assert params[name].kind == inspect.Parameter.KEYWORD_ONLY, (
+            f"Expected {name} to be KEYWORD_ONLY"
+        )
+
+
+def test_c_class_kw_only_call() -> None:
+    """KwOnly fields can be supplied as keyword arguments."""
+    obj = _TestCxxKwOnly(x=1, y=2, z=3, w=4)
+    assert obj.x == 1
+    assert obj.y == 2
+    assert obj.z == 3
+    assert obj.w == 4
+
+
+def test_c_class_kw_only_default() -> None:
+    """KwOnly field with a C++ default can be omitted."""
+    obj = _TestCxxKwOnly(x=1, y=2, z=3)
+    assert obj.w == 100
+
+
+def test_c_class_kw_only_rejects_positional() -> None:
+    """KwOnly fields reject positional arguments."""
+    with pytest.raises(TypeError, match="positional"):
+        _TestCxxKwOnly(1, 2, 3, 4)  # ty: ignore[missing-argument, 
too-many-positional-arguments]
+
+
+# ---------------------------------------------------------------------------
+# 9. Init subset from C++ reflection
+# ---------------------------------------------------------------------------
+
+
+def test_c_class_init_subset_signature() -> None:
+    """init=False fields from C++ reflection are excluded from __init__."""
+    sig = inspect.signature(_TestCxxInitSubset.__init__)
+    params = tuple(sig.parameters)
+    assert "required_field" in params
+    assert "optional_field" not in params
+    assert "note" not in params
+
+
+def test_c_class_init_subset_defaults() -> None:
+    """init=False fields get their default values from C++."""
+    obj = _TestCxxInitSubset(required_field=42)
+    assert obj.required_field == 42
+    assert obj.optional_field == -1  # C++ default
+    assert obj.note == "default"  # C++ default
+
+
+def test_c_class_init_subset_positional() -> None:
+    """Init-subset fields can be passed positionally."""
+    obj = _TestCxxInitSubset(7)
+    assert obj.required_field == 7
+    assert obj.optional_field == -1
+
+
+def test_c_class_init_subset_field_writable() -> None:
+    """Fields excluded from __init__ can still be assigned after 
construction."""
+    obj = _TestCxxInitSubset(required_field=0)
+    obj.optional_field = 11
+    assert obj.optional_field == 11
+
+
+# ---------------------------------------------------------------------------
+# 10. DerivedDerived with defaults
+# ---------------------------------------------------------------------------
+
+
+def test_c_class_derived_derived_defaults() -> None:
+    """DerivedDerived uses positional args; C++ defaults fill in omitted 
fields."""
+    obj = _TestCxxClassDerivedDerived(1, 2, 3.0, True)
+    assert obj.v_i64 == 1
+    assert obj.v_i32 == 2
+    assert obj.v_f64 == 3.0
+    assert obj.v_f32 == 8.0  # C++ default
+    assert obj.v_str == "default"  # C++ default
+    assert obj.v_bool is True
+
+
+def test_c_class_derived_derived_all_explicit() -> None:
+    """DerivedDerived with all fields explicitly provided."""
+    obj = _TestCxxClassDerivedDerived(
+        v_i64=123,
+        v_i32=456,
+        v_f64=4.0,
+        v_f32=9.0,
+        v_str="hello",
+        v_bool=True,
+    )
+    assert obj.v_i64 == 123
+    assert obj.v_i32 == 456
+    assert obj.v_f64 == 4.0
+    assert obj.v_f32 == 9.0
+    assert obj.v_str == "hello"
+    assert obj.v_bool is True
+
+
+# ---------------------------------------------------------------------------
+# 11. Hash / set usage
+# ---------------------------------------------------------------------------
+
+
+def test_c_class_usable_in_set() -> None:
+    """Equal objects deduplicate in a set."""
+    a = _TestCxxClassDerived(1, 2, 3.0, 4.0)
+    b = _TestCxxClassDerived(1, 2, 3.0, 4.0)
+    c = _TestCxxClassDerived(5, 6, 7.0, 8.0)
+    s = {a, b, c}
+    assert len(s) == 2  # a and b are equal
+
+
+def test_c_class_unequal_objects_in_set() -> None:
+    """Distinct objects are separate entries in a set."""
+    objs = {_TestCxxClassDerived(i, i, float(i), float(i)) for i in range(5)}
+    assert len(objs) == 5
diff --git a/tests/python/test_copy.py b/tests/python/test_dataclass_copy.py
similarity index 80%
rename from tests/python/test_copy.py
rename to tests/python/test_dataclass_copy.py
index 3ec9a55c..df9b6367 100644
--- a/tests/python/test_copy.py
+++ b/tests/python/test_dataclass_copy.py
@@ -685,6 +685,208 @@ class TestDeepCopyBranches:
         assert not pair.same_as(deep_pair)
         assert deep_pair.a == 5
 
+    # --- Cycle preservation with immutable root containers ---
+
+    def test_cycle_list_root_map_backref_preserved(self) -> None:
+        """Control case: List root with Map back-reference should preserve 
cycle."""
+        root_list = tvm_ffi.List()
+        m = tvm_ffi.Map({"list": root_list})
+        root_list.append(m)
+
+        deep_list = copy.deepcopy(root_list)
+        assert not root_list.same_as(deep_list)
+        assert deep_list[0]["list"].same_as(deep_list)
+
+    def test_cycle_map_root_list_backref_preserved(self) -> None:
+        """Map root with List child pointing back should preserve cycle to 
root copy."""
+        l = tvm_ffi.List()
+        m = tvm_ffi.Map({"list": l})
+        l.append(m)
+
+        deep_map = copy.deepcopy(m)
+        assert not m.same_as(deep_map)
+        assert not l.same_as(deep_map["list"])
+        assert deep_map["list"][0].same_as(deep_map)
+
+    def test_cycle_array_root_list_backref_preserved(self) -> None:
+        """Array root with List child pointing back should preserve cycle to 
root copy."""
+        l = tvm_ffi.List()
+        a = tvm_ffi.Array([l])
+        l.append(a)
+
+        deep_arr = copy.deepcopy(a)
+        assert not a.same_as(deep_arr)
+        assert not l.same_as(deep_arr[0])
+        assert deep_arr[0][0].same_as(deep_arr)
+
+    def test_cycle_array_root_dict_backref_preserved(self) -> None:
+        """Array root with Dict child pointing back should preserve cycle to 
root copy."""
+        d = tvm_ffi.Dict()
+        a = tvm_ffi.Array([d])
+        d["self"] = a
+
+        deep_arr = copy.deepcopy(a)
+        assert not a.same_as(deep_arr)
+        assert not d.same_as(deep_arr[0])
+        assert deep_arr[0]["self"].same_as(deep_arr)
+
+    def test_cycle_map_root_dict_backref_preserved(self) -> None:
+        """Map root with Dict child pointing back should preserve cycle to 
root copy."""
+        d = tvm_ffi.Dict()
+        m = tvm_ffi.Map({"dict": d})
+        d["self"] = m
+
+        deep_map = copy.deepcopy(m)
+        assert not m.same_as(deep_map)
+        assert not d.same_as(deep_map["dict"])
+        assert deep_map["dict"]["self"].same_as(deep_map)
+
+    def test_cycle_map_root_backref_identity_not_duplicated(self) -> None:
+        """Back-references in a map-root cycle should point to the root copied 
map."""
+        shared_list = tvm_ffi.List()
+        m = tvm_ffi.Map({"l1": shared_list, "l2": shared_list})
+        shared_list.append(m)
+
+        deep_map = copy.deepcopy(m)
+        assert deep_map["l1"].same_as(deep_map["l2"])
+        assert deep_map["l1"][0].same_as(deep_map)
+
+    def test_cycle_map_root_list_key_backref_preserved(self) -> None:
+        """Map-root cycles through keys should preserve back-reference to 
copied root."""
+        key_list = tvm_ffi.List()
+        m = tvm_ffi.Map({key_list: 1})
+        key_list.append(m)
+
+        deep_map = copy.deepcopy(m)
+        deep_key = next(iter(deep_map.keys()))
+        assert isinstance(deep_key, tvm_ffi.List)
+        assert deep_key[0].same_as(deep_map)
+
+    def test_cycle_map_root_dict_key_backref_preserved(self) -> None:
+        """Map-root cycles through Dict keys should preserve back-reference to 
copied root."""
+        key_dict = tvm_ffi.Dict()
+        m = tvm_ffi.Map({key_dict: 1})
+        key_dict["self"] = m
+
+        deep_map = copy.deepcopy(m)
+        deep_key = next(iter(deep_map.keys()))
+        assert isinstance(deep_key, tvm_ffi.Dict)
+        assert deep_key["self"].same_as(deep_map)
+
+    def test_cycle_array_root_dict_contains_root_as_key(self) -> None:
+        """Array root with Dict child using the root as key should fix key to 
copied root."""
+        d = tvm_ffi.Dict()
+        root = tvm_ffi.Array([d])
+        d[root] = 1
+
+        deep_root = copy.deepcopy(root)
+        deep_dict = deep_root[0]
+        deep_key = next(iter(deep_dict.keys()))
+
+        assert not root.same_as(deep_root)
+        assert deep_key.same_as(deep_root)
+        assert not deep_key.same_as(root)
+
+    def test_cycle_map_root_dict_contains_root_as_key(self) -> None:
+        """Map root with Dict child using the root as key should fix key to 
copied root."""
+        d = tvm_ffi.Dict()
+        root = tvm_ffi.Map({"d": d})
+        d[root] = 1
+
+        deep_root = copy.deepcopy(root)
+        deep_dict = deep_root["d"]
+        deep_key = next(iter(deep_dict.keys()))
+
+        assert not root.same_as(deep_root)
+        assert deep_key.same_as(deep_root)
+        assert not deep_key.same_as(root)
+
+    # --- Python deepcopy protocol consistency for immutable Shape ---
+
+    def test_shape_root_python_deepcopy_matches_ffi_deepcopy(self) -> None:
+        """copy.deepcopy(Shape) should be consistent with ffi.DeepCopy."""
+        deep_copy_fn = tvm_ffi.get_global_func("ffi.DeepCopy")
+        s = tvm_ffi.Shape((2, 3, 4))
+        ffi_copied = deep_copy_fn(s)
+        py_copied = copy.deepcopy(s)
+        assert py_copied == ffi_copied
+        assert isinstance(py_copied, type(s))
+
+    def test_shape_inside_python_container_deepcopy(self) -> None:
+        """Python container deepcopy should handle Shape payloads."""
+        s = tvm_ffi.Shape((1, 2))
+        payload = [s, {"shape": s}]
+        copied = copy.deepcopy(payload)
+        assert copied[0] == s
+        assert copied[1]["shape"] == s  # ty: ignore[invalid-argument-type]
+
+    # --- Cycle fixup: immutable container → reflected object back-reference 
---
+
+    def test_cycle_array_root_object_backreference(self) -> None:
+        """Array A → Object X, X.v_array = A.  Deep copy from A."""
+        obj = tvm_ffi.testing.create_object(
+            "testing.TestObjectDerived",
+            v_i64=42,
+            v_map=tvm_ffi.Map({}),
+            v_array=tvm_ffi.Array([]),
+        )
+        arr = tvm_ffi.Array([obj])
+        obj.v_array = arr  # ty: ignore[unresolved-attribute]
+
+        arr_deep = _deep_copy(arr)
+
+        assert not arr.same_as(arr_deep)
+        obj_deep = arr_deep[0]
+        assert not obj.same_as(obj_deep)
+        assert obj_deep.v_i64 == 42
+        assert not obj_deep.v_array.same_as(arr)
+        assert obj_deep.v_array.same_as(arr_deep)
+
+    def test_cycle_map_root_object_backreference(self) -> None:
+        """Map M → Object X, X.v_map = M.  Deep copy from M."""
+        obj = tvm_ffi.testing.create_object(
+            "testing.TestObjectDerived",
+            v_i64=7,
+            v_map=tvm_ffi.Map({}),
+            v_array=tvm_ffi.Array([]),
+        )
+        m = tvm_ffi.Map({"key": obj})
+        obj.v_map = m  # ty: ignore[unresolved-attribute]
+
+        m_deep = _deep_copy(m)
+
+        assert not m.same_as(m_deep)
+        obj_deep = m_deep["key"]
+        assert not obj.same_as(obj_deep)
+        assert obj_deep.v_i64 == 7
+        assert not obj_deep.v_map.same_as(m)
+        assert obj_deep.v_map.same_as(m_deep)
+
+    def test_cycle_nested_array_object_array(self) -> None:
+        """Array → Object → Array → Object → back to root Array."""
+        inner = tvm_ffi.testing.create_object(
+            "testing.TestObjectDerived",
+            v_i64=1,
+            v_map=tvm_ffi.Map({}),
+            v_array=tvm_ffi.Array([]),
+        )
+        outer = tvm_ffi.testing.create_object(
+            "testing.TestObjectDerived",
+            v_i64=2,
+            v_map=tvm_ffi.Map({}),
+            v_array=tvm_ffi.Array([inner]),
+        )
+        root_arr = tvm_ffi.Array([outer])
+        inner.v_array = root_arr  # ty: ignore[unresolved-attribute]
+
+        root_deep = _deep_copy(root_arr)
+
+        assert not root_arr.same_as(root_deep)
+        outer_deep = root_deep[0]
+        inner_deep = outer_deep.v_array[0]
+        assert not inner_deep.v_array.same_as(root_arr)
+        assert inner_deep.v_array.same_as(root_deep)
+
 
 # --------------------------------------------------------------------------- #
 #  __replace__
diff --git a/tests/python/test_dataclass_init.py 
b/tests/python/test_dataclass_init.py
index b957f54e..918838cf 100644
--- a/tests/python/test_dataclass_init.py
+++ b/tests/python/test_dataclass_init.py
@@ -886,7 +886,7 @@ class TestClassLevelInitFalse:
         assert field_names == ["x", "y"]
 
     def test_direct_construction_raises(self) -> None:
-        with pytest.raises(TypeError):
+        with pytest.raises(TypeError, match="cannot be constructed directly"):
             _TestCxxNoAutoInit(1, 2)  # ty: 
ignore[too-many-positional-arguments]
 
     def test_has_shallow_copy(self) -> None:


Reply via email to