This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 43d13e8  [CORE] Update logic to use combined ref count  (#58)
43d13e8 is described below

commit 43d13e86ee24d1558f929e3b0faa3182ca1af872
Author: Tianqi Chen <[email protected]>
AuthorDate: Fri Sep 26 13:05:50 2025 -0400

    [CORE] Update logic to use combined ref count  (#58)
    
    This PR follows the previous PR to update the change to use
    combined ref count optimization.
---
 docs/concepts/abi_overview.md  |  14 +++--
 include/tvm/ffi/c_api.h        |  27 +++++---
 include/tvm/ffi/memory.h       |   6 +-
 include/tvm/ffi/object.h       | 136 ++++++++++++++++++++++++++---------------
 pyproject.toml                 |   2 +-
 python/tvm_ffi/cython/base.pxi |   5 +-
 6 files changed, 119 insertions(+), 71 deletions(-)

diff --git a/docs/concepts/abi_overview.md b/docs/concepts/abi_overview.md
index b93397a..47639b5 100644
--- a/docs/concepts/abi_overview.md
+++ b/docs/concepts/abi_overview.md
@@ -191,8 +191,7 @@ we adopt a unified object storage format, defined as 
follows:
 
 ```c++
 typedef struct TVMFFIObject {
-  uint32_t strong_ref_count;
-  uint32_t weak_ref_count;
+  uint64_t combined_ref_count;
   int32_t type_index;
   uint32_t __padding;
   union {
@@ -204,13 +203,16 @@ typedef struct TVMFFIObject {
 
 `TVMFFIObject` defines a common 24-byte intrusive header that all in-memory 
objects share:
 
-- `strong_ref_count` stores the strong atomic reference counter of the object.
-- `weak_ref_count` stores the weak atomic reference counter of the object.
+- `combined_ref_count` packs strong and weak reference counter of the object 
into a single 64bit field
+  - The lower 32bits stores the strong atomic reference counter:
+    `strong_ref_count = combined_ref_count & 0xFFFFFFFF`
+  - The higher 32bits stores the weak atomic reference counter:
+    `weak_ref_count = (combined_ref_count >> 32) & 0xFFFFFFFF`
 - `type_index` helps us identify the type being stored, which is consistent 
with `TVMFFIAny.type_index`.
 - `deleter` should be called when either the strong or weak ref counter goes 
to zero.
   - The flags are set to indicate the event of either weak or strong going to 
zero, or both.
-  - When `strong_ref_count` gets to zero, the deleter needs to call the 
destructor of the object.
-  - When `weak_ref_count` gets to zero, the deleter needs to free the memory 
allocated by self.
+  - When strong reference counter gets to zero, the deleter needs to call the 
destructor of the object.
+  - When weak reference counter gets to zero, the deleter needs to free the 
memory allocated by self.
 
 **Rationales:** There are several considerations when designing the data 
structure:
 
diff --git a/include/tvm/ffi/c_api.h b/include/tvm/ffi/c_api.h
index d2b9fab..0cac1f7 100644
--- a/include/tvm/ffi/c_api.h
+++ b/include/tvm/ffi/c_api.h
@@ -219,17 +219,26 @@ typedef enum {
  * \brief C-based type of all FFI object header that allocates on heap.
  */
 typedef struct {
-  // Ref counter goes first to align ABI with most intrusive ptr designs.
-  // It is also likely more efficient as rc operations can be quite common
-  // ABI note: Strong ref counter and weak ref counter can be packed into a 
single 64-bit field
-  // Hopefully in future being able to use 64bit atomic that avoids extra 
reading of
-  // weak counter during deletion.
-  /*! \brief Strong reference counter of the object. */
-  uint32_t strong_ref_count;
   /*!
-   * \brief Weak reference counter of the object, for compatiblity with 
weak_ptr design.
+   * \brief Combined strong and weak reference counter of the object.
+   *
+   * Strong ref counter is packed into the lower 32 bits.
+   * Weak ref counter is packed into the upper 32 bits.
+   *
+   * It is equivalent to { uint32_t strong_ref_count, uint32_t weak_ref_count }
+   * in little-endian structure:
+   *
+   * - strong_ref_count: `combined_ref_count & 0xFFFFFFFF`
+   * - weak_ref_count: `(combined_ref_count >> 32) & 0xFFFFFFFF`
+   *
+   * Rationale: atomic ops on strong ref counter remains the same as +1/-1,
+   * this combined ref counter allows us to use u64 atomic once
+   * instead of a separate atomic read of weak counter during deletion.
+   *
+   * The ref counter goes first to align ABI with most intrusive ptr designs.
+   * It is also likely more efficient as rc operations can be quite common.
    */
-  uint32_t weak_ref_count;
+  uint64_t combined_ref_count;
   /*!
    * \brief type index of the object.
    * \note The type index of Object and Any are shared in FFI.
diff --git a/include/tvm/ffi/memory.h b/include/tvm/ffi/memory.h
index 1fa9d65..76c9003 100644
--- a/include/tvm/ffi/memory.h
+++ b/include/tvm/ffi/memory.h
@@ -66,8 +66,7 @@ class ObjAllocatorBase {
     static_assert(std::is_base_of<Object, T>::value, "make can only be used to 
create Object");
     T* ptr = Handler::New(static_cast<Derived*>(this), 
std::forward<Args>(args)...);
     TVMFFIObject* ffi_ptr = details::ObjectUnsafe::GetHeader(ptr);
-    ffi_ptr->strong_ref_count = 1;
-    ffi_ptr->weak_ref_count = 1;
+    ffi_ptr->combined_ref_count = kCombinedRefCountBothOne;
     ffi_ptr->type_index = T::RuntimeTypeIndex();
     ffi_ptr->deleter = Handler::Deleter();
     return details::ObjectUnsafe::ObjectPtrFromOwned<T>(ptr);
@@ -88,8 +87,7 @@ class ObjAllocatorBase {
     ArrayType* ptr =
         Handler::New(static_cast<Derived*>(this), num_elems, 
std::forward<Args>(args)...);
     TVMFFIObject* ffi_ptr = details::ObjectUnsafe::GetHeader(ptr);
-    ffi_ptr->strong_ref_count = 1;
-    ffi_ptr->weak_ref_count = 1;
+    ffi_ptr->combined_ref_count = kCombinedRefCountBothOne;
     ffi_ptr->type_index = ArrayType::RuntimeTypeIndex();
     ffi_ptr->deleter = Handler::Deleter();
     return details::ObjectUnsafe::ObjectPtrFromOwned<ArrayType>(ptr);
diff --git a/include/tvm/ffi/object.h b/include/tvm/ffi/object.h
index e5f955c..6eac9a4 100644
--- a/include/tvm/ffi/object.h
+++ b/include/tvm/ffi/object.h
@@ -129,6 +129,15 @@ namespace details {
 // unsafe operations related to object
 struct ObjectUnsafe;
 
+/*! \brief One counter for weak reference. */
+constexpr uint64_t kCombinedRefCountWeakOne = static_cast<uint64_t>(1) << 32;
+/*! \brief One counter for strong reference. */
+constexpr uint64_t kCombinedRefCountStrongOne = 1;
+/*! \brief Both reference counts. */
+constexpr uint64_t kCombinedRefCountBothOne = kCombinedRefCountWeakOne | 
kCombinedRefCountStrongOne;
+/*! \brief Mask to get the lower 32 bits of the combined reference count. */
+constexpr uint64_t kCombinedRefCountMaskUInt32 = (static_cast<uint64_t>(1) << 
32) - 1;
+
 /*!
  * Check if the type_index is an instance of TargetObjectType.
  *
@@ -192,8 +201,7 @@ class Object {
 
  public:
   Object() {
-    header_.strong_ref_count = 0;
-    header_.weak_ref_count = 0;
+    header_.combined_ref_count = 0;
     header_.deleter = nullptr;
   }
   /*!
@@ -247,12 +255,16 @@ class Object {
    * \return The usage count of the cell.
    * \note We use STL style naming to be consistent with known API in 
shared_ptr.
    */
-  int32_t use_count() const {
+  uint64_t use_count() const {
     // only need relaxed load of counters
 #ifdef _MSC_VER
-    return (reinterpret_cast<const volatile 
long*>(&header_.strong_ref_count))[0];  // NOLINT(*)
+    return ((reinterpret_cast<const volatile uint64_t*>(
+               &header_.combined_ref_count))[0]  // NOLINT(*)
+            ) &
+           kCombinedRefCountMaskUInt32;
 #else
-    return __atomic_load_n(&(header_.strong_ref_count), __ATOMIC_RELAXED);
+    return __atomic_load_n(&(header_.combined_ref_count), __ATOMIC_RELAXED) &
+           kCombinedRefCountMaskUInt32;
 #endif
   }
 
@@ -290,13 +302,18 @@ class Object {
   static int32_t _GetOrAllocRuntimeTypeIndex() { return 
TypeIndex::kTVMFFIObject; }
 
  private:
+  // exposing detailed constants to here
+  static constexpr uint64_t kCombinedRefCountMaskUInt32 = 
details::kCombinedRefCountMaskUInt32;
+  static constexpr uint64_t kCombinedRefCountStrongOne = 
details::kCombinedRefCountStrongOne;
+  static constexpr uint64_t kCombinedRefCountWeakOne = 
details::kCombinedRefCountWeakOne;
+  static constexpr uint64_t kCombinedRefCountBothOne = 
details::kCombinedRefCountBothOne;
   /*! \brief increase strong reference count, the caller must already hold a 
strong reference */
   void IncRef() {
 #ifdef _MSC_VER
-    _InterlockedIncrement(
-        reinterpret_cast<volatile long*>(&header_.strong_ref_count));  // 
NOLINT(*)
+    _InterlockedIncrement64(
+        reinterpret_cast<volatile __int64*>(&header_.combined_ref_count));  // 
NOLINT(*)
 #else
-    __atomic_fetch_add(&(header_.strong_ref_count), 1, __ATOMIC_RELAXED);
+    __atomic_fetch_add(&(header_.combined_ref_count), 1, __ATOMIC_RELAXED);
 #endif
   }
   /*!
@@ -306,12 +323,12 @@ class Object {
    */
   bool TryPromoteWeakPtr() {
 #ifdef _MSC_VER
-    uint32_t old_count =
-        (reinterpret_cast<const volatile 
long*>(&header_.strong_ref_count))[0];  // NOLINT(*)
-    while (old_count > 0) {
-      uint32_t new_count = old_count + 1;
-      uint32_t old_count_loaded = _InterlockedCompareExchange(
-          reinterpret_cast<volatile long*>(&header_.strong_ref_count), 
new_count, old_count);
+    uint64_t old_count =
+        (reinterpret_cast<const volatile 
__int64*>(&header_.combined_ref_count))[0];  // NOLINT(*)
+    while ((old_count & kCombinedRefCountMaskUInt32) != 0) {
+      uint64_t new_count = old_count + kCombinedRefCountStrongOne;
+      uint64_t old_count_loaded = _InterlockedCompareExchange64(
+          reinterpret_cast<volatile __int64*>(&header_.combined_ref_count), 
new_count, old_count);
       if (old_count == old_count_loaded) {
         return true;
       }
@@ -319,13 +336,13 @@ class Object {
     }
     return false;
 #else
-    uint32_t old_count = __atomic_load_n(&(header_.strong_ref_count), 
__ATOMIC_RELAXED);
-    while (old_count > 0) {
+    uint64_t old_count = __atomic_load_n(&(header_.combined_ref_count), 
__ATOMIC_RELAXED);
+    while ((old_count & kCombinedRefCountMaskUInt32) != 0) {
       // must do CAS to ensure that we are the only one that increases the 
reference count
       // avoid condition when two threads tries to promote weak to strong at 
same time
       // or when strong deletion happens between the load and the CAS
-      uint32_t new_count = old_count + 1;
-      if (__atomic_compare_exchange_n(&(header_.strong_ref_count), &old_count, 
new_count, true,
+      uint64_t new_count = old_count + kCombinedRefCountStrongOne;
+      if (__atomic_compare_exchange_n(&(header_.combined_ref_count), 
&old_count, new_count, true,
                                       __ATOMIC_ACQ_REL, __ATOMIC_RELAXED)) {
         return true;
       }
@@ -337,9 +354,11 @@ class Object {
   /*! \brief increase weak reference count */
   void IncWeakRef() {
 #ifdef _MSC_VER
-    _InterlockedIncrement(reinterpret_cast<volatile 
long*>(&header_.weak_ref_count));  // NOLINT(*)
+    _InlineInterlockedAdd64(
+        reinterpret_cast<volatile __int64*>(&header_.combined_ref_count),  // 
NOLINT(*)
+        kCombinedRefCountWeakOne);
 #else
-    __atomic_fetch_add(&(header_.weak_ref_count), 1, __ATOMIC_RELAXED);
+    __atomic_fetch_add(&(header_.combined_ref_count), 
kCombinedRefCountWeakOne, __ATOMIC_RELAXED);
 #endif
   }
 
@@ -347,43 +366,62 @@ class Object {
   void DecRef() {
 #ifdef _MSC_VER
     // use simpler impl in windows to ensure correctness
-    if (_InterlockedDecrement(                                                 
   //
-            reinterpret_cast<volatile long*>(&header_.strong_ref_count)) == 0) 
{  // NOLINT(*)
-      // full barrrier is implicit in InterlockedDecrement
+    uint64_t count_before_sub =
+        _InterlockedDecrement64(                                              
//
+            reinterpret_cast<volatile __int64*>(&header_.combined_ref_count)  
// NOLINT(*)
+            ) +
+        1;
+    if (count_before_sub == kCombinedRefCountBothOne) {  // NOLINT(*)
+      // fast path: both reference counts will go to zero
+      if (header_.deleter != nullptr) {
+        // full barrrier is implicit in InterlockedDecrement
+        header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskBoth);
+      }
+    } else if ((count_before_sub & kCombinedRefCountMaskUInt32) == 
kCombinedRefCountStrongOne) {
+      // strong reference count becomes zero, we need to first do strong 
deletion
+      // then decrease weak reference count
+      // full barrrier is implicit in InterlockedAdd
       if (header_.deleter != nullptr) {
         header_.deleter(&(this->header_), 
kTVMFFIObjectDeleterFlagBitMaskStrong);
       }
-      if (_InterlockedDecrement(                                               
   //
-              reinterpret_cast<volatile long*>(&header_.weak_ref_count)) == 0) 
{  // NOLINT(*)
+      // decrease weak reference count
+      if (_InlineInterlockedAdd64(  //
+              reinterpret_cast<volatile __int64*>(&header_.combined_ref_count),
+              -kCombinedRefCountWeakOne) == 0) {  // NOLINT(*)
         if (header_.deleter != nullptr) {
+          // full barrrier is implicit in InterlockedAdd
           header_.deleter(&(this->header_), 
kTVMFFIObjectDeleterFlagBitMaskWeak);
         }
       }
     }
 #else
     // first do a release, note we only need to acquire for deleter
-    if (__atomic_fetch_sub(&(header_.strong_ref_count), 1, __ATOMIC_RELEASE) 
== 1) {
-      if (__atomic_load_n(&(header_.weak_ref_count), __ATOMIC_RELAXED) == 1) {
-        // common case, we need to delete both the object and the memory block
-        // only acquire when we need to call deleter
-        __atomic_thread_fence(__ATOMIC_ACQUIRE);
-        if (header_.deleter != nullptr) {
-          // call deleter once
-          header_.deleter(&(this->header_), 
kTVMFFIObjectDeleterFlagBitMaskBoth);
-        }
-      } else {
-        // Slower path: there is still a weak reference left
+    // optimization: we only need one atomic to tell the common case
+    // where both reference counts are zero
+    uint64_t count_before_sub = 
__atomic_fetch_sub(&(header_.combined_ref_count),
+                                                   kCombinedRefCountStrongOne, 
__ATOMIC_RELEASE);
+    if (count_before_sub == kCombinedRefCountBothOne) {
+      // common case, we need to delete both the object and the memory block
+      // only acquire when we need to call deleter
+      __atomic_thread_fence(__ATOMIC_ACQUIRE);
+      if (header_.deleter != nullptr) {
+        // call deleter once
+        header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskBoth);
+      }
+    } else if ((count_before_sub & kCombinedRefCountMaskUInt32) == 
kCombinedRefCountStrongOne) {
+      // strong count is already zero
+      // Slower path: there is still a weak reference left
+      __atomic_thread_fence(__ATOMIC_ACQUIRE);
+      // call destructor first, then decrease weak reference count
+      if (header_.deleter != nullptr) {
+        header_.deleter(&(this->header_), 
kTVMFFIObjectDeleterFlagBitMaskStrong);
+      }
+      // now decrease weak reference count
+      if (__atomic_fetch_sub(&(header_.combined_ref_count), 
kCombinedRefCountWeakOne,
+                             __ATOMIC_RELEASE) == kCombinedRefCountWeakOne) {
         __atomic_thread_fence(__ATOMIC_ACQUIRE);
-        // call destructor first, then decrease weak reference count
         if (header_.deleter != nullptr) {
-          header_.deleter(&(this->header_), 
kTVMFFIObjectDeleterFlagBitMaskStrong);
-        }
-        // now decrease weak reference count
-        if (__atomic_fetch_sub(&(header_.weak_ref_count), 1, __ATOMIC_RELEASE) 
== 1) {
-          __atomic_thread_fence(__ATOMIC_ACQUIRE);
-          if (header_.deleter != nullptr) {
-            header_.deleter(&(this->header_), 
kTVMFFIObjectDeleterFlagBitMaskWeak);
-          }
+          header_.deleter(&(this->header_), 
kTVMFFIObjectDeleterFlagBitMaskWeak);
         }
       }
     }
@@ -393,15 +431,17 @@ class Object {
   /*! \brief decrease weak reference count */
   void DecWeakRef() {
 #ifdef _MSC_VER
-    if (_InterlockedDecrement(                                                 
 //
-            reinterpret_cast<volatile long*>(&header_.weak_ref_count)) == 0) { 
 // NOLINT(*)
+    if (_InlineInterlockedAdd64(                                               
//
+            reinterpret_cast<volatile __int64*>(&header_.combined_ref_count),  
// NOLINT(*)
+            -kCombinedRefCountWeakOne) == 0) {
       if (header_.deleter != nullptr) {
         header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskWeak);
       }
     }
 #else
     // now decrease weak reference count
-    if (__atomic_fetch_sub(&(header_.weak_ref_count), 1, __ATOMIC_RELEASE) == 
1) {
+    if (__atomic_fetch_sub(&(header_.combined_ref_count), 
kCombinedRefCountWeakOne,
+                           __ATOMIC_RELEASE) == kCombinedRefCountWeakOne) {
       __atomic_thread_fence(__ATOMIC_ACQUIRE);
       if (header_.deleter != nullptr) {
         header_.deleter(&(this->header_), kTVMFFIObjectDeleterFlagBitMaskWeak);
diff --git a/pyproject.toml b/pyproject.toml
index 83a280b..8166de4 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -17,7 +17,7 @@
 
 [project]
 name = "apache-tvm-ffi"
-version = "0.1.0b7"
+version = "0.1.0b8"
 description = "tvm ffi"
 
 authors = [{ name = "TVM FFI team" }]
diff --git a/python/tvm_ffi/cython/base.pxi b/python/tvm_ffi/cython/base.pxi
index ff532ea..a3ab73e 100644
--- a/python/tvm_ffi/cython/base.pxi
+++ b/python/tvm_ffi/cython/base.pxi
@@ -111,10 +111,9 @@ cdef extern from "tvm/ffi/c_api.h":
     ctypedef void* TVMFFIObjectHandle
 
     ctypedef struct TVMFFIObject:
-        uint32_t strong_ref_count
-        uint32_t weak_ref_count
+        uint64_t combined_ref_count
         int32_t type_index
-        int32_t __padding
+        uint32_t __padding
         void (*deleter)(TVMFFIObject* self)
 
     ctypedef struct TVMFFIAny:

Reply via email to