This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 10cb004 [CUDA] Isolate unified api to only in cubin launcher (#408)
10cb004 is described below
commit 10cb0048cef8f3a37282f403cded9eb96aa59464
Author: Tianqi Chen <[email protected]>
AuthorDate: Mon Jan 12 21:08:36 2026 -0500
[CUDA] Isolate unified api to only in cubin launcher (#408)
This PR isolates out the unified api to be only local to cubin launcher.
Background: it is generally error-prone to mix the driver and runtime
API. The particular unified api switch was mainly meant to be used in
cubin launcher for a narrow set of cuda versions(around 12.8 ish to
13.0).
However, we would like the most generic macros like
TVM_FFI_CHECK_CUDA_ERROR to be specific to runtime API. We should
revisit if we should simply deprecate driver API usages for better
maintainability.
---------
Co-authored-by: Junru Shao <[email protected]>
---
.github/workflows/ci_test.yml | 2 +-
docs/guides/cubin_launcher.rst | 2 +-
.../dynamic_cubin/src/lib_dynamic.cc | 4 ++--
.../embedded_cubin/cpp_embed/src/lib_embedded.cc | 4 ++--
.../embed_with_tvm_ffi/src/lib_embedded.cc | 4 ++--
.../include_bin2c/src/lib_embedded.cc | 4 ++--
examples/cubin_launcher/example_nvrtc_cubin.py | 6 ++----
examples/cubin_launcher/example_triton_cubin.py | 3 +--
include/tvm/ffi/extra/cuda/base.h | 22 ++++++++++++++++++++++
include/tvm/ffi/extra/cuda/cubin_launcher.h | 17 ++++++++---------
include/tvm/ffi/extra/cuda/device_guard.h | 2 +-
include/tvm/ffi/extra/cuda/internal/unified_api.h | 8 +++++---
tests/python/test_cubin_launcher.py | 8 ++++----
13 files changed, 53 insertions(+), 33 deletions(-)
diff --git a/.github/workflows/ci_test.yml b/.github/workflows/ci_test.yml
index 4191f01..88e7157 100644
--- a/.github/workflows/ci_test.yml
+++ b/.github/workflows/ci_test.yml
@@ -51,7 +51,7 @@ jobs:
id: cpp_files
run: |
FILES=$(git diff --name-only --diff-filter=ACMR origin/${{
github.base_ref }}...HEAD -- \
- '*.c' '*.cc' '*.cpp' '*.cxx' | tr '\n' ' ')
+ src/ tests/ | grep -E '\.(c|cc|cpp|cxx)$' | tr '\n' ' ')
echo "files=$FILES" >> $GITHUB_OUTPUT
[ -n "$FILES" ] && echo "changed=true" >> $GITHUB_OUTPUT || echo
"changed=false" >> $GITHUB_OUTPUT
diff --git a/docs/guides/cubin_launcher.rst b/docs/guides/cubin_launcher.rst
index 8407eb5..386ca16 100644
--- a/docs/guides/cubin_launcher.rst
+++ b/docs/guides/cubin_launcher.rst
@@ -376,7 +376,7 @@ To use dynamic shared memory, specify the size in the
:cpp:func:`tvm::ffi::Cubin
// Allocate 1KB of dynamic shared memory
uint32_t shared_mem_bytes = 1024;
- cudaError_t result = kernel.Launch(args, grid, block, stream,
shared_mem_bytes);
+ TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(kernel.Launch(args, grid, block,
stream, shared_mem_bytes));
Integration with Different Compilers
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
diff --git a/examples/cubin_launcher/dynamic_cubin/src/lib_dynamic.cc
b/examples/cubin_launcher/dynamic_cubin/src/lib_dynamic.cc
index 1a39383..ef5933a 100644
--- a/examples/cubin_launcher/dynamic_cubin/src/lib_dynamic.cc
+++ b/examples/cubin_launcher/dynamic_cubin/src/lib_dynamic.cc
@@ -85,7 +85,7 @@ void AddOne(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
// Launch kernel
tvm::ffi::cuda_api::ResultType result = g_add_one_kernel->Launch(args, grid,
block, stream);
- TVM_FFI_CHECK_CUDA_ERROR(result);
+ TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(result);
}
} // namespace cubin_dynamic
@@ -125,7 +125,7 @@ void MulTwo(tvm::ffi::TensorView x, tvm::ffi::TensorView y)
{
// Launch kernel
tvm::ffi::cuda_api::ResultType result = g_mul_two_kernel->Launch(args, grid,
block, stream);
- TVM_FFI_CHECK_CUDA_ERROR(result);
+ TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(result);
}
// Export TVM-FFI functions
diff --git
a/examples/cubin_launcher/embedded_cubin/cpp_embed/src/lib_embedded.cc
b/examples/cubin_launcher/embedded_cubin/cpp_embed/src/lib_embedded.cc
index 615bdc9..418338d 100644
--- a/examples/cubin_launcher/embedded_cubin/cpp_embed/src/lib_embedded.cc
+++ b/examples/cubin_launcher/embedded_cubin/cpp_embed/src/lib_embedded.cc
@@ -73,7 +73,7 @@ void AddOne(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
// Launch kernel
tvm::ffi::cuda_api::ResultType result = kernel.Launch(args, grid, block,
stream);
- TVM_FFI_CHECK_CUDA_ERROR(result);
+ TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(result);
}
} // namespace cubin_embedded
@@ -112,7 +112,7 @@ void MulTwo(tvm::ffi::TensorView x, tvm::ffi::TensorView y)
{
// Launch kernel
tvm::ffi::cuda_api::ResultType result = kernel.Launch(args, grid, block,
stream);
- TVM_FFI_CHECK_CUDA_ERROR(result);
+ TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(result);
}
// Export TVM-FFI functions
diff --git
a/examples/cubin_launcher/embedded_cubin/embed_with_tvm_ffi/src/lib_embedded.cc
b/examples/cubin_launcher/embedded_cubin/embed_with_tvm_ffi/src/lib_embedded.cc
index ea410a6..55520d4 100644
---
a/examples/cubin_launcher/embedded_cubin/embed_with_tvm_ffi/src/lib_embedded.cc
+++
b/examples/cubin_launcher/embedded_cubin/embed_with_tvm_ffi/src/lib_embedded.cc
@@ -70,7 +70,7 @@ void AddOne(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
// Launch kernel
tvm::ffi::cuda_api::ResultType result = kernel.Launch(args, grid, block,
stream);
- TVM_FFI_CHECK_CUDA_ERROR(result);
+ TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(result);
}
} // namespace cubin_embedded
@@ -109,7 +109,7 @@ void MulTwo(tvm::ffi::TensorView x, tvm::ffi::TensorView y)
{
// Launch kernel
tvm::ffi::cuda_api::ResultType result = kernel.Launch(args, grid, block,
stream);
- TVM_FFI_CHECK_CUDA_ERROR(result);
+ TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(result);
}
// Export TVM-FFI functions
diff --git
a/examples/cubin_launcher/embedded_cubin/include_bin2c/src/lib_embedded.cc
b/examples/cubin_launcher/embedded_cubin/include_bin2c/src/lib_embedded.cc
index 3c87889..a3ceec8 100644
--- a/examples/cubin_launcher/embedded_cubin/include_bin2c/src/lib_embedded.cc
+++ b/examples/cubin_launcher/embedded_cubin/include_bin2c/src/lib_embedded.cc
@@ -70,7 +70,7 @@ void AddOne(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
// Launch kernel
tvm::ffi::cuda_api::ResultType result = kernel.Launch(args, grid, block,
stream);
- TVM_FFI_CHECK_CUDA_ERROR(result);
+ TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(result);
}
} // namespace cubin_embedded
@@ -109,7 +109,7 @@ void MulTwo(tvm::ffi::TensorView x, tvm::ffi::TensorView y)
{
// Launch kernel
tvm::ffi::cuda_api::ResultType result = kernel.Launch(args, grid, block,
stream);
- TVM_FFI_CHECK_CUDA_ERROR(result);
+ TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(result);
}
// Export TVM-FFI functions
diff --git a/examples/cubin_launcher/example_nvrtc_cubin.py
b/examples/cubin_launcher/example_nvrtc_cubin.py
index 04668e6..8d04e40 100644
--- a/examples/cubin_launcher/example_nvrtc_cubin.py
+++ b/examples/cubin_launcher/example_nvrtc_cubin.py
@@ -129,8 +129,7 @@ def use_cubin_kernel(cubin_bytes: bytes) -> int:
cudaStream_t stream =
static_cast<cudaStream_t>(TVMFFIEnvGetStream(device.device_type,
device.device_id));
// Launch kernel
- cudaError_t result = kernel.Launch(args, grid, block, stream);
- TVM_FFI_CHECK_CUDA_ERROR(result);
+ TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(kernel.Launch(args, grid, block,
stream));
}
void MulTwo(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
@@ -158,8 +157,7 @@ def use_cubin_kernel(cubin_bytes: bytes) -> int:
cudaStream_t stream =
static_cast<cudaStream_t>(TVMFFIEnvGetStream(device.device_type,
device.device_id));
// Launch kernel
- cudaError_t result = kernel.Launch(args, grid, block, stream);
- TVM_FFI_CHECK_CUDA_ERROR(result);
+ TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(kernel.Launch(args, grid, block,
stream));
}
} // namespace nvrtc_loader
diff --git a/examples/cubin_launcher/example_triton_cubin.py
b/examples/cubin_launcher/example_triton_cubin.py
index 8d92440..6654289 100644
--- a/examples/cubin_launcher/example_triton_cubin.py
+++ b/examples/cubin_launcher/example_triton_cubin.py
@@ -133,8 +133,7 @@ def use_cubin_kernel(cubin_bytes: bytes) -> int:
DLDevice device = x.device();
cudaStream_t stream =
static_cast<cudaStream_t>(TVMFFIEnvGetStream(device.device_type,
device.device_id));
- cudaError_t result = kernel.Launch(args, grid, block, stream);
- TVM_FFI_CHECK_CUDA_ERROR(result);
+ TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(kernel.Launch(args, grid, block,
stream));
}
} // namespace triton_loader
diff --git a/include/tvm/ffi/extra/cuda/base.h
b/include/tvm/ffi/extra/cuda/base.h
index d8ea486..3ce1529 100644
--- a/include/tvm/ffi/extra/cuda/base.h
+++ b/include/tvm/ffi/extra/cuda/base.h
@@ -23,9 +23,31 @@
#ifndef TVM_FFI_EXTRA_CUDA_BASE_H_
#define TVM_FFI_EXTRA_CUDA_BASE_H_
+#include <cuda_runtime.h>
+#include <tvm/ffi/error.h>
+
namespace tvm {
namespace ffi {
+/*!
+ * \brief Macro for checking CUDA runtime API errors.
+ *
+ * This macro checks the return value of CUDA runtime API calls and throws
+ * a RuntimeError with detailed error information if the call fails.
+ *
+ * \param stmt The CUDA runtime API call to check.
+ */
+#define TVM_FFI_CHECK_CUDA_ERROR(stmt)
\
+ do {
\
+ cudaError_t __err = (stmt);
\
+ if (__err != cudaSuccess) {
\
+ const char* __err_name = cudaGetErrorName(__err);
\
+ const char* __err_str = cudaGetErrorString(__err);
\
+ TVM_FFI_THROW(RuntimeError) << "CUDA Runtime Error: " << __err_name << "
(" \
+ << static_cast<int>(__err) << "): " <<
__err_str; \
+ }
\
+ } while (0)
+
/*!
* \brief A simple 3D dimension type for CUDA kernel launch configuration.
*
diff --git a/include/tvm/ffi/extra/cuda/cubin_launcher.h
b/include/tvm/ffi/extra/cuda/cubin_launcher.h
index c910e89..f1aabde 100644
--- a/include/tvm/ffi/extra/cuda/cubin_launcher.h
+++ b/include/tvm/ffi/extra/cuda/cubin_launcher.h
@@ -29,7 +29,7 @@
#ifndef TVM_FFI_EXTRA_CUDA_CUBIN_LAUNCHER_H_
#define TVM_FFI_EXTRA_CUDA_CUBIN_LAUNCHER_H_
-#include <cuda.h>
+#include <cuda.h> // NOLINT(clang-diagnostic-error)
#include <cuda_runtime.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/extra/c_env_api.h>
@@ -234,7 +234,7 @@ namespace ffi {
* TVMFFIEnvGetStream(device.device_type, device.device_id));
*
* cudaError_t result = kernel.Launch(args, grid, block, stream);
- * TVM_FFI_CHECK_CUDA_ERROR(result);
+ * TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(result);
* }
* \endcode
*
@@ -295,7 +295,7 @@ class CubinModule {
* \param bytes CUBIN binary data as a Bytes object.
*/
explicit CubinModule(const Bytes& bytes) {
- TVM_FFI_CHECK_CUDA_ERROR(cuda_api::LoadLibrary(&library_, bytes.data()));
+ TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(cuda_api::LoadLibrary(&library_,
bytes.data()));
}
/*!
@@ -305,7 +305,7 @@ class CubinModule {
* \note The `code` buffer points to an ELF image.
*/
explicit CubinModule(const char* code) {
- TVM_FFI_CHECK_CUDA_ERROR(cuda_api::LoadLibrary(&library_, code));
+ TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(cuda_api::LoadLibrary(&library_,
code));
}
/*!
@@ -315,7 +315,7 @@ class CubinModule {
* \note The `code` buffer points to an ELF image.
*/
explicit CubinModule(const unsigned char* code) {
- TVM_FFI_CHECK_CUDA_ERROR(cuda_api::LoadLibrary(&library_, code));
+ TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(cuda_api::LoadLibrary(&library_,
code));
}
/*! \brief Destructor unloads the library */
@@ -418,7 +418,7 @@ class CubinModule {
* // Launch on stream
* cudaStream_t stream = ...;
* cudaError_t result = kernel.Launch(args, grid, block, stream);
- * TVM_FFI_CHECK_CUDA_ERROR(result);
+ * TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(result);
* \endcode
*
* \note This class is movable but not copyable.
@@ -434,7 +434,7 @@ class CubinKernel {
* \param name Name of the kernel function.
*/
CubinKernel(cuda_api::LibraryHandle library, const char* name) {
- TVM_FFI_CHECK_CUDA_ERROR(cuda_api::GetKernel(&kernel_, library, name));
+ TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(cuda_api::GetKernel(&kernel_,
library, name));
}
/*! \brief Destructor (kernel handle doesn't need explicit cleanup) */
@@ -464,8 +464,7 @@ class CubinKernel {
* \par Error Checking
* Always check the returned cudaError_t:
* \code{.cpp}
- * cudaError_t result = kernel.Launch(args, grid, block, stream);
- * TVM_FFI_CHECK_CUDA_ERROR(result);
+ * TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(kernel.Launch(args, grid, block,
stream));
* \endcode
*
* \param args Array of pointers to kernel arguments (must point to actual
values).
diff --git a/include/tvm/ffi/extra/cuda/device_guard.h
b/include/tvm/ffi/extra/cuda/device_guard.h
index 1b44cf6..0158688 100644
--- a/include/tvm/ffi/extra/cuda/device_guard.h
+++ b/include/tvm/ffi/extra/cuda/device_guard.h
@@ -23,7 +23,7 @@
#ifndef TVM_FFI_EXTRA_CUDA_DEVICE_GUARD_H_
#define TVM_FFI_EXTRA_CUDA_DEVICE_GUARD_H_
-#include <tvm/ffi/extra/cuda/internal/unified_api.h>
+#include <tvm/ffi/extra/cuda/base.h>
namespace tvm {
namespace ffi {
diff --git a/include/tvm/ffi/extra/cuda/internal/unified_api.h
b/include/tvm/ffi/extra/cuda/internal/unified_api.h
index 302d8ef..7b16150 100644
--- a/include/tvm/ffi/extra/cuda/internal/unified_api.h
+++ b/include/tvm/ffi/extra/cuda/internal/unified_api.h
@@ -74,7 +74,7 @@ using DeviceAttrType = CUdevice_attribute;
constexpr ResultType kSuccess = CUDA_SUCCESS;
// Driver API Functions
-#define _TVM_FFI_CUDA_FUNC(name) cu##name
+#define _TVM_FFI_CUDA_FUNC(name) cu##name //
NOLINT(bugprone-reserved-identifier)
#else
@@ -110,7 +110,9 @@ inline void GetErrorString(ResultType err, const char**
name, const char** str)
#endif
}
-#define TVM_FFI_CHECK_CUDA_ERROR(stmt)
\
+// this macro is only used to check cuda errors in cubin launcher where
+// we might switch between driver and runtime API.
+#define TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(stmt)
\
do {
\
::tvm::ffi::cuda_api::ResultType __err = (stmt);
\
if (__err != ::tvm::ffi::cuda_api::kSuccess) {
\
@@ -143,7 +145,7 @@ inline DeviceHandle GetDeviceHandle(int device_id) {
CUdevice dev;
// Note: We use CHECK here because this conversion usually shouldn't fail if
ID is valid
// and we need to return a value.
- TVM_FFI_CHECK_CUDA_ERROR(cuDeviceGet(&dev, device_id));
+ TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(cuDeviceGet(&dev, device_id));
return dev;
#else
return device_id;
diff --git a/tests/python/test_cubin_launcher.py
b/tests/python/test_cubin_launcher.py
index d2e4ff0..014511a 100644
--- a/tests/python/test_cubin_launcher.py
+++ b/tests/python/test_cubin_launcher.py
@@ -158,8 +158,8 @@ void LaunchAddOne(tvm::ffi::TensorView x,
tvm::ffi::TensorView y) {
DLDevice device = x.device();
cudaStream_t stream =
static_cast<cudaStream_t>(TVMFFIEnvGetStream(device.device_type,
device.device_id));
- cudaError_t result = g_kernel_add_one->Launch(args, grid, block, stream);
- TVM_FFI_CHECK_CUDA_ERROR(result);
+ auto result = g_kernel_add_one->Launch(args, grid, block, stream);
+ TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(result);
}
void LaunchMulTwo(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
@@ -184,8 +184,8 @@ void LaunchMulTwo(tvm::ffi::TensorView x,
tvm::ffi::TensorView y) {
DLDevice device = x.device();
cudaStream_t stream =
static_cast<cudaStream_t>(TVMFFIEnvGetStream(device.device_type,
device.device_id));
- cudaError_t result = g_kernel_mul_two->Launch(args, grid, block, stream);
- TVM_FFI_CHECK_CUDA_ERROR(result);
+ auto result = g_kernel_mul_two->Launch(args, grid, block, stream);
+ TVM_FFI_CHECK_CUBIN_LAUNCHER_CUDA_ERROR(result);
}
TVM_FFI_DLL_EXPORT_TYPED_FUNC(load_cubin_data, cubin_test::LoadCubinData);