This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new fc2630f feat: Add a method to define initializer (#127)
fc2630f is described below
commit fc2630fac7acc17f966aae7fb724b23a9eec93c6
Author: Yixin Dong <[email protected]>
AuthorDate: Wed Oct 15 08:52:26 2025 -0700
feat: Add a method to define initializer (#127)
This PR defines a `def(init<Args...>())` method to define a static
initializer function. This replaces the previous method of
`def_static("__ffi_init__", ...)`, and fully aligns with
nanobind/pybind.
---
docs/guides/cpp_guide.md | 2 +-
docs/guides/python_guide.md | 12 ++--
include/tvm/ffi/reflection/registry.h | 113 ++++++++++++++++++++++++----------
src/ffi/testing/testing.cc | 18 +++---
tests/cpp/CMakeLists.txt | 2 +-
tests/cpp/test_reflection.cc | 4 +-
6 files changed, 95 insertions(+), 56 deletions(-)
diff --git a/docs/guides/cpp_guide.md b/docs/guides/cpp_guide.md
index e37058b..ef69ed3 100644
--- a/docs/guides/cpp_guide.md
+++ b/docs/guides/cpp_guide.md
@@ -105,7 +105,7 @@ class MyIntPairObj : public tvm::ffi::Object {
// Required: declare type information
// to register a dynamic type index through the system
-TVM_FFI_DECLARE_OBJECT_INFO_FINAL("example.MyIntPair", MyIntPairObj,
tvm::ffi::Object);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("example.MyIntPair", MyIntPairObj,
tvm::ffi::Object);
};
void ExampleObjectPtr() {
diff --git a/docs/guides/python_guide.md b/docs/guides/python_guide.md
index dc34cfe..434f3ce 100644
--- a/docs/guides/python_guide.md
+++ b/docs/guides/python_guide.md
@@ -235,7 +235,7 @@ public:
TestIntPairObj(int64_t a, int64_t b) : a(a), b(b) {}
// Required: declare type information
-TVM_FFI_DECLARE_OBJECT_INFO_FINAL("testing.TestIntPair", TestIntPairObj,
tvm::ffi::Object);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("testing.TestIntPair", TestIntPairObj,
tvm::ffi::Object);
};
// Step 2: Define the reference wrapper (user-facing interface)
@@ -253,13 +253,11 @@ public:
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
// register the object into the system
- // register field accessors and a global static function `__create__` as
ffi::Function
+ // register field accessors and a global static function `__ffi_init__` as
ffi::Function
refl::ObjectDef<TestIntPairObj>()
+ .def(refl::init<int64_t, int64_t>())
.def_ro("a", &TestIntPairObj::a)
- .def_ro("b", &TestIntPairObj::b)
- .def_static("__create__", [](int64_t a, int64_t b) -> TestIntPair {
- return TestIntPair(a, b);
- });
+ .def_ro("b", &TestIntPairObj::b);
}
```
@@ -274,7 +272,7 @@ class TestIntPair(tvm_ffi.Object):
def __init__(self, a, b):
# This is a special method to call an FFI function whose return
# value exactly initializes the object handle of the object
- self.__init_handle_by_constructor__(TestIntPair.__create__, a, b)
+ self.__ffi_init__(a, b)
test_int_pair = TestIntPair(1, 2)
# We can access the fields by name
diff --git a/include/tvm/ffi/reflection/registry.h
b/include/tvm/ffi/reflection/registry.h
index 78ab7e8..124fc47 100644
--- a/include/tvm/ffi/reflection/registry.h
+++ b/include/tvm/ffi/reflection/registry.h
@@ -406,6 +406,56 @@ class GlobalDef : public ReflectionDefBase {
}
};
+/*!
+ * \brief Helper class to register a constructor method for object types.
+ *
+ * This helper is used with `ObjectDef::def()` to register an `__init__` method
+ * that constructs an object instance with the specified argument types.
+ *
+ * \tparam Args The argument types for the constructor.
+ *
+ * Example usage:
+ * \code
+ * class ExampleObject : public Object {
+ * public:
+ * int64_t v_i64;
+ * int32_t v_i32;
+ *
+ * ExampleObject(int64_t v_i64, int32_t v_i32) : v_i64(v_i64),
v_i32(v_i32) {}
+ * TVM_FFI_DECLARE_OBJECT_INFO("example.ExampleObject", ExampleObject,
Object);
+ * };
+ *
+ * // Register the constructor
+ * refl::ObjectDef<ExampleObject>()
+ * .def(refl::init<int64_t, int32_t>());
+ * \endcode
+ *
+ * \note The object type is automatically deduced from the `ObjectDef` context.
+ */
+template <typename... Args>
+struct init {
+ // Allow ObjectDef to access the execute function
+ template <typename Class>
+ friend class ObjectDef;
+
+ /*!
+ * \brief Constructor
+ */
+ init() {}
+
+ private:
+ /*!
+ * \brief Execute the constructor
+ * \tparam Class The class type.
+ * \param args The arguments to be passed to the constructor.
+ * \return The constructed object wrapped in an `ObjectRef`.
+ */
+ template <typename Class>
+ static inline ObjectRef execute(Args&&... args) {
+ return ObjectRef(ffi::make_object<Class>(std::forward<Args>(args)...));
+ }
+};
+
/*!
* \brief Helper to register Object's reflection metadata.
* \tparam Class The class type.
@@ -504,6 +554,34 @@ class ObjectDef : public ReflectionDefBase {
return *this;
}
+ /*!
+ * \brief Register a constructor for this object type.
+ *
+ * This method registers a static `__init__` method that constructs an
instance
+ * of the object with the specified argument types. The constructor can be
invoked
+ * from Python or other FFI bindings.
+ *
+ * \tparam Args The argument types for the constructor.
+ * \tparam Extra Additional arguments (e.g., docstring).
+ *
+ * \param init_func An instance of `init<Args...>` specifying constructor
signature.
+ * \param extra Optional additional metadata such as docstring.
+ *
+ * \return Reference to this `ObjectDef` for method chaining.
+ *
+ * Example:
+ * \code
+ * refl::ObjectDef<MyObject>()
+ * .def(refl::init<int64_t, std::string>(), "Constructor docstring");
+ * \endcode
+ */
+ template <typename... Args, typename... Extra>
+ TVM_FFI_INLINE ObjectDef& def(init<Args...> init_func, Extra&&... extra) {
+ RegisterMethod(INIT_METHOD_NAME, true, &init<Args...>::template
execute<Class>,
+ std::forward<Extra>(extra)...);
+ return *this;
+ }
+
private:
template <typename... ExtraArgs>
void RegisterExtraInfo(ExtraArgs&&... extra_args) {
@@ -576,6 +654,7 @@ class ObjectDef : public ReflectionDefBase {
int32_t type_index_;
const char* type_key_;
+ static constexpr const char* INIT_METHOD_NAME = "__ffi_init__";
};
/*!
@@ -656,40 +735,6 @@ inline void EnsureTypeAttrColumn(std::string_view name) {
reinterpret_cast<const
TVMFFIAny*>(&any_view)));
}
-/*!
- * \brief Invokes the constructor of a particular object type and returns an
`ObjectRef`.
- * \tparam T The object type to be constructed.
- * \tparam Args The argument types.
- * \param args The arguments to be passed to the constructor.
- * \return The constructed object wrapped in an `ObjectRef`.
- * \note This is usually used in FFI reflection boundary to register
`__init__` methods.
- *
- * Example
- *
- * \code
- *
- * class ExampleObject : public Object {
- * public:
- * int64_t v_i64;
- * int32_t v_i32;
- *
- * ExampleObject(int64_t v_i64, int32_t v_i32) : v_i64(v_i64),
v_i32(v_i32) {}
- * TVM_FFI_DECLARE_OBJECT_INFO("example.ExampleObject", ExampleObject,
Object);
- * };
- * refl::ObjectDef<ExampleObject>()
- * .def_static("__init__", refl::init<ExampleObject, int64_t, int32_t>);
- * \endcode
- */
-template <typename T, typename... Args>
-inline ObjectRef init(Args&&... args) {
- if constexpr (std::is_base_of_v<Object, T>) {
- return ObjectRef(ffi::make_object<T>(std::forward<Args>(args)...));
- } else {
- using U = typename T::ContainerType;
- return ObjectRef(ffi::make_object<U>(std::forward<Args>(args)...));
- }
-}
-
} // namespace reflection
} // namespace ffi
} // namespace tvm
diff --git a/src/ffi/testing/testing.cc b/src/ffi/testing/testing.cc
index 95c8b30..74f235e 100644
--- a/src/ffi/testing/testing.cc
+++ b/src/ffi/testing/testing.cc
@@ -68,9 +68,9 @@ class TestIntPair : public tvm::ffi::ObjectRef {
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<TestIntPairObj>()
+ .def(refl::init<int64_t, int64_t>())
.def_ro("a", &TestIntPairObj::a, "Field `a`")
.def_ro("b", &TestIntPairObj::b, "Field `b`")
- .def_static("__ffi_init__", refl::init<TestIntPairObj, int64_t, int64_t>)
.def("sum", &TestIntPair::Sum, "Method to compute sum of a and b");
}
@@ -206,40 +206,36 @@ TVM_FFI_STATIC_INIT_BLOCK() {
.def_ro("v_array", &TestObjectDerived::v_array);
refl::ObjectDef<TestCxxClassBase>()
- .def_static("__ffi_init__", refl::init<TestCxxClassBase, int64_t,
int32_t>)
+ .def(refl::init<int64_t, int32_t>())
.def_rw("v_i64", &TestCxxClassBase::v_i64)
.def_rw("v_i32", &TestCxxClassBase::v_i32);
refl::ObjectDef<TestCxxClassDerived>()
- .def_static("__ffi_init__", refl::init<TestCxxClassDerived, int64_t,
int32_t, double, float>)
+ .def(refl::init<int64_t, int32_t, double, float>())
.def_rw("v_f64", &TestCxxClassDerived::v_f64)
.def_rw("v_f32", &TestCxxClassDerived::v_f32);
refl::ObjectDef<TestCxxClassDerivedDerived>()
- .def_static(
- "__ffi_init__",
- refl::init<TestCxxClassDerivedDerived, int64_t, int32_t, double,
float, String, bool>)
+ .def(refl::init<int64_t, int32_t, double, float, String, bool>())
.def_rw("v_str", &TestCxxClassDerivedDerived::v_str)
.def_rw("v_bool", &TestCxxClassDerivedDerived::v_bool);
refl::ObjectDef<TestCxxInitSubsetObj>()
- .def_static("__ffi_init__", refl::init<TestCxxInitSubsetObj, int64_t,
String>)
+ .def(refl::init<int64_t, String>())
.def_rw("required_field", &TestCxxInitSubsetObj::required_field)
.def_rw("optional_field", &TestCxxInitSubsetObj::optional_field)
.def_rw("note", &TestCxxInitSubsetObj::note);
refl::ObjectDef<TestUnregisteredBaseObject>()
+ .def(refl::init<int64_t>(), "Constructor of TestUnregisteredBaseObject")
.def_ro("v1", &TestUnregisteredBaseObject::v1)
- .def_static("__ffi_init__", refl::init<TestUnregisteredBaseObject,
int64_t>,
- "Constructor of TestUnregisteredBaseObject")
.def("get_v1_plus_one", &TestUnregisteredBaseObject::GetV1PlusOne,
"Get (v1 + 1) from TestUnregisteredBaseObject");
refl::ObjectDef<TestUnregisteredObject>()
+ .def(refl::init<int64_t, int64_t>(), "Constructor of
TestUnregisteredObject")
.def_ro("v1", &TestUnregisteredObject::v1)
.def_ro("v2", &TestUnregisteredObject::v2)
- .def_static("__ffi_init__", refl::init<TestUnregisteredObject, int64_t,
int64_t>,
- "Constructor of TestUnregisteredObject")
.def("get_v2_plus_two", &TestUnregisteredObject::GetV2PlusTwo,
"Get (v2 + 2) from TestUnregisteredObject");
diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt
index 7ab3c2a..e08b2f8 100644
--- a/tests/cpp/CMakeLists.txt
+++ b/tests/cpp/CMakeLists.txt
@@ -5,7 +5,7 @@ if (TVM_FFI_USE_EXTRA_CXX_API)
list(APPEND _test_sources ${_test_extra_sources})
endif ()
-add_executable(tvm_ffi_tests EXCLUDE_FROM_ALL ${_test_sources})
+add_executable(tvm_ffi_tests ${_test_sources})
set_target_properties(
tvm_ffi_tests
diff --git a/tests/cpp/test_reflection.cc b/tests/cpp/test_reflection.cc
index c81a20b..8fe6a18 100644
--- a/tests/cpp/test_reflection.cc
+++ b/tests/cpp/test_reflection.cc
@@ -63,11 +63,11 @@ TVM_FFI_STATIC_INIT_BLOCK() {
TCustomFuncObj::RegisterReflection();
refl::ObjectDef<TestObjA>()
- .def_static("__ffi_init__", refl::init<TestObjA, int64_t, int64_t>)
+ .def(refl::init<int64_t, int64_t>())
.def_ro("x", &TestObjA::x)
.def_rw("y", &TestObjA::y);
refl::ObjectDef<TestObjADerived>()
- .def_static("__ffi_init__", refl::init<TestObjRefADerived, int64_t,
int64_t, int64_t>)
+ .def(refl::init<int64_t, int64_t, int64_t>())
.def_ro("z", &TestObjADerived::z);
}