This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new b0537f0 [CYTHON] Support cuda stream protocol (#109)
b0537f0 is described below
commit b0537f045b30334a12bf3365438aedb2c3bc7285
Author: Tianqi Chen <[email protected]>
AuthorDate: Mon Oct 13 14:25:37 2025 -0400
[CYTHON] Support cuda stream protocol (#109)
This PR updates the ABI to support cuda stream protocol
https://nvidia.github.io/cuda-python/cuda-core/latest/interoperability.html#cuda-stream-protocol
For now we turn stream into void* in passing into the ABI
---
python/tvm_ffi/cython/function.pxi | 20 +++++++++++++++++++-
tests/python/test_device.py | 16 ++++++++++++++++
2 files changed, 35 insertions(+), 1 deletion(-)
diff --git a/python/tvm_ffi/cython/function.pxi
b/python/tvm_ffi/cython/function.pxi
index 0305fec..6fc2a7b 100644
--- a/python/tvm_ffi/cython/function.pxi
+++ b/python/tvm_ffi/cython/function.pxi
@@ -287,6 +287,20 @@ cdef int TVMFFIPyArgSetterFFITensorCompatible_(
return 0
+cdef int TVMFFIPyArgSetterCUDAStream_(
+ TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
+ PyObject* py_arg, TVMFFIAny* out
+) except -1:
+ """Setter for cuda stream protocol"""
+ cdef object arg = <object>py_arg
+ # cuda stream is a subclass of str, so this check occur before str
+ cdef tuple cu_stream_tuple = arg.__cuda_stream__()
+ cdef long long long_ptr = <long long>cu_stream_tuple[1]
+ out.type_index = kTVMFFIOpaquePtr
+ out.v_ptr = <void*>long_ptr
+ return 0
+
+
cdef int TVMFFIPyArgSetterDType_(
TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
PyObject* py_arg, TVMFFIAny* out
@@ -605,10 +619,14 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value,
TVMFFIPyArgSetter* out) exce
temp_ptr = arg_class.__c_dlpack_exchange_api__
out.c_dlpack_exchange_api = <const DLPackExchangeAPI*>(<long
long>temp_ptr)
return 0
+ if hasattr(arg_class, "__cuda_stream__"):
+ # cuda stream protocol
+ out.func = TVMFFIPyArgSetterCUDAStream_
+ return 0
if torch is not None and isinstance(arg, torch.Tensor):
out.func = TVMFFIPyArgSetterTorchFallback_
return 0
- if hasattr(arg, "__dlpack__"):
+ if hasattr(arg_class, "__dlpack__"):
out.func = TVMFFIPyArgSetterDLPack_
return 0
if isinstance(arg, bool):
diff --git a/tests/python/test_device.py b/tests/python/test_device.py
index 7a3638b..71dc0e4 100644
--- a/tests/python/test_device.py
+++ b/tests/python/test_device.py
@@ -17,6 +17,7 @@
from __future__ import annotations
+import ctypes
import pickle
import pytest
@@ -109,3 +110,18 @@ def test_device_class_override() -> None:
device = tvm_ffi.device("cuda", 0)
assert isinstance(device, MyDevice)
tvm_ffi.core._set_class_device(old_device)
+
+
+def test_cuda_stream_handling() -> None:
+ class MyDummyStream:
+ def __init__(self, stream: int) -> None:
+ self.stream = stream
+
+ def __cuda_stream__(self) -> tuple[str, int]:
+ return ("cuda", self.stream)
+
+ stream = MyDummyStream(1)
+ echo = tvm_ffi.get_global_func("testing.echo")
+ y = echo(stream)
+ assert isinstance(y, ctypes.c_void_p)
+ assert y.value == 1