This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 25c25ae [C-API] Introduce TVMFFIHandleInitOnce (#317)
25c25ae is described below
commit 25c25aec22acadcf1aeb839297fe156bc0cf7183
Author: Tianqi Chen <[email protected]>
AuthorDate: Fri Dec 5 17:06:34 2025 -0500
[C-API] Introduce TVMFFIHandleInitOnce (#317)
In DSL settings, sometimes it is helpful to have ability to do
thread-safe initialization of static handles.
This PR adds two functions to make it easy to do so without stdcpp
run_once support
---
CMakeLists.txt | 1 +
include/tvm/ffi/c_api.h | 38 +++++
src/ffi/init_once.cc | 94 +++++++++++++
tests/cpp/extra/test_c_env_api.cc | 288 ++++++++++++++++++++++++++++++++++++++
4 files changed, 421 insertions(+)
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 4c55a13..5e2fb4e 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -66,6 +66,7 @@ set(_tvm_ffi_objs_sources
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/tensor.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/dtype.cc"
"${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/container.cc"
+ "${CMAKE_CURRENT_SOURCE_DIR}/src/ffi/init_once.cc"
)
set(_tvm_ffi_extra_objs_sources
diff --git a/include/tvm/ffi/c_api.h b/include/tvm/ffi/c_api.h
index 54fd8e7..9ed1ce8 100644
--- a/include/tvm/ffi/c_api.h
+++ b/include/tvm/ffi/c_api.h
@@ -1127,6 +1127,44 @@ TVM_FFI_DLL int32_t TVMFFITypeGetOrAllocIndex(const
TVMFFIByteArray* type_key,
*/
TVM_FFI_DLL const TVMFFITypeInfo* TVMFFIGetTypeInfo(int32_t type_index);
+// ----------------------------------------------------------------------------
+// Static handle initialization and deinitialization API
+// ----------------------------------------------------------------------------
+/*!
+ * \brief Initialize a handle once in a thread-safe manner.
+ *
+ * This function checks if *handle_addr is nullptr,
+ * and if so, calls the initialization function
+ * and stores the result in *handle_addr.
+ *
+ * This function is thread-safe and is meant to be used by DSLs that,
+ * unlike C++, may not have static initialization support.
+ *
+ * \param handle_addr The address of the handle to be initialized.
+ * \param init_func The initialization function to be called once to create
the result handle.
+ * \return 0 on success, nonzero on failure.
+ *
+ * \note If init_func encounters an error, it should call
TVMFFIErrorSetRaisedFromCStr
+ * to set the error and return nonzero, which will then be propagated to
the
+ * caller of TVMFFIHandleInitOnce.
+ */
+TVM_FFI_DLL int TVMFFIHandleInitOnce(void** handle_addr, int
(*init_func)(void** result));
+
+/*!
+ * \brief Deinitialize a handle once in a thread-safe manner.
+ *
+ * This function checks if *handle_addr is not nullptr, and if so,
+ * calls the deinitialization function
+ * and sets *handle_addr to nullptr.
+ *
+ * This function is thread-safe and is meant to be used by DSLs that,
+ * unlike C++, may not have static deinitialization support.
+ *
+ * \param handle_addr The address of the handle to be deinitialized.
+ * \param deinit_func The deinitialization function to be called if
*handle_addr is not nullptr.
+ * \return 0 on success, nonzero on failure.
+ */
+TVM_FFI_DLL int TVMFFIHandleDeinitOnce(void** handle_addr, int
(*deinit_func)(void* handle));
#ifdef __cplusplus
} // TVM_FFI_EXTERN_C
#endif
diff --git a/src/ffi/init_once.cc b/src/ffi/init_once.cc
new file mode 100644
index 0000000..cc74a85
--- /dev/null
+++ b/src/ffi/init_once.cc
@@ -0,0 +1,94 @@
+
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*
+ * \file src/ffi/init_once.cc
+ * \brief Handle Init Once C API implementation.
+ */
+#include <tvm/ffi/base_details.h>
+#include <tvm/ffi/c_api.h>
+
+#include <mutex>
+
+#ifdef _MSC_VER
+#include <windows.h>
+#endif
+
+namespace {
+
+TVM_FFI_INLINE void* AtomicLoadHandleAcquire(void** src_addr) {
+#ifdef _MSC_VER
+#if defined(_WIN32_WINNT) && (_WIN32_WINNT >= 0x0602)
+ return InterlockedCompareExchangePointerAcquire(reinterpret_cast<PVOID
volatile*>(src_addr), //
+ nullptr, nullptr);
+#else
+ return InterlockedCompareExchangePointer(reinterpret_cast<PVOID
volatile*>(src_addr), //
+ nullptr, nullptr);
+#endif
+#else
+ return __atomic_load_n(src_addr, __ATOMIC_ACQUIRE);
+#endif
+}
+
+TVM_FFI_INLINE void AtomicStoreHandleRelease(void** dst_addr, void* src) {
+#ifdef _MSC_VER
+ _InterlockedExchangePointer(reinterpret_cast<PVOID volatile*>(dst_addr),
src);
+#else
+ __atomic_store_n(dst_addr, src, __ATOMIC_RELEASE);
+#endif
+}
+} // namespace
+
+int TVMFFIHandleInitOnce(void** handle_addr, int (*init_func)(void** result)) {
+ // fast path: handle is already initialized
+ // we still need atomic load acquire to ensure the handle is not initialized
+ if (AtomicLoadHandleAcquire(handle_addr) != nullptr) return 0;
+ // slow path: handle is not initialized, call initialization function
+ // note: slow path is not meant to be called frequently, so we use a simple
mutex
+ static std::mutex mutex;
+ std::scoped_lock<std::mutex> lock(mutex);
+ // must check again here, because another thread may have initialized the
+ // handle before we acquired the lock
+ if (*handle_addr != nullptr) return 0;
+ void* result = nullptr;
+ int ret = init_func(&result);
+ if (ret != 0) return ret;
+ if (result == nullptr) {
+ TVMFFIErrorSetRaisedFromCStr("RuntimeError", "init_func must return a
non-NULL handle");
+ return -1;
+ }
+ // NOTE: we must use atomic store release to ensure the result is
+ // visible to other thread's atomic load acquire
+ AtomicStoreHandleRelease(handle_addr, result);
+ return 0;
+}
+
+int TVMFFIHandleDeinitOnce(void** handle_addr, int (*deinit_func)(void*
handle)) {
+#ifdef _MSC_VER
+ void* old_handle =
+ _InterlockedExchangePointer(reinterpret_cast<PVOID
volatile*>(handle_addr), nullptr);
+#else
+ void* old_handle = __atomic_exchange_n(handle_addr, nullptr,
__ATOMIC_ACQ_REL);
+#endif
+ if (old_handle != nullptr) {
+ return (*deinit_func)(old_handle);
+ }
+ return 0;
+}
diff --git a/tests/cpp/extra/test_c_env_api.cc
b/tests/cpp/extra/test_c_env_api.cc
index 9ac51e3..f4b37ad 100644
--- a/tests/cpp/extra/test_c_env_api.cc
+++ b/tests/cpp/extra/test_c_env_api.cc
@@ -21,6 +21,13 @@
#include <tvm/ffi/container/tensor.h>
#include <tvm/ffi/extra/c_env_api.h>
+#include <atomic>
+#include <chrono>
+#include <condition_variable>
+#include <mutex>
+#include <thread>
+#include <vector>
+
namespace {
using namespace tvm::ffi;
@@ -89,4 +96,285 @@ TEST(CEnvAPI, TVMFFIEnvTensorAllocError) {
tvm::ffi::Error);
TVMFFIEnvSetDLPackManagedTensorAllocator(old_allocator, 0, nullptr);
}
+
+// Helper functions for TVMFFIHandleInitDeinitOnce test
+static int InitSuccess(void** handle_addr) {
+ *handle_addr = new int(42);
+ return 0;
+}
+
+static int InitShouldNotBeCalled(void** handle_addr) {
+ *handle_addr = new int(999);
+ return 0;
+}
+
+static int DeinitSuccess(void* h) {
+ delete (int*)h;
+ return 0;
+}
+
+static int DeinitShouldNotBeCalled(void* h) {
+ // Should not be called when handle is already null
+ return -1;
+}
+
+static int InitWithError(void** handle_addr) {
+ TVMFFIErrorSetRaisedFromCStr("RuntimeError", "Initialization failed");
+ return -1;
+}
+
+static int InitReturnsNull(void** handle_addr) {
+ *handle_addr = nullptr; // Invalid: must return non-null handle
+ return 0;
+}
+
+static int InitForDeinitError(void** handle_addr) {
+ *handle_addr = new int(100);
+ return 0;
+}
+
+static int DeinitWithError(void* h) {
+ delete (int*)h;
+ TVMFFIErrorSetRaisedFromCStr("RuntimeError", "Deinitialization failed");
+ return -1;
+}
+
+static int InitValue123(void** handle_addr) {
+ *handle_addr = new int(123);
+ return 0;
+}
+
+static int InitValue456(void** handle_addr) {
+ *handle_addr = new int(456);
+ return 0;
+}
+
+TEST(CEnvAPI, TVMFFIHandleInitDeinitOnce) {
+ // Test 1: Successful initialization
+ void* handle = nullptr;
+ int ret = TVMFFIHandleInitOnce(&handle, InitSuccess);
+ EXPECT_EQ(ret, 0);
+ EXPECT_NE(handle, nullptr);
+ EXPECT_EQ(*(int*)handle, 42);
+
+ // Test 2: Multiple init calls should not re-initialize (idempotent)
+ void* original_handle = handle;
+ ret = TVMFFIHandleInitOnce(&handle, InitShouldNotBeCalled);
+ EXPECT_EQ(ret, 0);
+ EXPECT_EQ(handle, original_handle); // Handle should remain unchanged
+ EXPECT_EQ(*(int*)handle, 42); // Value should still be 42
+
+ // Test 3: Successful deinitialization
+ ret = TVMFFIHandleDeinitOnce(&handle, DeinitSuccess);
+ EXPECT_EQ(ret, 0);
+ EXPECT_EQ(handle, nullptr);
+
+ // Test 4: Multiple deinit calls should be safe (idempotent)
+ ret = TVMFFIHandleDeinitOnce(&handle, DeinitShouldNotBeCalled);
+ EXPECT_EQ(ret, 0);
+ EXPECT_EQ(handle, nullptr);
+
+ // Test 5: Init error - init_func returns error code
+ void* handle2 = nullptr;
+ ret = TVMFFIHandleInitOnce(&handle2, InitWithError);
+ EXPECT_NE(ret, 0);
+ EXPECT_EQ(handle2, nullptr);
+
+ // Test 6: Init error - init_func returns nullptr (invalid)
+ void* handle3 = nullptr;
+ ret = TVMFFIHandleInitOnce(&handle3, InitReturnsNull);
+ EXPECT_NE(ret, 0);
+ EXPECT_EQ(handle3, nullptr);
+
+ // Test 7: Deinit error - deinit_func returns error
+ void* handle4 = nullptr;
+ ret = TVMFFIHandleInitOnce(&handle4, InitForDeinitError);
+ EXPECT_EQ(ret, 0);
+ EXPECT_NE(handle4, nullptr);
+
+ ret = TVMFFIHandleDeinitOnce(&handle4, DeinitWithError);
+ EXPECT_NE(ret, 0);
+ EXPECT_EQ(handle4, nullptr); // Handle should still be set to nullptr
+
+ // Test 8: Init-deinit lifecycle
+ void* handle5 = nullptr;
+ ret = TVMFFIHandleInitOnce(&handle5, InitValue123);
+ EXPECT_EQ(ret, 0);
+ EXPECT_NE(handle5, nullptr);
+ EXPECT_EQ(*(int*)handle5, 123);
+
+ ret = TVMFFIHandleDeinitOnce(&handle5, DeinitSuccess);
+ EXPECT_EQ(ret, 0);
+ EXPECT_EQ(handle5, nullptr);
+
+ // Test 9: Ensure subsequent init after deinit works
+ ret = TVMFFIHandleInitOnce(&handle5, InitValue456);
+ EXPECT_EQ(ret, 0);
+ EXPECT_NE(handle5, nullptr);
+ EXPECT_EQ(*(int*)handle5, 456);
+
+ // Clean up
+ ret = TVMFFIHandleDeinitOnce(&handle5, DeinitSuccess);
+ EXPECT_EQ(ret, 0);
+}
+
+// Helper functions and data for multithreaded test
+struct ThreadSafeCounter {
+ int value;
+ std::atomic<int>* init_count_ptr;
+ std::atomic<int>* deinit_count_ptr;
+
+ ThreadSafeCounter(int v, std::atomic<int>* init_ptr, std::atomic<int>*
deinit_ptr)
+ : value(v), init_count_ptr(init_ptr), deinit_count_ptr(deinit_ptr) {}
+};
+
+// Global pointers for the current test counters
+static std::atomic<int>* g_init_count = nullptr;
+static std::atomic<int>* g_deinit_count = nullptr;
+
+static int InitWithCounter(void** handle_addr) {
+ auto* counter = new ThreadSafeCounter(42, g_init_count, g_deinit_count);
+ if (counter->init_count_ptr) {
+ counter->init_count_ptr->fetch_add(1, std::memory_order_relaxed);
+ }
+ // Small delay to increase the race window
+ std::this_thread::sleep_for(std::chrono::microseconds(100));
+ *handle_addr = counter;
+ return 0;
+}
+
+static int DeinitWithCounter(void* h) {
+ auto* counter = (ThreadSafeCounter*)h;
+ if (counter->deinit_count_ptr) {
+ counter->deinit_count_ptr->fetch_add(1, std::memory_order_relaxed);
+ }
+ // Small delay to increase the race window
+ std::this_thread::sleep_for(std::chrono::microseconds(100));
+ delete counter;
+ return 0;
+}
+
+TEST(CEnvAPI, TVMFFIHandleInitDeinitOnceMultithreaded) {
+ // Test 1: Multiple threads calling InitOnce - should initialize only once
+ {
+ void* handle = nullptr;
+ const int num_threads = 4;
+ std::vector<std::thread> threads;
+ std::vector<int> results(num_threads);
+ std::mutex mtx;
+ std::condition_variable cv;
+ bool ready = false;
+ std::atomic<int> init_count{0};
+
+ // Set global counter pointers
+ g_init_count = &init_count;
+ g_deinit_count = nullptr;
+
+ // Create threads that all try to initialize simultaneously
+ for (int i = 0; i < num_threads; ++i) {
+ threads.emplace_back([&handle, &results, &mtx, &cv, &ready, i]() {
+ // Wait for all threads to be ready
+ std::unique_lock<std::mutex> lock(mtx);
+ cv.wait(lock, [&ready] { return ready; });
+ lock.unlock();
+
+ results[i] = TVMFFIHandleInitOnce(&handle, InitWithCounter);
+ });
+ }
+
+ // Signal all threads to start
+ {
+ std::lock_guard<std::mutex> lock(mtx);
+ ready = true;
+ }
+ cv.notify_all();
+
+ // Wait for all threads to complete
+ for (auto& t : threads) {
+ t.join();
+ }
+
+ // All threads should succeed
+ for (int i = 0; i < num_threads; ++i) {
+ EXPECT_EQ(results[i], 0);
+ }
+
+ // Handle should be initialized
+ EXPECT_NE(handle, nullptr);
+ auto* counter = (ThreadSafeCounter*)handle;
+ EXPECT_EQ(counter->value, 42);
+
+ // Init should have been called exactly once
+ EXPECT_EQ(init_count.load(), 1);
+
+ // Clean up
+ int ret = TVMFFIHandleDeinitOnce(&handle, DeinitWithCounter);
+ EXPECT_EQ(ret, 0);
+
+ // Reset global pointers
+ g_init_count = nullptr;
+ }
+
+ // Test 2: Multiple threads calling DeinitOnce - should deinitialize only
once
+ {
+ void* handle = nullptr;
+ std::atomic<int> init_count{0};
+ std::atomic<int> deinit_count{0};
+
+ // Set global counter pointers
+ g_init_count = &init_count;
+ g_deinit_count = &deinit_count;
+
+ // Initialize first
+ int ret = TVMFFIHandleInitOnce(&handle, InitWithCounter);
+ EXPECT_EQ(ret, 0);
+ EXPECT_NE(handle, nullptr);
+
+ const int num_threads = 4;
+ std::vector<std::thread> threads;
+ std::vector<int> results(num_threads);
+ std::mutex mtx;
+ std::condition_variable cv;
+ bool ready = false;
+
+ // Create threads that all try to deinitialize simultaneously
+ for (int i = 0; i < num_threads; ++i) {
+ threads.emplace_back([&handle, &results, &mtx, &cv, &ready, i]() {
+ // Wait for all threads to be ready
+ std::unique_lock<std::mutex> lock(mtx);
+ cv.wait(lock, [&ready] { return ready; });
+ lock.unlock();
+
+ results[i] = TVMFFIHandleDeinitOnce(&handle, DeinitWithCounter);
+ });
+ }
+
+ // Signal all threads to start
+ {
+ std::lock_guard<std::mutex> lock(mtx);
+ ready = true;
+ }
+ cv.notify_all();
+
+ // Wait for all threads to complete
+ for (auto& t : threads) {
+ t.join();
+ }
+
+ // All threads should succeed
+ for (int i = 0; i < num_threads; ++i) {
+ EXPECT_EQ(results[i], 0);
+ }
+
+ // Handle should be null
+ EXPECT_EQ(handle, nullptr);
+
+ // Deinit should have been called exactly once
+ EXPECT_EQ(deinit_count.load(), 1);
+
+ // Reset global pointers
+ g_init_count = nullptr;
+ g_deinit_count = nullptr;
+ }
+}
} // namespace