This is an automated email from the ASF dual-hosted git repository.

bohan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 2004a8bcbf [NVRTC] Add NVSHMEM support to NVRTC compilation path 
(#18681)
2004a8bcbf is described below

commit 2004a8bcbfde8bc0c46995bfb4ce152d7dd4ec51
Author: Kathryn (Jinqi) Chen <[email protected]>
AuthorDate: Sat Jan 24 11:52:51 2026 -0800

    [NVRTC] Add NVSHMEM support to NVRTC compilation path (#18681)
---
 python/tvm/contrib/nvcc.py                         | 234 +++++++++++++++++++--
 .../tvm/script/ir_builder/tir/external_kernel.py   |  48 +++--
 src/runtime/cuda/cuda_module.cc                    |   1 +
 tests/python/disco/test_nvshmem.py                 | 126 ++++++++++-
 4 files changed, 379 insertions(+), 30 deletions(-)

diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py
index edf3e8af4f..7706f63973 100644
--- a/python/tvm/contrib/nvcc.py
+++ b/python/tvm/contrib/nvcc.py
@@ -71,16 +71,17 @@ def compile_cuda(
     - NVRTC is a "runtime" compilation library and can be faster for JIT 
compilation.
     - NVRTC requires cuda-python: pip install cuda-python
     """
-    # TODO: if need NVSHMEM for compilation, fall back to NVCC because support 
for NVRTC
-    # is not yet implemented
     use_nvshmem = "#include <nvshmem.h>" in code or "#include <nvshmemx.h>" in 
code
-    if compiler == "nvcc" or use_nvshmem:
-        return _compile_cuda_nvcc(code, target_format, arch, options, 
path_target, use_nvshmem)
+
+    if compiler == "nvcc":
+        result = _compile_cuda_nvcc(code, target_format, arch, options, 
path_target, use_nvshmem)
     elif compiler == "nvrtc":
-        return _compile_cuda_nvrtc(code, target_format, arch, options)
+        result = _compile_cuda_nvrtc(code, target_format, arch, options, 
path_target, use_nvshmem)
     else:
         raise ValueError(f"cuda compiler must be 'nvcc' or 'nvrtc', got: 
{compiler}")
 
+    return result
+
 
 def _compile_cuda_nvcc(
     code,
@@ -235,7 +236,9 @@ def _compile_cuda_nvcc(
         return data
 
 
-def _compile_cuda_nvrtc(code, target_format=None, arch=None, options=None):
+def _compile_cuda_nvrtc(
+    code, target_format=None, arch=None, options=None, path_target=None, 
use_nvshmem=False
+):
     """Compile CUDA code using NVRTC (NVIDIA Runtime Compilation).
 
     Parameters
@@ -248,6 +251,10 @@ def _compile_cuda_nvrtc(code, target_format=None, 
arch=None, options=None):
         Target architecture (e.g., "sm_80"). Auto-detected if None.
     options : str or list of str, optional
         Additional NVRTC options.
+    path_target : str, optional
+        Output file path. If provided, the compiled binary is written to this 
path.
+    use_nvshmem : bool, optional
+        Whether NVSHMEM is used. Default: False
 
     Returns
     -------
@@ -264,8 +271,20 @@ def _compile_cuda_nvrtc(code, target_format=None, 
arch=None, options=None):
             "See: https://nvidia.github.io/cuda-python/";
         ) from e
 
-    # Default target format
-    if target_format is None:
+    # For NVSHMEM, we also need the CUDA driver API to initialize the context 
for linking
+    if use_nvshmem:
+        import importlib.util  # pylint: disable=import-outside-toplevel
+
+        if importlib.util.find_spec("cuda.bindings.driver") is None:
+            raise RuntimeError(
+                "Failed to compile CUDA with NVRTC+NVSHMEM because the 
`cuda-python` package "
+                "is not available.\n"
+                "Please install it with: pip install cuda-python\n"
+                "See: https://nvidia.github.io/cuda-python/";
+            )
+
+    # NVSHMEM requires linking with device library, which always produces cubin
+    if use_nvshmem or target_format is None:
         target_format = "cubin"
 
     # Validate target_format (NVRTC doesn't support fatbin)
@@ -287,6 +306,11 @@ def _compile_cuda_nvrtc(code, target_format=None, 
arch=None, options=None):
         compute_version = 
get_target_compute_version(Target.current(allow_none=True))
         arch = f"sm_{''.join(compute_version.split('.'))}"
 
+    # Get NVSHMEM paths if needed
+    nvshmem_include_path, nvshmem_lib_path = None, None
+    if use_nvshmem:
+        nvshmem_include_path, nvshmem_lib_path = find_nvshmem_paths()
+
     # Strip host-only headers for NVRTC. NVRTC compiles device code and does 
not
     # require the CUDA driver header or host C++ headers.
     headers_to_strip = {"#include <cuda.h>"}
@@ -304,6 +328,47 @@ def _compile_cuda_nvrtc(code, target_format=None, 
arch=None, options=None):
             "};\n\n" + code_filtered
         )
 
+    # Add standard type definitions and compatibility macros that NVRTC 
doesn't provide.
+    nvrtc_preamble = """#include <cuda/std/cstdint>
+using cuda::std::uint8_t;
+using cuda::std::uint16_t;
+using cuda::std::uint32_t;
+using cuda::std::uint64_t;
+using cuda::std::int8_t;
+using cuda::std::int16_t;
+using cuda::std::int32_t;
+using cuda::std::int64_t;
+
+// NVRTC uses asm/volatile instead of __asm__/__volatile__
+#ifndef __asm__
+#define __asm__ asm
+#endif
+#ifndef __volatile__
+#define __volatile__ volatile
+#endif
+
+"""
+    code_filtered = nvrtc_preamble + code_filtered
+
+    # For NVSHMEM, add preamble to map cuda::std type traits to std namespace.
+    # NVSHMEM headers require std:: type traits but NVRTC uses cuda::std::.
+    if use_nvshmem:
+        nvshmem_preamble = """#include <cuda/std/type_traits>
+
+// Map cuda::std type traits to std namespace for NVSHMEM headers
+namespace std {
+    using cuda::std::is_integral;
+    using cuda::std::is_signed;
+    using cuda::std::is_unsigned;
+    using cuda::std::is_floating_point;
+    using cuda::std::is_same;
+    using cuda::std::enable_if;
+    using cuda::std::conditional;
+}
+
+"""
+        code_filtered = nvshmem_preamble + code_filtered
+
     # Create NVRTC program
     # Use "tvm_kernels.cu" for consistency with nvcc path
     result, prog = nvrtc.nvrtcCreateProgram(
@@ -319,6 +384,9 @@ def _compile_cuda_nvrtc(code, target_format=None, 
arch=None, options=None):
         b"-default-device",
     ]
 
+    if use_nvshmem:
+        compile_opts.extend([b"-rdc", b"true"])
+
     # Add CUDA include paths. NVRTC needs explicit include paths for CUDA 
headers.
     # Standard installations: cuda_path/include
     # Conda/architecture-specific installations: 
cuda_path/targets/<arch>/include
@@ -339,6 +407,12 @@ def _compile_cuda_nvrtc(code, target_format=None, 
arch=None, options=None):
     if os.path.isdir(arch_include):
         include_paths.append(arch_include)
 
+    # Check for CCCL include directory (required for cuda/std/cstdint and 
type_traits)
+    # CCCL provides standard library functionality for device code
+    cccl_include = os.path.join(arch_include, "cccl") if 
os.path.isdir(arch_include) else None
+    if cccl_include and os.path.isdir(cccl_include):
+        include_paths.append(cccl_include)
+
     # Verify we can find essential CUDA headers
     if not any(os.path.isfile(os.path.join(p, "cuda_runtime.h")) for p in 
include_paths):
         raise RuntimeError(
@@ -351,6 +425,26 @@ def _compile_cuda_nvrtc(code, target_format=None, 
arch=None, options=None):
     for include_path in include_paths:
         compile_opts.append(f"-I{include_path}".encode())
 
+    # Add NVSHMEM include path
+    if use_nvshmem and nvshmem_include_path:
+        compile_opts.append(f"-I{nvshmem_include_path}".encode())
+
+    # For NVSHMEM, add deprecation and type conversion macros
+    if use_nvshmem:
+        compile_opts.extend(
+            [
+                # Define deprecation macros as empty (not properly defined in 
NVRTC context)
+                b"-D__NV_SILENCE_DEPRECATION_BEGIN=",
+                b"-D__NV_SILENCE_DEPRECATION_END=",
+                b"-D__NV_SILENCE_HOST_DEPRECATION_BEGIN=",
+                b"-D__NV_SILENCE_HOST_DEPRECATION_END=",
+                # Disable FP8/FP6/FP4 extended types that cause issues with 
NVRTC
+                b"-D__CUDA_NO_FP8_CONVERSIONS__",
+                b"-D__CUDA_NO_FP6_CONVERSIONS__",
+                b"-D__CUDA_NO_FP4_CONVERSIONS__",
+            ]
+        )
+
     compile_opts.extend(
         [
             b"-U__CUDA_NO_HALF_OPERATORS__",
@@ -363,12 +457,40 @@ def _compile_cuda_nvrtc(code, target_format=None, 
arch=None, options=None):
         ]
     )
 
-    # Add user-provided options
+    # Add user-provided options, filtering out nvcc-specific flags that nvrtc 
doesn't support
     if options:
+        nvcc_only_prefixes = (
+            "-c",
+            "-O",
+            "-std",
+            "--std",
+            "-Xcompiler",
+            "-Xlinker",
+            "-Xarchive",
+            "-Xcudafe",
+            "-Xptxas",
+            "--compile",
+            "--compiler-options",
+            "--linker-options",
+            "-fPIC",
+            "-shared",
+            "-o",
+        )
         if isinstance(options, str):
-            compile_opts.append(options.encode())
-        else:
-            compile_opts.extend([opt.encode() if isinstance(opt, str) else opt 
for opt in options])
+            options = [options]
+        for opt in options:
+            if isinstance(opt, str):
+                opt_str = opt
+            elif isinstance(opt, bytes):
+                opt_str = opt.decode()
+            else:
+                opt_str = str(opt)
+            skip = any(
+                opt_str.startswith(prefix) or opt_str == prefix for prefix in 
nvcc_only_prefixes
+            )
+            if skip:
+                continue
+            compile_opts.append(opt.encode() if isinstance(opt, str) else opt)
 
     # Compile
     (result,) = nvrtc.nvrtcCompileProgram(prog, len(compile_opts), 
compile_opts)
@@ -410,10 +532,94 @@ def _compile_cuda_nvrtc(code, target_format=None, 
arch=None, options=None):
             nvrtc.nvrtcDestroyProgram(prog)
             raise RuntimeError(f"Failed to get PTX: 
{nvrtc.nvrtcGetErrorString(result)}")
 
-    # Clean up
+    # Clean up NVRTC program
     nvrtc.nvrtcDestroyProgram(prog)
 
-    return bytearray(binary_buf)
+    # Link stage for NVSHMEM
+    if use_nvshmem:
+        binary_buf = _link_nvshmem_nvrtc(binary_buf, nvshmem_lib_path)
+
+    if path_target:
+        with open(path_target, "wb") as f:
+            f.write(binary_buf)
+    return binary_buf
+
+
+def _link_nvshmem_nvrtc(binary_buf, nvshmem_lib_path):
+    """Link compiled CUBIN with NVSHMEM device library using CUDA driver 
API."""
+    import ctypes  # pylint: disable=import-outside-toplevel
+
+    from cuda.bindings import driver as cu  # pylint: 
disable=import-outside-toplevel
+
+    # cuLinkCreate requires a valid CUDA context.
+    # Always create a fresh context for linking to avoid issues with stale 
contexts
+    # in multi-process environments like Disco workers.
+    (result,) = cu.cuInit(0)
+    if result != cu.CUresult.CUDA_SUCCESS:
+        raise RuntimeError(f"Failed to initialize CUDA: {result}")
+
+    result, device = cu.cuDeviceGet(0)
+    if result != cu.CUresult.CUDA_SUCCESS:
+        raise RuntimeError(f"Failed to get CUDA device: {result}")
+
+    result, context = cu.cuCtxCreate(None, 0, device)
+    if result != cu.CUresult.CUDA_SUCCESS:
+        raise RuntimeError(f"Failed to create CUDA context: {result}")
+
+    try:
+        # Create linker
+        result, link_state = cu.cuLinkCreate(0, [], [])
+        if result != cu.CUresult.CUDA_SUCCESS:
+            raise RuntimeError(f"Failed to create CUDA linker: {result}")
+
+        try:
+            # Add our compiled CUBIN
+            (result,) = cu.cuLinkAddData(
+                link_state,
+                cu.CUjitInputType.CU_JIT_INPUT_CUBIN,
+                binary_buf,
+                len(binary_buf),
+                b"tvm_kernels.cubin",
+                0,
+                [],
+                [],
+            )
+            if result != cu.CUresult.CUDA_SUCCESS:
+                raise RuntimeError(f"Failed to add CUBIN to linker: {result}")
+
+            # Add NVSHMEM device library
+            nvshmem_device_lib = os.path.join(nvshmem_lib_path, 
"libnvshmem_device.a")
+            if not os.path.exists(nvshmem_device_lib):
+                raise RuntimeError(f"NVSHMEM device library not found: 
{nvshmem_device_lib}")
+
+            (result,) = cu.cuLinkAddFile(
+                link_state,
+                cu.CUjitInputType.CU_JIT_INPUT_LIBRARY,
+                nvshmem_device_lib.encode(),
+                0,
+                [],
+                [],
+            )
+            if result != cu.CUresult.CUDA_SUCCESS:
+                raise RuntimeError(f"Failed to add NVSHMEM device library: 
{result}")
+
+            # Complete linking
+            result, linked_cubin, linked_size = cu.cuLinkComplete(link_state)
+            if result != cu.CUresult.CUDA_SUCCESS:
+                raise RuntimeError(f"Failed to complete NVSHMEM linking: 
{result}")
+
+            # Copy linked binary before destroying linker
+            binary_buf = bytearray(ctypes.string_at(linked_cubin, linked_size))
+            if not binary_buf:
+                raise RuntimeError("Compilation error: empty result is 
generated")
+        finally:
+            # Clean up linker
+            cu.cuLinkDestroy(link_state)
+    finally:
+        # Clean up context
+        cu.cuCtxDestroy(context)
+
+    return binary_buf
 
 
 def find_cuda_path():
diff --git a/python/tvm/script/ir_builder/tir/external_kernel.py 
b/python/tvm/script/ir_builder/tir/external_kernel.py
index 45a3d364c1..d7854d7a68 100644
--- a/python/tvm/script/ir_builder/tir/external_kernel.py
+++ b/python/tvm/script/ir_builder/tir/external_kernel.py
@@ -58,14 +58,16 @@ class BaseKernel:  # pylint: disable=too-few-public-methods
         )
         return tvm_metadata
 
-    def _create_cuda_module(self, ptx, kernel_arg_types, launch_param_tags, 
kernel_name):
+    def _create_cuda_module(
+        self, binary_data, kernel_arg_types, launch_param_tags, kernel_name, 
fmt="ptx"
+    ):
         """
-        Create a CUDA module from PTX and metadata.
+        Create a CUDA module from compiled binary (PTX or cubin) and metadata.
 
         Parameters
         ----------
-        ptx : str
-            The PTX code of the kernel.
+        binary_data : str or bytes
+            The compiled binary data (PTX as str, cubin as bytes).
 
         kernel_arg_types : List[str]
             The types of the kernel arguments.
@@ -76,6 +78,9 @@ class BaseKernel:  # pylint: disable=too-few-public-methods
         kernel_name : str
             The name of the kernel.
 
+        fmt : str
+            The format of the binary data: "ptx" or "cubin".
+
         Returns
         -------
         kernel_module : Module
@@ -85,12 +90,16 @@ class BaseKernel:  # pylint: disable=too-few-public-methods
             kernel_name, kernel_arg_types, launch_param_tags
         )
         with tempfile.TemporaryDirectory() as temp_dir:
-            ptx_path = f"{temp_dir}/{kernel_name}.ptx"
-            with open(ptx_path, "w") as f:
-                f.write(ptx)
+            binary_path = f"{temp_dir}/{kernel_name}.{fmt}"
+            if fmt == "ptx":
+                with open(binary_path, "w") as f:
+                    f.write(binary_data)
+            else:
+                with open(binary_path, "wb") as f:
+                    f.write(binary_data)
             with open(f"{temp_dir}/{kernel_name}.tvm_meta.json", "w") as f:
                 f.write(tvm_metadata)
-            kernel_module = load_module(ptx_path)
+            kernel_module = load_module(binary_path)
         return kernel_module
 
 
@@ -139,20 +148,31 @@ class SourceKernel(BaseKernel):  # pylint: 
disable=too-few-public-methods
             pass
 
         with tempfile.TemporaryDirectory() as temp_dir:
-            ptx_path = f"{temp_dir}/{kernel_name}.ptx"
+            # Check if NVSHMEM is used - requires cubin output for device 
library linking
+            use_nvshmem = (
+                "#include <nvshmem.h>" in source_code or "#include 
<nvshmemx.h>" in source_code
+            )
+            target_format = "cubin" if use_nvshmem else "ptx"
+            output_path = f"{temp_dir}/{kernel_name}.{target_format}"
+
             compiler = os.environ.get("TVM_CUDA_COMPILE_MODE", "nvcc")
             nvcc.compile_cuda(
                 source_code,
-                target_format="ptx",
+                target_format=target_format,
                 options=compile_options,
-                path_target=ptx_path,
+                path_target=output_path,
                 compiler=compiler,
             )
-            with open(ptx_path, "r") as f:
-                ptx = f.read()
+
+            if target_format == "ptx":
+                with open(output_path, "r") as f:
+                    binary_data = f.read()
+            else:
+                with open(output_path, "rb") as f:
+                    binary_data = f.read()
 
             kernel_module = self._create_cuda_module(
-                ptx, kernel_arg_types, launch_param_tags, kernel_name
+                binary_data, kernel_arg_types, launch_param_tags, kernel_name, 
fmt=target_format
             )
 
         return kernel_name, kernel_module, runtime_args
diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc
index 19f4288c97..f5219ae98a 100644
--- a/src/runtime/cuda/cuda_module.cc
+++ b/src/runtime/cuda/cuda_module.cc
@@ -342,6 +342,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
   refl::GlobalDef()
       .def("ffi.Module.load_from_file.cuda", CUDAModuleLoadFile)
       .def("ffi.Module.load_from_file.ptx", CUDAModuleLoadFile)
+      .def("ffi.Module.load_from_file.cubin", CUDAModuleLoadFile)
       .def("ffi.Module.load_from_bytes.cuda", CUDAModuleLoadFromBytes);
 }
 }  // namespace runtime
diff --git a/tests/python/disco/test_nvshmem.py 
b/tests/python/disco/test_nvshmem.py
index 029eb8fe82..b98b49591d 100644
--- a/tests/python/disco/test_nvshmem.py
+++ b/tests/python/disco/test_nvshmem.py
@@ -28,6 +28,8 @@ import multiprocessing
 from multiprocessing import Process
 from typing import Any, Callable, List
 
+from tvm.script import ir as I
+from tvm.script import relax as R
 from tvm.script import tir as T
 
 
@@ -142,7 +144,7 @@ def test_nvshmem_compile():
     if tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid", True) is 
None:
         return
 
-    num_workers = 4
+    num_workers = 2
     sess = di.ProcessSession(num_workers=num_workers)
 
     f_init_nvshmem_uid = 
tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid")
@@ -191,6 +193,121 @@ def test_nvshmem_compile():
         shutil.rmtree(tmpdir, ignore_errors=True)
 
 
+NVSHMEM_QUERY_KERNEL_SOURCE = """
+#include <nvshmem.h>
+
+extern "C" __global__ void nvshmem_query_kernel(int* my_pe_out, int* 
n_pes_out) {
+    my_pe_out[0] = nvshmem_my_pe();
+    n_pes_out[0] = nvshmem_n_pes();
+}
+"""
+
+
+def _test_nvshmem_kernel_compile_impl():
+    """Test compiling and running a kernel that calls NVSHMEM functions"""
+    if tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid", True) is 
None:
+        return
+
+    num_workers = 2
+    sess = di.ProcessSession(num_workers=num_workers)
+
+    f_init_nvshmem_uid = 
tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid")
+    uid = f_init_nvshmem_uid()
+    init_dfunc = sess.get_global_func("runtime.disco.nvshmem.init_nvshmem")
+    init_dfunc(uid, num_workers, 0)
+    sess.sync_worker_0()
+
+    try:
+
+        @I.ir_module
+        class NvshmemQueryModule:
+            @T.prim_func
+            def query_pe(
+                my_pe_out: T.Buffer((1,), "int32"),
+                n_pes_out: T.Buffer((1,), "int32"),
+            ):
+                with T.block("root"):
+                    T.reads()
+                    T.writes(my_pe_out[0:1], n_pes_out[0:1])
+                    T.call_kernel(
+                        NVSHMEM_QUERY_KERNEL_SOURCE,
+                        ((1,), (1,)),  # grid=(1,), block=(1,)
+                        my_pe_out.data,
+                        n_pes_out.data,
+                        kernel_name="nvshmem_query_kernel",
+                    )
+
+            @R.function
+            def main() -> R.Tuple(R.Tensor((1,), "int32"), R.Tensor((1,), 
"int32")):
+                cls = NvshmemQueryModule
+                with R.dataflow():
+                    my_pe = R.call_tir(
+                        cls.query_pe,
+                        (),
+                        out_sinfo=[
+                            R.Tensor((1,), "int32"),
+                            R.Tensor((1,), "int32"),
+                        ],
+                    )
+                    R.output(my_pe)
+                return my_pe
+
+        tmpdir = tempfile.mkdtemp()
+        try:
+            path = tmpdir + "/test_nvshmem_kernel.so"
+
+            target = tvm.target.Target("cuda")
+            tvm.compile(NvshmemQueryModule, target=target).export_library(path)
+            mod = sess.load_vm_module(path)
+            result = mod["main"]()
+
+            # Verify results from each worker
+            for worker_id in range(num_workers):
+                my_pe_result, n_pes_result = 
result.debug_get_from_remote(worker_id)
+                my_pe_val = my_pe_result.numpy()[0]
+                n_pes_val = n_pes_result.numpy()[0]
+                assert (
+                    my_pe_val == worker_id
+                ), f"Worker {worker_id} reported my_pe={my_pe_val}, expected 
{worker_id}"
+                assert (
+                    n_pes_val == num_workers
+                ), f"Worker {worker_id} reported n_pes={n_pes_val}, expected 
{num_workers}"
+
+            # Sync all workers before cleanup
+            sess._sync_all()
+
+            finalize_dfunc = 
sess.get_global_func("runtime.disco.nvshmem.finalize_nvshmem")
+            finalize_dfunc()
+            sess.sync_worker_0()
+        finally:
+            shutil.rmtree(tmpdir, ignore_errors=True)
+    finally:
+        sess.shutdown()
+
+
+def test_nvshmem_kernel_compile_nvcc():
+    """Test NVSHMEM kernel compilation with nvcc."""
+    # Since this test runs in a separate process, we can safely set the env var
+    import os
+
+    os.environ["TVM_CUDA_COMPILE_MODE"] = "nvcc"
+    _test_nvshmem_kernel_compile_impl()
+
+
+def test_nvshmem_kernel_compile_nvrtc():
+    """Test NVSHMEM kernel compilation with nvrtc."""
+    try:
+        from cuda.bindings import nvrtc  # noqa: F401
+    except ImportError:
+        pytest.skip("cuda-python not available, skipping nvrtc test")
+
+    # Since this test runs in a separate process, we can safely set the env var
+    import os
+
+    os.environ["TVM_CUDA_COMPILE_MODE"] = "nvrtc"
+    _test_nvshmem_kernel_compile_impl()
+
+
 if __name__ == "__main__":
     # After the first call to `nvshmem_init`, a subsequent call to 
`nvshmem_init`
     # or `nvshmem_init_thread` in the same program results in undefined 
behavior.
@@ -212,8 +329,13 @@ if __name__ == "__main__":
                     p.exitcode == 0
                 ), f"Test {test_func.__name__} failed with exit code 
{p.exitcode}"
 
-    # testing compilation flow
     p = Process(target=test_nvshmem_compile)
     p.start()
     p.join()
     assert p.exitcode == 0, f"Test test_nvshmem_compile failed with exit code 
{p.exitcode}"
+
+    for test_func in [test_nvshmem_kernel_compile_nvcc, 
test_nvshmem_kernel_compile_nvrtc]:
+        p = Process(target=test_func)
+        p.start()
+        p.join()
+        assert p.exitcode == 0, f"Test {test_func.__name__} failed with exit 
code {p.exitcode}"

Reply via email to