This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 752ac8e  [PyTorch] Allow tensor conversion on rocm backend (#253)
752ac8e is described below

commit 752ac8ed2a76b5dcdf1655116b9449c207b872f0
Author: Ruihang Lai <[email protected]>
AuthorDate: Mon Nov 10 15:21:24 2025 -0500

    [PyTorch] Allow tensor conversion on rocm backend (#253)
    
    This commits adds the support of tensor conversion on rocm backend.
---
 addons/torch_c_dlpack_ext/build_backend.py         |  7 ++++-
 python/tvm_ffi/_optional_torch_c_dlpack.py         | 11 +++++--
 .../utils/_build_optional_torch_c_dlpack.py        | 22 ++++++++++++--
 tests/python/test_optional_torch_c_dlpack.py       |  4 ++-
 tests/python/test_tensor.py                        | 20 +++++++++++++
 tests/scripts/benchmark_dlpack.py                  | 35 +++++++++++++++-------
 6 files changed, 82 insertions(+), 17 deletions(-)

diff --git a/addons/torch_c_dlpack_ext/build_backend.py 
b/addons/torch_c_dlpack_ext/build_backend.py
index 4301e43..9489c5b 100644
--- a/addons/torch_c_dlpack_ext/build_backend.py
+++ b/addons/torch_c_dlpack_ext/build_backend.py
@@ -74,6 +74,11 @@ def build_wheel(
                 "No need to build any torch c dlpackc libs."
             )
         else:
+            extra_args = []
+            if torch.version.cuda is not None:
+                extra_args.append("--build-with-cuda")
+            elif torch.version.hip is not None:
+                extra_args.append("--build-with-rocm")
             subprocess.run(
                 [
                     sys.executable,
@@ -81,7 +86,7 @@ def build_wheel(
                     "tvm_ffi.utils._build_optional_torch_c_dlpack",
                     "--output-dir",
                     str(_package_path),
-                    "--build-with-cuda" if torch.cuda.is_available() else "",
+                    *extra_args,
                 ],
                 check=True,
                 env={**os.environ, "TVM_FFI_DISABLE_TORCH_C_DLPACK": "1"},
diff --git a/python/tvm_ffi/_optional_torch_c_dlpack.py 
b/python/tvm_ffi/_optional_torch_c_dlpack.py
index 29c1db3..e05fc34 100644
--- a/python/tvm_ffi/_optional_torch_c_dlpack.py
+++ b/python/tvm_ffi/_optional_torch_c_dlpack.py
@@ -67,7 +67,12 @@ def load_torch_c_dlpack_extension() -> Any:
         cache_dir = Path(os.environ.get("TVM_FFI_CACHE_DIR", 
"~/.cache/tvm-ffi")).expanduser()
         addon_output_dir = cache_dir
         major, minor = torch.__version__.split(".")[:2]
-        device = "cpu" if not torch.cuda.is_available() else "cuda"
+        if torch.version.cuda is not None:
+            device = "cuda"
+        elif torch.version.hip is not None:
+            device = "rocm"
+        else:
+            device = "cpu"
         suffix = ".dll" if sys.platform.startswith("win") else ".so"
         libname = 
f"libtorch_c_dlpack_addon_torch{major}{minor}-{device}{suffix}"
         lib_path = addon_output_dir / libname
@@ -83,8 +88,10 @@ def load_torch_c_dlpack_extension() -> Any:
                 "--libname",
                 libname,
             ]
-            if torch.cuda.is_available():
+            if device == "cuda":
                 args.append("--build-with-cuda")
+            elif device == "rocm":
+                args.append("--build-with-rocm")
             subprocess.run(
                 args,
                 check=True,
diff --git a/python/tvm_ffi/utils/_build_optional_torch_c_dlpack.py 
b/python/tvm_ffi/utils/_build_optional_torch_c_dlpack.py
index 5f73d63..20705be 100644
--- a/python/tvm_ffi/utils/_build_optional_torch_c_dlpack.py
+++ b/python/tvm_ffi/utils/_build_optional_torch_c_dlpack.py
@@ -45,6 +45,9 @@ cpp_source = """
 #ifdef BUILD_WITH_CUDA
 #include <c10/cuda/CUDAStream.h>
 #endif
+#ifdef BUILD_WITH_ROCM
+#include <c10/hip/HIPStream.h>
+#endif
 
 using namespace std;
 namespace at {
@@ -708,6 +711,11 @@ def main() -> None:  # noqa: PLR0912, PLR0915
         action="store_true",
         help="Build with CUDA support.",
     )
+    parser.add_argument(
+        "--build-with-rocm",
+        action="store_true",
+        help="Build with ROCm support.",
+    )
     parser.add_argument(
         "--libname",
         type=str,
@@ -716,6 +724,8 @@ def main() -> None:  # noqa: PLR0912, PLR0915
     )
 
     args = parser.parse_args()
+    if args.build_with_cuda and args.build_with_rocm:
+        raise ValueError("Cannot enable both CUDA and ROCm at the same time.")
 
     # resolve build directory
     if args.build_dir is None:
@@ -729,7 +739,12 @@ def main() -> None:  # noqa: PLR0912, PLR0915
     # resolve library name
     if args.libname == "auto":
         major, minor = torch.__version__.split(".")[:2]
-        device = "cpu" if not args.build_with_cuda else "cuda"
+        if args.build_with_cuda:
+            device = "cuda"
+        elif args.build_with_rocm:
+            device = "rocm"
+        else:
+            device = "cpu"
         suffix = ".dll" if IS_WINDOWS else ".so"
         libname = 
f"libtorch_c_dlpack_addon_torch{major}{minor}-{device}{suffix}"
     else:
@@ -759,7 +774,10 @@ def main() -> None:  # noqa: PLR0912, PLR0915
 
         if args.build_with_cuda:
             cflags.append("-DBUILD_WITH_CUDA")
-        include_paths.extend(get_torch_include_paths(args.build_with_cuda))
+        elif args.build_with_rocm:
+            cflags.extend(torch.utils.cpp_extension.COMMON_HIP_FLAGS)
+            cflags.append("-DBUILD_WITH_ROCM")
+        include_paths.extend(get_torch_include_paths(args.build_with_cuda or 
args.build_with_rocm))
 
         # use CXX11 ABI
         if torch.compiled_with_cxx11_abi():
diff --git a/tests/python/test_optional_torch_c_dlpack.py 
b/tests/python/test_optional_torch_c_dlpack.py
index fe0bac5..8aded2b 100644
--- a/tests/python/test_optional_torch_c_dlpack.py
+++ b/tests/python/test_optional_torch_c_dlpack.py
@@ -44,8 +44,10 @@ def test_build_torch_c_dlpack_extension() -> None:
         "--libname",
         "libtorch_c_dlpack_addon_test.so",
     ]
-    if torch.cuda.is_available():
+    if torch.version.cuda is not None:
         args.append("--build-with-cuda")
+    elif torch.version.hip is not None:
+        args.append("--build-with-rocm")
     subprocess.run(args, check=True)
 
     lib_path = 
str(Path("./output-dir/libtorch_c_dlpack_addon_test.so").resolve())
diff --git a/tests/python/test_tensor.py b/tests/python/test_tensor.py
index 4ea8680..d772ab3 100644
--- a/tests/python/test_tensor.py
+++ b/tests/python/test_tensor.py
@@ -112,3 +112,23 @@ def test_tvm_ffi_tensor_compatible() -> None:
     fecho = tvm_ffi.get_global_func("testing.echo")
     z = fecho(y)
     assert z.__chandle__() == x.__chandle__()
+
+
[email protected](
+    torch is None or torch.version.hip is None, reason="ROCm is not enabled in 
PyTorch"
+)
+def test_tensor_from_pytorch_rocm() -> None:
+    assert torch is not None
+
+    @tvm_ffi.register_global_func("testing.check_device", override=True)
+    def _check_device(x: tvm_ffi.Tensor) -> str:
+        return x.device.type
+
+    # PyTorch uses device name "cuda" to represent ROCm device
+    x = torch.randn(128, device="cuda")
+    device_type = tvm_ffi.get_global_func("testing.check_device")(x)
+    assert device_type == "rocm"
+
+
+if __name__ == "__main__":
+    pytest.main([__file__])
diff --git a/tests/scripts/benchmark_dlpack.py 
b/tests/scripts/benchmark_dlpack.py
index a4833f2..23db3a8 100644
--- a/tests/scripts/benchmark_dlpack.py
+++ b/tests/scripts/benchmark_dlpack.py
@@ -337,17 +337,30 @@ def load_torch_get_current_cuda_stream() -> 
Callable[[int], int]:
     """Create a faster get_current_cuda_stream for torch through cpp 
extension."""
     from torch.utils import cpp_extension  # noqa: PLC0415
 
-    source = """
-    #include <c10/cuda/CUDAStream.h>
-
-    int64_t get_current_cuda_stream(int device_id) {
-        at::cuda::CUDAStream stream = 
at::cuda::getCurrentCUDAStream(device_id);
-        // fast invariant, default stream is always 0
-        if (stream.id() == 0) return 0;
-        // convert to cudaStream_t
-        return reinterpret_cast<int64_t>(static_cast<cudaStream_t>(stream));
-    }
-    """
+    if torch.version.cuda is not None:
+        source = """
+        #include <c10/cuda/CUDAStream.h>
+
+        int64_t get_current_cuda_stream(int device_id) {
+            at::cuda::CUDAStream stream = 
at::cuda::getCurrentCUDAStream(device_id);
+            // fast invariant, default stream is always 0
+            if (stream.id() == 0) return 0;
+            // convert to cudaStream_t
+            return 
reinterpret_cast<int64_t>(static_cast<cudaStream_t>(stream));
+        }
+        """
+    elif torch.version.hip is not None:
+        source = """
+        #include <c10/hip/HIPStream.h>
+
+        int64_t get_current_cuda_stream(int device_id) {
+            at::hip::HIPStream stream = 
at::hip::getCurrentHIPStream(device_id);
+            // fast invariant, default stream is always 0
+            if (stream.id() == 0) return 0;
+            // convert to hipStream_t
+            return reinterpret_cast<int64_t>(static_cast<hipStream_t>(stream));
+        }
+        """
     result = cpp_extension.load_inline(
         name="get_current_cuda_stream",
         cpp_sources=[source],

Reply via email to