This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit f238c3b2648da993366348ccb3b79f76d502a938
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Tue Dec 12 01:12:27 2023 +0000

    fix
---
 src/runtime/contrib/cutlass/moe_gemm.cc | 20 +++++++++++++++-----
 1 file changed, 15 insertions(+), 5 deletions(-)

diff --git a/src/runtime/contrib/cutlass/moe_gemm.cc 
b/src/runtime/contrib/cutlass/moe_gemm.cc
index 046247d108..4c0de113ab 100644
--- a/src/runtime/contrib/cutlass/moe_gemm.cc
+++ b/src/runtime/contrib/cutlass/moe_gemm.cc
@@ -54,23 +54,29 @@ namespace runtime {
 TVM_REGISTER_GLOBAL("cutlass.moe_gemm_f16f16")
     .set_body_typed([](NDArray x, NDArray weight, NDArray 
total_rows_before_expert,
                        int64_t total_rows, int64_t n, int64_t k, int64_t 
num_experts, NDArray out) {
+      auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
+      ICHECK(func != nullptr);
+      cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator 
void*());
+
       fastertransformer::moe_gemm_bias_act<half, half>(
           reinterpret_cast<half*>(x->data), 
reinterpret_cast<half*>(weight->data), nullptr, nullptr,
           reinterpret_cast<half*>(out->data),
           reinterpret_cast<int64_t*>(total_rows_before_expert->data), 
total_rows, n, k, num_experts,
-          std::nullopt,
-          /*stream=*/nullptr /*FIXME*/);
+          std::nullopt, stream);
     });
 
 TVM_REGISTER_GLOBAL("cutlass.moe_gemm_s4f16")
     .set_body_typed([](NDArray x, NDArray weight, NDArray scales, NDArray 
total_rows_before_expert,
                        int64_t total_rows, int64_t n, int64_t k, int64_t 
num_experts, NDArray out) {
+      auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
+      ICHECK(func != nullptr);
+      cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator 
void*());
+
       fastertransformer::moe_gemm_bias_act<half, cutlass::uint4b_t>(
           reinterpret_cast<half*>(x->data), 
reinterpret_cast<cutlass::uint4b_t*>(weight->data),
           reinterpret_cast<half*>(scales->data), nullptr, 
reinterpret_cast<half*>(out->data),
           reinterpret_cast<int64_t*>(total_rows_before_expert->data), 
total_rows, n, k, num_experts,
-          std::nullopt,
-          /*stream=*/nullptr /*FIXME*/);
+          std::nullopt, stream);
     });
 
 TVM_REGISTER_GLOBAL("moe_compute_rows_before")
@@ -81,10 +87,14 @@ TVM_REGISTER_GLOBAL("moe_compute_rows_before")
       CHECK(sorted_indices->ndim == 1);
       CHECK(total_rows_before_expert->ndim == 1);
 
+      auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
+      ICHECK(func != nullptr);
+      cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator 
void*());
+
       int num_experts = total_rows_before_expert->shape[0];
       compute_total_rows_before_expert(
           reinterpret_cast<int*>(sorted_indices->data), 
sorted_indices->shape[0], num_experts,
-          reinterpret_cast<int64_t*>(total_rows_before_expert->data), nullptr);
+          reinterpret_cast<int64_t*>(total_rows_before_expert->data), stream);
     });
 
 }  // namespace runtime

Reply via email to