This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 6c85e56 [STREAM] Enable compact with cuda-python driver stream (#236)
6c85e56 is described below
commit 6c85e562c00f098743ed257ff4516a250f5145e7
Author: Tianqi Chen <[email protected]>
AuthorDate: Fri Nov 7 14:07:05 2025 -0500
[STREAM] Enable compact with cuda-python driver stream (#236)
As of now cuda-python driver stream do not yet support cuda stream
protocol. This PR enables a compact mode so we can take
cuda_driver.CUstream arguments and treat them as void_p.
---
python/tvm_ffi/_optional_torch_c_dlpack.py | 4 ++--
python/tvm_ffi/cython/function.pxi | 27 +++++++++++++++++++++++++--
tests/python/test_stream.py | 19 +++++++++++++++++++
3 files changed, 46 insertions(+), 4 deletions(-)
diff --git a/python/tvm_ffi/_optional_torch_c_dlpack.py
b/python/tvm_ffi/_optional_torch_c_dlpack.py
index 302a29c..7673c03 100644
--- a/python/tvm_ffi/_optional_torch_c_dlpack.py
+++ b/python/tvm_ffi/_optional_torch_c_dlpack.py
@@ -109,7 +109,7 @@ def load_torch_c_dlpack_extension() -> Any:
return None
-def patch_torch_cuda_stream_protocol() -> Any:
+def patch_torch_cuda_stream_protocol() -> None:
"""Load the torch cuda stream protocol for older versions of torch."""
try:
import torch # noqa: PLC0415
@@ -118,7 +118,7 @@ def patch_torch_cuda_stream_protocol() -> Any:
return
if not hasattr(torch.cuda.Stream, "__cuda_stream__"):
- def __torch_cuda_stream__(self: torch.cuda.Stream) -> tuple[int,
torch.cuda.Stream]:
+ def __torch_cuda_stream__(self: torch.cuda.Stream) -> tuple[int,
int]:
"""Return the version number and the cuda stream."""
return (0, self.cuda_stream)
diff --git a/python/tvm_ffi/cython/function.pxi
b/python/tvm_ffi/cython/function.pxi
index 4f9fca6..acfe3e2 100644
--- a/python/tvm_ffi/cython/function.pxi
+++ b/python/tvm_ffi/cython/function.pxi
@@ -33,9 +33,15 @@ if os.environ.get("TVM_FFI_BUILD_DOCS", "0") == "0":
import numpy
except ImportError:
numpy = None
+
+ try:
+ from cuda.bindings import driver as cuda_driver
+ except ImportError:
+ cuda_driver = None
else:
torch = None
numpy = None
+ cuda_driver = None
cdef int _RELEASE_GIL_BY_DEFAULT = int(
@@ -287,7 +293,7 @@ cdef int TVMFFIPyArgSetterFFIObjectCompatible_(
return 0
-cdef int TVMFFIPyArgSetterCUDAStream_(
+cdef int TVMFFIPyArgSetterCUDAStreamProtocol_(
TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
PyObject* py_arg, TVMFFIAny* out
) except -1:
@@ -301,6 +307,19 @@ cdef int TVMFFIPyArgSetterCUDAStream_(
return 0
+cdef int TVMFFIPyArgSetterCUDADriverStreamFallback_(
+ TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
+ PyObject* py_arg, TVMFFIAny* out
+) except -1:
+ """Setter for cuda.bindings.driver.CUstream as a fallback without
__cuda_stream__ protocol"""
+ cdef object arg = <object>py_arg
+ # call driver stream
+ cdef long long long_ptr = int(arg)
+ out.type_index = kTVMFFIOpaquePtr
+ out.v_ptr = <void*>long_ptr
+ return 0
+
+
cdef int TVMFFIPyArgSetterDType_(
TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
PyObject* py_arg, TVMFFIAny* out
@@ -658,7 +677,11 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value,
TVMFFIPyArgSetter* out) exce
return 0
if hasattr(arg_class, "__cuda_stream__"):
# cuda stream protocol
- out.func = TVMFFIPyArgSetterCUDAStream_
+ out.func = TVMFFIPyArgSetterCUDAStreamProtocol_
+ return 0
+ if cuda_driver is not None and isinstance(arg, cuda_driver.CUstream):
+ # TODO(tqchen): remove this once cuda-python supports __cuda_stream__
protocol
+ out.func = TVMFFIPyArgSetterCUDADriverStreamFallback_
return 0
if torch is not None and isinstance(arg, torch.Tensor):
out.func = TVMFFIPyArgSetterTorchFallback_
diff --git a/tests/python/test_stream.py b/tests/python/test_stream.py
index fbe1a0a..3b58ccb 100644
--- a/tests/python/test_stream.py
+++ b/tests/python/test_stream.py
@@ -17,6 +17,7 @@
from __future__ import annotations
+import ctypes
from types import ModuleType
import pytest
@@ -30,6 +31,24 @@ except ImportError:
torch = None
+try:
+ from cuda.bindings import driver as cuda_driver # type:
ignore[import-not-found]
+except ImportError:
+ cuda_driver = None
+
+
[email protected](cuda_driver is None, reason="Requires cuda-python")
+def test_cuda_driver_stream() -> None:
+ assert cuda_driver is not None
+ echo = tvm_ffi.get_global_func("testing.echo")
+ stream = cuda_driver.CUstream(0)
+ y = echo(stream)
+ assert y is not None
+ z = echo(cuda_driver.CUstream(1))
+ assert isinstance(z, ctypes.c_void_p)
+ assert z.value == 1
+
+
def gen_check_stream_mod() -> tvm_ffi.Module:
return tvm_ffi.cpp.load_inline(
name="check_stream",