This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new b0537f0  [CYTHON] Support cuda stream protocol (#109)
b0537f0 is described below

commit b0537f045b30334a12bf3365438aedb2c3bc7285
Author: Tianqi Chen <[email protected]>
AuthorDate: Mon Oct 13 14:25:37 2025 -0400

    [CYTHON] Support cuda stream protocol (#109)
    
    This PR updates the ABI to support cuda stream protocol
    
https://nvidia.github.io/cuda-python/cuda-core/latest/interoperability.html#cuda-stream-protocol
    
    For now we turn stream into void* in passing into the ABI
---
 python/tvm_ffi/cython/function.pxi | 20 +++++++++++++++++++-
 tests/python/test_device.py        | 16 ++++++++++++++++
 2 files changed, 35 insertions(+), 1 deletion(-)

diff --git a/python/tvm_ffi/cython/function.pxi 
b/python/tvm_ffi/cython/function.pxi
index 0305fec..6fc2a7b 100644
--- a/python/tvm_ffi/cython/function.pxi
+++ b/python/tvm_ffi/cython/function.pxi
@@ -287,6 +287,20 @@ cdef int TVMFFIPyArgSetterFFITensorCompatible_(
     return 0
 
 
+cdef int TVMFFIPyArgSetterCUDAStream_(
+    TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
+    PyObject* py_arg, TVMFFIAny* out
+) except -1:
+    """Setter for cuda stream protocol"""
+    cdef object arg = <object>py_arg
+    # cuda stream is a subclass of str, so this check occur before str
+    cdef tuple cu_stream_tuple = arg.__cuda_stream__()
+    cdef long long long_ptr = <long long>cu_stream_tuple[1]
+    out.type_index = kTVMFFIOpaquePtr
+    out.v_ptr = <void*>long_ptr
+    return 0
+
+
 cdef int TVMFFIPyArgSetterDType_(
     TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
     PyObject* py_arg, TVMFFIAny* out
@@ -605,10 +619,14 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value, 
TVMFFIPyArgSetter* out) exce
             temp_ptr = arg_class.__c_dlpack_exchange_api__
             out.c_dlpack_exchange_api = <const DLPackExchangeAPI*>(<long 
long>temp_ptr)
             return 0
+    if hasattr(arg_class, "__cuda_stream__"):
+        # cuda stream protocol
+        out.func = TVMFFIPyArgSetterCUDAStream_
+        return 0
     if torch is not None and isinstance(arg, torch.Tensor):
         out.func = TVMFFIPyArgSetterTorchFallback_
         return 0
-    if hasattr(arg, "__dlpack__"):
+    if hasattr(arg_class, "__dlpack__"):
         out.func = TVMFFIPyArgSetterDLPack_
         return 0
     if isinstance(arg, bool):
diff --git a/tests/python/test_device.py b/tests/python/test_device.py
index 7a3638b..71dc0e4 100644
--- a/tests/python/test_device.py
+++ b/tests/python/test_device.py
@@ -17,6 +17,7 @@
 
 from __future__ import annotations
 
+import ctypes
 import pickle
 
 import pytest
@@ -109,3 +110,18 @@ def test_device_class_override() -> None:
     device = tvm_ffi.device("cuda", 0)
     assert isinstance(device, MyDevice)
     tvm_ffi.core._set_class_device(old_device)
+
+
+def test_cuda_stream_handling() -> None:
+    class MyDummyStream:
+        def __init__(self, stream: int) -> None:
+            self.stream = stream
+
+        def __cuda_stream__(self) -> tuple[str, int]:
+            return ("cuda", self.stream)
+
+    stream = MyDummyStream(1)
+    echo = tvm_ffi.get_global_func("testing.echo")
+    y = echo(stream)
+    assert isinstance(y, ctypes.c_void_p)
+    assert y.value == 1

Reply via email to