This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 4147ba7 [FIX] Fixes from_dlpack issue fallback (#301)
4147ba7 is described below
commit 4147ba7db7500c19f0c78a2b85cf0a9bc4be4ac3
Author: Tianqi Chen <[email protected]>
AuthorDate: Tue Dec 2 12:20:45 2025 -0500
[FIX] Fixes from_dlpack issue fallback (#301)
This PR fixes from_dlpack issue after the attribute api changes to
pycapsule
---
python/tvm_ffi/cython/tensor.pxi | 6 ++++--
tests/python/test_dlpack_exchange_api.py | 11 ++++++++++-
2 files changed, 14 insertions(+), 3 deletions(-)
diff --git a/python/tvm_ffi/cython/tensor.pxi b/python/tvm_ffi/cython/tensor.pxi
index 0521dcb..844f7ef 100644
--- a/python/tvm_ffi/cython/tensor.pxi
+++ b/python/tvm_ffi/cython/tensor.pxi
@@ -105,7 +105,7 @@ cdef inline int _from_dlpack_versioned(
cdef inline int _from_dlpack_exchange_api(
- object ext_tensor, DLPackExchangeAPI* exchange_api, int require_alignment,
+ object ext_tensor, const DLPackExchangeAPI* exchange_api, int
require_alignment,
int require_contiguous, TVMFFIObjectHandle* out
) except -1:
cdef DLManagedTensorVersioned* temp_managed_tensor
@@ -131,12 +131,14 @@ cdef inline int _from_dlpack_universal(
# as of most frameworks do not yet support v1.1
# move to false as most frameworks get upgraded.
cdef int favor_legacy_dlpack = True
+ cdef const DLPackExchangeAPI* exchange_api = NULL
if hasattr(ext_tensor, "__c_dlpack_exchange_api__"):
try:
+ _get_dlpack_exchange_api(ext_tensor.__c_dlpack_exchange_api__,
&exchange_api)
return _from_dlpack_exchange_api(
ext_tensor,
- <DLPackExchangeAPI*><long
long>(ext_tensor.__c_dlpack_exchange_api__),
+ exchange_api,
require_alignment,
require_contiguous,
out
diff --git a/tests/python/test_dlpack_exchange_api.py
b/tests/python/test_dlpack_exchange_api.py
index 048ade5..7df8dcf 100644
--- a/tests/python/test_dlpack_exchange_api.py
+++ b/tests/python/test_dlpack_exchange_api.py
@@ -28,7 +28,7 @@ try:
# Import tvm_ffi to load the DLPack exchange API extension
# This sets torch.Tensor.__c_dlpack_exchange_api__
- import tvm_ffi # noqa: F401
+ import tvm_ffi
from torch.utils import cpp_extension # type: ignore
from tvm_ffi import libinfo
except ImportError:
@@ -222,5 +222,14 @@ def test_dlpack_exchange_api() -> None:
mod.test_dlpack_api(tensor, api_ptr, torch.cuda.is_available())
[email protected](not _has_dlpack_api, reason="PyTorch DLPack Exchange API
not available")
+def test_from_dlpack_torch() -> None:
+ # Covers from_dlpack to use fallback fastpath
+ tensor = torch.arange(24, dtype=torch.float32).reshape(2, 3, 4)
+ tensor_from_dlpack = tvm_ffi.from_dlpack(tensor)
+ assert tensor_from_dlpack.shape == tensor.shape
+ assert tensor_from_dlpack.dtype == tvm_ffi.float32
+
+
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])