This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 98cb8af  feat: Auto-create Python classes when missing (#49)
98cb8af is described below

commit 98cb8af49ff599c217fce96c3d4f57c0f52b8ec4
Author: Junru Shao <[email protected]>
AuthorDate: Thu Sep 25 12:11:04 2025 -0700

    feat: Auto-create Python classes when missing (#49)
---
 include/tvm/ffi/c_api.h               |   4 +-
 include/tvm/ffi/object.h              |   2 +-
 include/tvm/ffi/reflection/accessor.h |   4 +-
 python/tvm_ffi/core.pyi               |  10 ++-
 python/tvm_ffi/cython/base.pxi        |   2 +-
 python/tvm_ffi/cython/object.pxi      | 121 ++++++++++++++++++++++++----------
 python/tvm_ffi/cython/string.pxi      |   1 +
 python/tvm_ffi/cython/type_info.pxi   |  54 ++++++++++++---
 python/tvm_ffi/dataclasses/_utils.py  |  12 ----
 python/tvm_ffi/dataclasses/c_class.py |  15 ++---
 python/tvm_ffi/registry.py            |  36 ++--------
 src/ffi/extra/reflection_extra.cc     |   2 +-
 src/ffi/extra/testing.cc              |  38 +++++++++--
 src/ffi/object.cc                     |  16 ++---
 tests/cpp/test_object.cc              |   4 +-
 tests/python/test_object.py           |  27 +++++---
 16 files changed, 218 insertions(+), 130 deletions(-)

diff --git a/include/tvm/ffi/c_api.h b/include/tvm/ffi/c_api.h
index e988e24..62d9001 100644
--- a/include/tvm/ffi/c_api.h
+++ b/include/tvm/ffi/c_api.h
@@ -917,11 +917,11 @@ typedef struct TVMFFITypeInfo {
   /*! \brief the unique type key to identify the type. */
   TVMFFIByteArray type_key;
   /*!
-   * \brief type_acenstors[depth] stores the type_index of the acenstors at 
depth level
+   * \brief type_ancestors[depth] stores the type_index of the acenstors at 
depth level
    * \note To keep things simple, we do not allow multiple inheritance so the
    *       hieracy stays as a tree
    */
-  const struct TVMFFITypeInfo** type_acenstors;
+  const struct TVMFFITypeInfo** type_ancestors;
   // The following fields are used for reflection
   /*! \brief Cached hash value of the type key, used for consistent structural 
hashing. */
   uint64_t type_key_hash;
diff --git a/include/tvm/ffi/object.h b/include/tvm/ffi/object.h
index 1ebd3d7..93ef6d8 100644
--- a/include/tvm/ffi/object.h
+++ b/include/tvm/ffi/object.h
@@ -1012,7 +1012,7 @@ TVM_FFI_INLINE bool IsObjectInstance(int32_t 
object_type_index) {
       // the function checks that the info exists
       const TypeInfo* type_info = TVMFFIGetTypeInfo(object_type_index);
       return (type_info->type_depth > TargetType::_type_depth &&
-              type_info->type_acenstors[TargetType::_type_depth]->type_index 
== target_type_index);
+              type_info->type_ancestors[TargetType::_type_depth]->type_index 
== target_type_index);
     } else {
       return false;
     }
diff --git a/include/tvm/ffi/reflection/accessor.h 
b/include/tvm/ffi/reflection/accessor.h
index 5fadd09..b49da51 100644
--- a/include/tvm/ffi/reflection/accessor.h
+++ b/include/tvm/ffi/reflection/accessor.h
@@ -216,7 +216,7 @@ inline void ForEachFieldInfo(const TypeInfo* type_info, 
Callback callback) {
   // iterate through acenstors in parent to child order
   // skip the first one since it is always the root object
   for (int i = 1; i < type_info->type_depth; ++i) {
-    const TVMFFITypeInfo* parent_info = type_info->type_acenstors[i];
+    const TVMFFITypeInfo* parent_info = type_info->type_ancestors[i];
     for (int j = 0; j < parent_info->num_fields; ++j) {
       callback(parent_info->fields + j);
     }
@@ -243,7 +243,7 @@ inline bool ForEachFieldInfoWithEarlyStop(const TypeInfo* 
type_info,
   // iterate through acenstors in parent to child order
   // skip the first one since it is always the root object
   for (int i = 1; i < type_info->type_depth; ++i) {
-    const TVMFFITypeInfo* parent_info = type_info->type_acenstors[i];
+    const TVMFFITypeInfo* parent_info = type_info->type_ancestors[i];
     for (int j = 0; j < parent_info->num_fields; ++j) {
       if (callback_with_early_stop(parent_info->fields + j)) return true;
     }
diff --git a/python/tvm_ffi/core.pyi b/python/tvm_ffi/core.pyi
index 53c54dd..ac04114 100644
--- a/python/tvm_ffi/core.pyi
+++ b/python/tvm_ffi/core.pyi
@@ -46,7 +46,7 @@ class Object:
     def __hash__(self) -> int: ...
     def __init_handle_by_constructor__(self, fconstructor: Any, *args: Any) -> 
None: ...
     def __ffi_init__(self, *args: Any) -> None:
-        """Initialize the instance using the ` __init__` method registered on 
C++ side.
+        """Initialize the instance using the ` __ffi_init__` method registered 
on C++ side.
 
         Parameters
         ----------
@@ -83,8 +83,8 @@ class PyNativeObject:
 def _set_class_object(cls: type) -> None: ...
 def _register_object_by_index(type_index: int, type_cls: type) -> TypeInfo: ...
 def _object_type_key_to_index(type_key: str) -> int | None: ...
-def _set_type_cls(type_index: int, type_cls: type) -> None: ...
-def _lookup_type_info_from_type_key(type_key: str) -> TypeInfo: ...
+def _set_type_cls(type_info: TypeInfo, type_cls: type) -> None: ...
+def _lookup_or_register_type_info_from_type_key(type_key: str) -> TypeInfo: ...
 
 class Error(Object):
     """Base class for FFI errors."""
@@ -267,6 +267,8 @@ class TypeMethod:
     func: Any
     is_static: bool
 
+    def as_callable(self, cls: type) -> Callable[..., Any]: ...
+
 class TypeInfo:
     """Aggregated type information required to build a proxy class."""
 
@@ -276,3 +278,5 @@ class TypeInfo:
     fields: list[TypeField]
     methods: list[TypeMethod]
     parent_type_info: TypeInfo | None
+
+    def prototype_py(self) -> str: ...
diff --git a/python/tvm_ffi/cython/base.pxi b/python/tvm_ffi/cython/base.pxi
index 3ff3ecc..5c1ba1e 100644
--- a/python/tvm_ffi/cython/base.pxi
+++ b/python/tvm_ffi/cython/base.pxi
@@ -191,7 +191,7 @@ cdef extern from "tvm/ffi/c_api.h":
         int32_t type_index
         int32_t type_depth
         TVMFFIByteArray type_key
-        const int32_t* type_acenstors
+        const TVMFFITypeInfo** type_ancestors
         uint64_t type_key_hash
         int32_t num_fields
         int32_t num_methods
diff --git a/python/tvm_ffi/cython/object.pxi b/python/tvm_ffi/cython/object.pxi
index e8f3593..0bb1e03 100644
--- a/python/tvm_ffi/cython/object.pxi
+++ b/python/tvm_ffi/cython/object.pxi
@@ -15,6 +15,8 @@
 # specific language governing permissions and limitations
 # under the License.
 import warnings
+from typing import Any
+
 
 _CLASS_OBJECT = None
 
@@ -261,29 +263,66 @@ cdef inline object make_ret_opaque_object(TVMFFIAny 
result):
     (<Object>obj).chandle = result.v_obj
     return obj.pyobject()
 
+cdef inline object make_fallback_cls_for_type_index(int32_t type_index):
+    cdef str type_key = _type_index_to_key(type_index)
+    cdef object type_info = 
_lookup_or_register_type_info_from_type_key(type_key)
+    cdef object parent_type_info = type_info.parent_type_info
+    assert type_info.type_cls is None
+
+    # Ensure parent classes are created first
+    assert parent_type_info is not None
+    if parent_type_info.type_cls is None:   # recursively create parent class 
first
+        make_fallback_cls_for_type_index(parent_type_info.type_index)
+    assert parent_type_info.type_cls is not None
+
+    # Create `type_info.type_cls` now
+    class cls(parent_type_info.type_cls):
+        pass
+    attrs = dict(cls.__dict__)
+    attrs.pop("__dict__", None)
+    attrs.pop("__weakref__", None)
+    attrs.update({
+        "__slots__": (),
+        "__tvm_ffi_type_info__": type_info,
+        "__name__": type_key.split(".")[-1],
+        "__qualname__": type_key,
+        "__module__": ".".join(type_key.split(".")[:-1]),
+        "__doc__": f"Auto-generated fallback class for {type_key}.\n"
+                   "This class is generated because the class is not 
registered.\n"
+                   "Please do not use this class directly, instead register 
the class\n"
+                   "using `register_object` decorator.",
+    })
+    for field in type_info.fields:
+        attrs[field.name] = field.as_property(cls)
+    for method in type_info.methods:
+        name = method.name
+        if name == "__ffi_init__":
+            name = "__c_ffi_init__"
+        attrs[name] = method.as_callable(cls)
+    for name, val in attrs.items():
+        setattr(cls, name, val)
+    # Update the registry
+    type_info.type_cls = cls
+    _update_registry(type_index, type_key, type_info, cls)
+    return cls
+
 
 cdef inline object make_ret_object(TVMFFIAny result):
-    global TYPE_INDEX_TO_INFO
-    cdef int32_t tindex
-    cdef object cls
-    tindex = result.type_index
-
-    if tindex < len(TYPE_INDEX_TO_CLS):
-        cls = TYPE_INDEX_TO_CLS[tindex]
-        if cls is not None:
-            if issubclass(cls, PyNativeObject):
-                obj = Object.__new__(Object)
-                (<Object>obj).chandle = result.v_obj
-                return cls.__from_tvm_ffi_object__(cls, obj)
-            obj = cls.__new__(cls)
-            (<Object>obj).chandle = result.v_obj
-            return obj
+    cdef int32_t type_index
+    cdef object cls, obj
+    type_index = result.type_index
 
-    # object is not found in registered entry
-    # in this case we need to report an warning
-    type_key = _type_index_to_key(tindex)
-    warnings.warn(f"Returning type `{type_key}` which is not registered via 
register_object, fallback to Object")
-    obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
+    if type_index < len(TYPE_INDEX_TO_CLS) and (cls := 
TYPE_INDEX_TO_CLS[type_index]) is not None:
+        if issubclass(cls, PyNativeObject):
+            obj = Object.__new__(Object)
+            (<Object>obj).chandle = result.v_obj
+            return cls.__from_tvm_ffi_object__(cls, obj)
+    else:
+        # Slow path: object is not found in registered entry
+        # In this case create a dummy stub class for future usage.
+        # For every unregistered class, this slow path will be triggered only 
once.
+        cls = make_fallback_cls_for_type_index(type_index)
+    obj = cls.__new__(cls)
     (<Object>obj).chandle = result.v_obj
     return obj
 
@@ -294,17 +333,21 @@ cdef _get_method_from_method_info(const TVMFFIMethodInfo* 
method):
     return make_ret(result)
 
 
-def _type_info_create_from_type_key(object type_cls, str type_key):
+cdef _type_info_create_from_type_key(object type_cls, str type_key):
     cdef const TVMFFIFieldInfo* field
     cdef const TVMFFIMethodInfo* method
     cdef const TVMFFITypeInfo* info
     cdef int32_t type_index
+    cdef list ancestors = []
+    cdef int ancestor
     cdef object fields = []
     cdef object methods = []
     cdef FieldGetter getter
     cdef FieldSetter setter
     cdef ByteArrayArg type_key_arg = ByteArrayArg(c_str(type_key))
 
+    # NOTE: `type_key_arg` must be kept alive until after the call to 
`TVMFFITypeKeyToIndex`,
+    # because Cython doesn't defer the destruction of `type_key_arg` until 
after the call.
     if TVMFFITypeKeyToIndex(type_key_arg.cptr(), &type_index) != 0:
         raise ValueError(f"Cannot find type key: {type_key}")
     info = TVMFFIGetTypeInfo(type_index)
@@ -339,44 +382,54 @@ def _type_info_create_from_type_key(object type_cls, str 
type_key):
             )
         )
 
+    for i in range(info.type_depth):
+        ancestor = info.type_ancestors[i].type_index
+        ancestors.append(ancestor)
+
     return TypeInfo(
         type_cls=type_cls,
         type_index=type_index,
         type_key=bytearray_to_str(&info.type_key),
+        type_ancestors=ancestors,
         fields=fields,
         methods=methods,
         parent_type_info=None,
     )
 
 
-def _register_object_by_index(int type_index, object type_cls):
-    global TYPE_INDEX_TO_INFO, TYPE_KEY_TO_INFO, TYPE_INDEX_TO_CLS
-    cdef str type_key = _type_index_to_key(type_index)
-    cdef object info = _type_info_create_from_type_key(type_cls, type_key)
+cdef _update_registry(int type_index, object type_key, object type_info, 
object type_cls):
+    cdef int extra = type_index + 1 - len(TYPE_INDEX_TO_INFO)
     assert len(TYPE_INDEX_TO_INFO) == len(TYPE_INDEX_TO_CLS)
-    if (extra := type_index + 1 - len(TYPE_INDEX_TO_INFO)) > 0:
+    if extra > 0:
         TYPE_INDEX_TO_INFO.extend([None] * extra)
         TYPE_INDEX_TO_CLS.extend([None] * extra)
     TYPE_INDEX_TO_CLS[type_index] = type_cls
-    TYPE_INDEX_TO_INFO[type_index] = info
-    TYPE_KEY_TO_INFO[type_key] = info
+    TYPE_INDEX_TO_INFO[type_index] = type_info
+    TYPE_KEY_TO_INFO[type_key] = type_info
+
+
+def _register_object_by_index(int type_index, object type_cls):
+    global TYPE_INDEX_TO_INFO, TYPE_KEY_TO_INFO, TYPE_INDEX_TO_CLS
+    cdef str type_key = _type_index_to_key(type_index)
+    cdef object info = _type_info_create_from_type_key(type_cls, type_key)
+    _update_registry(type_index, type_key, info, type_cls)
     return info
 
 
-def _set_type_cls(int type_index, object type_cls):
+def _set_type_cls(object type_info, object type_cls):
     global TYPE_INDEX_TO_INFO, TYPE_INDEX_TO_CLS
-    assert len(TYPE_INDEX_TO_INFO) == len(TYPE_INDEX_TO_CLS)
-    type_info = TYPE_INDEX_TO_INFO[type_index]
     assert type_info.type_cls is None, f"Type already registered for 
{type_info.type_key}"
+    assert TYPE_INDEX_TO_INFO[type_info.type_index] is type_info
+    assert TYPE_KEY_TO_INFO[type_info.type_key] is type_info
     type_info.type_cls = type_cls
-    TYPE_INDEX_TO_CLS[type_index] = type_cls
+    TYPE_INDEX_TO_CLS[type_info.type_index] = type_cls
 
 
-def _lookup_type_info_from_type_key(type_key: str) -> TypeInfo:
+def _lookup_or_register_type_info_from_type_key(type_key: str) -> TypeInfo:
     if info := TYPE_KEY_TO_INFO.get(type_key, None):
         return info
     info = _type_info_create_from_type_key(None, type_key)
-    TYPE_KEY_TO_INFO[type_key] = info
+    _update_registry(info.type_index, type_key, info, None)
     return info
 
 
diff --git a/python/tvm_ffi/cython/string.pxi b/python/tvm_ffi/cython/string.pxi
index 0f9d11b..bb8f6e5 100644
--- a/python/tvm_ffi/cython/string.pxi
+++ b/python/tvm_ffi/cython/string.pxi
@@ -77,3 +77,4 @@ class Bytes(bytes, PyNativeObject):
 
 
 _register_object_by_index(kTVMFFIBytes, Bytes)
+_register_object_by_index(kTVMFFIObject, Object)
diff --git a/python/tvm_ffi/cython/type_info.pxi 
b/python/tvm_ffi/cython/type_info.pxi
index fcd443b..ca893b4 100644
--- a/python/tvm_ffi/cython/type_info.pxi
+++ b/python/tvm_ffi/cython/type_info.pxi
@@ -16,6 +16,7 @@
 # under the License.
 import dataclasses
 from typing import Optional, Any
+from io import StringIO
 
 
 cdef class FieldGetter:
@@ -75,20 +76,20 @@ class TypeField:
         assert self.setter is not None
         assert self.getter is not None
 
-    def as_property(self, cls: type):
+    def as_property(self, object cls):
         """Create a Python ``property`` object for this field on ``cls``."""
-        name = self.name
-        fget = self.getter
-        fset = self.setter
+        cdef str name = self.name
+        cdef str doc = self.doc or 
f"{cls.__module__}.{cls.__qualname__}.{name}"
+        cdef FieldGetter fget = self.getter
+        cdef FieldSetter fset = self.setter
         fget.__name__ = fset.__name__ = name
         fget.__module__ = fset.__module__ = cls.__module__
-        fget.__qualname__ = fset.__qualname__ = f"{cls.__qualname__}.{name}"  
# type: ignore[attr-defined]
-        fget.__doc__ = fset.__doc__ = f"Property `{name}` of class 
`{cls.__qualname__}`"  # type: ignore[attr-defined]
-
+        fget.__qualname__ = fset.__qualname__ = f"{cls.__qualname__}.{name}"
+        fget.__doc__ = fset.__doc__ = f"Property `{name}` of class 
`{cls.__qualname__}`"
         return property(
-            fget=fget if self.getter is not None else None,
-            fset=fset if (not self.frozen) and self.setter is not None else 
None,
-            doc=f"{cls.__module__}.{cls.__qualname__}.{name}",
+            fget=fget,
+            fset=fset if (not self.frozen) else None,
+            doc=doc,
         )
 
 
@@ -101,6 +102,21 @@ class TypeMethod:
     func: object
     is_static: bool
 
+    def as_callable(self, object cls):
+        """Create a Python method attribute for this method on ``cls``."""
+        cdef str name = self.name
+        cdef str doc = self.doc or f"Method `{name}` of class 
`{cls.__qualname__}`"
+        cdef object func = self.func
+        if not self.is_static:
+            func = _member_method_wrapper(func)
+        func.__module__ = cls.__module__
+        func.__name__ = name
+        func.__qualname__ = f"{cls.__qualname__}.{name}"
+        func.__doc__ = doc
+        if self.is_static:
+            func = staticmethod(func)
+        return func
+
 
 @dataclasses.dataclass(eq=False)
 class TypeInfo:
@@ -109,6 +125,24 @@ class TypeInfo:
     type_cls: Optional[type]
     type_index: int
     type_key: str
+    type_ancestors: list[int]
     fields: list[TypeField]
     methods: list[TypeMethod]
     parent_type_info: Optional[TypeInfo]
+
+    def __post_init__(self):
+        cdef int parent_type_index
+        cdef str parent_type_key
+        if not self.type_ancestors:
+            return
+        parent_type_index = self.type_ancestors[-1]
+        parent_type_key = _type_index_to_key(parent_type_index)
+        # ensure parent is registered
+        self.parent_type_info = 
_lookup_or_register_type_info_from_type_key(parent_type_key)
+
+
+def _member_method_wrapper(method_func: Callable[..., Any]) -> Callable[..., 
Any]:
+    def wrapper(self: Any, *args: Any) -> Any:
+        return method_func(self, *args)
+
+    return wrapper
diff --git a/python/tvm_ffi/dataclasses/_utils.py 
b/python/tvm_ffi/dataclasses/_utils.py
index b6bcdac..60f31fb 100644
--- a/python/tvm_ffi/dataclasses/_utils.py
+++ b/python/tvm_ffi/dataclasses/_utils.py
@@ -26,23 +26,11 @@ from ..core import (
     Object,
     TypeField,
     TypeInfo,
-    _lookup_type_info_from_type_key,
 )
 
 _InputClsType = TypeVar("_InputClsType")
 
 
-def get_parent_type_info(type_cls: type) -> TypeInfo:
-    """Find the nearest ancestor with registered ``__tvm_ffi_type_info__``.
-
-    If none are found, return the base ``ffi.Object`` type info.
-    """
-    for base in type_cls.__bases__:
-        if (info := getattr(base, "__tvm_ffi_type_info__", None)) is not None:
-            return info
-    return _lookup_type_info_from_type_key("ffi.Object")
-
-
 def type_info_to_cls(
     type_info: TypeInfo,
     cls: type[_InputClsType],
diff --git a/python/tvm_ffi/dataclasses/c_class.py 
b/python/tvm_ffi/dataclasses/c_class.py
index dc6aeed..42dc4fd 100644
--- a/python/tvm_ffi/dataclasses/c_class.py
+++ b/python/tvm_ffi/dataclasses/c_class.py
@@ -27,16 +27,16 @@ from collections.abc import Callable
 from dataclasses import InitVar
 from typing import ClassVar, TypeVar, get_origin, get_type_hints
 
-from typing_extensions import dataclass_transform  # type: ignore[attr-defined]
+from typing_extensions import dataclass_transform
 
-from ..core import TypeField, TypeInfo
+from ..core import TypeField, TypeInfo, 
_lookup_or_register_type_info_from_type_key, _set_type_cls
 from . import _utils
-from .field import Field, field
+from .field import field
 
 _InputClsType = TypeVar("_InputClsType")
 
 
-@dataclass_transform(field_specifiers=(field, Field))
+@dataclass_transform(field_specifiers=(field,))
 def c_class(
     type_key: str, init: bool = True
 ) -> Callable[[type[_InputClsType]], type[_InputClsType]]:
@@ -116,9 +116,8 @@ def c_class(
         nonlocal init
         init = init and "__init__" not in super_type_cls.__dict__
         # Step 1. Retrieve `type_info` from registry
-        type_info: TypeInfo = _utils._lookup_type_info_from_type_key(type_key)
-        assert type_info.parent_type_info is None, f"Already registered type: 
{type_key}"
-        type_info.parent_type_info = 
_utils.get_parent_type_info(super_type_cls)
+        type_info: TypeInfo = 
_lookup_or_register_type_info_from_type_key(type_key)
+        assert type_info.parent_type_info is not None
         # Step 2. Reflect all the fields of the type
         type_info.fields = _inspect_c_class_fields(super_type_cls, type_info)
         for type_field in type_info.fields:
@@ -130,7 +129,7 @@ def c_class(
             cls=super_type_cls,
             methods={"__init__": fn_init},
         )
-        type_info.type_cls = type_cls
+        _set_type_cls(type_info, type_cls)
         return type_cls
 
     return decorator
diff --git a/python/tvm_ffi/registry.py b/python/tvm_ffi/registry.py
index 3ef4039..45c1da0 100644
--- a/python/tvm_ffi/registry.py
+++ b/python/tvm_ffi/registry.py
@@ -249,45 +249,17 @@ def init_ffi_api(namespace: str, target_module_name: str 
| None = None) -> None:
         setattr(target_module, fname, f)
 
 
-def _member_method_wrapper(method_func: Callable[..., Any]) -> Callable[..., 
Any]:
-    def wrapper(self: Any, *args: Any) -> Any:
-        return method_func(self, *args)
-
-    return wrapper
-
-
 def _add_class_attrs(type_cls: type, type_info: TypeInfo) -> type:
     for field in type_info.fields:
-        getter = field.getter
-        setter = field.setter if not field.frozen else None
-        doc = field.doc if field.doc else None
         name = field.name
-        if hasattr(type_cls, name):
-            # skip already defined attributes
-            continue
-        setattr(type_cls, name, property(getter, setter, doc=doc))
+        if not hasattr(type_cls, name):  # skip already defined attributes
+            setattr(type_cls, name, field.as_property(type_cls))
     for method in type_info.methods:
         name = method.name
         if name == "__ffi_init__":
             name = "__c_ffi_init__"
-        doc = method.doc if method.doc else None
-        method_func = method.func
-        if method.is_static:
-            if doc is not None:
-                method_func.__doc__ = doc
-            method_func.__name__ = name
-            method_pyfunc: Any = staticmethod(method_func)
-        else:
-            wrapped_func = _member_method_wrapper(method_func)
-            if doc is not None:
-                wrapped_func.__doc__ = doc
-            wrapped_func.__name__ = name
-            method_pyfunc = wrapped_func
-
-        if hasattr(type_cls, name):
-            # skip already defined attributes
-            continue
-        setattr(type_cls, name, method_pyfunc)
+        if not hasattr(type_cls, name):
+            setattr(type_cls, name, method.as_callable(type_cls))
     return type_cls
 
 
diff --git a/src/ffi/extra/reflection_extra.cc 
b/src/ffi/extra/reflection_extra.cc
index f923643..d36e1ac 100644
--- a/src/ffi/extra/reflection_extra.cc
+++ b/src/ffi/extra/reflection_extra.cc
@@ -90,7 +90,7 @@ void MakeObjectFromPackedArgs(ffi::PackedArgs args, Any* ret) 
{
   // iterate through acenstors in parent to child order
   // skip the first one since it is always the root object
   for (int i = 1; i < type_info->type_depth; ++i) {
-    update_fields(type_info->type_acenstors[i]);
+    update_fields(type_info->type_ancestors[i]);
   }
   update_fields(type_info);
 
diff --git a/src/ffi/extra/testing.cc b/src/ffi/extra/testing.cc
index cf55161..f67752f 100644
--- a/src/ffi/extra/testing.cc
+++ b/src/ffi/extra/testing.cc
@@ -134,13 +134,23 @@ class TestCxxInitSubsetObj : public Object {
   TVM_FFI_DECLARE_OBJECT_INFO("testing.TestCxxInitSubset", 
TestCxxInitSubsetObj, Object);
 };
 
-class TestUnregisteredObject : public Object {
+class TestUnregisteredBaseObject : public Object {
  public:
-  int64_t value;
-
-  explicit TestUnregisteredObject(int64_t value) : value(value) {}
+  int64_t v1;
+  explicit TestUnregisteredBaseObject(int64_t v1) : v1(v1) {}
+  int64_t GetV1PlusOne() const { return v1 + 1; }
+  TVM_FFI_DECLARE_OBJECT_INFO("testing.TestUnregisteredBaseObject", 
TestUnregisteredBaseObject,
+                              Object);
+};
 
-  TVM_FFI_DECLARE_OBJECT_INFO("testing.TestUnregisteredObject", 
TestUnregisteredObject, Object);
+class TestUnregisteredObject : public TestUnregisteredBaseObject {
+ public:
+  int64_t v2;
+  explicit TestUnregisteredObject(int64_t v1, int64_t v2)
+      : TestUnregisteredBaseObject(v1), v2(v2) {}
+  int64_t GetV2PlusTwo() const { return v2 + 2; }
+  TVM_FFI_DECLARE_OBJECT_INFO("testing.TestUnregisteredObject", 
TestUnregisteredObject,
+                              TestUnregisteredBaseObject);
 };
 
 TVM_FFI_NO_INLINE void TestRaiseError(String kind, String msg) {
@@ -176,6 +186,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
       .def_static("__ffi_init__", refl::init<TestCxxClassDerived, int64_t, 
int32_t, double, float>)
       .def_rw("v_f64", &TestCxxClassDerived::v_f64)
       .def_rw("v_f32", &TestCxxClassDerived::v_f32);
+
   refl::ObjectDef<TestCxxClassDerivedDerived>()
       .def_static(
           "__ffi_init__",
@@ -189,6 +200,21 @@ TVM_FFI_STATIC_INIT_BLOCK() {
       .def_rw("optional_field", &TestCxxInitSubsetObj::optional_field)
       .def_rw("note", &TestCxxInitSubsetObj::note);
 
+  refl::ObjectDef<TestUnregisteredBaseObject>()
+      .def_ro("v1", &TestUnregisteredBaseObject::v1)
+      .def_static("__ffi_init__", refl::init<TestUnregisteredBaseObject, 
int64_t>,
+                  "Constructor of TestUnregisteredBaseObject")
+      .def("get_v1_plus_one", &TestUnregisteredBaseObject::GetV1PlusOne,
+           "Get (v1 + 1) from TestUnregisteredBaseObject");
+
+  refl::ObjectDef<TestUnregisteredObject>()
+      .def_ro("v1", &TestUnregisteredObject::v1)
+      .def_ro("v2", &TestUnregisteredObject::v2)
+      .def_static("__ffi_init__", refl::init<TestUnregisteredObject, int64_t, 
int64_t>,
+                  "Constructor of TestUnregisteredObject")
+      .def("get_v2_plus_two", &TestUnregisteredObject::GetV2PlusTwo,
+           "Get (v2 + 2) from TestUnregisteredObject");
+
   refl::GlobalDef()
       .def("testing.test_raise_error", TestRaiseError)
       .def_packed("testing.nop", [](PackedArgs args, Any* ret) {})
@@ -206,7 +232,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
            })
       .def("testing.object_use_count", [](const Object* obj) { return 
obj->use_count(); })
       .def("testing.make_unregistered_object",
-           []() { return ObjectRef(make_object<TestUnregisteredObject>(42)); 
});
+           []() { return ObjectRef(make_object<TestUnregisteredObject>(41, 
42)); });
 }
 
 }  // namespace ffi
diff --git a/src/ffi/object.cc b/src/ffi/object.cc
index 292c8e9..d9c6698 100644
--- a/src/ffi/object.cc
+++ b/src/ffi/object.cc
@@ -54,7 +54,7 @@ class TypeTable {
     /*! \brief stored type key */
     String type_key_data;
     /*! \brief acenstor information */
-    std::vector<const TVMFFITypeInfo*> type_acenstors_data;
+    std::vector<const TVMFFITypeInfo*> type_ancestors_data;
     /*! \brief type fields informaton */
     std::vector<TVMFFIFieldInfo> type_fields_data;
     /*! \brief type methods informaton */
@@ -81,21 +81,21 @@ class TypeTable {
       if (type_depth != 0) {
         TVM_FFI_ICHECK_NOTNULL(parent);
         TVM_FFI_ICHECK_EQ(type_depth, parent->type_depth + 1);
-        type_acenstors_data.resize(type_depth);
+        type_ancestors_data.resize(type_depth);
         // copy over parent's type information
         for (int32_t i = 0; i < parent->type_depth; ++i) {
-          type_acenstors_data[i] = parent->type_acenstors[i];
+          type_ancestors_data[i] = parent->type_ancestors[i];
         }
         // set last type information to be parent
-        type_acenstors_data[parent->type_depth] = parent;
+        type_ancestors_data[parent->type_depth] = parent;
       }
-      // initialize type info: no change to type_key and type_acenstors fields
+      // initialize type info: no change to type_key and type_ancestors fields
       // after this line
       this->type_index = type_index;
       this->type_depth = type_depth;
       this->type_key = TVMFFIByteArray{this->type_key_data.data(), 
this->type_key_data.length()};
       this->type_key_hash = std::hash<String>()(this->type_key_data);
-      this->type_acenstors = type_acenstors_data.data();
+      this->type_ancestors = type_ancestors_data.data();
       // initialize the reflection information
       this->num_fields = 0;
       this->num_methods = 0;
@@ -280,7 +280,7 @@ class TypeTable {
     for (auto it = type_table_.rbegin(); it != type_table_.rend(); ++it) {
       const Entry* ptr = it->get();
       if (ptr != nullptr && ptr->type_depth != 0) {
-        int parent_index = ptr->type_acenstors[ptr->type_depth - 
1]->type_index;
+        int parent_index = ptr->type_ancestors[ptr->type_depth - 
1]->type_index;
         num_children[parent_index] += num_children[ptr->type_index] + 1;
         if (expected_child_slots[ptr->type_index] + 1 < ptr->num_slots) {
           expected_child_slots[ptr->type_index] = ptr->num_slots - 1;
@@ -293,7 +293,7 @@ class TypeTable {
       if (ptr != nullptr && num_children[ptr->type_index] >= 
min_children_count) {
         std::cerr << '[' << ptr->type_index << "]\t" << 
ToStringView(ptr->type_key);
         if (ptr->type_depth != 0) {
-          int32_t parent_index = ptr->type_acenstors[ptr->type_depth - 
1]->type_index;
+          int32_t parent_index = ptr->type_ancestors[ptr->type_depth - 
1]->type_index;
           std::cerr << "\tparent=" << 
ToStringView(type_table_[parent_index]->type_key);
         } else {
           std::cerr << "\tparent=root";
diff --git a/tests/cpp/test_object.cc b/tests/cpp/test_object.cc
index ec5c54c..6c8c822 100644
--- a/tests/cpp/test_object.cc
+++ b/tests/cpp/test_object.cc
@@ -55,8 +55,8 @@ TEST(Object, TypeInfo) {
   EXPECT_TRUE(info != nullptr);
   EXPECT_EQ(info->type_index, TIntObj::RuntimeTypeIndex());
   EXPECT_EQ(info->type_depth, 2);
-  EXPECT_EQ(info->type_acenstors[0]->type_index, Object::_type_index);
-  EXPECT_EQ(info->type_acenstors[1]->type_index, TNumberObj::_type_index);
+  EXPECT_EQ(info->type_ancestors[0]->type_index, Object::_type_index);
+  EXPECT_EQ(info->type_ancestors[1]->type_index, TNumberObj::_type_index);
   EXPECT_GE(info->type_index, TypeIndex::kTVMFFIDynObjectBegin);
 }
 
diff --git a/tests/python/test_object.py b/tests/python/test_object.py
index 3b36f5b..0c64e46 100644
--- a/tests/python/test_object.py
+++ b/tests/python/test_object.py
@@ -19,6 +19,7 @@ from typing import Any
 
 import pytest
 import tvm_ffi
+from tvm_ffi.core import TypeInfo
 
 
 def test_make_object() -> None:
@@ -103,12 +104,22 @@ def test_opaque_object() -> None:
 
 
 def test_unregistered_object_fallback() -> None:
-    with pytest.warns(
-        UserWarning,
-        match=(
-            r"Returning type `testing\.TestUnregisteredObject` "
-            r"which is not registered via register_object, fallback to Object"
-        ),
-    ):
+    def _check_type(x: Any) -> None:
+        type_info: TypeInfo = type(x).__tvm_ffi_type_info__  # type: 
ignore[attr-defined]
+        assert type_info.type_key == "testing.TestUnregisteredObject"
+        assert x.v1 == 41
+        assert x.v2 == 42
+        assert x.get_v1_plus_one() == 42  # type: ignore[attr-defined]
+        assert x.get_v2_plus_two() == 44  # type: ignore[attr-defined]
+        assert type(x).__name__ == "TestUnregisteredObject"
+        assert type(x).__module__ == "testing"
+        assert type(x).__qualname__ == "testing.TestUnregisteredObject"
+        assert "Auto-generated fallback class" in type(x).__doc__  # type: 
ignore[operator]
+        assert "Get (v1 + 1) from TestUnregisteredBaseObject" in 
type(x).get_v1_plus_one.__doc__  # type: ignore[attr-defined]
+        assert "Get (v2 + 2) from TestUnregisteredObject" in 
type(x).get_v2_plus_two.__doc__  # type: ignore[attr-defined]
+
+    obj = tvm_ffi.testing.make_unregistered_object()
+    _check_type(obj)
+    for _ in range(5):
         obj = tvm_ffi.testing.make_unregistered_object()
-    assert type(obj) is tvm_ffi.Object
+        _check_type(obj)

Reply via email to