This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new c01dadf feat: Introduce `tvm_ffi::reflection::init<>` (#32)
c01dadf is described below
commit c01dadf31a66e74cdbfd7fdb1ffc81a75007965f
Author: Junru Shao <[email protected]>
AuthorDate: Sun Sep 21 04:27:24 2025 -0700
feat: Introduce `tvm_ffi::reflection::init<>` (#32)
This PR introduces `tvm_ffi::reflection::init<ObjectType, Args...>`,
which can be used to simplify registration of `__init__` method.
---
include/tvm/ffi/object.h | 4 ++--
include/tvm/ffi/reflection/registry.h | 33 +++++++++++++++++++++++++++++++++
python/tvm_ffi/cython/type_info.pxi | 2 +-
python/tvm_ffi/testing.py | 9 +++++++++
src/ffi/extra/testing.cc | 3 +--
tests/cpp/test_reflection.cc | 32 ++++++++++++++++++++++++++++++--
tests/python/test_object.py | 6 ++++++
7 files changed, 82 insertions(+), 7 deletions(-)
diff --git a/include/tvm/ffi/object.h b/include/tvm/ffi/object.h
index 6dcc30e..1ebd3d7 100644
--- a/include/tvm/ffi/object.h
+++ b/include/tvm/ffi/object.h
@@ -960,7 +960,7 @@ struct ObjectPtrEqual {
using __PtrType = std::conditional_t<ObjectName::_type_mutable, ObjectName*,
const ObjectName*>; \
__PtrType operator->() const { return static_cast<__PtrType>(data_.get()); }
\
__PtrType get() const { return static_cast<__PtrType>(data_.get()); }
\
- static constexpr bool _type_is_nullable = true;
\
+ [[maybe_unused]] static constexpr bool _type_is_nullable = true;
\
using ContainerType = ObjectName
/*!
@@ -976,7 +976,7 @@ struct ObjectPtrEqual {
using __PtrType = std::conditional_t<ObjectName::_type_mutable, ObjectName*,
const ObjectName*>; \
__PtrType operator->() const { return static_cast<__PtrType>(data_.get()); }
\
__PtrType get() const { return static_cast<__PtrType>(data_.get()); }
\
- static constexpr bool _type_is_nullable = false;
\
+ [[maybe_unused]] static constexpr bool _type_is_nullable = false;
\
using ContainerType = ObjectName
namespace details {
diff --git a/include/tvm/ffi/reflection/registry.h
b/include/tvm/ffi/reflection/registry.h
index c0d984f..f72fd3c 100644
--- a/include/tvm/ffi/reflection/registry.h
+++ b/include/tvm/ffi/reflection/registry.h
@@ -558,6 +558,39 @@ inline void EnsureTypeAttrColumn(std::string_view name) {
reinterpret_cast<const
TVMFFIAny*>(&any_view)));
}
+/*!
+ * \brief Invokes the constructor of a particular object type and returns an
`ObjectRef`.
+ * \tparam T The object type to be constructed.
+ * \tparam Args The argument types.
+ * \param args The arguments to be passed to the constructor.
+ * \return The constructed object wrapped in an `ObjectRef`.
+ * \note This is usually used in FFI reflection boundary to register
`__init__` methods.
+ *
+ * Example
+ *
+ * \code
+ *
+ * class ExampleObject : public Object {
+ * public:
+ * int64_t v_i64;
+ * int32_t v_i32;
+ *
+ * ExampleObject(int64_t v_i64, int32_t v_i32) : v_i64(v_i64),
v_i32(v_i32) {}
+ * TVM_FFI_DECLARE_OBJECT_INFO("example.ExampleObject", ExampleObject,
Object);
+ * };
+ * refl::ObjectDef<ExampleObject>()
+ * .def_static("__init__", refl::init<ExampleObject, int64_t, int32_t>);
+ */
+template <typename T, typename... Args>
+inline ObjectRef init(Args&&... args) {
+ if constexpr (std::is_base_of_v<Object, T>) {
+ return ObjectRef(ffi::make_object<T>(std::forward<Args>(args)...));
+ } else {
+ using U = typename T::ContainerType;
+ return ObjectRef(ffi::make_object<U>(std::forward<Args>(args)...));
+ }
+}
+
} // namespace reflection
} // namespace ffi
} // namespace tvm
diff --git a/python/tvm_ffi/cython/type_info.pxi
b/python/tvm_ffi/cython/type_info.pxi
index 2abb204..bde25be 100644
--- a/python/tvm_ffi/cython/type_info.pxi
+++ b/python/tvm_ffi/cython/type_info.pxi
@@ -73,7 +73,7 @@ class TypeField:
assert self.setter is not None
assert self.getter is not None
- def as_property(self, cls: type) -> property:
+ def as_property(self, cls: type):
"""Create a Python ``property`` object for this field on ``cls``."""
name = self.name
fget = self.getter
diff --git a/python/tvm_ffi/testing.py b/python/tvm_ffi/testing.py
index e58c115..3215d8a 100644
--- a/python/tvm_ffi/testing.py
+++ b/python/tvm_ffi/testing.py
@@ -28,6 +28,15 @@ class TestObjectBase(Object):
"""Test object base class."""
+@register_object("testing.TestIntPair")
+class TestIntPair(Object):
+ """Test Int Pair."""
+
+ def __init__(self, a: int, b: int) -> None:
+ """Construct the object."""
+ self.__init_handle_by_constructor__(TestIntPair.__ffi_init__, a, b)
+
+
@register_object("testing.TestObjectDerived")
class TestObjectDerived(TestObjectBase):
"""Test object derived class."""
diff --git a/src/ffi/extra/testing.cc b/src/ffi/extra/testing.cc
index 3d9501d..9c3a019 100644
--- a/src/ffi/extra/testing.cc
+++ b/src/ffi/extra/testing.cc
@@ -60,8 +60,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
refl::ObjectDef<TestIntPairObj>()
.def_ro("a", &TestIntPairObj::a)
.def_ro("b", &TestIntPairObj::b)
- .def_static("__create__",
- [](int64_t a, int64_t b) -> TestIntPair { return
TestIntPair(a, b); });
+ .def_static("__ffi_init__", refl::init<TestIntPairObj, int64_t,
int64_t>);
}
class TestObjectBase : public Object {
diff --git a/tests/cpp/test_reflection.cc b/tests/cpp/test_reflection.cc
index c9aa500..89b7ccd 100644
--- a/tests/cpp/test_reflection.cc
+++ b/tests/cpp/test_reflection.cc
@@ -36,6 +36,7 @@ using namespace tvm::ffi::testing;
struct TestObjA : public Object {
int64_t x;
int64_t y;
+ TestObjA(int64_t x, int64_t y) : x(x), y(y) {}
static constexpr bool _type_mutable = true;
TVM_FFI_DECLARE_OBJECT_INFO("test.TestObjA", TestObjA, Object);
@@ -43,9 +44,14 @@ struct TestObjA : public Object {
struct TestObjADerived : public TestObjA {
int64_t z;
+ TestObjADerived(int64_t x, int64_t y, int64_t z) : TestObjA(x, y), z(z) {}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("test.TestObjADerived", TestObjADerived,
TestObjA);
};
+struct TestObjRefADerived : public ObjectRef {
+ TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TestObjRefADerived, ObjectRef,
TestObjADerived);
+};
+
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
@@ -56,8 +62,13 @@ TVM_FFI_STATIC_INIT_BLOCK() {
TFuncObj::RegisterReflection();
TCustomFuncObj::RegisterReflection();
- refl::ObjectDef<TestObjA>().def_ro("x", &TestObjA::x).def_rw("y",
&TestObjA::y);
- refl::ObjectDef<TestObjADerived>().def_ro("z", &TestObjADerived::z);
+ refl::ObjectDef<TestObjA>()
+ .def_static("__ffi_init__", refl::init<TestObjA, int64_t, int64_t>)
+ .def_ro("x", &TestObjA::x)
+ .def_rw("y", &TestObjA::y);
+ refl::ObjectDef<TestObjADerived>()
+ .def_static("__ffi_init__", refl::init<TestObjRefADerived, int64_t,
int64_t, int64_t>)
+ .def_ro("z", &TestObjADerived::z);
}
TEST(Reflection, GetFieldByteOffset) {
@@ -131,6 +142,23 @@ TEST(Reflection, CallMethod) {
EXPECT_EQ(prim_expr_sub(TPrimExpr("float", 1), 2.0).cast<double>(), -1.0);
}
+TEST(Reflection, InitFunction_Base) {
+ Function int_init = reflection::GetMethod("test.TestObjA", "__ffi_init__");
+ Any obj_a = int_init(1, 2);
+ EXPECT_TRUE(obj_a.as<TestObjA>() != nullptr);
+ EXPECT_EQ(obj_a.as<TestObjA>()->x, 1);
+ EXPECT_EQ(obj_a.as<TestObjA>()->y, 2);
+}
+
+TEST(Reflection, InitFunction_Derived) {
+ Function derived_init = reflection::GetMethod("test.TestObjADerived",
"__ffi_init__");
+ Any obj_derived = derived_init(1, 2, 3);
+ EXPECT_TRUE(obj_derived.as<TestObjADerived>() != nullptr);
+ EXPECT_EQ(obj_derived.as<TestObjADerived>()->x, 1);
+ EXPECT_EQ(obj_derived.as<TestObjADerived>()->y, 2);
+ EXPECT_EQ(obj_derived.as<TestObjADerived>()->z, 3);
+}
+
TEST(Reflection, ForEachFieldInfo) {
const TypeInfo* info =
TVMFFIGetTypeInfo(TestObjADerived::RuntimeTypeIndex());
Map<String, int> field_name_to_offset;
diff --git a/tests/python/test_object.py b/tests/python/test_object.py
index bcfb52e..ea54adf 100644
--- a/tests/python/test_object.py
+++ b/tests/python/test_object.py
@@ -29,6 +29,12 @@ def test_make_object() -> None:
assert obj0.v_str == "hello"
+def test_make_object_via_init() -> None:
+ obj0 = tvm_ffi.testing.TestIntPair(1, 2)
+ assert obj0.a == 1
+ assert obj0.b == 2
+
+
def test_method() -> None:
obj0 = tvm_ffi.testing.create_object("testing.TestObjectBase", v_i64=12)
assert obj0.add_i64(1) == 13