This is an automated email from the ASF dual-hosted git repository.
bohan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 2004a8bcbf [NVRTC] Add NVSHMEM support to NVRTC compilation path
(#18681)
2004a8bcbf is described below
commit 2004a8bcbfde8bc0c46995bfb4ce152d7dd4ec51
Author: Kathryn (Jinqi) Chen <[email protected]>
AuthorDate: Sat Jan 24 11:52:51 2026 -0800
[NVRTC] Add NVSHMEM support to NVRTC compilation path (#18681)
---
python/tvm/contrib/nvcc.py | 234 +++++++++++++++++++--
.../tvm/script/ir_builder/tir/external_kernel.py | 48 +++--
src/runtime/cuda/cuda_module.cc | 1 +
tests/python/disco/test_nvshmem.py | 126 ++++++++++-
4 files changed, 379 insertions(+), 30 deletions(-)
diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py
index edf3e8af4f..7706f63973 100644
--- a/python/tvm/contrib/nvcc.py
+++ b/python/tvm/contrib/nvcc.py
@@ -71,16 +71,17 @@ def compile_cuda(
- NVRTC is a "runtime" compilation library and can be faster for JIT
compilation.
- NVRTC requires cuda-python: pip install cuda-python
"""
- # TODO: if need NVSHMEM for compilation, fall back to NVCC because support
for NVRTC
- # is not yet implemented
use_nvshmem = "#include <nvshmem.h>" in code or "#include <nvshmemx.h>" in
code
- if compiler == "nvcc" or use_nvshmem:
- return _compile_cuda_nvcc(code, target_format, arch, options,
path_target, use_nvshmem)
+
+ if compiler == "nvcc":
+ result = _compile_cuda_nvcc(code, target_format, arch, options,
path_target, use_nvshmem)
elif compiler == "nvrtc":
- return _compile_cuda_nvrtc(code, target_format, arch, options)
+ result = _compile_cuda_nvrtc(code, target_format, arch, options,
path_target, use_nvshmem)
else:
raise ValueError(f"cuda compiler must be 'nvcc' or 'nvrtc', got:
{compiler}")
+ return result
+
def _compile_cuda_nvcc(
code,
@@ -235,7 +236,9 @@ def _compile_cuda_nvcc(
return data
-def _compile_cuda_nvrtc(code, target_format=None, arch=None, options=None):
+def _compile_cuda_nvrtc(
+ code, target_format=None, arch=None, options=None, path_target=None,
use_nvshmem=False
+):
"""Compile CUDA code using NVRTC (NVIDIA Runtime Compilation).
Parameters
@@ -248,6 +251,10 @@ def _compile_cuda_nvrtc(code, target_format=None,
arch=None, options=None):
Target architecture (e.g., "sm_80"). Auto-detected if None.
options : str or list of str, optional
Additional NVRTC options.
+ path_target : str, optional
+ Output file path. If provided, the compiled binary is written to this
path.
+ use_nvshmem : bool, optional
+ Whether NVSHMEM is used. Default: False
Returns
-------
@@ -264,8 +271,20 @@ def _compile_cuda_nvrtc(code, target_format=None,
arch=None, options=None):
"See: https://nvidia.github.io/cuda-python/"
) from e
- # Default target format
- if target_format is None:
+ # For NVSHMEM, we also need the CUDA driver API to initialize the context
for linking
+ if use_nvshmem:
+ import importlib.util # pylint: disable=import-outside-toplevel
+
+ if importlib.util.find_spec("cuda.bindings.driver") is None:
+ raise RuntimeError(
+ "Failed to compile CUDA with NVRTC+NVSHMEM because the
`cuda-python` package "
+ "is not available.\n"
+ "Please install it with: pip install cuda-python\n"
+ "See: https://nvidia.github.io/cuda-python/"
+ )
+
+ # NVSHMEM requires linking with device library, which always produces cubin
+ if use_nvshmem or target_format is None:
target_format = "cubin"
# Validate target_format (NVRTC doesn't support fatbin)
@@ -287,6 +306,11 @@ def _compile_cuda_nvrtc(code, target_format=None,
arch=None, options=None):
compute_version =
get_target_compute_version(Target.current(allow_none=True))
arch = f"sm_{''.join(compute_version.split('.'))}"
+ # Get NVSHMEM paths if needed
+ nvshmem_include_path, nvshmem_lib_path = None, None
+ if use_nvshmem:
+ nvshmem_include_path, nvshmem_lib_path = find_nvshmem_paths()
+
# Strip host-only headers for NVRTC. NVRTC compiles device code and does
not
# require the CUDA driver header or host C++ headers.
headers_to_strip = {"#include <cuda.h>"}
@@ -304,6 +328,47 @@ def _compile_cuda_nvrtc(code, target_format=None,
arch=None, options=None):
"};\n\n" + code_filtered
)
+ # Add standard type definitions and compatibility macros that NVRTC
doesn't provide.
+ nvrtc_preamble = """#include <cuda/std/cstdint>
+using cuda::std::uint8_t;
+using cuda::std::uint16_t;
+using cuda::std::uint32_t;
+using cuda::std::uint64_t;
+using cuda::std::int8_t;
+using cuda::std::int16_t;
+using cuda::std::int32_t;
+using cuda::std::int64_t;
+
+// NVRTC uses asm/volatile instead of __asm__/__volatile__
+#ifndef __asm__
+#define __asm__ asm
+#endif
+#ifndef __volatile__
+#define __volatile__ volatile
+#endif
+
+"""
+ code_filtered = nvrtc_preamble + code_filtered
+
+ # For NVSHMEM, add preamble to map cuda::std type traits to std namespace.
+ # NVSHMEM headers require std:: type traits but NVRTC uses cuda::std::.
+ if use_nvshmem:
+ nvshmem_preamble = """#include <cuda/std/type_traits>
+
+// Map cuda::std type traits to std namespace for NVSHMEM headers
+namespace std {
+ using cuda::std::is_integral;
+ using cuda::std::is_signed;
+ using cuda::std::is_unsigned;
+ using cuda::std::is_floating_point;
+ using cuda::std::is_same;
+ using cuda::std::enable_if;
+ using cuda::std::conditional;
+}
+
+"""
+ code_filtered = nvshmem_preamble + code_filtered
+
# Create NVRTC program
# Use "tvm_kernels.cu" for consistency with nvcc path
result, prog = nvrtc.nvrtcCreateProgram(
@@ -319,6 +384,9 @@ def _compile_cuda_nvrtc(code, target_format=None,
arch=None, options=None):
b"-default-device",
]
+ if use_nvshmem:
+ compile_opts.extend([b"-rdc", b"true"])
+
# Add CUDA include paths. NVRTC needs explicit include paths for CUDA
headers.
# Standard installations: cuda_path/include
# Conda/architecture-specific installations:
cuda_path/targets/<arch>/include
@@ -339,6 +407,12 @@ def _compile_cuda_nvrtc(code, target_format=None,
arch=None, options=None):
if os.path.isdir(arch_include):
include_paths.append(arch_include)
+ # Check for CCCL include directory (required for cuda/std/cstdint and
type_traits)
+ # CCCL provides standard library functionality for device code
+ cccl_include = os.path.join(arch_include, "cccl") if
os.path.isdir(arch_include) else None
+ if cccl_include and os.path.isdir(cccl_include):
+ include_paths.append(cccl_include)
+
# Verify we can find essential CUDA headers
if not any(os.path.isfile(os.path.join(p, "cuda_runtime.h")) for p in
include_paths):
raise RuntimeError(
@@ -351,6 +425,26 @@ def _compile_cuda_nvrtc(code, target_format=None,
arch=None, options=None):
for include_path in include_paths:
compile_opts.append(f"-I{include_path}".encode())
+ # Add NVSHMEM include path
+ if use_nvshmem and nvshmem_include_path:
+ compile_opts.append(f"-I{nvshmem_include_path}".encode())
+
+ # For NVSHMEM, add deprecation and type conversion macros
+ if use_nvshmem:
+ compile_opts.extend(
+ [
+ # Define deprecation macros as empty (not properly defined in
NVRTC context)
+ b"-D__NV_SILENCE_DEPRECATION_BEGIN=",
+ b"-D__NV_SILENCE_DEPRECATION_END=",
+ b"-D__NV_SILENCE_HOST_DEPRECATION_BEGIN=",
+ b"-D__NV_SILENCE_HOST_DEPRECATION_END=",
+ # Disable FP8/FP6/FP4 extended types that cause issues with
NVRTC
+ b"-D__CUDA_NO_FP8_CONVERSIONS__",
+ b"-D__CUDA_NO_FP6_CONVERSIONS__",
+ b"-D__CUDA_NO_FP4_CONVERSIONS__",
+ ]
+ )
+
compile_opts.extend(
[
b"-U__CUDA_NO_HALF_OPERATORS__",
@@ -363,12 +457,40 @@ def _compile_cuda_nvrtc(code, target_format=None,
arch=None, options=None):
]
)
- # Add user-provided options
+ # Add user-provided options, filtering out nvcc-specific flags that nvrtc
doesn't support
if options:
+ nvcc_only_prefixes = (
+ "-c",
+ "-O",
+ "-std",
+ "--std",
+ "-Xcompiler",
+ "-Xlinker",
+ "-Xarchive",
+ "-Xcudafe",
+ "-Xptxas",
+ "--compile",
+ "--compiler-options",
+ "--linker-options",
+ "-fPIC",
+ "-shared",
+ "-o",
+ )
if isinstance(options, str):
- compile_opts.append(options.encode())
- else:
- compile_opts.extend([opt.encode() if isinstance(opt, str) else opt
for opt in options])
+ options = [options]
+ for opt in options:
+ if isinstance(opt, str):
+ opt_str = opt
+ elif isinstance(opt, bytes):
+ opt_str = opt.decode()
+ else:
+ opt_str = str(opt)
+ skip = any(
+ opt_str.startswith(prefix) or opt_str == prefix for prefix in
nvcc_only_prefixes
+ )
+ if skip:
+ continue
+ compile_opts.append(opt.encode() if isinstance(opt, str) else opt)
# Compile
(result,) = nvrtc.nvrtcCompileProgram(prog, len(compile_opts),
compile_opts)
@@ -410,10 +532,94 @@ def _compile_cuda_nvrtc(code, target_format=None,
arch=None, options=None):
nvrtc.nvrtcDestroyProgram(prog)
raise RuntimeError(f"Failed to get PTX:
{nvrtc.nvrtcGetErrorString(result)}")
- # Clean up
+ # Clean up NVRTC program
nvrtc.nvrtcDestroyProgram(prog)
- return bytearray(binary_buf)
+ # Link stage for NVSHMEM
+ if use_nvshmem:
+ binary_buf = _link_nvshmem_nvrtc(binary_buf, nvshmem_lib_path)
+
+ if path_target:
+ with open(path_target, "wb") as f:
+ f.write(binary_buf)
+ return binary_buf
+
+
+def _link_nvshmem_nvrtc(binary_buf, nvshmem_lib_path):
+ """Link compiled CUBIN with NVSHMEM device library using CUDA driver
API."""
+ import ctypes # pylint: disable=import-outside-toplevel
+
+ from cuda.bindings import driver as cu # pylint:
disable=import-outside-toplevel
+
+ # cuLinkCreate requires a valid CUDA context.
+ # Always create a fresh context for linking to avoid issues with stale
contexts
+ # in multi-process environments like Disco workers.
+ (result,) = cu.cuInit(0)
+ if result != cu.CUresult.CUDA_SUCCESS:
+ raise RuntimeError(f"Failed to initialize CUDA: {result}")
+
+ result, device = cu.cuDeviceGet(0)
+ if result != cu.CUresult.CUDA_SUCCESS:
+ raise RuntimeError(f"Failed to get CUDA device: {result}")
+
+ result, context = cu.cuCtxCreate(None, 0, device)
+ if result != cu.CUresult.CUDA_SUCCESS:
+ raise RuntimeError(f"Failed to create CUDA context: {result}")
+
+ try:
+ # Create linker
+ result, link_state = cu.cuLinkCreate(0, [], [])
+ if result != cu.CUresult.CUDA_SUCCESS:
+ raise RuntimeError(f"Failed to create CUDA linker: {result}")
+
+ try:
+ # Add our compiled CUBIN
+ (result,) = cu.cuLinkAddData(
+ link_state,
+ cu.CUjitInputType.CU_JIT_INPUT_CUBIN,
+ binary_buf,
+ len(binary_buf),
+ b"tvm_kernels.cubin",
+ 0,
+ [],
+ [],
+ )
+ if result != cu.CUresult.CUDA_SUCCESS:
+ raise RuntimeError(f"Failed to add CUBIN to linker: {result}")
+
+ # Add NVSHMEM device library
+ nvshmem_device_lib = os.path.join(nvshmem_lib_path,
"libnvshmem_device.a")
+ if not os.path.exists(nvshmem_device_lib):
+ raise RuntimeError(f"NVSHMEM device library not found:
{nvshmem_device_lib}")
+
+ (result,) = cu.cuLinkAddFile(
+ link_state,
+ cu.CUjitInputType.CU_JIT_INPUT_LIBRARY,
+ nvshmem_device_lib.encode(),
+ 0,
+ [],
+ [],
+ )
+ if result != cu.CUresult.CUDA_SUCCESS:
+ raise RuntimeError(f"Failed to add NVSHMEM device library:
{result}")
+
+ # Complete linking
+ result, linked_cubin, linked_size = cu.cuLinkComplete(link_state)
+ if result != cu.CUresult.CUDA_SUCCESS:
+ raise RuntimeError(f"Failed to complete NVSHMEM linking:
{result}")
+
+ # Copy linked binary before destroying linker
+ binary_buf = bytearray(ctypes.string_at(linked_cubin, linked_size))
+ if not binary_buf:
+ raise RuntimeError("Compilation error: empty result is
generated")
+ finally:
+ # Clean up linker
+ cu.cuLinkDestroy(link_state)
+ finally:
+ # Clean up context
+ cu.cuCtxDestroy(context)
+
+ return binary_buf
def find_cuda_path():
diff --git a/python/tvm/script/ir_builder/tir/external_kernel.py
b/python/tvm/script/ir_builder/tir/external_kernel.py
index 45a3d364c1..d7854d7a68 100644
--- a/python/tvm/script/ir_builder/tir/external_kernel.py
+++ b/python/tvm/script/ir_builder/tir/external_kernel.py
@@ -58,14 +58,16 @@ class BaseKernel: # pylint: disable=too-few-public-methods
)
return tvm_metadata
- def _create_cuda_module(self, ptx, kernel_arg_types, launch_param_tags,
kernel_name):
+ def _create_cuda_module(
+ self, binary_data, kernel_arg_types, launch_param_tags, kernel_name,
fmt="ptx"
+ ):
"""
- Create a CUDA module from PTX and metadata.
+ Create a CUDA module from compiled binary (PTX or cubin) and metadata.
Parameters
----------
- ptx : str
- The PTX code of the kernel.
+ binary_data : str or bytes
+ The compiled binary data (PTX as str, cubin as bytes).
kernel_arg_types : List[str]
The types of the kernel arguments.
@@ -76,6 +78,9 @@ class BaseKernel: # pylint: disable=too-few-public-methods
kernel_name : str
The name of the kernel.
+ fmt : str
+ The format of the binary data: "ptx" or "cubin".
+
Returns
-------
kernel_module : Module
@@ -85,12 +90,16 @@ class BaseKernel: # pylint: disable=too-few-public-methods
kernel_name, kernel_arg_types, launch_param_tags
)
with tempfile.TemporaryDirectory() as temp_dir:
- ptx_path = f"{temp_dir}/{kernel_name}.ptx"
- with open(ptx_path, "w") as f:
- f.write(ptx)
+ binary_path = f"{temp_dir}/{kernel_name}.{fmt}"
+ if fmt == "ptx":
+ with open(binary_path, "w") as f:
+ f.write(binary_data)
+ else:
+ with open(binary_path, "wb") as f:
+ f.write(binary_data)
with open(f"{temp_dir}/{kernel_name}.tvm_meta.json", "w") as f:
f.write(tvm_metadata)
- kernel_module = load_module(ptx_path)
+ kernel_module = load_module(binary_path)
return kernel_module
@@ -139,20 +148,31 @@ class SourceKernel(BaseKernel): # pylint:
disable=too-few-public-methods
pass
with tempfile.TemporaryDirectory() as temp_dir:
- ptx_path = f"{temp_dir}/{kernel_name}.ptx"
+ # Check if NVSHMEM is used - requires cubin output for device
library linking
+ use_nvshmem = (
+ "#include <nvshmem.h>" in source_code or "#include
<nvshmemx.h>" in source_code
+ )
+ target_format = "cubin" if use_nvshmem else "ptx"
+ output_path = f"{temp_dir}/{kernel_name}.{target_format}"
+
compiler = os.environ.get("TVM_CUDA_COMPILE_MODE", "nvcc")
nvcc.compile_cuda(
source_code,
- target_format="ptx",
+ target_format=target_format,
options=compile_options,
- path_target=ptx_path,
+ path_target=output_path,
compiler=compiler,
)
- with open(ptx_path, "r") as f:
- ptx = f.read()
+
+ if target_format == "ptx":
+ with open(output_path, "r") as f:
+ binary_data = f.read()
+ else:
+ with open(output_path, "rb") as f:
+ binary_data = f.read()
kernel_module = self._create_cuda_module(
- ptx, kernel_arg_types, launch_param_tags, kernel_name
+ binary_data, kernel_arg_types, launch_param_tags, kernel_name,
fmt=target_format
)
return kernel_name, kernel_module, runtime_args
diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc
index 19f4288c97..f5219ae98a 100644
--- a/src/runtime/cuda/cuda_module.cc
+++ b/src/runtime/cuda/cuda_module.cc
@@ -342,6 +342,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
refl::GlobalDef()
.def("ffi.Module.load_from_file.cuda", CUDAModuleLoadFile)
.def("ffi.Module.load_from_file.ptx", CUDAModuleLoadFile)
+ .def("ffi.Module.load_from_file.cubin", CUDAModuleLoadFile)
.def("ffi.Module.load_from_bytes.cuda", CUDAModuleLoadFromBytes);
}
} // namespace runtime
diff --git a/tests/python/disco/test_nvshmem.py
b/tests/python/disco/test_nvshmem.py
index 029eb8fe82..b98b49591d 100644
--- a/tests/python/disco/test_nvshmem.py
+++ b/tests/python/disco/test_nvshmem.py
@@ -28,6 +28,8 @@ import multiprocessing
from multiprocessing import Process
from typing import Any, Callable, List
+from tvm.script import ir as I
+from tvm.script import relax as R
from tvm.script import tir as T
@@ -142,7 +144,7 @@ def test_nvshmem_compile():
if tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid", True) is
None:
return
- num_workers = 4
+ num_workers = 2
sess = di.ProcessSession(num_workers=num_workers)
f_init_nvshmem_uid =
tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid")
@@ -191,6 +193,121 @@ def test_nvshmem_compile():
shutil.rmtree(tmpdir, ignore_errors=True)
+NVSHMEM_QUERY_KERNEL_SOURCE = """
+#include <nvshmem.h>
+
+extern "C" __global__ void nvshmem_query_kernel(int* my_pe_out, int*
n_pes_out) {
+ my_pe_out[0] = nvshmem_my_pe();
+ n_pes_out[0] = nvshmem_n_pes();
+}
+"""
+
+
+def _test_nvshmem_kernel_compile_impl():
+ """Test compiling and running a kernel that calls NVSHMEM functions"""
+ if tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid", True) is
None:
+ return
+
+ num_workers = 2
+ sess = di.ProcessSession(num_workers=num_workers)
+
+ f_init_nvshmem_uid =
tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid")
+ uid = f_init_nvshmem_uid()
+ init_dfunc = sess.get_global_func("runtime.disco.nvshmem.init_nvshmem")
+ init_dfunc(uid, num_workers, 0)
+ sess.sync_worker_0()
+
+ try:
+
+ @I.ir_module
+ class NvshmemQueryModule:
+ @T.prim_func
+ def query_pe(
+ my_pe_out: T.Buffer((1,), "int32"),
+ n_pes_out: T.Buffer((1,), "int32"),
+ ):
+ with T.block("root"):
+ T.reads()
+ T.writes(my_pe_out[0:1], n_pes_out[0:1])
+ T.call_kernel(
+ NVSHMEM_QUERY_KERNEL_SOURCE,
+ ((1,), (1,)), # grid=(1,), block=(1,)
+ my_pe_out.data,
+ n_pes_out.data,
+ kernel_name="nvshmem_query_kernel",
+ )
+
+ @R.function
+ def main() -> R.Tuple(R.Tensor((1,), "int32"), R.Tensor((1,),
"int32")):
+ cls = NvshmemQueryModule
+ with R.dataflow():
+ my_pe = R.call_tir(
+ cls.query_pe,
+ (),
+ out_sinfo=[
+ R.Tensor((1,), "int32"),
+ R.Tensor((1,), "int32"),
+ ],
+ )
+ R.output(my_pe)
+ return my_pe
+
+ tmpdir = tempfile.mkdtemp()
+ try:
+ path = tmpdir + "/test_nvshmem_kernel.so"
+
+ target = tvm.target.Target("cuda")
+ tvm.compile(NvshmemQueryModule, target=target).export_library(path)
+ mod = sess.load_vm_module(path)
+ result = mod["main"]()
+
+ # Verify results from each worker
+ for worker_id in range(num_workers):
+ my_pe_result, n_pes_result =
result.debug_get_from_remote(worker_id)
+ my_pe_val = my_pe_result.numpy()[0]
+ n_pes_val = n_pes_result.numpy()[0]
+ assert (
+ my_pe_val == worker_id
+ ), f"Worker {worker_id} reported my_pe={my_pe_val}, expected
{worker_id}"
+ assert (
+ n_pes_val == num_workers
+ ), f"Worker {worker_id} reported n_pes={n_pes_val}, expected
{num_workers}"
+
+ # Sync all workers before cleanup
+ sess._sync_all()
+
+ finalize_dfunc =
sess.get_global_func("runtime.disco.nvshmem.finalize_nvshmem")
+ finalize_dfunc()
+ sess.sync_worker_0()
+ finally:
+ shutil.rmtree(tmpdir, ignore_errors=True)
+ finally:
+ sess.shutdown()
+
+
+def test_nvshmem_kernel_compile_nvcc():
+ """Test NVSHMEM kernel compilation with nvcc."""
+ # Since this test runs in a separate process, we can safely set the env var
+ import os
+
+ os.environ["TVM_CUDA_COMPILE_MODE"] = "nvcc"
+ _test_nvshmem_kernel_compile_impl()
+
+
+def test_nvshmem_kernel_compile_nvrtc():
+ """Test NVSHMEM kernel compilation with nvrtc."""
+ try:
+ from cuda.bindings import nvrtc # noqa: F401
+ except ImportError:
+ pytest.skip("cuda-python not available, skipping nvrtc test")
+
+ # Since this test runs in a separate process, we can safely set the env var
+ import os
+
+ os.environ["TVM_CUDA_COMPILE_MODE"] = "nvrtc"
+ _test_nvshmem_kernel_compile_impl()
+
+
if __name__ == "__main__":
# After the first call to `nvshmem_init`, a subsequent call to
`nvshmem_init`
# or `nvshmem_init_thread` in the same program results in undefined
behavior.
@@ -212,8 +329,13 @@ if __name__ == "__main__":
p.exitcode == 0
), f"Test {test_func.__name__} failed with exit code
{p.exitcode}"
- # testing compilation flow
p = Process(target=test_nvshmem_compile)
p.start()
p.join()
assert p.exitcode == 0, f"Test test_nvshmem_compile failed with exit code
{p.exitcode}"
+
+ for test_func in [test_nvshmem_kernel_compile_nvcc,
test_nvshmem_kernel_compile_nvrtc]:
+ p = Process(target=test_func)
+ p.start()
+ p.join()
+ assert p.exitcode == 0, f"Test {test_func.__name__} failed with exit
code {p.exitcode}"