This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 33989b48d9e23e64ea1806425393e5a18011f288
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Sun Dec 10 21:06:56 2023 +0000

    moe update
---
 3rdparty/cutlass_fpA_intB_gemm                   |  2 +-
 cmake/modules/contrib/CUTLASS.cmake              |  5 ++++
 python/tvm/dlight/gpu/fallback.py                |  2 +-
 src/runtime/contrib/cutlass/weight_preprocess.cc | 14 ++++++++----
 src/runtime/contrib/thrust/thrust.cu             | 29 ++++++++++++++++++++++++
 5 files changed, 45 insertions(+), 7 deletions(-)

diff --git a/3rdparty/cutlass_fpA_intB_gemm b/3rdparty/cutlass_fpA_intB_gemm
index 6b744ba3a2..ce04004923 160000
--- a/3rdparty/cutlass_fpA_intB_gemm
+++ b/3rdparty/cutlass_fpA_intB_gemm
@@ -1 +1 @@
-Subproject commit 6b744ba3a272d38451235b08a873a0d8add737f0
+Subproject commit ce04004923035845f1b092a0ad3b74e760a7a014
diff --git a/cmake/modules/contrib/CUTLASS.cmake 
b/cmake/modules/contrib/CUTLASS.cmake
index bd3e3b1166..b21e3ec918 100644
--- a/cmake/modules/contrib/CUTLASS.cmake
+++ b/cmake/modules/contrib/CUTLASS.cmake
@@ -22,7 +22,12 @@ if(USE_CUDA AND USE_CUTLASS)
   set(CUTLASS_DIR ${PROJECT_SOURCE_DIR}/3rdparty/cutlass)
   add_subdirectory(${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm)
   add_subdirectory(${PROJECT_SOURCE_DIR}/3rdparty/libflash_attn)
+
+  include_directories(3rdparty/cutlass_fpA_intB_gemm
+                      3rdparty/cutlass_fpA_intB_gemm/cutlass/include)  # FIXME
   list(APPEND RUNTIME_SRCS src/runtime/contrib/cutlass/weight_preprocess.cc)
+  list(APPEND RUNTIME_SRCS src/runtime/contrib/cutlass/moe_gemm.cc)
+  list(APPEND RUNTIME_SRCS src/runtime/contrib/cutlass/moe_compute_rows.cu)
 
   message(STATUS "Build with CUTLASS")
 endif()
diff --git a/python/tvm/dlight/gpu/fallback.py 
b/python/tvm/dlight/gpu/fallback.py
index 3e0dbbcdaa..030e8d374e 100644
--- a/python/tvm/dlight/gpu/fallback.py
+++ b/python/tvm/dlight/gpu/fallback.py
@@ -63,7 +63,7 @@ class Fallback(ScheduleRule):
                 {"S": s_loops, "R": r_loops, "O": 
o_loops}[iter_type].append(loop)
 
             if not s_loops:
-                s_loops.append(sch.add_unit_loop(block.block))
+                s_loops.append(sch.add_unit_loop(block))
             sch.reorder(*s_loops, *r_loops, *o_loops)
             bx, tx = sch.split(  # pylint: disable=invalid-name
                 sch.fuse(*s_loops),
diff --git a/src/runtime/contrib/cutlass/weight_preprocess.cc 
b/src/runtime/contrib/cutlass/weight_preprocess.cc
index ef80627cc7..eb60891ac6 100644
--- a/src/runtime/contrib/cutlass/weight_preprocess.cc
+++ b/src/runtime/contrib/cutlass/weight_preprocess.cc
@@ -23,6 +23,8 @@
 
 #include 
"../../../3rdparty/cutlass_fpA_intB_gemm/cutlass_kernels/cutlass_preprocessors.h"
 
+
+
 namespace tvm {
 namespace runtime {
 
@@ -37,17 +39,19 @@ namespace runtime {
 // The preprocessing functions are defined in C++, so we need to copy the 
input weight to CPU.
 TVM_REGISTER_GLOBAL("cutlass.ft_preprocess_weight")
     .set_body_typed([](NDArray packed_weight, int sm, bool is_int4) {
-      int rows = packed_weight->shape[0];
-      int cols = packed_weight->shape[1];
-      std::vector<int8_t> input_cpu(rows * cols);
-      std::vector<int8_t> output_cpu(rows * cols);
+      bool is_2d = packed_weight->ndim == 2;
+      int num_experts = 1;
+      int rows = packed_weight->shape[is_2d ? 0 : 1];
+      int cols = packed_weight->shape[is_2d ? 0 : 1];
+      std::vector<int8_t> input_cpu(num_experts * rows * cols);
+      std::vector<int8_t> output_cpu(num_experts * rows * cols);
       packed_weight.CopyToBytes(input_cpu.data(), input_cpu.size());
       // multiply cols by 2 since the "col" params in preprocess_weights 
refers to the column of
       // the unpacked weight.
       if (is_int4) {
         cols *= 2;
       }
-      fastertransformer::preprocess_weights(output_cpu.data(), 
input_cpu.data(), rows, cols,
+      fastertransformer::preprocess_weights(output_cpu.data(), 
input_cpu.data(), is_2d ? -1 : num_experts, rows, cols,
                                             is_int4, sm);
       auto out = NDArray::Empty(packed_weight.Shape(), packed_weight->dtype, 
packed_weight->device);
       out.CopyFromBytes(output_cpu.data(), output_cpu.size());
diff --git a/src/runtime/contrib/thrust/thrust.cu 
b/src/runtime/contrib/thrust/thrust.cu
index d633072835..3a0699da6a 100644
--- a/src/runtime/contrib/thrust/thrust.cu
+++ b/src/runtime/contrib/thrust/thrust.cu
@@ -21,6 +21,7 @@
  * \file Use external Thrust library call
  */
 
+#include <cuda_fp16.h>
 #include <thrust/device_ptr.h>
 #include <thrust/device_vector.h>
 #include <thrust/sort.h>
@@ -140,6 +141,18 @@ void thrust_sort_common(DLTensor* input,
     } else {
       LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
     }
+  } else if (data_dtype == "float16") {
+    if (out_dtype == "int32") {
+      thrust_sort<half, int32_t>(input, values_out, indices_out, is_ascend, 
sort_len);
+    } else if (out_dtype == "int64") {
+      thrust_sort<half, int64_t>(input, values_out, indices_out, is_ascend, 
sort_len);
+    } else if (out_dtype == "float32") {
+      thrust_sort<half, float>(input, values_out, indices_out, is_ascend, 
sort_len);
+    } else if (out_dtype == "float64") {
+      thrust_sort<half, double>(input, values_out, indices_out, is_ascend, 
sort_len);
+    } else {
+      LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
+    }
   } else if (data_dtype == "int32") {
     if (out_dtype == "int32") {
       thrust_sort<int32_t, int32_t>(input, values_out, indices_out, is_ascend, 
sort_len);
@@ -185,6 +198,22 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort")
                      data_dtype, out_dtype);
 });
 
+TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort_dps")
+.set_body([](TVMArgs args, TVMRetValue* ret) {
+  ICHECK_GE(args.num_args, 4);
+  DLTensor* input = args[0];
+  DLTensor* values_out = args[2];
+  DLTensor* indices_out = args[3];
+  bool is_ascend = args[1];
+
+  auto data_dtype = DLDataType2String(input->dtype);
+  auto out_dtype = DLDataType2String(indices_out->dtype);
+
+  int n_values = input->shape[input->ndim - 1];
+  thrust_sort_common(input, values_out, indices_out, is_ascend, n_values,
+                     data_dtype, out_dtype);
+});
+
 template<typename KeyType, typename ValueType>
 void thrust_stable_sort_by_key(DLTensor* keys_in,
                                DLTensor* values_in,

Reply via email to