This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 22c049b Expose get stream method (#97)
22c049b is described below
commit 22c049b8f3b64e7e2f17b28df044b065ae3fba83
Author: Yaxing Cai <[email protected]>
AuthorDate: Wed Oct 8 17:48:13 2025 -0700
Expose get stream method (#97)
This PR exposes the get env stream method as followup of #5.
---
python/tvm_ffi/__init__.py | 2 +-
python/tvm_ffi/core.pyi | 1 +
python/tvm_ffi/cython/base.pxi | 6 ++++++
python/tvm_ffi/stream.py | 17 +++++++++++++++++
tests/python/test_stream.py | 3 +++
5 files changed, 28 insertions(+), 1 deletion(-)
diff --git a/python/tvm_ffi/__init__.py b/python/tvm_ffi/__init__.py
index 720968d..e46a7aa 100644
--- a/python/tvm_ffi/__init__.py
+++ b/python/tvm_ffi/__init__.py
@@ -42,7 +42,7 @@ from ._tensor import Device, device, DLDeviceType
from ._tensor import from_dlpack, Tensor, Shape
from .container import Array, Map
from .module import Module, system_lib, load_module
-from .stream import StreamContext, use_raw_stream, use_torch_stream
+from .stream import StreamContext, get_raw_stream, use_raw_stream,
use_torch_stream
from . import serialization
from . import access_path
from . import testing
diff --git a/python/tvm_ffi/core.pyi b/python/tvm_ffi/core.pyi
index 54b44cf..45a7d28 100644
--- a/python/tvm_ffi/core.pyi
+++ b/python/tvm_ffi/core.pyi
@@ -259,6 +259,7 @@ def _convert_to_ffi_error(error: BaseException) -> Error:
...
def _env_set_current_stream(
device_type: int, device_id: int, stream: int | c_void_p
) -> int | c_void_p: ...
+def _env_get_current_stream(device_type: int, device_id: int) -> int: ...
class DataType:
"""Internal wrapper around ``DLDataType``.
diff --git a/python/tvm_ffi/cython/base.pxi b/python/tvm_ffi/cython/base.pxi
index 633ace4..a8b4212 100644
--- a/python/tvm_ffi/cython/base.pxi
+++ b/python/tvm_ffi/cython/base.pxi
@@ -261,6 +261,12 @@ def _env_set_current_stream(int device_type, int
device_id, uint64_t stream):
return <uint64_t>prev_stream
+def _env_get_current_stream(int device_type, int device_id):
+ cdef void* current_stream
+ current_stream = TVMFFIEnvGetStream(device_type, device_id)
+ return <uint64_t>current_stream
+
+
cdef extern from "tvm_ffi_python_helpers.h":
# no need to expose fields of the call context setter data structure
ctypedef int (*DLPackFromPyObject)(
diff --git a/python/tvm_ffi/stream.py b/python/tvm_ffi/stream.py
index 7f2dde5..e00d422 100644
--- a/python/tvm_ffi/stream.py
+++ b/python/tvm_ffi/stream.py
@@ -161,3 +161,20 @@ def use_raw_stream(device: core.Device, stream: Union[int,
c_void_p]) -> StreamC
"try use_torch_stream when using torch.cuda.Stream or
torch.cuda.graph"
)
return StreamContext(device, stream)
+
+
+def get_raw_stream(device: core.Device) -> int:
+ """Get the current ffi stream of given device.
+
+ Parameters
+ ----------
+ device : tvm_ffi.Device
+ The device to which the stream belongs.
+
+ Returns
+ -------
+ stream : int
+ The current ffi stream.
+
+ """
+ return core._env_get_current_stream(device.dlpack_device_type(),
device.index)
diff --git a/tests/python/test_stream.py b/tests/python/test_stream.py
index 9280aab..fbe1a0a 100644
--- a/tests/python/test_stream.py
+++ b/tests/python/test_stream.py
@@ -50,11 +50,14 @@ def test_raw_stream() -> None:
stream_2 = 987654321
with tvm_ffi.use_raw_stream(device, stream_1):
mod.check_stream(device.dlpack_device_type(), device.index, stream_1)
+ assert tvm_ffi.get_raw_stream(device) == stream_1
with tvm_ffi.use_raw_stream(device, stream_2):
mod.check_stream(device.dlpack_device_type(), device.index,
stream_2)
+ assert tvm_ffi.get_raw_stream(device) == stream_2
mod.check_stream(device.dlpack_device_type(), device.index, stream_1)
+ assert tvm_ffi.get_raw_stream(device) == stream_1
@pytest.mark.skipif(