This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new fde8dab  [CYTHON] Fix stream passing bug (#68)
fde8dab is described below

commit fde8dabbba8aa0ea8133a02fcd9ff0190d830948
Author: Tianqi Chen <[email protected]>
AuthorDate: Sat Sep 27 20:46:13 2025 -0400

    [CYTHON] Fix stream passing bug (#68)
    
    This PR fixes a bug in stream passing which breaks expected stream
    passing behavior. Also added a regression case via load_inline_cuda to
    guard this issue (needs CUDA env to run atm).
---
 pyproject.toml                     |  2 +-
 python/tvm_ffi/__init__.py         |  2 +-
 python/tvm_ffi/cython/function.pxi |  2 +-
 tests/python/test_load_inline.py   | 15 +++++++++++++--
 4 files changed, 16 insertions(+), 5 deletions(-)

diff --git a/pyproject.toml b/pyproject.toml
index f734184..4a6a734 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -17,7 +17,7 @@
 
 [project]
 name = "apache-tvm-ffi"
-version = "0.1.0b10"
+version = "0.1.0b11"
 description = "tvm ffi"
 
 authors = [{ name = "TVM FFI team" }]
diff --git a/python/tvm_ffi/__init__.py b/python/tvm_ffi/__init__.py
index 807f9a9..d88a35e 100644
--- a/python/tvm_ffi/__init__.py
+++ b/python/tvm_ffi/__init__.py
@@ -17,7 +17,7 @@
 """TVM FFI Python package."""
 
 # version
-__version__ = "0.1.0b10"
+__version__ = "0.1.0b11"
 
 # order matters here so we need to skip isort here
 # isort: skip_file
diff --git a/python/tvm_ffi/cython/function.pxi 
b/python/tvm_ffi/cython/function.pxi
index 095e3d6..2fa75fb 100644
--- a/python/tvm_ffi/cython/function.pxi
+++ b/python/tvm_ffi/cython/function.pxi
@@ -155,7 +155,7 @@ cdef int TVMFFIPyArgSetterDLPackCExporter_(
     if this.c_dlpack_tensor_allocator != NULL:
         ctx.c_dlpack_tensor_allocator = this.c_dlpack_tensor_allocator
 
-    if ctx.device_id != -1:
+    if ctx.device_type != -1:
         # already queried device, do not do it again, pass NULL to stream
         if (this.c_dlpack_from_pyobject)(arg, &temp_managed_tensor, NULL) != 0:
             return -1
diff --git a/tests/python/test_load_inline.py b/tests/python/test_load_inline.py
index cd46bf5..229dc62 100644
--- a/tests/python/test_load_inline.py
+++ b/tests/python/test_load_inline.py
@@ -167,7 +167,7 @@ def test_load_inline_cuda() -> None:
               }
             }
 
-            void add_one_cuda(tvm::ffi::Tensor x, tvm::ffi::Tensor y) {
+            void add_one_cuda(tvm::ffi::Tensor x, tvm::ffi::Tensor y, int64_t 
raw_stream) {
               // implementation of a library function
               TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
               DLDataType f32_dtype{kDLFloat, 32, 1};
@@ -184,6 +184,8 @@ def test_load_inline_cuda() -> None:
               // with torch.Tensors
               cudaStream_t stream = static_cast<cudaStream_t>(
                   TVMFFIEnvGetStream(x->device.device_type, 
x->device.device_id));
+              TVM_FFI_ICHECK_EQ(reinterpret_cast<int64_t>(stream), raw_stream)
+                << "stream must be the same as raw_stream";
               // launch the kernel
               AddOneKernel<<<nblock, nthread_per_block, 0, 
stream>>>(static_cast<float*>(x->data),
                                                                      
static_cast<float*>(y->data), n);
@@ -193,9 +195,18 @@ def test_load_inline_cuda() -> None:
     )
 
     if torch is not None:
+        # test with raw stream
         x_cuda = torch.asarray([1, 2, 3, 4, 5], dtype=torch.float32, 
device="cuda")
         y_cuda = torch.empty_like(x_cuda)
-        mod.add_one_cuda(x_cuda, y_cuda)
+        mod.add_one_cuda(x_cuda, y_cuda, 0)
+        torch.testing.assert_close(x_cuda + 1, y_cuda)
+
+        # test with torch stream
+        y_cuda = torch.empty_like(x_cuda)
+        stream = torch.cuda.Stream()
+        with torch.cuda.stream(stream):
+            mod.add_one_cuda(x_cuda, y_cuda, stream.cuda_stream)
+        stream.synchronize()
         torch.testing.assert_close(x_cuda + 1, y_cuda)
 
 

Reply via email to