This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e693452ad8 fix: Complete CHECK update across contrib runtime (#18861)
e693452ad8 is described below

commit e693452ad8afdfb75abab66f73cce890a4e979b7
Author: Ruihang Lai <[email protected]>
AuthorDate: Sun Mar 1 14:15:13 2026 -0500

    fix: Complete CHECK update across contrib runtime (#18861)
    
    Replace all references to the old `CHECK` / `CHECK_EQ` / `CHECK_GE`
    macro names with their `TVM_FFI_CHECK` equivalents across CUTLASS and
    NVSHMEM contrib sources, completing the macro rename introduced by the
    TVM FFI refactor. User-facing input validation checks use `ValueError`,
    CUTLASS status and workspace checks use `RuntimeError`.
---
 python/tvm/contrib/cutlass/attention_operation.py  | 14 ++---
 src/runtime/contrib/cutlass/fp16_group_gemm.cuh    | 18 +++---
 .../cutlass/fp16_group_gemm_runner_sm100.cuh       | 12 ++--
 .../cutlass/fp16_group_gemm_runner_sm90.cuh        | 12 ++--
 src/runtime/contrib/cutlass/fp8_gemm.cu            | 19 +++---
 src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu | 18 +++---
 .../contrib/cutlass/fp8_groupwise_scaled_gemm.cuh  | 73 +++++++++++-----------
 .../fp8_groupwise_scaled_gemm_runner_sm100.cuh     | 12 ++--
 .../fp8_groupwise_scaled_gemm_runner_sm90.cuh      | 12 ++--
 ...p8_groupwise_scaled_group_gemm_runner_sm100.cuh | 12 ++--
 .../fp8_groupwise_scaled_group_gemm_sm100.cu       | 36 +++++------
 src/runtime/contrib/cutlass/gemm_runner.cuh        | 12 ++--
 src/runtime/contrib/nvshmem/kv_transfer.cu         | 71 ++++++++++++---------
 13 files changed, 166 insertions(+), 155 deletions(-)

diff --git a/python/tvm/contrib/cutlass/attention_operation.py 
b/python/tvm/contrib/cutlass/attention_operation.py
index f4f059daf0..560da4e60e 100644
--- a/python/tvm/contrib/cutlass/attention_operation.py
+++ b/python/tvm/contrib/cutlass/attention_operation.py
@@ -26,7 +26,7 @@ def instantiate_attention_template(attrs):
     based on a template and the provided attribute map."""
 
     bias_template = """
-  CHECK(${bias}->ndim == 4); // B, N, S, S'
+  TVM_FFI_CHECK(${bias}->ndim == 4, ValueError); // B, N, S, S'
 
   p.attn_bias_ptr = reinterpret_cast<T *>(${bias}->data);
   p.bias_strideM = ${bias_strideM};
@@ -46,9 +46,9 @@ def instantiate_attention_template(attrs):
   p.query_ptr = reinterpret_cast<T *>(${query}->data);
   p.key_ptr = reinterpret_cast<T *>(${key}->data);
   p.value_ptr = reinterpret_cast<T *>(${value}->data);
-  CHECK(${query}->ndim == 4); // B, S, N, H
-  CHECK(${key}->ndim == 4); // B, S', N, H
-  CHECK(${value}->ndim == 4); // B, S', N, H'
+  TVM_FFI_CHECK(${query}->ndim == 4, ValueError); // B, S, N, H
+  TVM_FFI_CHECK(${key}->ndim == 4, ValueError); // B, S', N, H
+  TVM_FFI_CHECK(${value}->ndim == 4, ValueError); // B, S', N, H'
 
   // stride for N
   p.q_strideH = p.head_dim; // H
@@ -69,7 +69,7 @@ def instantiate_attention_template(attrs):
   p.query_ptr = reinterpret_cast<T *>(${qkv}->data);
   p.key_ptr = reinterpret_cast<T *>(${qkv}->data) + p.head_dim * p.num_heads;
   p.value_ptr = reinterpret_cast<T *>(${qkv}->data) + p.head_dim * p.num_heads 
* 2;
-  CHECK(${qkv}->ndim == 3); // B, S, NH + NH + NH'
+  TVM_FFI_CHECK(${qkv}->ndim == 3, ValueError); // B, S, NH + NH + NH'
 
   // stride for N
   p.q_strideH = p.head_dim; // H
@@ -132,7 +132,7 @@ def instantiate_attention_template(attrs):
 
 
   p.o_strideM = p.head_dim_value * p.num_heads; // H' * N
-  CHECK(out0->ndim == 4); // B, S, N, H'
+  TVM_FFI_CHECK(out0->ndim == 4, ValueError); // B, S, N, H'
 
   ${qkv_template}
   ${bias_template}
@@ -148,7 +148,7 @@ def instantiate_attention_template(attrs):
     }();
   }
 
-  CHECK(Attention::check_supported(p));
+  TVM_FFI_CHECK(Attention::check_supported(p), RuntimeError);
   cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetStream(kDLCUDA, 
${query}->device.device_id));
 
   kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes, stream>>>(p);
diff --git a/src/runtime/contrib/cutlass/fp16_group_gemm.cuh 
b/src/runtime/contrib/cutlass/fp16_group_gemm.cuh
index 9e1a17eaef..c2a2dd372b 100644
--- a/src/runtime/contrib/cutlass/fp16_group_gemm.cuh
+++ b/src/runtime/contrib/cutlass/fp16_group_gemm.cuh
@@ -38,11 +38,11 @@ void tvm_cutlass_group_gemm_impl(Tensor x, Tensor weight, 
Tensor indptr, Tensor
   // Workspace is used for storing device-side group gemm arguments and 
cutlass internal workspace.
   // Recommened size is 4MB.
   cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetStream(kDLCUDA, 
x->device.device_id));
-  CHECK_EQ(x->ndim, 2);
-  CHECK_EQ(weight->ndim, 3);
-  CHECK_EQ(indptr->ndim, 1);
-  CHECK_EQ(workspace->ndim, 1);
-  CHECK_EQ(out->ndim, 2);
+  TVM_FFI_CHECK_EQ(x->ndim, 2, ValueError);
+  TVM_FFI_CHECK_EQ(weight->ndim, 3, ValueError);
+  TVM_FFI_CHECK_EQ(indptr->ndim, 1, ValueError);
+  TVM_FFI_CHECK_EQ(workspace->ndim, 1, ValueError);
+  TVM_FFI_CHECK_EQ(out->ndim, 2, ValueError);
   int num_groups = weight->shape[0];
   int n = weight->shape[1];
   int k = weight->shape[2];
@@ -50,16 +50,16 @@ void tvm_cutlass_group_gemm_impl(Tensor x, Tensor weight, 
Tensor indptr, Tensor
   float beta = 0.0f;
 
   if (DataType(x->dtype) == DataType::Float(16)) {
-    CHECK(DataType(weight->dtype) == DataType::Float(16));
-    CHECK(DataType(out->dtype) == DataType::Float(16));
+    TVM_FFI_CHECK(DataType(weight->dtype) == DataType::Float(16), ValueError);
+    TVM_FFI_CHECK(DataType(out->dtype) == DataType::Float(16), ValueError);
     using Dtype = cutlass::half_t;
     CutlassGroupGemm<Arch, Dtype, Dtype, Dtype>::run(
         static_cast<Dtype*>(x->data), static_cast<Dtype*>(weight->data),
         static_cast<int64_t*>(indptr->data), 
static_cast<uint8_t*>(workspace->data),
         workspace->shape[0], n, k, num_groups, alpha, beta, 
static_cast<Dtype*>(out->data), stream);
   } else if (DataType(x->dtype) == DataType::BFloat(16)) {
-    CHECK(DataType(weight->dtype) == DataType::BFloat(16));
-    CHECK(DataType(out->dtype) == DataType::BFloat(16));
+    TVM_FFI_CHECK(DataType(weight->dtype) == DataType::BFloat(16), ValueError);
+    TVM_FFI_CHECK(DataType(out->dtype) == DataType::BFloat(16), ValueError);
     using Dtype = cutlass::bfloat16_t;
     CutlassGroupGemm<Arch, Dtype, Dtype, Dtype>::run(
         static_cast<Dtype*>(x->data), static_cast<Dtype*>(weight->data),
diff --git a/src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm100.cuh 
b/src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm100.cuh
index 34a50fdc54..22a9bea646 100644
--- a/src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm100.cuh
+++ b/src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm100.cuh
@@ -40,11 +40,11 @@
 #include "cutlass/gemm/kernel/gemm_universal.hpp"
 // clang-format on
 
-#define CUTLASS_CHECK(status)                                      \
-  {                                                                \
-    cutlass::Status error = status;                                \
-    CHECK(error == cutlass::Status::kSuccess)                      \
-        << "Got cutlass error: " << cutlassGetStatusString(error); \
+#define CUTLASS_CHECK(status)                                       \
+  {                                                                 \
+    cutlass::Status error = status;                                 \
+    TVM_FFI_CHECK(error == cutlass::Status::kSuccess, RuntimeError) \
+        << "Got cutlass error: " << cutlassGetStatusString(error);  \
   }
 
 using namespace cute;
@@ -156,7 +156,7 @@ struct CutlassGroupGemmRunner {
                                          hw_info};
     Gemm gemm_op;
     CUTLASS_CHECK(gemm_op.can_implement(arguments));
-    CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments));
+    TVM_FFI_CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments), 
RuntimeError);
     CUTLASS_CHECK(gemm_op.initialize(arguments, workspace, stream));
     CUTLASS_CHECK(gemm_op.run(stream));
   }
diff --git a/src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm90.cuh 
b/src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm90.cuh
index b894cbc84e..4fc513e3db 100644
--- a/src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm90.cuh
+++ b/src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm90.cuh
@@ -40,11 +40,11 @@
 #include "cutlass/gemm/kernel/gemm_universal.hpp"
 // clang-format on
 
-#define CUTLASS_CHECK(status)                                      \
-  {                                                                \
-    cutlass::Status error = status;                                \
-    CHECK(error == cutlass::Status::kSuccess)                      \
-        << "Got cutlass error: " << cutlassGetStatusString(error); \
+#define CUTLASS_CHECK(status)                                       \
+  {                                                                 \
+    cutlass::Status error = status;                                 \
+    TVM_FFI_CHECK(error == cutlass::Status::kSuccess, RuntimeError) \
+        << "Got cutlass error: " << cutlassGetStatusString(error);  \
   }
 
 using namespace cute;
@@ -156,7 +156,7 @@ struct CutlassGroupGemmRunner {
                                        hw_info};
     Gemm gemm_op;
     CUTLASS_CHECK(gemm_op.can_implement(arguments));
-    CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments));
+    TVM_FFI_CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments), 
RuntimeError);
     CUTLASS_CHECK(gemm_op.initialize(arguments, workspace, stream));
     CUTLASS_CHECK(gemm_op.run(stream));
   }
diff --git a/src/runtime/contrib/cutlass/fp8_gemm.cu 
b/src/runtime/contrib/cutlass/fp8_gemm.cu
index d41064efba..02fd34aa10 100644
--- a/src/runtime/contrib/cutlass/fp8_gemm.cu
+++ b/src/runtime/contrib/cutlass/fp8_gemm.cu
@@ -44,20 +44,21 @@ void tvm_cutlass_fp8_gemm(Tensor x, Tensor weight, Tensor 
workspace, Tensor alph
   // Recommened size is 4MB.
   cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetStream(kDLCUDA, 
x->device.device_id));
 
-  CHECK_GE(x->ndim, 2);
-  CHECK_EQ(weight->ndim, 2);
-  CHECK_EQ(workspace->ndim, 1);
-  CHECK_GE(out->ndim, 2);
-  CHECK_EQ(alpha->dtype.code, kDLFloat);
-  CHECK_EQ(alpha->dtype.bits, 32);
-  CHECK_EQ(alpha->ndim, 1);
-  CHECK_EQ(alpha->shape[0], 1);
+  TVM_FFI_CHECK_GE(x->ndim, 2, ValueError);
+  TVM_FFI_CHECK_EQ(weight->ndim, 2, ValueError);
+  TVM_FFI_CHECK_EQ(workspace->ndim, 1, ValueError);
+  TVM_FFI_CHECK_GE(out->ndim, 2, ValueError);
+  TVM_FFI_CHECK_EQ(alpha->dtype.code, kDLFloat, ValueError);
+  TVM_FFI_CHECK_EQ(alpha->dtype.bits, 32, ValueError);
+  TVM_FFI_CHECK_EQ(alpha->ndim, 1, ValueError);
+  TVM_FFI_CHECK_EQ(alpha->shape[0], 1, ValueError);
   int64_t m = 1;
   for (int i = 0; i < x->ndim - 1; ++i) {
     m *= x->shape[i];
   }
   int64_t n = weight->shape[0];
-  CHECK_EQ(x->shape[x->ndim - 1], weight->shape[1]) << "Only col-major weight 
is supported now.";
+  TVM_FFI_CHECK_EQ(x->shape[x->ndim - 1], weight->shape[1], ValueError)
+      << "Only col-major weight is supported now.";
   int64_t k = x->shape[x->ndim - 1];
   const float* beta = nullptr;
   if (m <= 64) {
diff --git a/src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu 
b/src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu
index b2e08b7570..adfcaed0c0 100644
--- a/src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu
+++ b/src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu
@@ -47,15 +47,15 @@ void tvm_cutlass_fp8_group_gemm(Tensor x, Tensor weight, 
Tensor indptr, Tensor w
   // Workspace is used for storing device-side group gemm arguments and 
cutlass internal workspace.
   // Recommened size is 4MB.
   cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetStream(kDLCUDA, 
x->device.device_id));
-  CHECK_EQ(x->ndim, 2);
-  CHECK_EQ(weight->ndim, 3);
-  CHECK_EQ(indptr->ndim, 1);
-  CHECK_EQ(workspace->ndim, 1);
-  CHECK_EQ(out->ndim, 2);
-  CHECK_EQ(alpha->dtype.code, kDLFloat);
-  CHECK_EQ(alpha->dtype.bits, 32);
-  CHECK_EQ(alpha->ndim, 1);
-  CHECK_EQ(alpha->shape[0], 1);
+  TVM_FFI_CHECK_EQ(x->ndim, 2, ValueError);
+  TVM_FFI_CHECK_EQ(weight->ndim, 3, ValueError);
+  TVM_FFI_CHECK_EQ(indptr->ndim, 1, ValueError);
+  TVM_FFI_CHECK_EQ(workspace->ndim, 1, ValueError);
+  TVM_FFI_CHECK_EQ(out->ndim, 2, ValueError);
+  TVM_FFI_CHECK_EQ(alpha->dtype.code, kDLFloat, ValueError);
+  TVM_FFI_CHECK_EQ(alpha->dtype.bits, 32, ValueError);
+  TVM_FFI_CHECK_EQ(alpha->ndim, 1, ValueError);
+  TVM_FFI_CHECK_EQ(alpha->shape[0], 1, ValueError);
   int num_groups = weight->shape[0];
   int n = weight->shape[1];
   int k = x->shape[1];
diff --git a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh 
b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh
index 07f70f35e0..338a96c8b7 100644
--- a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh
+++ b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh
@@ -42,35 +42,36 @@ void tvm_cutlass_fp8_groupwise_scaled_gemm_impl(Tensor a, 
Tensor b, Tensor scale
   // Recommened size is 4MB.
   cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetStream(kDLCUDA, 
a->device.device_id));
 
-  CHECK_GE(a->ndim, 2);
-  CHECK_EQ(scales_a->ndim, a->ndim);
-  CHECK_EQ(b->ndim, 2);
-  CHECK_EQ(scales_b->ndim, 2);
-  CHECK_EQ(workspace->ndim, 1);
-  CHECK_EQ(out->ndim, a->ndim);
+  TVM_FFI_CHECK_GE(a->ndim, 2, ValueError);
+  TVM_FFI_CHECK_EQ(scales_a->ndim, a->ndim, ValueError);
+  TVM_FFI_CHECK_EQ(b->ndim, 2, ValueError);
+  TVM_FFI_CHECK_EQ(scales_b->ndim, 2, ValueError);
+  TVM_FFI_CHECK_EQ(workspace->ndim, 1, ValueError);
+  TVM_FFI_CHECK_EQ(out->ndim, a->ndim, ValueError);
   int64_t m = 1;
   for (int64_t i = 0; i < a->ndim - 1; ++i) {
     m *= a->shape[i];
   }
   int64_t n = b->shape[0];
-  CHECK_EQ(a->shape[a->ndim - 1], b->shape[1]) << "Only col-major B is 
supported now.";
+  TVM_FFI_CHECK_EQ(a->shape[a->ndim - 1], b->shape[1], ValueError)
+      << "Only col-major B is supported now.";
   int64_t k = a->shape[a->ndim - 1];
 
   // scales_a is col-major of (*a_shape[:-1], k / block_size)
-  CHECK_EQ(scales_a->shape[0] * block_size_1, k);
+  TVM_FFI_CHECK_EQ(scales_a->shape[0] * block_size_1, k, ValueError);
   for (int64_t i = 1; i < scales_a->ndim; ++i) {
-    CHECK_EQ(scales_a->shape[i], a->shape[i - 1]);
+    TVM_FFI_CHECK_EQ(scales_a->shape[i], a->shape[i - 1], ValueError);
   }
   // scales_b is col-major of (k / block_size, n / block_size)
-  CHECK_EQ((n + block_size_0 - 1) / block_size_0, scales_b->shape[0]);
-  CHECK_EQ(scales_b->shape[1] * block_size_1, k);
+  TVM_FFI_CHECK_EQ((n + block_size_0 - 1) / block_size_0, scales_b->shape[0], 
ValueError);
+  TVM_FFI_CHECK_EQ(scales_b->shape[1] * block_size_1, k, ValueError);
 
   using tvm::runtime::DataType;
-  CHECK_EQ(DataType(a->dtype), DataType::Float8E4M3FN());
-  CHECK_EQ(DataType(b->dtype), DataType::Float8E4M3FN());
-  CHECK_EQ(DataType(scales_a->dtype), DataType::Float(32));
-  CHECK_EQ(DataType(scales_b->dtype), DataType::Float(32));
-  CHECK_EQ(DataType(workspace->dtype), DataType::UInt(8));
+  TVM_FFI_CHECK_EQ(DataType(a->dtype), DataType::Float8E4M3FN(), ValueError);
+  TVM_FFI_CHECK_EQ(DataType(b->dtype), DataType::Float8E4M3FN(), ValueError);
+  TVM_FFI_CHECK_EQ(DataType(scales_a->dtype), DataType::Float(32), ValueError);
+  TVM_FFI_CHECK_EQ(DataType(scales_b->dtype), DataType::Float(32), ValueError);
+  TVM_FFI_CHECK_EQ(DataType(workspace->dtype), DataType::UInt(8), ValueError);
 
   if (DataType(out->dtype) == DataType::Float(16)) {
     CutlassFP8GroupwiseGemm<Arch, TileShape, ClusterShape, 
cutlass::float_e4m3_t,
@@ -107,35 +108,35 @@ void tvm_cutlass_fp8_groupwise_scaled_bmm_impl(Tensor a, 
Tensor b, Tensor scales
   // Recommened size is 4MB.
   cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetStream(kDLCUDA, 
a->device.device_id));
 
-  CHECK_EQ(a->ndim, 3);
-  CHECK_EQ(scales_a->ndim, 3);
-  CHECK_EQ(b->ndim, 3);
-  CHECK_EQ(scales_b->ndim, 3);
-  CHECK_EQ(workspace->ndim, 1);
-  CHECK_EQ(out->ndim, 3);
+  TVM_FFI_CHECK_EQ(a->ndim, 3, ValueError);
+  TVM_FFI_CHECK_EQ(scales_a->ndim, 3, ValueError);
+  TVM_FFI_CHECK_EQ(b->ndim, 3, ValueError);
+  TVM_FFI_CHECK_EQ(scales_b->ndim, 3, ValueError);
+  TVM_FFI_CHECK_EQ(workspace->ndim, 1, ValueError);
+  TVM_FFI_CHECK_EQ(out->ndim, 3, ValueError);
   int64_t batch_size = a->shape[0];
   int64_t m = a->shape[1];
   int64_t n = b->shape[1];
-  CHECK_EQ(a->shape[2], b->shape[2]) << "Only col-major B is supported now.";
+  TVM_FFI_CHECK_EQ(a->shape[2], b->shape[2], ValueError) << "Only col-major B 
is supported now.";
   int64_t k = a->shape[2];
-  CHECK_EQ(b->shape[0], batch_size);
-  CHECK_EQ(scales_a->shape[0], batch_size);
-  CHECK_EQ(scales_b->shape[0], batch_size);
-  CHECK_EQ(out->shape[0], batch_size);
+  TVM_FFI_CHECK_EQ(b->shape[0], batch_size, ValueError);
+  TVM_FFI_CHECK_EQ(scales_a->shape[0], batch_size, ValueError);
+  TVM_FFI_CHECK_EQ(scales_b->shape[0], batch_size, ValueError);
+  TVM_FFI_CHECK_EQ(out->shape[0], batch_size, ValueError);
 
   // scales_a is col-major of (batch_size, m, k / block_size)
-  CHECK_EQ(scales_a->shape[1] * block_size_1, k);
-  CHECK_EQ(scales_a->shape[2], m);
+  TVM_FFI_CHECK_EQ(scales_a->shape[1] * block_size_1, k, ValueError);
+  TVM_FFI_CHECK_EQ(scales_a->shape[2], m, ValueError);
   // scales_b is col-major of (k / block_size, n / block_size)
-  CHECK_EQ(scales_b->shape[1] * block_size_0, n);
-  CHECK_EQ(scales_b->shape[2] * block_size_1, k);
+  TVM_FFI_CHECK_EQ(scales_b->shape[1] * block_size_0, n, ValueError);
+  TVM_FFI_CHECK_EQ(scales_b->shape[2] * block_size_1, k, ValueError);
 
   using tvm::runtime::DataType;
-  CHECK_EQ(DataType(a->dtype), DataType::Float8E4M3FN());
-  CHECK_EQ(DataType(b->dtype), DataType::Float8E4M3FN());
-  CHECK_EQ(DataType(scales_a->dtype), DataType::Float(32));
-  CHECK_EQ(DataType(scales_b->dtype), DataType::Float(32));
-  CHECK_EQ(DataType(workspace->dtype), DataType::UInt(8));
+  TVM_FFI_CHECK_EQ(DataType(a->dtype), DataType::Float8E4M3FN(), ValueError);
+  TVM_FFI_CHECK_EQ(DataType(b->dtype), DataType::Float8E4M3FN(), ValueError);
+  TVM_FFI_CHECK_EQ(DataType(scales_a->dtype), DataType::Float(32), ValueError);
+  TVM_FFI_CHECK_EQ(DataType(scales_b->dtype), DataType::Float(32), ValueError);
+  TVM_FFI_CHECK_EQ(DataType(workspace->dtype), DataType::UInt(8), ValueError);
 
   if (DataType(out->dtype) == DataType::Float(16)) {
     CutlassFP8GroupwiseGemm<Arch, TileShape, ClusterShape, 
cutlass::float_e4m3_t,
diff --git 
a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm100.cuh 
b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm100.cuh
index 87cd8108f9..4c02cdf46e 100644
--- a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm100.cuh
+++ b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm100.cuh
@@ -45,11 +45,11 @@
 #include "cutlass/tensor_ref.h"
 // clang-format on
 
-#define CUTLASS_CHECK(status)                                      \
-  {                                                                \
-    cutlass::Status error = status;                                \
-    CHECK(error == cutlass::Status::kSuccess)                      \
-        << "Got cutlass error: " << cutlassGetStatusString(error); \
+#define CUTLASS_CHECK(status)                                       \
+  {                                                                 \
+    cutlass::Status error = status;                                 \
+    TVM_FFI_CHECK(error == cutlass::Status::kSuccess, RuntimeError) \
+        << "Got cutlass error: " << cutlassGetStatusString(error);  \
   }
 
 using namespace cute;
@@ -137,7 +137,7 @@ struct CutlassFP8ScaledGroupwiseGemmRunnerSM100 {
 
     Gemm gemm_op;
     CUTLASS_CHECK(gemm_op.can_implement(arguments));
-    CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments));
+    TVM_FFI_CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments), 
RuntimeError);
     CUTLASS_CHECK(gemm_op.initialize(arguments, workspace, stream));
     CUTLASS_CHECK(gemm_op.run(stream));
   }
diff --git 
a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm90.cuh 
b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm90.cuh
index d5321d157c..69445a780f 100644
--- a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm90.cuh
+++ b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm90.cuh
@@ -45,11 +45,11 @@
 #include "cutlass_extensions/gemm/dispatch_policy.hpp"
 // clang-format on
 
-#define CUTLASS_CHECK(status)                                      \
-  {                                                                \
-    cutlass::Status error = status;                                \
-    CHECK(error == cutlass::Status::kSuccess)                      \
-        << "Got cutlass error: " << cutlassGetStatusString(error); \
+#define CUTLASS_CHECK(status)                                       \
+  {                                                                 \
+    cutlass::Status error = status;                                 \
+    TVM_FFI_CHECK(error == cutlass::Status::kSuccess, RuntimeError) \
+        << "Got cutlass error: " << cutlassGetStatusString(error);  \
   }
 
 using namespace cute;
@@ -141,7 +141,7 @@ struct CutlassFP8GroupwiseScaledGemmRunner {
 
     Gemm gemm_op;
     CUTLASS_CHECK(gemm_op.can_implement(arguments));
-    CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments));
+    TVM_FFI_CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments), 
RuntimeError);
     CUTLASS_CHECK(gemm_op.initialize(arguments, workspace, stream));
     CUTLASS_CHECK(gemm_op.run(stream));
   }
diff --git 
a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_runner_sm100.cuh 
b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_runner_sm100.cuh
index 19c6b699aa..e6729b9d96 100644
--- 
a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_runner_sm100.cuh
+++ 
b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_runner_sm100.cuh
@@ -40,11 +40,11 @@
 #include "cutlass/gemm/kernel/gemm_universal.hpp"
 // clang-format on
 
-#define CUTLASS_CHECK(status)                                      \
-  {                                                                \
-    cutlass::Status error = status;                                \
-    CHECK(error == cutlass::Status::kSuccess)                      \
-        << "Got cutlass error: " << cutlassGetStatusString(error); \
+#define CUTLASS_CHECK(status)                                       \
+  {                                                                 \
+    cutlass::Status error = status;                                 \
+    TVM_FFI_CHECK(error == cutlass::Status::kSuccess, RuntimeError) \
+        << "Got cutlass error: " << cutlassGetStatusString(error);  \
   }
 
 using namespace cute;
@@ -127,7 +127,7 @@ struct CutlassFP8ScaledGroupwiseGroupGemmRunnerSM100 {
 
     Gemm gemm_op;
     CUTLASS_CHECK(gemm_op.can_implement(arguments));
-    CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments));
+    TVM_FFI_CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments), 
RuntimeError);
     CUTLASS_CHECK(gemm_op.initialize(arguments, workspace, stream));
     CUTLASS_CHECK(gemm_op.run(stream));
   }
diff --git 
a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu 
b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu
index 955a01765c..1ab5b41f1d 100644
--- a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu
+++ b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu
@@ -38,32 +38,32 @@ void tvm_fp8_groupwise_scaled_group_gemm_sm100(Tensor a, 
Tensor b, Tensor scales
   // Workspace is used for storing device-side group gemm arguments and 
cutlass internal workspace.
   // Recommended size is 4MB.
   cudaStream_t stream = static_cast<cudaStream_t>(TVMFFIEnvGetStream(kDLCUDA, 
a->device.device_id));
-  CHECK_EQ(a->ndim, 2);
-  CHECK_EQ(b->ndim, 3);
-  CHECK_EQ(indptr->ndim, 1);
-  CHECK_EQ(workspace->ndim, 1);
-  CHECK_EQ(out->ndim, 2);
+  TVM_FFI_CHECK_EQ(a->ndim, 2, ValueError);
+  TVM_FFI_CHECK_EQ(b->ndim, 3, ValueError);
+  TVM_FFI_CHECK_EQ(indptr->ndim, 1, ValueError);
+  TVM_FFI_CHECK_EQ(workspace->ndim, 1, ValueError);
+  TVM_FFI_CHECK_EQ(out->ndim, 2, ValueError);
   int num_groups = b->shape[0];
   int n = b->shape[1];
   int k = b->shape[2];
 
-  CHECK_EQ(scales_a->ndim, a->ndim);
-  CHECK_EQ(scales_b->ndim, b->ndim);
+  TVM_FFI_CHECK_EQ(scales_a->ndim, a->ndim, ValueError);
+  TVM_FFI_CHECK_EQ(scales_b->ndim, b->ndim, ValueError);
   // scales_a is row-major of (m, k / block_size)
-  CHECK_EQ((k + block_size_1 - 1) / block_size_1, scales_a->shape[1]);
-  CHECK_EQ(scales_a->shape[0], a->shape[0]);
+  TVM_FFI_CHECK_EQ((k + block_size_1 - 1) / block_size_1, scales_a->shape[1], 
ValueError);
+  TVM_FFI_CHECK_EQ(scales_a->shape[0], a->shape[0], ValueError);
   // scales_b is col-major of (k / block_size, n / block_size)
-  CHECK_EQ(scales_b->shape[0], num_groups);
-  CHECK_EQ((n + block_size_0 - 1) / block_size_0, scales_b->shape[1]);
-  CHECK_EQ((k + block_size_1 - 1) / block_size_1, scales_b->shape[2]);
+  TVM_FFI_CHECK_EQ(scales_b->shape[0], num_groups, ValueError);
+  TVM_FFI_CHECK_EQ((n + block_size_0 - 1) / block_size_0, scales_b->shape[1], 
ValueError);
+  TVM_FFI_CHECK_EQ((k + block_size_1 - 1) / block_size_1, scales_b->shape[2], 
ValueError);
 
   using tvm::runtime::DataType;
-  CHECK_EQ(DataType(a->dtype), DataType::Float8E4M3FN());
-  CHECK_EQ(DataType(b->dtype), DataType::Float8E4M3FN());
-  CHECK_EQ(DataType(scales_a->dtype), DataType::Float(32));
-  CHECK_EQ(DataType(scales_b->dtype), DataType::Float(32));
-  CHECK_EQ(DataType(indptr->dtype), DataType::Int(64));
-  CHECK_EQ(DataType(workspace->dtype), DataType::UInt(8));
+  TVM_FFI_CHECK_EQ(DataType(a->dtype), DataType::Float8E4M3FN(), ValueError);
+  TVM_FFI_CHECK_EQ(DataType(b->dtype), DataType::Float8E4M3FN(), ValueError);
+  TVM_FFI_CHECK_EQ(DataType(scales_a->dtype), DataType::Float(32), ValueError);
+  TVM_FFI_CHECK_EQ(DataType(scales_b->dtype), DataType::Float(32), ValueError);
+  TVM_FFI_CHECK_EQ(DataType(indptr->dtype), DataType::Int(64), ValueError);
+  TVM_FFI_CHECK_EQ(DataType(workspace->dtype), DataType::UInt(8), ValueError);
 
   if (DataType(out->dtype) == DataType::Float(16)) {
     using Dtype = cutlass::half_t;
diff --git a/src/runtime/contrib/cutlass/gemm_runner.cuh 
b/src/runtime/contrib/cutlass/gemm_runner.cuh
index 0ca1d1be02..b0907bfe29 100644
--- a/src/runtime/contrib/cutlass/gemm_runner.cuh
+++ b/src/runtime/contrib/cutlass/gemm_runner.cuh
@@ -40,11 +40,11 @@
 #include "cutlass/gemm/kernel/gemm_universal.hpp"
 // clang-format on
 
-#define CUTLASS_CHECK(status)                                      \
-  {                                                                \
-    cutlass::Status error = status;                                \
-    CHECK(error == cutlass::Status::kSuccess)                      \
-        << "Got cutlass error: " << cutlassGetStatusString(error); \
+#define CUTLASS_CHECK(status)                                       \
+  {                                                                 \
+    cutlass::Status error = status;                                 \
+    TVM_FFI_CHECK(error == cutlass::Status::kSuccess, RuntimeError) \
+        << "Got cutlass error: " << cutlassGetStatusString(error);  \
   }
 
 using namespace cute;
@@ -130,7 +130,7 @@ struct CutlassGemmRunner {
 
     Gemm gemm_op;
     CUTLASS_CHECK(gemm_op.can_implement(arguments));
-    CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments));
+    TVM_FFI_CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments), 
RuntimeError);
     CUTLASS_CHECK(gemm_op.initialize(arguments, workspace, stream));
     CUTLASS_CHECK(gemm_op.run(stream));
   }
diff --git a/src/runtime/contrib/nvshmem/kv_transfer.cu 
b/src/runtime/contrib/nvshmem/kv_transfer.cu
index 34916a614a..1338ea3e6e 100644
--- a/src/runtime/contrib/nvshmem/kv_transfer.cu
+++ b/src/runtime/contrib/nvshmem/kv_transfer.cu
@@ -180,41 +180,48 @@ __global__ void KVTransferPageToPage(T* remote_pages, T* 
local_pages, int32_t* r
 
 int _KVTransfer(DLTensor* remote_pages, DLTensor* k, DLTensor* v, DLTensor* 
remote_position_map,
                 DLTensor* remote_tp_group_pe_offset, TVMStreamHandle 
transfer_stream) {
-  CHECK_EQ(remote_pages->device.device_type, kDLCUDA)
+  TVM_FFI_CHECK_EQ(remote_pages->device.device_type, kDLCUDA, ValueError)
       << "The device of remote_pages matrix must be CUDA.";
-  CHECK_EQ(k->device.device_type, kDLCUDA) << "The device of k matrix must be 
CUDA.";
-  CHECK_EQ(v->device.device_type, kDLCUDA) << "The device of v matrix must be 
CUDA.";
-  CHECK_EQ(remote_position_map->device.device_type, kDLCUDA)
+  TVM_FFI_CHECK_EQ(k->device.device_type, kDLCUDA, ValueError)
+      << "The device of k matrix must be CUDA.";
+  TVM_FFI_CHECK_EQ(v->device.device_type, kDLCUDA, ValueError)
+      << "The device of v matrix must be CUDA.";
+  TVM_FFI_CHECK_EQ(remote_position_map->device.device_type, kDLCUDA, 
ValueError)
       << "The device of remote_position_map matrix must be CUDA.";
   size_t dev_id = remote_pages->device.device_id;
-  CHECK_EQ(k->device.device_id, dev_id)
+  TVM_FFI_CHECK_EQ(k->device.device_id, dev_id, ValueError)
       << "The device id of remote_pages and k matrix doesn't match.";
-  CHECK_EQ(v->device.device_id, dev_id)
+  TVM_FFI_CHECK_EQ(v->device.device_id, dev_id, ValueError)
       << "The device id of remote_pages and v matrix doesn't match.";
-  CHECK_EQ(remote_position_map->device.device_id, dev_id)
+  TVM_FFI_CHECK_EQ(remote_position_map->device.device_id, dev_id, ValueError)
       << "The device id of remote_pages and remote_position_map matrix doesn't 
match.";
-  CHECK_EQ(remote_tp_group_pe_offset->device.device_id, dev_id)
+  TVM_FFI_CHECK_EQ(remote_tp_group_pe_offset->device.device_id, dev_id, 
ValueError)
       << "The device id of remote_pages and remote_tp_group_pe_offset matrix 
doesn't match.";
 
-  CHECK_EQ(remote_pages->ndim, 5);
+  TVM_FFI_CHECK_EQ(remote_pages->ndim, 5, ValueError);
   int remote_num_pages = remote_pages->shape[0];
   int remote_num_kv_head = remote_pages->shape[2];
   int page_size = remote_pages->shape[3];
   int head_dim = remote_pages->shape[4];
 
-  CHECK_GE(k->ndim, 3);
+  TVM_FFI_CHECK_GE(k->ndim, 3, ValueError);
   int kv_len = k->shape[k->ndim - 3];
   int local_num_kv_heads = k->shape[k->ndim - 2];
-  CHECK_EQ(head_dim, k->shape[k->ndim - 1]);
+  TVM_FFI_CHECK_EQ(head_dim, k->shape[k->ndim - 1], ValueError);
 
-  CHECK_GE(v->ndim, 3);
-  CHECK_EQ(kv_len, v->shape[v->ndim - 3]);
-  CHECK_EQ(local_num_kv_heads, v->shape[v->ndim - 2]);
-  CHECK_EQ(head_dim, v->shape[v->ndim - 1]);
+  TVM_FFI_CHECK_GE(v->ndim, 3, ValueError);
+  TVM_FFI_CHECK_EQ(kv_len, v->shape[v->ndim - 3], ValueError);
+  TVM_FFI_CHECK_EQ(local_num_kv_heads, v->shape[v->ndim - 2], ValueError);
+  TVM_FFI_CHECK_EQ(head_dim, v->shape[v->ndim - 1], ValueError);
 
-  CHECK(remote_pages->dtype.lanes == 1 && k->dtype.lanes == 1 && 
v->dtype.lanes == 1);
-  CHECK(remote_pages->dtype.bits == k->dtype.bits && remote_pages->dtype.code 
== k->dtype.code);
-  CHECK(remote_pages->dtype.bits == v->dtype.bits && remote_pages->dtype.code 
== v->dtype.code);
+  TVM_FFI_CHECK(remote_pages->dtype.lanes == 1 && k->dtype.lanes == 1 && 
v->dtype.lanes == 1,
+                ValueError);
+  TVM_FFI_CHECK(
+      remote_pages->dtype.bits == k->dtype.bits && remote_pages->dtype.code == 
k->dtype.code,
+      ValueError);
+  TVM_FFI_CHECK(
+      remote_pages->dtype.bits == v->dtype.bits && remote_pages->dtype.code == 
v->dtype.code,
+      ValueError);
   int local_tp_rank;
   tvm::runtime::DiscoWorker* worker = 
tvm::runtime::ThreadLocalDiscoWorker::Get()->worker;
   if (worker == nullptr) {
@@ -258,34 +265,36 @@ int _KVTransfer(DLTensor* remote_pages, DLTensor* k, 
DLTensor* v, DLTensor* remo
 int _KVTransferPageToPage(DLTensor* remote_pages, DLTensor* local_pages,
                           DLTensor* remote_position_map, DLTensor* 
local_position_map,
                           DLTensor* remote_tp_group_pe_offset, TVMStreamHandle 
transfer_stream) {
-  CHECK_EQ(remote_pages->device.device_type, kDLCUDA)
+  TVM_FFI_CHECK_EQ(remote_pages->device.device_type, kDLCUDA, ValueError)
       << "The device of remote_pages matrix must be CUDA.";
-  CHECK_EQ(local_pages->device.device_type, kDLCUDA) << "The device of k 
matrix must be CUDA.";
-  CHECK_EQ(remote_position_map->device.device_type, kDLCUDA)
+  TVM_FFI_CHECK_EQ(local_pages->device.device_type, kDLCUDA, ValueError)
+      << "The device of k matrix must be CUDA.";
+  TVM_FFI_CHECK_EQ(remote_position_map->device.device_type, kDLCUDA, 
ValueError)
       << "The device of remote_position_map matrix must be CUDA.";
   size_t dev_id = remote_pages->device.device_id;
-  CHECK_EQ(local_pages->device.device_id, dev_id)
+  TVM_FFI_CHECK_EQ(local_pages->device.device_id, dev_id, ValueError)
       << "The device id of remote_pages and k matrix doesn't match.";
-  CHECK_EQ(remote_position_map->device.device_id, dev_id)
+  TVM_FFI_CHECK_EQ(remote_position_map->device.device_id, dev_id, ValueError)
       << "The device id of remote_pages and remote_position_map matrix doesn't 
match.";
-  CHECK_EQ(remote_tp_group_pe_offset->device.device_id, dev_id)
+  TVM_FFI_CHECK_EQ(remote_tp_group_pe_offset->device.device_id, dev_id, 
ValueError)
       << "The device id of remote_pages and remote_tp_group_pe_offset matrix 
doesn't match.";
 
-  CHECK_EQ(remote_pages->ndim, 5);
+  TVM_FFI_CHECK_EQ(remote_pages->ndim, 5, ValueError);
   int remote_num_kv_head = remote_pages->shape[2];
   int page_size = remote_pages->shape[3];
   int head_dim = remote_pages->shape[4];
 
-  CHECK_GE(local_pages->ndim, 5);
+  TVM_FFI_CHECK_GE(local_pages->ndim, 5, ValueError);
   int local_num_kv_heads = local_pages->shape[2];
-  CHECK_EQ(head_dim, local_pages->shape[4]);
+  TVM_FFI_CHECK_EQ(head_dim, local_pages->shape[4], ValueError);
 
-  CHECK_EQ(remote_position_map->ndim, 1);
+  TVM_FFI_CHECK_EQ(remote_position_map->ndim, 1, ValueError);
   int ntokens = remote_position_map->shape[0];
 
-  CHECK(remote_pages->dtype.lanes == 1 && local_pages->dtype.lanes == 1);
-  CHECK(remote_pages->dtype.bits == local_pages->dtype.bits &&
-        remote_pages->dtype.code == local_pages->dtype.code);
+  TVM_FFI_CHECK(remote_pages->dtype.lanes == 1 && local_pages->dtype.lanes == 
1, ValueError);
+  TVM_FFI_CHECK(remote_pages->dtype.bits == local_pages->dtype.bits &&
+                    remote_pages->dtype.code == local_pages->dtype.code,
+                ValueError);
 
   int local_tp_rank;
   tvm::runtime::DiscoWorker* worker = 
tvm::runtime::ThreadLocalDiscoWorker::Get()->worker;

Reply via email to