This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 02f6f7daf233769345a1ca5f0c35492f337a98ce
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Mon Dec 11 16:35:45 2023 +0000

    fix
---
 src/relax/op/tensor/index.cc                     |  2 +-
 src/runtime/contrib/cutlass/moe_gemm.cc          |  7 -----
 src/runtime/contrib/cutlass/weight_preprocess.cc | 37 ++++++++++++++++++++----
 3 files changed, 33 insertions(+), 13 deletions(-)

diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc
index 6d9dfc86ba..6cda62ad0e 100644
--- a/src/relax/op/tensor/index.cc
+++ b/src/relax/op/tensor/index.cc
@@ -167,7 +167,7 @@ inline PrimExpr CanonicalizeIndex(PrimExpr index, PrimExpr 
extent, int64_t strid
   // include/tvm/topi/detail/strided_slice.h
   PrimExpr begin_range = stride < 0 ? -1 : 0;
   PrimExpr end_range = stride < 0 ? extent - 1 : extent;
-  index = if_then_else(index < 0, index + extent, index);
+  // index = if_then_else(index < 0, index + extent, index); // FIXME
   return assume_inbound ? index : min(max(index, begin_range), end_range);  // 
NOLINT
 }
 
diff --git a/src/runtime/contrib/cutlass/moe_gemm.cc 
b/src/runtime/contrib/cutlass/moe_gemm.cc
index a1796991c3..046247d108 100644
--- a/src/runtime/contrib/cutlass/moe_gemm.cc
+++ b/src/runtime/contrib/cutlass/moe_gemm.cc
@@ -54,19 +54,12 @@ namespace runtime {
 TVM_REGISTER_GLOBAL("cutlass.moe_gemm_f16f16")
     .set_body_typed([](NDArray x, NDArray weight, NDArray 
total_rows_before_expert,
                        int64_t total_rows, int64_t n, int64_t k, int64_t 
num_experts, NDArray out) {
-      LOG(INFO) << "GEMM MOE F16F16";
-      LOG(INFO) << "x: " << x->data << " weight: " << weight->data
-                << " total_rows_before_expert: " << 
total_rows_before_expert->data
-                << " total_rows: " << total_rows << " n: " << n << " k: " << k
-                << " num_experts: " << num_experts << " out: " << out->data;
-      //   using half = cutlass::half_t;
       fastertransformer::moe_gemm_bias_act<half, half>(
           reinterpret_cast<half*>(x->data), 
reinterpret_cast<half*>(weight->data), nullptr, nullptr,
           reinterpret_cast<half*>(out->data),
           reinterpret_cast<int64_t*>(total_rows_before_expert->data), 
total_rows, n, k, num_experts,
           std::nullopt,
           /*stream=*/nullptr /*FIXME*/);
-      LOG(INFO) << "MOE OK";
     });
 
 TVM_REGISTER_GLOBAL("cutlass.moe_gemm_s4f16")
diff --git a/src/runtime/contrib/cutlass/weight_preprocess.cc 
b/src/runtime/contrib/cutlass/weight_preprocess.cc
index eb60891ac6..adbb67fb51 100644
--- a/src/runtime/contrib/cutlass/weight_preprocess.cc
+++ b/src/runtime/contrib/cutlass/weight_preprocess.cc
@@ -17,14 +17,13 @@
  * under the License.
  */
 
+#include <cuda_fp16.h>
 #include <tvm/runtime/ndarray.h>
 #include <tvm/runtime/packed_func.h>
 #include <tvm/runtime/registry.h>
 
 #include 
"../../../3rdparty/cutlass_fpA_intB_gemm/cutlass_kernels/cutlass_preprocessors.h"
 
-
-
 namespace tvm {
 namespace runtime {
 
@@ -42,7 +41,8 @@ TVM_REGISTER_GLOBAL("cutlass.ft_preprocess_weight")
       bool is_2d = packed_weight->ndim == 2;
       int num_experts = 1;
       int rows = packed_weight->shape[is_2d ? 0 : 1];
-      int cols = packed_weight->shape[is_2d ? 0 : 1];
+      int cols = packed_weight->shape[is_2d ? 1 : 2];
+
       std::vector<int8_t> input_cpu(num_experts * rows * cols);
       std::vector<int8_t> output_cpu(num_experts * rows * cols);
       packed_weight.CopyToBytes(input_cpu.data(), input_cpu.size());
@@ -51,12 +51,39 @@ TVM_REGISTER_GLOBAL("cutlass.ft_preprocess_weight")
       if (is_int4) {
         cols *= 2;
       }
-      fastertransformer::preprocess_weights(output_cpu.data(), 
input_cpu.data(), is_2d ? -1 : num_experts, rows, cols,
-                                            is_int4, sm);
+      fastertransformer::preprocess_weights(output_cpu.data(), 
input_cpu.data(),
+                                            is_2d ? -1 : num_experts, rows, 
cols, is_int4, sm);
       auto out = NDArray::Empty(packed_weight.Shape(), packed_weight->dtype, 
packed_weight->device);
       out.CopyFromBytes(output_cpu.data(), output_cpu.size());
       return out;
     });
 
+TVM_REGISTER_GLOBAL("cutlass.symmetric_quantize").set_body_typed([](NDArray 
weight, bool is_int4) {
+  CHECK(is_int4);
+  CHECK(weight->dtype.code == kDLFloat && weight->dtype.bits == 16);
+  CHECK(weight->ndim == 3);
+  CHECK(weight->device.device_type == kDLCPU);
+  int64_t num_experts = weight->shape[0];
+  int64_t rows = weight->shape[1];
+  int64_t cols = weight->shape[2];
+
+  ShapeTuple out_weight_shape{num_experts, rows, cols / 2};
+  ShapeTuple out_scale_shape{num_experts, cols};
+  auto out_weight = NDArray::Empty(
+      out_weight_shape, DLDataType{.code = kDLInt, .bits = 8, .lanes = 1}, 
weight->device);
+  auto out_scale = NDArray::Empty(
+      out_scale_shape, DLDataType{.code = kDLFloat, .bits = 16, .lanes = 1}, 
weight->device);
+
+  fastertransformer::symmetric_quantize<half, half>(
+      reinterpret_cast<int8_t*>(out_weight->data), 
reinterpret_cast<half*>(out_scale->data),
+      reinterpret_cast<const half*>(weight->data),
+      std::vector<size_t>{static_cast<size_t>(num_experts), 
static_cast<size_t>(rows),
+                          static_cast<size_t>(cols)},
+      true);
+  // out_weight.CopyFromBytes(output_cpu.data(), output_cpu.size());
+  // out_scale.CopyFromBytes(scales_cpu.data(), scales_cpu.size() * 
sizeof(half));
+  return Array<NDArray>{out_weight, out_scale};
+});
+
 }  // namespace runtime
 }  // namespace tvm

Reply via email to