This is an automated email from the ASF dual-hosted git repository.

bohan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new fa905d2b69 [Compile] accelerate compilation speed using NVRTC (#18519)
fa905d2b69 is described below

commit fa905d2b693e5368ea059a72f3fd1333005f6560
Author: Kathryn (Jinqi) Chen <[email protected]>
AuthorDate: Thu Jan 8 11:08:06 2026 -0500

    [Compile] accelerate compilation speed using NVRTC (#18519)
    
    This PR supports NVRTC as an alternative to NVCC for faster, device-side
    JIT compilation of CUDA kernels, in favor of the PR
    
[https://github.com/apache/tvm-ffi/pull/283](https://github.com/apache/tvm-ffi/pull/283).
    
    It enhances the CUDA compilation backend by:
    - Adding Python NVRTC support using cuda-python bindings
    - Removing legacy C++ NVRTC fallback in favor of a Python-first approach
    - Keeping nvcc as the default compiler with fatbin output (no behavior
    change for existing users)
    
    Users can choose the compilation backend using an environment variable
    `TVM_CUDA_COMPILE_MODE`, choosing from "nvcc" and "nvrtc". For example,
    
    `TVM_CUDA_COMPILE_MODE=nvrtc python3 your_program.py`
    
    Here is a short benchmark of the compilation speed of kernels in
    `test_target_codegen_cuda.py`.
    
    ### NVCC vs NVRTC Compilation Time Comparison (Python-side Call)
    
    | Test Case | Code Size | NVCC Time (ms) | NVRTC Time (ms) | Speedup |
    | :--- | :--- | :--- | :--- | :--- |
    | `test_crossthread_reduction1` | 1945 B | 241.27 | 51.23 | **4.7x** |
    | `test_cuda_bf16_vectorize_add` | 3760 B | 342.72 | 44.50 | **7.7x** |
    | `test_cuda_const_float_to_half` | 12394 B | 272.85 | 31.99 | **8.5x**
    |
    | `test_cuda_device_func_call` | 975 B | 215.58 | 21.47 | **10.0x** |
    | `test_cuda_float_const_hex_format` | 685 B | 217.39 | 20.52 |
    **10.6x** |
    | `test_cuda_floordiv_with_vectorization` | 1050 B | 213.88 | 23.32 |
    **9.2x** |
    | `test_cuda_inf_nan` | 673 B | 214.33 | 24.94 | **8.6x** |
    | `test_cuda_tensormap` | 755 B | 213.91 | 20.74 | **10.3x** |
    | `test_cuda_thread_sync_inside_condition` | 1007 B | 213.43 | 28.29 |
    **7.5x** |
    | `test_cuda_vectorize_add` | 908 B | 226.81 | 40.39 | **5.6x** |
    | `test_cuda_vectorize_load` | 734 B | 217.25 | 24.02 | **9.0x** |
    | `test_device_host_call_same_func` | 924 B | 216.03 | 21.21 | **10.2x**
    |
    | `test_vectorized_intrin1` | 847 B | 226.15 | 26.34 | **8.6x** |
    
    ### NVSHMEM Support
    
    Currently, NVSHMEM is **not** supported via NVRTC.
    - Fallback Behavior: When NVSHMEM is required, the compilation pipeline
    will automatically fall back to NVCC, even if `TVM_CUDA_COMPILE_MODE` is
    set to nvrtc.
    - Future Roadmap: Support for NVRTC with NVSHMEM is planned for
    follow-up PRs.
---
 cmake/modules/CUDA.cmake                           |   2 -
 cmake/utils/FindCUDA.cmake                         |   9 -
 docker/Dockerfile.ci_gpu                           |   3 +
 docker/install/ubuntu_install_cuda_python.sh       |  23 ++
 python/tvm/contrib/nvcc.py                         | 329 +++++++++++++++++++--
 .../tvm/script/ir_builder/tir/external_kernel.py   |  24 +-
 src/runtime/contrib/nvshmem/init.cc                |  18 +-
 src/target/opt/build_cuda_on.cc                    | 122 ++------
 src/target/source/codegen_cuda.cc                  |   8 +-
 src/target/source/literal/cuda_half_t.h            |   8 +-
 tests/python/codegen/test_target_codegen_cuda.py   |  32 +-
 tests/python/disco/test_nvshmem.py                 |  30 +-
 .../test_tir_transform_inject_ptx_async_copy.py    |   7 +
 13 files changed, 465 insertions(+), 150 deletions(-)

diff --git a/cmake/modules/CUDA.cmake b/cmake/modules/CUDA.cmake
index 64f41d65fa..e9f5854901 100644
--- a/cmake/modules/CUDA.cmake
+++ b/cmake/modules/CUDA.cmake
@@ -54,10 +54,8 @@ if(USE_CUDA)
   list(APPEND RUNTIME_SRCS ${RUNTIME_CUDA_SRCS})
   list(APPEND COMPILER_SRCS src/target/opt/build_cuda_on.cc)
 
-  list(APPEND TVM_LINKER_LIBS ${CUDA_NVRTC_LIBRARY})
   list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CUDART_LIBRARY})
   list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CUDA_LIBRARY})
-  list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_NVRTC_LIBRARY})
 
   if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
     if(CMAKE_VERSION VERSION_LESS "3.24")
diff --git a/cmake/utils/FindCUDA.cmake b/cmake/utils/FindCUDA.cmake
index c4c18eef0f..c62506cf41 100644
--- a/cmake/utils/FindCUDA.cmake
+++ b/cmake/utils/FindCUDA.cmake
@@ -33,7 +33,6 @@
 # - CUDA_TOOLKIT_ROOT_DIR
 # - CUDA_CUDA_LIBRARY
 # - CUDA_CUDART_LIBRARY
-# - CUDA_NVRTC_LIBRARY
 # - CUDA_CUDNN_INCLUDE_DIRS
 # - CUDA_CUDNN_LIBRARY
 # - CUDA_CUBLAS_LIBRARY
@@ -64,9 +63,6 @@ macro(find_cuda use_cuda use_cudnn)
       find_library(CUDA_CUDA_LIBRARY cuda
         ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64
         ${CUDA_TOOLKIT_ROOT_DIR}/lib/Win32)
-      find_library(CUDA_NVRTC_LIBRARY nvrtc
-        ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64
-        ${CUDA_TOOLKIT_ROOT_DIR}/lib/Win32)
       find_library(CUDA_CUBLAS_LIBRARY cublas
         ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64
         ${CUDA_TOOLKIT_ROOT_DIR}/lib/Win32)
@@ -81,10 +77,6 @@ macro(find_cuda use_cuda use_cudnn)
       if(_CUDA_CUDA_LIBRARY)
         set(CUDA_CUDA_LIBRARY ${_CUDA_CUDA_LIBRARY})
       endif()
-      find_library(CUDA_NVRTC_LIBRARY nvrtc
-        PATHS ${CUDA_TOOLKIT_ROOT_DIR}
-        PATH_SUFFIXES lib lib64 targets/x86_64-linux/lib 
targets/x86_64-linux/lib/stubs lib64/stubs lib/x86_64-linux-gnu
-        NO_DEFAULT_PATH)
       find_library(CUDA_CURAND_LIBRARY curand
         PATHS ${CUDA_TOOLKIT_ROOT_DIR}
         PATH_SUFFIXES lib lib64 targets/x86_64-linux/lib 
targets/x86_64-linux/lib/stubs lib64/stubs lib/x86_64-linux-gnu
@@ -140,7 +132,6 @@ macro(find_cuda use_cuda use_cudnn)
     message(STATUS "Found CUDA_TOOLKIT_ROOT_DIR=" ${CUDA_TOOLKIT_ROOT_DIR})
     message(STATUS "Found CUDA_CUDA_LIBRARY=" ${CUDA_CUDA_LIBRARY})
     message(STATUS "Found CUDA_CUDART_LIBRARY=" ${CUDA_CUDART_LIBRARY})
-    message(STATUS "Found CUDA_NVRTC_LIBRARY=" ${CUDA_NVRTC_LIBRARY})
     message(STATUS "Found CUDA_CUDNN_INCLUDE_DIRS=" ${CUDA_CUDNN_INCLUDE_DIRS})
     message(STATUS "Found CUDA_CUDNN_LIBRARY=" ${CUDA_CUDNN_LIBRARY})
     message(STATUS "Found CUDA_CUBLAS_LIBRARY=" ${CUDA_CUBLAS_LIBRARY})
diff --git a/docker/Dockerfile.ci_gpu b/docker/Dockerfile.ci_gpu
index 1295c679d7..a72bd60fd7 100644
--- a/docker/Dockerfile.ci_gpu
+++ b/docker/Dockerfile.ci_gpu
@@ -60,6 +60,9 @@ RUN bash /install/ubuntu_install_opencl.sh
 COPY install/ubuntu_install_python_package.sh 
/install/ubuntu_install_python_package.sh
 RUN bash /install/ubuntu_install_python_package.sh
 
+COPY install/ubuntu_install_cuda_python.sh 
/install/ubuntu_install_cuda_python.sh
+RUN bash /install/ubuntu_install_cuda_python.sh
+
 COPY install/ubuntu_install_sphinx.sh /install/ubuntu_install_sphinx.sh
 RUN bash /install/ubuntu_install_sphinx.sh
 
diff --git a/docker/install/ubuntu_install_cuda_python.sh 
b/docker/install/ubuntu_install_cuda_python.sh
new file mode 100644
index 0000000000..eb4efac5c0
--- /dev/null
+++ b/docker/install/ubuntu_install_cuda_python.sh
@@ -0,0 +1,23 @@
+#!/bin/bash
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+set -e
+set -u
+set -o pipefail
+
+pip3 install cuda-python
diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py
index d062714938..edf3e8af4f 100644
--- a/python/tvm/contrib/nvcc.py
+++ b/python/tvm/contrib/nvcc.py
@@ -16,14 +16,18 @@
 # under the License.
 # pylint: disable=invalid-name
 """Utility to invoke nvcc compiler in the system"""
+
 from __future__ import absolute_import as _abs
 
+import glob
 import os
+import platform
 import subprocess
 import warnings
 from typing import Tuple
 
 import tvm_ffi
+
 import tvm
 from tvm.target import Target
 
@@ -31,8 +35,10 @@ from ..base import py_str
 from . import utils
 
 
-def compile_cuda(code, target_format=None, arch=None, options=None, 
path_target=None):
-    """Compile cuda code with NVCC from env.
+def compile_cuda(
+    code, target_format=None, arch=None, options=None, path_target=None, 
compiler="nvcc"
+):
+    """Compile cuda code with NVCC or NVRTC.
 
     Parameters
     ----------
@@ -40,7 +46,7 @@ def compile_cuda(code, target_format=None, arch=None, 
options=None, path_target=
         The cuda code.
 
     target_format : str
-        The target format of nvcc compiler.
+        The target format of the compiler ("ptx", "cubin", or "fatbin").
 
     arch : str
         The cuda architecture.
@@ -51,14 +57,61 @@ def compile_cuda(code, target_format=None, arch=None, 
options=None, path_target=
     path_target : str, optional
         Output file.
 
-    Return
-    ------
-    cubin : bytearray
-        The bytearray of the cubin
+    compiler : str, optional
+        Compiler backend: "nvcc" or "nvrtc".
+        This can be set by the TVM_CUDA_COMPILE_MODE environment variable.
+
+    Returns
+    -------
+    res_binary : bytearray
+        The bytearray of the compiled binary (ptx/cubin/fatbin).
+
+    Notes
+    -----
+    - NVRTC is a "runtime" compilation library and can be faster for JIT 
compilation.
+    - NVRTC requires cuda-python: pip install cuda-python
+    """
+    # TODO: if need NVSHMEM for compilation, fall back to NVCC because support 
for NVRTC
+    # is not yet implemented
+    use_nvshmem = "#include <nvshmem.h>" in code or "#include <nvshmemx.h>" in 
code
+    if compiler == "nvcc" or use_nvshmem:
+        return _compile_cuda_nvcc(code, target_format, arch, options, 
path_target, use_nvshmem)
+    elif compiler == "nvrtc":
+        return _compile_cuda_nvrtc(code, target_format, arch, options)
+    else:
+        raise ValueError(f"cuda compiler must be 'nvcc' or 'nvrtc', got: 
{compiler}")
+
+
+def _compile_cuda_nvcc(
+    code,
+    target_format=None,
+    arch=None,
+    options=None,
+    path_target=None,
+    use_nvshmem=False,
+):
+    """Compile CUDA code using nvcc.
+
+    Parameters
+    ----------
+    code : str
+        The CUDA source code.
+    target_format : str, optional
+        Output format: "ptx", "cubin", or "fatbin".
+    arch : str, optional
+        Target architecture. Auto-detected if None.
+    options : str or list of str, optional
+        Additional nvcc options.
+    path_target : str, optional
+        Output file path.
+
+    Returns
+    -------
+    bytearray
+        Compiled binary data.
     """
     # Check for NVSHMEM dependency
     nvshmem_include_path, nvshmem_lib_path = None, None
-    use_nvshmem = "#include <nvshmem.h>" in code or "#include <nvshmemx.h>" in 
code
     if use_nvshmem:
         # NOTE: we cannot check whether nvshmem is used based on whether
         # the global function "runtime.nvshmem.cumodule_init" is defined.
@@ -106,8 +159,9 @@ def compile_cuda(code, target_format=None, arch=None, 
options=None, path_target=
 
     file_target = path_target if path_target else temp_target
     if use_nvshmem:
-        file_prefix = file_target.split(".")[0]
+        file_prefix = os.path.splitext(file_target)[0]
         file_target = f"{file_prefix}.o"  # in the first stage, compile to 
object file
+
     cmd = ["nvcc"]
     cmd += [f"--{target_format}", "-O3"]
     if kernels_output_dir is not None:
@@ -151,14 +205,11 @@ def compile_cuda(code, target_format=None, arch=None, 
options=None, path_target=
         msg += py_str(out)
         raise RuntimeError(msg)
 
-    # start second stage of compilation
+    # Second stage for NVSHMEM
     if use_nvshmem:
         cmd = ["nvlink"]
         cmd += [f"-arch=sm_{compute_version}"]
-        cmd += [
-            "-L",
-            nvshmem_lib_path,
-        ]
+        cmd += ["-L", nvshmem_lib_path]
         cmd += ["-L", os.path.join(find_cuda_path(), "lib64")]
         cmd += ["-l", "nvshmem_device"]
         cmd += ["-l", "cudadevrt"]
@@ -184,6 +235,187 @@ def compile_cuda(code, target_format=None, arch=None, 
options=None, path_target=
         return data
 
 
+def _compile_cuda_nvrtc(code, target_format=None, arch=None, options=None):
+    """Compile CUDA code using NVRTC (NVIDIA Runtime Compilation).
+
+    Parameters
+    ----------
+    code : str
+        The CUDA source code.
+    target_format : str, optional
+        Output format: "cubin" or "ptx". Default: "cubin"
+    arch : str, optional
+        Target architecture (e.g., "sm_80"). Auto-detected if None.
+    options : str or list of str, optional
+        Additional NVRTC options.
+
+    Returns
+    -------
+    bytearray
+        Compiled binary data.
+    """
+    try:
+        from cuda.bindings import nvrtc  # pylint: 
disable=import-outside-toplevel
+    except ImportError as e:
+        raise RuntimeError(
+            "Failed to compile CUDA with NVRTC because the `cuda-python` 
package "
+            "is not available.\n"
+            "Please install it with: pip install cuda-python\n"
+            "See: https://nvidia.github.io/cuda-python/";
+        ) from e
+
+    # Default target format
+    if target_format is None:
+        target_format = "cubin"
+
+    # Validate target_format (NVRTC doesn't support fatbin)
+    if target_format == "fatbin":
+        raise ValueError(
+            "NVRTC does not support fatbin generation yet. "
+            "Use target_format='cubin' or 'ptx' with NVRTC, "
+            "or set compiler='nvcc' for fatbin compilation."
+        )
+    if target_format not in ["cubin", "ptx"]:
+        raise ValueError(f"target_format must be 'cubin' or 'ptx', got: 
{target_format}")
+
+    # Validate options
+    if options is not None and not isinstance(options, (str, list)):
+        raise ValueError("options must be str or list of str")
+
+    # Auto-detect architecture
+    if arch is None:
+        compute_version = 
get_target_compute_version(Target.current(allow_none=True))
+        arch = f"sm_{''.join(compute_version.split('.'))}"
+
+    # Strip host-only headers for NVRTC. NVRTC compiles device code and does 
not
+    # require the CUDA driver header or host C++ headers.
+    headers_to_strip = {"#include <cuda.h>"}
+    code_filtered = "\n".join(
+        line for line in code.splitlines() if line.strip() not in 
headers_to_strip
+    )
+
+    # NVRTC compiles device code and does not include the host-side cuda.h.
+    # CUtensorMap is a host-side structure, to reference and use it in device 
code,
+    # we must forward-declare it for NVRTC.
+    if "CUtensorMap" in code_filtered:
+        code_filtered = (
+            "struct __align__(128) CUtensorMap {\n"
+            "  unsigned long long opaque[16];\n"
+            "};\n\n" + code_filtered
+        )
+
+    # Create NVRTC program
+    # Use "tvm_kernels.cu" for consistency with nvcc path
+    result, prog = nvrtc.nvrtcCreateProgram(
+        str.encode(code_filtered), b"tvm_kernels.cu", 0, None, None
+    )
+    if result != nvrtc.nvrtcResult.NVRTC_SUCCESS:
+        raise RuntimeError(f"Failed to create NVRTC program: 
{nvrtc.nvrtcGetErrorString(result)}")
+
+    # Prepare compilation options
+    cuda_path = find_cuda_path()
+    compile_opts = [
+        f"--gpu-architecture={arch}".encode(),
+        b"-default-device",
+    ]
+
+    # Add CUDA include paths. NVRTC needs explicit include paths for CUDA 
headers.
+    # Standard installations: cuda_path/include
+    # Conda/architecture-specific installations: 
cuda_path/targets/<arch>/include
+    include_paths = []
+
+    # Check standard include directory
+    standard_include = os.path.join(cuda_path, "include")
+    if os.path.isdir(standard_include):
+        include_paths.append(standard_include)
+
+    # Check architecture-specific include directory
+    arch_include = os.path.join(
+        cuda_path,
+        "targets",
+        f"{platform.machine()}-{platform.system().lower()}",
+        "include",
+    )
+    if os.path.isdir(arch_include):
+        include_paths.append(arch_include)
+
+    # Verify we can find essential CUDA headers
+    if not any(os.path.isfile(os.path.join(p, "cuda_runtime.h")) for p in 
include_paths):
+        raise RuntimeError(
+            f"Cannot find CUDA headers in {cuda_path}. "
+            f"Searched in: {include_paths}. "
+            "Please ensure CUDA is properly installed."
+        )
+
+    # Add all valid include paths
+    for include_path in include_paths:
+        compile_opts.append(f"-I{include_path}".encode())
+
+    compile_opts.extend(
+        [
+            b"-U__CUDA_NO_HALF_OPERATORS__",
+            b"-U__CUDA_NO_HALF_CONVERSIONS__",
+            b"-U__CUDA_NO_BFLOAT16_OPERATORS__",
+            b"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
+            b"-U__CUDA_NO_BFLOAT162_OPERATORS__",
+            b"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
+            b"--use_fast_math",
+        ]
+    )
+
+    # Add user-provided options
+    if options:
+        if isinstance(options, str):
+            compile_opts.append(options.encode())
+        else:
+            compile_opts.extend([opt.encode() if isinstance(opt, str) else opt 
for opt in options])
+
+    # Compile
+    (result,) = nvrtc.nvrtcCompileProgram(prog, len(compile_opts), 
compile_opts)
+    if result != nvrtc.nvrtcResult.NVRTC_SUCCESS:
+        # Get compilation log
+        result_log, log_size = nvrtc.nvrtcGetProgramLogSize(prog)
+        if result_log == nvrtc.nvrtcResult.NVRTC_SUCCESS and log_size > 0:
+            log_buf = bytearray(log_size)
+            (result_log,) = nvrtc.nvrtcGetProgramLog(prog, log_buf)
+            if result_log == nvrtc.nvrtcResult.NVRTC_SUCCESS:
+                error_msg = f"NVRTC compilation 
failed:\n{log_buf.decode('utf-8')}"
+            else:
+                error_msg = f"NVRTC compilation failed (couldn't get log): 
{result}"
+        else:
+            error_msg = f"NVRTC compilation failed: {result}"
+
+        nvrtc.nvrtcDestroyProgram(prog)
+        raise RuntimeError(error_msg)
+
+    # Get compiled binary
+    if target_format == "cubin":
+        result, binary_size = nvrtc.nvrtcGetCUBINSize(prog)
+        if result != nvrtc.nvrtcResult.NVRTC_SUCCESS:
+            nvrtc.nvrtcDestroyProgram(prog)
+            raise RuntimeError(f"Failed to get CUBIN size: 
{nvrtc.nvrtcGetErrorString(result)}")
+        binary_buf = bytearray(binary_size)
+        (result,) = nvrtc.nvrtcGetCUBIN(prog, binary_buf)
+        if result != nvrtc.nvrtcResult.NVRTC_SUCCESS:
+            nvrtc.nvrtcDestroyProgram(prog)
+            raise RuntimeError(f"Failed to get CUBIN: 
{nvrtc.nvrtcGetErrorString(result)}")
+    else:  # ptx
+        result, binary_size = nvrtc.nvrtcGetPTXSize(prog)
+        if result != nvrtc.nvrtcResult.NVRTC_SUCCESS:
+            nvrtc.nvrtcDestroyProgram(prog)
+            raise RuntimeError(f"Failed to get PTX size: 
{nvrtc.nvrtcGetErrorString(result)}")
+        binary_buf = bytearray(binary_size)
+        (result,) = nvrtc.nvrtcGetPTX(prog, binary_buf)
+        if result != nvrtc.nvrtcResult.NVRTC_SUCCESS:
+            nvrtc.nvrtcDestroyProgram(prog)
+            raise RuntimeError(f"Failed to get PTX: 
{nvrtc.nvrtcGetErrorString(result)}")
+
+    # Clean up
+    nvrtc.nvrtcDestroyProgram(prog)
+
+    return bytearray(binary_buf)
+
+
 def find_cuda_path():
     """Utility function to find cuda path
 
@@ -241,7 +473,7 @@ def get_cuda_version(cuda_path=None):
     (out, _) = proc.communicate()
     out = py_str(out)
     if proc.returncode == 0:
-        release_line = [l for l in out.split("\n") if "release" in l][0]
+        release_line = [line for line in out.split("\n") if "release" in 
line][0]
         release_fields = [s.strip() for s in release_line.split(",")]
         version_str = [f[1:] for f in release_fields if f.startswith("V")][0]
         return tuple(int(field) for field in version_str.split("."))
@@ -280,16 +512,37 @@ def find_nvshmem_paths() -> Tuple[str, str]:
             unique_candidates.append(path)
 
     for root in unique_candidates:
-        include_path = os.path.join(root, "include")
+        # Check both standard include path and versioned subdirectories (e.g., 
nvshmem_12)
+        include_paths_to_check = [os.path.join(root, "include")]
+
+        # Add versioned subdirectories like include/nvshmem_*
+        versioned_includes = glob.glob(os.path.join(root, "include", 
"nvshmem_*"))
+        include_paths_to_check.extend(versioned_includes)
+
+        # Check standard and architecture-specific lib directories
         lib_paths_to_check = [
             os.path.join(root, "lib64"),
             os.path.join(root, "lib"),
         ]
 
-        if os.path.isfile(os.path.join(include_path, "nvshmem.h")):
-            for lib_path in lib_paths_to_check:
-                if os.path.isfile(os.path.join(lib_path, "libnvshmem.a")):
-                    return include_path, lib_path
+        # Add architecture-specific lib paths (e.g., lib/x86_64-linux-gnu)
+        machine = platform.machine()
+        system = platform.system().lower()
+        lib_paths_to_check.extend(
+            [
+                os.path.join(root, "lib", f"{machine}-{system}-gnu"),
+                os.path.join(root, "lib64", f"{machine}-{system}-gnu"),
+            ]
+        )
+
+        for include_path in include_paths_to_check:
+            if os.path.isfile(os.path.join(include_path, "nvshmem.h")):
+                for lib_path in lib_paths_to_check:
+                    # Check for both static (.a) and shared (.so) libraries
+                    if os.path.isfile(os.path.join(lib_path, "libnvshmem.a")) 
or os.path.isfile(
+                        os.path.join(lib_path, "libnvshmem.so")
+                    ):
+                        return include_path, lib_path
 
     error_message = [
         "Error: Could not find NVSHMEM installation.",
@@ -315,9 +568,39 @@ def find_nvshmem_paths() -> Tuple[str, str]:
 
 @tvm_ffi.register_global_func
 def tvm_callback_cuda_compile(code, target):  # pylint: disable=unused-argument
-    """use nvcc to generate fatbin code for better optimization"""
-    ptx = compile_cuda(code, target_format="fatbin")
-    return ptx
+    """
+    Compile CUDA code using the configured backend (nvcc or nvrtc).
+
+    This callback is invoked by TVM's C++ backend during CUDA module 
compilation.
+    By default, uses nvcc to generate fatbin.
+
+    Environment Variables
+    ---------------------
+    TVM_CUDA_COMPILE_MODE : str
+        Compiler backend: "nvcc" (default) or "nvrtc"
+        - "nvcc": Use nvcc subprocess, generates fatbin
+        - "nvrtc": Use NVRTC via cuda-python for faster JIT, generates cubin
+
+    Parameters
+    ----------
+    code : str
+        CUDA source code to compile
+    target : Target
+        TVM target architecture
+
+    Returns
+    -------
+    bytes
+        Compiled binary (fatbin for nvcc, cubin for nvrtc)
+    """
+    compiler = os.environ.get("TVM_CUDA_COMPILE_MODE", "nvcc").lower()
+
+    if compiler == "nvrtc":
+        return compile_cuda(code, target_format="cubin", compiler="nvrtc")
+    if compiler == "nvcc":
+        return compile_cuda(code, target_format="fatbin", compiler="nvcc")
+
+    raise ValueError(f"Invalid TVM_CUDA_COMPILE_MODE: {compiler}. Expected 
'nvcc' or 'nvrtc'.")
 
 
 @tvm_ffi.register_global_func("tvm_callback_libdevice_path")
diff --git a/python/tvm/script/ir_builder/tir/external_kernel.py 
b/python/tvm/script/ir_builder/tir/external_kernel.py
index 405e1e6cbf..45a3d364c1 100644
--- a/python/tvm/script/ir_builder/tir/external_kernel.py
+++ b/python/tvm/script/ir_builder/tir/external_kernel.py
@@ -17,14 +17,15 @@
 """External kernel integration fro TIR"""
 import json
 import logging
+import os
 import tempfile
 from pathlib import Path
 from typing import Any, Dict, List, Tuple, Union
 
 from tvm import __version__ as tvm_version
 from tvm import tir
-from tvm.runtime import Module, load_module, const
 from tvm.contrib import nvcc
+from tvm.runtime import Module, const, load_module
 
 
 class BaseKernel:  # pylint: disable=too-few-public-methods
@@ -100,10 +101,15 @@ class SourceKernel(BaseKernel):  # pylint: 
disable=too-few-public-methods
         self.source_code = source_code
 
     def compile_to_device_module(  # pylint: disable=arguments-differ
-        self, grid: List[List[Union[int, tir.PrimExpr]]], *args: List[Any], 
**kwargs: Dict[str, Any]
+        self,
+        grid: List[List[Union[int, tir.PrimExpr]]],
+        *args: List[Any],
+        **kwargs: Dict[str, Any],
     ) -> Tuple[str, Module, List[Any]]:
         """Compile the kernel to a device module."""
-        from tvm.relax.frontend.nn import SourceModule  # pylint: 
disable=import-outside-toplevel
+        from tvm.relax.frontend.nn import (  # pylint: 
disable=import-outside-toplevel
+            SourceModule,
+        )
 
         kernel_name = kwargs["kernel_name"]
         assert len(grid) == 2, (
@@ -134,8 +140,13 @@ class SourceKernel(BaseKernel):  # pylint: 
disable=too-few-public-methods
 
         with tempfile.TemporaryDirectory() as temp_dir:
             ptx_path = f"{temp_dir}/{kernel_name}.ptx"
+            compiler = os.environ.get("TVM_CUDA_COMPILE_MODE", "nvcc")
             nvcc.compile_cuda(
-                source_code, target_format="ptx", options=compile_options, 
path_target=ptx_path
+                source_code,
+                target_format="ptx",
+                options=compile_options,
+                path_target=ptx_path,
+                compiler=compiler,
             )
             with open(ptx_path, "r") as f:
                 ptx = f.read()
@@ -171,7 +182,10 @@ def call_kernel(
     kwargs : Dict[str, Any]
         Additional keyword arguments to pass to the kernel or compilation.
     """
-    from ..ir import module_get_attr, module_set_attr  # pylint: 
disable=import-outside-toplevel
+    from ..ir import (  # pylint: disable=import-outside-toplevel
+        module_get_attr,
+        module_set_attr,
+    )
     from .ir import call_packed  # pylint: disable=import-outside-toplevel
 
     kernel_type = f"{type(kernel).__module__}.{type(kernel).__qualname__}"
diff --git a/src/runtime/contrib/nvshmem/init.cc 
b/src/runtime/contrib/nvshmem/init.cc
index 3471902bc3..d682e2cae5 100644
--- a/src/runtime/contrib/nvshmem/init.cc
+++ b/src/runtime/contrib/nvshmem/init.cc
@@ -16,6 +16,7 @@
  * specific language governing permissions and limitations
  * under the License.
  */
+#include <cuda.h>
 #include <nvshmem.h>
 #include <nvshmemx.h>
 #include <picojson.h>
@@ -117,7 +118,22 @@ void NVSHMEMXCumoduleInit(void* cuModule) {
     // NOTE: we do not check the return value of nvshmemx_cumodule_init.
     // The reason is because that the input cuModule might not use any NVSHMEM 
functions,
     // in which case the nvshmemx_cumodule_init will fail.
-    nvshmemx_cumodule_init(mod);
+
+    // A set of guards to check if the module has NVSHMEM symbol to avoid the
+    // "gpgpu named symbol not found" error.
+    CUdeviceptr d_ptr;
+    size_t d_size;
+    const char* kNvshmemDeviceSymbols[] = {
+        "nvshmemi_device_state_d",      "nvshmem_i_device_state_d",
+        "nvshmemi_device_team_state_d", "nvshmemi_device_heap_base_d",
+        "nvshmemi_device_heap_size_d",  "nvshmemi_device_heap_d",
+    };
+    for (const char* sym : kNvshmemDeviceSymbols) {
+      if (cuModuleGetGlobal(&d_ptr, &d_size, mod, sym) == CUDA_SUCCESS) {
+        nvshmemx_cumodule_init(mod);
+        return;
+      }
+    }
   }
 }
 
diff --git a/src/target/opt/build_cuda_on.cc b/src/target/opt/build_cuda_on.cc
index 8d2589aaec..88960594d0 100644
--- a/src/target/opt/build_cuda_on.cc
+++ b/src/target/opt/build_cuda_on.cc
@@ -28,7 +28,6 @@
 #include <tvm/ffi/reflection/registry.h>
 #endif
 #include <cuda_runtime.h>
-#include <nvrtc.h>
 
 #include <cstdlib>
 
@@ -40,91 +39,10 @@
 namespace tvm {
 namespace codegen {
 
-#define NVRTC_CALL(x)                                                          
              \
-  {                                                                            
              \
-    nvrtcResult result = x;                                                    
              \
-    if (result != NVRTC_SUCCESS) {                                             
              \
-      LOG(FATAL) << "NvrtcError: " #x " failed with error: " << 
nvrtcGetErrorString(result); \
-    }                                                                          
              \
-  }
-
-std::string FindCUDAIncludePath() {
-#if defined(_WIN32)
-  const std::string delimiter = "\\";
-#else
-  const std::string delimiter = "/";
-#endif
-  std::string cuda_include_path;
-  const char* cuda_path_env = std::getenv("CUDA_PATH");
-  if (cuda_path_env != nullptr) {
-    cuda_include_path += cuda_path_env;
-    cuda_include_path += delimiter + "include";
-    return cuda_include_path;
-  }
-
-#if defined(__linux__)
-  struct stat st;
-  cuda_include_path = "/usr/local/cuda/include";
-  if (stat(cuda_include_path.c_str(), &st) == 0) {
-    return cuda_include_path;
-  }
-
-  if (stat("/usr/include/cuda.h", &st) == 0) {
-    return "/usr/include";
-  }
-#endif
-  LOG(FATAL) << "Cannot find cuda include path."
-             << "CUDA_PATH is not set or CUDA is not installed in the default 
installation path."
-             << "In other than linux, it is necessary to set CUDA_PATH.";
-  return cuda_include_path;
-}
-
-std::string NVRTCCompile(const std::string& code, bool include_path = false) {
-  std::vector<std::string> compile_params;
-  std::vector<const char*> param_cstrings{};
-  nvrtcProgram prog;
-  std::string cc = "30";
-  int major, minor;
-  cudaError_t e1 = cudaDeviceGetAttribute(&major, 
cudaDevAttrComputeCapabilityMajor, 0);
-  cudaError_t e2 = cudaDeviceGetAttribute(&minor, 
cudaDevAttrComputeCapabilityMinor, 0);
-
-  if (e1 == cudaSuccess && e2 == cudaSuccess) {
-    cc = std::to_string(major) + std::to_string(minor);
-  } else {
-    LOG(WARNING) << "cannot detect compute capability from your device, "
-                 << "fall back to compute_30.";
-  }
-
-  compile_params.push_back("-arch=compute_" + cc);
-
-  if (include_path) {
-    std::string include_option = "--include-path=" + FindCUDAIncludePath();
-
-    compile_params.push_back(include_option);
-  }
-
-  for (const auto& string : compile_params) {
-    param_cstrings.push_back(string.c_str());
-  }
-  NVRTC_CALL(nvrtcCreateProgram(&prog, code.c_str(), nullptr, 0, nullptr, 
nullptr));
-  nvrtcResult compile_res = nvrtcCompileProgram(prog, param_cstrings.size(), 
param_cstrings.data());
-
-  size_t log_size;
-  NVRTC_CALL(nvrtcGetProgramLogSize(prog, &log_size));
-  std::string log;
-  log.resize(log_size);
-  NVRTC_CALL(nvrtcGetProgramLog(prog, &log[0]));
-  ICHECK_EQ(compile_res, NVRTC_SUCCESS) << log;
-  size_t ptx_size;
-  NVRTC_CALL(nvrtcGetPTXSize(prog, &ptx_size));
-
-  std::string ptx;
-  ptx.resize(ptx_size);
-  NVRTC_CALL(nvrtcGetPTX(prog, &ptx[0]));
-  NVRTC_CALL(nvrtcDestroyProgram(&prog));
-
-  return ptx;
-}
+// Note: CUDA include path finding and NVRTC compilation are now handled
+// in Python for better maintainability and to leverage cuda-python bindings.
+// The C++ NVRTC code has been removed as part of the Python-first
+// compilation strategy.
 
 ffi::Module BuildCUDA(IRModule mod, Target target) {
   bool output_ssa = false;
@@ -157,20 +75,32 @@ ffi::Module BuildCUDA(IRModule mod, Target target) {
     code = (*f)(code, target).cast<std::string>();
   }
   std::string fmt = "ptx";
-  std::string ptx;
+  std::string compiled;
+
+  // Always use Python compilation callback (nvcc or nvrtc)
+  // The C++ NVRTC fallback has been removed in favor of Python-first approach
+  auto f_compile = ffi::Function::GetGlobal("tvm_callback_cuda_compile");
+  ICHECK(f_compile != nullptr)
+      << "tvm_callback_cuda_compile not found. "
+      << "Please ensure TVM Python runtime is properly initialized.\n"
+      << "The Python callback (tvm.contrib.nvcc.tvm_callback_cuda_compile) is 
required "
+      << "for CUDA compilation. The C++ NVRTC fallback has been removed.\n"
+      << "Make sure to import tvm.contrib.nvcc in your Python code.";
+
+  // Enter target scope for compilation
   auto f_enter = ffi::Function::GetGlobal("target.TargetEnterScope");
   (*f_enter)(target);
-  if (auto f = ffi::Function::GetGlobal("tvm_callback_cuda_compile")) {
-    ptx = (*f)(code, target).cast<std::string>();
-    // Dirty matching to check PTX vs cubin.
-    // TODO(tqchen) more reliable checks
-    if (ptx[0] != '/') fmt = "cubin";
-  } else {
-    ptx = NVRTCCompile(code, cg.need_include_path());
-  }
+
+  // Compile CUDA code via Python callback
+  compiled = (*f_compile)(code, target).cast<std::string>();
+  // Dirty matching to check PTX vs cubin.
+  // TODO(tqchen) more reliable checks
+  if (compiled[0] != '/') fmt = "cubin";
+  // Exit target scope
   auto f_exit = ffi::Function::GetGlobal("target.TargetExitScope");
   (*f_exit)(target);
-  return CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code);
+
+  return CUDAModuleCreate(compiled, fmt, ExtractFuncInfo(mod), code);
 }
 
 TVM_FFI_STATIC_INIT_BLOCK() {
diff --git a/src/target/source/codegen_cuda.cc 
b/src/target/source/codegen_cuda.cc
index a9cfad9ab6..86201a2a05 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -310,10 +310,16 @@ std::string CodeGenCUDA::Finish() {
   decl_stream << "#define TVM_ENABLE_L2_PREFETCH 0\n";
   decl_stream << "#endif\n";
 
+  // Emit type aliases, guarding int64_t/uint64_t for compatibility
+  decl_stream << "\n#ifdef __CUDACC_RTC__\n";
+  decl_stream << "using int64_t = long long;\n";
+  decl_stream << "using uint64_t = unsigned long long;\n";
+  decl_stream << "#else\n";
   decl_stream << "#include <cstdint>\n";
+  decl_stream << "#endif\n";
   decl_stream << "using uint = unsigned int;\n";
   decl_stream << "using uchar = unsigned char;\n";
-  decl_stream << "using ushort = unsigned short;\n";
+  decl_stream << "using ushort = unsigned short;\n\n";
 
   return CodeGenC::Finish();
 }
diff --git a/src/target/source/literal/cuda_half_t.h 
b/src/target/source/literal/cuda_half_t.h
index 3f1fcbc2dc..682845e9e7 100644
--- a/src/target/source/literal/cuda_half_t.h
+++ b/src/target/source/literal/cuda_half_t.h
@@ -391,7 +391,9 @@ void declare_vector_type_extensions(std::ostringstream& 
stream, bool enable_fp16
                                     bool enable_fp8, bool enable_fp4) {
   if (enable_fp16 || enable_bf16) {
     stream << R"(
-#include <type_traits>
+template <typename T, typename U> struct is_same { static constexpr bool value 
= false; };
+template <typename T> struct is_same<T, T> { static constexpr bool value = 
true; };
+
 template <typename T, typename TVec2>
 struct __align__(8) half4_bfloat164 {
   T x, y, z, w;
@@ -401,7 +403,7 @@ struct __align__(8) half4_bfloat164 {
     if (enable_fp8) {
       stream << R"(
   __host__ __device__ explicit half4_bfloat164(const __nv_fp8x4_e4m3& fp8x4) {
-    if constexpr (std::is_same_v<T, __half>) {
+    if constexpr (is_same<T, __half>::value) {
       __nv_fp8x2_e4m3 lo_part, hi_part;
       lo_part.__x = static_cast<__nv_fp8x2_storage_t>(fp8x4.__x & 0xFFFF);
       hi_part.__x = static_cast<__nv_fp8x2_storage_t>((fp8x4.__x >> 16) & 
0xFFFF);
@@ -481,7 +483,7 @@ struct __align__(8) half4_bfloat164 {
     if (enable_fp4) {
       stream << R"(
   __host__ __device__ explicit half4_bfloat164(const __nv_fp4x4_e2m1& fp4x4) {
-    if constexpr (std::is_same_v<T, __half>) {
+    if constexpr (is_same<T, __half>::value) {
       __nv_fp4x2_storage_t lo_part = 
static_cast<__nv_fp4x2_storage_t>(fp4x4.__x & 0xFF);
       __nv_fp4x2_storage_t hi_part = 
static_cast<__nv_fp4x2_storage_t>((fp4x4.__x >> 8) & 0xFF);
       TVec2 lo_half2 = __half2(__nv_cvt_fp4x2_to_halfraw2(lo_part, __NV_E2M1));
diff --git a/tests/python/codegen/test_target_codegen_cuda.py 
b/tests/python/codegen/test_target_codegen_cuda.py
index 1b31e64414..177541da08 100644
--- a/tests/python/codegen/test_target_codegen_cuda.py
+++ b/tests/python/codegen/test_target_codegen_cuda.py
@@ -20,6 +20,7 @@ import numpy as np
 import pytest
 
 import tvm
+import tvm.contrib.nvcc
 import tvm.testing
 from tvm import te, topi
 from tvm.contrib.nvcc import have_bf16, have_fp16, have_int8
@@ -27,6 +28,31 @@ from tvm.script import ir as I
 from tvm.script import tir as T
 
 
[email protected](autouse=True, params=["nvcc", "nvrtc"])
+def setup_cuda_compile_mode(request):
+    mode = request.param
+    if mode == "nvrtc":
+        try:
+            from cuda.bindings import nvrtc
+        except ImportError:
+            pytest.skip("cuda-python not available, skipping nvrtc tests")
+
+    orig_func = tvm.contrib.nvcc.tvm_callback_cuda_compile
+
+    def compile_mode_wrapper(code, target):
+        if mode == "nvcc":
+            return tvm.contrib.nvcc.compile_cuda(code, target_format="fatbin", 
compiler="nvcc")
+        elif mode == "nvrtc":
+            return tvm.contrib.nvcc.compile_cuda(code, target_format="cubin", 
compiler="nvrtc")
+        else:
+            raise ValueError(f"Unknown mode: {mode}")
+
+    tvm.register_global_func("tvm_callback_cuda_compile", 
compile_mode_wrapper, override=True)
+    # yield back to the original function so that each test runs twice
+    yield
+    tvm.register_global_func("tvm_callback_cuda_compile", orig_func, 
override=True)
+
+
 @tvm.testing.requires_gpu
 @tvm.testing.requires_cuda
 def test_cuda_vectorize_add():
@@ -201,13 +227,13 @@ def test_cuda_make_int8():
         fun(a)
         np.testing.assert_equal(a.numpy(), np_a)
 
-    check_cuda(64, np.int8(0xAB), 4)
+    check_cuda(64, np.uint8(0xAB).view(np.int8), 4)
     check_cuda(64, 0, 4)
     check_cuda(64, -3, 4)
-    check_cuda(64, np.int8(0xAB), 3)
+    check_cuda(64, np.uint8(0xAB).view(np.int8), 3)
     check_cuda(64, 0, 3)
     check_cuda(64, -3, 3)
-    check_cuda(64, np.int8(0xAB), 2)
+    check_cuda(64, np.uint8(0xAB).view(np.int8), 2)
     check_cuda(64, 0, 2)
     check_cuda(64, -3, 2)
 
diff --git a/tests/python/disco/test_nvshmem.py 
b/tests/python/disco/test_nvshmem.py
index d9976e05e5..029eb8fe82 100644
--- a/tests/python/disco/test_nvshmem.py
+++ b/tests/python/disco/test_nvshmem.py
@@ -16,13 +16,15 @@
 # under the License.
 """Basic tests for a Disco nvshmem support"""
 # pylint: disable=missing-docstring
-import tempfile
-
 import numpy as np
 import pytest
+
+import shutil
 import subprocess
-import threading
 import sys
+import tempfile
+import threading
+import multiprocessing
 from multiprocessing import Process
 from typing import Any, Callable, List
 
@@ -160,7 +162,8 @@ def test_nvshmem_compile():
                     T.writes(B[v1, v0])
                     B[v1, v0] = A[v0, v1]
 
-    with tempfile.TemporaryDirectory() as tmpdir:
+    tmpdir = tempfile.mkdtemp()
+    try:
         path = tmpdir + "/test.so"
         A_np = np.arange(8 * 16).astype("float32").reshape([8, 16])
         B_np = np.zeros((16, 8), dtype="float32")
@@ -180,9 +183,12 @@ def test_nvshmem_compile():
         # finish the execution
         sess._sync_all()
 
-    finalize_dfunc = 
sess.get_global_func("runtime.disco.nvshmem.finalize_nvshmem")
-    finalize_dfunc()
-    sess.sync_worker_0()
+        finalize_dfunc = 
sess.get_global_func("runtime.disco.nvshmem.finalize_nvshmem")
+        finalize_dfunc()
+        sess.sync_worker_0()
+    finally:
+        sess.shutdown()
+        shutil.rmtree(tmpdir, ignore_errors=True)
 
 
 if __name__ == "__main__":
@@ -190,14 +196,24 @@ if __name__ == "__main__":
     # or `nvshmem_init_thread` in the same program results in undefined 
behavior.
     # So we always create a new process to run the test. Then no repeated 
nvshmem
     # init happens in the same process, since the worker0 may share the same 
process.
+
+    # Use 'spawn' start method to avoid inheriting CUDA state from parent 
process
+    # 'fork' (default on Linux) can cause issues with CUDA contexts in child 
processes
+    multiprocessing.set_start_method("spawn", force=True)
+
     for session_kind in [create_socket_session, di.ProcessSession]:
         for num_workers in [2, 4]:
             for test_func in [test_nvshmem_init_finalize, test_nvshmem_empty]:
                 p = Process(target=test_func, args=[session_kind, num_workers])
                 p.start()
                 p.join()
+                # Ensure the process finished successfully
+                assert (
+                    p.exitcode == 0
+                ), f"Test {test_func.__name__} failed with exit code 
{p.exitcode}"
 
     # testing compilation flow
     p = Process(target=test_nvshmem_compile)
     p.start()
     p.join()
+    assert p.exitcode == 0, f"Test test_nvshmem_compile failed with exit code 
{p.exitcode}"
diff --git 
a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py 
b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py
index 0855afcfd6..aa7e2b3575 100644
--- a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py
+++ b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py
@@ -256,10 +256,17 @@ cast_smem_ptr_to_int(const void* const smem_ptr)
 #else
 #define TVM_ENABLE_L2_PREFETCH 0
 #endif
+
+#ifdef __CUDACC_RTC__
+using int64_t = long long;
+using uint64_t = unsigned long long;
+#else
 #include <cstdint>
+#endif
 using uint = unsigned int;
 using uchar = unsigned char;
 using ushort = unsigned short;
+
 extern "C" __global__ void __launch_bounds__(16) main_kernel(float* 
__restrict__ A, float* __restrict__ B, float* __restrict__ C);
 extern "C" __global__ void __launch_bounds__(16) main_kernel(float* 
__restrict__ A, float* __restrict__ B, float* __restrict__ C) {
   __shared__ float A_shared[64];


Reply via email to