This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e693452ad8 fix: Complete CHECK update across contrib runtime (#18861)
e693452ad8 is described below
commit e693452ad8afdfb75abab66f73cce890a4e979b7
Author: Ruihang Lai <[email protected]>
AuthorDate: Sun Mar 1 14:15:13 2026 -0500
fix: Complete CHECK update across contrib runtime (#18861)
Replace all references to the old `CHECK` / `CHECK_EQ` / `CHECK_GE`
macro names with their `TVM_FFI_CHECK` equivalents across CUTLASS and
NVSHMEM contrib sources, completing the macro rename introduced by the
TVM FFI refactor. User-facing input validation checks use `ValueError`,
CUTLASS status and workspace checks use `RuntimeError`.
---
python/tvm/contrib/cutlass/attention_operation.py | 14 ++---
src/runtime/contrib/cutlass/fp16_group_gemm.cuh | 18 +++---
.../cutlass/fp16_group_gemm_runner_sm100.cuh | 12 ++--
.../cutlass/fp16_group_gemm_runner_sm90.cuh | 12 ++--
src/runtime/contrib/cutlass/fp8_gemm.cu | 19 +++---
src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu | 18 +++---
.../contrib/cutlass/fp8_groupwise_scaled_gemm.cuh | 73 +++++++++++-----------
.../fp8_groupwise_scaled_gemm_runner_sm100.cuh | 12 ++--
.../fp8_groupwise_scaled_gemm_runner_sm90.cuh | 12 ++--
...p8_groupwise_scaled_group_gemm_runner_sm100.cuh | 12 ++--
.../fp8_groupwise_scaled_group_gemm_sm100.cu | 36 +++++------
src/runtime/contrib/cutlass/gemm_runner.cuh | 12 ++--
src/runtime/contrib/nvshmem/kv_transfer.cu | 71 ++++++++++++---------
13 files changed, 166 insertions(+), 155 deletions(-)
diff --git a/python/tvm/contrib/cutlass/attention_operation.py
b/python/tvm/contrib/cutlass/attention_operation.py
index f4f059daf0..560da4e60e 100644
--- a/python/tvm/contrib/cutlass/attention_operation.py
+++ b/python/tvm/contrib/cutlass/attention_operation.py
@@ -26,7 +26,7 @@ def instantiate_attention_template(attrs):
based on a template and the provided attribute map."""
bias_template = """
- CHECK(${bias}->ndim == 4); // B, N, S, S'
+ TVM_FFI_CHECK(${bias}->ndim == 4, ValueError); // B, N, S, S'
p.attn_bias_ptr = reinterpret_cast<T *>(${bias}->data);
p.bias_strideM = ${bias_strideM};
@@ -46,9 +46,9 @@ def instantiate_attention_template(attrs):
p.query_ptr = reinterpret_cast<T *>(${query}->data);
p.key_ptr = reinterpret_cast<T *>(${key}->data);
p.value_ptr = reinterpret_cast<T *>(${value}->data);
- CHECK(${query}->ndim == 4); // B, S, N, H
- CHECK(${key}->ndim == 4); // B, S', N, H
- CHECK(${value}->ndim == 4); // B, S', N, H'
+ TVM_FFI_CHECK(${query}->ndim == 4, ValueError); // B, S, N, H
+ TVM_FFI_CHECK(${key}->ndim == 4, ValueError); // B, S', N, H
+ TVM_FFI_CHECK(${value}->ndim == 4, ValueError); // B, S', N, H'
// stride for N
p.q_strideH = p.head_dim; // H
@@ -69,7 +69,7 @@ def instantiate_attention_template(attrs):
p.query_ptr = reinterpret_cast<T *>(${qkv}->data);
p.key_ptr = reinterpret_cast<T *>(${qkv}->data) + p.head_dim * p.num_heads;
p.value_ptr = reinterpret_cast<T *>(${qkv}->data) + p.head_dim * p.num_heads
* 2;
- CHECK(${qkv}->ndim == 3); // B, S, NH + NH + NH'
+ TVM_FFI_CHECK(${qkv}->ndim == 3, ValueError); // B, S, NH + NH + NH'
// stride for N
p.q_strideH = p.head_dim; // H
@@ -132,7 +132,7 @@ def instantiate_attention_template(attrs):
p.o_strideM = p.head_dim_value * p.num_heads; // H' * N
- CHECK(out0->ndim == 4); // B, S, N, H'
+ TVM_FFI_CHECK(out0->ndim == 4, ValueError); // B, S, N, H'
${qkv_template}
${bias_template}
@@ -148,7 +148,7 @@ def instantiate_attention_template(attrs):
}();
}
- CHECK(Attention::check_supported(p));
+ TVM_FFI_CHECK(Attention::check_supported(p), RuntimeError);
cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetStream(kDLCUDA,
${query}->device.device_id));
kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes, stream>>>(p);
diff --git a/src/runtime/contrib/cutlass/fp16_group_gemm.cuh
b/src/runtime/contrib/cutlass/fp16_group_gemm.cuh
index 9e1a17eaef..c2a2dd372b 100644
--- a/src/runtime/contrib/cutlass/fp16_group_gemm.cuh
+++ b/src/runtime/contrib/cutlass/fp16_group_gemm.cuh
@@ -38,11 +38,11 @@ void tvm_cutlass_group_gemm_impl(Tensor x, Tensor weight,
Tensor indptr, Tensor
// Workspace is used for storing device-side group gemm arguments and
cutlass internal workspace.
// Recommened size is 4MB.
cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetStream(kDLCUDA,
x->device.device_id));
- CHECK_EQ(x->ndim, 2);
- CHECK_EQ(weight->ndim, 3);
- CHECK_EQ(indptr->ndim, 1);
- CHECK_EQ(workspace->ndim, 1);
- CHECK_EQ(out->ndim, 2);
+ TVM_FFI_CHECK_EQ(x->ndim, 2, ValueError);
+ TVM_FFI_CHECK_EQ(weight->ndim, 3, ValueError);
+ TVM_FFI_CHECK_EQ(indptr->ndim, 1, ValueError);
+ TVM_FFI_CHECK_EQ(workspace->ndim, 1, ValueError);
+ TVM_FFI_CHECK_EQ(out->ndim, 2, ValueError);
int num_groups = weight->shape[0];
int n = weight->shape[1];
int k = weight->shape[2];
@@ -50,16 +50,16 @@ void tvm_cutlass_group_gemm_impl(Tensor x, Tensor weight,
Tensor indptr, Tensor
float beta = 0.0f;
if (DataType(x->dtype) == DataType::Float(16)) {
- CHECK(DataType(weight->dtype) == DataType::Float(16));
- CHECK(DataType(out->dtype) == DataType::Float(16));
+ TVM_FFI_CHECK(DataType(weight->dtype) == DataType::Float(16), ValueError);
+ TVM_FFI_CHECK(DataType(out->dtype) == DataType::Float(16), ValueError);
using Dtype = cutlass::half_t;
CutlassGroupGemm<Arch, Dtype, Dtype, Dtype>::run(
static_cast<Dtype*>(x->data), static_cast<Dtype*>(weight->data),
static_cast<int64_t*>(indptr->data),
static_cast<uint8_t*>(workspace->data),
workspace->shape[0], n, k, num_groups, alpha, beta,
static_cast<Dtype*>(out->data), stream);
} else if (DataType(x->dtype) == DataType::BFloat(16)) {
- CHECK(DataType(weight->dtype) == DataType::BFloat(16));
- CHECK(DataType(out->dtype) == DataType::BFloat(16));
+ TVM_FFI_CHECK(DataType(weight->dtype) == DataType::BFloat(16), ValueError);
+ TVM_FFI_CHECK(DataType(out->dtype) == DataType::BFloat(16), ValueError);
using Dtype = cutlass::bfloat16_t;
CutlassGroupGemm<Arch, Dtype, Dtype, Dtype>::run(
static_cast<Dtype*>(x->data), static_cast<Dtype*>(weight->data),
diff --git a/src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm100.cuh
b/src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm100.cuh
index 34a50fdc54..22a9bea646 100644
--- a/src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm100.cuh
+++ b/src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm100.cuh
@@ -40,11 +40,11 @@
#include "cutlass/gemm/kernel/gemm_universal.hpp"
// clang-format on
-#define CUTLASS_CHECK(status) \
- { \
- cutlass::Status error = status; \
- CHECK(error == cutlass::Status::kSuccess) \
- << "Got cutlass error: " << cutlassGetStatusString(error); \
+#define CUTLASS_CHECK(status) \
+ { \
+ cutlass::Status error = status; \
+ TVM_FFI_CHECK(error == cutlass::Status::kSuccess, RuntimeError) \
+ << "Got cutlass error: " << cutlassGetStatusString(error); \
}
using namespace cute;
@@ -156,7 +156,7 @@ struct CutlassGroupGemmRunner {
hw_info};
Gemm gemm_op;
CUTLASS_CHECK(gemm_op.can_implement(arguments));
- CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments));
+ TVM_FFI_CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments),
RuntimeError);
CUTLASS_CHECK(gemm_op.initialize(arguments, workspace, stream));
CUTLASS_CHECK(gemm_op.run(stream));
}
diff --git a/src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm90.cuh
b/src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm90.cuh
index b894cbc84e..4fc513e3db 100644
--- a/src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm90.cuh
+++ b/src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm90.cuh
@@ -40,11 +40,11 @@
#include "cutlass/gemm/kernel/gemm_universal.hpp"
// clang-format on
-#define CUTLASS_CHECK(status) \
- { \
- cutlass::Status error = status; \
- CHECK(error == cutlass::Status::kSuccess) \
- << "Got cutlass error: " << cutlassGetStatusString(error); \
+#define CUTLASS_CHECK(status) \
+ { \
+ cutlass::Status error = status; \
+ TVM_FFI_CHECK(error == cutlass::Status::kSuccess, RuntimeError) \
+ << "Got cutlass error: " << cutlassGetStatusString(error); \
}
using namespace cute;
@@ -156,7 +156,7 @@ struct CutlassGroupGemmRunner {
hw_info};
Gemm gemm_op;
CUTLASS_CHECK(gemm_op.can_implement(arguments));
- CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments));
+ TVM_FFI_CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments),
RuntimeError);
CUTLASS_CHECK(gemm_op.initialize(arguments, workspace, stream));
CUTLASS_CHECK(gemm_op.run(stream));
}
diff --git a/src/runtime/contrib/cutlass/fp8_gemm.cu
b/src/runtime/contrib/cutlass/fp8_gemm.cu
index d41064efba..02fd34aa10 100644
--- a/src/runtime/contrib/cutlass/fp8_gemm.cu
+++ b/src/runtime/contrib/cutlass/fp8_gemm.cu
@@ -44,20 +44,21 @@ void tvm_cutlass_fp8_gemm(Tensor x, Tensor weight, Tensor
workspace, Tensor alph
// Recommened size is 4MB.
cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetStream(kDLCUDA,
x->device.device_id));
- CHECK_GE(x->ndim, 2);
- CHECK_EQ(weight->ndim, 2);
- CHECK_EQ(workspace->ndim, 1);
- CHECK_GE(out->ndim, 2);
- CHECK_EQ(alpha->dtype.code, kDLFloat);
- CHECK_EQ(alpha->dtype.bits, 32);
- CHECK_EQ(alpha->ndim, 1);
- CHECK_EQ(alpha->shape[0], 1);
+ TVM_FFI_CHECK_GE(x->ndim, 2, ValueError);
+ TVM_FFI_CHECK_EQ(weight->ndim, 2, ValueError);
+ TVM_FFI_CHECK_EQ(workspace->ndim, 1, ValueError);
+ TVM_FFI_CHECK_GE(out->ndim, 2, ValueError);
+ TVM_FFI_CHECK_EQ(alpha->dtype.code, kDLFloat, ValueError);
+ TVM_FFI_CHECK_EQ(alpha->dtype.bits, 32, ValueError);
+ TVM_FFI_CHECK_EQ(alpha->ndim, 1, ValueError);
+ TVM_FFI_CHECK_EQ(alpha->shape[0], 1, ValueError);
int64_t m = 1;
for (int i = 0; i < x->ndim - 1; ++i) {
m *= x->shape[i];
}
int64_t n = weight->shape[0];
- CHECK_EQ(x->shape[x->ndim - 1], weight->shape[1]) << "Only col-major weight
is supported now.";
+ TVM_FFI_CHECK_EQ(x->shape[x->ndim - 1], weight->shape[1], ValueError)
+ << "Only col-major weight is supported now.";
int64_t k = x->shape[x->ndim - 1];
const float* beta = nullptr;
if (m <= 64) {
diff --git a/src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu
b/src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu
index b2e08b7570..adfcaed0c0 100644
--- a/src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu
+++ b/src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu
@@ -47,15 +47,15 @@ void tvm_cutlass_fp8_group_gemm(Tensor x, Tensor weight,
Tensor indptr, Tensor w
// Workspace is used for storing device-side group gemm arguments and
cutlass internal workspace.
// Recommened size is 4MB.
cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetStream(kDLCUDA,
x->device.device_id));
- CHECK_EQ(x->ndim, 2);
- CHECK_EQ(weight->ndim, 3);
- CHECK_EQ(indptr->ndim, 1);
- CHECK_EQ(workspace->ndim, 1);
- CHECK_EQ(out->ndim, 2);
- CHECK_EQ(alpha->dtype.code, kDLFloat);
- CHECK_EQ(alpha->dtype.bits, 32);
- CHECK_EQ(alpha->ndim, 1);
- CHECK_EQ(alpha->shape[0], 1);
+ TVM_FFI_CHECK_EQ(x->ndim, 2, ValueError);
+ TVM_FFI_CHECK_EQ(weight->ndim, 3, ValueError);
+ TVM_FFI_CHECK_EQ(indptr->ndim, 1, ValueError);
+ TVM_FFI_CHECK_EQ(workspace->ndim, 1, ValueError);
+ TVM_FFI_CHECK_EQ(out->ndim, 2, ValueError);
+ TVM_FFI_CHECK_EQ(alpha->dtype.code, kDLFloat, ValueError);
+ TVM_FFI_CHECK_EQ(alpha->dtype.bits, 32, ValueError);
+ TVM_FFI_CHECK_EQ(alpha->ndim, 1, ValueError);
+ TVM_FFI_CHECK_EQ(alpha->shape[0], 1, ValueError);
int num_groups = weight->shape[0];
int n = weight->shape[1];
int k = x->shape[1];
diff --git a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh
b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh
index 07f70f35e0..338a96c8b7 100644
--- a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh
+++ b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh
@@ -42,35 +42,36 @@ void tvm_cutlass_fp8_groupwise_scaled_gemm_impl(Tensor a,
Tensor b, Tensor scale
// Recommened size is 4MB.
cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetStream(kDLCUDA,
a->device.device_id));
- CHECK_GE(a->ndim, 2);
- CHECK_EQ(scales_a->ndim, a->ndim);
- CHECK_EQ(b->ndim, 2);
- CHECK_EQ(scales_b->ndim, 2);
- CHECK_EQ(workspace->ndim, 1);
- CHECK_EQ(out->ndim, a->ndim);
+ TVM_FFI_CHECK_GE(a->ndim, 2, ValueError);
+ TVM_FFI_CHECK_EQ(scales_a->ndim, a->ndim, ValueError);
+ TVM_FFI_CHECK_EQ(b->ndim, 2, ValueError);
+ TVM_FFI_CHECK_EQ(scales_b->ndim, 2, ValueError);
+ TVM_FFI_CHECK_EQ(workspace->ndim, 1, ValueError);
+ TVM_FFI_CHECK_EQ(out->ndim, a->ndim, ValueError);
int64_t m = 1;
for (int64_t i = 0; i < a->ndim - 1; ++i) {
m *= a->shape[i];
}
int64_t n = b->shape[0];
- CHECK_EQ(a->shape[a->ndim - 1], b->shape[1]) << "Only col-major B is
supported now.";
+ TVM_FFI_CHECK_EQ(a->shape[a->ndim - 1], b->shape[1], ValueError)
+ << "Only col-major B is supported now.";
int64_t k = a->shape[a->ndim - 1];
// scales_a is col-major of (*a_shape[:-1], k / block_size)
- CHECK_EQ(scales_a->shape[0] * block_size_1, k);
+ TVM_FFI_CHECK_EQ(scales_a->shape[0] * block_size_1, k, ValueError);
for (int64_t i = 1; i < scales_a->ndim; ++i) {
- CHECK_EQ(scales_a->shape[i], a->shape[i - 1]);
+ TVM_FFI_CHECK_EQ(scales_a->shape[i], a->shape[i - 1], ValueError);
}
// scales_b is col-major of (k / block_size, n / block_size)
- CHECK_EQ((n + block_size_0 - 1) / block_size_0, scales_b->shape[0]);
- CHECK_EQ(scales_b->shape[1] * block_size_1, k);
+ TVM_FFI_CHECK_EQ((n + block_size_0 - 1) / block_size_0, scales_b->shape[0],
ValueError);
+ TVM_FFI_CHECK_EQ(scales_b->shape[1] * block_size_1, k, ValueError);
using tvm::runtime::DataType;
- CHECK_EQ(DataType(a->dtype), DataType::Float8E4M3FN());
- CHECK_EQ(DataType(b->dtype), DataType::Float8E4M3FN());
- CHECK_EQ(DataType(scales_a->dtype), DataType::Float(32));
- CHECK_EQ(DataType(scales_b->dtype), DataType::Float(32));
- CHECK_EQ(DataType(workspace->dtype), DataType::UInt(8));
+ TVM_FFI_CHECK_EQ(DataType(a->dtype), DataType::Float8E4M3FN(), ValueError);
+ TVM_FFI_CHECK_EQ(DataType(b->dtype), DataType::Float8E4M3FN(), ValueError);
+ TVM_FFI_CHECK_EQ(DataType(scales_a->dtype), DataType::Float(32), ValueError);
+ TVM_FFI_CHECK_EQ(DataType(scales_b->dtype), DataType::Float(32), ValueError);
+ TVM_FFI_CHECK_EQ(DataType(workspace->dtype), DataType::UInt(8), ValueError);
if (DataType(out->dtype) == DataType::Float(16)) {
CutlassFP8GroupwiseGemm<Arch, TileShape, ClusterShape,
cutlass::float_e4m3_t,
@@ -107,35 +108,35 @@ void tvm_cutlass_fp8_groupwise_scaled_bmm_impl(Tensor a,
Tensor b, Tensor scales
// Recommened size is 4MB.
cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetStream(kDLCUDA,
a->device.device_id));
- CHECK_EQ(a->ndim, 3);
- CHECK_EQ(scales_a->ndim, 3);
- CHECK_EQ(b->ndim, 3);
- CHECK_EQ(scales_b->ndim, 3);
- CHECK_EQ(workspace->ndim, 1);
- CHECK_EQ(out->ndim, 3);
+ TVM_FFI_CHECK_EQ(a->ndim, 3, ValueError);
+ TVM_FFI_CHECK_EQ(scales_a->ndim, 3, ValueError);
+ TVM_FFI_CHECK_EQ(b->ndim, 3, ValueError);
+ TVM_FFI_CHECK_EQ(scales_b->ndim, 3, ValueError);
+ TVM_FFI_CHECK_EQ(workspace->ndim, 1, ValueError);
+ TVM_FFI_CHECK_EQ(out->ndim, 3, ValueError);
int64_t batch_size = a->shape[0];
int64_t m = a->shape[1];
int64_t n = b->shape[1];
- CHECK_EQ(a->shape[2], b->shape[2]) << "Only col-major B is supported now.";
+ TVM_FFI_CHECK_EQ(a->shape[2], b->shape[2], ValueError) << "Only col-major B
is supported now.";
int64_t k = a->shape[2];
- CHECK_EQ(b->shape[0], batch_size);
- CHECK_EQ(scales_a->shape[0], batch_size);
- CHECK_EQ(scales_b->shape[0], batch_size);
- CHECK_EQ(out->shape[0], batch_size);
+ TVM_FFI_CHECK_EQ(b->shape[0], batch_size, ValueError);
+ TVM_FFI_CHECK_EQ(scales_a->shape[0], batch_size, ValueError);
+ TVM_FFI_CHECK_EQ(scales_b->shape[0], batch_size, ValueError);
+ TVM_FFI_CHECK_EQ(out->shape[0], batch_size, ValueError);
// scales_a is col-major of (batch_size, m, k / block_size)
- CHECK_EQ(scales_a->shape[1] * block_size_1, k);
- CHECK_EQ(scales_a->shape[2], m);
+ TVM_FFI_CHECK_EQ(scales_a->shape[1] * block_size_1, k, ValueError);
+ TVM_FFI_CHECK_EQ(scales_a->shape[2], m, ValueError);
// scales_b is col-major of (k / block_size, n / block_size)
- CHECK_EQ(scales_b->shape[1] * block_size_0, n);
- CHECK_EQ(scales_b->shape[2] * block_size_1, k);
+ TVM_FFI_CHECK_EQ(scales_b->shape[1] * block_size_0, n, ValueError);
+ TVM_FFI_CHECK_EQ(scales_b->shape[2] * block_size_1, k, ValueError);
using tvm::runtime::DataType;
- CHECK_EQ(DataType(a->dtype), DataType::Float8E4M3FN());
- CHECK_EQ(DataType(b->dtype), DataType::Float8E4M3FN());
- CHECK_EQ(DataType(scales_a->dtype), DataType::Float(32));
- CHECK_EQ(DataType(scales_b->dtype), DataType::Float(32));
- CHECK_EQ(DataType(workspace->dtype), DataType::UInt(8));
+ TVM_FFI_CHECK_EQ(DataType(a->dtype), DataType::Float8E4M3FN(), ValueError);
+ TVM_FFI_CHECK_EQ(DataType(b->dtype), DataType::Float8E4M3FN(), ValueError);
+ TVM_FFI_CHECK_EQ(DataType(scales_a->dtype), DataType::Float(32), ValueError);
+ TVM_FFI_CHECK_EQ(DataType(scales_b->dtype), DataType::Float(32), ValueError);
+ TVM_FFI_CHECK_EQ(DataType(workspace->dtype), DataType::UInt(8), ValueError);
if (DataType(out->dtype) == DataType::Float(16)) {
CutlassFP8GroupwiseGemm<Arch, TileShape, ClusterShape,
cutlass::float_e4m3_t,
diff --git
a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm100.cuh
b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm100.cuh
index 87cd8108f9..4c02cdf46e 100644
--- a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm100.cuh
+++ b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm100.cuh
@@ -45,11 +45,11 @@
#include "cutlass/tensor_ref.h"
// clang-format on
-#define CUTLASS_CHECK(status) \
- { \
- cutlass::Status error = status; \
- CHECK(error == cutlass::Status::kSuccess) \
- << "Got cutlass error: " << cutlassGetStatusString(error); \
+#define CUTLASS_CHECK(status) \
+ { \
+ cutlass::Status error = status; \
+ TVM_FFI_CHECK(error == cutlass::Status::kSuccess, RuntimeError) \
+ << "Got cutlass error: " << cutlassGetStatusString(error); \
}
using namespace cute;
@@ -137,7 +137,7 @@ struct CutlassFP8ScaledGroupwiseGemmRunnerSM100 {
Gemm gemm_op;
CUTLASS_CHECK(gemm_op.can_implement(arguments));
- CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments));
+ TVM_FFI_CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments),
RuntimeError);
CUTLASS_CHECK(gemm_op.initialize(arguments, workspace, stream));
CUTLASS_CHECK(gemm_op.run(stream));
}
diff --git
a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm90.cuh
b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm90.cuh
index d5321d157c..69445a780f 100644
--- a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm90.cuh
+++ b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm90.cuh
@@ -45,11 +45,11 @@
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
// clang-format on
-#define CUTLASS_CHECK(status) \
- { \
- cutlass::Status error = status; \
- CHECK(error == cutlass::Status::kSuccess) \
- << "Got cutlass error: " << cutlassGetStatusString(error); \
+#define CUTLASS_CHECK(status) \
+ { \
+ cutlass::Status error = status; \
+ TVM_FFI_CHECK(error == cutlass::Status::kSuccess, RuntimeError) \
+ << "Got cutlass error: " << cutlassGetStatusString(error); \
}
using namespace cute;
@@ -141,7 +141,7 @@ struct CutlassFP8GroupwiseScaledGemmRunner {
Gemm gemm_op;
CUTLASS_CHECK(gemm_op.can_implement(arguments));
- CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments));
+ TVM_FFI_CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments),
RuntimeError);
CUTLASS_CHECK(gemm_op.initialize(arguments, workspace, stream));
CUTLASS_CHECK(gemm_op.run(stream));
}
diff --git
a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_runner_sm100.cuh
b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_runner_sm100.cuh
index 19c6b699aa..e6729b9d96 100644
---
a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_runner_sm100.cuh
+++
b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_runner_sm100.cuh
@@ -40,11 +40,11 @@
#include "cutlass/gemm/kernel/gemm_universal.hpp"
// clang-format on
-#define CUTLASS_CHECK(status) \
- { \
- cutlass::Status error = status; \
- CHECK(error == cutlass::Status::kSuccess) \
- << "Got cutlass error: " << cutlassGetStatusString(error); \
+#define CUTLASS_CHECK(status) \
+ { \
+ cutlass::Status error = status; \
+ TVM_FFI_CHECK(error == cutlass::Status::kSuccess, RuntimeError) \
+ << "Got cutlass error: " << cutlassGetStatusString(error); \
}
using namespace cute;
@@ -127,7 +127,7 @@ struct CutlassFP8ScaledGroupwiseGroupGemmRunnerSM100 {
Gemm gemm_op;
CUTLASS_CHECK(gemm_op.can_implement(arguments));
- CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments));
+ TVM_FFI_CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments),
RuntimeError);
CUTLASS_CHECK(gemm_op.initialize(arguments, workspace, stream));
CUTLASS_CHECK(gemm_op.run(stream));
}
diff --git
a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu
b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu
index 955a01765c..1ab5b41f1d 100644
--- a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu
+++ b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu
@@ -38,32 +38,32 @@ void tvm_fp8_groupwise_scaled_group_gemm_sm100(Tensor a,
Tensor b, Tensor scales
// Workspace is used for storing device-side group gemm arguments and
cutlass internal workspace.
// Recommended size is 4MB.
cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetStream(kDLCUDA,
a->device.device_id));
- CHECK_EQ(a->ndim, 2);
- CHECK_EQ(b->ndim, 3);
- CHECK_EQ(indptr->ndim, 1);
- CHECK_EQ(workspace->ndim, 1);
- CHECK_EQ(out->ndim, 2);
+ TVM_FFI_CHECK_EQ(a->ndim, 2, ValueError);
+ TVM_FFI_CHECK_EQ(b->ndim, 3, ValueError);
+ TVM_FFI_CHECK_EQ(indptr->ndim, 1, ValueError);
+ TVM_FFI_CHECK_EQ(workspace->ndim, 1, ValueError);
+ TVM_FFI_CHECK_EQ(out->ndim, 2, ValueError);
int num_groups = b->shape[0];
int n = b->shape[1];
int k = b->shape[2];
- CHECK_EQ(scales_a->ndim, a->ndim);
- CHECK_EQ(scales_b->ndim, b->ndim);
+ TVM_FFI_CHECK_EQ(scales_a->ndim, a->ndim, ValueError);
+ TVM_FFI_CHECK_EQ(scales_b->ndim, b->ndim, ValueError);
// scales_a is row-major of (m, k / block_size)
- CHECK_EQ((k + block_size_1 - 1) / block_size_1, scales_a->shape[1]);
- CHECK_EQ(scales_a->shape[0], a->shape[0]);
+ TVM_FFI_CHECK_EQ((k + block_size_1 - 1) / block_size_1, scales_a->shape[1],
ValueError);
+ TVM_FFI_CHECK_EQ(scales_a->shape[0], a->shape[0], ValueError);
// scales_b is col-major of (k / block_size, n / block_size)
- CHECK_EQ(scales_b->shape[0], num_groups);
- CHECK_EQ((n + block_size_0 - 1) / block_size_0, scales_b->shape[1]);
- CHECK_EQ((k + block_size_1 - 1) / block_size_1, scales_b->shape[2]);
+ TVM_FFI_CHECK_EQ(scales_b->shape[0], num_groups, ValueError);
+ TVM_FFI_CHECK_EQ((n + block_size_0 - 1) / block_size_0, scales_b->shape[1],
ValueError);
+ TVM_FFI_CHECK_EQ((k + block_size_1 - 1) / block_size_1, scales_b->shape[2],
ValueError);
using tvm::runtime::DataType;
- CHECK_EQ(DataType(a->dtype), DataType::Float8E4M3FN());
- CHECK_EQ(DataType(b->dtype), DataType::Float8E4M3FN());
- CHECK_EQ(DataType(scales_a->dtype), DataType::Float(32));
- CHECK_EQ(DataType(scales_b->dtype), DataType::Float(32));
- CHECK_EQ(DataType(indptr->dtype), DataType::Int(64));
- CHECK_EQ(DataType(workspace->dtype), DataType::UInt(8));
+ TVM_FFI_CHECK_EQ(DataType(a->dtype), DataType::Float8E4M3FN(), ValueError);
+ TVM_FFI_CHECK_EQ(DataType(b->dtype), DataType::Float8E4M3FN(), ValueError);
+ TVM_FFI_CHECK_EQ(DataType(scales_a->dtype), DataType::Float(32), ValueError);
+ TVM_FFI_CHECK_EQ(DataType(scales_b->dtype), DataType::Float(32), ValueError);
+ TVM_FFI_CHECK_EQ(DataType(indptr->dtype), DataType::Int(64), ValueError);
+ TVM_FFI_CHECK_EQ(DataType(workspace->dtype), DataType::UInt(8), ValueError);
if (DataType(out->dtype) == DataType::Float(16)) {
using Dtype = cutlass::half_t;
diff --git a/src/runtime/contrib/cutlass/gemm_runner.cuh
b/src/runtime/contrib/cutlass/gemm_runner.cuh
index 0ca1d1be02..b0907bfe29 100644
--- a/src/runtime/contrib/cutlass/gemm_runner.cuh
+++ b/src/runtime/contrib/cutlass/gemm_runner.cuh
@@ -40,11 +40,11 @@
#include "cutlass/gemm/kernel/gemm_universal.hpp"
// clang-format on
-#define CUTLASS_CHECK(status) \
- { \
- cutlass::Status error = status; \
- CHECK(error == cutlass::Status::kSuccess) \
- << "Got cutlass error: " << cutlassGetStatusString(error); \
+#define CUTLASS_CHECK(status) \
+ { \
+ cutlass::Status error = status; \
+ TVM_FFI_CHECK(error == cutlass::Status::kSuccess, RuntimeError) \
+ << "Got cutlass error: " << cutlassGetStatusString(error); \
}
using namespace cute;
@@ -130,7 +130,7 @@ struct CutlassGemmRunner {
Gemm gemm_op;
CUTLASS_CHECK(gemm_op.can_implement(arguments));
- CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments));
+ TVM_FFI_CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments),
RuntimeError);
CUTLASS_CHECK(gemm_op.initialize(arguments, workspace, stream));
CUTLASS_CHECK(gemm_op.run(stream));
}
diff --git a/src/runtime/contrib/nvshmem/kv_transfer.cu
b/src/runtime/contrib/nvshmem/kv_transfer.cu
index 34916a614a..1338ea3e6e 100644
--- a/src/runtime/contrib/nvshmem/kv_transfer.cu
+++ b/src/runtime/contrib/nvshmem/kv_transfer.cu
@@ -180,41 +180,48 @@ __global__ void KVTransferPageToPage(T* remote_pages, T*
local_pages, int32_t* r
int _KVTransfer(DLTensor* remote_pages, DLTensor* k, DLTensor* v, DLTensor*
remote_position_map,
DLTensor* remote_tp_group_pe_offset, TVMStreamHandle
transfer_stream) {
- CHECK_EQ(remote_pages->device.device_type, kDLCUDA)
+ TVM_FFI_CHECK_EQ(remote_pages->device.device_type, kDLCUDA, ValueError)
<< "The device of remote_pages matrix must be CUDA.";
- CHECK_EQ(k->device.device_type, kDLCUDA) << "The device of k matrix must be
CUDA.";
- CHECK_EQ(v->device.device_type, kDLCUDA) << "The device of v matrix must be
CUDA.";
- CHECK_EQ(remote_position_map->device.device_type, kDLCUDA)
+ TVM_FFI_CHECK_EQ(k->device.device_type, kDLCUDA, ValueError)
+ << "The device of k matrix must be CUDA.";
+ TVM_FFI_CHECK_EQ(v->device.device_type, kDLCUDA, ValueError)
+ << "The device of v matrix must be CUDA.";
+ TVM_FFI_CHECK_EQ(remote_position_map->device.device_type, kDLCUDA,
ValueError)
<< "The device of remote_position_map matrix must be CUDA.";
size_t dev_id = remote_pages->device.device_id;
- CHECK_EQ(k->device.device_id, dev_id)
+ TVM_FFI_CHECK_EQ(k->device.device_id, dev_id, ValueError)
<< "The device id of remote_pages and k matrix doesn't match.";
- CHECK_EQ(v->device.device_id, dev_id)
+ TVM_FFI_CHECK_EQ(v->device.device_id, dev_id, ValueError)
<< "The device id of remote_pages and v matrix doesn't match.";
- CHECK_EQ(remote_position_map->device.device_id, dev_id)
+ TVM_FFI_CHECK_EQ(remote_position_map->device.device_id, dev_id, ValueError)
<< "The device id of remote_pages and remote_position_map matrix doesn't
match.";
- CHECK_EQ(remote_tp_group_pe_offset->device.device_id, dev_id)
+ TVM_FFI_CHECK_EQ(remote_tp_group_pe_offset->device.device_id, dev_id,
ValueError)
<< "The device id of remote_pages and remote_tp_group_pe_offset matrix
doesn't match.";
- CHECK_EQ(remote_pages->ndim, 5);
+ TVM_FFI_CHECK_EQ(remote_pages->ndim, 5, ValueError);
int remote_num_pages = remote_pages->shape[0];
int remote_num_kv_head = remote_pages->shape[2];
int page_size = remote_pages->shape[3];
int head_dim = remote_pages->shape[4];
- CHECK_GE(k->ndim, 3);
+ TVM_FFI_CHECK_GE(k->ndim, 3, ValueError);
int kv_len = k->shape[k->ndim - 3];
int local_num_kv_heads = k->shape[k->ndim - 2];
- CHECK_EQ(head_dim, k->shape[k->ndim - 1]);
+ TVM_FFI_CHECK_EQ(head_dim, k->shape[k->ndim - 1], ValueError);
- CHECK_GE(v->ndim, 3);
- CHECK_EQ(kv_len, v->shape[v->ndim - 3]);
- CHECK_EQ(local_num_kv_heads, v->shape[v->ndim - 2]);
- CHECK_EQ(head_dim, v->shape[v->ndim - 1]);
+ TVM_FFI_CHECK_GE(v->ndim, 3, ValueError);
+ TVM_FFI_CHECK_EQ(kv_len, v->shape[v->ndim - 3], ValueError);
+ TVM_FFI_CHECK_EQ(local_num_kv_heads, v->shape[v->ndim - 2], ValueError);
+ TVM_FFI_CHECK_EQ(head_dim, v->shape[v->ndim - 1], ValueError);
- CHECK(remote_pages->dtype.lanes == 1 && k->dtype.lanes == 1 &&
v->dtype.lanes == 1);
- CHECK(remote_pages->dtype.bits == k->dtype.bits && remote_pages->dtype.code
== k->dtype.code);
- CHECK(remote_pages->dtype.bits == v->dtype.bits && remote_pages->dtype.code
== v->dtype.code);
+ TVM_FFI_CHECK(remote_pages->dtype.lanes == 1 && k->dtype.lanes == 1 &&
v->dtype.lanes == 1,
+ ValueError);
+ TVM_FFI_CHECK(
+ remote_pages->dtype.bits == k->dtype.bits && remote_pages->dtype.code ==
k->dtype.code,
+ ValueError);
+ TVM_FFI_CHECK(
+ remote_pages->dtype.bits == v->dtype.bits && remote_pages->dtype.code ==
v->dtype.code,
+ ValueError);
int local_tp_rank;
tvm::runtime::DiscoWorker* worker =
tvm::runtime::ThreadLocalDiscoWorker::Get()->worker;
if (worker == nullptr) {
@@ -258,34 +265,36 @@ int _KVTransfer(DLTensor* remote_pages, DLTensor* k,
DLTensor* v, DLTensor* remo
int _KVTransferPageToPage(DLTensor* remote_pages, DLTensor* local_pages,
DLTensor* remote_position_map, DLTensor*
local_position_map,
DLTensor* remote_tp_group_pe_offset, TVMStreamHandle
transfer_stream) {
- CHECK_EQ(remote_pages->device.device_type, kDLCUDA)
+ TVM_FFI_CHECK_EQ(remote_pages->device.device_type, kDLCUDA, ValueError)
<< "The device of remote_pages matrix must be CUDA.";
- CHECK_EQ(local_pages->device.device_type, kDLCUDA) << "The device of k
matrix must be CUDA.";
- CHECK_EQ(remote_position_map->device.device_type, kDLCUDA)
+ TVM_FFI_CHECK_EQ(local_pages->device.device_type, kDLCUDA, ValueError)
+ << "The device of k matrix must be CUDA.";
+ TVM_FFI_CHECK_EQ(remote_position_map->device.device_type, kDLCUDA,
ValueError)
<< "The device of remote_position_map matrix must be CUDA.";
size_t dev_id = remote_pages->device.device_id;
- CHECK_EQ(local_pages->device.device_id, dev_id)
+ TVM_FFI_CHECK_EQ(local_pages->device.device_id, dev_id, ValueError)
<< "The device id of remote_pages and k matrix doesn't match.";
- CHECK_EQ(remote_position_map->device.device_id, dev_id)
+ TVM_FFI_CHECK_EQ(remote_position_map->device.device_id, dev_id, ValueError)
<< "The device id of remote_pages and remote_position_map matrix doesn't
match.";
- CHECK_EQ(remote_tp_group_pe_offset->device.device_id, dev_id)
+ TVM_FFI_CHECK_EQ(remote_tp_group_pe_offset->device.device_id, dev_id,
ValueError)
<< "The device id of remote_pages and remote_tp_group_pe_offset matrix
doesn't match.";
- CHECK_EQ(remote_pages->ndim, 5);
+ TVM_FFI_CHECK_EQ(remote_pages->ndim, 5, ValueError);
int remote_num_kv_head = remote_pages->shape[2];
int page_size = remote_pages->shape[3];
int head_dim = remote_pages->shape[4];
- CHECK_GE(local_pages->ndim, 5);
+ TVM_FFI_CHECK_GE(local_pages->ndim, 5, ValueError);
int local_num_kv_heads = local_pages->shape[2];
- CHECK_EQ(head_dim, local_pages->shape[4]);
+ TVM_FFI_CHECK_EQ(head_dim, local_pages->shape[4], ValueError);
- CHECK_EQ(remote_position_map->ndim, 1);
+ TVM_FFI_CHECK_EQ(remote_position_map->ndim, 1, ValueError);
int ntokens = remote_position_map->shape[0];
- CHECK(remote_pages->dtype.lanes == 1 && local_pages->dtype.lanes == 1);
- CHECK(remote_pages->dtype.bits == local_pages->dtype.bits &&
- remote_pages->dtype.code == local_pages->dtype.code);
+ TVM_FFI_CHECK(remote_pages->dtype.lanes == 1 && local_pages->dtype.lanes ==
1, ValueError);
+ TVM_FFI_CHECK(remote_pages->dtype.bits == local_pages->dtype.bits &&
+ remote_pages->dtype.code == local_pages->dtype.code,
+ ValueError);
int local_tp_rank;
tvm::runtime::DiscoWorker* worker =
tvm::runtime::ThreadLocalDiscoWorker::Get()->worker;