This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch refactor-s0
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit efd253828e432eaff97b6b478a25f8765500607d
Author: tqchen <[email protected]>
AuthorDate: Tue Aug 6 16:06:58 2024 -0400

    [FFI] Enable ObjectPtr and tests
    
    This commit enables object ptr and tests.
    Move most of the friend class to internal.
---
 ffi/include/tvm/ffi/c_ffi_abi.h  |   7 +-
 ffi/include/tvm/ffi/memory.h     | 213 +++++++++++++++++++++++++++++++++++++++
 ffi/include/tvm/ffi/object.h     |  98 ++++++++++++++++--
 ffi/tests/example/test_object.cc |  27 ++++-
 4 files changed, 335 insertions(+), 10 deletions(-)

diff --git a/ffi/include/tvm/ffi/c_ffi_abi.h b/ffi/include/tvm/ffi/c_ffi_abi.h
index ccba3f58df..f1a1ae78da 100644
--- a/ffi/include/tvm/ffi/c_ffi_abi.h
+++ b/ffi/include/tvm/ffi/c_ffi_abi.h
@@ -51,7 +51,7 @@ extern "C" {
 #endif
 
 #ifdef __cplusplus
-enum class TVMFFITypeIndex : int32_t {
+enum TVMFFITypeIndex : int32_t {
 #else
 typedef enum {
 #endif
@@ -77,7 +77,10 @@ typedef enum {
   kTVMFFIFunc = 68,
   kTVMFFIStr = 69,
   // [Section] Dynamic Boxed: [kTVMFFIDynObjectBegin, +oo)
-  kTVMFFIDynObjectBegin = 128,
+  // kTVMFFIDynObject is used to indicate that the type index
+  // is dynamic and needs to be looked up at runtime
+  kTVMFFIDynObject = 128,
+  kTVMFFIDynObjectBegin = 129
 #ifdef __cplusplus
 };
 #else
diff --git a/ffi/include/tvm/ffi/memory.h b/ffi/include/tvm/ffi/memory.h
index e69de29bb2..cb7067502e 100644
--- a/ffi/include/tvm/ffi/memory.h
+++ b/ffi/include/tvm/ffi/memory.h
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/ffi/memory.h
+ * \brief Runtime memory management to allocate on heap object.
+ */
+#ifndef TVM_FFI_MEMORY_H_
+#define TVM_FFI_MEMORY_H_
+
+#include <tvm/ffi/object.h>
+
+#include <cstdlib>
+#include <type_traits>
+#include <utility>
+
+namespace tvm {
+namespace ffi {
+
+/*! \brief Deleter function for obeject */
+typedef void (*FObjectDeleter)(TVMFFIObject* obj);
+
+/*!
+ * \brief Allocate an object using default allocator.
+ * \param args arguments to the constructor.
+ * \tparam T the node type.
+ * \return The ObjectPtr to the allocated object.
+ */
+template <typename T, typename... Args>
+inline ObjectPtr<T> make_object(Args&&... args);
+
+// Detail implementations after this
+//
+// The current design allows swapping the
+// allocator pattern when necessary.
+//
+// Possible future allocator optimizations:
+// - Arena allocator that gives ownership of memory to arena (deleter = 
nullptr)
+// - Thread-local object pools: one pool per size and alignment requirement.
+// - Can specialize by type of object to give the specific allocator to each 
object.
+
+/*!
+ * \brief Base class of object allocators that implements make.
+ *  Use curiously recurring template pattern.
+ *
+ * \tparam Derived The derived class.
+ */
+template <typename Derived>
+class ObjAllocatorBase {
+ public:
+  /*!
+   * \brief Make a new object using the allocator.
+   * \tparam T The type to be allocated.
+   * \tparam Args The constructor signature.
+   * \param args The arguments.
+   */
+  template <typename T, typename... Args>
+  inline ObjectPtr<T> make_object(Args&&... args) {
+    using Handler = typename Derived::template Handler<T>;
+    static_assert(std::is_base_of<Object, T>::value, "make can only be used to 
create Object");
+    T* ptr = Handler::New(static_cast<Derived*>(this), 
std::forward<Args>(args)...);
+    TVMFFIObject* ffi_ptr = 
details::ObjectInternal::StaticCast<TVMFFIObject*>(ptr);
+    // NOTE: ref_counter is initialized in object constructor
+    ffi_ptr->type_index = T::RuntimeTypeIndex();
+    ffi_ptr->deleter = Handler::Deleter();
+    return details::ObjectInternal::ObjectPtr<T>(ptr);
+  }
+
+  /*!
+   * \tparam ArrayType The type to be allocated.
+   * \tparam ElemType The type of array element.
+   * \tparam Args The constructor signature.
+   * \param num_elems The number of array elements.
+   * \param args The arguments.
+   */
+  template <typename ArrayType, typename ElemType, typename... Args>
+  inline ObjectPtr<ArrayType> make_inplace_array(size_t num_elems, Args&&... 
args) {
+    using Handler = typename Derived::template ArrayHandler<ArrayType, 
ElemType>;
+    static_assert(std::is_base_of<Object, ArrayType>::value,
+                  "make_inplace_array can only be used to create Object");
+    ArrayType* ptr =
+        Handler::New(static_cast<Derived*>(this), num_elems, 
std::forward<Args>(args)...);
+    ptr->type_index_ = ArrayType::RuntimeTypeIndex();
+    ptr->deleter_ = Handler::Deleter();
+    return ObjectPtr<ArrayType>(ptr);
+  }
+};
+
+// Simple allocator that uses new/delete.
+class SimpleObjAllocator : public ObjAllocatorBase<SimpleObjAllocator> {
+ public:
+  template <typename T>
+  class Handler {
+   public:
+    using StorageType = typename std::aligned_storage<sizeof(T), 
alignof(T)>::type;
+
+    template <typename... Args>
+    static T* New(SimpleObjAllocator*, Args&&... args) {
+      // NOTE: the first argument is not needed for SimpleObjAllocator
+      // It is reserved for special allocators that needs to recycle
+      // the object to itself (e.g. in the case of object pool).
+      //
+      // In the case of an object pool, an allocator needs to create
+      // a special chunk memory that hides reference to the allocator
+      // and call allocator's release function in the deleter.
+
+      // NOTE2: Use inplace new to allocate
+      // This is used to get rid of warning when deleting a virtual
+      // class with non-virtual destructor.
+      // We are fine here as we captured the right deleter during construction.
+      // This is also the right way to get storage type for an object pool.
+      StorageType* data = new StorageType();
+      new (data) T(std::forward<Args>(args)...);
+      return reinterpret_cast<T*>(data);
+    }
+
+    static FObjectDeleter Deleter() { return Deleter_; }
+
+   private:
+    static void Deleter_(TVMFFIObject* objptr) {
+      // NOTE: this is important to cast back to T*
+      // because objptr and tptr may not be the same
+      // depending on how sub-class allocates the space.
+      T* tptr = details::ObjectInternal::StaticCast<T*>(objptr);
+      // It is important to do tptr->T::~T(),
+      // so that we explicitly call the specific destructor
+      // instead of tptr->~T(), which could mean the intention
+      // call a virtual destructor(which may not be available and is not 
required).
+      tptr->T::~T();
+      delete reinterpret_cast<StorageType*>(tptr);
+    }
+  };
+
+  // Array handler that uses new/delete.
+  template <typename ArrayType, typename ElemType>
+  class ArrayHandler {
+   public:
+    using StorageType = typename std::aligned_storage<sizeof(ArrayType), 
alignof(ArrayType)>::type;
+    // for now only support elements that aligns with array header.
+    static_assert(alignof(ArrayType) % alignof(ElemType) == 0 &&
+                      sizeof(ArrayType) % alignof(ElemType) == 0,
+                  "element alignment constraint");
+
+    template <typename... Args>
+    static ArrayType* New(SimpleObjAllocator*, size_t num_elems, Args&&... 
args) {
+      // NOTE: the first argument is not needed for ArrayObjAllocator
+      // It is reserved for special allocators that needs to recycle
+      // the object to itself (e.g. in the case of object pool).
+      //
+      // In the case of an object pool, an allocator needs to create
+      // a special chunk memory that hides reference to the allocator
+      // and call allocator's release function in the deleter.
+      // NOTE2: Use inplace new to allocate
+      // This is used to get rid of warning when deleting a virtual
+      // class with non-virtual destructor.
+      // We are fine here as we captured the right deleter during construction.
+      // This is also the right way to get storage type for an object pool.
+      size_t unit = sizeof(StorageType);
+      size_t requested_size = num_elems * sizeof(ElemType) + sizeof(ArrayType);
+      size_t num_storage_slots = (requested_size + unit - 1) / unit;
+      StorageType* data = new StorageType[num_storage_slots];
+      new (data) ArrayType(std::forward<Args>(args)...);
+      return reinterpret_cast<ArrayType*>(data);
+    }
+
+    static FObjectDeleter Deleter() { return Deleter_; }
+
+   private:
+    static void Deleter_(Object* objptr) {
+      // NOTE: this is important to cast back to ArrayType*
+      // because objptr and tptr may not be the same
+      // depending on how sub-class allocates the space.
+      ArrayType* tptr = static_cast<ArrayType*>(objptr);
+      // It is important to do tptr->ArrayType::~ArrayType(),
+      // so that we explicitly call the specific destructor
+      // instead of tptr->~ArrayType(), which could mean the intention
+      // call a virtual destructor(which may not be available and is not 
required).
+      tptr->ArrayType::~ArrayType();
+      StorageType* p = reinterpret_cast<StorageType*>(tptr);
+      delete[] p;
+    }
+  };
+};
+
+template <typename T, typename... Args>
+inline ObjectPtr<T> make_object(Args&&... args) {
+  return SimpleObjAllocator().make_object<T>(std::forward<Args>(args)...);
+}
+
+template <typename ArrayType, typename ElemType, typename... Args>
+inline ObjectPtr<ArrayType> make_inplace_array_object(size_t num_elems, 
Args&&... args) {
+  return SimpleObjAllocator().make_inplace_array<ArrayType, 
ElemType>(num_elems,
+                                                                      
std::forward<Args>(args)...);
+}
+
+}  // namespace ffi
+}  // namespace tvm
+#endif  // TVM_FFI_MEMORY_H_
diff --git a/ffi/include/tvm/ffi/object.h b/ffi/include/tvm/ffi/object.h
index 90e78defae..6b973f943c 100644
--- a/ffi/include/tvm/ffi/object.h
+++ b/ffi/include/tvm/ffi/object.h
@@ -29,22 +29,89 @@
 namespace tvm {
 namespace ffi {
 
+using TypeIndex = TVMFFITypeIndex;
+
 namespace details {
 // forward declare object internal
 struct ObjectInternal;
 }  // namespace details
 
-class Object : protected TVMFFIObject {
+/*!
+ * \brief base class of all object containers.
+ *
+ * Sub-class of objects should declare the following static constexpr fields:
+ *
+ * - _type_index:
+ *      Static type index of the object, if assigned to 
TypeIndex::kTVMFFIDynObject
+ *      the type index will be assigned during runtime.
+ *      Runtime type index can be accessed by ObjectType::TypeIndex();
+ * - _type_key:
+ *       The unique string identifier of the type.
+ * - _type_final:
+ *       Whether the type is terminal type(there is no subclass of the type in 
the object system).
+ *       This field is automatically set by macro TVM_DECLARE_FINAL_OBJECT_INFO
+ *       It is still OK to sub-class a terminal object type T and construct it 
using make_object.
+ *       But IsInstance check will only show that the object type is T(instead 
of the sub-class).
+ *
+ * The following two fields are necessary for base classes that can be 
sub-classed.
+ *
+ * - _type_child_slots:
+ *       Number of reserved type index slots for child classes.
+ *       Used for runtime optimization for type checking in IsInstance.
+ *       If an object's type_index is within range of [type_index, type_index 
+ _type_child_slots]
+ *       Then the object can be quickly decided as sub-class of the current 
object class.
+ *       If not, a fallback mechanism is used to check the global type table.
+ *       Recommendation: set to estimate number of children needed.
+ *
+ * - _type_child_slots_can_overflow:
+ *       Whether we can add additional child classes even if the number of 
child classes
+ *       exceeds the _type_child_slots. A fallback mechanism to check global 
type table will be
+ * used. Recommendation: set to false for optimal runtime speed if we know 
exact number of children.
+ *
+ * Two macros are used to declare helper functions in the object:
+ * - Use TVM_FFI_DECLARE_BASE_OBJECT_INFO for object classes that can be 
sub-classed.
+ * - Use TVM_FFI_DECLARE_FINAL_OBJECT_INFO for object classes that cannot be 
sub-classed.
+ *
+ * New objects can be created using make_object function.
+ * Which will automatically populate the type_index and deleter of the object.
+ */
+class Object : private TVMFFIObject {
  public:
   Object() {
     TVMFFIObject::ref_counter = 0;
     TVMFFIObject::deleter = nullptr;
   }
 
+  // Information about the object
+  static constexpr const char* _type_key = "runtime.Object";
+
+  // Default object type properties for sub-classes
+  static constexpr bool _type_final = false;
+  static constexpr uint32_t _type_child_slots = 0;
+  static constexpr bool _type_child_slots_can_overflow = true;
+  // NOTE: the following field is not type index of Object
+  // but was intended to be used by sub-classes as default value.
+  // The type index of Object is TypeIndex::kRoot
+  static constexpr int32_t _type_index = TypeIndex::kTVMFFIDynObject;
+
+  // The following functions are provided by macro
+  // TVM_FFI_DECLARE_BASE_OBJECT_INFO and TVM_DECLARE_FINAL_OBJECT_INFO
+  /*!
+   * \brief Get the runtime allocated type index of the type
+   * \note Getting this information may need dynamic calls into a global table.
+   */
+  static int32_t RuntimeTypeIndex() { return TypeIndex::kTVMFFIObject; }
+  /*!
+   * \brief Internal function to get or allocate a runtime index.
+   * \note
+   */
+  static int32_t _GetOrAllocRuntimeTypeIndex() { return 
TypeIndex::kTVMFFIObject; }
+
  private:
-  /*! \brief decreas*/
+  /*! \brief increase reference count */
   void IncRef() { details::AtomicIncrementRelaxed(&(this->ref_counter)); }
 
+  /*! \brief decrease reference count and delete the object */
   void DecRef() {
     if (details::AtomicDecrementRelAcq(&(this->ref_counter)) == 1) {
       if (this->deleter != nullptr) {
@@ -177,9 +244,9 @@ class ObjectPtr {
   /*! \return Whether two ObjectPtr equals each other */
   bool operator!=(const ObjectPtr<T>& other) const { return data_ != 
other.data_; }
   /*! \return Whether the pointer is nullptr */
-  bool operator==(std::nullptr_t null) const { return data_ == nullptr; }
+  bool operator==(std::nullptr_t) const { return data_ == nullptr; }
   /*! \return Whether the pointer is not nullptr */
-  bool operator!=(std::nullptr_t null) const { return data_ != nullptr; }
+  bool operator!=(std::nullptr_t) const { return data_ != nullptr; }
 
  private:
   /*! \brief internal pointer field */
@@ -207,6 +274,9 @@ class ObjectPtr {
   // friend classes
   friend class Object;
   friend class ObjectRef;
+  friend struct ObjectPtrHash;
+  template <typename>
+  friend class ObjectPtr;
   friend class tvm::ffi::details::ObjectInternal;
   template <typename RelayRefType, typename ObjType>
   friend RelayRefType GetRef(const ObjType* ptr);
@@ -215,8 +285,24 @@ class ObjectPtr {
 };
 
 namespace details {
-/*! \brief Namespace to internally manipulate object class. */
-struct ObjectInternal {};
+/*!
+ * \brief Namespace to internally manipulate object class.
+ * \note These functions are only supposed to be used by internal
+ * implementations and not external users of the tvm::ffi
+ */
+struct ObjectInternal {
+  // NOTE: these helper to perform static cast
+  // that also allows conversion from/to FFI values
+  template <typename T, typename U>
+  static TVM_FFI_INLINE T StaticCast(U src) {
+    return static_cast<T>(src);
+  }
+
+  template <typename T>
+  static TVM_FFI_INLINE ObjectPtr<T> ObjectPtr(Object* raw_ptr) {
+    return tvm::ffi::ObjectPtr<T>(raw_ptr);
+  }
+};
 }  // namespace details
 }  // namespace ffi
 }  // namespace tvm
diff --git a/ffi/tests/example/test_object.cc b/ffi/tests/example/test_object.cc
index 95b7d0717f..959246fda0 100644
--- a/ffi/tests/example/test_object.cc
+++ b/ffi/tests/example/test_object.cc
@@ -1,12 +1,35 @@
 #include <gtest/gtest.h>
+#include <tvm/ffi/memory.h>
 #include <tvm/ffi/object.h>
 
 namespace {
 
 using namespace tvm::ffi;
 
-TEST(Object, Default) {
-  // Object* x = new Object();
+class IntObj : public Object {
+ public:
+  int64_t value;
+
+  IntObj(int64_t value) : value(value) {}
+};
+
+TEST(Object, RefCounter) {
+  ObjectPtr<IntObj> a = make_object<IntObj>(11);
+  ObjectPtr<IntObj> b = a;
+
+  EXPECT_EQ(a->value, 11);
+
+  EXPECT_EQ(a.use_count(), 2);
+  b.reset();
+  EXPECT_EQ(a.use_count(), 1);
+  EXPECT_TRUE(b == nullptr);
+  EXPECT_EQ(b.use_count(), 0);
+
+  ObjectPtr<IntObj> c = std::move(a);
+  EXPECT_EQ(c.use_count(), 1);
+  EXPECT_TRUE(a == nullptr);
+
+  EXPECT_EQ(c->value, 11);
 }
 
 }  // namespace

Reply via email to