This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new c01dadf  feat: Introduce `tvm_ffi::reflection::init<>` (#32)
c01dadf is described below

commit c01dadf31a66e74cdbfd7fdb1ffc81a75007965f
Author: Junru Shao <[email protected]>
AuthorDate: Sun Sep 21 04:27:24 2025 -0700

    feat: Introduce `tvm_ffi::reflection::init<>` (#32)
    
    This PR introduces `tvm_ffi::reflection::init<ObjectType, Args...>`,
    which can be used to simplify registration of `__init__` method.
---
 include/tvm/ffi/object.h              |  4 ++--
 include/tvm/ffi/reflection/registry.h | 33 +++++++++++++++++++++++++++++++++
 python/tvm_ffi/cython/type_info.pxi   |  2 +-
 python/tvm_ffi/testing.py             |  9 +++++++++
 src/ffi/extra/testing.cc              |  3 +--
 tests/cpp/test_reflection.cc          | 32 ++++++++++++++++++++++++++++++--
 tests/python/test_object.py           |  6 ++++++
 7 files changed, 82 insertions(+), 7 deletions(-)

diff --git a/include/tvm/ffi/object.h b/include/tvm/ffi/object.h
index 6dcc30e..1ebd3d7 100644
--- a/include/tvm/ffi/object.h
+++ b/include/tvm/ffi/object.h
@@ -960,7 +960,7 @@ struct ObjectPtrEqual {
   using __PtrType = std::conditional_t<ObjectName::_type_mutable, ObjectName*, 
const ObjectName*>; \
   __PtrType operator->() const { return static_cast<__PtrType>(data_.get()); } 
                    \
   __PtrType get() const { return static_cast<__PtrType>(data_.get()); }        
                    \
-  static constexpr bool _type_is_nullable = true;                              
                    \
+  [[maybe_unused]] static constexpr bool _type_is_nullable = true;             
                    \
   using ContainerType = ObjectName
 
 /*!
@@ -976,7 +976,7 @@ struct ObjectPtrEqual {
   using __PtrType = std::conditional_t<ObjectName::_type_mutable, ObjectName*, 
const ObjectName*>; \
   __PtrType operator->() const { return static_cast<__PtrType>(data_.get()); } 
                    \
   __PtrType get() const { return static_cast<__PtrType>(data_.get()); }        
                    \
-  static constexpr bool _type_is_nullable = false;                             
                    \
+  [[maybe_unused]] static constexpr bool _type_is_nullable = false;            
                    \
   using ContainerType = ObjectName
 
 namespace details {
diff --git a/include/tvm/ffi/reflection/registry.h 
b/include/tvm/ffi/reflection/registry.h
index c0d984f..f72fd3c 100644
--- a/include/tvm/ffi/reflection/registry.h
+++ b/include/tvm/ffi/reflection/registry.h
@@ -558,6 +558,39 @@ inline void EnsureTypeAttrColumn(std::string_view name) {
                                                  reinterpret_cast<const 
TVMFFIAny*>(&any_view)));
 }
 
+/*!
+ * \brief Invokes the constructor of a particular object type and returns an 
`ObjectRef`.
+ * \tparam T The object type to be constructed.
+ * \tparam Args The argument types.
+ * \param args The arguments to be passed to the constructor.
+ * \return The constructed object wrapped in an `ObjectRef`.
+ * \note This is usually used in FFI reflection boundary to register 
`__init__` methods.
+ *
+ * Example
+ *
+ * \code
+ *
+ *   class ExampleObject : public Object {
+ *    public:
+ *      int64_t v_i64;
+ *      int32_t v_i32;
+ *
+ *      ExampleObject(int64_t v_i64, int32_t v_i32) : v_i64(v_i64), 
v_i32(v_i32) {}
+ *      TVM_FFI_DECLARE_OBJECT_INFO("example.ExampleObject", ExampleObject, 
Object);
+ *   };
+ *   refl::ObjectDef<ExampleObject>()
+ *      .def_static("__init__", refl::init<ExampleObject, int64_t, int32_t>);
+ */
+template <typename T, typename... Args>
+inline ObjectRef init(Args&&... args) {
+  if constexpr (std::is_base_of_v<Object, T>) {
+    return ObjectRef(ffi::make_object<T>(std::forward<Args>(args)...));
+  } else {
+    using U = typename T::ContainerType;
+    return ObjectRef(ffi::make_object<U>(std::forward<Args>(args)...));
+  }
+}
+
 }  // namespace reflection
 }  // namespace ffi
 }  // namespace tvm
diff --git a/python/tvm_ffi/cython/type_info.pxi 
b/python/tvm_ffi/cython/type_info.pxi
index 2abb204..bde25be 100644
--- a/python/tvm_ffi/cython/type_info.pxi
+++ b/python/tvm_ffi/cython/type_info.pxi
@@ -73,7 +73,7 @@ class TypeField:
         assert self.setter is not None
         assert self.getter is not None
 
-    def as_property(self, cls: type) -> property:
+    def as_property(self, cls: type):
         """Create a Python ``property`` object for this field on ``cls``."""
         name = self.name
         fget = self.getter
diff --git a/python/tvm_ffi/testing.py b/python/tvm_ffi/testing.py
index e58c115..3215d8a 100644
--- a/python/tvm_ffi/testing.py
+++ b/python/tvm_ffi/testing.py
@@ -28,6 +28,15 @@ class TestObjectBase(Object):
     """Test object base class."""
 
 
+@register_object("testing.TestIntPair")
+class TestIntPair(Object):
+    """Test Int Pair."""
+
+    def __init__(self, a: int, b: int) -> None:
+        """Construct the object."""
+        self.__init_handle_by_constructor__(TestIntPair.__ffi_init__, a, b)
+
+
 @register_object("testing.TestObjectDerived")
 class TestObjectDerived(TestObjectBase):
     """Test object derived class."""
diff --git a/src/ffi/extra/testing.cc b/src/ffi/extra/testing.cc
index 3d9501d..9c3a019 100644
--- a/src/ffi/extra/testing.cc
+++ b/src/ffi/extra/testing.cc
@@ -60,8 +60,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
   refl::ObjectDef<TestIntPairObj>()
       .def_ro("a", &TestIntPairObj::a)
       .def_ro("b", &TestIntPairObj::b)
-      .def_static("__create__",
-                  [](int64_t a, int64_t b) -> TestIntPair { return 
TestIntPair(a, b); });
+      .def_static("__ffi_init__", refl::init<TestIntPairObj, int64_t, 
int64_t>);
 }
 
 class TestObjectBase : public Object {
diff --git a/tests/cpp/test_reflection.cc b/tests/cpp/test_reflection.cc
index c9aa500..89b7ccd 100644
--- a/tests/cpp/test_reflection.cc
+++ b/tests/cpp/test_reflection.cc
@@ -36,6 +36,7 @@ using namespace tvm::ffi::testing;
 struct TestObjA : public Object {
   int64_t x;
   int64_t y;
+  TestObjA(int64_t x, int64_t y) : x(x), y(y) {}
 
   static constexpr bool _type_mutable = true;
   TVM_FFI_DECLARE_OBJECT_INFO("test.TestObjA", TestObjA, Object);
@@ -43,9 +44,14 @@ struct TestObjA : public Object {
 
 struct TestObjADerived : public TestObjA {
   int64_t z;
+  TestObjADerived(int64_t x, int64_t y, int64_t z) : TestObjA(x, y), z(z) {}
   TVM_FFI_DECLARE_OBJECT_INFO_FINAL("test.TestObjADerived", TestObjADerived, 
TestObjA);
 };
 
+struct TestObjRefADerived : public ObjectRef {
+  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TestObjRefADerived, ObjectRef, 
TestObjADerived);
+};
+
 TVM_FFI_STATIC_INIT_BLOCK() {
   namespace refl = tvm::ffi::reflection;
 
@@ -56,8 +62,13 @@ TVM_FFI_STATIC_INIT_BLOCK() {
   TFuncObj::RegisterReflection();
   TCustomFuncObj::RegisterReflection();
 
-  refl::ObjectDef<TestObjA>().def_ro("x", &TestObjA::x).def_rw("y", 
&TestObjA::y);
-  refl::ObjectDef<TestObjADerived>().def_ro("z", &TestObjADerived::z);
+  refl::ObjectDef<TestObjA>()
+      .def_static("__ffi_init__", refl::init<TestObjA, int64_t, int64_t>)
+      .def_ro("x", &TestObjA::x)
+      .def_rw("y", &TestObjA::y);
+  refl::ObjectDef<TestObjADerived>()
+      .def_static("__ffi_init__", refl::init<TestObjRefADerived, int64_t, 
int64_t, int64_t>)
+      .def_ro("z", &TestObjADerived::z);
 }
 
 TEST(Reflection, GetFieldByteOffset) {
@@ -131,6 +142,23 @@ TEST(Reflection, CallMethod) {
   EXPECT_EQ(prim_expr_sub(TPrimExpr("float", 1), 2.0).cast<double>(), -1.0);
 }
 
+TEST(Reflection, InitFunction_Base) {
+  Function int_init = reflection::GetMethod("test.TestObjA", "__ffi_init__");
+  Any obj_a = int_init(1, 2);
+  EXPECT_TRUE(obj_a.as<TestObjA>() != nullptr);
+  EXPECT_EQ(obj_a.as<TestObjA>()->x, 1);
+  EXPECT_EQ(obj_a.as<TestObjA>()->y, 2);
+}
+
+TEST(Reflection, InitFunction_Derived) {
+  Function derived_init = reflection::GetMethod("test.TestObjADerived", 
"__ffi_init__");
+  Any obj_derived = derived_init(1, 2, 3);
+  EXPECT_TRUE(obj_derived.as<TestObjADerived>() != nullptr);
+  EXPECT_EQ(obj_derived.as<TestObjADerived>()->x, 1);
+  EXPECT_EQ(obj_derived.as<TestObjADerived>()->y, 2);
+  EXPECT_EQ(obj_derived.as<TestObjADerived>()->z, 3);
+}
+
 TEST(Reflection, ForEachFieldInfo) {
   const TypeInfo* info = 
TVMFFIGetTypeInfo(TestObjADerived::RuntimeTypeIndex());
   Map<String, int> field_name_to_offset;
diff --git a/tests/python/test_object.py b/tests/python/test_object.py
index bcfb52e..ea54adf 100644
--- a/tests/python/test_object.py
+++ b/tests/python/test_object.py
@@ -29,6 +29,12 @@ def test_make_object() -> None:
     assert obj0.v_str == "hello"
 
 
+def test_make_object_via_init() -> None:
+    obj0 = tvm_ffi.testing.TestIntPair(1, 2)
+    assert obj0.a == 1
+    assert obj0.b == 2
+
+
 def test_method() -> None:
     obj0 = tvm_ffi.testing.create_object("testing.TestObjectBase", v_i64=12)
     assert obj0.add_i64(1) == 13

Reply via email to