This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 7b57a46 fix: Add schema support for `ObjectRef::MemFn` (#113)
7b57a46 is described below
commit 7b57a46648662b11b483f252787db22ab701231f
Author: Junru Shao <[email protected]>
AuthorDate: Tue Oct 14 04:48:40 2025 -0700
fix: Add schema support for `ObjectRef::MemFn` (#113)
This PR adds specialization of `tvm::ffi::details::FunctionInfo<>` for
`ObjectRef` types, which fixes downstream compilation error.
For more context, in downstream TVM usecases, a member function of an
`ObjectRef` could be registered as a global function, which is not taken
into consideration in current project, resulting in compilation issues.
---
cmake/Utils/AddLibbacktrace.cmake | 2 ++
include/tvm/ffi/function_details.h | 20 +++++++++++++++-----
python/tvm_ffi/testing.py | 1 +
src/ffi/extra/testing.cc | 13 +++++++++----
tests/python/test_metadata.py | 6 ++++++
5 files changed, 33 insertions(+), 9 deletions(-)
diff --git a/cmake/Utils/AddLibbacktrace.cmake
b/cmake/Utils/AddLibbacktrace.cmake
index eda095a..17b9c8a 100644
--- a/cmake/Utils/AddLibbacktrace.cmake
+++ b/cmake/Utils/AddLibbacktrace.cmake
@@ -42,6 +42,7 @@ function (_libbacktrace_compile)
PREFIX libbacktrace
SOURCE_DIR ${libbacktrace_source}
BINARY_DIR ${libbacktrace_prefix}
+ LOG_DIR ${libbacktrace_prefix}/logs
CONFIGURE_COMMAND
"sh" #
"${libbacktrace_source}/configure" #
@@ -61,6 +62,7 @@ function (_libbacktrace_compile)
LOG_CONFIGURE ON
LOG_INSTALL ON
LOG_BUILD ON
+ LOG_MERGED_STDOUTERR ON
LOG_OUTPUT_ON_FAILURE ON
)
ExternalProject_Add_Step(
diff --git a/include/tvm/ffi/function_details.h
b/include/tvm/ffi/function_details.h
index d003f4b..e7766fd 100644
--- a/include/tvm/ffi/function_details.h
+++ b/include/tvm/ffi/function_details.h
@@ -107,17 +107,27 @@ struct FunctionInfoHelper<R (T::*)(Args...) const> :
FuncFunctorImpl<R, Args...>
* \tparam T The function/functor type.
* \note We need a decltype redirection because this helps lambda types.
*/
-template <typename T>
+template <typename T, typename = void>
struct FunctionInfo : FunctionInfoHelper<decltype(&T::operator())> {};
template <typename R, typename... Args>
-struct FunctionInfo<R(Args...)> : FuncFunctorImpl<R, Args...> {};
+struct FunctionInfo<R(Args...), void> : FuncFunctorImpl<R, Args...> {};
template <typename R, typename... Args>
-struct FunctionInfo<R (*)(Args...)> : FuncFunctorImpl<R, Args...> {};
+struct FunctionInfo<R (*)(Args...), void> : FuncFunctorImpl<R, Args...> {};
// Support pointer-to-member functions used in reflection (e.g. &Class::method)
template <typename Class, typename R, typename... Args>
-struct FunctionInfo<R (Class::*)(Args...)> : FuncFunctorImpl<R, Class*,
Args...> {};
+struct FunctionInfo<R (Class::*)(Args...),
std::enable_if_t<std::is_base_of_v<Object, Class>>>
+ : FuncFunctorImpl<R, Class*, Args...> {};
+template <typename Class, typename R, typename... Args>
+struct FunctionInfo<R (Class::*)(Args...) const,
std::enable_if_t<std::is_base_of_v<Object, Class>>>
+ : FuncFunctorImpl<R, const Class*, Args...> {};
+
+template <typename Class, typename R, typename... Args>
+struct FunctionInfo<R (Class::*)(Args...),
std::enable_if_t<std::is_base_of_v<ObjectRef, Class>>>
+ : FuncFunctorImpl<R, Class, Args...> {};
template <typename Class, typename R, typename... Args>
-struct FunctionInfo<R (Class::*)(Args...) const> : FuncFunctorImpl<R, const
Class*, Args...> {};
+struct FunctionInfo<R (Class::*)(Args...) const,
+ std::enable_if_t<std::is_base_of_v<ObjectRef, Class>>>
+ : FuncFunctorImpl<R, const Class, Args...> {};
/*! \brief Using static function to output typed function signature */
using FGetFuncSignature = std::string (*)();
diff --git a/python/tvm_ffi/testing.py b/python/tvm_ffi/testing.py
index 71467df..820fd60 100644
--- a/python/tvm_ffi/testing.py
+++ b/python/tvm_ffi/testing.py
@@ -57,6 +57,7 @@ class TestIntPair(Object):
b: int
@staticmethod
def __c_ffi_init__(_0: int, _1: int, /) -> Object: ...
+ def sum(_0: TestIntPair, /) -> int: ...
# fmt: on
# tvm-ffi-stubgen(end)
diff --git a/src/ffi/extra/testing.cc b/src/ffi/extra/testing.cc
index 13f0044..11772ad 100644
--- a/src/ffi/extra/testing.cc
+++ b/src/ffi/extra/testing.cc
@@ -59,6 +59,8 @@ class TestIntPair : public tvm::ffi::ObjectRef {
data_ = tvm::ffi::make_object<TestIntPairObj>(a, b);
}
+ int64_t Sum() const { return get()->a + get()->b; }
+
// Required: define object reference methods
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TestIntPair, tvm::ffi::ObjectRef,
TestIntPairObj);
};
@@ -68,7 +70,8 @@ TVM_FFI_STATIC_INIT_BLOCK() {
refl::ObjectDef<TestIntPairObj>()
.def_ro("a", &TestIntPairObj::a, "Field `a`")
.def_ro("b", &TestIntPairObj::b, "Field `b`")
- .def_static("__ffi_init__", refl::init<TestIntPairObj, int64_t,
int64_t>);
+ .def_static("__ffi_init__", refl::init<TestIntPairObj, int64_t, int64_t>)
+ .def("sum", &TestIntPair::Sum, "Method to compute sum of a and b");
}
class TestObjectBase : public Object {
@@ -264,9 +267,11 @@ TVM_FFI_STATIC_INIT_BLOCK() {
TVMFFISafeCallType symbol = __add_one_c_symbol;
return reinterpret_cast<int64_t>(reinterpret_cast<void*>(symbol));
})
- .def("testing.get_mlir_add_one_c_symbol", []() {
- return
reinterpret_cast<int64_t>(reinterpret_cast<void*>(_mlir_add_one_c_symbol));
- });
+ .def("testing.get_mlir_add_one_c_symbol",
+ []() {
+ return
reinterpret_cast<int64_t>(reinterpret_cast<void*>(_mlir_add_one_c_symbol));
+ })
+ .def_method("testing.TestIntPairSum", &TestIntPair::Sum, "Get sum of the
pair");
}
} // namespace ffi
diff --git a/tests/python/test_metadata.py b/tests/python/test_metadata.py
index 22e2ce5..787ed91 100644
--- a/tests/python/test_metadata.py
+++ b/tests/python/test_metadata.py
@@ -189,3 +189,9 @@ def test_metadata_member_method() -> None:
break
else:
raise ValueError("Method not found: add_int")
+
+
+def test_mem_fn_as_global_func() -> None:
+ metadata: dict[str, Any] =
get_global_func_metadata("testing.TestIntPairSum")
+ type_schema: TypeSchema = TypeSchema.from_json_str(metadata["type_schema"])
+ assert str(type_schema) == "Callable[[testing.TestIntPair], int]"