This is an automated email from the ASF dual-hosted git repository.
jwfromm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 585d6d2559 [CUTLASS] Fix CUTLASS kernel compilation (#18238)
585d6d2559 is described below
commit 585d6d25596f184b1a4acbd9f4aae52f2c4e5c41
Author: Ruihang Lai <[email protected]>
AuthorDate: Tue Aug 26 13:20:19 2025 -0400
[CUTLASS] Fix CUTLASS kernel compilation (#18238)
This PR fixes a few places in the current CUTLASS kernel AOT
compilation.
---
src/runtime/contrib/cutlass/fp16_group_gemm.cuh | 5 +++--
src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu | 1 +
2 files changed, 4 insertions(+), 2 deletions(-)
diff --git a/src/runtime/contrib/cutlass/fp16_group_gemm.cuh
b/src/runtime/contrib/cutlass/fp16_group_gemm.cuh
index a09051a86e..cb26a0796d 100644
--- a/src/runtime/contrib/cutlass/fp16_group_gemm.cuh
+++ b/src/runtime/contrib/cutlass/fp16_group_gemm.cuh
@@ -19,6 +19,7 @@
#include <cuda_fp16.h>
#include <float.h>
+#include <tvm/ffi/extra/c_env_api.h>
#include <tvm/ffi/function.h>
#include <tvm/runtime/ndarray.h>
@@ -36,7 +37,8 @@ void tvm_cutlass_group_gemm_impl(NDArray x, NDArray weight,
NDArray indptr, NDAr
NDArray out) {
// Workspace is used for storing device-side group gemm arguments and
cutlass internal workspace.
// Recommened size is 4MB.
- cudaStream_t stream =
static_cast<cudaStream_t>(TVMFFIEnvGetCurrentStream(kDLCUDA,
x->device.device_id));
+ cudaStream_t stream =
+ static_cast<cudaStream_t>(TVMFFIEnvGetCurrentStream(kDLCUDA,
x->device.device_id));
CHECK_EQ(x->ndim, 2);
CHECK_EQ(weight->ndim, 3);
CHECK_EQ(indptr->ndim, 1);
@@ -47,7 +49,6 @@ void tvm_cutlass_group_gemm_impl(NDArray x, NDArray weight,
NDArray indptr, NDAr
int k = weight->shape[2];
float alpha = 1.0f;
float beta = 0.0f;
- cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());
if (DataType(x->dtype) == DataType::Float(16)) {
CHECK(DataType(weight->dtype) == DataType::Float(16));
diff --git
a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu
b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu
index 2745c0b1fc..b9be378a9a 100644
--- a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu
+++ b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu
@@ -19,6 +19,7 @@
#include <cuda_fp16.h>
#include <float.h>
+#include <tvm/ffi/extra/c_env_api.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/ndarray.h>