This is an automated email from the ASF dual-hosted git repository. masahi pushed a commit to branch unity in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 02f6f7daf233769345a1ca5f0c35492f337a98ce Author: Wuwei Lin <wu...@apache.org> AuthorDate: Mon Dec 11 16:35:45 2023 +0000 fix --- src/relax/op/tensor/index.cc | 2 +- src/runtime/contrib/cutlass/moe_gemm.cc | 7 ----- src/runtime/contrib/cutlass/weight_preprocess.cc | 37 ++++++++++++++++++++---- 3 files changed, 33 insertions(+), 13 deletions(-) diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc index 6d9dfc86ba..6cda62ad0e 100644 --- a/src/relax/op/tensor/index.cc +++ b/src/relax/op/tensor/index.cc @@ -167,7 +167,7 @@ inline PrimExpr CanonicalizeIndex(PrimExpr index, PrimExpr extent, int64_t strid // include/tvm/topi/detail/strided_slice.h PrimExpr begin_range = stride < 0 ? -1 : 0; PrimExpr end_range = stride < 0 ? extent - 1 : extent; - index = if_then_else(index < 0, index + extent, index); + // index = if_then_else(index < 0, index + extent, index); // FIXME return assume_inbound ? index : min(max(index, begin_range), end_range); // NOLINT } diff --git a/src/runtime/contrib/cutlass/moe_gemm.cc b/src/runtime/contrib/cutlass/moe_gemm.cc index a1796991c3..046247d108 100644 --- a/src/runtime/contrib/cutlass/moe_gemm.cc +++ b/src/runtime/contrib/cutlass/moe_gemm.cc @@ -54,19 +54,12 @@ namespace runtime { TVM_REGISTER_GLOBAL("cutlass.moe_gemm_f16f16") .set_body_typed([](NDArray x, NDArray weight, NDArray total_rows_before_expert, int64_t total_rows, int64_t n, int64_t k, int64_t num_experts, NDArray out) { - LOG(INFO) << "GEMM MOE F16F16"; - LOG(INFO) << "x: " << x->data << " weight: " << weight->data - << " total_rows_before_expert: " << total_rows_before_expert->data - << " total_rows: " << total_rows << " n: " << n << " k: " << k - << " num_experts: " << num_experts << " out: " << out->data; - // using half = cutlass::half_t; fastertransformer::moe_gemm_bias_act<half, half>( reinterpret_cast<half*>(x->data), reinterpret_cast<half*>(weight->data), nullptr, nullptr, reinterpret_cast<half*>(out->data), reinterpret_cast<int64_t*>(total_rows_before_expert->data), total_rows, n, k, num_experts, std::nullopt, /*stream=*/nullptr /*FIXME*/); - LOG(INFO) << "MOE OK"; }); TVM_REGISTER_GLOBAL("cutlass.moe_gemm_s4f16") diff --git a/src/runtime/contrib/cutlass/weight_preprocess.cc b/src/runtime/contrib/cutlass/weight_preprocess.cc index eb60891ac6..adbb67fb51 100644 --- a/src/runtime/contrib/cutlass/weight_preprocess.cc +++ b/src/runtime/contrib/cutlass/weight_preprocess.cc @@ -17,14 +17,13 @@ * under the License. */ +#include <cuda_fp16.h> #include <tvm/runtime/ndarray.h> #include <tvm/runtime/packed_func.h> #include <tvm/runtime/registry.h> #include "../../../3rdparty/cutlass_fpA_intB_gemm/cutlass_kernels/cutlass_preprocessors.h" - - namespace tvm { namespace runtime { @@ -42,7 +41,8 @@ TVM_REGISTER_GLOBAL("cutlass.ft_preprocess_weight") bool is_2d = packed_weight->ndim == 2; int num_experts = 1; int rows = packed_weight->shape[is_2d ? 0 : 1]; - int cols = packed_weight->shape[is_2d ? 0 : 1]; + int cols = packed_weight->shape[is_2d ? 1 : 2]; + std::vector<int8_t> input_cpu(num_experts * rows * cols); std::vector<int8_t> output_cpu(num_experts * rows * cols); packed_weight.CopyToBytes(input_cpu.data(), input_cpu.size()); @@ -51,12 +51,39 @@ TVM_REGISTER_GLOBAL("cutlass.ft_preprocess_weight") if (is_int4) { cols *= 2; } - fastertransformer::preprocess_weights(output_cpu.data(), input_cpu.data(), is_2d ? -1 : num_experts, rows, cols, - is_int4, sm); + fastertransformer::preprocess_weights(output_cpu.data(), input_cpu.data(), + is_2d ? -1 : num_experts, rows, cols, is_int4, sm); auto out = NDArray::Empty(packed_weight.Shape(), packed_weight->dtype, packed_weight->device); out.CopyFromBytes(output_cpu.data(), output_cpu.size()); return out; }); +TVM_REGISTER_GLOBAL("cutlass.symmetric_quantize").set_body_typed([](NDArray weight, bool is_int4) { + CHECK(is_int4); + CHECK(weight->dtype.code == kDLFloat && weight->dtype.bits == 16); + CHECK(weight->ndim == 3); + CHECK(weight->device.device_type == kDLCPU); + int64_t num_experts = weight->shape[0]; + int64_t rows = weight->shape[1]; + int64_t cols = weight->shape[2]; + + ShapeTuple out_weight_shape{num_experts, rows, cols / 2}; + ShapeTuple out_scale_shape{num_experts, cols}; + auto out_weight = NDArray::Empty( + out_weight_shape, DLDataType{.code = kDLInt, .bits = 8, .lanes = 1}, weight->device); + auto out_scale = NDArray::Empty( + out_scale_shape, DLDataType{.code = kDLFloat, .bits = 16, .lanes = 1}, weight->device); + + fastertransformer::symmetric_quantize<half, half>( + reinterpret_cast<int8_t*>(out_weight->data), reinterpret_cast<half*>(out_scale->data), + reinterpret_cast<const half*>(weight->data), + std::vector<size_t>{static_cast<size_t>(num_experts), static_cast<size_t>(rows), + static_cast<size_t>(cols)}, + true); + // out_weight.CopyFromBytes(output_cpu.data(), output_cpu.size()); + // out_scale.CopyFromBytes(scales_cpu.data(), scales_cpu.size() * sizeof(half)); + return Array<NDArray>{out_weight, out_scale}; +}); + } // namespace runtime } // namespace tvm