This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 37d0485b ROCm: fix DLPack current_work_stream and force addon override
for Torch exchange API (#466)
37d0485b is described below
commit 37d0485b2058885bf4e7a486f7d7b2174a8ac1ce
Author: Hubert Lu <[email protected]>
AuthorDate: Fri Feb 20 07:07:07 2026 -0800
ROCm: fix DLPack current_work_stream and force addon override for Torch
exchange API (#466)
## Motivation
On ROCm, graph-capture-sensitive callers (for example SGLang JIT
kernels) can hit stream mismatch behavior through the Torch DLPack
exchange API path. In practice this shows up as HIP graph empty-capture
warnings and forced downstream workarounds in callers.
We want `tvm-ffi` to provide a correct ROCm stream callback and let
downstream projects remove ROCm-specific Python bridging code.
## Reproducer
### Minimal `tvm-ffi` reproducer
- New test: `tests/python/test_current_work_stream_gpu.py`
- It obtains `torch.Tensor.__dlpack_c_exchange_api__`, calls
`current_work_stream`, and verifies stream identity against a
non-default torch stream.
- It checks:
- `kDLCUDA` always
- `kDLROCM` on HIP runtime
## Fix
### 1) ROCm stream callback implementation
File:
- `python/tvm_ffi/utils/_build_optional_torch_c_dlpack.py`
Changes:
- Include `ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h` for ROCm builds.
- In `CurrentWorkStream(...)`:
- ROCm build path now returns:
- `c10::hip::getCurrentHIPStreamMasqueradingAsCUDA(device_id).stream()`
- CUDA build path remains:
- `at::cuda::getCurrentCUDAStream(device_id).stream()`
This aligns stream reporting with ROCm semantics used by PyTorch HIP
stream handling.
### 2) Ensure ROCm uses `tvm-ffi` addon API
File:
- `python/tvm_ffi/_optional_torch_c_dlpack.py`
Changes:
- On ROCm (`torch.cuda.is_available() and torch.version.hip is not
None`), do not early-return just because torch already exposes
`__dlpack_c_exchange_api__`.
- Force loading/using the `tvm-ffi` addon capsule on ROCm so the fixed
callback is actually active.
- NVIDIA/CUDA path remains unchanged.
## Validation
### `tvm-ffi` side
- `python -m pytest -q tests/python/test_current_work_stream_gpu.py`
- Result: pass
### Downstream impact check (SGLang)
To verify this can remove divergence in SGLang JIT path, we temporarily
removed:
- `hip_ensure_tvm_ffi_stream(...)`
- `to_tvm_tensor_cached(...)`
Ref:
https://github.com/sgl-project/sglang/pull/18992/changes/df80efb85eee5f5c2be4241865a77508ae0d7a69#diff-ff0d21f07b9d4c75f02741e8e63e4eee34a1bd24e8f74c7b6b96f0f19ee97bd8R204-R233
---
python/tvm_ffi/_optional_torch_c_dlpack.py | 5 +-
.../utils/_build_optional_torch_c_dlpack.py | 9 +-
tests/python/test_current_work_stream_gpu.py | 96 ++++++++++++++++++++++
3 files changed, 107 insertions(+), 3 deletions(-)
diff --git a/python/tvm_ffi/_optional_torch_c_dlpack.py
b/python/tvm_ffi/_optional_torch_c_dlpack.py
index 594abaee..70121084 100644
--- a/python/tvm_ffi/_optional_torch_c_dlpack.py
+++ b/python/tvm_ffi/_optional_torch_c_dlpack.py
@@ -94,7 +94,8 @@ def load_torch_c_dlpack_extension() -> Any: # noqa: PLR0912,
PLR0915
import torch # noqa: PLC0415
import torch.version # noqa: PLC0415
- if _check_and_update_dlpack_c_exchange_api(torch.Tensor):
+ prefer_rocm_override = bool(torch.cuda.is_available() and
torch.version.hip is not None)
+ if _check_and_update_dlpack_c_exchange_api(torch.Tensor) and not
prefer_rocm_override:
# skip loading the extension if the __dlpack_c_exchange_api__
# attribute is already set so we don't have to do it in
# newer version of PyTorch
@@ -106,7 +107,7 @@ def load_torch_c_dlpack_extension() -> Any: # noqa:
PLR0912, PLR0915
try:
import torch_c_dlpack_ext # noqa: PLC0415, F401
- if _check_and_update_dlpack_c_exchange_api(torch.Tensor):
+ if _check_and_update_dlpack_c_exchange_api(torch.Tensor) and not
prefer_rocm_override:
return None
except ImportError:
pass
diff --git a/python/tvm_ffi/utils/_build_optional_torch_c_dlpack.py
b/python/tvm_ffi/utils/_build_optional_torch_c_dlpack.py
index 8f38530c..7277568d 100644
--- a/python/tvm_ffi/utils/_build_optional_torch_c_dlpack.py
+++ b/python/tvm_ffi/utils/_build_optional_torch_c_dlpack.py
@@ -48,6 +48,7 @@ cpp_source = """
#endif
#ifdef BUILD_WITH_ROCM
#include <c10/hip/HIPStream.h>
+#include <ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h>
#endif
using namespace std;
@@ -506,8 +507,14 @@ struct TorchDLPackExchangeAPI : public DLPackExchangeAPI {
// Get current CUDA/ROCm work stream
static int CurrentWorkStream(DLDeviceType device_type, int32_t device_id,
void** out_stream) {
try {
+#ifdef BUILD_WITH_ROCM
+ if (device_type == kDLROCM || device_type == kDLCUDA) {
+ *out_stream =
c10::hip::getCurrentHIPStreamMasqueradingAsCUDA(device_id).stream();
+ return 0;
+ }
+#endif
#ifdef BUILD_WITH_CUDA
- if (device_type == kDLCUDA || device_type == kDLROCM) {
+ if (device_type == kDLCUDA) {
*out_stream = at::cuda::getCurrentCUDAStream(device_id).stream();
return 0;
}
diff --git a/tests/python/test_current_work_stream_gpu.py
b/tests/python/test_current_work_stream_gpu.py
new file mode 100644
index 00000000..fcd3b946
--- /dev/null
+++ b/tests/python/test_current_work_stream_gpu.py
@@ -0,0 +1,96 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file to
+# you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import ctypes
+
+import pytest
+
+try:
+ import torch
+ import tvm_ffi # noqa: F401
+ from torch.utils import cpp_extension
+ from tvm_ffi import libinfo
+except ImportError:
+ torch = None # ty: ignore[invalid-assignment]
+
+if torch is None:
+ _HAS_TORCH = False
+ _HAS_GPU = False
+ _HAS_DLPACK_EXCHANGE_API = False
+else:
+ _HAS_TORCH = True
+ _HAS_GPU = bool(torch.cuda.is_available())
+ _HAS_DLPACK_EXCHANGE_API = bool(hasattr(torch.Tensor,
"__dlpack_c_exchange_api__"))
+
+
[email protected](not _HAS_TORCH, reason="Requires torch")
[email protected](not _HAS_GPU, reason="Requires GPU runtime")
[email protected](not _HAS_DLPACK_EXCHANGE_API, reason="Requires
__dlpack_c_exchange_api__")
+def test_current_work_stream_matches_torch_stream() -> None:
+ assert torch is not None
+ api_attr = torch.Tensor.__dlpack_c_exchange_api__
+
+ pythonapi = ctypes.pythonapi
+ pythonapi.PyCapsule_GetPointer.restype = ctypes.c_size_t
+ pythonapi.PyCapsule_GetPointer.argtypes = [ctypes.py_object,
ctypes.c_char_p]
+ api_ptr = pythonapi.PyCapsule_GetPointer(api_attr, b"dlpack_exchange_api")
+ assert api_ptr != 0
+
+ source = r"""
+ #include <torch/extension.h>
+ #include <dlpack/dlpack.h>
+
+ void assert_current_work_stream(int64_t api_ptr_int, bool is_hip, int64_t
expected_stream) {
+ DLPackExchangeAPI* api =
reinterpret_cast<DLPackExchangeAPI*>(api_ptr_int);
+ TORCH_CHECK(api != nullptr, "API pointer is NULL");
+ TORCH_CHECK(api->current_work_stream != nullptr, "current_work_stream
is NULL");
+
+ void* stream_cuda = nullptr;
+ int result_cuda = api->current_work_stream(kDLCUDA, 0, &stream_cuda);
+ TORCH_CHECK(result_cuda == 0, "current_work_stream(kDLCUDA) failed");
+ TORCH_CHECK(reinterpret_cast<int64_t>(stream_cuda) == expected_stream,
+ "kDLCUDA stream mismatch");
+
+ if (is_hip) {
+ void* stream_rocm = nullptr;
+ int result_rocm = api->current_work_stream(kDLROCM, 0,
&stream_rocm);
+ TORCH_CHECK(result_rocm == 0, "current_work_stream(kDLROCM)
failed");
+ TORCH_CHECK(reinterpret_cast<int64_t>(stream_rocm) ==
expected_stream,
+ "kDLROCM stream mismatch");
+ }
+ }
+ """
+
+ include_paths = libinfo.include_paths()
+ include_paths += cpp_extension.include_paths("cuda")
+
+ mod = cpp_extension.load_inline(
+ name="test_current_work_stream_gpu_ext",
+ cpp_sources=[source],
+ functions=["assert_current_work_stream"],
+ with_cuda=torch.cuda.is_available(),
+ extra_include_paths=include_paths,
+ )
+
+ device_id = torch.cuda.current_device()
+ is_hip = torch.version.hip is not None
+ stream = torch.cuda.Stream(device=device_id)
+ with torch.cuda.stream(stream):
+ expected_stream = int(stream.cuda_stream)
+ mod.assert_current_work_stream(api_ptr, is_hip, expected_stream)