This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new d68c8d8  fix: Fallback for unregistered object types (#44)
d68c8d8 is described below

commit d68c8d8d4520318d3598c39c71c444169b1244bc
Author: Junru Shao <[email protected]>
AuthorDate: Mon Sep 22 17:53:53 2025 -0700

    fix: Fallback for unregistered object types (#44)
    
    - guard the FFI return path so unregistered object types fall back to
    `Object` without throwing
    - add an unregistered testing object plus a helper to obtain it through
    the testing API
    - cover the fallback behavior with a new Python unit test
---
 python/tvm_ffi/cython/object.pxi | 18 ++++++++++--------
 python/tvm_ffi/testing.py        |  7 ++++++-
 src/ffi/extra/testing.cc         | 13 ++++++++++++-
 tests/python/test_object.py      | 12 ++++++++++++
 4 files changed, 40 insertions(+), 10 deletions(-)

diff --git a/python/tvm_ffi/cython/object.pxi b/python/tvm_ffi/cython/object.pxi
index 326a98b..4cc737a 100644
--- a/python/tvm_ffi/cython/object.pxi
+++ b/python/tvm_ffi/cython/object.pxi
@@ -269,15 +269,17 @@ cdef inline object make_ret_object(TVMFFIAny result):
     tindex = result.type_index
 
     if tindex < len(TYPE_INDEX_TO_INFO):
-        cls = TYPE_INDEX_TO_INFO[tindex].type_cls
-        if cls is not None:
-            if issubclass(cls, PyNativeObject):
-                obj = Object.__new__(Object)
+        type_info = TYPE_INDEX_TO_INFO[tindex]
+        if type_info is not None:
+            cls = type_info.type_cls
+            if cls is not None:
+                if issubclass(cls, PyNativeObject):
+                    obj = Object.__new__(Object)
+                    (<Object>obj).chandle = result.v_obj
+                    return cls.__from_tvm_ffi_object__(cls, obj)
+                obj = cls.__new__(cls)
                 (<Object>obj).chandle = result.v_obj
-                return cls.__from_tvm_ffi_object__(cls, obj)
-            obj = cls.__new__(cls)
-            (<Object>obj).chandle = result.v_obj
-            return obj
+                return obj
 
     # object is not found in registered entry
     # in this case we need to report an warning
diff --git a/python/tvm_ffi/testing.py b/python/tvm_ffi/testing.py
index 6d302bc..0053564 100644
--- a/python/tvm_ffi/testing.py
+++ b/python/tvm_ffi/testing.py
@@ -24,7 +24,7 @@ from . import _ffi_api
 from .container import Array, Map
 from .core import Object
 from .dataclasses import c_class, field
-from .registry import register_object
+from .registry import get_global_func, register_object
 
 
 @register_object("testing.TestObjectBase")
@@ -84,6 +84,11 @@ def create_object(type_key: str, **kwargs: Any) -> Object:
     return _ffi_api.MakeObjectFromPackedArgs(*args)
 
 
+def make_unregistered_object() -> Object:
+    """Return an object whose type is not registered on the Python side."""
+    return get_global_func("testing.make_unregistered_object")()
+
+
 @c_class("testing.TestCxxClassBase")
 class _TestCxxClassBase:
     v_i64: int
diff --git a/src/ffi/extra/testing.cc b/src/ffi/extra/testing.cc
index 79b6164..afd952b 100644
--- a/src/ffi/extra/testing.cc
+++ b/src/ffi/extra/testing.cc
@@ -121,6 +121,15 @@ class TestCxxClassDerivedDerived : public 
TestCxxClassDerived {
                               TestCxxClassDerived);
 };
 
+class TestUnregisteredObject : public Object {
+ public:
+  int64_t value;
+
+  explicit TestUnregisteredObject(int64_t value) : value(value) {}
+
+  TVM_FFI_DECLARE_OBJECT_INFO("testing.TestUnregisteredObject", 
TestUnregisteredObject, Object);
+};
+
 TVM_FFI_NO_INLINE void TestRaiseError(String kind, String msg) {
   // keep name and no liner for testing backtrace
   throw ffi::Error(kind, msg, TVMFFIBacktrace(__FILE__, __LINE__, 
TVM_FFI_FUNC_SIG, 0));
@@ -176,7 +185,9 @@ TVM_FFI_STATIC_INIT_BLOCK() {
              }
              std::cout << "Function finished without catching signal" << 
std::endl;
            })
-      .def("testing.object_use_count", [](const Object* obj) { return 
obj->use_count(); });
+      .def("testing.object_use_count", [](const Object* obj) { return 
obj->use_count(); })
+      .def("testing.make_unregistered_object",
+           []() { return ObjectRef(make_object<TestUnregisteredObject>(42)); 
});
 }
 
 }  // namespace ffi
diff --git a/tests/python/test_object.py b/tests/python/test_object.py
index aa1a791..3b36f5b 100644
--- a/tests/python/test_object.py
+++ b/tests/python/test_object.py
@@ -100,3 +100,15 @@ def test_opaque_object() -> None:
     assert sys.getrefcount(obj0) == 3
     obj0_cpy = None
     assert sys.getrefcount(obj0) == 2
+
+
+def test_unregistered_object_fallback() -> None:
+    with pytest.warns(
+        UserWarning,
+        match=(
+            r"Returning type `testing\.TestUnregisteredObject` "
+            r"which is not registered via register_object, fallback to Object"
+        ),
+    ):
+        obj = tvm_ffi.testing.make_unregistered_object()
+    assert type(obj) is tvm_ffi.Object

Reply via email to