This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 22a7894  [DLPack] C Functions for DLPack Speed Exchange and Stream 
Handling (#96)
22a7894 is described below

commit 22a78943b78306a73011757fa635afa9dce35114
Author: Kathryn (Jinqi) Chen <[email protected]>
AuthorDate: Sat Oct 11 14:09:54 2025 -0700

    [DLPack] C Functions for DLPack Speed Exchange and Stream Handling (#96)
    
    ## Summary of Changes
    
    This PR introduces a unified `DLPackExchangeAPI` struct as described in
    proposal [175](https://github.com/dmlc/dlpack/issues/175). This new
    convention replaces the previous mechanism of separate function
    pointers, and aligns with the latest DLPack standard as shown in PR
    [174](https://github.com/dmlc/dlpack/pull/174).
    
    Within the new `DLPackExchangeAPI` struct, it also includes a
    `current_work_stream` function pointer that allows more robust and
    integrated querying of the current device stream (e.g., CUDA stream)
    during DLPack tensor exchanges. All the conversion from/to DLPack has
    been updated to `_no_sync`, meaning you should use `current_work_stream`
    to explicitly handle stream synchronization. It also includes a
    non-owning DLTensor conversion to avoid unnecessary reference counting.
    
    Following this change, the Python FFI for PyTorch has been updated to
    expose the new `DLPackExchangeAPI` struct via
    `__c_dlpack_exchange_api__` on torch.Tensor.
    
    The `3rdparty/dlpack` has been updated to incorporate the latest commit.
---
 3rdparty/dlpack                                |   2 +-
 python/tvm_ffi/_optional_torch_c_dlpack.py     |  89 ++++++++++++++------
 python/tvm_ffi/core.pyi                        |   2 +-
 python/tvm_ffi/cython/base.pxi                 |  53 ++++++++++--
 python/tvm_ffi/cython/function.pxi             | 109 +++++++++++++++----------
 python/tvm_ffi/cython/tensor.pxi               |  85 ++++++++++++++-----
 python/tvm_ffi/cython/tvm_ffi_python_helpers.h |  54 +++++-------
 tests/python/test_load_inline.py               |   4 +-
 8 files changed, 268 insertions(+), 130 deletions(-)

diff --git a/3rdparty/dlpack b/3rdparty/dlpack
index addbc8b..1117366 160000
--- a/3rdparty/dlpack
+++ b/3rdparty/dlpack
@@ -1 +1 @@
-Subproject commit addbc8b3d9449691d01827ac4a0e0d035cf8ea40
+Subproject commit 111736618e8d1028b23605f76dcaa6a38cfea809
diff --git a/python/tvm_ffi/_optional_torch_c_dlpack.py 
b/python/tvm_ffi/_optional_torch_c_dlpack.py
index c1bb1ef..94bb3d7 100644
--- a/python/tvm_ffi/_optional_torch_c_dlpack.py
+++ b/python/tvm_ffi/_optional_torch_c_dlpack.py
@@ -464,18 +464,41 @@ at::Tensor fromDLPackImpl(T* src, 
std::function<void(void*)> deleter) {
       {device});
 }
 
+void toDLPackNonOwningImpl(const Tensor& tensor, DLTensor& out) {
+  // Fill in the pre-allocated DLTensor struct with direct pointers
+  // This is a non-owning conversion - the caller owns the tensor
+  // and must keep it alive for the duration of DLTensor usage
+  out.data = tensor.data_ptr();
+  out.device = torchDeviceToDLDeviceForDLPackv1(tensor.device());
+  out.ndim = static_cast<int32_t>(tensor.dim());
+  out.dtype = getDLDataTypeForDLPackv1(tensor);
+  // sizes() and strides() return pointers to TensorImpl's stable storage
+  // which remains valid as long as the tensor is alive
+  out.shape = const_cast<int64_t*>(tensor.sizes().data());
+  out.strides = const_cast<int64_t*>(tensor.strides().data());
+  out.byte_offset = 0;
+}
+
 } // namespace
 } // namespace at
 
-int TorchDLPackFromPyObject(void* py_obj, DLManagedTensorVersioned** out, 
void** env_stream) {
+int TorchDLPackDLTensorFromPyObjectNoSync(void* py_obj, DLTensor* out) {
+  try {
+    // Use handle (non-owning) to avoid unnecessary refcount operations
+    py::handle handle(static_cast<PyObject*>(py_obj));
+    at::Tensor tensor = handle.cast<at::Tensor>();
+    at::toDLPackNonOwningImpl(tensor, *out);
+    return 0;
+  } catch (const std::exception& e) {
+    PyErr_SetString(PyExc_RuntimeError, e.what());
+    return -1;
+  }
+}
+
+int TorchDLPackManagedTensorFromPyObjectNoSync(void* py_obj, 
DLManagedTensorVersioned** out) {
   try {
     py::handle handle(static_cast<PyObject*>(py_obj));
     at::Tensor tensor = handle.cast<at::Tensor>();
-#ifdef BUILD_WITH_CUDA
-    if (env_stream != nullptr && tensor.is_cuda()) {
-      *env_stream = 
at::cuda::getCurrentCUDAStream(tensor.device().index()).stream();
-    }
-#endif
     *out = at::toDLPackImpl<DLManagedTensorVersioned>(tensor);
     return 0;
   } catch (const std::exception& e) {
@@ -484,7 +507,7 @@ int TorchDLPackFromPyObject(void* py_obj, 
DLManagedTensorVersioned** out, void**
   }
 }
 
-int TorchDLPackToPyObject(DLManagedTensorVersioned* src, void** py_obj_out) {
+int TorchDLPackManagedTensorToPyObjectNoSync(DLManagedTensorVersioned* src, 
void** py_obj_out) {
   try {
     at::Tensor tensor = at::fromDLPackImpl<DLManagedTensorVersioned>(src, 
nullptr);
     *py_obj_out = THPVariable_Wrap(tensor);
@@ -495,7 +518,7 @@ int TorchDLPackToPyObject(DLManagedTensorVersioned* src, 
void** py_obj_out) {
   }
 }
 
-int TorchDLPackTensorAllocator(
+int TorchDLPackManagedTensorAllocator(
     DLTensor* prototype, DLManagedTensorVersioned** out, void* error_ctx,
     void (*SetError)(void* error_ctx, const char* kind, const char* message)
 ) {
@@ -508,21 +531,45 @@ int TorchDLPackTensorAllocator(
     *out = at::toDLPackImpl<DLManagedTensorVersioned>(tensor);
     return 0;
   } catch (const std::exception& e) {
-    SetError(error_ctx, "TorchDLPackTensorAllocator", e.what());
+    SetError(error_ctx, "TorchDLPackManagedTensorAllocator", e.what());
     return -1;
   }
 }
 
-int64_t TorchDLPackFromPyObjectPtr() {
-  return reinterpret_cast<int64_t>(TorchDLPackFromPyObject);
+int TorchDLPackCurrentWorkStream(DLDeviceType device_type, int32_t device_id, 
void** out_stream) {
+  try {
+#ifdef BUILD_WITH_CUDA
+    if (device_type == kDLCUDA || device_type == kDLROCM) {
+      *out_stream = at::cuda::getCurrentCUDAStream(device_id).stream();
+    }
+#endif
+    return 0;
+  } catch (const std::exception& e) {
+    PyErr_SetString(PyExc_RuntimeError, e.what());
+    return -1;
+  }
 }
 
-int64_t TorchDLPackToPyObjectPtr() {
-  return reinterpret_cast<int64_t>(TorchDLPackToPyObject);
-}
+struct TorchDLPackExchangeAPI : public DLPackExchangeAPI {
+  TorchDLPackExchangeAPI() {
+    header.version.major = DLPACK_MAJOR_VERSION;
+    header.version.minor = DLPACK_MINOR_VERSION;
+    header.prev_api = nullptr;
+    managed_tensor_allocator = TorchDLPackManagedTensorAllocator;
+    managed_tensor_from_py_object_no_sync = 
TorchDLPackManagedTensorFromPyObjectNoSync;
+    managed_tensor_to_py_object_no_sync = 
TorchDLPackManagedTensorToPyObjectNoSync;
+    dltensor_from_py_object_no_sync = TorchDLPackDLTensorFromPyObjectNoSync;
+    current_work_stream = TorchDLPackCurrentWorkStream;
+  }
+
+  static const DLPackExchangeAPI* Global() {
+    static TorchDLPackExchangeAPI inst;
+    return &inst;
+  }
+};
 
-int64_t TorchDLPackTensorAllocatorPtr() {
-  return reinterpret_cast<int64_t>(TorchDLPackTensorAllocator);
+int64_t TorchDLPackExchangeAPIPtr() {
+  return reinterpret_cast<int64_t>(TorchDLPackExchangeAPI::Global());
 }
     """
     try:
@@ -541,17 +588,13 @@ int64_t TorchDLPackTensorAllocatorPtr() {
             name="c_dlpack",
             cpp_sources=cpp_source,
             functions=[
-                "TorchDLPackFromPyObjectPtr",
-                "TorchDLPackToPyObjectPtr",
-                "TorchDLPackTensorAllocatorPtr",
+                "TorchDLPackExchangeAPIPtr",
             ],
             extra_cflags=extra_cflags,
             extra_include_paths=include_paths,
         )
-        # set the dlpack related flags
-        setattr(torch.Tensor, "__c_dlpack_from_pyobject__", 
mod.TorchDLPackFromPyObjectPtr())
-        setattr(torch.Tensor, "__c_dlpack_to_pyobject__", 
mod.TorchDLPackToPyObjectPtr())
-        setattr(torch.Tensor, "__c_dlpack_tensor_allocator__", 
mod.TorchDLPackTensorAllocatorPtr())
+        # Set the DLPackExchangeAPI pointer on the class
+        setattr(torch.Tensor, "__c_dlpack_exchange_api__", 
mod.TorchDLPackExchangeAPIPtr())
         return mod
     except ImportError:
         pass
diff --git a/python/tvm_ffi/core.pyi b/python/tvm_ffi/core.pyi
index 092515a..05e5a04 100644
--- a/python/tvm_ffi/core.pyi
+++ b/python/tvm_ffi/core.pyi
@@ -485,7 +485,7 @@ def from_dlpack(
 class DLTensorTestWrapper:
     """Wrapper of a Tensor that exposes DLPack protocol, only for testing 
purpose."""
 
-    __c_dlpack_from_pyobject__: int
+    __c_dlpack_exchange_api__: int
     def __init__(self, tensor: Tensor) -> None: ...
     def __tvm_ffi_env_stream__(self) -> int: ...
     def __dlpack_device__(self) -> tuple[int, int]: ...
diff --git a/python/tvm_ffi/cython/base.pxi b/python/tvm_ffi/cython/base.pxi
index dcc338a..a86ea00 100644
--- a/python/tvm_ffi/cython/base.pxi
+++ b/python/tvm_ffi/cython/base.pxi
@@ -25,6 +25,9 @@ from cpython cimport pycapsule, PyCapsule_Destructor
 from cpython cimport PyErr_SetNone
 
 cdef extern from "dlpack/dlpack.h":
+    int DLPACK_MAJOR_VERSION
+    int DLPACK_MINOR_VERSION
+
     cdef enum:
         kDLCPU = 1,
         kDLCUDA = 2,
@@ -77,6 +80,47 @@ cdef extern from "dlpack/dlpack.h":
         void (*deleter)(DLManagedTensorVersioned* self)
         uint64_t flags
 
+    # DLPack Exchange API function pointer types
+    ctypedef int (*DLPackManagedTensorAllocator)(
+        DLTensor* prototype,
+        DLManagedTensorVersioned** out,
+        void* error_ctx,
+        void (*SetError)(void* error_ctx, const char* kind, const char* 
message)
+    ) noexcept
+
+    ctypedef int (*DLPackManagedTensorFromPyObjectNoSync)(
+        void* py_object,
+        DLManagedTensorVersioned** out
+    ) noexcept
+
+    ctypedef int (*DLPackManagedTensorToPyObjectNoSync)(
+        DLManagedTensorVersioned* tensor,
+        void** out_py_object
+    ) noexcept
+
+    ctypedef int (*DLPackCurrentWorkStream)(
+        int device_type,
+        int32_t device_id,
+        void** out_current_stream
+    ) noexcept
+
+    ctypedef int (*DLPackDLTensorFromPyObjectNoSync)(
+        void* py_object,
+        DLTensor* out
+    ) noexcept
+
+    ctypedef struct DLPackExchangeAPIHeader:
+        DLPackVersion version
+        DLPackExchangeAPIHeader* prev_api
+
+    ctypedef struct DLPackExchangeAPI:
+        DLPackExchangeAPIHeader header
+        DLPackManagedTensorAllocator managed_tensor_allocator
+        DLPackManagedTensorFromPyObjectNoSync 
managed_tensor_from_py_object_no_sync
+        DLPackManagedTensorToPyObjectNoSync managed_tensor_to_py_object_no_sync
+        DLPackDLTensorFromPyObjectNoSync dltensor_from_py_object_no_sync
+        DLPackCurrentWorkStream current_work_stream
+
 
 # Cython binding for TVM FFI C API
 cdef extern from "tvm/ffi/c_api.h":
@@ -285,14 +329,11 @@ cdef extern from "tvm_ffi_python_helpers.h":
         int device_type
         int device_id
         TVMFFIStreamHandle stream
-        DLPackToPyObject c_dlpack_to_pyobject
-        DLPackTensorAllocator c_dlpack_tensor_allocator
+        const DLPackExchangeAPI* c_dlpack_exchange_api
 
     ctypedef struct TVMFFIPyArgSetter:
         int (*func)(TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,  
PyObject* py_arg, TVMFFIAny* out) except -1
-        DLPackFromPyObject c_dlpack_from_pyobject
-        DLPackToPyObject c_dlpack_to_pyobject
-        DLPackTensorAllocator c_dlpack_tensor_allocator
+        const DLPackExchangeAPI* c_dlpack_exchange_api
 
     ctypedef int (*TVMFFIPyArgSetterFactory)(PyObject* value, 
TVMFFIPyArgSetter* out) except -1
     # The main call function
@@ -303,7 +344,7 @@ cdef extern from "tvm_ffi_python_helpers.h":
         TVMFFIAny* result,
         int* c_api_ret_code,
         int release_gil,
-        DLPackToPyObject* out_dlpack_importer
+        const DLPackExchangeAPI** out_ctx_dlpack_api
     ) except -1
 
     int TVMFFIPyConstructorCall(
diff --git a/python/tvm_ffi/cython/function.pxi 
b/python/tvm_ffi/cython/function.pxi
index a6cc3be..7dcab4b 100644
--- a/python/tvm_ffi/cython/function.pxi
+++ b/python/tvm_ffi/cython/function.pxi
@@ -55,13 +55,13 @@ cdef inline object make_ret_small_bytes(TVMFFIAny result):
     return bytearray_to_bytes(&bytes)
 
 
-cdef inline object make_ret(TVMFFIAny result, DLPackToPyObject 
c_dlpack_to_pyobject = NULL):
+cdef inline object make_ret(TVMFFIAny result, const DLPackExchangeAPI* 
c_ctx_dlpack_api = NULL):
     """convert result to return value."""
     cdef int32_t type_index
     type_index = result.type_index
     if type_index == kTVMFFITensor:
         # specially handle Tensor as it needs a special dltensor field
-        return make_tensor_from_any(result, c_dlpack_to_pyobject)
+        return make_tensor_from_any(result, c_ctx_dlpack_api)
     elif type_index == kTVMFFIOpaquePyObject:
         return make_ret_opaque_object(result)
     elif type_index >= kTVMFFIStaticObjectBegin:
@@ -142,35 +142,39 @@ cdef int TVMFFIPyArgSetterObject_(
     return 0
 
 
-cdef int TVMFFIPyArgSetterDLPackCExporter_(
+cdef int TVMFFIPyArgSetterDLPackExchangeAPI_(
     TVMFFIPyArgSetter* this, TVMFFIPyCallContext* ctx,
     PyObject* arg, TVMFFIAny* out
 ) except -1:
     cdef DLManagedTensorVersioned* temp_managed_tensor
     cdef TVMFFIObjectHandle temp_chandle
-    cdef TVMFFIStreamHandle env_stream = NULL
+    cdef void* current_stream = NULL
+    cdef const DLPackExchangeAPI* api = this.c_dlpack_exchange_api
 
-    if this.c_dlpack_to_pyobject != NULL:
-        ctx.c_dlpack_to_pyobject = this.c_dlpack_to_pyobject
-    if this.c_dlpack_tensor_allocator != NULL:
-        ctx.c_dlpack_tensor_allocator = this.c_dlpack_tensor_allocator
+    # Set the exchange API in context
+    ctx.c_dlpack_exchange_api = api
 
-    if ctx.device_type != -1:
-        # already queried device, do not do it again, pass NULL to stream
-        if (this.c_dlpack_from_pyobject)(arg, &temp_managed_tensor, NULL) != 0:
-            return -1
-    else:
-        # query string on the envrionment stream
-        if (this.c_dlpack_from_pyobject)(arg, &temp_managed_tensor, 
&env_stream) != 0:
-            return -1
-        # If device is not CPU, we should set the device type and id
-        if temp_managed_tensor.dl_tensor.device.device_type != kDLCPU:
-            ctx.stream = env_stream
-            ctx.device_type = temp_managed_tensor.dl_tensor.device.device_type
-            ctx.device_id = temp_managed_tensor.dl_tensor.device.device_id
-    # run conversion
+    # Convert PyObject to DLPack using the struct's function pointer
+    if api.managed_tensor_from_py_object_no_sync(arg, &temp_managed_tensor) != 
0:
+        return -1
+
+    # Query current stream from producer if device is not CPU
+    if temp_managed_tensor.dl_tensor.device.device_type != kDLCPU:
+        if ctx.device_type == -1 and api.current_work_stream != NULL:
+            # First time seeing a device, query the stream
+            if api.current_work_stream(
+                temp_managed_tensor.dl_tensor.device.device_type,
+                temp_managed_tensor.dl_tensor.device.device_id,
+                &current_stream
+            ) == 0:
+                ctx.stream = <TVMFFIStreamHandle>current_stream
+                ctx.device_type = 
temp_managed_tensor.dl_tensor.device.device_type
+                ctx.device_id = temp_managed_tensor.dl_tensor.device.device_id
+
+    # Convert to TVM Tensor
     if TVMFFITensorFromDLPackVersioned(temp_managed_tensor, 0, 0, 
&temp_chandle) != 0:
         raise BufferError("Failed to convert DLManagedTensorVersioned to 
ffi.Tensor")
+
     out.type_index = kTVMFFITensor
     out.v_ptr = temp_chandle
     TVMFFIPyPushTempFFIObject(ctx, temp_chandle)
@@ -179,15 +183,36 @@ cdef int TVMFFIPyArgSetterDLPackCExporter_(
 
 cdef int TorchDLPackToPyObjectFallback_(
     DLManagedTensorVersioned* dltensor, void** py_obj_out
-) except -1:
+) noexcept:
     # a bit convoluted but ok as a fallback
     cdef TVMFFIObjectHandle temp_chandle
-    TVMFFITensorFromDLPackVersioned(dltensor, 0, 0, &temp_chandle)
-    tensor = make_tensor_from_chandle(temp_chandle)
-    torch_tensor = torch.from_dlpack(tensor)
-    Py_INCREF(torch_tensor)
-    py_obj_out[0] = <void*>(<PyObject*>torch_tensor)
-    return 0
+    if TVMFFITensorFromDLPackVersioned(dltensor, 0, 0, &temp_chandle) != 0:
+        return -1
+    try:
+        tensor = make_tensor_from_chandle(temp_chandle)
+        torch_tensor = torch.from_dlpack(tensor)
+        Py_INCREF(torch_tensor)
+        py_obj_out[0] = <void*>(<PyObject*>torch_tensor)
+        return 0
+    except Exception:
+        return -1
+
+cdef inline const DLPackExchangeAPI* GetTorchFallbackExchangeAPI() noexcept:
+    global _torch_fallback_exchange_api
+
+    _torch_fallback_exchange_api.header.version.major = DLPACK_MAJOR_VERSION
+    _torch_fallback_exchange_api.header.version.minor = DLPACK_MINOR_VERSION
+    _torch_fallback_exchange_api.header.prev_api = NULL
+    _torch_fallback_exchange_api.managed_tensor_allocator = NULL
+    _torch_fallback_exchange_api.managed_tensor_from_py_object_no_sync = NULL
+    _torch_fallback_exchange_api.managed_tensor_to_py_object_no_sync = 
TorchDLPackToPyObjectFallback_
+    _torch_fallback_exchange_api.dltensor_from_py_object_no_sync = NULL
+    _torch_fallback_exchange_api.current_work_stream = NULL
+
+    return &_torch_fallback_exchange_api
+
+# Static storage for the fallback exchange API
+cdef DLPackExchangeAPI _torch_fallback_exchange_api
 
 
 cdef int TVMFFIPyArgSetterTorchFallback_(
@@ -202,7 +227,7 @@ cdef int TVMFFIPyArgSetterTorchFallback_(
     out.type_index = kTVMFFITensor
     out.v_ptr = (<Tensor>arg).chandle
     temp_dltensor = TVMFFITensorGetDLTensorPtr((<Tensor>arg).chandle)
-    ctx.c_dlpack_to_pyobject = TorchDLPackToPyObjectFallback_
+    ctx.c_dlpack_exchange_api = GetTorchFallbackExchangeAPI()
     # record the stream and device for torch context
     if is_cuda and ctx.device_type != -1:
         ctx.device_type = temp_dltensor.device.device_type
@@ -546,17 +571,13 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value, 
TVMFFIPyArgSetter* out) exce
         out.func = TVMFFIPyArgSetterObjectRValueRef_
         return 0
     if os.environ.get("TVM_FFI_SKIP_c_dlpack_from_pyobject", "0") != "1":
-        # external tensors
-        if hasattr(arg, "__c_dlpack_from_pyobject__"):
-            out.func = TVMFFIPyArgSetterDLPackCExporter_
-            temp_ptr = arg.__c_dlpack_from_pyobject__
-            out.c_dlpack_from_pyobject = <DLPackFromPyObject>temp_ptr
-            if hasattr(arg, "__c_dlpack_to_pyobject__"):
-                temp_ptr = arg.__c_dlpack_to_pyobject__
-                out.c_dlpack_to_pyobject = <DLPackToPyObject>temp_ptr
-            if hasattr(arg, "__c_dlpack_tensor_allocator__"):
-                temp_ptr = arg.__c_dlpack_tensor_allocator__
-                out.c_dlpack_tensor_allocator = <DLPackTensorAllocator>temp_ptr
+        # Check for DLPackExchangeAPI struct (new approach)
+        # This is checked on the CLASS, not the instance
+        arg_class = type(arg)
+        if hasattr(arg_class, "__c_dlpack_exchange_api__"):
+            out.func = TVMFFIPyArgSetterDLPackExchangeAPI_
+            temp_ptr = arg_class.__c_dlpack_exchange_api__
+            out.c_dlpack_exchange_api = <const DLPackExchangeAPI*>(<long 
long>temp_ptr)
             return 0
     if torch is not None and isinstance(arg, torch.Tensor):
         out.func = TVMFFIPyArgSetterTorchFallback_
@@ -658,7 +679,7 @@ cdef class Function(Object):
     def __call__(self, *args):
         cdef TVMFFIAny result
         cdef int c_api_ret_code
-        cdef DLPackToPyObject c_dlpack_to_pyobject = NULL
+        cdef const DLPackExchangeAPI* c_ctx_dlpack_api = NULL
         # IMPORTANT: caller need to initialize result->type_index to 
kTVMFFINone
         result.type_index = kTVMFFINone
         result.v_int64 = 0
@@ -668,12 +689,12 @@ cdef class Function(Object):
             &result,
             &c_api_ret_code,
             self.release_gil,
-            &c_dlpack_to_pyobject
+            &c_ctx_dlpack_api
         )
         # NOTE: logic is same as check_call
         # directly inline here to simplify the resulting trace
         if c_api_ret_code == 0:
-            return make_ret(result, c_dlpack_to_pyobject)
+            return make_ret(result, c_ctx_dlpack_api)
         elif c_api_ret_code == -2:
             raise_existing_error()
         raise move_from_last_error().py_error()
diff --git a/python/tvm_ffi/cython/tensor.pxi b/python/tvm_ffi/cython/tensor.pxi
index 74b065b..4ebc515 100644
--- a/python/tvm_ffi/cython/tensor.pxi
+++ b/python/tvm_ffi/cython/tensor.pxi
@@ -275,33 +275,74 @@ _set_class_tensor(Tensor)
 _register_object_by_index(kTVMFFITensor, Tensor)
 
 
-cdef int _dltensor_test_wrapper_c_dlpack_from_pyobject(
-    void* obj, DLManagedTensorVersioned** out, TVMFFIStreamHandle* env_stream
+cdef int _dltensor_test_wrapper_from_pyobject(
+    void* obj, DLManagedTensorVersioned** out
 ) except -1:
+    """DLPackExchangeAPI: managed_tensor_from_py_object_no_sync"""
     cdef PyObject* py_obj = <PyObject*>obj
     cdef DLTensorTestWrapper wrapper = <DLTensorTestWrapper>py_obj
-    cdef TVMFFIStreamHandle current_stream
-    cdef DLManagedTensorVersioned* temp_managed_tensor
-    if env_stream != NULL:
-        env_stream[0] = TVMFFIEnvGetStream(
-            wrapper.tensor.cdltensor.device.device_type,
-            wrapper.tensor.cdltensor.device.device_id
-        )
-
     return TVMFFITensorToDLPackVersioned(wrapper.tensor.chandle, out)
 
 
-def _dltensor_test_wrapper_c_dlpack_from_pyobject_as_intptr():
-    cdef DLPackFromPyObject converter_func = 
_dltensor_test_wrapper_c_dlpack_from_pyobject
-    cdef void* temp_ptr = <void*>converter_func
-    cdef long long temp_int_ptr = <long long>temp_ptr
-    return temp_int_ptr
+cdef int _dltensor_test_wrapper_to_pyobject(
+    DLManagedTensorVersioned* tensor, void** out_py_object
+) except -1:
+    """DLPackExchangeAPI: managed_tensor_to_py_object_no_sync"""
+    cdef TVMFFIObjectHandle temp_chandle
+    if TVMFFITensorFromDLPackVersioned(tensor, 0, 0, &temp_chandle) != 0:
+        return -1
+    py_tensor = make_tensor_from_chandle(temp_chandle)
+    Py_INCREF(py_tensor)
+    out_py_object[0] = <void*>(<PyObject*>py_tensor)
+    return 0
+
+
+cdef int _dltensor_test_wrapper_current_work_stream(
+    int device_type, int32_t device_id, void** out_stream
+) except -1:
+    """DLPackExchangeAPI: current_work_stream"""
+    if device_type != kDLCPU:
+        out_stream[0] = <void*>TVMFFIEnvGetStream(device_type, device_id)
+    return 0
+
+
+# Module-level static DLPackExchangeAPI for DLTensorTestWrapper
+cdef DLPackExchangeAPI _dltensor_test_wrapper_static_api
+
+cdef const DLPackExchangeAPI* _dltensor_test_wrapper_get_exchange_api() 
noexcept:
+    """Get the static DLPackExchangeAPI instance for DLTensorTestWrapper."""
+    global _dltensor_test_wrapper_static_api
+
+    # Initialize header using macros from dlpack.h
+    _dltensor_test_wrapper_static_api.header.version.major = 
DLPACK_MAJOR_VERSION
+    _dltensor_test_wrapper_static_api.header.version.minor = 
DLPACK_MINOR_VERSION
+    _dltensor_test_wrapper_static_api.header.prev_api = NULL
+
+    # Initialize function pointers
+    _dltensor_test_wrapper_static_api.managed_tensor_allocator = NULL
+    _dltensor_test_wrapper_static_api.managed_tensor_from_py_object_no_sync = (
+        
<DLPackManagedTensorFromPyObjectNoSync>_dltensor_test_wrapper_from_pyobject
+    )
+    _dltensor_test_wrapper_static_api.managed_tensor_to_py_object_no_sync = (
+        <DLPackManagedTensorToPyObjectNoSync>_dltensor_test_wrapper_to_pyobject
+    )
+    _dltensor_test_wrapper_static_api.dltensor_from_py_object_no_sync = NULL
+    _dltensor_test_wrapper_static_api.current_work_stream = (
+        <DLPackCurrentWorkStream>_dltensor_test_wrapper_current_work_stream
+    )
+
+    return &_dltensor_test_wrapper_static_api
+
+
+def _dltensor_test_wrapper_exchange_api_ptr():
+    """Return the pointer to the DLPackExchangeAPI struct as an integer."""
+    return <long long>_dltensor_test_wrapper_get_exchange_api()
 
 
 cdef class DLTensorTestWrapper:
     """Wrapper of a Tensor that exposes DLPack protocol, only for testing 
purpose.
     """
-    __c_dlpack_from_pyobject__ = 
_dltensor_test_wrapper_c_dlpack_from_pyobject_as_intptr()
+    __c_dlpack_exchange_api__ = _dltensor_test_wrapper_exchange_api_ptr()
 
     cdef Tensor tensor
     cdef dict __dict__
@@ -334,19 +375,21 @@ cdef inline object make_ret_dltensor(TVMFFIAny result):
     return tensor
 
 
-cdef inline object make_tensor_from_chandle(TVMFFIObjectHandle chandle, 
DLPackToPyObject c_dlpack_to_pyobject = NULL):
+cdef inline object make_tensor_from_chandle(
+    TVMFFIObjectHandle chandle, const DLPackExchangeAPI* c_ctx_dlpack_api = 
NULL
+):
     # TODO: Implement
     cdef Tensor tensor
     cdef void* py_obj
     cdef DLManagedTensorVersioned* dlpack
 
-    if c_dlpack_to_pyobject != NULL:
+    if c_ctx_dlpack_api != NULL and 
c_ctx_dlpack_api.managed_tensor_to_py_object_no_sync != NULL:
         # try convert and import into the environment array if possible
         if TVMFFITensorToDLPackVersioned(chandle, &dlpack) == 0:
             try:
                 # note that py_obj already holds an extra reference to the 
tensor
                 # so we need to decref it after the conversion
-                c_dlpack_to_pyobject(dlpack, &py_obj)
+                c_ctx_dlpack_api.managed_tensor_to_py_object_no_sync(dlpack, 
&py_obj)
                 tensor = <Tensor>(<PyObject*>py_obj)
                 Py_DECREF(tensor)
                 # decref original handle to prevent leak.
@@ -365,5 +408,5 @@ cdef inline object 
make_tensor_from_chandle(TVMFFIObjectHandle chandle, DLPackTo
     return tensor
 
 
-cdef inline object make_tensor_from_any(TVMFFIAny any, DLPackToPyObject 
c_dlpack_to_pyobject):
-    return make_tensor_from_chandle(any.v_ptr, c_dlpack_to_pyobject)
+cdef inline object make_tensor_from_any(TVMFFIAny any, const 
DLPackExchangeAPI* c_ctx_dlpack_api):
+    return make_tensor_from_chandle(any.v_ptr, c_ctx_dlpack_api)
diff --git a/python/tvm_ffi/cython/tvm_ffi_python_helpers.h 
b/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
index ea60d51..da2404b 100644
--- a/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
+++ b/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
@@ -89,10 +89,8 @@ struct TVMFFIPyCallContext {
   void** temp_py_objects = nullptr;
   /*! \brief the number of temporary arguments */
   int num_temp_py_objects = 0;
-  /*! \brief the DLPack exporter, if any */
-  DLPackToPyObject c_dlpack_to_pyobject{nullptr};
-  /*! \brief the DLPack allocator, if any */
-  DLPackTensorAllocator c_dlpack_tensor_allocator{nullptr};
+  /*! \brief the DLPack exchange API, if any */
+  const DLPackExchangeAPI* c_dlpack_exchange_api{nullptr};
 };
 
 /*! \brief Argument setter for a given python argument. */
@@ -108,17 +106,10 @@ struct TVMFFIPyArgSetter {
   int (*func)(TVMFFIPyArgSetter* self, TVMFFIPyCallContext* call_ctx, 
PyObject* arg,
               TVMFFIAny* out);
   /*!
-   * \brief Optional DLPack exporter for for setters that leverages DLPack 
protocol.
+   * \brief Optional DLPackExchangeAPI struct pointer.
+   * This is the new struct-based approach that bundles all DLPack exchange 
functions.
    */
-  DLPackFromPyObject c_dlpack_from_pyobject{nullptr};
-  /*!
-   * \brief Optional DLPack importer for for setters that leverages DLPack 
protocol.
-   */
-  DLPackToPyObject c_dlpack_to_pyobject{nullptr};
-  /*!
-   * \brief Optional DLPack allocator for for setters that leverages DLPack 
protocol.
-   */
-  DLPackTensorAllocator c_dlpack_tensor_allocator{nullptr};
+  const DLPackExchangeAPI* c_dlpack_exchange_api{nullptr};
   /*!
    * \brief Invoke the setter.
    * \param call_ctx The call context.
@@ -273,13 +264,14 @@ class TVMFFIPyCallManager {
    * \param result The result of the function
    * \param c_api_ret_code The return code of the C-call
    * \param release_gil Whether to release the GIL
-   * \param optional_out_dlpack_importer The DLPack importer to be used for 
the result
+   * \param optional_out_ctx_dlpack_api The DLPack exchange API to be used for 
the result
    * \return 0 on when there is no python error, -1 on python error
    * \note When an error happens on FFI side, we should return 0 and set 
c_api_ret_code
    */
   TVM_FFI_INLINE int FuncCall(TVMFFIPyArgSetterFactory setter_factory, void* 
func_handle,
                               PyObject* py_arg_tuple, TVMFFIAny* result, int* 
c_api_ret_code,
-                              bool release_gil, DLPackToPyObject* 
optional_out_dlpack_importer) {
+                              bool release_gil,
+                              const DLPackExchangeAPI** 
optional_out_ctx_dlpack_api) {
     int64_t num_args = PyTuple_Size(py_arg_tuple);
     if (num_args == -1) return -1;
     try {
@@ -300,9 +292,10 @@ class TVMFFIPyCallManager {
         // setting failed, directly return
         if (c_api_ret_code[0] != 0) return 0;
       }
-      if (ctx.c_dlpack_tensor_allocator != nullptr) {
-        c_api_ret_code[0] =
-            TVMFFIEnvSetTensorAllocator(ctx.c_dlpack_tensor_allocator, 0, 
&prev_tensor_allocator);
+      if (ctx.c_dlpack_exchange_api != nullptr &&
+          ctx.c_dlpack_exchange_api->managed_tensor_allocator != nullptr) {
+        c_api_ret_code[0] = TVMFFIEnvSetTensorAllocator(
+            ctx.c_dlpack_exchange_api->managed_tensor_allocator, 0, 
&prev_tensor_allocator);
         if (c_api_ret_code[0] != 0) return 0;
       }
       // call the function
@@ -323,12 +316,13 @@ class TVMFFIPyCallManager {
           return -1;
         }
       }
-      if (prev_tensor_allocator != ctx.c_dlpack_tensor_allocator) {
+      if (ctx.c_dlpack_exchange_api != nullptr &&
+          prev_tensor_allocator != 
ctx.c_dlpack_exchange_api->managed_tensor_allocator) {
         c_api_ret_code[0] = TVMFFIEnvSetTensorAllocator(prev_tensor_allocator, 
0, nullptr);
         if (c_api_ret_code[0] != 0) return 0;
       }
-      if (optional_out_dlpack_importer != nullptr && ctx.c_dlpack_to_pyobject 
!= nullptr) {
-        *optional_out_dlpack_importer = ctx.c_dlpack_to_pyobject;
+      if (optional_out_ctx_dlpack_api != nullptr && ctx.c_dlpack_exchange_api 
!= nullptr) {
+        *optional_out_ctx_dlpack_api = ctx.c_dlpack_exchange_api;
       }
       return 0;
     } catch (const std::exception& ex) {
@@ -379,13 +373,9 @@ class TVMFFIPyCallManager {
           parent_ctx->device_id = ctx.device_id;
           parent_ctx->stream = ctx.stream;
         }
-        // DLPack allocator
-        if (parent_ctx->c_dlpack_tensor_allocator == nullptr) {
-          parent_ctx->c_dlpack_tensor_allocator = 
ctx.c_dlpack_tensor_allocator;
-        }
-        // DLPack importer
-        if (parent_ctx->c_dlpack_to_pyobject == nullptr) {
-          parent_ctx->c_dlpack_to_pyobject = ctx.c_dlpack_to_pyobject;
+        // DLPack exchange API
+        if (parent_ctx->c_dlpack_exchange_api == nullptr) {
+          parent_ctx->c_dlpack_exchange_api = ctx.c_dlpack_exchange_api;
         }
       }
       return 0;
@@ -490,16 +480,16 @@ class TVMFFIPyCallManager {
  * \param result The result of the function
  * \param c_api_ret_code The return code of the function
  * \param release_gil Whether to release the GIL
- * \param out_dlpack_exporter The DLPack exporter to be used for the result
+ * \param out_ctx_dlpack_api The DLPack exchange API to be used for the result
  * \return 0 on success, nonzero on failure
  */
 TVM_FFI_INLINE int TVMFFIPyFuncCall(TVMFFIPyArgSetterFactory setter_factory, 
void* func_handle,
                                     PyObject* py_arg_tuple, TVMFFIAny* result, 
int* c_api_ret_code,
                                     bool release_gil = true,
-                                    DLPackToPyObject* out_dlpack_importer = 
nullptr) {
+                                    const DLPackExchangeAPI** 
out_ctx_dlpack_api = nullptr) {
   return TVMFFIPyCallManager::ThreadLocal()->FuncCall(setter_factory, 
func_handle, py_arg_tuple,
                                                       result, c_api_ret_code, 
release_gil,
-                                                      out_dlpack_importer);
+                                                      out_ctx_dlpack_api);
 }
 
 /*!
diff --git a/tests/python/test_load_inline.py b/tests/python/test_load_inline.py
index b30471b..795e691 100644
--- a/tests/python/test_load_inline.py
+++ b/tests/python/test_load_inline.py
@@ -213,8 +213,8 @@ def test_load_inline_cuda() -> None:
 @pytest.mark.skipif(torch is None, reason="Requires torch")
 def test_load_inline_with_env_tensor_allocator() -> None:
     assert torch is not None
-    if not hasattr(torch.Tensor, "__c_dlpack_tensor_allocator__"):
-        pytest.skip("Torch does not support __c_dlpack_tensor_allocator__")
+    if not hasattr(torch.Tensor, "__c_dlpack_exchange_api__"):
+        pytest.skip("Torch does not support __c_dlpack_exchange_api__")
     mod: Module = tvm_ffi.cpp.load_inline(
         name="hello",
         cpp_sources=r"""

Reply via email to