This is an automated email from the ASF dual-hosted git repository.

lukhut pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 77b71fc830 [CMSIS-NN] Support for Softmax Int16 operator (#15407)
77b71fc830 is described below

commit 77b71fc8304467ba6a86433066b1a86eb8c225c6
Author: Codrut-Grigore Irimie <78698310+nike...@users.noreply.github.com>
AuthorDate: Wed Aug 9 17:07:03 2023 +0300

    [CMSIS-NN] Support for Softmax Int16 operator (#15407)
    
    * Support for int16 Softmax in CMSIS-NN
    * Supporting integration test
---
 python/tvm/relay/op/contrib/cmsisnn.py             |  14 +-
 src/relay/backend/contrib/cmsisnn/compute_luts.cc  |  76 +++++++++++
 src/relay/backend/contrib/cmsisnn/compute_luts.h   |  55 ++++++++
 src/relay/backend/contrib/cmsisnn/relay_to_tir.cc  | 151 +++++++++++++++++----
 .../backend/contrib/cmsisnn/tir_to_runtime.cc      |  56 ++++++++
 tests/python/contrib/test_cmsisnn/test_softmax.py  |  43 ++++++
 6 files changed, 364 insertions(+), 31 deletions(-)

diff --git a/python/tvm/relay/op/contrib/cmsisnn.py 
b/python/tvm/relay/op/contrib/cmsisnn.py
index cf32947446..ed620f0ff1 100644
--- a/python/tvm/relay/op/contrib/cmsisnn.py
+++ b/python/tvm/relay/op/contrib/cmsisnn.py
@@ -86,11 +86,21 @@ def pattern_table():
         zero_point = pattern.args[2].data.numpy().item(0)
 
         # check for dtypes of quantize and dequantize
-        return (
+        if (
             (scale == 1.0 / 256 and zero_point == -128)
             and pattern.attrs.out_dtype == "int8"
             and dequantize_call.args[0].checked_type.dtype == "int8"
-        )
+        ):
+            return True
+
+        if (
+            (scale == 1.0 / 32768 and zero_point == 0)
+            and pattern.attrs.out_dtype == "int16"
+            and dequantize_call.args[0].checked_type.dtype == "int16"
+        ):
+            return True
+
+        return False
 
     def qnn_conv2d_pattern(with_pad):
         """Create pattern for qnn.conv2D with optional pad and/or optional 
fused relu."""
diff --git a/src/relay/backend/contrib/cmsisnn/compute_luts.cc 
b/src/relay/backend/contrib/cmsisnn/compute_luts.cc
new file mode 100644
index 0000000000..13dcb395b3
--- /dev/null
+++ b/src/relay/backend/contrib/cmsisnn/compute_luts.cc
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file src/relay/backend/contrib/cmsisnn/compute_luts.cc
+ * \brief Creates LUTs for operators in different bit formats for accelerating 
computations.
+ */
+
+#include "compute_luts.h"
+
+#include <algorithm>
+#include <cmath>
+#include <limits>
+
+namespace tvm {
+namespace relay {
+namespace contrib {
+namespace cmsisnn {
+
+void CalculateLUTInt16(int key_zero_point, float key_scale, int 
value_zero_point, float value_scale,
+                       float (*func)(float), const int steps, int16_t* lut) {
+  const float value_min = 
static_cast<float>(std::numeric_limits<int16_t>::min());
+  const float value_max = 
static_cast<float>(std::numeric_limits<int16_t>::max());
+  const float key_min_deq = key_scale * (std::numeric_limits<int16_t>::min() - 
key_zero_point);
+  const float key_max_deq = key_scale * (std::numeric_limits<int16_t>::max() - 
key_zero_point);
+  const float value_min_deq =
+      value_scale * (std::numeric_limits<int16_t>::min() - value_zero_point);
+  const float value_max_deq =
+      value_scale * (std::numeric_limits<int16_t>::max() - value_zero_point);
+
+  const float step_size_deq = (key_max_deq - key_min_deq) / (steps - 1);
+  const float half_step_size_deq = step_size_deq / 2;
+
+  const float value_inv_quantizing =
+      (std::numeric_limits<int16_t>::max() - 
std::numeric_limits<int16_t>::min() + 1) /
+      (value_max_deq - value_min_deq);
+
+  for (int i = 0; i < steps - 1; i++) {
+    float value_deq = func(key_min_deq + i * step_size_deq);
+    float mid_value_deq = func(key_min_deq + i * step_size_deq + 
half_step_size_deq);
+    float next_value_deq = func(key_min_deq + (i + 1) * step_size_deq);
+
+    float value = std::round(value_deq * value_inv_quantizing);
+    float mid_value = std::round(mid_value_deq * value_inv_quantizing);
+    float next_value = std::round(next_value_deq * value_inv_quantizing);
+    float mid_iterp_value = std::round((value + next_value) / 2);
+
+    float mid_err = mid_iterp_value - mid_value;
+    float bias = std::round(mid_err / 2);
+
+    lut[i] = static_cast<int16_t>(std::max(std::min(value - bias, value_max), 
value_min));
+  }
+
+  lut[steps - 1] = static_cast<int16_t>(
+      std::max(std::min(func(value_max_deq) * value_inv_quantizing, 
value_max), value_min));
+}
+
+}  // namespace cmsisnn
+}  // namespace contrib
+}  // namespace relay
+}  // namespace tvm
diff --git a/src/relay/backend/contrib/cmsisnn/compute_luts.h 
b/src/relay/backend/contrib/cmsisnn/compute_luts.h
new file mode 100644
index 0000000000..eca4127e40
--- /dev/null
+++ b/src/relay/backend/contrib/cmsisnn/compute_luts.h
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/contrib/cmsisnn/compute_luts.h
+ * \brief CMSIS-NN LUTs calculation functions
+ */
+
+#ifndef TVM_RELAY_BACKEND_CONTRIB_CMSISNN_COMPUTE_LUTS_H_
+#define TVM_RELAY_BACKEND_CONTRIB_CMSISNN_COMPUTE_LUTS_H_
+
+#include <cstdint>
+
+namespace tvm {
+namespace relay {
+namespace contrib {
+namespace cmsisnn {
+
+/*!
+ * \brief Populates an int16 LUT based on the quantization parameters of its 
keys, values and
+ * respective transformation function
+ *
+ * \param key_zero_point - zero point of table's keys
+ * \param key_scale - scale of the table's keys
+ * \param value_zero_point - zero point of table's values
+ * \param value_scale - scale of the table's values
+ * \param func - function pointer of the transformation performed by the LUT
+ * \param steps - number of total values inside the table
+ * \param lut - int16_t array storing the values of the LUT
+ */
+void CalculateLUTInt16(int key_zero_point, float key_scale, int 
value_zero_point, float value_scale,
+                       float (*func)(float), const int steps, int16_t* lut);
+
+}  // namespace cmsisnn
+}  // namespace contrib
+}  // namespace relay
+}  // namespace tvm
+
+#endif  // TVM_RELAY_BACKEND_CONTRIB_CMSISNN_COMPUTE_LUTS_H_
diff --git a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc 
b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
index 33547f4bd8..49800195f6 100644
--- a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
+++ b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
@@ -30,6 +30,7 @@
 #include "../../../transforms/pattern_utils.h"
 #include "buffer_size.h"
 #include "compiler_attrs.h"
+#include "compute_luts.h"
 #include "convolutions.h"
 
 namespace tvm {
@@ -89,11 +90,17 @@ class RelayToTIRVisitor : public MixedModeMutator {
  private:
   inline IntImm ToArg(int32_t value) { return IntImm(DataType::Int(32), 
value); }
 
-  void CreatePrimFuncForExtern(const GlobalVar& global_var, Array<tir::Var> 
func_signature,
-                               const Map<tir::Var, tir::Buffer>& buffer_map,
-                               tvm::Array<PrimExpr> call_extern_args,
-                               PrimExpr context_buffer_var = PrimExpr(),
-                               int context_buffer_size = 0, int num_bits = 8) {
+  //  struct used to allocated const NDArray
+  struct tir_input_constant_buffers {
+    tir::Var buffer_var;
+    tvm::runtime::NDArray ndarray;
+  };
+
+  void CreatePrimFuncForExtern(
+      const GlobalVar& global_var, Array<tir::Var> func_signature,
+      const Map<tir::Var, tir::Buffer>& buffer_map, tvm::Array<PrimExpr> 
call_extern_args,
+      PrimExpr context_buffer_var = PrimExpr(), int context_buffer_size = 0, 
int num_bits = 8,
+      std::vector<tir_input_constant_buffers> context_const_buffer_vars = {}) {
     Map<String, ObjectRef> dict_attrs;
     dict_attrs.Set(tvm::attr::kGlobalSymbol, global_var->name_hint);
     dict_attrs.Set(tvm::attr::kTarget, target_);
@@ -107,8 +114,22 @@ class RelayToTIRVisitor : public MixedModeMutator {
                            {context_buffer_size}, tir::const_true(), body);
     }
 
+    for (int i = 0; i < static_cast<int>(context_const_buffer_vars.size()); 
i++) {
+      int bits = context_const_buffer_vars[i].ndarray.DataType().bits();
+
+      Array<PrimExpr> extents;
+      for (int shape : context_const_buffer_vars[i].ndarray.Shape()) {
+        extents.push_back(PrimExpr(shape));
+      }
+
+      body = 
tir::AllocateConst(Downcast<tir::Var>(context_const_buffer_vars[i].buffer_var),
+                                DataType::Int(bits), extents, 
context_const_buffer_vars[i].ndarray,
+                                body);
+    }
+
     tir::PrimFunc replacement_func(func_signature, body, VoidType(), 
buffer_map,
                                    DictAttrs(dict_attrs));
+
     ir_module_->Add(global_var, replacement_func);
   }
 
@@ -505,6 +526,7 @@ class RelayToTIRVisitor : public MixedModeMutator {
     const CallNode* softmax_call = quantize_call->args[0].as<CallNode>();
     const CallNode* dequant_call = softmax_call->args[0].as<CallNode>();
     const float quant_scale = 
GetScalarFromConstant<float>(dequant_call->args[1]);
+    const auto bit_width = 
quantize_call->type_as<TensorTypeNode>()->dtype.bits();
 
     // assuming layout as NHWC
     auto shape = quantize_call->type_as<TensorTypeNode>()->shape;
@@ -517,36 +539,107 @@ class RelayToTIRVisitor : public MixedModeMutator {
 
     // calculate multiplier and shift for CMSIS-NN softmax API
     // Note: TensorFlow Lite Micro assumptions
-    // Output zero point and scale are fixed to -128 and 1 / 256
+    // Output zero point and scale are fixed to -128 and 1 / 256 in the case 
of an int8 operator
+    // or to 0 and 1 / 32768 in the case of an int16 operator
     // kScaledDiffIntegerBits, kInputBits, kBeta are described on the 
following github page
     // 
https://github.com/tensorflow/tflite-micro/blob/d97cd0908d8cf5021e9d86f05a49888bee28c2a4/tensorflow/lite/micro/kernels/softmax_common.cc#L47
-    double beta_multiplier = (kBeta * quant_scale * (1 << (31 - kInputBits)));
-    beta_multiplier = std::min<double>(beta_multiplier, (1ll << 31) - 1.0);
-    auto mult_shift_pair = 
tvm::relay::qnn::GetFixedPointMultiplierShift(beta_multiplier);
-    int32_t mult = std::get<0>(mult_shift_pair);
-    int32_t shift = std::get<1>(mult_shift_pair);
-    int32_t diff_min = (1 << kScaledDiffIntegerBits) - 1;
-    diff_min <<= (31 - kScaledDiffIntegerBits);
-    diff_min >>= shift;
-    diff_min *= -1;
+
+    int32_t mult;
+    int32_t shift;
+    int32_t diff_min = 0;
+
+    std::vector<tir_input_constant_buffers> softmax_params(2);
+    Device dev{DLDeviceType::kDLCPU, 0};
+
+    if (bit_width == 8) {
+      double beta_multiplier = (kBeta * quant_scale * (1 << (31 - 
kInputBits)));
+      beta_multiplier = std::min<double>(beta_multiplier, (1ll << 31) - 1.0);
+      auto mult_shift_pair = 
tvm::relay::qnn::GetFixedPointMultiplierShift(beta_multiplier);
+      mult = std::get<0>(mult_shift_pair);
+      shift = std::get<1>(mult_shift_pair);
+      diff_min = (1 << kScaledDiffIntegerBits) - 1;
+      diff_min <<= (31 - kScaledDiffIntegerBits);
+      diff_min >>= shift;
+      diff_min *= -1;
+    } else {  // bit_width == 16
+      double scale_beta_rescale = quant_scale * kBeta / (10.0 / 65535.0);
+      auto mult_shift_pair = 
tvm::relay::qnn::GetFixedPointMultiplierShift(scale_beta_rescale);
+      mult = std::get<0>(mult_shift_pair);
+      shift = std::get<1>(mult_shift_pair);
+
+      const int kLUTEntries = 513;
+      int16_t softmax_s16_exp_lut[kLUTEntries];
+      int16_t softmax_s16_one_by_one_lut[kLUTEntries];
+
+      const int range_int16 =
+          std::numeric_limits<int16_t>::max() - 
std::numeric_limits<int16_t>::min();
+      int exp_zero_point = std::numeric_limits<int16_t>::max();
+      float exp_scale = 10.0f / range_int16;
+
+      int one_by_one_zero_point = std::numeric_limits<int16_t>::min();
+      float one_by_one_scale = 1.0f / range_int16;
+
+      int lut_value_zero_point = 0;
+      float lut_value_scale = 2.0f / range_int16;
+
+      CalculateLUTInt16(
+          exp_zero_point, exp_scale, lut_value_zero_point, lut_value_scale,
+          [](float key) { return std::exp(key); }, kLUTEntries, 
softmax_s16_exp_lut);
+      CalculateLUTInt16(
+          one_by_one_zero_point, one_by_one_scale, lut_value_zero_point, 
lut_value_scale,
+          [](float key) { return 1.0f / (1.0f + key); }, kLUTEntries, 
softmax_s16_one_by_one_lut);
+
+      // first LUT
+      softmax_params[0].buffer_var =
+          tir::Var("exp_lut", PointerType(PrimType(DataType::Int(bit_width)), 
"global.workspace"));
+      softmax_params[0].ndarray =
+          runtime::NDArray::Empty({kLUTEntries}, DataType::Int(bit_width), 
dev);
+      softmax_params[0].ndarray.CopyFromBytes(softmax_s16_exp_lut, 
sizeof(int16_t) * kLUTEntries);
+
+      // second LUT
+      softmax_params[1].buffer_var = tir::Var(
+          "one_by_one_lut", PointerType(PrimType(DataType::Int(bit_width)), 
"global.workspace"));
+      softmax_params[1].ndarray =
+          runtime::NDArray::Empty({kLUTEntries}, DataType::Int(bit_width), 
dev);
+      softmax_params[1].ndarray.CopyFromBytes(softmax_s16_one_by_one_lut,
+                                              sizeof(int16_t) * kLUTEntries);
+    }
 
     BufferCreator buffer_creator;
-    tir::Var in_var = buffer_creator.CreateBufferVar("input", 
DataType::Handle(8));
-    tir::Var out_var = buffer_creator.CreateBufferVar("output", 
DataType::Handle(8));
+    tir::Var in_var = buffer_creator.CreateBufferVar("input", 
DataType::Handle(bit_width));
+    tir::Var out_var = buffer_creator.CreateBufferVar("output", 
DataType::Handle(bit_width));
+
+    if (bit_width == 8) {
+      tvm::Array<PrimExpr> args = {
+          tir::StringImm("arm_softmax_s" + std::to_string(bit_width)),
+          in_var,
+          ToArg(num_rows),
+          ToArg(row_size),
+          ToArg(mult),
+          ToArg(shift),
+          ToArg(diff_min),
+          out_var,
+      };
 
-    tvm::Array<PrimExpr> args = {
-        tir::StringImm("arm_softmax_s8"),
-        in_var,
-        ToArg(num_rows),
-        ToArg(row_size),
-        ToArg(mult),
-        ToArg(shift),
-        ToArg(diff_min),
-        out_var,
-    };
+      CreatePrimFuncForExtern(global_var, buffer_creator.GetPrimFuncParams(),
+                              buffer_creator.GetBufferMap(), args);
+    } else {  // bit_width == 16
+      tvm::Array<PrimExpr> args = {
+          tir::StringImm("arm_softmax_s" + std::to_string(bit_width)),
+          in_var,
+          ToArg(num_rows),
+          ToArg(row_size),
+          ToArg(mult),
+          ToArg(shift),
+          softmax_params[0].buffer_var,
+          softmax_params[1].buffer_var,
+          out_var,
+      };
 
-    CreatePrimFuncForExtern(global_var, buffer_creator.GetPrimFuncParams(),
-                            buffer_creator.GetBufferMap(), args);
+      CreatePrimFuncForExtern(global_var, buffer_creator.GetPrimFuncParams(),
+                              buffer_creator.GetBufferMap(), args, PrimExpr(), 
0, 16,
+                              softmax_params);
+    }
   }
 
   struct BinaryElementwiseClipPattern {
diff --git a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc 
b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc
index ea2eabd767..6febfe3486 100644
--- a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc
+++ b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc
@@ -99,6 +99,11 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost {
     int clip_max;
   };
 
+  struct CMSISNNSoftmaxLutS16 {
+    std::string exp_lut_name;
+    std::string one_by_one_lut_name;
+  };
+
   using codegen::CodeGenCHost::VisitStmt_;
 
   /*!  * \brief Emits CMSIS-NN APIs for every call_extern */
@@ -107,6 +112,7 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost {
       CodeGenCHost::VisitExpr_(op, os);
       return;
     }
+
     std::string cmsis_func_name = op->args[0].as<StringImmNode>()->value;
     if (cmsis_func_name == "arm_softmax_s8" || cmsis_func_name == 
"arm_elementwise_mul_s8" ||
         cmsis_func_name == "arm_elementwise_add_s8" ||
@@ -124,6 +130,8 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost {
     } else if (cmsis_func_name == "arm_avgpool_s8" || cmsis_func_name == 
"arm_avgpool_s16" ||
                cmsis_func_name == "arm_max_pool_s8" || cmsis_func_name == 
"arm_max_pool_s16") {
       EmitPool2D(op);
+    } else if (cmsis_func_name == "arm_softmax_s16") {
+      EmitSoftmaxInt16(op);
     }
     return;
   }
@@ -220,6 +228,14 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost {
        << "," << dims.c << "};\n";
     return struct_name;
   }
+  /*!  * \brief Emits cmsis_nn_softmax_params struct */
+  std::string EmitCMSISNNSoftmaxLutS16(std::ostream& os, CMSISNNSoftmaxLutS16 
softmax_params) {
+    std::string struct_name = "softmax_params";
+    PrintIndent();
+    os << "cmsis_nn_softmax_lut_s16 " << struct_name << "= {" << 
softmax_params.exp_lut_name << ", "
+       << softmax_params.one_by_one_lut_name << "};\n";
+    return struct_name;
+  }
 
   /*!  * \brief Deduces variable name from call_extern argument resting at id 
*/
   std::string VarNameFromArg(const CallNode* op, int id) {
@@ -295,6 +311,14 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost {
     dims.c = ValueFromArg(op, ++base_pos);
     return dims;
   }
+  /*!  * \brief extracts CMSIS-NN softmax LUTs from call_extern */
+  CMSISNNSoftmaxLutS16 extract_softmax_softmax_lut_s16(const CallNode* op, int 
exp_lut_pos,
+                                                       int one_by_one_lut_pos) 
{
+    CMSISNNSoftmaxLutS16 softmax_params;
+    softmax_params.exp_lut_name = 
op->args[exp_lut_pos].as<VarNode>()->name_hint;
+    softmax_params.one_by_one_lut_name = 
op->args[one_by_one_lut_pos].as<VarNode>()->name_hint;
+    return softmax_params;
+  }
 
   /*!  * \brief Emits CMSIS-NN APIs for every call_extern comprising 
convolution */
   void EmitConv2D(const CallNode* op) {
@@ -472,6 +496,38 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost {
     EmitErrorCheck();
   }
 
+  void EmitSoftmaxInt16(const CallNode* op) {
+    std::string cmsis_func_name = op->args[0].as<StringImmNode>()->value;
+
+    // extract buffer names from call_extern
+    int arg_id = 0;
+    std::string input_data = VarNameFromArg(op, ++arg_id);
+    int num_rows = ValueFromArg(op, ++arg_id);
+    int row_size = ValueFromArg(op, ++arg_id);
+    int multiplier = ValueFromArg(op, ++arg_id);
+    int shift = ValueFromArg(op, ++arg_id);
+    // extracting LUT names from call_extern
+    CMSISNNSoftmaxLutS16 softmax_params_buffer =
+        extract_softmax_softmax_lut_s16(op, arg_id + 1, arg_id + 2);
+    arg_id += 2;
+    std::string output_data = VarNameFromArg(op, ++arg_id);
+
+    // Emit CMSIS-NN API arguments
+    std::string softmax_params = EmitCMSISNNSoftmaxLutS16(stream, 
softmax_params_buffer);
+
+    PrintIndent();
+    stream << "arm_cmsis_nn_status status = ";
+    stream << cmsis_func_name << "(";
+    stream << input_data << ", ";
+    stream << num_rows << ", ";
+    stream << row_size << ", ";
+    stream << multiplier << ", ";
+    stream << shift << ", ";
+    stream << "&" << softmax_params << ", ";
+    stream << output_data << ");\n";
+    EmitErrorCheck();
+  }
+
   void EmitErrorCheck() {
     auto emit_error = [&](std::string error) {
       if (this->debug_last_error) {
diff --git a/tests/python/contrib/test_cmsisnn/test_softmax.py 
b/tests/python/contrib/test_cmsisnn/test_softmax.py
index 0316d567ad..82547f44f5 100644
--- a/tests/python/contrib/test_cmsisnn/test_softmax.py
+++ b/tests/python/contrib/test_cmsisnn/test_softmax.py
@@ -91,6 +91,49 @@ def test_op_int8(zero_point, scale, compiler_cpu, cpu_flags):
     )
 
 
+@skip_if_no_reference_system
+@tvm.testing.requires_cmsisnn
+@pytest.mark.parametrize(["zero_point", "scale"], [[0, 1.0 / 32768]])
+@pytest.mark.parametrize(
+    "compiler_cpu, cpu_flags", [("cortex-m55", "+nomve"), ("cortex-m55", ""), 
("cortex-m7", "")]
+)
+def test_op_int16(zero_point, scale, compiler_cpu, cpu_flags):
+    """Tests int16 QNN Softmax for CMSIS-NN"""
+    interface_api = "c"
+    use_unpacked_api = True
+
+    dtype = "int16"
+    shape = [1, 16, 16, 3]
+
+    # output scale and zero_point must be fixed
+    model = make_model(shape, dtype, dtype, zero_point, scale, 0, 1.0 / 32768)
+    orig_mod = make_module(model)
+    cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod)
+
+    # validate pattern matching
+    assert_partitioned_function(orig_mod, cmsisnn_mod)
+
+    # validate the output
+    in_min, in_max = get_dtype_range(dtype)
+    np.random.seed(0)
+    input_data = np.random.randint(in_min, high=in_max, size=shape, 
dtype=dtype)
+    inputs = {"in0": input_data}
+    params = {}
+    output_list = generate_ref_data(orig_mod["main"], inputs, params)
+    compile_and_run(
+        AOTTestModel(
+            module=cmsisnn_mod,
+            inputs=inputs,
+            outputs=output_list,
+            params=params,
+            output_tolerance=2,
+        ),
+        create_test_runner(compiler_cpu, cpu_flags),
+        interface_api,
+        use_unpacked_api,
+    )
+
+
 def parameterize_for_invalid_model(test):
     """Generates parameters for non int8 input and output of Softmax"""
     in_dtype = ["uint8", "int8"]

Reply via email to