This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 22a7894 [DLPack] C Functions for DLPack Speed Exchange and Stream
Handling (#96)
22a7894 is described below
commit 22a78943b78306a73011757fa635afa9dce35114
Author: Kathryn (Jinqi) Chen <[email protected]>
AuthorDate: Sat Oct 11 14:09:54 2025 -0700
[DLPack] C Functions for DLPack Speed Exchange and Stream Handling (#96)
## Summary of Changes
This PR introduces a unified `DLPackExchangeAPI` struct as described in
proposal [175](https://github.com/dmlc/dlpack/issues/175). This new
convention replaces the previous mechanism of separate function
pointers, and aligns with the latest DLPack standard as shown in PR
[174](https://github.com/dmlc/dlpack/pull/174).
Within the new `DLPackExchangeAPI` struct, it also includes a
`current_work_stream` function pointer that allows more robust and
integrated querying of the current device stream (e.g., CUDA stream)
during DLPack tensor exchanges. All the conversion from/to DLPack has
been updated to `_no_sync`, meaning you should use `current_work_stream`
to explicitly handle stream synchronization. It also includes a
non-owning DLTensor conversion to avoid unnecessary reference counting.
Following this change, the Python FFI for PyTorch has been updated to
expose the new `DLPackExchangeAPI` struct via
`__c_dlpack_exchange_api__` on torch.Tensor.
The `3rdparty/dlpack` has been updated to incorporate the latest commit.
---
3rdparty/dlpack | 2 +-
python/tvm_ffi/_optional_torch_c_dlpack.py | 89 ++++++++++++++------
python/tvm_ffi/core.pyi | 2 +-
python/tvm_ffi/cython/base.pxi | 53 ++++++++++--
python/tvm_ffi/cython/function.pxi | 109 +++++++++++++++----------
python/tvm_ffi/cython/tensor.pxi | 85 ++++++++++++++-----
python/tvm_ffi/cython/tvm_ffi_python_helpers.h | 54 +++++-------
tests/python/test_load_inline.py | 4 +-
8 files changed, 268 insertions(+), 130 deletions(-)
diff --git a/3rdparty/dlpack b/3rdparty/dlpack
index addbc8b..1117366 160000
--- a/3rdparty/dlpack
+++ b/3rdparty/dlpack
@@ -1 +1 @@
-Subproject commit addbc8b3d9449691d01827ac4a0e0d035cf8ea40
+Subproject commit 111736618e8d1028b23605f76dcaa6a38cfea809
diff --git a/python/tvm_ffi/_optional_torch_c_dlpack.py
b/python/tvm_ffi/_optional_torch_c_dlpack.py
index c1bb1ef..94bb3d7 100644
--- a/python/tvm_ffi/_optional_torch_c_dlpack.py
+++ b/python/tvm_ffi/_optional_torch_c_dlpack.py
@@ -464,18 +464,41 @@ at::Tensor fromDLPackImpl(T* src,
std::function<void(void*)> deleter) {
{device});
}
+void toDLPackNonOwningImpl(const Tensor& tensor, DLTensor& out) {
+ // Fill in the pre-allocated DLTensor struct with direct pointers
+ // This is a non-owning conversion - the caller owns the tensor
+ // and must keep it alive for the duration of DLTensor usage
+ out.data = tensor.data_ptr();
+ out.device = torchDeviceToDLDeviceForDLPackv1(tensor.device());
+ out.ndim = static_cast<int32_t>(tensor.dim());
+ out.dtype = getDLDataTypeForDLPackv1(tensor);
+ // sizes() and strides() return pointers to TensorImpl's stable storage
+ // which remains valid as long as the tensor is alive
+ out.shape = const_cast<int64_t*>(tensor.sizes().data());
+ out.strides = const_cast<int64_t*>(tensor.strides().data());
+ out.byte_offset = 0;
+}
+
} // namespace
} // namespace at
-int TorchDLPackFromPyObject(void* py_obj, DLManagedTensorVersioned** out,
void** env_stream) {
+int TorchDLPackDLTensorFromPyObjectNoSync(void* py_obj, DLTensor* out) {
+ try {
+ // Use handle (non-owning) to avoid unnecessary refcount operations
+ py::handle handle(static_cast<PyObject*>(py_obj));
+ at::Tensor tensor = handle.cast<at::Tensor>();
+ at::toDLPackNonOwningImpl(tensor, *out);
+ return 0;
+ } catch (const std::exception& e) {
+ PyErr_SetString(PyExc_RuntimeError, e.what());
+ return -1;
+ }
+}
+
+int TorchDLPackManagedTensorFromPyObjectNoSync(void* py_obj,
DLManagedTensorVersioned** out) {
try {
py::handle handle(static_cast<PyObject*>(py_obj));
at::Tensor tensor = handle.cast<at::Tensor>();
-#ifdef BUILD_WITH_CUDA
- if (env_stream != nullptr && tensor.is_cuda()) {
- *env_stream =
at::cuda::getCurrentCUDAStream(tensor.device().index()).stream();
- }
-#endif
*out = at::toDLPackImpl<DLManagedTensorVersioned>(tensor);
return 0;
} catch (const std::exception& e) {
@@ -484,7 +507,7 @@ int TorchDLPackFromPyObject(void* py_obj,
DLManagedTensorVersioned** out, void**
}
}
-int TorchDLPackToPyObject(DLManagedTensorVersioned* src, void** py_obj_out) {
+int TorchDLPackManagedTensorToPyObjectNoSync(DLManagedTensorVersioned* src,
void** py_obj_out) {
try {
at::Tensor tensor = at::fromDLPackImpl<DLManagedTensorVersioned>(src,
nullptr);
*py_obj_out = THPVariable_Wrap(tensor);
@@ -495,7 +518,7 @@ int TorchDLPackToPyObject(DLManagedTensorVersioned* src,
void** py_obj_out) {
}
}
-int TorchDLPackTensorAllocator(
+int TorchDLPackManagedTensorAllocator(
DLTensor* prototype, DLManagedTensorVersioned** out, void* error_ctx,
void (*SetError)(void* error_ctx, const char* kind, const char* message)
) {
@@ -508,21 +531,45 @@ int TorchDLPackTensorAllocator(
*out = at::toDLPackImpl<DLManagedTensorVersioned>(tensor);
return 0;
} catch (const std::exception& e) {
- SetError(error_ctx, "TorchDLPackTensorAllocator", e.what());
+ SetError(error_ctx, "TorchDLPackManagedTensorAllocator", e.what());
return -1;
}
}
-int64_t TorchDLPackFromPyObjectPtr() {
- return reinterpret_cast<int64_t>(TorchDLPackFromPyObject);
+int TorchDLPackCurrentWorkStream(DLDeviceType device_type, int32_t device_id,
void** out_stream) {
+ try {
+#ifdef BUILD_WITH_CUDA
+ if (device_type == kDLCUDA || device_type == kDLROCM) {
+ *out_stream = at::cuda::getCurrentCUDAStream(device_id).stream();
+ }
+#endif
+ return 0;
+ } catch (const std::exception& e) {
+ PyErr_SetString(PyExc_RuntimeError, e.what());
+ return -1;
+ }
}
-int64_t TorchDLPackToPyObjectPtr() {
- return reinterpret_cast<int64_t>(TorchDLPackToPyObject);
-}
+struct TorchDLPackExchangeAPI : public DLPackExchangeAPI {
+ TorchDLPackExchangeAPI() {
+ header.version.major = DLPACK_MAJOR_VERSION;
+ header.version.minor = DLPACK_MINOR_VERSION;
+ header.prev_api = nullptr;
+ managed_tensor_allocator = TorchDLPackManagedTensorAllocator;
+ managed_tensor_from_py_object_no_sync =
TorchDLPackManagedTensorFromPyObjectNoSync;
+ managed_tensor_to_py_object_no_sync =
TorchDLPackManagedTensorToPyObjectNoSync;
+ dltensor_from_py_object_no_sync = TorchDLPackDLTensorFromPyObjectNoSync;
+ current_work_stream = TorchDLPackCurrentWorkStream;
+ }
+
+ static const DLPackExchangeAPI* Global() {
+ static TorchDLPackExchangeAPI inst;
+ return &inst;
+ }
+};
-int64_t TorchDLPackTensorAllocatorPtr() {
- return reinterpret_cast<int64_t>(TorchDLPackTensorAllocator);
+int64_t TorchDLPackExchangeAPIPtr() {
+ return reinterpret_cast<int64_t>(TorchDLPackExchangeAPI::Global());
}
"""
try:
@@ -541,17 +588,13 @@ int64_t TorchDLPackTensorAllocatorPtr() {
name="c_dlpack",
cpp_sources=cpp_source,
functions=[
- "TorchDLPackFromPyObjectPtr",
- "TorchDLPackToPyObjectPtr",
- "TorchDLPackTensorAllocatorPtr",
+ "TorchDLPackExchangeAPIPtr",
],
extra_cflags=extra_cflags,
extra_include_paths=include_paths,
)
- # set the dlpack related flags
- setattr(torch.Tensor, "__c_dlpack_from_pyobject__",
mod.TorchDLPackFromPyObjectPtr())
- setattr(torch.Tensor, "__c_dlpack_to_pyobject__",
mod.TorchDLPackToPyObjectPtr())
- setattr(torch.Tensor, "__c_dlpack_tensor_allocator__",
mod.TorchDLPackTensorAllocatorPtr())
+ # Set the DLPackExchangeAPI pointer on the class
+ setattr(torch.Tensor, "__c_dlpack_exchange_api__",
mod.TorchDLPackExchangeAPIPtr())
return mod
except ImportError:
pass
diff --git a/python/tvm_ffi/core.pyi b/python/tvm_ffi/core.pyi
index 092515a..05e5a04 100644
--- a/python/tvm_ffi/core.pyi
+++ b/python/tvm_ffi/core.pyi
@@ -485,7 +485,7 @@ def from_dlpack(
class DLTensorTestWrapper:
"""Wrapper of a Tensor that exposes DLPack protocol, only for testing
purpose."""
- __c_dlpack_from_pyobject__: int
+ __c_dlpack_exchange_api__: int
def __init__(self, tensor: Tensor) -> None: ...
def __tvm_ffi_env_stream__(self) -> int: ...
def __dlpack_device__(self) -> tuple[int, int]: ...
diff --git a/python/tvm_ffi/cython/base.pxi b/python/tvm_ffi/cython/base.pxi
index dcc338a..a86ea00 100644
--- a/python/tvm_ffi/cython/base.pxi
+++ b/python/tvm_ffi/cython/base.pxi
@@ -25,6 +25,9 @@ from cpython cimport pycapsule, PyCapsule_Destructor
from cpython cimport PyErr_SetNone
cdef extern from "dlpack/dlpack.h":
+ int DLPACK_MAJOR_VERSION
+ int DLPACK_MINOR_VERSION
+
cdef enum:
kDLCPU = 1,
kDLCUDA = 2,
@@ -77,6 +80,47 @@ cdef extern from "dlpack/dlpack.h":
void (*deleter)(DLManagedTensorVersioned* self)
uint64_t flags
+ # DLPack Exchange API function pointer types
+ ctypedef int (*DLPackManagedTensorAllocator)(
+ DLTensor* prototype,
+ DLManagedTensorVersioned** out,
+ void* error_ctx,
+ void (*SetError)(void* error_ctx, const char* kind, const char*
message)
+ ) noexcept
+
+ ctypedef int (*DLPackManagedTensorFromPyObjectNoSync)(
+ void* py_object,
+ DLManagedTensorVersioned** out
+ ) noexcept
+
+ ctypedef int (*DLPackManagedTensorToPyObjectNoSync)(
+ DLManagedTensorVersioned* tensor,
+ void** out_py_object
+ ) noexcept
+
+ ctypedef int (*DLPackCurrentWorkStream)(
+ int device_type,
+ int32_t device_id,
+ void** out_current_stream
+ ) noexcept
+
+ ctypedef int (*DLPackDLTensorFromPyObjectNoSync)(
+ void* py_object,
+ DLTensor* out
+ ) noexcept
+
+ ctypedef struct DLPackExchangeAPIHeader:
+ DLPackVersion version
+ DLPackExchangeAPIHeader* prev_api
+
+ ctypedef struct DLPackExchangeAPI:
+ DLPackExchangeAPIHeader header
+ DLPackManagedTensorAllocator managed_tensor_allocator
+ DLPackManagedTensorFromPyObjectNoSync
managed_tensor_from_py_object_no_sync
+ DLPackManagedTensorToPyObjectNoSync managed_tensor_to_py_object_no_sync
+ DLPackDLTensorFromPyObjectNoSync dltensor_from_py_object_no_sync
+ DLPackCurrentWorkStream current_work_stream
+
# Cython binding for TVM FFI C API
cdef extern from "tvm/ffi/c_api.h":
@@ -285,14 +329,11 @@ cdef extern from "tvm_ffi_python_helpers.h":
int device_type
int device_id
TVMFFIStreamHandle stream
- DLPackToPyObject c_dlpack_to_pyobject
- DLPackTensorAllocator c_dlpack_tensor_allocator
+ const DLPackExchangeAPI* c_dlpack_exchange_api
ctypedef struct TVMFFIPyArgSetter:
int (*func)(TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
PyObject* py_arg, TVMFFIAny* out) except -1
- DLPackFromPyObject c_dlpack_from_pyobject
- DLPackToPyObject c_dlpack_to_pyobject
- DLPackTensorAllocator c_dlpack_tensor_allocator
+ const DLPackExchangeAPI* c_dlpack_exchange_api
ctypedef int (*TVMFFIPyArgSetterFactory)(PyObject* value,
TVMFFIPyArgSetter* out) except -1
# The main call function
@@ -303,7 +344,7 @@ cdef extern from "tvm_ffi_python_helpers.h":
TVMFFIAny* result,
int* c_api_ret_code,
int release_gil,
- DLPackToPyObject* out_dlpack_importer
+ const DLPackExchangeAPI** out_ctx_dlpack_api
) except -1
int TVMFFIPyConstructorCall(
diff --git a/python/tvm_ffi/cython/function.pxi
b/python/tvm_ffi/cython/function.pxi
index a6cc3be..7dcab4b 100644
--- a/python/tvm_ffi/cython/function.pxi
+++ b/python/tvm_ffi/cython/function.pxi
@@ -55,13 +55,13 @@ cdef inline object make_ret_small_bytes(TVMFFIAny result):
return bytearray_to_bytes(&bytes)
-cdef inline object make_ret(TVMFFIAny result, DLPackToPyObject
c_dlpack_to_pyobject = NULL):
+cdef inline object make_ret(TVMFFIAny result, const DLPackExchangeAPI*
c_ctx_dlpack_api = NULL):
"""convert result to return value."""
cdef int32_t type_index
type_index = result.type_index
if type_index == kTVMFFITensor:
# specially handle Tensor as it needs a special dltensor field
- return make_tensor_from_any(result, c_dlpack_to_pyobject)
+ return make_tensor_from_any(result, c_ctx_dlpack_api)
elif type_index == kTVMFFIOpaquePyObject:
return make_ret_opaque_object(result)
elif type_index >= kTVMFFIStaticObjectBegin:
@@ -142,35 +142,39 @@ cdef int TVMFFIPyArgSetterObject_(
return 0
-cdef int TVMFFIPyArgSetterDLPackCExporter_(
+cdef int TVMFFIPyArgSetterDLPackExchangeAPI_(
TVMFFIPyArgSetter* this, TVMFFIPyCallContext* ctx,
PyObject* arg, TVMFFIAny* out
) except -1:
cdef DLManagedTensorVersioned* temp_managed_tensor
cdef TVMFFIObjectHandle temp_chandle
- cdef TVMFFIStreamHandle env_stream = NULL
+ cdef void* current_stream = NULL
+ cdef const DLPackExchangeAPI* api = this.c_dlpack_exchange_api
- if this.c_dlpack_to_pyobject != NULL:
- ctx.c_dlpack_to_pyobject = this.c_dlpack_to_pyobject
- if this.c_dlpack_tensor_allocator != NULL:
- ctx.c_dlpack_tensor_allocator = this.c_dlpack_tensor_allocator
+ # Set the exchange API in context
+ ctx.c_dlpack_exchange_api = api
- if ctx.device_type != -1:
- # already queried device, do not do it again, pass NULL to stream
- if (this.c_dlpack_from_pyobject)(arg, &temp_managed_tensor, NULL) != 0:
- return -1
- else:
- # query string on the envrionment stream
- if (this.c_dlpack_from_pyobject)(arg, &temp_managed_tensor,
&env_stream) != 0:
- return -1
- # If device is not CPU, we should set the device type and id
- if temp_managed_tensor.dl_tensor.device.device_type != kDLCPU:
- ctx.stream = env_stream
- ctx.device_type = temp_managed_tensor.dl_tensor.device.device_type
- ctx.device_id = temp_managed_tensor.dl_tensor.device.device_id
- # run conversion
+ # Convert PyObject to DLPack using the struct's function pointer
+ if api.managed_tensor_from_py_object_no_sync(arg, &temp_managed_tensor) !=
0:
+ return -1
+
+ # Query current stream from producer if device is not CPU
+ if temp_managed_tensor.dl_tensor.device.device_type != kDLCPU:
+ if ctx.device_type == -1 and api.current_work_stream != NULL:
+ # First time seeing a device, query the stream
+ if api.current_work_stream(
+ temp_managed_tensor.dl_tensor.device.device_type,
+ temp_managed_tensor.dl_tensor.device.device_id,
+ ¤t_stream
+ ) == 0:
+ ctx.stream = <TVMFFIStreamHandle>current_stream
+ ctx.device_type =
temp_managed_tensor.dl_tensor.device.device_type
+ ctx.device_id = temp_managed_tensor.dl_tensor.device.device_id
+
+ # Convert to TVM Tensor
if TVMFFITensorFromDLPackVersioned(temp_managed_tensor, 0, 0,
&temp_chandle) != 0:
raise BufferError("Failed to convert DLManagedTensorVersioned to
ffi.Tensor")
+
out.type_index = kTVMFFITensor
out.v_ptr = temp_chandle
TVMFFIPyPushTempFFIObject(ctx, temp_chandle)
@@ -179,15 +183,36 @@ cdef int TVMFFIPyArgSetterDLPackCExporter_(
cdef int TorchDLPackToPyObjectFallback_(
DLManagedTensorVersioned* dltensor, void** py_obj_out
-) except -1:
+) noexcept:
# a bit convoluted but ok as a fallback
cdef TVMFFIObjectHandle temp_chandle
- TVMFFITensorFromDLPackVersioned(dltensor, 0, 0, &temp_chandle)
- tensor = make_tensor_from_chandle(temp_chandle)
- torch_tensor = torch.from_dlpack(tensor)
- Py_INCREF(torch_tensor)
- py_obj_out[0] = <void*>(<PyObject*>torch_tensor)
- return 0
+ if TVMFFITensorFromDLPackVersioned(dltensor, 0, 0, &temp_chandle) != 0:
+ return -1
+ try:
+ tensor = make_tensor_from_chandle(temp_chandle)
+ torch_tensor = torch.from_dlpack(tensor)
+ Py_INCREF(torch_tensor)
+ py_obj_out[0] = <void*>(<PyObject*>torch_tensor)
+ return 0
+ except Exception:
+ return -1
+
+cdef inline const DLPackExchangeAPI* GetTorchFallbackExchangeAPI() noexcept:
+ global _torch_fallback_exchange_api
+
+ _torch_fallback_exchange_api.header.version.major = DLPACK_MAJOR_VERSION
+ _torch_fallback_exchange_api.header.version.minor = DLPACK_MINOR_VERSION
+ _torch_fallback_exchange_api.header.prev_api = NULL
+ _torch_fallback_exchange_api.managed_tensor_allocator = NULL
+ _torch_fallback_exchange_api.managed_tensor_from_py_object_no_sync = NULL
+ _torch_fallback_exchange_api.managed_tensor_to_py_object_no_sync =
TorchDLPackToPyObjectFallback_
+ _torch_fallback_exchange_api.dltensor_from_py_object_no_sync = NULL
+ _torch_fallback_exchange_api.current_work_stream = NULL
+
+ return &_torch_fallback_exchange_api
+
+# Static storage for the fallback exchange API
+cdef DLPackExchangeAPI _torch_fallback_exchange_api
cdef int TVMFFIPyArgSetterTorchFallback_(
@@ -202,7 +227,7 @@ cdef int TVMFFIPyArgSetterTorchFallback_(
out.type_index = kTVMFFITensor
out.v_ptr = (<Tensor>arg).chandle
temp_dltensor = TVMFFITensorGetDLTensorPtr((<Tensor>arg).chandle)
- ctx.c_dlpack_to_pyobject = TorchDLPackToPyObjectFallback_
+ ctx.c_dlpack_exchange_api = GetTorchFallbackExchangeAPI()
# record the stream and device for torch context
if is_cuda and ctx.device_type != -1:
ctx.device_type = temp_dltensor.device.device_type
@@ -546,17 +571,13 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value,
TVMFFIPyArgSetter* out) exce
out.func = TVMFFIPyArgSetterObjectRValueRef_
return 0
if os.environ.get("TVM_FFI_SKIP_c_dlpack_from_pyobject", "0") != "1":
- # external tensors
- if hasattr(arg, "__c_dlpack_from_pyobject__"):
- out.func = TVMFFIPyArgSetterDLPackCExporter_
- temp_ptr = arg.__c_dlpack_from_pyobject__
- out.c_dlpack_from_pyobject = <DLPackFromPyObject>temp_ptr
- if hasattr(arg, "__c_dlpack_to_pyobject__"):
- temp_ptr = arg.__c_dlpack_to_pyobject__
- out.c_dlpack_to_pyobject = <DLPackToPyObject>temp_ptr
- if hasattr(arg, "__c_dlpack_tensor_allocator__"):
- temp_ptr = arg.__c_dlpack_tensor_allocator__
- out.c_dlpack_tensor_allocator = <DLPackTensorAllocator>temp_ptr
+ # Check for DLPackExchangeAPI struct (new approach)
+ # This is checked on the CLASS, not the instance
+ arg_class = type(arg)
+ if hasattr(arg_class, "__c_dlpack_exchange_api__"):
+ out.func = TVMFFIPyArgSetterDLPackExchangeAPI_
+ temp_ptr = arg_class.__c_dlpack_exchange_api__
+ out.c_dlpack_exchange_api = <const DLPackExchangeAPI*>(<long
long>temp_ptr)
return 0
if torch is not None and isinstance(arg, torch.Tensor):
out.func = TVMFFIPyArgSetterTorchFallback_
@@ -658,7 +679,7 @@ cdef class Function(Object):
def __call__(self, *args):
cdef TVMFFIAny result
cdef int c_api_ret_code
- cdef DLPackToPyObject c_dlpack_to_pyobject = NULL
+ cdef const DLPackExchangeAPI* c_ctx_dlpack_api = NULL
# IMPORTANT: caller need to initialize result->type_index to
kTVMFFINone
result.type_index = kTVMFFINone
result.v_int64 = 0
@@ -668,12 +689,12 @@ cdef class Function(Object):
&result,
&c_api_ret_code,
self.release_gil,
- &c_dlpack_to_pyobject
+ &c_ctx_dlpack_api
)
# NOTE: logic is same as check_call
# directly inline here to simplify the resulting trace
if c_api_ret_code == 0:
- return make_ret(result, c_dlpack_to_pyobject)
+ return make_ret(result, c_ctx_dlpack_api)
elif c_api_ret_code == -2:
raise_existing_error()
raise move_from_last_error().py_error()
diff --git a/python/tvm_ffi/cython/tensor.pxi b/python/tvm_ffi/cython/tensor.pxi
index 74b065b..4ebc515 100644
--- a/python/tvm_ffi/cython/tensor.pxi
+++ b/python/tvm_ffi/cython/tensor.pxi
@@ -275,33 +275,74 @@ _set_class_tensor(Tensor)
_register_object_by_index(kTVMFFITensor, Tensor)
-cdef int _dltensor_test_wrapper_c_dlpack_from_pyobject(
- void* obj, DLManagedTensorVersioned** out, TVMFFIStreamHandle* env_stream
+cdef int _dltensor_test_wrapper_from_pyobject(
+ void* obj, DLManagedTensorVersioned** out
) except -1:
+ """DLPackExchangeAPI: managed_tensor_from_py_object_no_sync"""
cdef PyObject* py_obj = <PyObject*>obj
cdef DLTensorTestWrapper wrapper = <DLTensorTestWrapper>py_obj
- cdef TVMFFIStreamHandle current_stream
- cdef DLManagedTensorVersioned* temp_managed_tensor
- if env_stream != NULL:
- env_stream[0] = TVMFFIEnvGetStream(
- wrapper.tensor.cdltensor.device.device_type,
- wrapper.tensor.cdltensor.device.device_id
- )
-
return TVMFFITensorToDLPackVersioned(wrapper.tensor.chandle, out)
-def _dltensor_test_wrapper_c_dlpack_from_pyobject_as_intptr():
- cdef DLPackFromPyObject converter_func =
_dltensor_test_wrapper_c_dlpack_from_pyobject
- cdef void* temp_ptr = <void*>converter_func
- cdef long long temp_int_ptr = <long long>temp_ptr
- return temp_int_ptr
+cdef int _dltensor_test_wrapper_to_pyobject(
+ DLManagedTensorVersioned* tensor, void** out_py_object
+) except -1:
+ """DLPackExchangeAPI: managed_tensor_to_py_object_no_sync"""
+ cdef TVMFFIObjectHandle temp_chandle
+ if TVMFFITensorFromDLPackVersioned(tensor, 0, 0, &temp_chandle) != 0:
+ return -1
+ py_tensor = make_tensor_from_chandle(temp_chandle)
+ Py_INCREF(py_tensor)
+ out_py_object[0] = <void*>(<PyObject*>py_tensor)
+ return 0
+
+
+cdef int _dltensor_test_wrapper_current_work_stream(
+ int device_type, int32_t device_id, void** out_stream
+) except -1:
+ """DLPackExchangeAPI: current_work_stream"""
+ if device_type != kDLCPU:
+ out_stream[0] = <void*>TVMFFIEnvGetStream(device_type, device_id)
+ return 0
+
+
+# Module-level static DLPackExchangeAPI for DLTensorTestWrapper
+cdef DLPackExchangeAPI _dltensor_test_wrapper_static_api
+
+cdef const DLPackExchangeAPI* _dltensor_test_wrapper_get_exchange_api()
noexcept:
+ """Get the static DLPackExchangeAPI instance for DLTensorTestWrapper."""
+ global _dltensor_test_wrapper_static_api
+
+ # Initialize header using macros from dlpack.h
+ _dltensor_test_wrapper_static_api.header.version.major =
DLPACK_MAJOR_VERSION
+ _dltensor_test_wrapper_static_api.header.version.minor =
DLPACK_MINOR_VERSION
+ _dltensor_test_wrapper_static_api.header.prev_api = NULL
+
+ # Initialize function pointers
+ _dltensor_test_wrapper_static_api.managed_tensor_allocator = NULL
+ _dltensor_test_wrapper_static_api.managed_tensor_from_py_object_no_sync = (
+
<DLPackManagedTensorFromPyObjectNoSync>_dltensor_test_wrapper_from_pyobject
+ )
+ _dltensor_test_wrapper_static_api.managed_tensor_to_py_object_no_sync = (
+ <DLPackManagedTensorToPyObjectNoSync>_dltensor_test_wrapper_to_pyobject
+ )
+ _dltensor_test_wrapper_static_api.dltensor_from_py_object_no_sync = NULL
+ _dltensor_test_wrapper_static_api.current_work_stream = (
+ <DLPackCurrentWorkStream>_dltensor_test_wrapper_current_work_stream
+ )
+
+ return &_dltensor_test_wrapper_static_api
+
+
+def _dltensor_test_wrapper_exchange_api_ptr():
+ """Return the pointer to the DLPackExchangeAPI struct as an integer."""
+ return <long long>_dltensor_test_wrapper_get_exchange_api()
cdef class DLTensorTestWrapper:
"""Wrapper of a Tensor that exposes DLPack protocol, only for testing
purpose.
"""
- __c_dlpack_from_pyobject__ =
_dltensor_test_wrapper_c_dlpack_from_pyobject_as_intptr()
+ __c_dlpack_exchange_api__ = _dltensor_test_wrapper_exchange_api_ptr()
cdef Tensor tensor
cdef dict __dict__
@@ -334,19 +375,21 @@ cdef inline object make_ret_dltensor(TVMFFIAny result):
return tensor
-cdef inline object make_tensor_from_chandle(TVMFFIObjectHandle chandle,
DLPackToPyObject c_dlpack_to_pyobject = NULL):
+cdef inline object make_tensor_from_chandle(
+ TVMFFIObjectHandle chandle, const DLPackExchangeAPI* c_ctx_dlpack_api =
NULL
+):
# TODO: Implement
cdef Tensor tensor
cdef void* py_obj
cdef DLManagedTensorVersioned* dlpack
- if c_dlpack_to_pyobject != NULL:
+ if c_ctx_dlpack_api != NULL and
c_ctx_dlpack_api.managed_tensor_to_py_object_no_sync != NULL:
# try convert and import into the environment array if possible
if TVMFFITensorToDLPackVersioned(chandle, &dlpack) == 0:
try:
# note that py_obj already holds an extra reference to the
tensor
# so we need to decref it after the conversion
- c_dlpack_to_pyobject(dlpack, &py_obj)
+ c_ctx_dlpack_api.managed_tensor_to_py_object_no_sync(dlpack,
&py_obj)
tensor = <Tensor>(<PyObject*>py_obj)
Py_DECREF(tensor)
# decref original handle to prevent leak.
@@ -365,5 +408,5 @@ cdef inline object
make_tensor_from_chandle(TVMFFIObjectHandle chandle, DLPackTo
return tensor
-cdef inline object make_tensor_from_any(TVMFFIAny any, DLPackToPyObject
c_dlpack_to_pyobject):
- return make_tensor_from_chandle(any.v_ptr, c_dlpack_to_pyobject)
+cdef inline object make_tensor_from_any(TVMFFIAny any, const
DLPackExchangeAPI* c_ctx_dlpack_api):
+ return make_tensor_from_chandle(any.v_ptr, c_ctx_dlpack_api)
diff --git a/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
b/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
index ea60d51..da2404b 100644
--- a/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
+++ b/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
@@ -89,10 +89,8 @@ struct TVMFFIPyCallContext {
void** temp_py_objects = nullptr;
/*! \brief the number of temporary arguments */
int num_temp_py_objects = 0;
- /*! \brief the DLPack exporter, if any */
- DLPackToPyObject c_dlpack_to_pyobject{nullptr};
- /*! \brief the DLPack allocator, if any */
- DLPackTensorAllocator c_dlpack_tensor_allocator{nullptr};
+ /*! \brief the DLPack exchange API, if any */
+ const DLPackExchangeAPI* c_dlpack_exchange_api{nullptr};
};
/*! \brief Argument setter for a given python argument. */
@@ -108,17 +106,10 @@ struct TVMFFIPyArgSetter {
int (*func)(TVMFFIPyArgSetter* self, TVMFFIPyCallContext* call_ctx,
PyObject* arg,
TVMFFIAny* out);
/*!
- * \brief Optional DLPack exporter for for setters that leverages DLPack
protocol.
+ * \brief Optional DLPackExchangeAPI struct pointer.
+ * This is the new struct-based approach that bundles all DLPack exchange
functions.
*/
- DLPackFromPyObject c_dlpack_from_pyobject{nullptr};
- /*!
- * \brief Optional DLPack importer for for setters that leverages DLPack
protocol.
- */
- DLPackToPyObject c_dlpack_to_pyobject{nullptr};
- /*!
- * \brief Optional DLPack allocator for for setters that leverages DLPack
protocol.
- */
- DLPackTensorAllocator c_dlpack_tensor_allocator{nullptr};
+ const DLPackExchangeAPI* c_dlpack_exchange_api{nullptr};
/*!
* \brief Invoke the setter.
* \param call_ctx The call context.
@@ -273,13 +264,14 @@ class TVMFFIPyCallManager {
* \param result The result of the function
* \param c_api_ret_code The return code of the C-call
* \param release_gil Whether to release the GIL
- * \param optional_out_dlpack_importer The DLPack importer to be used for
the result
+ * \param optional_out_ctx_dlpack_api The DLPack exchange API to be used for
the result
* \return 0 on when there is no python error, -1 on python error
* \note When an error happens on FFI side, we should return 0 and set
c_api_ret_code
*/
TVM_FFI_INLINE int FuncCall(TVMFFIPyArgSetterFactory setter_factory, void*
func_handle,
PyObject* py_arg_tuple, TVMFFIAny* result, int*
c_api_ret_code,
- bool release_gil, DLPackToPyObject*
optional_out_dlpack_importer) {
+ bool release_gil,
+ const DLPackExchangeAPI**
optional_out_ctx_dlpack_api) {
int64_t num_args = PyTuple_Size(py_arg_tuple);
if (num_args == -1) return -1;
try {
@@ -300,9 +292,10 @@ class TVMFFIPyCallManager {
// setting failed, directly return
if (c_api_ret_code[0] != 0) return 0;
}
- if (ctx.c_dlpack_tensor_allocator != nullptr) {
- c_api_ret_code[0] =
- TVMFFIEnvSetTensorAllocator(ctx.c_dlpack_tensor_allocator, 0,
&prev_tensor_allocator);
+ if (ctx.c_dlpack_exchange_api != nullptr &&
+ ctx.c_dlpack_exchange_api->managed_tensor_allocator != nullptr) {
+ c_api_ret_code[0] = TVMFFIEnvSetTensorAllocator(
+ ctx.c_dlpack_exchange_api->managed_tensor_allocator, 0,
&prev_tensor_allocator);
if (c_api_ret_code[0] != 0) return 0;
}
// call the function
@@ -323,12 +316,13 @@ class TVMFFIPyCallManager {
return -1;
}
}
- if (prev_tensor_allocator != ctx.c_dlpack_tensor_allocator) {
+ if (ctx.c_dlpack_exchange_api != nullptr &&
+ prev_tensor_allocator !=
ctx.c_dlpack_exchange_api->managed_tensor_allocator) {
c_api_ret_code[0] = TVMFFIEnvSetTensorAllocator(prev_tensor_allocator,
0, nullptr);
if (c_api_ret_code[0] != 0) return 0;
}
- if (optional_out_dlpack_importer != nullptr && ctx.c_dlpack_to_pyobject
!= nullptr) {
- *optional_out_dlpack_importer = ctx.c_dlpack_to_pyobject;
+ if (optional_out_ctx_dlpack_api != nullptr && ctx.c_dlpack_exchange_api
!= nullptr) {
+ *optional_out_ctx_dlpack_api = ctx.c_dlpack_exchange_api;
}
return 0;
} catch (const std::exception& ex) {
@@ -379,13 +373,9 @@ class TVMFFIPyCallManager {
parent_ctx->device_id = ctx.device_id;
parent_ctx->stream = ctx.stream;
}
- // DLPack allocator
- if (parent_ctx->c_dlpack_tensor_allocator == nullptr) {
- parent_ctx->c_dlpack_tensor_allocator =
ctx.c_dlpack_tensor_allocator;
- }
- // DLPack importer
- if (parent_ctx->c_dlpack_to_pyobject == nullptr) {
- parent_ctx->c_dlpack_to_pyobject = ctx.c_dlpack_to_pyobject;
+ // DLPack exchange API
+ if (parent_ctx->c_dlpack_exchange_api == nullptr) {
+ parent_ctx->c_dlpack_exchange_api = ctx.c_dlpack_exchange_api;
}
}
return 0;
@@ -490,16 +480,16 @@ class TVMFFIPyCallManager {
* \param result The result of the function
* \param c_api_ret_code The return code of the function
* \param release_gil Whether to release the GIL
- * \param out_dlpack_exporter The DLPack exporter to be used for the result
+ * \param out_ctx_dlpack_api The DLPack exchange API to be used for the result
* \return 0 on success, nonzero on failure
*/
TVM_FFI_INLINE int TVMFFIPyFuncCall(TVMFFIPyArgSetterFactory setter_factory,
void* func_handle,
PyObject* py_arg_tuple, TVMFFIAny* result,
int* c_api_ret_code,
bool release_gil = true,
- DLPackToPyObject* out_dlpack_importer =
nullptr) {
+ const DLPackExchangeAPI**
out_ctx_dlpack_api = nullptr) {
return TVMFFIPyCallManager::ThreadLocal()->FuncCall(setter_factory,
func_handle, py_arg_tuple,
result, c_api_ret_code,
release_gil,
- out_dlpack_importer);
+ out_ctx_dlpack_api);
}
/*!
diff --git a/tests/python/test_load_inline.py b/tests/python/test_load_inline.py
index b30471b..795e691 100644
--- a/tests/python/test_load_inline.py
+++ b/tests/python/test_load_inline.py
@@ -213,8 +213,8 @@ def test_load_inline_cuda() -> None:
@pytest.mark.skipif(torch is None, reason="Requires torch")
def test_load_inline_with_env_tensor_allocator() -> None:
assert torch is not None
- if not hasattr(torch.Tensor, "__c_dlpack_tensor_allocator__"):
- pytest.skip("Torch does not support __c_dlpack_tensor_allocator__")
+ if not hasattr(torch.Tensor, "__c_dlpack_exchange_api__"):
+ pytest.skip("Torch does not support __c_dlpack_exchange_api__")
mod: Module = tvm_ffi.cpp.load_inline(
name="hello",
cpp_sources=r"""