This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 5393647 [DLPack] Further compatibility (#315)
5393647 is described below
commit 539364726ea51d5ea4c7695bf9a1a6cfc37ca8cc
Author: Tianqi Chen <[email protected]>
AuthorDate: Fri Dec 5 12:23:26 2025 -0500
[DLPack] Further compatibility (#315)
This PR updates the DLPack support to be compatible to name
c_dlpack_exchange_api and dlpack_c_exchange_api in case of future DLPack
naming changes. Backward compatibility is kept.
---
addons/torch_c_dlpack_ext/build_backend.py | 6 ++--
.../torch_c_dlpack_ext/torch_c_dlpack_ext/core.py | 25 +++++++++++++++-
python/tvm_ffi/_optional_torch_c_dlpack.py | 33 +++++++++++++++-------
python/tvm_ffi/core.pyi | 2 +-
python/tvm_ffi/cython/base.pxi | 4 +--
python/tvm_ffi/cython/function.pxi | 12 ++++----
python/tvm_ffi/cython/tensor.pxi | 17 ++++++-----
python/tvm_ffi/cython/tvm_ffi_python_helpers.h | 22 +++++++--------
tests/python/test_dlpack_exchange_api.py | 32 ++++++++-------------
tests/python/test_load_inline.py | 4 +--
10 files changed, 93 insertions(+), 64 deletions(-)
diff --git a/addons/torch_c_dlpack_ext/build_backend.py
b/addons/torch_c_dlpack_ext/build_backend.py
index e4504fa..639977e 100644
--- a/addons/torch_c_dlpack_ext/build_backend.py
+++ b/addons/torch_c_dlpack_ext/build_backend.py
@@ -66,9 +66,11 @@ def build_wheel(
# build wheel from sdist package, compile the torch c dlpack ext
library locally.
import torch # noqa: PLC0415
- if hasattr(torch.Tensor, "__c_dlpack_exchange_api__"):
+ if hasattr(torch.Tensor, "__dlpack_c_exchange_api__") or hasattr(
+ torch.Tensor, "__c_dlpack_exchange_api__"
+ ):
print(
- "torch.Tensor already has attribute __c_dlpack_exchange_api__.
"
+ "torch.Tensor already has attribute __dlpack_c_exchange_api__.
"
"No need to build any torch c dlpackc libs."
)
else:
diff --git a/addons/torch_c_dlpack_ext/torch_c_dlpack_ext/core.py
b/addons/torch_c_dlpack_ext/torch_c_dlpack_ext/core.py
index d0313a7..6b63e0a 100644
--- a/addons/torch_c_dlpack_ext/torch_c_dlpack_ext/core.py
+++ b/addons/torch_c_dlpack_ext/torch_c_dlpack_ext/core.py
@@ -19,14 +19,31 @@
import ctypes
import sys
from pathlib import Path
+from typing import Any
import torch
from packaging.version import Version
+def _create_dlpack_exchange_api_capsule(ptr_as_int: int) -> Any:
+ """Create a PyCapsule wrapping the DLPack exchange API pointer."""
+ capsule_name = b"dlpack_exchange_api"
+ pythonapi = ctypes.pythonapi
+ pythonapi.PyCapsule_New.restype = ctypes.py_object
+ pythonapi.PyCapsule_New.argtypes = [
+ ctypes.c_void_p,
+ ctypes.c_char_p,
+ ctypes.c_void_p,
+ ]
+ capsule = pythonapi.PyCapsule_New(ctypes.c_void_p(ptr_as_int),
capsule_name, None)
+ return capsule
+
+
def load_torch_c_dlpack_extension() -> None:
"""Load the torch c dlpack extension based on torch version."""
- if hasattr(torch.Tensor, "__c_dlpack_exchange_api__"):
+ if hasattr(torch.Tensor, "__dlpack_c_exchange_api__") or hasattr(
+ torch.Tensor, "__c_dlpack_exchange_api__"
+ ):
return None
version = Version(torch.__version__)
if sys.platform.startswith("win32"):
@@ -52,6 +69,12 @@ def load_torch_c_dlpack_extension() -> None:
# We will do eager upgrade to PyCapsule in the tvm-ffi side instead.
dlpack_exchange_api_ptr_as_int = func()
setattr(torch.Tensor, "__c_dlpack_exchange_api__",
dlpack_exchange_api_ptr_as_int)
+ setattr(
+ torch.Tensor,
+ "__dlpack_c_exchange_api__",
+ _create_dlpack_exchange_api_capsule(dlpack_exchange_api_ptr_as_int),
+ )
+
return lib
diff --git a/python/tvm_ffi/_optional_torch_c_dlpack.py
b/python/tvm_ffi/_optional_torch_c_dlpack.py
index f3d0119..dde0550 100644
--- a/python/tvm_ffi/_optional_torch_c_dlpack.py
+++ b/python/tvm_ffi/_optional_torch_c_dlpack.py
@@ -70,12 +70,31 @@ def _create_dlpack_exchange_api_capsule(ptr_as_int: int) ->
Any:
return capsule
+def _check_and_update_dlpack_c_exchange_api(tensor_cls: object) -> bool:
+ """Check if the DLPack exchange API is available and update the
__dlpack_c_exchange_api__ attribute."""
+ if hasattr(tensor_cls, "__dlpack_c_exchange_api__"):
+ return True
+ # legacy path compactibility handling
+ if hasattr(tensor_cls, "__c_dlpack_exchange_api__"):
+ c_dlpack_attribute = tensor_cls.__c_dlpack_exchange_api__
+ if isinstance(c_dlpack_attribute, int):
+ setattr(
+ tensor_cls,
+ "__dlpack_c_exchange_api__",
+ _create_dlpack_exchange_api_capsule(c_dlpack_attribute),
+ )
+ else:
+ setattr(tensor_cls, "__dlpack_c_exchange_api__",
c_dlpack_attribute)
+ return True
+ return False
+
+
def load_torch_c_dlpack_extension() -> Any: # noqa: PLR0912, PLR0915
try:
import torch # noqa: PLC0415
- if hasattr(torch.Tensor, "__c_dlpack_exchange_api__"):
- # skip loading the extension if the __c_dlpack_exchange_api__
+ if _check_and_update_dlpack_c_exchange_api(torch.Tensor):
+ # skip loading the extension if the __dlpack_c_exchange_api__
# attribute is already set so we don't have to do it in
# newer version of PyTorch
return None
@@ -86,12 +105,7 @@ def load_torch_c_dlpack_extension() -> Any: # noqa:
PLR0912, PLR0915
try:
import torch_c_dlpack_ext # type: ignore # noqa: PLC0415, F401
- if hasattr(torch.Tensor, "__c_dlpack_exchange_api__"):
- if isinstance(torch.Tensor.__c_dlpack_exchange_api__, int):
- # Brings up to speed with the new PyCapsule behavior
- torch.Tensor.__c_dlpack_exchange_api__ =
_create_dlpack_exchange_api_capsule(
- torch.Tensor.__c_dlpack_exchange_api__
- )
+ if _check_and_update_dlpack_c_exchange_api(torch.Tensor):
return None
except ImportError:
pass
@@ -152,8 +166,7 @@ def load_torch_c_dlpack_extension() -> Any: # noqa:
PLR0912, PLR0915
# Create a PyCapsule from the pointer
capsule = _create_dlpack_exchange_api_capsule(func())
# Set the DLPackExchangeAPI pointer on the class
- setattr(torch.Tensor, "__c_dlpack_exchange_api__", capsule)
-
+ setattr(torch.Tensor, "__dlpack_c_exchange_api__", capsule)
return lib
except ImportError:
pass
diff --git a/python/tvm_ffi/core.pyi b/python/tvm_ffi/core.pyi
index 06a78e6..2cee79c 100644
--- a/python/tvm_ffi/core.pyi
+++ b/python/tvm_ffi/core.pyi
@@ -175,7 +175,7 @@ def from_dlpack(
) -> Tensor: ...
class DLTensorTestWrapper:
- __c_dlpack_exchange_api__: int
+ __dlpack_c_exchange_api__: int
def __init__(self, tensor: Tensor) -> None: ...
def __tvm_ffi_env_stream__(self) -> int: ...
def __dlpack_device__(self) -> tuple[int, int]: ...
diff --git a/python/tvm_ffi/cython/base.pxi b/python/tvm_ffi/cython/base.pxi
index 933bc86..9e7ff87 100644
--- a/python/tvm_ffi/cython/base.pxi
+++ b/python/tvm_ffi/cython/base.pxi
@@ -322,11 +322,11 @@ cdef extern from "tvm_ffi_python_helpers.h":
int device_type
int device_id
TVMFFIStreamHandle stream
- const DLPackExchangeAPI* c_dlpack_exchange_api
+ const DLPackExchangeAPI* dlpack_c_exchange_api
ctypedef struct TVMFFIPyArgSetter:
int (*func)(TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
PyObject* py_arg, TVMFFIAny* out) except -1
- const DLPackExchangeAPI* c_dlpack_exchange_api
+ const DLPackExchangeAPI* dlpack_c_exchange_api
ctypedef int (*TVMFFIPyArgSetterFactory)(PyObject* value,
TVMFFIPyArgSetter* out) except -1
# The main call function
diff --git a/python/tvm_ffi/cython/function.pxi
b/python/tvm_ffi/cython/function.pxi
index 29af699..01a3366 100644
--- a/python/tvm_ffi/cython/function.pxi
+++ b/python/tvm_ffi/cython/function.pxi
@@ -156,10 +156,10 @@ cdef int TVMFFIPyArgSetterDLPackExchangeAPI_(
cdef DLManagedTensorVersioned* temp_managed_tensor
cdef TVMFFIObjectHandle temp_chandle
cdef void* current_stream = NULL
- cdef const DLPackExchangeAPI* exchange_api = this.c_dlpack_exchange_api
+ cdef const DLPackExchangeAPI* exchange_api = this.dlpack_c_exchange_api
# Set the exchange API in context
- ctx.c_dlpack_exchange_api = exchange_api
+ ctx.dlpack_c_exchange_api = exchange_api
# Convert PyObject to DLPack using the struct's function pointer
if exchange_api.managed_tensor_from_py_object_no_sync(arg,
&temp_managed_tensor) != 0:
@@ -239,7 +239,7 @@ cdef int TVMFFIPyArgSetterTorchFallback_(
out.type_index = kTVMFFITensor
out.v_ptr = (<Tensor>arg).chandle
temp_dltensor = TVMFFITensorGetDLTensorPtr((<Tensor>arg).chandle)
- ctx.c_dlpack_exchange_api = GetTorchFallbackExchangeAPI()
+ ctx.dlpack_c_exchange_api = GetTorchFallbackExchangeAPI()
# record the stream and device for torch context
if is_cuda and ctx.device_type == -1:
ctx.device_type = temp_dltensor.device.device_type
@@ -740,12 +740,12 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value,
TVMFFIPyArgSetter* out) exce
# as a member variable
out.func = TVMFFIPyArgSetterFFIObjectProtocol_
return 0
- if os.environ.get("TVM_FFI_SKIP_C_DLPACK_EXCHANGE_API", "0") != "1":
+ if os.environ.get("TVM_FFI_SKIP_DLPACK_C_EXCHANGE_API", "0") != "1":
# Check for DLPackExchangeAPI struct (new approach)
# This is checked on the CLASS, not the instance
- if hasattr(arg_class, "__c_dlpack_exchange_api__"):
+ if hasattr(arg_class, "__dlpack_c_exchange_api__"):
out.func = TVMFFIPyArgSetterDLPackExchangeAPI_
- _get_dlpack_exchange_api(arg_class.__c_dlpack_exchange_api__,
&(out.c_dlpack_exchange_api))
+ _get_dlpack_exchange_api(arg_class.__dlpack_c_exchange_api__,
&(out.dlpack_c_exchange_api))
return 0
if hasattr(arg_class, "__cuda_stream__"):
# cuda stream protocol
diff --git a/python/tvm_ffi/cython/tensor.pxi b/python/tvm_ffi/cython/tensor.pxi
index 841687a..1f4973d 100644
--- a/python/tvm_ffi/cython/tensor.pxi
+++ b/python/tvm_ffi/cython/tensor.pxi
@@ -133,9 +133,9 @@ cdef inline int _from_dlpack_universal(
cdef int favor_legacy_dlpack = True
cdef const DLPackExchangeAPI* exchange_api = NULL
- if hasattr(ext_tensor, "__c_dlpack_exchange_api__"):
+ if hasattr(ext_tensor, "__dlpack_c_exchange_api__"):
try:
- _get_dlpack_exchange_api(ext_tensor.__c_dlpack_exchange_api__,
&exchange_api)
+ _get_dlpack_exchange_api(ext_tensor.__dlpack_c_exchange_api__,
&exchange_api)
return _from_dlpack_exchange_api(
ext_tensor,
exchange_api,
@@ -405,7 +405,7 @@ cdef int _dltensor_test_wrapper_current_work_stream(
# Module-level static DLPackExchangeAPI for DLTensorTestWrapper
cdef DLPackExchangeAPI _dltensor_test_wrapper_static_api
-cdef const DLPackExchangeAPI* _dltensor_test_wrapper_get_exchange_api()
noexcept:
+cdef DLPackExchangeAPI* _dltensor_test_wrapper_get_exchange_api() noexcept:
"""Get the static DLPackExchangeAPI instance for DLTensorTestWrapper."""
global _dltensor_test_wrapper_static_api
@@ -430,15 +430,14 @@ cdef const DLPackExchangeAPI*
_dltensor_test_wrapper_get_exchange_api() noexcept
return &_dltensor_test_wrapper_static_api
-def _dltensor_test_wrapper_exchange_api_ptr():
- """Return the pointer to the DLPackExchangeAPI struct as an integer."""
- return <long long>_dltensor_test_wrapper_get_exchange_api()
-
-
cdef class DLTensorTestWrapper:
"""Wrapper of a Tensor that exposes DLPack protocol, only for testing
purpose.
"""
- __c_dlpack_exchange_api__ = _dltensor_test_wrapper_exchange_api_ptr()
+ __dlpack_c_exchange_api__ = pycapsule.PyCapsule_New(
+ _dltensor_test_wrapper_get_exchange_api(),
+ b"dlpack_exchange_api",
+ NULL
+ )
cdef Tensor tensor
cdef dict __dict__
diff --git a/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
b/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
index 88c27d7..93d6540 100644
--- a/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
+++ b/python/tvm_ffi/cython/tvm_ffi_python_helpers.h
@@ -86,7 +86,7 @@ class TVMFFIPyCallContext {
/*! \brief Detected stream, if any */
void* stream = nullptr;
/*! \brief the DLPack exchange API, if any */
- const DLPackExchangeAPI* c_dlpack_exchange_api{nullptr};
+ const DLPackExchangeAPI* dlpack_c_exchange_api{nullptr};
/*! \brief pointer to the call stack space */
TVMFFIPyCallStack* call_stack = nullptr;
/*! \brief the temporary arguments to be recycled */
@@ -174,7 +174,7 @@ struct TVMFFIPyArgSetter {
* \brief Optional DLPackExchangeAPI struct pointer.
* This is the new struct-based approach that bundles all DLPack exchange
functions.
*/
- const DLPackExchangeAPI* c_dlpack_exchange_api{nullptr};
+ const DLPackExchangeAPI* dlpack_c_exchange_api{nullptr};
/*!
* \brief Invoke the setter.
* \param call_ctx The call context.
@@ -297,10 +297,10 @@ class TVMFFIPyCallManager {
// setting failed, directly return
if (c_api_ret_code[0] != 0) return 0;
}
- if (ctx.c_dlpack_exchange_api != nullptr &&
- ctx.c_dlpack_exchange_api->managed_tensor_allocator != nullptr) {
+ if (ctx.dlpack_c_exchange_api != nullptr &&
+ ctx.dlpack_c_exchange_api->managed_tensor_allocator != nullptr) {
c_api_ret_code[0] = TVMFFIEnvSetDLPackManagedTensorAllocator(
- ctx.c_dlpack_exchange_api->managed_tensor_allocator, 0,
&prev_tensor_allocator);
+ ctx.dlpack_c_exchange_api->managed_tensor_allocator, 0,
&prev_tensor_allocator);
if (c_api_ret_code[0] != 0) return 0;
}
// call the function
@@ -321,14 +321,14 @@ class TVMFFIPyCallManager {
return -1;
}
}
- if (ctx.c_dlpack_exchange_api != nullptr &&
- prev_tensor_allocator !=
ctx.c_dlpack_exchange_api->managed_tensor_allocator) {
+ if (ctx.dlpack_c_exchange_api != nullptr &&
+ prev_tensor_allocator !=
ctx.dlpack_c_exchange_api->managed_tensor_allocator) {
c_api_ret_code[0] =
TVMFFIEnvSetDLPackManagedTensorAllocator(prev_tensor_allocator, 0,
nullptr);
if (c_api_ret_code[0] != 0) return 0;
}
- if (optional_out_ctx_dlpack_api != nullptr && ctx.c_dlpack_exchange_api
!= nullptr) {
- *optional_out_ctx_dlpack_api = ctx.c_dlpack_exchange_api;
+ if (optional_out_ctx_dlpack_api != nullptr && ctx.dlpack_c_exchange_api
!= nullptr) {
+ *optional_out_ctx_dlpack_api = ctx.dlpack_c_exchange_api;
}
return 0;
} catch (const std::exception& ex) {
@@ -380,8 +380,8 @@ class TVMFFIPyCallManager {
parent_ctx->stream = ctx.stream;
}
// DLPack exchange API
- if (parent_ctx->c_dlpack_exchange_api == nullptr) {
- parent_ctx->c_dlpack_exchange_api = ctx.c_dlpack_exchange_api;
+ if (parent_ctx->dlpack_c_exchange_api == nullptr) {
+ parent_ctx->dlpack_c_exchange_api = ctx.dlpack_c_exchange_api;
}
}
return 0;
diff --git a/tests/python/test_dlpack_exchange_api.py
b/tests/python/test_dlpack_exchange_api.py
index 7df8dcf..70e7586 100644
--- a/tests/python/test_dlpack_exchange_api.py
+++ b/tests/python/test_dlpack_exchange_api.py
@@ -27,7 +27,7 @@ try:
import torch # type: ignore[no-redef]
# Import tvm_ffi to load the DLPack exchange API extension
- # This sets torch.Tensor.__c_dlpack_exchange_api__
+ # This sets torch.Tensor.__dlpack_c_exchange_api__
import tvm_ffi
from torch.utils import cpp_extension # type: ignore
from tvm_ffi import libinfo
@@ -35,7 +35,7 @@ except ImportError:
torch = None
# Check if DLPack Exchange API is available
-_has_dlpack_api = torch is not None and hasattr(torch.Tensor,
"__c_dlpack_exchange_api__")
+_has_dlpack_api = torch is not None and hasattr(torch.Tensor,
"__dlpack_c_exchange_api__")
@pytest.mark.skipif(not _has_dlpack_api, reason="PyTorch DLPack Exchange API
not available")
@@ -45,24 +45,16 @@ def test_dlpack_exchange_api() -> None:
pytest.xfail("DLPack Exchange API test is known to fail on Windows
platform")
assert torch is not None
-
- assert hasattr(torch.Tensor, "__c_dlpack_exchange_api__")
- api_attr = torch.Tensor.__c_dlpack_exchange_api__
-
- # Handle both PyCapsule and integer types
- if isinstance(api_attr, int):
- # Direct integer pointer
- api_ptr = api_attr
- assert api_ptr != 0, "API pointer should not be NULL"
- else:
- # PyCapsule - extract the pointer as integer
- pythonapi = ctypes.pythonapi
- # Set restype to c_size_t to get integer directly (avoids c_void_p
quirks)
- pythonapi.PyCapsule_GetPointer.restype = ctypes.c_size_t
- pythonapi.PyCapsule_GetPointer.argtypes = [ctypes.py_object,
ctypes.c_char_p]
- capsule_name = b"dlpack_exchange_api"
- api_ptr = pythonapi.PyCapsule_GetPointer(api_attr, capsule_name)
- assert api_ptr != 0, "API pointer from PyCapsule should not be NULL"
+ assert hasattr(torch.Tensor, "__dlpack_c_exchange_api__")
+ api_attr = torch.Tensor.__dlpack_c_exchange_api__
+ # PyCapsule - extract the pointer as integer
+ pythonapi = ctypes.pythonapi
+ # Set restype to c_size_t to get integer directly (avoids c_void_p quirks)
+ pythonapi.PyCapsule_GetPointer.restype = ctypes.c_size_t
+ pythonapi.PyCapsule_GetPointer.argtypes = [ctypes.py_object,
ctypes.c_char_p]
+ capsule_name = b"dlpack_exchange_api"
+ api_ptr = pythonapi.PyCapsule_GetPointer(api_attr, capsule_name)
+ assert api_ptr != 0, "API pointer from PyCapsule should not be NULL"
tensor = torch.arange(24, dtype=torch.float32).reshape(2, 3, 4)
diff --git a/tests/python/test_load_inline.py b/tests/python/test_load_inline.py
index f39c485..77b9f8b 100644
--- a/tests/python/test_load_inline.py
+++ b/tests/python/test_load_inline.py
@@ -213,8 +213,8 @@ def test_load_inline_cuda() -> None:
@pytest.mark.skipif(torch is None, reason="Requires torch")
def test_load_inline_with_env_tensor_allocator() -> None:
assert torch is not None
- if not hasattr(torch.Tensor, "__c_dlpack_exchange_api__"):
- pytest.skip("Torch does not support __c_dlpack_exchange_api__")
+ if not hasattr(torch.Tensor, "__dlpack_c_exchange_api__"):
+ pytest.skip("Torch does not support __dlpack_c_exchange_api__")
mod: Module = tvm_ffi.cpp.load_inline(
name="hello",
cpp_sources=r"""