This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s0 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 8503919cdb80a9ab45a801cc9590eb18a8c99424 Author: tqchen <[email protected]> AuthorDate: Mon Mar 10 09:43:07 2025 -0400 [FFI] Enhance the deleter and reflection to be aware of padding Ensure correct calculation of class ptr based on obj ptr. --- ffi/include/tvm/ffi/c_api.h | 2 +- ffi/include/tvm/ffi/memory.h | 23 +++++++++-------------- ffi/include/tvm/ffi/object.h | 21 ++++++++++++++++++--- ffi/include/tvm/ffi/reflection.h | 7 ++++--- ffi/include/tvm/ffi/type_traits.h | 2 +- ffi/tests/cpp/test_object.cc | 4 ++++ ffi/tests/cpp/test_reflection.cc | 7 ++++--- ffi/tests/cpp/testing_object.h | 12 +++++++++++- 8 files changed, 52 insertions(+), 26 deletions(-) diff --git a/ffi/include/tvm/ffi/c_api.h b/ffi/include/tvm/ffi/c_api.h index 5a626d03c0..a1cf753448 100644 --- a/ffi/include/tvm/ffi/c_api.h +++ b/ffi/include/tvm/ffi/c_api.h @@ -111,7 +111,7 @@ typedef struct TVMFFIObject { /*! \brief Reference counter of the object. */ int32_t ref_counter; /*! \brief Deleter to be invoked when reference counter goes to zero. */ - void (*deleter)(void* self); + void (*deleter)(struct TVMFFIObject* self); } TVMFFIObject; /*! diff --git a/ffi/include/tvm/ffi/memory.h b/ffi/include/tvm/ffi/memory.h index d0e5995fa5..d3af2dd490 100644 --- a/ffi/include/tvm/ffi/memory.h +++ b/ffi/include/tvm/ffi/memory.h @@ -33,7 +33,7 @@ namespace tvm { namespace ffi { /*! \brief Deleter function for obeject */ -typedef void (*FObjectDeleter)(void* obj); +typedef void (*FObjectDeleter)(TVMFFIObject* obj); /*! * \brief Allocate an object using default allocator. @@ -75,10 +75,10 @@ class ObjAllocatorBase { static_assert(std::is_base_of<Object, T>::value, "make can only be used to create Object"); T* ptr = Handler::New(static_cast<Derived*>(this), std::forward<Args>(args)...); TVMFFIObject* ffi_ptr = details::ObjectUnsafe::GetHeader(ptr); - // NOTE: ref_counter is initialized in object constructor + ffi_ptr->ref_counter = 1; ffi_ptr->type_index = T::RuntimeTypeIndex(); ffi_ptr->deleter = Handler::Deleter(); - return details::ObjectUnsafe::ObjectPtrFromUnowned<T>(ptr); + return details::ObjectUnsafe::ObjectPtrFromOwned<T>(ptr); } /*! @@ -96,9 +96,10 @@ class ObjAllocatorBase { ArrayType* ptr = Handler::New(static_cast<Derived*>(this), num_elems, std::forward<Args>(args)...); TVMFFIObject* ffi_ptr = details::ObjectUnsafe::GetHeader(ptr); + ffi_ptr->ref_counter = 1; ffi_ptr->type_index = ArrayType::RuntimeTypeIndex(); ffi_ptr->deleter = Handler::Deleter(); - return details::ObjectUnsafe::ObjectPtrFromUnowned<ArrayType>(ptr); + return details::ObjectUnsafe::ObjectPtrFromOwned<ArrayType>(ptr); } }; @@ -133,11 +134,8 @@ class SimpleObjAllocator : public ObjAllocatorBase<SimpleObjAllocator> { static FObjectDeleter Deleter() { return Deleter_; } private: - static void Deleter_(void* objptr) { - // NOTE: this is important to cast back to T* - // because objptr and tptr may not be the same - // depending on how sub-class allocates the space. - T* tptr = static_cast<T*>(objptr); + static void Deleter_(TVMFFIObject* objptr) { + T* tptr = details::ObjectUnsafe::RawObjectPtrFromUnowned<T>(objptr); // It is important to do tptr->T::~T(), // so that we explicitly call the specific destructor // instead of tptr->~T(), which could mean the intention @@ -182,11 +180,8 @@ class SimpleObjAllocator : public ObjAllocatorBase<SimpleObjAllocator> { static FObjectDeleter Deleter() { return Deleter_; } private: - static void Deleter_(void* objptr) { - // NOTE: this is important to cast back to ArrayType* - // because objptr and tptr may not be the same - // depending on how sub-class allocates the space. - ArrayType* tptr = static_cast<ArrayType*>(objptr); + static void Deleter_(TVMFFIObject* objptr) { + ArrayType* tptr = details::ObjectUnsafe::RawObjectPtrFromUnowned<ArrayType>(objptr); // It is important to do tptr->ArrayType::~ArrayType(), // so that we explicitly call the specific destructor // instead of tptr->~ArrayType(), which could mean the intention diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h index 70b68f6fc6..8173a362e2 100644 --- a/ffi/include/tvm/ffi/object.h +++ b/ffi/include/tvm/ffi/object.h @@ -104,7 +104,6 @@ class Object { header_.ref_counter = 0; header_.deleter = nullptr; } - /*! * Check if the object is an instance of TargetType. * \tparam TargetType The target type to be checked. @@ -156,7 +155,7 @@ class Object { void DecRef() { if (details::AtomicDecrementRelAcq(&(header_.ref_counter)) == 1) { if (header_.deleter != nullptr) { - header_.deleter(this); + header_.deleter(&(this->header_)); } } } @@ -520,6 +519,14 @@ class ObjectUnsafe { return const_cast<TVMFFIObject*>(&(src->header_)); } + template <typename Class> + static TVM_FFI_INLINE int64_t GetObjectOffsetToSubclass() { + return ( + reinterpret_cast<int64_t>(&(static_cast<Class*>(nullptr)->header_)) - + reinterpret_cast<int64_t>(&(static_cast<Object*>(nullptr)->header_)) + ); + } + template <typename T> static TVM_FFI_INLINE ObjectPtr<T> ObjectPtrFromOwned(Object* raw_ptr) { tvm::ffi::ObjectPtr<T> ptr; @@ -527,7 +534,15 @@ class ObjectUnsafe { return ptr; } - // Create ObjectPtr from unknowned ptr + template <typename T> + static TVM_FFI_INLINE T* RawObjectPtrFromUnowned(TVMFFIObject* obj_ptr) { + // NOTE: this is important to first cast to Object* + // then cast back to T* because objptr and tptr may not be the same + // depending on how sub-class allocates the space. + return static_cast<T*>(reinterpret_cast<Object*>(obj_ptr)); + } + + // Create ObjectPtr from unowned ptr template <typename T> static TVM_FFI_INLINE ObjectPtr<T> ObjectPtrFromUnowned(Object* raw_ptr) { return tvm::ffi::ObjectPtr<T>(raw_ptr); diff --git a/ffi/include/tvm/ffi/reflection.h b/ffi/include/tvm/ffi/reflection.h index 3e6a1ecc37..6da880ab1f 100644 --- a/ffi/include/tvm/ffi/reflection.h +++ b/ffi/include/tvm/ffi/reflection.h @@ -52,8 +52,9 @@ struct Type2FieldStaticTypeIndex<T, std::enable_if_t<TypeTraits<T>::enabled>> { * \returns The byteoffset */ template <typename Class, typename T> -inline int64_t GetFieldByteOffset(T Class::* field_ptr) { - return reinterpret_cast<int64_t>(&(static_cast<Class*>(nullptr)->*field_ptr)); +inline int64_t GetFieldByteOffsetToObject(T Class::* field_ptr) { + int64_t field_offset_to_class = reinterpret_cast<int64_t>(&(static_cast<Class*>(nullptr)->*field_ptr)); + return field_offset_to_class - details::ObjectUnsafe::GetObjectOffsetToSubclass<Class>(); } class ReflectionDef { @@ -82,7 +83,7 @@ class ReflectionDef { info.field_static_type_index = Type2FieldStaticTypeIndex<T>::value; // store byte offset and setter, getter // so the same setter can be reused for all the same type - info.byte_offset = GetFieldByteOffset<Class, T>(field_ptr); + info.byte_offset = GetFieldByteOffsetToObject<Class, T>(field_ptr); info.readonly = readonly; info.getter = FieldGetter<T>; info.setter = FieldSetter<T>; diff --git a/ffi/include/tvm/ffi/type_traits.h b/ffi/include/tvm/ffi/type_traits.h index 69b3738be0..c418dfb900 100644 --- a/ffi/include/tvm/ffi/type_traits.h +++ b/ffi/include/tvm/ffi/type_traits.h @@ -530,7 +530,7 @@ struct TypeTraits<const TObject*, std::enable_if_t<std::is_base_of_v<Object, TOb } static TVM_FFI_INLINE const TObject* CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { - return reinterpret_cast<const TObject*>(src->v_obj); + return details::ObjectUnsafe::RawObjectPtrFromUnowned<TObject>(src->v_obj); } static TVM_FFI_INLINE std::optional<const TObject*> TryCopyFromAnyView(const TVMFFIAny* src) { diff --git a/ffi/tests/cpp/test_object.cc b/ffi/tests/cpp/test_object.cc index ed3361a0ba..1e05731f48 100644 --- a/ffi/tests/cpp/test_object.cc +++ b/ffi/tests/cpp/test_object.cc @@ -34,6 +34,10 @@ TEST(Object, RefCounter) { EXPECT_EQ(a->value, 11); EXPECT_EQ(a.use_count(), 2); + ObjectPtr<TIntObj> aa = make_object<TIntObj>(*a); + EXPECT_EQ(aa.use_count(), 1); + EXPECT_EQ(aa->value, 11); + b.reset(); EXPECT_EQ(a.use_count(), 1); EXPECT_TRUE(b == nullptr); diff --git a/ffi/tests/cpp/test_reflection.cc b/ffi/tests/cpp/test_reflection.cc index 9de0015009..a69505935c 100644 --- a/ffi/tests/cpp/test_reflection.cc +++ b/ffi/tests/cpp/test_reflection.cc @@ -28,15 +28,16 @@ namespace { using namespace tvm::ffi; using namespace tvm::ffi::testing; -struct A { +struct A : public Object { ObjectRef obj; int32_t x; int32_t y; }; TEST(Reflection, GetFieldByteOffset) { - EXPECT_EQ(details::GetFieldByteOffset(&A::x), 8); - EXPECT_EQ(details::GetFieldByteOffset(&A::y), 12); + EXPECT_EQ(details::GetFieldByteOffsetToObject(&A::x), 8 + sizeof(TVMFFIObject)); + EXPECT_EQ(details::GetFieldByteOffsetToObject(&A::y), 12 + sizeof(TVMFFIObject)); + EXPECT_EQ(details::GetFieldByteOffsetToObject(&TIntObj::value), sizeof(TVMFFIObject)); } TEST(Reflection, FieldGetter) { diff --git a/ffi/tests/cpp/testing_object.h b/ffi/tests/cpp/testing_object.h index a4d6b1353f..cd6055ec1f 100644 --- a/ffi/tests/cpp/testing_object.h +++ b/ffi/tests/cpp/testing_object.h @@ -28,7 +28,17 @@ namespace tvm { namespace ffi { namespace testing { -class TNumberObj : public Object { +// We deliberately pad extra +// in the header to test cases +// where the object subclass address +// do not align with the base object address +// not handling properly will cause buffer overflow +class BasePad { + public: + int64_t extra[4]; +}; + +class TNumberObj : public BasePad, public Object { public: // declare as one slot, with float as overflow static constexpr uint32_t _type_child_slots = 1;
