This is an automated email from the ASF dual-hosted git repository. masahi pushed a commit to branch unity in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 33989b48d9e23e64ea1806425393e5a18011f288 Author: Wuwei Lin <wu...@apache.org> AuthorDate: Sun Dec 10 21:06:56 2023 +0000 moe update --- 3rdparty/cutlass_fpA_intB_gemm | 2 +- cmake/modules/contrib/CUTLASS.cmake | 5 ++++ python/tvm/dlight/gpu/fallback.py | 2 +- src/runtime/contrib/cutlass/weight_preprocess.cc | 14 ++++++++---- src/runtime/contrib/thrust/thrust.cu | 29 ++++++++++++++++++++++++ 5 files changed, 45 insertions(+), 7 deletions(-) diff --git a/3rdparty/cutlass_fpA_intB_gemm b/3rdparty/cutlass_fpA_intB_gemm index 6b744ba3a2..ce04004923 160000 --- a/3rdparty/cutlass_fpA_intB_gemm +++ b/3rdparty/cutlass_fpA_intB_gemm @@ -1 +1 @@ -Subproject commit 6b744ba3a272d38451235b08a873a0d8add737f0 +Subproject commit ce04004923035845f1b092a0ad3b74e760a7a014 diff --git a/cmake/modules/contrib/CUTLASS.cmake b/cmake/modules/contrib/CUTLASS.cmake index bd3e3b1166..b21e3ec918 100644 --- a/cmake/modules/contrib/CUTLASS.cmake +++ b/cmake/modules/contrib/CUTLASS.cmake @@ -22,7 +22,12 @@ if(USE_CUDA AND USE_CUTLASS) set(CUTLASS_DIR ${PROJECT_SOURCE_DIR}/3rdparty/cutlass) add_subdirectory(${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm) add_subdirectory(${PROJECT_SOURCE_DIR}/3rdparty/libflash_attn) + + include_directories(3rdparty/cutlass_fpA_intB_gemm + 3rdparty/cutlass_fpA_intB_gemm/cutlass/include) # FIXME list(APPEND RUNTIME_SRCS src/runtime/contrib/cutlass/weight_preprocess.cc) + list(APPEND RUNTIME_SRCS src/runtime/contrib/cutlass/moe_gemm.cc) + list(APPEND RUNTIME_SRCS src/runtime/contrib/cutlass/moe_compute_rows.cu) message(STATUS "Build with CUTLASS") endif() diff --git a/python/tvm/dlight/gpu/fallback.py b/python/tvm/dlight/gpu/fallback.py index 3e0dbbcdaa..030e8d374e 100644 --- a/python/tvm/dlight/gpu/fallback.py +++ b/python/tvm/dlight/gpu/fallback.py @@ -63,7 +63,7 @@ class Fallback(ScheduleRule): {"S": s_loops, "R": r_loops, "O": o_loops}[iter_type].append(loop) if not s_loops: - s_loops.append(sch.add_unit_loop(block.block)) + s_loops.append(sch.add_unit_loop(block)) sch.reorder(*s_loops, *r_loops, *o_loops) bx, tx = sch.split( # pylint: disable=invalid-name sch.fuse(*s_loops), diff --git a/src/runtime/contrib/cutlass/weight_preprocess.cc b/src/runtime/contrib/cutlass/weight_preprocess.cc index ef80627cc7..eb60891ac6 100644 --- a/src/runtime/contrib/cutlass/weight_preprocess.cc +++ b/src/runtime/contrib/cutlass/weight_preprocess.cc @@ -23,6 +23,8 @@ #include "../../../3rdparty/cutlass_fpA_intB_gemm/cutlass_kernels/cutlass_preprocessors.h" + + namespace tvm { namespace runtime { @@ -37,17 +39,19 @@ namespace runtime { // The preprocessing functions are defined in C++, so we need to copy the input weight to CPU. TVM_REGISTER_GLOBAL("cutlass.ft_preprocess_weight") .set_body_typed([](NDArray packed_weight, int sm, bool is_int4) { - int rows = packed_weight->shape[0]; - int cols = packed_weight->shape[1]; - std::vector<int8_t> input_cpu(rows * cols); - std::vector<int8_t> output_cpu(rows * cols); + bool is_2d = packed_weight->ndim == 2; + int num_experts = 1; + int rows = packed_weight->shape[is_2d ? 0 : 1]; + int cols = packed_weight->shape[is_2d ? 0 : 1]; + std::vector<int8_t> input_cpu(num_experts * rows * cols); + std::vector<int8_t> output_cpu(num_experts * rows * cols); packed_weight.CopyToBytes(input_cpu.data(), input_cpu.size()); // multiply cols by 2 since the "col" params in preprocess_weights refers to the column of // the unpacked weight. if (is_int4) { cols *= 2; } - fastertransformer::preprocess_weights(output_cpu.data(), input_cpu.data(), rows, cols, + fastertransformer::preprocess_weights(output_cpu.data(), input_cpu.data(), is_2d ? -1 : num_experts, rows, cols, is_int4, sm); auto out = NDArray::Empty(packed_weight.Shape(), packed_weight->dtype, packed_weight->device); out.CopyFromBytes(output_cpu.data(), output_cpu.size()); diff --git a/src/runtime/contrib/thrust/thrust.cu b/src/runtime/contrib/thrust/thrust.cu index d633072835..3a0699da6a 100644 --- a/src/runtime/contrib/thrust/thrust.cu +++ b/src/runtime/contrib/thrust/thrust.cu @@ -21,6 +21,7 @@ * \file Use external Thrust library call */ +#include <cuda_fp16.h> #include <thrust/device_ptr.h> #include <thrust/device_vector.h> #include <thrust/sort.h> @@ -140,6 +141,18 @@ void thrust_sort_common(DLTensor* input, } else { LOG(FATAL) << "Unsupported output dtype: " << out_dtype; } + } else if (data_dtype == "float16") { + if (out_dtype == "int32") { + thrust_sort<half, int32_t>(input, values_out, indices_out, is_ascend, sort_len); + } else if (out_dtype == "int64") { + thrust_sort<half, int64_t>(input, values_out, indices_out, is_ascend, sort_len); + } else if (out_dtype == "float32") { + thrust_sort<half, float>(input, values_out, indices_out, is_ascend, sort_len); + } else if (out_dtype == "float64") { + thrust_sort<half, double>(input, values_out, indices_out, is_ascend, sort_len); + } else { + LOG(FATAL) << "Unsupported output dtype: " << out_dtype; + } } else if (data_dtype == "int32") { if (out_dtype == "int32") { thrust_sort<int32_t, int32_t>(input, values_out, indices_out, is_ascend, sort_len); @@ -185,6 +198,22 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort") data_dtype, out_dtype); }); +TVM_REGISTER_GLOBAL("tvm.contrib.thrust.sort_dps") +.set_body([](TVMArgs args, TVMRetValue* ret) { + ICHECK_GE(args.num_args, 4); + DLTensor* input = args[0]; + DLTensor* values_out = args[2]; + DLTensor* indices_out = args[3]; + bool is_ascend = args[1]; + + auto data_dtype = DLDataType2String(input->dtype); + auto out_dtype = DLDataType2String(indices_out->dtype); + + int n_values = input->shape[input->ndim - 1]; + thrust_sort_common(input, values_out, indices_out, is_ascend, n_values, + data_dtype, out_dtype); +}); + template<typename KeyType, typename ValueType> void thrust_stable_sort_by_key(DLTensor* keys_in, DLTensor* values_in,