This is an automated email from the ASF dual-hosted git repository.
bohan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new fa905d2b69 [Compile] accelerate compilation speed using NVRTC (#18519)
fa905d2b69 is described below
commit fa905d2b693e5368ea059a72f3fd1333005f6560
Author: Kathryn (Jinqi) Chen <[email protected]>
AuthorDate: Thu Jan 8 11:08:06 2026 -0500
[Compile] accelerate compilation speed using NVRTC (#18519)
This PR supports NVRTC as an alternative to NVCC for faster, device-side
JIT compilation of CUDA kernels, in favor of the PR
[https://github.com/apache/tvm-ffi/pull/283](https://github.com/apache/tvm-ffi/pull/283).
It enhances the CUDA compilation backend by:
- Adding Python NVRTC support using cuda-python bindings
- Removing legacy C++ NVRTC fallback in favor of a Python-first approach
- Keeping nvcc as the default compiler with fatbin output (no behavior
change for existing users)
Users can choose the compilation backend using an environment variable
`TVM_CUDA_COMPILE_MODE`, choosing from "nvcc" and "nvrtc". For example,
`TVM_CUDA_COMPILE_MODE=nvrtc python3 your_program.py`
Here is a short benchmark of the compilation speed of kernels in
`test_target_codegen_cuda.py`.
### NVCC vs NVRTC Compilation Time Comparison (Python-side Call)
| Test Case | Code Size | NVCC Time (ms) | NVRTC Time (ms) | Speedup |
| :--- | :--- | :--- | :--- | :--- |
| `test_crossthread_reduction1` | 1945 B | 241.27 | 51.23 | **4.7x** |
| `test_cuda_bf16_vectorize_add` | 3760 B | 342.72 | 44.50 | **7.7x** |
| `test_cuda_const_float_to_half` | 12394 B | 272.85 | 31.99 | **8.5x**
|
| `test_cuda_device_func_call` | 975 B | 215.58 | 21.47 | **10.0x** |
| `test_cuda_float_const_hex_format` | 685 B | 217.39 | 20.52 |
**10.6x** |
| `test_cuda_floordiv_with_vectorization` | 1050 B | 213.88 | 23.32 |
**9.2x** |
| `test_cuda_inf_nan` | 673 B | 214.33 | 24.94 | **8.6x** |
| `test_cuda_tensormap` | 755 B | 213.91 | 20.74 | **10.3x** |
| `test_cuda_thread_sync_inside_condition` | 1007 B | 213.43 | 28.29 |
**7.5x** |
| `test_cuda_vectorize_add` | 908 B | 226.81 | 40.39 | **5.6x** |
| `test_cuda_vectorize_load` | 734 B | 217.25 | 24.02 | **9.0x** |
| `test_device_host_call_same_func` | 924 B | 216.03 | 21.21 | **10.2x**
|
| `test_vectorized_intrin1` | 847 B | 226.15 | 26.34 | **8.6x** |
### NVSHMEM Support
Currently, NVSHMEM is **not** supported via NVRTC.
- Fallback Behavior: When NVSHMEM is required, the compilation pipeline
will automatically fall back to NVCC, even if `TVM_CUDA_COMPILE_MODE` is
set to nvrtc.
- Future Roadmap: Support for NVRTC with NVSHMEM is planned for
follow-up PRs.
---
cmake/modules/CUDA.cmake | 2 -
cmake/utils/FindCUDA.cmake | 9 -
docker/Dockerfile.ci_gpu | 3 +
docker/install/ubuntu_install_cuda_python.sh | 23 ++
python/tvm/contrib/nvcc.py | 329 +++++++++++++++++++--
.../tvm/script/ir_builder/tir/external_kernel.py | 24 +-
src/runtime/contrib/nvshmem/init.cc | 18 +-
src/target/opt/build_cuda_on.cc | 122 ++------
src/target/source/codegen_cuda.cc | 8 +-
src/target/source/literal/cuda_half_t.h | 8 +-
tests/python/codegen/test_target_codegen_cuda.py | 32 +-
tests/python/disco/test_nvshmem.py | 30 +-
.../test_tir_transform_inject_ptx_async_copy.py | 7 +
13 files changed, 465 insertions(+), 150 deletions(-)
diff --git a/cmake/modules/CUDA.cmake b/cmake/modules/CUDA.cmake
index 64f41d65fa..e9f5854901 100644
--- a/cmake/modules/CUDA.cmake
+++ b/cmake/modules/CUDA.cmake
@@ -54,10 +54,8 @@ if(USE_CUDA)
list(APPEND RUNTIME_SRCS ${RUNTIME_CUDA_SRCS})
list(APPEND COMPILER_SRCS src/target/opt/build_cuda_on.cc)
- list(APPEND TVM_LINKER_LIBS ${CUDA_NVRTC_LIBRARY})
list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CUDART_LIBRARY})
list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CUDA_LIBRARY})
- list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_NVRTC_LIBRARY})
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
if(CMAKE_VERSION VERSION_LESS "3.24")
diff --git a/cmake/utils/FindCUDA.cmake b/cmake/utils/FindCUDA.cmake
index c4c18eef0f..c62506cf41 100644
--- a/cmake/utils/FindCUDA.cmake
+++ b/cmake/utils/FindCUDA.cmake
@@ -33,7 +33,6 @@
# - CUDA_TOOLKIT_ROOT_DIR
# - CUDA_CUDA_LIBRARY
# - CUDA_CUDART_LIBRARY
-# - CUDA_NVRTC_LIBRARY
# - CUDA_CUDNN_INCLUDE_DIRS
# - CUDA_CUDNN_LIBRARY
# - CUDA_CUBLAS_LIBRARY
@@ -64,9 +63,6 @@ macro(find_cuda use_cuda use_cudnn)
find_library(CUDA_CUDA_LIBRARY cuda
${CUDA_TOOLKIT_ROOT_DIR}/lib/x64
${CUDA_TOOLKIT_ROOT_DIR}/lib/Win32)
- find_library(CUDA_NVRTC_LIBRARY nvrtc
- ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64
- ${CUDA_TOOLKIT_ROOT_DIR}/lib/Win32)
find_library(CUDA_CUBLAS_LIBRARY cublas
${CUDA_TOOLKIT_ROOT_DIR}/lib/x64
${CUDA_TOOLKIT_ROOT_DIR}/lib/Win32)
@@ -81,10 +77,6 @@ macro(find_cuda use_cuda use_cudnn)
if(_CUDA_CUDA_LIBRARY)
set(CUDA_CUDA_LIBRARY ${_CUDA_CUDA_LIBRARY})
endif()
- find_library(CUDA_NVRTC_LIBRARY nvrtc
- PATHS ${CUDA_TOOLKIT_ROOT_DIR}
- PATH_SUFFIXES lib lib64 targets/x86_64-linux/lib
targets/x86_64-linux/lib/stubs lib64/stubs lib/x86_64-linux-gnu
- NO_DEFAULT_PATH)
find_library(CUDA_CURAND_LIBRARY curand
PATHS ${CUDA_TOOLKIT_ROOT_DIR}
PATH_SUFFIXES lib lib64 targets/x86_64-linux/lib
targets/x86_64-linux/lib/stubs lib64/stubs lib/x86_64-linux-gnu
@@ -140,7 +132,6 @@ macro(find_cuda use_cuda use_cudnn)
message(STATUS "Found CUDA_TOOLKIT_ROOT_DIR=" ${CUDA_TOOLKIT_ROOT_DIR})
message(STATUS "Found CUDA_CUDA_LIBRARY=" ${CUDA_CUDA_LIBRARY})
message(STATUS "Found CUDA_CUDART_LIBRARY=" ${CUDA_CUDART_LIBRARY})
- message(STATUS "Found CUDA_NVRTC_LIBRARY=" ${CUDA_NVRTC_LIBRARY})
message(STATUS "Found CUDA_CUDNN_INCLUDE_DIRS=" ${CUDA_CUDNN_INCLUDE_DIRS})
message(STATUS "Found CUDA_CUDNN_LIBRARY=" ${CUDA_CUDNN_LIBRARY})
message(STATUS "Found CUDA_CUBLAS_LIBRARY=" ${CUDA_CUBLAS_LIBRARY})
diff --git a/docker/Dockerfile.ci_gpu b/docker/Dockerfile.ci_gpu
index 1295c679d7..a72bd60fd7 100644
--- a/docker/Dockerfile.ci_gpu
+++ b/docker/Dockerfile.ci_gpu
@@ -60,6 +60,9 @@ RUN bash /install/ubuntu_install_opencl.sh
COPY install/ubuntu_install_python_package.sh
/install/ubuntu_install_python_package.sh
RUN bash /install/ubuntu_install_python_package.sh
+COPY install/ubuntu_install_cuda_python.sh
/install/ubuntu_install_cuda_python.sh
+RUN bash /install/ubuntu_install_cuda_python.sh
+
COPY install/ubuntu_install_sphinx.sh /install/ubuntu_install_sphinx.sh
RUN bash /install/ubuntu_install_sphinx.sh
diff --git a/docker/install/ubuntu_install_cuda_python.sh
b/docker/install/ubuntu_install_cuda_python.sh
new file mode 100644
index 0000000000..eb4efac5c0
--- /dev/null
+++ b/docker/install/ubuntu_install_cuda_python.sh
@@ -0,0 +1,23 @@
+#!/bin/bash
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+set -e
+set -u
+set -o pipefail
+
+pip3 install cuda-python
diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py
index d062714938..edf3e8af4f 100644
--- a/python/tvm/contrib/nvcc.py
+++ b/python/tvm/contrib/nvcc.py
@@ -16,14 +16,18 @@
# under the License.
# pylint: disable=invalid-name
"""Utility to invoke nvcc compiler in the system"""
+
from __future__ import absolute_import as _abs
+import glob
import os
+import platform
import subprocess
import warnings
from typing import Tuple
import tvm_ffi
+
import tvm
from tvm.target import Target
@@ -31,8 +35,10 @@ from ..base import py_str
from . import utils
-def compile_cuda(code, target_format=None, arch=None, options=None,
path_target=None):
- """Compile cuda code with NVCC from env.
+def compile_cuda(
+ code, target_format=None, arch=None, options=None, path_target=None,
compiler="nvcc"
+):
+ """Compile cuda code with NVCC or NVRTC.
Parameters
----------
@@ -40,7 +46,7 @@ def compile_cuda(code, target_format=None, arch=None,
options=None, path_target=
The cuda code.
target_format : str
- The target format of nvcc compiler.
+ The target format of the compiler ("ptx", "cubin", or "fatbin").
arch : str
The cuda architecture.
@@ -51,14 +57,61 @@ def compile_cuda(code, target_format=None, arch=None,
options=None, path_target=
path_target : str, optional
Output file.
- Return
- ------
- cubin : bytearray
- The bytearray of the cubin
+ compiler : str, optional
+ Compiler backend: "nvcc" or "nvrtc".
+ This can be set by the TVM_CUDA_COMPILE_MODE environment variable.
+
+ Returns
+ -------
+ res_binary : bytearray
+ The bytearray of the compiled binary (ptx/cubin/fatbin).
+
+ Notes
+ -----
+ - NVRTC is a "runtime" compilation library and can be faster for JIT
compilation.
+ - NVRTC requires cuda-python: pip install cuda-python
+ """
+ # TODO: if need NVSHMEM for compilation, fall back to NVCC because support
for NVRTC
+ # is not yet implemented
+ use_nvshmem = "#include <nvshmem.h>" in code or "#include <nvshmemx.h>" in
code
+ if compiler == "nvcc" or use_nvshmem:
+ return _compile_cuda_nvcc(code, target_format, arch, options,
path_target, use_nvshmem)
+ elif compiler == "nvrtc":
+ return _compile_cuda_nvrtc(code, target_format, arch, options)
+ else:
+ raise ValueError(f"cuda compiler must be 'nvcc' or 'nvrtc', got:
{compiler}")
+
+
+def _compile_cuda_nvcc(
+ code,
+ target_format=None,
+ arch=None,
+ options=None,
+ path_target=None,
+ use_nvshmem=False,
+):
+ """Compile CUDA code using nvcc.
+
+ Parameters
+ ----------
+ code : str
+ The CUDA source code.
+ target_format : str, optional
+ Output format: "ptx", "cubin", or "fatbin".
+ arch : str, optional
+ Target architecture. Auto-detected if None.
+ options : str or list of str, optional
+ Additional nvcc options.
+ path_target : str, optional
+ Output file path.
+
+ Returns
+ -------
+ bytearray
+ Compiled binary data.
"""
# Check for NVSHMEM dependency
nvshmem_include_path, nvshmem_lib_path = None, None
- use_nvshmem = "#include <nvshmem.h>" in code or "#include <nvshmemx.h>" in
code
if use_nvshmem:
# NOTE: we cannot check whether nvshmem is used based on whether
# the global function "runtime.nvshmem.cumodule_init" is defined.
@@ -106,8 +159,9 @@ def compile_cuda(code, target_format=None, arch=None,
options=None, path_target=
file_target = path_target if path_target else temp_target
if use_nvshmem:
- file_prefix = file_target.split(".")[0]
+ file_prefix = os.path.splitext(file_target)[0]
file_target = f"{file_prefix}.o" # in the first stage, compile to
object file
+
cmd = ["nvcc"]
cmd += [f"--{target_format}", "-O3"]
if kernels_output_dir is not None:
@@ -151,14 +205,11 @@ def compile_cuda(code, target_format=None, arch=None,
options=None, path_target=
msg += py_str(out)
raise RuntimeError(msg)
- # start second stage of compilation
+ # Second stage for NVSHMEM
if use_nvshmem:
cmd = ["nvlink"]
cmd += [f"-arch=sm_{compute_version}"]
- cmd += [
- "-L",
- nvshmem_lib_path,
- ]
+ cmd += ["-L", nvshmem_lib_path]
cmd += ["-L", os.path.join(find_cuda_path(), "lib64")]
cmd += ["-l", "nvshmem_device"]
cmd += ["-l", "cudadevrt"]
@@ -184,6 +235,187 @@ def compile_cuda(code, target_format=None, arch=None,
options=None, path_target=
return data
+def _compile_cuda_nvrtc(code, target_format=None, arch=None, options=None):
+ """Compile CUDA code using NVRTC (NVIDIA Runtime Compilation).
+
+ Parameters
+ ----------
+ code : str
+ The CUDA source code.
+ target_format : str, optional
+ Output format: "cubin" or "ptx". Default: "cubin"
+ arch : str, optional
+ Target architecture (e.g., "sm_80"). Auto-detected if None.
+ options : str or list of str, optional
+ Additional NVRTC options.
+
+ Returns
+ -------
+ bytearray
+ Compiled binary data.
+ """
+ try:
+ from cuda.bindings import nvrtc # pylint:
disable=import-outside-toplevel
+ except ImportError as e:
+ raise RuntimeError(
+ "Failed to compile CUDA with NVRTC because the `cuda-python`
package "
+ "is not available.\n"
+ "Please install it with: pip install cuda-python\n"
+ "See: https://nvidia.github.io/cuda-python/"
+ ) from e
+
+ # Default target format
+ if target_format is None:
+ target_format = "cubin"
+
+ # Validate target_format (NVRTC doesn't support fatbin)
+ if target_format == "fatbin":
+ raise ValueError(
+ "NVRTC does not support fatbin generation yet. "
+ "Use target_format='cubin' or 'ptx' with NVRTC, "
+ "or set compiler='nvcc' for fatbin compilation."
+ )
+ if target_format not in ["cubin", "ptx"]:
+ raise ValueError(f"target_format must be 'cubin' or 'ptx', got:
{target_format}")
+
+ # Validate options
+ if options is not None and not isinstance(options, (str, list)):
+ raise ValueError("options must be str or list of str")
+
+ # Auto-detect architecture
+ if arch is None:
+ compute_version =
get_target_compute_version(Target.current(allow_none=True))
+ arch = f"sm_{''.join(compute_version.split('.'))}"
+
+ # Strip host-only headers for NVRTC. NVRTC compiles device code and does
not
+ # require the CUDA driver header or host C++ headers.
+ headers_to_strip = {"#include <cuda.h>"}
+ code_filtered = "\n".join(
+ line for line in code.splitlines() if line.strip() not in
headers_to_strip
+ )
+
+ # NVRTC compiles device code and does not include the host-side cuda.h.
+ # CUtensorMap is a host-side structure, to reference and use it in device
code,
+ # we must forward-declare it for NVRTC.
+ if "CUtensorMap" in code_filtered:
+ code_filtered = (
+ "struct __align__(128) CUtensorMap {\n"
+ " unsigned long long opaque[16];\n"
+ "};\n\n" + code_filtered
+ )
+
+ # Create NVRTC program
+ # Use "tvm_kernels.cu" for consistency with nvcc path
+ result, prog = nvrtc.nvrtcCreateProgram(
+ str.encode(code_filtered), b"tvm_kernels.cu", 0, None, None
+ )
+ if result != nvrtc.nvrtcResult.NVRTC_SUCCESS:
+ raise RuntimeError(f"Failed to create NVRTC program:
{nvrtc.nvrtcGetErrorString(result)}")
+
+ # Prepare compilation options
+ cuda_path = find_cuda_path()
+ compile_opts = [
+ f"--gpu-architecture={arch}".encode(),
+ b"-default-device",
+ ]
+
+ # Add CUDA include paths. NVRTC needs explicit include paths for CUDA
headers.
+ # Standard installations: cuda_path/include
+ # Conda/architecture-specific installations:
cuda_path/targets/<arch>/include
+ include_paths = []
+
+ # Check standard include directory
+ standard_include = os.path.join(cuda_path, "include")
+ if os.path.isdir(standard_include):
+ include_paths.append(standard_include)
+
+ # Check architecture-specific include directory
+ arch_include = os.path.join(
+ cuda_path,
+ "targets",
+ f"{platform.machine()}-{platform.system().lower()}",
+ "include",
+ )
+ if os.path.isdir(arch_include):
+ include_paths.append(arch_include)
+
+ # Verify we can find essential CUDA headers
+ if not any(os.path.isfile(os.path.join(p, "cuda_runtime.h")) for p in
include_paths):
+ raise RuntimeError(
+ f"Cannot find CUDA headers in {cuda_path}. "
+ f"Searched in: {include_paths}. "
+ "Please ensure CUDA is properly installed."
+ )
+
+ # Add all valid include paths
+ for include_path in include_paths:
+ compile_opts.append(f"-I{include_path}".encode())
+
+ compile_opts.extend(
+ [
+ b"-U__CUDA_NO_HALF_OPERATORS__",
+ b"-U__CUDA_NO_HALF_CONVERSIONS__",
+ b"-U__CUDA_NO_BFLOAT16_OPERATORS__",
+ b"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
+ b"-U__CUDA_NO_BFLOAT162_OPERATORS__",
+ b"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
+ b"--use_fast_math",
+ ]
+ )
+
+ # Add user-provided options
+ if options:
+ if isinstance(options, str):
+ compile_opts.append(options.encode())
+ else:
+ compile_opts.extend([opt.encode() if isinstance(opt, str) else opt
for opt in options])
+
+ # Compile
+ (result,) = nvrtc.nvrtcCompileProgram(prog, len(compile_opts),
compile_opts)
+ if result != nvrtc.nvrtcResult.NVRTC_SUCCESS:
+ # Get compilation log
+ result_log, log_size = nvrtc.nvrtcGetProgramLogSize(prog)
+ if result_log == nvrtc.nvrtcResult.NVRTC_SUCCESS and log_size > 0:
+ log_buf = bytearray(log_size)
+ (result_log,) = nvrtc.nvrtcGetProgramLog(prog, log_buf)
+ if result_log == nvrtc.nvrtcResult.NVRTC_SUCCESS:
+ error_msg = f"NVRTC compilation
failed:\n{log_buf.decode('utf-8')}"
+ else:
+ error_msg = f"NVRTC compilation failed (couldn't get log):
{result}"
+ else:
+ error_msg = f"NVRTC compilation failed: {result}"
+
+ nvrtc.nvrtcDestroyProgram(prog)
+ raise RuntimeError(error_msg)
+
+ # Get compiled binary
+ if target_format == "cubin":
+ result, binary_size = nvrtc.nvrtcGetCUBINSize(prog)
+ if result != nvrtc.nvrtcResult.NVRTC_SUCCESS:
+ nvrtc.nvrtcDestroyProgram(prog)
+ raise RuntimeError(f"Failed to get CUBIN size:
{nvrtc.nvrtcGetErrorString(result)}")
+ binary_buf = bytearray(binary_size)
+ (result,) = nvrtc.nvrtcGetCUBIN(prog, binary_buf)
+ if result != nvrtc.nvrtcResult.NVRTC_SUCCESS:
+ nvrtc.nvrtcDestroyProgram(prog)
+ raise RuntimeError(f"Failed to get CUBIN:
{nvrtc.nvrtcGetErrorString(result)}")
+ else: # ptx
+ result, binary_size = nvrtc.nvrtcGetPTXSize(prog)
+ if result != nvrtc.nvrtcResult.NVRTC_SUCCESS:
+ nvrtc.nvrtcDestroyProgram(prog)
+ raise RuntimeError(f"Failed to get PTX size:
{nvrtc.nvrtcGetErrorString(result)}")
+ binary_buf = bytearray(binary_size)
+ (result,) = nvrtc.nvrtcGetPTX(prog, binary_buf)
+ if result != nvrtc.nvrtcResult.NVRTC_SUCCESS:
+ nvrtc.nvrtcDestroyProgram(prog)
+ raise RuntimeError(f"Failed to get PTX:
{nvrtc.nvrtcGetErrorString(result)}")
+
+ # Clean up
+ nvrtc.nvrtcDestroyProgram(prog)
+
+ return bytearray(binary_buf)
+
+
def find_cuda_path():
"""Utility function to find cuda path
@@ -241,7 +473,7 @@ def get_cuda_version(cuda_path=None):
(out, _) = proc.communicate()
out = py_str(out)
if proc.returncode == 0:
- release_line = [l for l in out.split("\n") if "release" in l][0]
+ release_line = [line for line in out.split("\n") if "release" in
line][0]
release_fields = [s.strip() for s in release_line.split(",")]
version_str = [f[1:] for f in release_fields if f.startswith("V")][0]
return tuple(int(field) for field in version_str.split("."))
@@ -280,16 +512,37 @@ def find_nvshmem_paths() -> Tuple[str, str]:
unique_candidates.append(path)
for root in unique_candidates:
- include_path = os.path.join(root, "include")
+ # Check both standard include path and versioned subdirectories (e.g.,
nvshmem_12)
+ include_paths_to_check = [os.path.join(root, "include")]
+
+ # Add versioned subdirectories like include/nvshmem_*
+ versioned_includes = glob.glob(os.path.join(root, "include",
"nvshmem_*"))
+ include_paths_to_check.extend(versioned_includes)
+
+ # Check standard and architecture-specific lib directories
lib_paths_to_check = [
os.path.join(root, "lib64"),
os.path.join(root, "lib"),
]
- if os.path.isfile(os.path.join(include_path, "nvshmem.h")):
- for lib_path in lib_paths_to_check:
- if os.path.isfile(os.path.join(lib_path, "libnvshmem.a")):
- return include_path, lib_path
+ # Add architecture-specific lib paths (e.g., lib/x86_64-linux-gnu)
+ machine = platform.machine()
+ system = platform.system().lower()
+ lib_paths_to_check.extend(
+ [
+ os.path.join(root, "lib", f"{machine}-{system}-gnu"),
+ os.path.join(root, "lib64", f"{machine}-{system}-gnu"),
+ ]
+ )
+
+ for include_path in include_paths_to_check:
+ if os.path.isfile(os.path.join(include_path, "nvshmem.h")):
+ for lib_path in lib_paths_to_check:
+ # Check for both static (.a) and shared (.so) libraries
+ if os.path.isfile(os.path.join(lib_path, "libnvshmem.a"))
or os.path.isfile(
+ os.path.join(lib_path, "libnvshmem.so")
+ ):
+ return include_path, lib_path
error_message = [
"Error: Could not find NVSHMEM installation.",
@@ -315,9 +568,39 @@ def find_nvshmem_paths() -> Tuple[str, str]:
@tvm_ffi.register_global_func
def tvm_callback_cuda_compile(code, target): # pylint: disable=unused-argument
- """use nvcc to generate fatbin code for better optimization"""
- ptx = compile_cuda(code, target_format="fatbin")
- return ptx
+ """
+ Compile CUDA code using the configured backend (nvcc or nvrtc).
+
+ This callback is invoked by TVM's C++ backend during CUDA module
compilation.
+ By default, uses nvcc to generate fatbin.
+
+ Environment Variables
+ ---------------------
+ TVM_CUDA_COMPILE_MODE : str
+ Compiler backend: "nvcc" (default) or "nvrtc"
+ - "nvcc": Use nvcc subprocess, generates fatbin
+ - "nvrtc": Use NVRTC via cuda-python for faster JIT, generates cubin
+
+ Parameters
+ ----------
+ code : str
+ CUDA source code to compile
+ target : Target
+ TVM target architecture
+
+ Returns
+ -------
+ bytes
+ Compiled binary (fatbin for nvcc, cubin for nvrtc)
+ """
+ compiler = os.environ.get("TVM_CUDA_COMPILE_MODE", "nvcc").lower()
+
+ if compiler == "nvrtc":
+ return compile_cuda(code, target_format="cubin", compiler="nvrtc")
+ if compiler == "nvcc":
+ return compile_cuda(code, target_format="fatbin", compiler="nvcc")
+
+ raise ValueError(f"Invalid TVM_CUDA_COMPILE_MODE: {compiler}. Expected
'nvcc' or 'nvrtc'.")
@tvm_ffi.register_global_func("tvm_callback_libdevice_path")
diff --git a/python/tvm/script/ir_builder/tir/external_kernel.py
b/python/tvm/script/ir_builder/tir/external_kernel.py
index 405e1e6cbf..45a3d364c1 100644
--- a/python/tvm/script/ir_builder/tir/external_kernel.py
+++ b/python/tvm/script/ir_builder/tir/external_kernel.py
@@ -17,14 +17,15 @@
"""External kernel integration fro TIR"""
import json
import logging
+import os
import tempfile
from pathlib import Path
from typing import Any, Dict, List, Tuple, Union
from tvm import __version__ as tvm_version
from tvm import tir
-from tvm.runtime import Module, load_module, const
from tvm.contrib import nvcc
+from tvm.runtime import Module, const, load_module
class BaseKernel: # pylint: disable=too-few-public-methods
@@ -100,10 +101,15 @@ class SourceKernel(BaseKernel): # pylint:
disable=too-few-public-methods
self.source_code = source_code
def compile_to_device_module( # pylint: disable=arguments-differ
- self, grid: List[List[Union[int, tir.PrimExpr]]], *args: List[Any],
**kwargs: Dict[str, Any]
+ self,
+ grid: List[List[Union[int, tir.PrimExpr]]],
+ *args: List[Any],
+ **kwargs: Dict[str, Any],
) -> Tuple[str, Module, List[Any]]:
"""Compile the kernel to a device module."""
- from tvm.relax.frontend.nn import SourceModule # pylint:
disable=import-outside-toplevel
+ from tvm.relax.frontend.nn import ( # pylint:
disable=import-outside-toplevel
+ SourceModule,
+ )
kernel_name = kwargs["kernel_name"]
assert len(grid) == 2, (
@@ -134,8 +140,13 @@ class SourceKernel(BaseKernel): # pylint:
disable=too-few-public-methods
with tempfile.TemporaryDirectory() as temp_dir:
ptx_path = f"{temp_dir}/{kernel_name}.ptx"
+ compiler = os.environ.get("TVM_CUDA_COMPILE_MODE", "nvcc")
nvcc.compile_cuda(
- source_code, target_format="ptx", options=compile_options,
path_target=ptx_path
+ source_code,
+ target_format="ptx",
+ options=compile_options,
+ path_target=ptx_path,
+ compiler=compiler,
)
with open(ptx_path, "r") as f:
ptx = f.read()
@@ -171,7 +182,10 @@ def call_kernel(
kwargs : Dict[str, Any]
Additional keyword arguments to pass to the kernel or compilation.
"""
- from ..ir import module_get_attr, module_set_attr # pylint:
disable=import-outside-toplevel
+ from ..ir import ( # pylint: disable=import-outside-toplevel
+ module_get_attr,
+ module_set_attr,
+ )
from .ir import call_packed # pylint: disable=import-outside-toplevel
kernel_type = f"{type(kernel).__module__}.{type(kernel).__qualname__}"
diff --git a/src/runtime/contrib/nvshmem/init.cc
b/src/runtime/contrib/nvshmem/init.cc
index 3471902bc3..d682e2cae5 100644
--- a/src/runtime/contrib/nvshmem/init.cc
+++ b/src/runtime/contrib/nvshmem/init.cc
@@ -16,6 +16,7 @@
* specific language governing permissions and limitations
* under the License.
*/
+#include <cuda.h>
#include <nvshmem.h>
#include <nvshmemx.h>
#include <picojson.h>
@@ -117,7 +118,22 @@ void NVSHMEMXCumoduleInit(void* cuModule) {
// NOTE: we do not check the return value of nvshmemx_cumodule_init.
// The reason is because that the input cuModule might not use any NVSHMEM
functions,
// in which case the nvshmemx_cumodule_init will fail.
- nvshmemx_cumodule_init(mod);
+
+ // A set of guards to check if the module has NVSHMEM symbol to avoid the
+ // "gpgpu named symbol not found" error.
+ CUdeviceptr d_ptr;
+ size_t d_size;
+ const char* kNvshmemDeviceSymbols[] = {
+ "nvshmemi_device_state_d", "nvshmem_i_device_state_d",
+ "nvshmemi_device_team_state_d", "nvshmemi_device_heap_base_d",
+ "nvshmemi_device_heap_size_d", "nvshmemi_device_heap_d",
+ };
+ for (const char* sym : kNvshmemDeviceSymbols) {
+ if (cuModuleGetGlobal(&d_ptr, &d_size, mod, sym) == CUDA_SUCCESS) {
+ nvshmemx_cumodule_init(mod);
+ return;
+ }
+ }
}
}
diff --git a/src/target/opt/build_cuda_on.cc b/src/target/opt/build_cuda_on.cc
index 8d2589aaec..88960594d0 100644
--- a/src/target/opt/build_cuda_on.cc
+++ b/src/target/opt/build_cuda_on.cc
@@ -28,7 +28,6 @@
#include <tvm/ffi/reflection/registry.h>
#endif
#include <cuda_runtime.h>
-#include <nvrtc.h>
#include <cstdlib>
@@ -40,91 +39,10 @@
namespace tvm {
namespace codegen {
-#define NVRTC_CALL(x)
\
- {
\
- nvrtcResult result = x;
\
- if (result != NVRTC_SUCCESS) {
\
- LOG(FATAL) << "NvrtcError: " #x " failed with error: " <<
nvrtcGetErrorString(result); \
- }
\
- }
-
-std::string FindCUDAIncludePath() {
-#if defined(_WIN32)
- const std::string delimiter = "\\";
-#else
- const std::string delimiter = "/";
-#endif
- std::string cuda_include_path;
- const char* cuda_path_env = std::getenv("CUDA_PATH");
- if (cuda_path_env != nullptr) {
- cuda_include_path += cuda_path_env;
- cuda_include_path += delimiter + "include";
- return cuda_include_path;
- }
-
-#if defined(__linux__)
- struct stat st;
- cuda_include_path = "/usr/local/cuda/include";
- if (stat(cuda_include_path.c_str(), &st) == 0) {
- return cuda_include_path;
- }
-
- if (stat("/usr/include/cuda.h", &st) == 0) {
- return "/usr/include";
- }
-#endif
- LOG(FATAL) << "Cannot find cuda include path."
- << "CUDA_PATH is not set or CUDA is not installed in the default
installation path."
- << "In other than linux, it is necessary to set CUDA_PATH.";
- return cuda_include_path;
-}
-
-std::string NVRTCCompile(const std::string& code, bool include_path = false) {
- std::vector<std::string> compile_params;
- std::vector<const char*> param_cstrings{};
- nvrtcProgram prog;
- std::string cc = "30";
- int major, minor;
- cudaError_t e1 = cudaDeviceGetAttribute(&major,
cudaDevAttrComputeCapabilityMajor, 0);
- cudaError_t e2 = cudaDeviceGetAttribute(&minor,
cudaDevAttrComputeCapabilityMinor, 0);
-
- if (e1 == cudaSuccess && e2 == cudaSuccess) {
- cc = std::to_string(major) + std::to_string(minor);
- } else {
- LOG(WARNING) << "cannot detect compute capability from your device, "
- << "fall back to compute_30.";
- }
-
- compile_params.push_back("-arch=compute_" + cc);
-
- if (include_path) {
- std::string include_option = "--include-path=" + FindCUDAIncludePath();
-
- compile_params.push_back(include_option);
- }
-
- for (const auto& string : compile_params) {
- param_cstrings.push_back(string.c_str());
- }
- NVRTC_CALL(nvrtcCreateProgram(&prog, code.c_str(), nullptr, 0, nullptr,
nullptr));
- nvrtcResult compile_res = nvrtcCompileProgram(prog, param_cstrings.size(),
param_cstrings.data());
-
- size_t log_size;
- NVRTC_CALL(nvrtcGetProgramLogSize(prog, &log_size));
- std::string log;
- log.resize(log_size);
- NVRTC_CALL(nvrtcGetProgramLog(prog, &log[0]));
- ICHECK_EQ(compile_res, NVRTC_SUCCESS) << log;
- size_t ptx_size;
- NVRTC_CALL(nvrtcGetPTXSize(prog, &ptx_size));
-
- std::string ptx;
- ptx.resize(ptx_size);
- NVRTC_CALL(nvrtcGetPTX(prog, &ptx[0]));
- NVRTC_CALL(nvrtcDestroyProgram(&prog));
-
- return ptx;
-}
+// Note: CUDA include path finding and NVRTC compilation are now handled
+// in Python for better maintainability and to leverage cuda-python bindings.
+// The C++ NVRTC code has been removed as part of the Python-first
+// compilation strategy.
ffi::Module BuildCUDA(IRModule mod, Target target) {
bool output_ssa = false;
@@ -157,20 +75,32 @@ ffi::Module BuildCUDA(IRModule mod, Target target) {
code = (*f)(code, target).cast<std::string>();
}
std::string fmt = "ptx";
- std::string ptx;
+ std::string compiled;
+
+ // Always use Python compilation callback (nvcc or nvrtc)
+ // The C++ NVRTC fallback has been removed in favor of Python-first approach
+ auto f_compile = ffi::Function::GetGlobal("tvm_callback_cuda_compile");
+ ICHECK(f_compile != nullptr)
+ << "tvm_callback_cuda_compile not found. "
+ << "Please ensure TVM Python runtime is properly initialized.\n"
+ << "The Python callback (tvm.contrib.nvcc.tvm_callback_cuda_compile) is
required "
+ << "for CUDA compilation. The C++ NVRTC fallback has been removed.\n"
+ << "Make sure to import tvm.contrib.nvcc in your Python code.";
+
+ // Enter target scope for compilation
auto f_enter = ffi::Function::GetGlobal("target.TargetEnterScope");
(*f_enter)(target);
- if (auto f = ffi::Function::GetGlobal("tvm_callback_cuda_compile")) {
- ptx = (*f)(code, target).cast<std::string>();
- // Dirty matching to check PTX vs cubin.
- // TODO(tqchen) more reliable checks
- if (ptx[0] != '/') fmt = "cubin";
- } else {
- ptx = NVRTCCompile(code, cg.need_include_path());
- }
+
+ // Compile CUDA code via Python callback
+ compiled = (*f_compile)(code, target).cast<std::string>();
+ // Dirty matching to check PTX vs cubin.
+ // TODO(tqchen) more reliable checks
+ if (compiled[0] != '/') fmt = "cubin";
+ // Exit target scope
auto f_exit = ffi::Function::GetGlobal("target.TargetExitScope");
(*f_exit)(target);
- return CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code);
+
+ return CUDAModuleCreate(compiled, fmt, ExtractFuncInfo(mod), code);
}
TVM_FFI_STATIC_INIT_BLOCK() {
diff --git a/src/target/source/codegen_cuda.cc
b/src/target/source/codegen_cuda.cc
index a9cfad9ab6..86201a2a05 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -310,10 +310,16 @@ std::string CodeGenCUDA::Finish() {
decl_stream << "#define TVM_ENABLE_L2_PREFETCH 0\n";
decl_stream << "#endif\n";
+ // Emit type aliases, guarding int64_t/uint64_t for compatibility
+ decl_stream << "\n#ifdef __CUDACC_RTC__\n";
+ decl_stream << "using int64_t = long long;\n";
+ decl_stream << "using uint64_t = unsigned long long;\n";
+ decl_stream << "#else\n";
decl_stream << "#include <cstdint>\n";
+ decl_stream << "#endif\n";
decl_stream << "using uint = unsigned int;\n";
decl_stream << "using uchar = unsigned char;\n";
- decl_stream << "using ushort = unsigned short;\n";
+ decl_stream << "using ushort = unsigned short;\n\n";
return CodeGenC::Finish();
}
diff --git a/src/target/source/literal/cuda_half_t.h
b/src/target/source/literal/cuda_half_t.h
index 3f1fcbc2dc..682845e9e7 100644
--- a/src/target/source/literal/cuda_half_t.h
+++ b/src/target/source/literal/cuda_half_t.h
@@ -391,7 +391,9 @@ void declare_vector_type_extensions(std::ostringstream&
stream, bool enable_fp16
bool enable_fp8, bool enable_fp4) {
if (enable_fp16 || enable_bf16) {
stream << R"(
-#include <type_traits>
+template <typename T, typename U> struct is_same { static constexpr bool value
= false; };
+template <typename T> struct is_same<T, T> { static constexpr bool value =
true; };
+
template <typename T, typename TVec2>
struct __align__(8) half4_bfloat164 {
T x, y, z, w;
@@ -401,7 +403,7 @@ struct __align__(8) half4_bfloat164 {
if (enable_fp8) {
stream << R"(
__host__ __device__ explicit half4_bfloat164(const __nv_fp8x4_e4m3& fp8x4) {
- if constexpr (std::is_same_v<T, __half>) {
+ if constexpr (is_same<T, __half>::value) {
__nv_fp8x2_e4m3 lo_part, hi_part;
lo_part.__x = static_cast<__nv_fp8x2_storage_t>(fp8x4.__x & 0xFFFF);
hi_part.__x = static_cast<__nv_fp8x2_storage_t>((fp8x4.__x >> 16) &
0xFFFF);
@@ -481,7 +483,7 @@ struct __align__(8) half4_bfloat164 {
if (enable_fp4) {
stream << R"(
__host__ __device__ explicit half4_bfloat164(const __nv_fp4x4_e2m1& fp4x4) {
- if constexpr (std::is_same_v<T, __half>) {
+ if constexpr (is_same<T, __half>::value) {
__nv_fp4x2_storage_t lo_part =
static_cast<__nv_fp4x2_storage_t>(fp4x4.__x & 0xFF);
__nv_fp4x2_storage_t hi_part =
static_cast<__nv_fp4x2_storage_t>((fp4x4.__x >> 8) & 0xFF);
TVec2 lo_half2 = __half2(__nv_cvt_fp4x2_to_halfraw2(lo_part, __NV_E2M1));
diff --git a/tests/python/codegen/test_target_codegen_cuda.py
b/tests/python/codegen/test_target_codegen_cuda.py
index 1b31e64414..177541da08 100644
--- a/tests/python/codegen/test_target_codegen_cuda.py
+++ b/tests/python/codegen/test_target_codegen_cuda.py
@@ -20,6 +20,7 @@ import numpy as np
import pytest
import tvm
+import tvm.contrib.nvcc
import tvm.testing
from tvm import te, topi
from tvm.contrib.nvcc import have_bf16, have_fp16, have_int8
@@ -27,6 +28,31 @@ from tvm.script import ir as I
from tvm.script import tir as T
[email protected](autouse=True, params=["nvcc", "nvrtc"])
+def setup_cuda_compile_mode(request):
+ mode = request.param
+ if mode == "nvrtc":
+ try:
+ from cuda.bindings import nvrtc
+ except ImportError:
+ pytest.skip("cuda-python not available, skipping nvrtc tests")
+
+ orig_func = tvm.contrib.nvcc.tvm_callback_cuda_compile
+
+ def compile_mode_wrapper(code, target):
+ if mode == "nvcc":
+ return tvm.contrib.nvcc.compile_cuda(code, target_format="fatbin",
compiler="nvcc")
+ elif mode == "nvrtc":
+ return tvm.contrib.nvcc.compile_cuda(code, target_format="cubin",
compiler="nvrtc")
+ else:
+ raise ValueError(f"Unknown mode: {mode}")
+
+ tvm.register_global_func("tvm_callback_cuda_compile",
compile_mode_wrapper, override=True)
+ # yield back to the original function so that each test runs twice
+ yield
+ tvm.register_global_func("tvm_callback_cuda_compile", orig_func,
override=True)
+
+
@tvm.testing.requires_gpu
@tvm.testing.requires_cuda
def test_cuda_vectorize_add():
@@ -201,13 +227,13 @@ def test_cuda_make_int8():
fun(a)
np.testing.assert_equal(a.numpy(), np_a)
- check_cuda(64, np.int8(0xAB), 4)
+ check_cuda(64, np.uint8(0xAB).view(np.int8), 4)
check_cuda(64, 0, 4)
check_cuda(64, -3, 4)
- check_cuda(64, np.int8(0xAB), 3)
+ check_cuda(64, np.uint8(0xAB).view(np.int8), 3)
check_cuda(64, 0, 3)
check_cuda(64, -3, 3)
- check_cuda(64, np.int8(0xAB), 2)
+ check_cuda(64, np.uint8(0xAB).view(np.int8), 2)
check_cuda(64, 0, 2)
check_cuda(64, -3, 2)
diff --git a/tests/python/disco/test_nvshmem.py
b/tests/python/disco/test_nvshmem.py
index d9976e05e5..029eb8fe82 100644
--- a/tests/python/disco/test_nvshmem.py
+++ b/tests/python/disco/test_nvshmem.py
@@ -16,13 +16,15 @@
# under the License.
"""Basic tests for a Disco nvshmem support"""
# pylint: disable=missing-docstring
-import tempfile
-
import numpy as np
import pytest
+
+import shutil
import subprocess
-import threading
import sys
+import tempfile
+import threading
+import multiprocessing
from multiprocessing import Process
from typing import Any, Callable, List
@@ -160,7 +162,8 @@ def test_nvshmem_compile():
T.writes(B[v1, v0])
B[v1, v0] = A[v0, v1]
- with tempfile.TemporaryDirectory() as tmpdir:
+ tmpdir = tempfile.mkdtemp()
+ try:
path = tmpdir + "/test.so"
A_np = np.arange(8 * 16).astype("float32").reshape([8, 16])
B_np = np.zeros((16, 8), dtype="float32")
@@ -180,9 +183,12 @@ def test_nvshmem_compile():
# finish the execution
sess._sync_all()
- finalize_dfunc =
sess.get_global_func("runtime.disco.nvshmem.finalize_nvshmem")
- finalize_dfunc()
- sess.sync_worker_0()
+ finalize_dfunc =
sess.get_global_func("runtime.disco.nvshmem.finalize_nvshmem")
+ finalize_dfunc()
+ sess.sync_worker_0()
+ finally:
+ sess.shutdown()
+ shutil.rmtree(tmpdir, ignore_errors=True)
if __name__ == "__main__":
@@ -190,14 +196,24 @@ if __name__ == "__main__":
# or `nvshmem_init_thread` in the same program results in undefined
behavior.
# So we always create a new process to run the test. Then no repeated
nvshmem
# init happens in the same process, since the worker0 may share the same
process.
+
+ # Use 'spawn' start method to avoid inheriting CUDA state from parent
process
+ # 'fork' (default on Linux) can cause issues with CUDA contexts in child
processes
+ multiprocessing.set_start_method("spawn", force=True)
+
for session_kind in [create_socket_session, di.ProcessSession]:
for num_workers in [2, 4]:
for test_func in [test_nvshmem_init_finalize, test_nvshmem_empty]:
p = Process(target=test_func, args=[session_kind, num_workers])
p.start()
p.join()
+ # Ensure the process finished successfully
+ assert (
+ p.exitcode == 0
+ ), f"Test {test_func.__name__} failed with exit code
{p.exitcode}"
# testing compilation flow
p = Process(target=test_nvshmem_compile)
p.start()
p.join()
+ assert p.exitcode == 0, f"Test test_nvshmem_compile failed with exit code
{p.exitcode}"
diff --git
a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py
b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py
index 0855afcfd6..aa7e2b3575 100644
--- a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py
+++ b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py
@@ -256,10 +256,17 @@ cast_smem_ptr_to_int(const void* const smem_ptr)
#else
#define TVM_ENABLE_L2_PREFETCH 0
#endif
+
+#ifdef __CUDACC_RTC__
+using int64_t = long long;
+using uint64_t = unsigned long long;
+#else
#include <cstdint>
+#endif
using uint = unsigned int;
using uchar = unsigned char;
using ushort = unsigned short;
+
extern "C" __global__ void __launch_bounds__(16) main_kernel(float*
__restrict__ A, float* __restrict__ B, float* __restrict__ C);
extern "C" __global__ void __launch_bounds__(16) main_kernel(float*
__restrict__ A, float* __restrict__ B, float* __restrict__ C) {
__shared__ float A_shared[64];