This is an automated email from the ASF dual-hosted git repository.

yaxingcai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new adac5ebd feat: add cuLaunchKernelEx support (#476)
adac5ebd is described below

commit adac5ebd0ad695fc77bf5113c8127a7298c1ed52
Author: Gabriel Wu <[email protected]>
AuthorDate: Sat Feb 28 00:02:43 2026 +0800

    feat: add cuLaunchKernelEx support (#476)
---
 include/tvm/ffi/extra/cuda/cubin_launcher.h       | 14 ++++
 include/tvm/ffi/extra/cuda/internal/unified_api.h | 74 +++++++++++++++++++
 tests/python/test_cubin_launcher.py               | 87 +++++++++++++++++++++++
 3 files changed, 175 insertions(+)

diff --git a/include/tvm/ffi/extra/cuda/cubin_launcher.h 
b/include/tvm/ffi/extra/cuda/cubin_launcher.h
index f1aabdef..d384ce43 100644
--- a/include/tvm/ffi/extra/cuda/cubin_launcher.h
+++ b/include/tvm/ffi/extra/cuda/cubin_launcher.h
@@ -482,6 +482,20 @@ class CubinKernel {
     return cuda_api::LaunchKernel(kernel_, args, grid, block, stream, 
dyn_smem_bytes);
   }
 
+  /*!
+   * \brief Launch the kernel using extended launch API with a pre-built 
config.
+   *
+   * This enables features like cluster dimensions (SM90+) that require
+   * cuLaunchKernelEx / cudaLaunchKernelExC.
+   *
+   * \param args Array of pointers to kernel arguments.
+   * \param config The launch configuration (populated by 
ConstructLaunchConfig).
+   * \return Result code.
+   */
+  cuda_api::ResultType LaunchEx(void** args, const cuda_api::LaunchConfig& 
config) {
+    return cuda_api::LaunchKernelEx(kernel_, args, config);
+  }
+
   /*! \brief Get the underlying cudaKernel_t handle */
   cuda_api::KernelHandle GetHandle() const { return kernel_; }
 
diff --git a/include/tvm/ffi/extra/cuda/internal/unified_api.h 
b/include/tvm/ffi/extra/cuda/internal/unified_api.h
index 7b16150a..861262c0 100644
--- a/include/tvm/ffi/extra/cuda/internal/unified_api.h
+++ b/include/tvm/ffi/extra/cuda/internal/unified_api.h
@@ -188,6 +188,80 @@ inline ResultType 
SetKernelMaxDynamicSharedMem(KernelHandle kernel, int shmem,
 #endif
 }
 
+/*!
+ * \brief Launch a kernel using the extended launch API with launch attributes.
+ *
+ * This enables features like cluster dimensions (SM90+) that require
+ * cuLaunchKernelEx / cudaLaunchKernelExC.
+ *
+ * \param kernel The kernel handle.
+ * \param args Array of pointers to kernel arguments.
+ * \param config The launch configuration (grid, block, smem, stream, 
attributes).
+ * \return Result code.
+ */
+inline ResultType LaunchKernelEx(KernelHandle kernel, void** args, const 
LaunchConfig& config) {
+#if TVM_FFI_CUBIN_LAUNCHER_USE_DRIVER_API
+  return cuLaunchKernelEx(&config, reinterpret_cast<CUfunction>(kernel), args, 
nullptr);
+#else
+  return cudaLaunchKernelExC(&config, reinterpret_cast<const void*>(kernel), 
args);
+#endif
+}
+
+/*!
+ * \brief Construct a launch configuration with optional cluster dimensions.
+ *
+ * \param kernel The kernel handle.
+ * \param stream The CUDA stream.
+ * \param smem_size Dynamic shared memory size in bytes.
+ * \param grid Grid dimensions.
+ * \param block Block dimensions.
+ * \param cluster_dim Cluster dimension (1 = no clustering, >1 enables cluster 
launch).
+ * \param[out] config The launch configuration to populate.
+ * \param[out] attr Storage for a launch attribute (must outlive the launch 
call).
+ * \return kSuccess.
+ */
+inline ResultType ConstructLaunchConfig(KernelHandle kernel, StreamHandle 
stream,
+                                        uint32_t smem_size, tvm::ffi::dim3 
grid,
+                                        tvm::ffi::dim3 block, int cluster_dim, 
LaunchConfig& config,
+                                        LaunchAttrType& attr) {
+#if TVM_FFI_CUBIN_LAUNCHER_USE_DRIVER_API
+  config.gridDimX = grid.x;
+  config.gridDimY = grid.y;
+  config.gridDimZ = grid.z;
+  config.blockDimX = block.x;
+  config.blockDimY = block.y;
+  config.blockDimZ = block.z;
+  config.sharedMemBytes = smem_size;
+  config.hStream = stream;
+  config.numAttrs = 0;
+  config.attrs = nullptr;
+
+  if (cluster_dim > 1) {
+    attr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
+    attr.value.clusterDim.x = static_cast<unsigned>(cluster_dim);
+    attr.value.clusterDim.y = 1;
+    attr.value.clusterDim.z = 1;
+    config.attrs = &attr;
+    config.numAttrs = 1;
+  }
+#else
+  config.gridDim = {grid.x, grid.y, grid.z};
+  config.blockDim = {block.x, block.y, block.z};
+  config.dynamicSmemBytes = smem_size;
+  config.stream = stream;
+  config.numAttrs = 0;
+  config.attrs = nullptr;
+
+  if (cluster_dim > 1) {
+    attr.id = cudaLaunchAttributeClusterDimension;
+    attr.val.clusterDim = {static_cast<unsigned>(cluster_dim), 1, 1};
+    config.attrs = &attr;
+    config.numAttrs = 1;
+  }
+#endif
+  return kSuccess;
+}
+
 // Additional wrappers for device operations used in CubinLauncher
 inline ResultType GetDeviceCount(int* count) {
 #if TVM_FFI_CUBIN_LAUNCHER_USE_DRIVER_API
diff --git a/tests/python/test_cubin_launcher.py 
b/tests/python/test_cubin_launcher.py
index c953f882..4139679d 100644
--- a/tests/python/test_cubin_launcher.py
+++ b/tests/python/test_cubin_launcher.py
@@ -225,6 +225,93 @@ TVM_FFI_DLL_EXPORT_TYPED_FUNC(launch_mul_two, 
cubin_test::LaunchMulTwo);
     torch.testing.assert_close(y, expected)
 
 
[email protected](sys.platform != "linux", reason="CUBIN launcher only 
supported on Linux")
[email protected](torch is None, reason="PyTorch not installed")
[email protected](not _is_cuda_available(), reason="CUDA not available")
[email protected](
+    not _is_cuda_version_greater_than_13(), reason="CUDA version must be 
greater than 13.0"
+)
+def test_cubin_launcher_launch_ex() -> None:
+    """Test LaunchEx with ConstructLaunchConfig (no clustering)."""
+    assert torch is not None, "PyTorch is required for this test"
+
+    cubin_bytes = _compile_kernel_to_cubin()
+
+    cpp_code = """
+#include <tvm/ffi/container/tensor.h>
+#include <tvm/ffi/error.h>
+#include <tvm/ffi/extra/c_env_api.h>
+#include <tvm/ffi/extra/cuda/cubin_launcher.h>
+#include <tvm/ffi/function.h>
+
+#include <memory>
+
+namespace cubin_test_launch_ex {
+
+static std::unique_ptr<tvm::ffi::CubinModule> g_module;
+static std::unique_ptr<tvm::ffi::CubinKernel> g_kernel_add_one;
+
+void LoadCubinData(const tvm::ffi::Bytes& cubin_data) {
+  g_module = std::make_unique<tvm::ffi::CubinModule>(cubin_data);
+  g_kernel_add_one = 
std::make_unique<tvm::ffi::CubinKernel>((*g_module)["add_one_cuda"]);
+}
+
+void LaunchAddOneEx(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
+  TVM_FFI_CHECK(g_module != nullptr, RuntimeError) << "CUBIN module not 
loaded";
+  TVM_FFI_CHECK(x.ndim() == 1, ValueError) << "Input must be 1D tensor";
+  TVM_FFI_CHECK(y.ndim() == 1, ValueError) << "Output must be 1D tensor";
+  TVM_FFI_CHECK(x.size(0) == y.size(0), ValueError) << "Sizes must match";
+
+  int64_t n = x.size(0);
+  void* x_ptr = x.data_ptr();
+  void* y_ptr = y.data_ptr();
+
+  void* args[] = {&x_ptr, &y_ptr, &n};
+
+  tvm::ffi::dim3 grid((n + 1023) / 1024);
+  tvm::ffi::dim3 block(1024);
+
+  DLDevice device = x.device();
+  auto stream = static_cast<tvm::ffi::cuda_api::StreamHandle>(
+      TVMFFIEnvGetStream(device.device_type, device.device_id));
+
+  // Use ConstructLaunchConfig + LaunchEx (cluster_dim=1 means no clustering)
+  tvm::ffi::cuda_api::LaunchConfig config;
+  tvm::ffi::cuda_api::LaunchAttrType attr;
+  auto err = tvm::ffi::cuda_api::ConstructLaunchConfig(
+      g_kernel_add_one->GetHandle(), stream, /*smem_size=*/0,
+      grid, block, /*cluster_dim=*/1, config, attr);
+  TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(err);
+
+  auto result = g_kernel_add_one->LaunchEx(args, config);
+  TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(result);
+}
+
+TVM_FFI_DLL_EXPORT_TYPED_FUNC(load_cubin_data, 
cubin_test_launch_ex::LoadCubinData);
+TVM_FFI_DLL_EXPORT_TYPED_FUNC(launch_add_one_ex, 
cubin_test_launch_ex::LaunchAddOneEx);
+
+}  // namespace cubin_test_launch_ex
+"""
+
+    mod = tvm_ffi.cpp.load_inline(
+        "cubin_test_launch_ex",
+        cuda_sources=cpp_code,
+        extra_ldflags=["-lcudart"],
+    )
+
+    load_fn = mod["load_cubin_data"]
+    load_fn(cubin_bytes)
+
+    launch_add_one_ex = mod["launch_add_one_ex"]
+    n = 256
+    x = torch.arange(n, dtype=torch.float32, device="cuda")
+    y = torch.empty(n, dtype=torch.float32, device="cuda")
+
+    launch_add_one_ex(x, y)
+    expected = x + 1
+    torch.testing.assert_close(y, expected)
+
+
 @pytest.mark.skipif(sys.platform != "linux", reason="CUBIN launcher only 
supported on Linux")
 @pytest.mark.skipif(torch is None, reason="PyTorch not installed")
 @pytest.mark.skipif(not _is_cuda_available(), reason="CUDA not available")

Reply via email to