This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 752ac8e [PyTorch] Allow tensor conversion on rocm backend (#253)
752ac8e is described below
commit 752ac8ed2a76b5dcdf1655116b9449c207b872f0
Author: Ruihang Lai <[email protected]>
AuthorDate: Mon Nov 10 15:21:24 2025 -0500
[PyTorch] Allow tensor conversion on rocm backend (#253)
This commits adds the support of tensor conversion on rocm backend.
---
addons/torch_c_dlpack_ext/build_backend.py | 7 ++++-
python/tvm_ffi/_optional_torch_c_dlpack.py | 11 +++++--
.../utils/_build_optional_torch_c_dlpack.py | 22 ++++++++++++--
tests/python/test_optional_torch_c_dlpack.py | 4 ++-
tests/python/test_tensor.py | 20 +++++++++++++
tests/scripts/benchmark_dlpack.py | 35 +++++++++++++++-------
6 files changed, 82 insertions(+), 17 deletions(-)
diff --git a/addons/torch_c_dlpack_ext/build_backend.py
b/addons/torch_c_dlpack_ext/build_backend.py
index 4301e43..9489c5b 100644
--- a/addons/torch_c_dlpack_ext/build_backend.py
+++ b/addons/torch_c_dlpack_ext/build_backend.py
@@ -74,6 +74,11 @@ def build_wheel(
"No need to build any torch c dlpackc libs."
)
else:
+ extra_args = []
+ if torch.version.cuda is not None:
+ extra_args.append("--build-with-cuda")
+ elif torch.version.hip is not None:
+ extra_args.append("--build-with-rocm")
subprocess.run(
[
sys.executable,
@@ -81,7 +86,7 @@ def build_wheel(
"tvm_ffi.utils._build_optional_torch_c_dlpack",
"--output-dir",
str(_package_path),
- "--build-with-cuda" if torch.cuda.is_available() else "",
+ *extra_args,
],
check=True,
env={**os.environ, "TVM_FFI_DISABLE_TORCH_C_DLPACK": "1"},
diff --git a/python/tvm_ffi/_optional_torch_c_dlpack.py
b/python/tvm_ffi/_optional_torch_c_dlpack.py
index 29c1db3..e05fc34 100644
--- a/python/tvm_ffi/_optional_torch_c_dlpack.py
+++ b/python/tvm_ffi/_optional_torch_c_dlpack.py
@@ -67,7 +67,12 @@ def load_torch_c_dlpack_extension() -> Any:
cache_dir = Path(os.environ.get("TVM_FFI_CACHE_DIR",
"~/.cache/tvm-ffi")).expanduser()
addon_output_dir = cache_dir
major, minor = torch.__version__.split(".")[:2]
- device = "cpu" if not torch.cuda.is_available() else "cuda"
+ if torch.version.cuda is not None:
+ device = "cuda"
+ elif torch.version.hip is not None:
+ device = "rocm"
+ else:
+ device = "cpu"
suffix = ".dll" if sys.platform.startswith("win") else ".so"
libname =
f"libtorch_c_dlpack_addon_torch{major}{minor}-{device}{suffix}"
lib_path = addon_output_dir / libname
@@ -83,8 +88,10 @@ def load_torch_c_dlpack_extension() -> Any:
"--libname",
libname,
]
- if torch.cuda.is_available():
+ if device == "cuda":
args.append("--build-with-cuda")
+ elif device == "rocm":
+ args.append("--build-with-rocm")
subprocess.run(
args,
check=True,
diff --git a/python/tvm_ffi/utils/_build_optional_torch_c_dlpack.py
b/python/tvm_ffi/utils/_build_optional_torch_c_dlpack.py
index 5f73d63..20705be 100644
--- a/python/tvm_ffi/utils/_build_optional_torch_c_dlpack.py
+++ b/python/tvm_ffi/utils/_build_optional_torch_c_dlpack.py
@@ -45,6 +45,9 @@ cpp_source = """
#ifdef BUILD_WITH_CUDA
#include <c10/cuda/CUDAStream.h>
#endif
+#ifdef BUILD_WITH_ROCM
+#include <c10/hip/HIPStream.h>
+#endif
using namespace std;
namespace at {
@@ -708,6 +711,11 @@ def main() -> None: # noqa: PLR0912, PLR0915
action="store_true",
help="Build with CUDA support.",
)
+ parser.add_argument(
+ "--build-with-rocm",
+ action="store_true",
+ help="Build with ROCm support.",
+ )
parser.add_argument(
"--libname",
type=str,
@@ -716,6 +724,8 @@ def main() -> None: # noqa: PLR0912, PLR0915
)
args = parser.parse_args()
+ if args.build_with_cuda and args.build_with_rocm:
+ raise ValueError("Cannot enable both CUDA and ROCm at the same time.")
# resolve build directory
if args.build_dir is None:
@@ -729,7 +739,12 @@ def main() -> None: # noqa: PLR0912, PLR0915
# resolve library name
if args.libname == "auto":
major, minor = torch.__version__.split(".")[:2]
- device = "cpu" if not args.build_with_cuda else "cuda"
+ if args.build_with_cuda:
+ device = "cuda"
+ elif args.build_with_rocm:
+ device = "rocm"
+ else:
+ device = "cpu"
suffix = ".dll" if IS_WINDOWS else ".so"
libname =
f"libtorch_c_dlpack_addon_torch{major}{minor}-{device}{suffix}"
else:
@@ -759,7 +774,10 @@ def main() -> None: # noqa: PLR0912, PLR0915
if args.build_with_cuda:
cflags.append("-DBUILD_WITH_CUDA")
- include_paths.extend(get_torch_include_paths(args.build_with_cuda))
+ elif args.build_with_rocm:
+ cflags.extend(torch.utils.cpp_extension.COMMON_HIP_FLAGS)
+ cflags.append("-DBUILD_WITH_ROCM")
+ include_paths.extend(get_torch_include_paths(args.build_with_cuda or
args.build_with_rocm))
# use CXX11 ABI
if torch.compiled_with_cxx11_abi():
diff --git a/tests/python/test_optional_torch_c_dlpack.py
b/tests/python/test_optional_torch_c_dlpack.py
index fe0bac5..8aded2b 100644
--- a/tests/python/test_optional_torch_c_dlpack.py
+++ b/tests/python/test_optional_torch_c_dlpack.py
@@ -44,8 +44,10 @@ def test_build_torch_c_dlpack_extension() -> None:
"--libname",
"libtorch_c_dlpack_addon_test.so",
]
- if torch.cuda.is_available():
+ if torch.version.cuda is not None:
args.append("--build-with-cuda")
+ elif torch.version.hip is not None:
+ args.append("--build-with-rocm")
subprocess.run(args, check=True)
lib_path =
str(Path("./output-dir/libtorch_c_dlpack_addon_test.so").resolve())
diff --git a/tests/python/test_tensor.py b/tests/python/test_tensor.py
index 4ea8680..d772ab3 100644
--- a/tests/python/test_tensor.py
+++ b/tests/python/test_tensor.py
@@ -112,3 +112,23 @@ def test_tvm_ffi_tensor_compatible() -> None:
fecho = tvm_ffi.get_global_func("testing.echo")
z = fecho(y)
assert z.__chandle__() == x.__chandle__()
+
+
[email protected](
+ torch is None or torch.version.hip is None, reason="ROCm is not enabled in
PyTorch"
+)
+def test_tensor_from_pytorch_rocm() -> None:
+ assert torch is not None
+
+ @tvm_ffi.register_global_func("testing.check_device", override=True)
+ def _check_device(x: tvm_ffi.Tensor) -> str:
+ return x.device.type
+
+ # PyTorch uses device name "cuda" to represent ROCm device
+ x = torch.randn(128, device="cuda")
+ device_type = tvm_ffi.get_global_func("testing.check_device")(x)
+ assert device_type == "rocm"
+
+
+if __name__ == "__main__":
+ pytest.main([__file__])
diff --git a/tests/scripts/benchmark_dlpack.py
b/tests/scripts/benchmark_dlpack.py
index a4833f2..23db3a8 100644
--- a/tests/scripts/benchmark_dlpack.py
+++ b/tests/scripts/benchmark_dlpack.py
@@ -337,17 +337,30 @@ def load_torch_get_current_cuda_stream() ->
Callable[[int], int]:
"""Create a faster get_current_cuda_stream for torch through cpp
extension."""
from torch.utils import cpp_extension # noqa: PLC0415
- source = """
- #include <c10/cuda/CUDAStream.h>
-
- int64_t get_current_cuda_stream(int device_id) {
- at::cuda::CUDAStream stream =
at::cuda::getCurrentCUDAStream(device_id);
- // fast invariant, default stream is always 0
- if (stream.id() == 0) return 0;
- // convert to cudaStream_t
- return reinterpret_cast<int64_t>(static_cast<cudaStream_t>(stream));
- }
- """
+ if torch.version.cuda is not None:
+ source = """
+ #include <c10/cuda/CUDAStream.h>
+
+ int64_t get_current_cuda_stream(int device_id) {
+ at::cuda::CUDAStream stream =
at::cuda::getCurrentCUDAStream(device_id);
+ // fast invariant, default stream is always 0
+ if (stream.id() == 0) return 0;
+ // convert to cudaStream_t
+ return
reinterpret_cast<int64_t>(static_cast<cudaStream_t>(stream));
+ }
+ """
+ elif torch.version.hip is not None:
+ source = """
+ #include <c10/hip/HIPStream.h>
+
+ int64_t get_current_cuda_stream(int device_id) {
+ at::hip::HIPStream stream =
at::hip::getCurrentHIPStream(device_id);
+ // fast invariant, default stream is always 0
+ if (stream.id() == 0) return 0;
+ // convert to hipStream_t
+ return reinterpret_cast<int64_t>(static_cast<hipStream_t>(stream));
+ }
+ """
result = cpp_extension.load_inline(
name="get_current_cuda_stream",
cpp_sources=[source],