This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s0 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit efd253828e432eaff97b6b478a25f8765500607d Author: tqchen <[email protected]> AuthorDate: Tue Aug 6 16:06:58 2024 -0400 [FFI] Enable ObjectPtr and tests This commit enables object ptr and tests. Move most of the friend class to internal. --- ffi/include/tvm/ffi/c_ffi_abi.h | 7 +- ffi/include/tvm/ffi/memory.h | 213 +++++++++++++++++++++++++++++++++++++++ ffi/include/tvm/ffi/object.h | 98 ++++++++++++++++-- ffi/tests/example/test_object.cc | 27 ++++- 4 files changed, 335 insertions(+), 10 deletions(-) diff --git a/ffi/include/tvm/ffi/c_ffi_abi.h b/ffi/include/tvm/ffi/c_ffi_abi.h index ccba3f58df..f1a1ae78da 100644 --- a/ffi/include/tvm/ffi/c_ffi_abi.h +++ b/ffi/include/tvm/ffi/c_ffi_abi.h @@ -51,7 +51,7 @@ extern "C" { #endif #ifdef __cplusplus -enum class TVMFFITypeIndex : int32_t { +enum TVMFFITypeIndex : int32_t { #else typedef enum { #endif @@ -77,7 +77,10 @@ typedef enum { kTVMFFIFunc = 68, kTVMFFIStr = 69, // [Section] Dynamic Boxed: [kTVMFFIDynObjectBegin, +oo) - kTVMFFIDynObjectBegin = 128, + // kTVMFFIDynObject is used to indicate that the type index + // is dynamic and needs to be looked up at runtime + kTVMFFIDynObject = 128, + kTVMFFIDynObjectBegin = 129 #ifdef __cplusplus }; #else diff --git a/ffi/include/tvm/ffi/memory.h b/ffi/include/tvm/ffi/memory.h index e69de29bb2..cb7067502e 100644 --- a/ffi/include/tvm/ffi/memory.h +++ b/ffi/include/tvm/ffi/memory.h @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/memory.h + * \brief Runtime memory management to allocate on heap object. + */ +#ifndef TVM_FFI_MEMORY_H_ +#define TVM_FFI_MEMORY_H_ + +#include <tvm/ffi/object.h> + +#include <cstdlib> +#include <type_traits> +#include <utility> + +namespace tvm { +namespace ffi { + +/*! \brief Deleter function for obeject */ +typedef void (*FObjectDeleter)(TVMFFIObject* obj); + +/*! + * \brief Allocate an object using default allocator. + * \param args arguments to the constructor. + * \tparam T the node type. + * \return The ObjectPtr to the allocated object. + */ +template <typename T, typename... Args> +inline ObjectPtr<T> make_object(Args&&... args); + +// Detail implementations after this +// +// The current design allows swapping the +// allocator pattern when necessary. +// +// Possible future allocator optimizations: +// - Arena allocator that gives ownership of memory to arena (deleter = nullptr) +// - Thread-local object pools: one pool per size and alignment requirement. +// - Can specialize by type of object to give the specific allocator to each object. + +/*! + * \brief Base class of object allocators that implements make. + * Use curiously recurring template pattern. + * + * \tparam Derived The derived class. + */ +template <typename Derived> +class ObjAllocatorBase { + public: + /*! + * \brief Make a new object using the allocator. + * \tparam T The type to be allocated. + * \tparam Args The constructor signature. + * \param args The arguments. + */ + template <typename T, typename... Args> + inline ObjectPtr<T> make_object(Args&&... args) { + using Handler = typename Derived::template Handler<T>; + static_assert(std::is_base_of<Object, T>::value, "make can only be used to create Object"); + T* ptr = Handler::New(static_cast<Derived*>(this), std::forward<Args>(args)...); + TVMFFIObject* ffi_ptr = details::ObjectInternal::StaticCast<TVMFFIObject*>(ptr); + // NOTE: ref_counter is initialized in object constructor + ffi_ptr->type_index = T::RuntimeTypeIndex(); + ffi_ptr->deleter = Handler::Deleter(); + return details::ObjectInternal::ObjectPtr<T>(ptr); + } + + /*! + * \tparam ArrayType The type to be allocated. + * \tparam ElemType The type of array element. + * \tparam Args The constructor signature. + * \param num_elems The number of array elements. + * \param args The arguments. + */ + template <typename ArrayType, typename ElemType, typename... Args> + inline ObjectPtr<ArrayType> make_inplace_array(size_t num_elems, Args&&... args) { + using Handler = typename Derived::template ArrayHandler<ArrayType, ElemType>; + static_assert(std::is_base_of<Object, ArrayType>::value, + "make_inplace_array can only be used to create Object"); + ArrayType* ptr = + Handler::New(static_cast<Derived*>(this), num_elems, std::forward<Args>(args)...); + ptr->type_index_ = ArrayType::RuntimeTypeIndex(); + ptr->deleter_ = Handler::Deleter(); + return ObjectPtr<ArrayType>(ptr); + } +}; + +// Simple allocator that uses new/delete. +class SimpleObjAllocator : public ObjAllocatorBase<SimpleObjAllocator> { + public: + template <typename T> + class Handler { + public: + using StorageType = typename std::aligned_storage<sizeof(T), alignof(T)>::type; + + template <typename... Args> + static T* New(SimpleObjAllocator*, Args&&... args) { + // NOTE: the first argument is not needed for SimpleObjAllocator + // It is reserved for special allocators that needs to recycle + // the object to itself (e.g. in the case of object pool). + // + // In the case of an object pool, an allocator needs to create + // a special chunk memory that hides reference to the allocator + // and call allocator's release function in the deleter. + + // NOTE2: Use inplace new to allocate + // This is used to get rid of warning when deleting a virtual + // class with non-virtual destructor. + // We are fine here as we captured the right deleter during construction. + // This is also the right way to get storage type for an object pool. + StorageType* data = new StorageType(); + new (data) T(std::forward<Args>(args)...); + return reinterpret_cast<T*>(data); + } + + static FObjectDeleter Deleter() { return Deleter_; } + + private: + static void Deleter_(TVMFFIObject* objptr) { + // NOTE: this is important to cast back to T* + // because objptr and tptr may not be the same + // depending on how sub-class allocates the space. + T* tptr = details::ObjectInternal::StaticCast<T*>(objptr); + // It is important to do tptr->T::~T(), + // so that we explicitly call the specific destructor + // instead of tptr->~T(), which could mean the intention + // call a virtual destructor(which may not be available and is not required). + tptr->T::~T(); + delete reinterpret_cast<StorageType*>(tptr); + } + }; + + // Array handler that uses new/delete. + template <typename ArrayType, typename ElemType> + class ArrayHandler { + public: + using StorageType = typename std::aligned_storage<sizeof(ArrayType), alignof(ArrayType)>::type; + // for now only support elements that aligns with array header. + static_assert(alignof(ArrayType) % alignof(ElemType) == 0 && + sizeof(ArrayType) % alignof(ElemType) == 0, + "element alignment constraint"); + + template <typename... Args> + static ArrayType* New(SimpleObjAllocator*, size_t num_elems, Args&&... args) { + // NOTE: the first argument is not needed for ArrayObjAllocator + // It is reserved for special allocators that needs to recycle + // the object to itself (e.g. in the case of object pool). + // + // In the case of an object pool, an allocator needs to create + // a special chunk memory that hides reference to the allocator + // and call allocator's release function in the deleter. + // NOTE2: Use inplace new to allocate + // This is used to get rid of warning when deleting a virtual + // class with non-virtual destructor. + // We are fine here as we captured the right deleter during construction. + // This is also the right way to get storage type for an object pool. + size_t unit = sizeof(StorageType); + size_t requested_size = num_elems * sizeof(ElemType) + sizeof(ArrayType); + size_t num_storage_slots = (requested_size + unit - 1) / unit; + StorageType* data = new StorageType[num_storage_slots]; + new (data) ArrayType(std::forward<Args>(args)...); + return reinterpret_cast<ArrayType*>(data); + } + + static FObjectDeleter Deleter() { return Deleter_; } + + private: + static void Deleter_(Object* objptr) { + // NOTE: this is important to cast back to ArrayType* + // because objptr and tptr may not be the same + // depending on how sub-class allocates the space. + ArrayType* tptr = static_cast<ArrayType*>(objptr); + // It is important to do tptr->ArrayType::~ArrayType(), + // so that we explicitly call the specific destructor + // instead of tptr->~ArrayType(), which could mean the intention + // call a virtual destructor(which may not be available and is not required). + tptr->ArrayType::~ArrayType(); + StorageType* p = reinterpret_cast<StorageType*>(tptr); + delete[] p; + } + }; +}; + +template <typename T, typename... Args> +inline ObjectPtr<T> make_object(Args&&... args) { + return SimpleObjAllocator().make_object<T>(std::forward<Args>(args)...); +} + +template <typename ArrayType, typename ElemType, typename... Args> +inline ObjectPtr<ArrayType> make_inplace_array_object(size_t num_elems, Args&&... args) { + return SimpleObjAllocator().make_inplace_array<ArrayType, ElemType>(num_elems, + std::forward<Args>(args)...); +} + +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_MEMORY_H_ diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h index 90e78defae..6b973f943c 100644 --- a/ffi/include/tvm/ffi/object.h +++ b/ffi/include/tvm/ffi/object.h @@ -29,22 +29,89 @@ namespace tvm { namespace ffi { +using TypeIndex = TVMFFITypeIndex; + namespace details { // forward declare object internal struct ObjectInternal; } // namespace details -class Object : protected TVMFFIObject { +/*! + * \brief base class of all object containers. + * + * Sub-class of objects should declare the following static constexpr fields: + * + * - _type_index: + * Static type index of the object, if assigned to TypeIndex::kTVMFFIDynObject + * the type index will be assigned during runtime. + * Runtime type index can be accessed by ObjectType::TypeIndex(); + * - _type_key: + * The unique string identifier of the type. + * - _type_final: + * Whether the type is terminal type(there is no subclass of the type in the object system). + * This field is automatically set by macro TVM_DECLARE_FINAL_OBJECT_INFO + * It is still OK to sub-class a terminal object type T and construct it using make_object. + * But IsInstance check will only show that the object type is T(instead of the sub-class). + * + * The following two fields are necessary for base classes that can be sub-classed. + * + * - _type_child_slots: + * Number of reserved type index slots for child classes. + * Used for runtime optimization for type checking in IsInstance. + * If an object's type_index is within range of [type_index, type_index + _type_child_slots] + * Then the object can be quickly decided as sub-class of the current object class. + * If not, a fallback mechanism is used to check the global type table. + * Recommendation: set to estimate number of children needed. + * + * - _type_child_slots_can_overflow: + * Whether we can add additional child classes even if the number of child classes + * exceeds the _type_child_slots. A fallback mechanism to check global type table will be + * used. Recommendation: set to false for optimal runtime speed if we know exact number of children. + * + * Two macros are used to declare helper functions in the object: + * - Use TVM_FFI_DECLARE_BASE_OBJECT_INFO for object classes that can be sub-classed. + * - Use TVM_FFI_DECLARE_FINAL_OBJECT_INFO for object classes that cannot be sub-classed. + * + * New objects can be created using make_object function. + * Which will automatically populate the type_index and deleter of the object. + */ +class Object : private TVMFFIObject { public: Object() { TVMFFIObject::ref_counter = 0; TVMFFIObject::deleter = nullptr; } + // Information about the object + static constexpr const char* _type_key = "runtime.Object"; + + // Default object type properties for sub-classes + static constexpr bool _type_final = false; + static constexpr uint32_t _type_child_slots = 0; + static constexpr bool _type_child_slots_can_overflow = true; + // NOTE: the following field is not type index of Object + // but was intended to be used by sub-classes as default value. + // The type index of Object is TypeIndex::kRoot + static constexpr int32_t _type_index = TypeIndex::kTVMFFIDynObject; + + // The following functions are provided by macro + // TVM_FFI_DECLARE_BASE_OBJECT_INFO and TVM_DECLARE_FINAL_OBJECT_INFO + /*! + * \brief Get the runtime allocated type index of the type + * \note Getting this information may need dynamic calls into a global table. + */ + static int32_t RuntimeTypeIndex() { return TypeIndex::kTVMFFIObject; } + /*! + * \brief Internal function to get or allocate a runtime index. + * \note + */ + static int32_t _GetOrAllocRuntimeTypeIndex() { return TypeIndex::kTVMFFIObject; } + private: - /*! \brief decreas*/ + /*! \brief increase reference count */ void IncRef() { details::AtomicIncrementRelaxed(&(this->ref_counter)); } + /*! \brief decrease reference count and delete the object */ void DecRef() { if (details::AtomicDecrementRelAcq(&(this->ref_counter)) == 1) { if (this->deleter != nullptr) { @@ -177,9 +244,9 @@ class ObjectPtr { /*! \return Whether two ObjectPtr equals each other */ bool operator!=(const ObjectPtr<T>& other) const { return data_ != other.data_; } /*! \return Whether the pointer is nullptr */ - bool operator==(std::nullptr_t null) const { return data_ == nullptr; } + bool operator==(std::nullptr_t) const { return data_ == nullptr; } /*! \return Whether the pointer is not nullptr */ - bool operator!=(std::nullptr_t null) const { return data_ != nullptr; } + bool operator!=(std::nullptr_t) const { return data_ != nullptr; } private: /*! \brief internal pointer field */ @@ -207,6 +274,9 @@ class ObjectPtr { // friend classes friend class Object; friend class ObjectRef; + friend struct ObjectPtrHash; + template <typename> + friend class ObjectPtr; friend class tvm::ffi::details::ObjectInternal; template <typename RelayRefType, typename ObjType> friend RelayRefType GetRef(const ObjType* ptr); @@ -215,8 +285,24 @@ class ObjectPtr { }; namespace details { -/*! \brief Namespace to internally manipulate object class. */ -struct ObjectInternal {}; +/*! + * \brief Namespace to internally manipulate object class. + * \note These functions are only supposed to be used by internal + * implementations and not external users of the tvm::ffi + */ +struct ObjectInternal { + // NOTE: these helper to perform static cast + // that also allows conversion from/to FFI values + template <typename T, typename U> + static TVM_FFI_INLINE T StaticCast(U src) { + return static_cast<T>(src); + } + + template <typename T> + static TVM_FFI_INLINE ObjectPtr<T> ObjectPtr(Object* raw_ptr) { + return tvm::ffi::ObjectPtr<T>(raw_ptr); + } +}; } // namespace details } // namespace ffi } // namespace tvm diff --git a/ffi/tests/example/test_object.cc b/ffi/tests/example/test_object.cc index 95b7d0717f..959246fda0 100644 --- a/ffi/tests/example/test_object.cc +++ b/ffi/tests/example/test_object.cc @@ -1,12 +1,35 @@ #include <gtest/gtest.h> +#include <tvm/ffi/memory.h> #include <tvm/ffi/object.h> namespace { using namespace tvm::ffi; -TEST(Object, Default) { - // Object* x = new Object(); +class IntObj : public Object { + public: + int64_t value; + + IntObj(int64_t value) : value(value) {} +}; + +TEST(Object, RefCounter) { + ObjectPtr<IntObj> a = make_object<IntObj>(11); + ObjectPtr<IntObj> b = a; + + EXPECT_EQ(a->value, 11); + + EXPECT_EQ(a.use_count(), 2); + b.reset(); + EXPECT_EQ(a.use_count(), 1); + EXPECT_TRUE(b == nullptr); + EXPECT_EQ(b.use_count(), 0); + + ObjectPtr<IntObj> c = std::move(a); + EXPECT_EQ(c.use_count(), 1); + EXPECT_TRUE(a == nullptr); + + EXPECT_EQ(c->value, 11); } } // namespace
