This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch refactor-s0
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 8503919cdb80a9ab45a801cc9590eb18a8c99424
Author: tqchen <[email protected]>
AuthorDate: Mon Mar 10 09:43:07 2025 -0400

    [FFI] Enhance the deleter and reflection to be aware of padding
    
    Ensure correct calculation of class ptr based on obj ptr.
---
 ffi/include/tvm/ffi/c_api.h       |  2 +-
 ffi/include/tvm/ffi/memory.h      | 23 +++++++++--------------
 ffi/include/tvm/ffi/object.h      | 21 ++++++++++++++++++---
 ffi/include/tvm/ffi/reflection.h  |  7 ++++---
 ffi/include/tvm/ffi/type_traits.h |  2 +-
 ffi/tests/cpp/test_object.cc      |  4 ++++
 ffi/tests/cpp/test_reflection.cc  |  7 ++++---
 ffi/tests/cpp/testing_object.h    | 12 +++++++++++-
 8 files changed, 52 insertions(+), 26 deletions(-)

diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h
index 5a626d03c0..a1cf753448 100644
--- a/ffi/include/tvm/ffi/c_api.h
+++ b/ffi/include/tvm/ffi/c_api.h
@@ -111,7 +111,7 @@ typedef struct TVMFFIObject {
   /*! \brief Reference counter of the object. */
   int32_t ref_counter;
   /*! \brief Deleter to be invoked when reference counter goes to zero. */
-  void (*deleter)(void* self);
+  void (*deleter)(struct TVMFFIObject* self);
 } TVMFFIObject;
 
 /*!
diff --git a/ffi/include/tvm/ffi/memory.h b/ffi/include/tvm/ffi/memory.h
index d0e5995fa5..d3af2dd490 100644
--- a/ffi/include/tvm/ffi/memory.h
+++ b/ffi/include/tvm/ffi/memory.h
@@ -33,7 +33,7 @@ namespace tvm {
 namespace ffi {
 
 /*! \brief Deleter function for obeject */
-typedef void (*FObjectDeleter)(void* obj);
+typedef void (*FObjectDeleter)(TVMFFIObject* obj);
 
 /*!
  * \brief Allocate an object using default allocator.
@@ -75,10 +75,10 @@ class ObjAllocatorBase {
     static_assert(std::is_base_of<Object, T>::value, "make can only be used to 
create Object");
     T* ptr = Handler::New(static_cast<Derived*>(this), 
std::forward<Args>(args)...);
     TVMFFIObject* ffi_ptr = details::ObjectUnsafe::GetHeader(ptr);
-    // NOTE: ref_counter is initialized in object constructor
+    ffi_ptr->ref_counter = 1;
     ffi_ptr->type_index = T::RuntimeTypeIndex();
     ffi_ptr->deleter = Handler::Deleter();
-    return details::ObjectUnsafe::ObjectPtrFromUnowned<T>(ptr);
+    return details::ObjectUnsafe::ObjectPtrFromOwned<T>(ptr);
   }
 
   /*!
@@ -96,9 +96,10 @@ class ObjAllocatorBase {
     ArrayType* ptr =
         Handler::New(static_cast<Derived*>(this), num_elems, 
std::forward<Args>(args)...);
     TVMFFIObject* ffi_ptr = details::ObjectUnsafe::GetHeader(ptr);
+    ffi_ptr->ref_counter = 1;
     ffi_ptr->type_index = ArrayType::RuntimeTypeIndex();
     ffi_ptr->deleter = Handler::Deleter();
-    return details::ObjectUnsafe::ObjectPtrFromUnowned<ArrayType>(ptr);
+    return details::ObjectUnsafe::ObjectPtrFromOwned<ArrayType>(ptr);
   }
 };
 
@@ -133,11 +134,8 @@ class SimpleObjAllocator : public 
ObjAllocatorBase<SimpleObjAllocator> {
     static FObjectDeleter Deleter() { return Deleter_; }
 
    private:
-    static void Deleter_(void* objptr) {
-      // NOTE: this is important to cast back to T*
-      // because objptr and tptr may not be the same
-      // depending on how sub-class allocates the space.
-      T* tptr = static_cast<T*>(objptr);
+    static void Deleter_(TVMFFIObject* objptr) {
+      T* tptr = details::ObjectUnsafe::RawObjectPtrFromUnowned<T>(objptr);
       // It is important to do tptr->T::~T(),
       // so that we explicitly call the specific destructor
       // instead of tptr->~T(), which could mean the intention
@@ -182,11 +180,8 @@ class SimpleObjAllocator : public 
ObjAllocatorBase<SimpleObjAllocator> {
     static FObjectDeleter Deleter() { return Deleter_; }
 
    private:
-    static void Deleter_(void* objptr) {
-      // NOTE: this is important to cast back to ArrayType*
-      // because objptr and tptr may not be the same
-      // depending on how sub-class allocates the space.
-      ArrayType* tptr = static_cast<ArrayType*>(objptr);
+    static void Deleter_(TVMFFIObject* objptr) {
+      ArrayType* tptr = 
details::ObjectUnsafe::RawObjectPtrFromUnowned<ArrayType>(objptr);
       // It is important to do tptr->ArrayType::~ArrayType(),
       // so that we explicitly call the specific destructor
       // instead of tptr->~ArrayType(), which could mean the intention
diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h
index 70b68f6fc6..8173a362e2 100644
--- a/ffi/include/tvm/ffi/object.h
+++ b/ffi/include/tvm/ffi/object.h
@@ -104,7 +104,6 @@ class Object {
     header_.ref_counter = 0;
     header_.deleter = nullptr;
   }
-
   /*!
    * Check if the object is an instance of TargetType.
    * \tparam TargetType The target type to be checked.
@@ -156,7 +155,7 @@ class Object {
   void DecRef() {
     if (details::AtomicDecrementRelAcq(&(header_.ref_counter)) == 1) {
       if (header_.deleter != nullptr) {
-        header_.deleter(this);
+        header_.deleter(&(this->header_));
       }
     }
   }
@@ -520,6 +519,14 @@ class ObjectUnsafe {
     return const_cast<TVMFFIObject*>(&(src->header_));
   }
 
+  template <typename Class>
+  static TVM_FFI_INLINE int64_t GetObjectOffsetToSubclass() {
+    return (
+      reinterpret_cast<int64_t>(&(static_cast<Class*>(nullptr)->header_)) -
+      reinterpret_cast<int64_t>(&(static_cast<Object*>(nullptr)->header_))
+    );
+  }
+
   template <typename T>
   static TVM_FFI_INLINE ObjectPtr<T> ObjectPtrFromOwned(Object* raw_ptr) {
     tvm::ffi::ObjectPtr<T> ptr;
@@ -527,7 +534,15 @@ class ObjectUnsafe {
     return ptr;
   }
 
-  // Create ObjectPtr from unknowned ptr
+  template <typename T>
+  static TVM_FFI_INLINE T* RawObjectPtrFromUnowned(TVMFFIObject* obj_ptr) {
+    // NOTE: this is important to first cast to Object*
+    // then cast back to T* because objptr and tptr may not be the same
+    // depending on how sub-class allocates the space.
+    return static_cast<T*>(reinterpret_cast<Object*>(obj_ptr));
+  }
+
+  // Create ObjectPtr from unowned ptr
   template <typename T>
   static TVM_FFI_INLINE ObjectPtr<T> ObjectPtrFromUnowned(Object* raw_ptr) {
     return tvm::ffi::ObjectPtr<T>(raw_ptr);
diff --git a/ffi/include/tvm/ffi/reflection.h b/ffi/include/tvm/ffi/reflection.h
index 3e6a1ecc37..6da880ab1f 100644
--- a/ffi/include/tvm/ffi/reflection.h
+++ b/ffi/include/tvm/ffi/reflection.h
@@ -52,8 +52,9 @@ struct Type2FieldStaticTypeIndex<T, 
std::enable_if_t<TypeTraits<T>::enabled>> {
  * \returns The byteoffset
  */
 template <typename Class, typename T>
-inline int64_t GetFieldByteOffset(T Class::* field_ptr) {
-  return 
reinterpret_cast<int64_t>(&(static_cast<Class*>(nullptr)->*field_ptr));
+inline int64_t GetFieldByteOffsetToObject(T Class::* field_ptr) {
+  int64_t field_offset_to_class = 
reinterpret_cast<int64_t>(&(static_cast<Class*>(nullptr)->*field_ptr));
+  return field_offset_to_class - 
details::ObjectUnsafe::GetObjectOffsetToSubclass<Class>();
 }
 
 class ReflectionDef {
@@ -82,7 +83,7 @@ class ReflectionDef {
     info.field_static_type_index = Type2FieldStaticTypeIndex<T>::value;
     // store byte offset and setter, getter
     // so the same setter can be reused for all the same type
-    info.byte_offset = GetFieldByteOffset<Class, T>(field_ptr);
+    info.byte_offset = GetFieldByteOffsetToObject<Class, T>(field_ptr);
     info.readonly = readonly;
     info.getter = FieldGetter<T>;
     info.setter = FieldSetter<T>;
diff --git a/ffi/include/tvm/ffi/type_traits.h 
b/ffi/include/tvm/ffi/type_traits.h
index 69b3738be0..c418dfb900 100644
--- a/ffi/include/tvm/ffi/type_traits.h
+++ b/ffi/include/tvm/ffi/type_traits.h
@@ -530,7 +530,7 @@ struct TypeTraits<const TObject*, 
std::enable_if_t<std::is_base_of_v<Object, TOb
   }
 
   static TVM_FFI_INLINE const TObject* CopyFromAnyViewAfterCheck(const 
TVMFFIAny* src) {
-    return reinterpret_cast<const TObject*>(src->v_obj);
+    return details::ObjectUnsafe::RawObjectPtrFromUnowned<TObject>(src->v_obj);
   }
 
   static TVM_FFI_INLINE std::optional<const TObject*> TryCopyFromAnyView(const 
TVMFFIAny* src) {
diff --git a/ffi/tests/cpp/test_object.cc b/ffi/tests/cpp/test_object.cc
index ed3361a0ba..1e05731f48 100644
--- a/ffi/tests/cpp/test_object.cc
+++ b/ffi/tests/cpp/test_object.cc
@@ -34,6 +34,10 @@ TEST(Object, RefCounter) {
   EXPECT_EQ(a->value, 11);
 
   EXPECT_EQ(a.use_count(), 2);
+  ObjectPtr<TIntObj> aa = make_object<TIntObj>(*a);
+  EXPECT_EQ(aa.use_count(), 1);
+  EXPECT_EQ(aa->value, 11);
+
   b.reset();
   EXPECT_EQ(a.use_count(), 1);
   EXPECT_TRUE(b == nullptr);
diff --git a/ffi/tests/cpp/test_reflection.cc b/ffi/tests/cpp/test_reflection.cc
index 9de0015009..a69505935c 100644
--- a/ffi/tests/cpp/test_reflection.cc
+++ b/ffi/tests/cpp/test_reflection.cc
@@ -28,15 +28,16 @@ namespace {
 using namespace tvm::ffi;
 using namespace tvm::ffi::testing;
 
-struct A {
+struct A : public Object {
   ObjectRef obj;
   int32_t x;
   int32_t y;
 };
 
 TEST(Reflection, GetFieldByteOffset) {
-  EXPECT_EQ(details::GetFieldByteOffset(&A::x), 8);
-  EXPECT_EQ(details::GetFieldByteOffset(&A::y), 12);
+  EXPECT_EQ(details::GetFieldByteOffsetToObject(&A::x), 8 + 
sizeof(TVMFFIObject));
+  EXPECT_EQ(details::GetFieldByteOffsetToObject(&A::y), 12 + 
sizeof(TVMFFIObject));
+  EXPECT_EQ(details::GetFieldByteOffsetToObject(&TIntObj::value), 
sizeof(TVMFFIObject));
 }
 
 TEST(Reflection, FieldGetter) {
diff --git a/ffi/tests/cpp/testing_object.h b/ffi/tests/cpp/testing_object.h
index a4d6b1353f..cd6055ec1f 100644
--- a/ffi/tests/cpp/testing_object.h
+++ b/ffi/tests/cpp/testing_object.h
@@ -28,7 +28,17 @@ namespace tvm {
 namespace ffi {
 namespace testing {
 
-class TNumberObj : public Object {
+// We deliberately pad extra
+// in the header to test cases
+// where the object subclass address
+// do not align with the base object address
+// not handling properly will cause buffer overflow
+class BasePad {
+ public:
+  int64_t extra[4];
+};
+
+class TNumberObj : public BasePad, public Object {
  public:
   // declare as one slot, with float as overflow
   static constexpr uint32_t _type_child_slots = 1;

Reply via email to