This is an automated email from the ASF dual-hosted git repository.

yaxingcai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 71bbe91  Fix TVMFFIEnvSetDLPackManagedTensorAllocator to correctly 
return the original allocator (#371)
71bbe91 is described below

commit 71bbe91737afd58a330c735369c069317f48cc29
Author: Nan <[email protected]>
AuthorDate: Mon Jan 5 21:19:35 2026 +0800

    Fix TVMFFIEnvSetDLPackManagedTensorAllocator to correctly return the 
original allocator (#371)
    
    Before this commit, TVMFFIEnvSetDLPackManagedTensorAllocator incorrectly
    set the previous allocator in:
    TVMFFIEnvSetDLPackManagedTensorAllocator(NewAllocator, 0,
    &pre_allocator);
    
    ---------
    
    Co-authored-by: nan <[email protected]>
---
 src/ffi/extra/env_context.cc      | 7 +++----
 tests/cpp/extra/test_c_env_api.cc | 8 ++++++++
 2 files changed, 11 insertions(+), 4 deletions(-)

diff --git a/src/ffi/extra/env_context.cc b/src/ffi/extra/env_context.cc
index cde76c4..9b2fb25 100644
--- a/src/ffi/extra/env_context.cc
+++ b/src/ffi/extra/env_context.cc
@@ -65,13 +65,12 @@ class EnvContext {
   void SetDLPackManagedTensorAllocator(DLPackManagedTensorAllocator allocator,
                                        int write_to_global_context,
                                        DLPackManagedTensorAllocator* 
opt_out_original_allocator) {
-    dlpack_allocator_ = allocator;
+    if (opt_out_original_allocator != nullptr) {
+      *opt_out_original_allocator = GetDLPackManagedTensorAllocator();
+    }
     if (write_to_global_context != 0) {
       GlobalTensorAllocator() = allocator;
     }
-    if (opt_out_original_allocator != nullptr) {
-      *opt_out_original_allocator = dlpack_allocator_;
-    }
     dlpack_allocator_ = allocator;
   }
 
diff --git a/tests/cpp/extra/test_c_env_api.cc 
b/tests/cpp/extra/test_c_env_api.cc
index 8a38d75..78e746e 100644
--- a/tests/cpp/extra/test_c_env_api.cc
+++ b/tests/cpp/extra/test_c_env_api.cc
@@ -59,6 +59,14 @@ int TestDLPackManagedTensorAllocatorError(DLTensor* 
prototype, DLManagedTensorVe
   return -1;
 }
 
+TEST(CEnvAPI, TVMFFIEnvSetDLPackManagedTensorAllocator) {
+  auto old_allocator = TVMFFIEnvGetDLPackManagedTensorAllocator();
+  DLPackManagedTensorAllocator pre_allocator;
+  TVMFFIEnvSetDLPackManagedTensorAllocator(TestDLPackManagedTensorAllocator, 
0, &pre_allocator);
+  EXPECT_EQ(old_allocator, pre_allocator);
+  TVMFFIEnvSetDLPackManagedTensorAllocator(old_allocator, 0, nullptr);
+}
+
 TEST(CEnvAPI, TVMFFIEnvTensorAlloc) {
   auto old_allocator = TVMFFIEnvGetDLPackManagedTensorAllocator();
   TVMFFIEnvSetDLPackManagedTensorAllocator(TestDLPackManagedTensorAllocator, 
0, nullptr);

Reply via email to