This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 37d0485b ROCm: fix DLPack current_work_stream and force addon override 
for Torch exchange API (#466)
37d0485b is described below

commit 37d0485b2058885bf4e7a486f7d7b2174a8ac1ce
Author: Hubert Lu <[email protected]>
AuthorDate: Fri Feb 20 07:07:07 2026 -0800

    ROCm: fix DLPack current_work_stream and force addon override for Torch 
exchange API (#466)
    
    ## Motivation
    
    On ROCm, graph-capture-sensitive callers (for example SGLang JIT
    kernels) can hit stream mismatch behavior through the Torch DLPack
    exchange API path. In practice this shows up as HIP graph empty-capture
    warnings and forced downstream workarounds in callers.
    
    We want `tvm-ffi` to provide a correct ROCm stream callback and let
    downstream projects remove ROCm-specific Python bridging code.
    
    ## Reproducer
    
    ### Minimal `tvm-ffi` reproducer
    
    - New test: `tests/python/test_current_work_stream_gpu.py`
    - It obtains `torch.Tensor.__dlpack_c_exchange_api__`, calls
    `current_work_stream`, and verifies stream identity against a
    non-default torch stream.
    - It checks:
      - `kDLCUDA` always
      - `kDLROCM` on HIP runtime
    
    ## Fix
    
    ### 1) ROCm stream callback implementation
    
    File:
    
    - `python/tvm_ffi/utils/_build_optional_torch_c_dlpack.py`
    
    Changes:
    
    - Include `ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h` for ROCm builds.
    - In `CurrentWorkStream(...)`:
      - ROCm build path now returns:
    - `c10::hip::getCurrentHIPStreamMasqueradingAsCUDA(device_id).stream()`
      - CUDA build path remains:
        - `at::cuda::getCurrentCUDAStream(device_id).stream()`
    
    This aligns stream reporting with ROCm semantics used by PyTorch HIP
    stream handling.
    
    ### 2) Ensure ROCm uses `tvm-ffi` addon API
    
    File:
    
    - `python/tvm_ffi/_optional_torch_c_dlpack.py`
    
    Changes:
    
    - On ROCm (`torch.cuda.is_available() and torch.version.hip is not
    None`), do not early-return just because torch already exposes
    `__dlpack_c_exchange_api__`.
    - Force loading/using the `tvm-ffi` addon capsule on ROCm so the fixed
    callback is actually active.
    - NVIDIA/CUDA path remains unchanged.
    
    ## Validation
    
    ### `tvm-ffi` side
    
    - `python -m pytest -q tests/python/test_current_work_stream_gpu.py`
    - Result: pass
    
    ### Downstream impact check (SGLang)
    
    To verify this can remove divergence in SGLang JIT path, we temporarily
    removed:
    
    - `hip_ensure_tvm_ffi_stream(...)`
    - `to_tvm_tensor_cached(...)`
    Ref:
    
https://github.com/sgl-project/sglang/pull/18992/changes/df80efb85eee5f5c2be4241865a77508ae0d7a69#diff-ff0d21f07b9d4c75f02741e8e63e4eee34a1bd24e8f74c7b6b96f0f19ee97bd8R204-R233
---
 python/tvm_ffi/_optional_torch_c_dlpack.py         |  5 +-
 .../utils/_build_optional_torch_c_dlpack.py        |  9 +-
 tests/python/test_current_work_stream_gpu.py       | 96 ++++++++++++++++++++++
 3 files changed, 107 insertions(+), 3 deletions(-)

diff --git a/python/tvm_ffi/_optional_torch_c_dlpack.py 
b/python/tvm_ffi/_optional_torch_c_dlpack.py
index 594abaee..70121084 100644
--- a/python/tvm_ffi/_optional_torch_c_dlpack.py
+++ b/python/tvm_ffi/_optional_torch_c_dlpack.py
@@ -94,7 +94,8 @@ def load_torch_c_dlpack_extension() -> Any:  # noqa: PLR0912, 
PLR0915
         import torch  # noqa: PLC0415
         import torch.version  # noqa: PLC0415
 
-        if _check_and_update_dlpack_c_exchange_api(torch.Tensor):
+        prefer_rocm_override = bool(torch.cuda.is_available() and 
torch.version.hip is not None)
+        if _check_and_update_dlpack_c_exchange_api(torch.Tensor) and not 
prefer_rocm_override:
             # skip loading the extension if the __dlpack_c_exchange_api__
             # attribute is already set so we don't have to do it in
             # newer version of PyTorch
@@ -106,7 +107,7 @@ def load_torch_c_dlpack_extension() -> Any:  # noqa: 
PLR0912, PLR0915
     try:
         import torch_c_dlpack_ext  # noqa: PLC0415, F401
 
-        if _check_and_update_dlpack_c_exchange_api(torch.Tensor):
+        if _check_and_update_dlpack_c_exchange_api(torch.Tensor) and not 
prefer_rocm_override:
             return None
     except ImportError:
         pass
diff --git a/python/tvm_ffi/utils/_build_optional_torch_c_dlpack.py 
b/python/tvm_ffi/utils/_build_optional_torch_c_dlpack.py
index 8f38530c..7277568d 100644
--- a/python/tvm_ffi/utils/_build_optional_torch_c_dlpack.py
+++ b/python/tvm_ffi/utils/_build_optional_torch_c_dlpack.py
@@ -48,6 +48,7 @@ cpp_source = """
 #endif
 #ifdef BUILD_WITH_ROCM
 #include <c10/hip/HIPStream.h>
+#include <ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h>
 #endif
 
 using namespace std;
@@ -506,8 +507,14 @@ struct TorchDLPackExchangeAPI : public DLPackExchangeAPI {
   // Get current CUDA/ROCm work stream
   static int CurrentWorkStream(DLDeviceType device_type, int32_t device_id, 
void** out_stream) {
     try {
+#ifdef BUILD_WITH_ROCM
+      if (device_type == kDLROCM || device_type == kDLCUDA) {
+        *out_stream = 
c10::hip::getCurrentHIPStreamMasqueradingAsCUDA(device_id).stream();
+        return 0;
+      }
+#endif
 #ifdef BUILD_WITH_CUDA
-      if (device_type == kDLCUDA || device_type == kDLROCM) {
+      if (device_type == kDLCUDA) {
         *out_stream = at::cuda::getCurrentCUDAStream(device_id).stream();
         return 0;
       }
diff --git a/tests/python/test_current_work_stream_gpu.py 
b/tests/python/test_current_work_stream_gpu.py
new file mode 100644
index 00000000..fcd3b946
--- /dev/null
+++ b/tests/python/test_current_work_stream_gpu.py
@@ -0,0 +1,96 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file to
+# you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import ctypes
+
+import pytest
+
+try:
+    import torch
+    import tvm_ffi  # noqa: F401
+    from torch.utils import cpp_extension
+    from tvm_ffi import libinfo
+except ImportError:
+    torch = None  # ty: ignore[invalid-assignment]
+
+if torch is None:
+    _HAS_TORCH = False
+    _HAS_GPU = False
+    _HAS_DLPACK_EXCHANGE_API = False
+else:
+    _HAS_TORCH = True
+    _HAS_GPU = bool(torch.cuda.is_available())
+    _HAS_DLPACK_EXCHANGE_API = bool(hasattr(torch.Tensor, 
"__dlpack_c_exchange_api__"))
+
+
[email protected](not _HAS_TORCH, reason="Requires torch")
[email protected](not _HAS_GPU, reason="Requires GPU runtime")
[email protected](not _HAS_DLPACK_EXCHANGE_API, reason="Requires 
__dlpack_c_exchange_api__")
+def test_current_work_stream_matches_torch_stream() -> None:
+    assert torch is not None
+    api_attr = torch.Tensor.__dlpack_c_exchange_api__
+
+    pythonapi = ctypes.pythonapi
+    pythonapi.PyCapsule_GetPointer.restype = ctypes.c_size_t
+    pythonapi.PyCapsule_GetPointer.argtypes = [ctypes.py_object, 
ctypes.c_char_p]
+    api_ptr = pythonapi.PyCapsule_GetPointer(api_attr, b"dlpack_exchange_api")
+    assert api_ptr != 0
+
+    source = r"""
+    #include <torch/extension.h>
+    #include <dlpack/dlpack.h>
+
+    void assert_current_work_stream(int64_t api_ptr_int, bool is_hip, int64_t 
expected_stream) {
+        DLPackExchangeAPI* api = 
reinterpret_cast<DLPackExchangeAPI*>(api_ptr_int);
+        TORCH_CHECK(api != nullptr, "API pointer is NULL");
+        TORCH_CHECK(api->current_work_stream != nullptr, "current_work_stream 
is NULL");
+
+        void* stream_cuda = nullptr;
+        int result_cuda = api->current_work_stream(kDLCUDA, 0, &stream_cuda);
+        TORCH_CHECK(result_cuda == 0, "current_work_stream(kDLCUDA) failed");
+        TORCH_CHECK(reinterpret_cast<int64_t>(stream_cuda) == expected_stream,
+                    "kDLCUDA stream mismatch");
+
+        if (is_hip) {
+            void* stream_rocm = nullptr;
+            int result_rocm = api->current_work_stream(kDLROCM, 0, 
&stream_rocm);
+            TORCH_CHECK(result_rocm == 0, "current_work_stream(kDLROCM) 
failed");
+            TORCH_CHECK(reinterpret_cast<int64_t>(stream_rocm) == 
expected_stream,
+                        "kDLROCM stream mismatch");
+        }
+    }
+    """
+
+    include_paths = libinfo.include_paths()
+    include_paths += cpp_extension.include_paths("cuda")
+
+    mod = cpp_extension.load_inline(
+        name="test_current_work_stream_gpu_ext",
+        cpp_sources=[source],
+        functions=["assert_current_work_stream"],
+        with_cuda=torch.cuda.is_available(),
+        extra_include_paths=include_paths,
+    )
+
+    device_id = torch.cuda.current_device()
+    is_hip = torch.version.hip is not None
+    stream = torch.cuda.Stream(device=device_id)
+    with torch.cuda.stream(stream):
+        expected_stream = int(stream.cuda_stream)
+        mod.assert_current_work_stream(api_ptr, is_hip, expected_stream)

Reply via email to