gemini-code-assist[bot] commented on code in PR #317:
URL: https://github.com/apache/tvm-ffi/pull/317#discussion_r2594043105


##########
src/ffi/extra/env_c_api.cc:
##########
@@ -146,3 +152,85 @@ int TVMFFIEnvRegisterCAPI(const char* name, void* symbol) {
   tvm::ffi::EnvCAPIRegistry::Global()->Register(s_name, symbol);
   TVM_FFI_SAFE_CALL_END();
 }
+
+namespace {
+
+TVM_FFI_INLINE void* AtomicLoadHandleAcquire(void** src_addr) {
+#ifdef _MSC_VER
+#if defined(_WIN32_WINNT) && (_WIN32_WINNT >= 0x0602)
+  return InterlockedCompareExchangePointerAcquire(reinterpret_cast<PVOID 
volatile*>(src_addr),  //
+                                                  nullptr, nullptr);
+#else
+  return InterlockedCompareExchangePointer(reinterpret_cast<PVOID 
volatile*>(src_addr),  //
+                                           nullptr, nullptr);
+#endif
+#else
+  return __atomic_load_n(src_addr, __ATOMIC_ACQUIRE);
+#endif
+}
+
+TVM_FFI_INLINE bool AtomicCompareExchangeHandleAcqRel(void** src_addr, void** 
expected,
+                                                      void* desired) {
+#ifdef _MSC_VER
+  PVOID result = InterlockedCompareExchangePointer(reinterpret_cast<PVOID 
volatile*>(src_addr),
+                                                   desired, *expected);
+  if (result == *expected) {
+    return true;
+  } else {
+    *expected = result;
+    return false;
+  }
+#else
+  return __atomic_compare_exchange_n(src_addr, expected, desired, 
/*weak=*/false, __ATOMIC_ACQ_REL,
+                                     __ATOMIC_ACQUIRE);
+#endif
+}
+
+TVM_FFI_INLINE void AtomicStoreHandleRelease(void** dst_addr, void* src) {
+#ifdef _MSC_VER
+  _InterlockedExchangePointer(reinterpret_cast<PVOID volatile*>(dst_addr), 
src);
+#else
+  __atomic_store_n(dst_addr, src, __ATOMIC_RELEASE);
+#endif
+}
+}  // namespace
+
+int TVMFFIEnvHandleInitOnce(void** handle_addr, int (*init_func)(void** 
handle_addr)) {

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The parameter name `handle_addr` for the argument of `init_func` is 
inconsistent with the header declaration (`void** result`) and is less 
descriptive. Using `result` as declared in the header improves code readability 
and maintainability.
   
   ```c
   int TVMFFIEnvHandleInitOnce(void** handle_addr, int (*init_func)(void** 
result)) {
   ```



##########
src/ffi/extra/env_c_api.cc:
##########
@@ -146,3 +152,85 @@ int TVMFFIEnvRegisterCAPI(const char* name, void* symbol) {
   tvm::ffi::EnvCAPIRegistry::Global()->Register(s_name, symbol);
   TVM_FFI_SAFE_CALL_END();
 }
+
+namespace {
+
+TVM_FFI_INLINE void* AtomicLoadHandleAcquire(void** src_addr) {
+#ifdef _MSC_VER
+#if defined(_WIN32_WINNT) && (_WIN32_WINNT >= 0x0602)
+  return InterlockedCompareExchangePointerAcquire(reinterpret_cast<PVOID 
volatile*>(src_addr),  //
+                                                  nullptr, nullptr);
+#else
+  return InterlockedCompareExchangePointer(reinterpret_cast<PVOID 
volatile*>(src_addr),  //
+                                           nullptr, nullptr);
+#endif
+#else
+  return __atomic_load_n(src_addr, __ATOMIC_ACQUIRE);
+#endif
+}
+
+TVM_FFI_INLINE bool AtomicCompareExchangeHandleAcqRel(void** src_addr, void** 
expected,
+                                                      void* desired) {
+#ifdef _MSC_VER
+  PVOID result = InterlockedCompareExchangePointer(reinterpret_cast<PVOID 
volatile*>(src_addr),
+                                                   desired, *expected);
+  if (result == *expected) {
+    return true;
+  } else {
+    *expected = result;
+    return false;
+  }
+#else
+  return __atomic_compare_exchange_n(src_addr, expected, desired, 
/*weak=*/false, __ATOMIC_ACQ_REL,
+                                     __ATOMIC_ACQUIRE);
+#endif
+}
+
+TVM_FFI_INLINE void AtomicStoreHandleRelease(void** dst_addr, void* src) {
+#ifdef _MSC_VER
+  _InterlockedExchangePointer(reinterpret_cast<PVOID volatile*>(dst_addr), 
src);
+#else
+  __atomic_store_n(dst_addr, src, __ATOMIC_RELEASE);
+#endif
+}
+}  // namespace
+
+int TVMFFIEnvHandleInitOnce(void** handle_addr, int (*init_func)(void** 
handle_addr)) {
+  // fast path: handle is already initialized
+  // we still need atomic load acquire to ensure the handle is not initialized
+  if (AtomicLoadHandleAcquire(handle_addr) != nullptr) return 0;
+  // slow path: handle is not initialized, call initialization function
+  static std::mutex mutex;
+  std::scoped_lock<std::mutex> lock(mutex);
+  // must check again here, because another thread may have initialized the
+  // handle before we acquired the lock
+  if (*handle_addr != nullptr) return 0;
+  void* result = nullptr;
+  int ret = init_func(&result);
+  if (ret != 0) return ret;
+  if (result == nullptr) {
+    TVMFFIErrorSetRaisedFromCStr("RuntimeError", "init_func must return a 
non-NULL handle");
+    return -1;
+  }
+  // NOTE: we must use atomic store release to ensure the result is
+  // visible to other thread's atomic load acquire
+  AtomicStoreHandleRelease(handle_addr, result);
+  return 0;
+}
+
+int TVMFFIEnvHandleDeinitOnce(void** handle_addr, int (*deinit_func)(void* 
handle_addr)) {
+  // fast path: handle is not initialized
+  void* current_handle = AtomicLoadHandleAcquire(handle_addr);
+  // atomic CAS, only one thread will win the CAS race and obtain the handle 
to call deinit_func
+  while (current_handle != nullptr) {
+    void* expected = current_handle;
+    // set handle_addr to be nullptr if it equals to current_handle
+    // otherwise read the value into expected
+    if (AtomicCompareExchangeHandleAcqRel(handle_addr, &expected, nullptr)) {
+      // this thread wins the CAS race, deinitialize the handle
+      return (*deinit_func)(current_handle);
+    }
+    current_handle = expected;
+  }
+  return 0;
+}

Review Comment:
   ![critical](https://www.gstatic.com/codereviewagent/critical.svg)
   
   The `while` loop in `TVMFFIEnvHandleDeinitOnce` has a race condition. If one 
thread is attempting to deinitialize a handle, and another thread successfully 
deinitializes and then re-initializes the same handle, the first thread's 
`AtomicCompareExchangeHandleAcqRel` will fail, but `current_handle` will be 
updated with the *new* handle value. The loop will then continue and attempt to 
deinitialize the newly created handle, violating the "once" semantics and 
potentially leading to a double-free or use-after-free. This implementation is 
also vulnerable to the ABA problem.
   
   A correct and simpler implementation is to use an atomic exchange operation 
to atomically retrieve the handle and set the pointer to `nullptr`. This is a 
single atomic operation and avoids these issues. The suggested change also 
corrects the function pointer's parameter name to match the header file.
   
   ```c
   int TVMFFIEnvHandleDeinitOnce(void** handle_addr, int (*deinit_func)(void* 
handle)) {
   #ifdef _MSC_VER
     void* old_handle =
         _InterlockedExchangePointer(reinterpret_cast<PVOID 
volatile*>(handle_addr), nullptr);
   #else
     void* old_handle = __atomic_exchange_n(handle_addr, nullptr, 
__ATOMIC_ACQ_REL);
   #endif
     if (old_handle != nullptr) {
       return (*deinit_func)(old_handle);
     }
     return 0;
   }
   ```



##########
src/ffi/extra/env_c_api.cc:
##########
@@ -146,3 +152,85 @@ int TVMFFIEnvRegisterCAPI(const char* name, void* symbol) {
   tvm::ffi::EnvCAPIRegistry::Global()->Register(s_name, symbol);
   TVM_FFI_SAFE_CALL_END();
 }
+
+namespace {
+
+TVM_FFI_INLINE void* AtomicLoadHandleAcquire(void** src_addr) {
+#ifdef _MSC_VER
+#if defined(_WIN32_WINNT) && (_WIN32_WINNT >= 0x0602)
+  return InterlockedCompareExchangePointerAcquire(reinterpret_cast<PVOID 
volatile*>(src_addr),  //
+                                                  nullptr, nullptr);
+#else
+  return InterlockedCompareExchangePointer(reinterpret_cast<PVOID 
volatile*>(src_addr),  //
+                                           nullptr, nullptr);
+#endif
+#else
+  return __atomic_load_n(src_addr, __ATOMIC_ACQUIRE);
+#endif
+}
+
+TVM_FFI_INLINE bool AtomicCompareExchangeHandleAcqRel(void** src_addr, void** 
expected,
+                                                      void* desired) {
+#ifdef _MSC_VER
+  PVOID result = InterlockedCompareExchangePointer(reinterpret_cast<PVOID 
volatile*>(src_addr),
+                                                   desired, *expected);
+  if (result == *expected) {
+    return true;
+  } else {
+    *expected = result;
+    return false;
+  }
+#else
+  return __atomic_compare_exchange_n(src_addr, expected, desired, 
/*weak=*/false, __ATOMIC_ACQ_REL,
+                                     __ATOMIC_ACQUIRE);
+#endif
+}
+
+TVM_FFI_INLINE void AtomicStoreHandleRelease(void** dst_addr, void* src) {
+#ifdef _MSC_VER
+  _InterlockedExchangePointer(reinterpret_cast<PVOID volatile*>(dst_addr), 
src);
+#else
+  __atomic_store_n(dst_addr, src, __ATOMIC_RELEASE);
+#endif
+}
+}  // namespace
+
+int TVMFFIEnvHandleInitOnce(void** handle_addr, int (*init_func)(void** 
handle_addr)) {
+  // fast path: handle is already initialized
+  // we still need atomic load acquire to ensure the handle is not initialized
+  if (AtomicLoadHandleAcquire(handle_addr) != nullptr) return 0;
+  // slow path: handle is not initialized, call initialization function
+  static std::mutex mutex;
+  std::scoped_lock<std::mutex> lock(mutex);

Review Comment:
   ![critical](https://www.gstatic.com/codereviewagent/critical.svg)
   
   The use of a single `static std::mutex` for all initializations creates a 
global lock. This means that if multiple threads try to initialize different 
handles concurrently, they will be serialized, leading to a major performance 
bottleneck. The function name `InitOnce` suggests per-handle "once" semantics, 
but the global lock violates this by making all initializations dependent on 
each other. This could also lead to deadlocks in complex scenarios if other 
locks are acquired within the `init_func`.
   
   A better approach that doesn't require an API change is to use a striped 
lock to reduce contention. For example, an array of mutexes can be used, with 
the handle address hashed to an index.
   
   ```c
     static std::mutex mutexes[256];
     std::scoped_lock<std::mutex> lock(mutexes[std::hash<void*>()(handle_addr) 
% 256]);
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to