This is an automated email from the ASF dual-hosted git repository. masahi pushed a commit to branch unity in repository https://gitbox.apache.org/repos/asf/tvm.git
commit f238c3b2648da993366348ccb3b79f76d502a938 Author: Wuwei Lin <wu...@apache.org> AuthorDate: Tue Dec 12 01:12:27 2023 +0000 fix --- src/runtime/contrib/cutlass/moe_gemm.cc | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/src/runtime/contrib/cutlass/moe_gemm.cc b/src/runtime/contrib/cutlass/moe_gemm.cc index 046247d108..4c0de113ab 100644 --- a/src/runtime/contrib/cutlass/moe_gemm.cc +++ b/src/runtime/contrib/cutlass/moe_gemm.cc @@ -54,23 +54,29 @@ namespace runtime { TVM_REGISTER_GLOBAL("cutlass.moe_gemm_f16f16") .set_body_typed([](NDArray x, NDArray weight, NDArray total_rows_before_expert, int64_t total_rows, int64_t n, int64_t k, int64_t num_experts, NDArray out) { + auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream"); + ICHECK(func != nullptr); + cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator void*()); + fastertransformer::moe_gemm_bias_act<half, half>( reinterpret_cast<half*>(x->data), reinterpret_cast<half*>(weight->data), nullptr, nullptr, reinterpret_cast<half*>(out->data), reinterpret_cast<int64_t*>(total_rows_before_expert->data), total_rows, n, k, num_experts, - std::nullopt, - /*stream=*/nullptr /*FIXME*/); + std::nullopt, stream); }); TVM_REGISTER_GLOBAL("cutlass.moe_gemm_s4f16") .set_body_typed([](NDArray x, NDArray weight, NDArray scales, NDArray total_rows_before_expert, int64_t total_rows, int64_t n, int64_t k, int64_t num_experts, NDArray out) { + auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream"); + ICHECK(func != nullptr); + cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator void*()); + fastertransformer::moe_gemm_bias_act<half, cutlass::uint4b_t>( reinterpret_cast<half*>(x->data), reinterpret_cast<cutlass::uint4b_t*>(weight->data), reinterpret_cast<half*>(scales->data), nullptr, reinterpret_cast<half*>(out->data), reinterpret_cast<int64_t*>(total_rows_before_expert->data), total_rows, n, k, num_experts, - std::nullopt, - /*stream=*/nullptr /*FIXME*/); + std::nullopt, stream); }); TVM_REGISTER_GLOBAL("moe_compute_rows_before") @@ -81,10 +87,14 @@ TVM_REGISTER_GLOBAL("moe_compute_rows_before") CHECK(sorted_indices->ndim == 1); CHECK(total_rows_before_expert->ndim == 1); + auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream"); + ICHECK(func != nullptr); + cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator void*()); + int num_experts = total_rows_before_expert->shape[0]; compute_total_rows_before_expert( reinterpret_cast<int*>(sorted_indices->data), sorted_indices->shape[0], num_experts, - reinterpret_cast<int64_t*>(total_rows_before_expert->data), nullptr); + reinterpret_cast<int64_t*>(total_rows_before_expert->data), stream); }); } // namespace runtime