This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 22c049b  Expose get stream method (#97)
22c049b is described below

commit 22c049b8f3b64e7e2f17b28df044b065ae3fba83
Author: Yaxing Cai <[email protected]>
AuthorDate: Wed Oct 8 17:48:13 2025 -0700

    Expose get stream method (#97)
    
    This PR exposes the get env stream method as followup of #5.
---
 python/tvm_ffi/__init__.py     |  2 +-
 python/tvm_ffi/core.pyi        |  1 +
 python/tvm_ffi/cython/base.pxi |  6 ++++++
 python/tvm_ffi/stream.py       | 17 +++++++++++++++++
 tests/python/test_stream.py    |  3 +++
 5 files changed, 28 insertions(+), 1 deletion(-)

diff --git a/python/tvm_ffi/__init__.py b/python/tvm_ffi/__init__.py
index 720968d..e46a7aa 100644
--- a/python/tvm_ffi/__init__.py
+++ b/python/tvm_ffi/__init__.py
@@ -42,7 +42,7 @@ from ._tensor import Device, device, DLDeviceType
 from ._tensor import from_dlpack, Tensor, Shape
 from .container import Array, Map
 from .module import Module, system_lib, load_module
-from .stream import StreamContext, use_raw_stream, use_torch_stream
+from .stream import StreamContext, get_raw_stream, use_raw_stream, 
use_torch_stream
 from . import serialization
 from . import access_path
 from . import testing
diff --git a/python/tvm_ffi/core.pyi b/python/tvm_ffi/core.pyi
index 54b44cf..45a7d28 100644
--- a/python/tvm_ffi/core.pyi
+++ b/python/tvm_ffi/core.pyi
@@ -259,6 +259,7 @@ def _convert_to_ffi_error(error: BaseException) -> Error: 
...
 def _env_set_current_stream(
     device_type: int, device_id: int, stream: int | c_void_p
 ) -> int | c_void_p: ...
+def _env_get_current_stream(device_type: int, device_id: int) -> int: ...
 
 class DataType:
     """Internal wrapper around ``DLDataType``.
diff --git a/python/tvm_ffi/cython/base.pxi b/python/tvm_ffi/cython/base.pxi
index 633ace4..a8b4212 100644
--- a/python/tvm_ffi/cython/base.pxi
+++ b/python/tvm_ffi/cython/base.pxi
@@ -261,6 +261,12 @@ def _env_set_current_stream(int device_type, int 
device_id, uint64_t stream):
     return <uint64_t>prev_stream
 
 
+def _env_get_current_stream(int device_type, int device_id):
+    cdef void* current_stream
+    current_stream = TVMFFIEnvGetStream(device_type, device_id)
+    return <uint64_t>current_stream
+
+
 cdef extern from "tvm_ffi_python_helpers.h":
     # no need to expose fields of the call context setter data structure
     ctypedef int (*DLPackFromPyObject)(
diff --git a/python/tvm_ffi/stream.py b/python/tvm_ffi/stream.py
index 7f2dde5..e00d422 100644
--- a/python/tvm_ffi/stream.py
+++ b/python/tvm_ffi/stream.py
@@ -161,3 +161,20 @@ def use_raw_stream(device: core.Device, stream: Union[int, 
c_void_p]) -> StreamC
             "try use_torch_stream when using torch.cuda.Stream or 
torch.cuda.graph"
         )
     return StreamContext(device, stream)
+
+
+def get_raw_stream(device: core.Device) -> int:
+    """Get the current ffi stream of given device.
+
+    Parameters
+    ----------
+    device : tvm_ffi.Device
+        The device to which the stream belongs.
+
+    Returns
+    -------
+    stream : int
+        The current ffi stream.
+
+    """
+    return core._env_get_current_stream(device.dlpack_device_type(), 
device.index)
diff --git a/tests/python/test_stream.py b/tests/python/test_stream.py
index 9280aab..fbe1a0a 100644
--- a/tests/python/test_stream.py
+++ b/tests/python/test_stream.py
@@ -50,11 +50,14 @@ def test_raw_stream() -> None:
     stream_2 = 987654321
     with tvm_ffi.use_raw_stream(device, stream_1):
         mod.check_stream(device.dlpack_device_type(), device.index, stream_1)
+        assert tvm_ffi.get_raw_stream(device) == stream_1
 
         with tvm_ffi.use_raw_stream(device, stream_2):
             mod.check_stream(device.dlpack_device_type(), device.index, 
stream_2)
+            assert tvm_ffi.get_raw_stream(device) == stream_2
 
         mod.check_stream(device.dlpack_device_type(), device.index, stream_1)
+        assert tvm_ffi.get_raw_stream(device) == stream_1
 
 
 @pytest.mark.skipif(

Reply via email to