This is an automated email from the ASF dual-hosted git repository.
yaxingcai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 7f3f872 [DLPack] Leverage exchange api when possible (#260)
7f3f872 is described below
commit 7f3f8726156ab6e33f781562afafd9c6f219551f
Author: Tianqi Chen <[email protected]>
AuthorDate: Wed Nov 12 14:53:20 2025 -0500
[DLPack] Leverage exchange api when possible (#260)
This PR updates the from_dlpack to leverage the exchange api when
possible.
---
python/tvm_ffi/cython/function.pxi | 13 +++++----
python/tvm_ffi/cython/tensor.pxi | 56 +++++++++++++++++++++++++-------------
2 files changed, 45 insertions(+), 24 deletions(-)
diff --git a/python/tvm_ffi/cython/function.pxi
b/python/tvm_ffi/cython/function.pxi
index c1fb6a2..800427b 100644
--- a/python/tvm_ffi/cython/function.pxi
+++ b/python/tvm_ffi/cython/function.pxi
@@ -156,20 +156,20 @@ cdef int TVMFFIPyArgSetterDLPackExchangeAPI_(
cdef DLManagedTensorVersioned* temp_managed_tensor
cdef TVMFFIObjectHandle temp_chandle
cdef void* current_stream = NULL
- cdef const DLPackExchangeAPI* api = this.c_dlpack_exchange_api
+ cdef const DLPackExchangeAPI* exchange_api = this.c_dlpack_exchange_api
# Set the exchange API in context
- ctx.c_dlpack_exchange_api = api
+ ctx.c_dlpack_exchange_api = exchange_api
# Convert PyObject to DLPack using the struct's function pointer
- if api.managed_tensor_from_py_object_no_sync(arg, &temp_managed_tensor) !=
0:
+ if exchange_api.managed_tensor_from_py_object_no_sync(arg,
&temp_managed_tensor) != 0:
return -1
# Query current stream from producer if device is not CPU
if temp_managed_tensor.dl_tensor.device.device_type != kDLCPU:
- if ctx.device_type == -1 and api.current_work_stream != NULL:
+ if ctx.device_type == -1 and exchange_api.current_work_stream != NULL:
# First time seeing a device, query the stream
- if api.current_work_stream(
+ if exchange_api.current_work_stream(
temp_managed_tensor.dl_tensor.device.device_type,
temp_managed_tensor.dl_tensor.device.device_id,
¤t_stream
@@ -180,6 +180,9 @@ cdef int TVMFFIPyArgSetterDLPackExchangeAPI_(
# Convert to TVM Tensor
if TVMFFITensorFromDLPackVersioned(temp_managed_tensor, 0, 0,
&temp_chandle) != 0:
+ # recycle the managed tensor to avoid leak
+ if temp_managed_tensor.deleter != NULL:
+ temp_managed_tensor.deleter(temp_managed_tensor)
raise BufferError("Failed to convert DLManagedTensorVersioned to
ffi.Tensor")
out.type_index = kTVMFFITensor
diff --git a/python/tvm_ffi/cython/tensor.pxi b/python/tvm_ffi/cython/tensor.pxi
index 6a9cf9e..614e487 100644
--- a/python/tvm_ffi/cython/tensor.pxi
+++ b/python/tvm_ffi/cython/tensor.pxi
@@ -45,20 +45,6 @@ cdef void _c_dlpack_versioned_deleter(object pycaps):
dltensor.deleter(dltensor)
-cdef inline object _from_dlpack_intptr(
- void* dlpack
-):
- cdef TVMFFIObjectHandle chandle
- cdef DLManagedTensor* ptr = <DLManagedTensor*>dlpack
- cdef int c_api_ret_code
- cdef int c_req_alignment = 0
- cdef int c_req_contiguous = 0
- c_api_ret_code = TVMFFITensorFromDLPack(
- ptr, c_req_alignment, c_req_contiguous, &chandle)
- CHECK_CALL(c_api_ret_code)
- return make_tensor_from_chandle(chandle)
-
-
cdef inline int _from_dlpack(
object dltensor, int require_alignment,
int require_contiguous, TVMFFIObjectHandle* out
@@ -100,6 +86,26 @@ cdef inline int _from_dlpack_versioned(
raise ValueError("Expect a dltensor_versioned field, PyCapsule can only be
consumed once")
+cdef inline int _from_dlpack_exchange_api(
+ object ext_tensor, DLPackExchangeAPI* exchange_api, int require_alignment,
+ int require_contiguous, TVMFFIObjectHandle* out
+) except -1:
+ cdef DLManagedTensorVersioned* temp_managed_tensor
+ cdef PyObject* ext_tensor_pyobj = <PyObject*>ext_tensor
+ if exchange_api.managed_tensor_from_py_object_no_sync(ext_tensor_pyobj,
&temp_managed_tensor) != 0:
+ return -1
+
+ # Convert to TVM Tensor
+ if TVMFFITensorFromDLPackVersioned(
+ temp_managed_tensor, require_alignment, require_contiguous, out
+ ) != 0:
+ # recycle the managed tensor to avoid leak
+ if temp_managed_tensor.deleter != NULL:
+ temp_managed_tensor.deleter(temp_managed_tensor)
+ raise BufferError("Failed to convert DLManagedTensorVersioned to
ffi.Tensor")
+
+ return 0
+
cdef inline int _from_dlpack_universal(
object ext_tensor, int require_alignment,
int require_contiguous, TVMFFIObjectHandle* out
@@ -108,9 +114,21 @@ cdef inline int _from_dlpack_universal(
# move to false as most frameworks get upgraded.
cdef int favor_legacy_dlpack = True
+ if hasattr(ext_tensor, "__c_dlpack_exchange_api__"):
+ try:
+ return _from_dlpack_exchange_api(
+ ext_tensor,
+ <DLPackExchangeAPI*><long
long>(ext_tensor.__c_dlpack_exchange_api__),
+ require_alignment,
+ require_contiguous,
+ out
+ )
+ except BufferError:
+ pass
+
if hasattr(ext_tensor, "__dlpack__"):
if favor_legacy_dlpack:
- _from_dlpack(
+ return _from_dlpack(
ext_tensor.__dlpack__(),
require_alignment,
require_contiguous,
@@ -118,14 +136,14 @@ cdef inline int _from_dlpack_universal(
)
else:
try:
- _from_dlpack_versioned(
+ return _from_dlpack_versioned(
ext_tensor.__dlpack__(max_version=__dlpack_version__),
require_alignment,
require_contiguous,
out
)
except TypeError:
- _from_dlpack(
+ return _from_dlpack(
ext_tensor.__dlpack__(),
require_alignment,
require_contiguous,
@@ -133,14 +151,14 @@ cdef inline int _from_dlpack_universal(
)
else:
if pycapsule.PyCapsule_IsValid(ext_tensor, _c_str_dltensor_versioned):
- _from_dlpack_versioned(
+ return _from_dlpack_versioned(
ext_tensor,
require_alignment,
require_contiguous,
out
)
elif pycapsule.PyCapsule_IsValid(ext_tensor, _c_str_dltensor):
- _from_dlpack(
+ return _from_dlpack(
ext_tensor,
require_alignment,
require_contiguous,