This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new d68c8d8 fix: Fallback for unregistered object types (#44)
d68c8d8 is described below
commit d68c8d8d4520318d3598c39c71c444169b1244bc
Author: Junru Shao <[email protected]>
AuthorDate: Mon Sep 22 17:53:53 2025 -0700
fix: Fallback for unregistered object types (#44)
- guard the FFI return path so unregistered object types fall back to
`Object` without throwing
- add an unregistered testing object plus a helper to obtain it through
the testing API
- cover the fallback behavior with a new Python unit test
---
python/tvm_ffi/cython/object.pxi | 18 ++++++++++--------
python/tvm_ffi/testing.py | 7 ++++++-
src/ffi/extra/testing.cc | 13 ++++++++++++-
tests/python/test_object.py | 12 ++++++++++++
4 files changed, 40 insertions(+), 10 deletions(-)
diff --git a/python/tvm_ffi/cython/object.pxi b/python/tvm_ffi/cython/object.pxi
index 326a98b..4cc737a 100644
--- a/python/tvm_ffi/cython/object.pxi
+++ b/python/tvm_ffi/cython/object.pxi
@@ -269,15 +269,17 @@ cdef inline object make_ret_object(TVMFFIAny result):
tindex = result.type_index
if tindex < len(TYPE_INDEX_TO_INFO):
- cls = TYPE_INDEX_TO_INFO[tindex].type_cls
- if cls is not None:
- if issubclass(cls, PyNativeObject):
- obj = Object.__new__(Object)
+ type_info = TYPE_INDEX_TO_INFO[tindex]
+ if type_info is not None:
+ cls = type_info.type_cls
+ if cls is not None:
+ if issubclass(cls, PyNativeObject):
+ obj = Object.__new__(Object)
+ (<Object>obj).chandle = result.v_obj
+ return cls.__from_tvm_ffi_object__(cls, obj)
+ obj = cls.__new__(cls)
(<Object>obj).chandle = result.v_obj
- return cls.__from_tvm_ffi_object__(cls, obj)
- obj = cls.__new__(cls)
- (<Object>obj).chandle = result.v_obj
- return obj
+ return obj
# object is not found in registered entry
# in this case we need to report an warning
diff --git a/python/tvm_ffi/testing.py b/python/tvm_ffi/testing.py
index 6d302bc..0053564 100644
--- a/python/tvm_ffi/testing.py
+++ b/python/tvm_ffi/testing.py
@@ -24,7 +24,7 @@ from . import _ffi_api
from .container import Array, Map
from .core import Object
from .dataclasses import c_class, field
-from .registry import register_object
+from .registry import get_global_func, register_object
@register_object("testing.TestObjectBase")
@@ -84,6 +84,11 @@ def create_object(type_key: str, **kwargs: Any) -> Object:
return _ffi_api.MakeObjectFromPackedArgs(*args)
+def make_unregistered_object() -> Object:
+ """Return an object whose type is not registered on the Python side."""
+ return get_global_func("testing.make_unregistered_object")()
+
+
@c_class("testing.TestCxxClassBase")
class _TestCxxClassBase:
v_i64: int
diff --git a/src/ffi/extra/testing.cc b/src/ffi/extra/testing.cc
index 79b6164..afd952b 100644
--- a/src/ffi/extra/testing.cc
+++ b/src/ffi/extra/testing.cc
@@ -121,6 +121,15 @@ class TestCxxClassDerivedDerived : public
TestCxxClassDerived {
TestCxxClassDerived);
};
+class TestUnregisteredObject : public Object {
+ public:
+ int64_t value;
+
+ explicit TestUnregisteredObject(int64_t value) : value(value) {}
+
+ TVM_FFI_DECLARE_OBJECT_INFO("testing.TestUnregisteredObject",
TestUnregisteredObject, Object);
+};
+
TVM_FFI_NO_INLINE void TestRaiseError(String kind, String msg) {
// keep name and no liner for testing backtrace
throw ffi::Error(kind, msg, TVMFFIBacktrace(__FILE__, __LINE__,
TVM_FFI_FUNC_SIG, 0));
@@ -176,7 +185,9 @@ TVM_FFI_STATIC_INIT_BLOCK() {
}
std::cout << "Function finished without catching signal" <<
std::endl;
})
- .def("testing.object_use_count", [](const Object* obj) { return
obj->use_count(); });
+ .def("testing.object_use_count", [](const Object* obj) { return
obj->use_count(); })
+ .def("testing.make_unregistered_object",
+ []() { return ObjectRef(make_object<TestUnregisteredObject>(42));
});
}
} // namespace ffi
diff --git a/tests/python/test_object.py b/tests/python/test_object.py
index aa1a791..3b36f5b 100644
--- a/tests/python/test_object.py
+++ b/tests/python/test_object.py
@@ -100,3 +100,15 @@ def test_opaque_object() -> None:
assert sys.getrefcount(obj0) == 3
obj0_cpy = None
assert sys.getrefcount(obj0) == 2
+
+
+def test_unregistered_object_fallback() -> None:
+ with pytest.warns(
+ UserWarning,
+ match=(
+ r"Returning type `testing\.TestUnregisteredObject` "
+ r"which is not registered via register_object, fallback to Object"
+ ),
+ ):
+ obj = tvm_ffi.testing.make_unregistered_object()
+ assert type(obj) is tvm_ffi.Object