ashutosh-arm commented on a change in pull request #9331:
URL: https://github.com/apache/tvm/pull/9331#discussion_r733864363



##########
File path: tests/python/contrib/test_cmsisnn/test_conv2d.py
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""CMSIS-NN integration tests: Conv2D"""
+import itertools
+import numpy as np
+import pytest
+import tvm
+from tvm import relay
+from tvm.relay.op.contrib import cmsisnn
+
+
+from tests.python.relay.aot.aot_test_utils import (
+    AOTTestModel,
+    AOT_CORSTONE300_RUNNER,
+    AOT_DEFAULT_RUNNER,
+    generate_ref_data,
+    compile_and_run,
+)
+from utils import (
+    skip_if_no_reference_system,
+    make_module,
+    count_num_calls,
+    get_range_for_dtype_str,
+    get_same_padding,
+    get_conv2d_qnn_params,
+    make_qnn_relu,
+)
+
+
+def make_model(
+    shape,
+    kernel_shape,
+    input_zp,
+    input_sc,
+    kernel_zp,
+    kernel_sc,
+    output_zp,
+    output_sc,
+    padding,
+    strides,
+    dilation,
+    groups,
+    dtype,
+    kernel_dtype,
+    out_channels,
+    weight_format,
+    enable_bias,
+    relu_type,
+):
+    """Return a model and any parameters it may have"""
+    h_index = weight_format.index("H")
+    w_index = weight_format.index("W")
+    kernel_h = kernel_shape[h_index]
+    kernel_w = kernel_shape[w_index]
+    a = relay.var("in0", shape=shape, dtype=dtype)
+    p = (0, 0, 0, 0)
+    if padding == "SAME":
+        p = get_same_padding((shape[1], shape[2]), (kernel_h, kernel_w), 
dilation, strides)
+        a = relay.nn.pad(
+            a,
+            pad_width=[(0, 0), (p[0], p[2]), (p[1], p[3]), (0, 0)],
+            pad_value=input_zp,
+            pad_mode="constant",
+        )
+        shape = (shape[0], shape[1] + p[0] + p[2], shape[2] + p[1] + p[3], 
shape[3])
+
+    weight_shape = (kernel_h, kernel_w, shape[3] // groups, out_channels)
+    w = tvm.nd.array(
+        np.random.randint(
+            np.iinfo(kernel_dtype).min,
+            high=np.iinfo(kernel_dtype).max,
+            size=weight_shape,
+            dtype=kernel_dtype,
+        )
+    )
+    weights = relay.const(w, kernel_dtype)
+    conv = relay.qnn.op.conv2d(
+        a,
+        weights,
+        input_zero_point=relay.const(input_zp, "int32"),
+        kernel_zero_point=relay.const(kernel_zp, "int32"),
+        input_scale=relay.const(input_sc, "float32"),
+        kernel_scale=relay.const(kernel_sc, "float32"),
+        kernel_size=(kernel_h, kernel_w),
+        data_layout="NHWC",
+        kernel_layout=weight_format,
+        dilation=dilation,
+        strides=strides,
+        groups=groups,
+        channels=out_channels,
+        padding=p,
+        out_dtype="int32",
+    )
+    b = tvm.nd.array(np.random.randint(0, high=10, size=(out_channels,), 
dtype="int32"))
+    bc = relay.const(b, "int32")
+    bias = conv
+    if enable_bias:
+        bias = relay.nn.bias_add(conv, bc, axis=3)
+    requant_input_sc = [sc * input_sc for sc in kernel_sc]
+    req = relay.qnn.op.requantize(
+        bias,
+        relay.const(requant_input_sc, "float32"),
+        relay.const(0, "int32"),
+        relay.const(output_sc, "float32"),
+        relay.const(output_zp, "int32"),
+        out_dtype=dtype,
+    )
+    relu = make_qnn_relu(req, relu_type, output_sc, output_zp, dtype)
+    params = {"w": w, "b": b}
+    return relu, params
+
+
+@tvm.testing.requires_cmsisnn
+@pytest.mark.parametrize("ifm_shape", [(1, 28, 28, 12), (1, 64, 100, 4)])
+@pytest.mark.parametrize("kernel_size", [(3, 3)])
+@pytest.mark.parametrize("padding", ["SAME", "VALID"])
+@pytest.mark.parametrize("strides, dilation", [((2, 2), (1, 1)), ((1, 1), (1, 
1))])
+@pytest.mark.parametrize("enable_bias", [True, False])
+@pytest.mark.parametrize("relu_type", ["NONE", "RELU"])
+@pytest.mark.parametrize(
+    "in_zp, in_sc, k_sc, out_channels",
+    [(10, 0.0128, [0.11, 0.22], 2), (-64, 1, [1, 0.0256, 1.37], 3)],
+)
+def test_op_int8(
+    ifm_shape,
+    kernel_size,
+    padding,
+    strides,
+    dilation,
+    enable_bias,
+    relu_type,
+    in_zp,
+    in_sc,
+    k_sc,
+    out_channels,
+):
+    interface_api = "c"
+    use_unpacked_api = True
+    test_runner = AOT_CORSTONE300_RUNNER
+
+    k_zp = 0
+    groups = 1
+    weight_format = "HWIO"
+    kernel_h = kernel_size[0]
+    kernel_w = kernel_size[1]
+    dtype = "int8"
+    in_min, in_max = get_range_for_dtype_str(dtype)
+
+    weight_shape = None
+    if weight_format == "HWIO":
+        weight_shape = (kernel_h, kernel_w, ifm_shape[3] // groups, 
out_channels)
+    else:
+        weight_shape = (kernel_h, kernel_w, ifm_shape[3], out_channels)
+
+    out_sc, out_zp = get_conv2d_qnn_params(
+        weight_shape, in_sc, in_zp, k_sc, k_zp, dtype, dtype, dtype, False
+    )
+
+    model, params = make_model(
+        ifm_shape,
+        weight_shape,
+        in_zp,
+        in_sc,
+        k_zp,
+        k_sc,
+        out_zp,
+        out_sc,
+        padding,
+        strides,
+        dilation,
+        groups,
+        dtype,
+        dtype,
+        out_channels,
+        weight_format,
+        enable_bias,
+        relu_type,
+    )
+    orig_mod = make_module(model)
+    cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod, params)
+
+    # validate pattern matching
+    attrs = [
+        cmsisnn_mod[var.name_hint].attrs
+        for var in cmsisnn_mod.get_global_vars()
+        if cmsisnn_mod[var.name_hint].attrs
+    ]
+    assert any(attrs), "At least one function with external attributes was 
expected."
+
+    compilers = [
+        key == "Compiler" and value == "cmsisnn" for attr in attrs for key, 
value in attr.items()
+    ]
+    assert any(compilers), "Module does not contain function for cmsisnn 
target."
+
+    assert count_num_calls(orig_mod) == count_num_calls(
+        cmsisnn_mod
+    ), "Number of calls changed during partitioning"
+
+    # validate the output
+    np.random.seed(0)

Review comment:
       Thanks! This was new to me.

##########
File path: tests/python/contrib/test_cmsisnn/test_conv2d.py
##########
@@ -0,0 +1,303 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""CMSIS-NN integration tests: Conv2D"""
+import itertools
+import numpy as np
+import pytest
+import tvm
+from tvm import relay
+from tvm.relay.op.contrib import cmsisnn
+
+
+from tests.python.relay.aot.aot_test_utils import (
+    AOTTestModel,
+    AOT_CORSTONE300_RUNNER,
+    AOT_DEFAULT_RUNNER,
+    generate_ref_data,
+    compile_and_run,
+)
+from utils import (
+    skip_if_no_reference_system,
+    make_module,
+    count_num_calls,
+    get_range_for_dtype_str,
+    get_same_padding,
+    get_conv2d_qnn_params,
+    make_qnn_relu,
+)
+
+
+def make_model(
+    shape,
+    kernel_shape,
+    input_zp,
+    input_sc,
+    kernel_zp,
+    kernel_sc,
+    output_zp,
+    output_sc,
+    padding,
+    strides,
+    dilation,
+    groups,
+    dtype,
+    kernel_dtype,
+    out_channels,
+    weight_format,
+    enable_bias,
+    relu_type,
+):
+    """Return a model and any parameters it may have"""
+    h_index = weight_format.index("H")
+    w_index = weight_format.index("W")
+    kernel_h = kernel_shape[h_index]
+    kernel_w = kernel_shape[w_index]
+    a = relay.var("in0", shape=shape, dtype=dtype)
+    p = (0, 0, 0, 0)
+    if padding == "SAME":
+        p = get_same_padding((shape[1], shape[2]), (kernel_h, kernel_w), 
dilation, strides)
+        a = relay.nn.pad(
+            a,
+            pad_width=[(0, 0), (p[0], p[2]), (p[1], p[3]), (0, 0)],
+            pad_value=input_zp,
+            pad_mode="constant",
+        )
+        shape = (shape[0], shape[1] + p[0] + p[2], shape[2] + p[1] + p[3], 
shape[3])
+
+    weight_shape = (kernel_h, kernel_w, shape[3] // groups, out_channels)
+    w = tvm.nd.array(
+        np.random.randint(
+            np.iinfo(kernel_dtype).min,
+            high=np.iinfo(kernel_dtype).max,
+            size=weight_shape,
+            dtype=kernel_dtype,
+        )
+    )
+    weights = relay.const(w, kernel_dtype)
+    conv = relay.qnn.op.conv2d(
+        a,
+        weights,
+        input_zero_point=relay.const(input_zp, "int32"),
+        kernel_zero_point=relay.const(kernel_zp, "int32"),
+        input_scale=relay.const(input_sc, "float32"),
+        kernel_scale=relay.const(kernel_sc, "float32"),
+        kernel_size=(kernel_h, kernel_w),
+        data_layout="NHWC",
+        kernel_layout=weight_format,
+        dilation=dilation,
+        strides=strides,
+        groups=groups,
+        channels=out_channels,
+        padding=p,
+        out_dtype="int32",
+    )
+    b = tvm.nd.array(np.random.randint(0, high=10, size=(out_channels,), 
dtype="int32"))
+    bc = relay.const(b, "int32")
+    bias = conv
+    if enable_bias:
+        bias = relay.nn.bias_add(conv, bc, axis=3)
+    requant_input_sc = [sc * input_sc for sc in kernel_sc]
+    req = relay.qnn.op.requantize(
+        bias,
+        relay.const(requant_input_sc, "float32"),
+        relay.const(0, "int32"),
+        relay.const(output_sc, "float32"),
+        relay.const(output_zp, "int32"),
+        out_dtype=dtype,
+    )
+    relu = make_qnn_relu(req, relu_type, output_sc, output_zp, dtype)
+    params = {"w": w, "b": b}
+    return relu, params
+
+
+@tvm.testing.requires_cmsisnn
+@pytest.mark.parametrize("ifm_shape", [(1, 28, 28, 12), (1, 64, 100, 4)])
+@pytest.mark.parametrize("kernel_size", [(3, 3)])
+@pytest.mark.parametrize("padding", ["SAME", "VALID"])
+@pytest.mark.parametrize("strides, dilation", [((2, 2), (1, 1)), ((1, 1), (1, 
1))])
+@pytest.mark.parametrize("enable_bias", [True, False])
+@pytest.mark.parametrize("relu_type", ["NONE", "RELU"])
+@pytest.mark.parametrize(
+    "in_zp, in_sc, k_sc, out_channels",
+    [(10, 0.0128, [0.11, 0.22], 2), (-64, 1, [1, 0.0256, 1.37], 3)],
+)
+def test_op_int8(
+    ifm_shape,
+    kernel_size,
+    padding,
+    strides,
+    dilation,
+    enable_bias,
+    relu_type,
+    in_zp,
+    in_sc,
+    k_sc,
+    out_channels,
+):
+    interface_api = "c"
+    use_unpacked_api = True
+    test_runner = AOT_CORSTONE300_RUNNER
+
+    k_zp = 0
+    groups = 1
+    weight_format = "HWIO"
+    kernel_h = kernel_size[0]
+    kernel_w = kernel_size[1]
+    dtype = "int8"
+    in_min, in_max = get_range_for_dtype_str(dtype)
+
+    weight_shape = None
+    if weight_format == "HWIO":
+        weight_shape = (kernel_h, kernel_w, ifm_shape[3] // groups, 
out_channels)
+    else:
+        weight_shape = (kernel_h, kernel_w, ifm_shape[3], out_channels)
+
+    out_sc, out_zp = get_conv2d_qnn_params(
+        weight_shape, in_sc, in_zp, k_sc, k_zp, dtype, dtype, dtype, False
+    )
+
+    model, params = make_model(
+        ifm_shape,
+        weight_shape,
+        in_zp,
+        in_sc,
+        k_zp,
+        k_sc,
+        out_zp,
+        out_sc,
+        padding,
+        strides,
+        dilation,
+        groups,
+        dtype,
+        dtype,
+        out_channels,
+        weight_format,
+        enable_bias,
+        relu_type,
+    )
+    orig_mod = make_module(model)
+    cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod, params)
+
+    # validate pattern matching
+    attrs = [
+        cmsisnn_mod[var.name_hint].attrs
+        for var in cmsisnn_mod.get_global_vars()
+        if cmsisnn_mod[var.name_hint].attrs
+    ]
+    assert any(attrs), "At least one function with external attributes was 
expected."
+
+    compilers = [
+        key == "Compiler" and value == "cmsisnn" for attr in attrs for key, 
value in attr.items()
+    ]
+    assert any(compilers), "Module does not contain function for cmsisnn 
target."
+
+    assert count_num_calls(orig_mod) == count_num_calls(
+        cmsisnn_mod
+    ), "Number of calls changed during partitioning"
+
+    # validate the output
+    np.random.seed(0)
+    inputs = {
+        "in0": np.random.randint(in_min, high=in_max, size=ifm_shape, 
dtype="int8"),
+    }
+    output_list = generate_ref_data(orig_mod["main"], inputs, params)
+    compile_and_run(
+        AOTTestModel(
+            module=cmsisnn_mod,
+            inputs=inputs,
+            outputs=output_list,
+            params=params,
+            output_tolerance=1,
+        ),
+        test_runner,
+        interface_api,
+        use_unpacked_api,
+    )
+
+
+def parameterize_for_invalid_model(test):
+    in_dtype = ["uint8", "int8"]
+    kernel_dtype = ["uint8", "int8"]
+    kernel_zero_point = [-33, 10, 0]
+    all_combinations = itertools.product(in_dtype, kernel_dtype, 
kernel_zero_point)
+    all_combinations = filter(
+        lambda parameters: not (
+            parameters[0] == "int8" and parameters[1] == "int8" and 
parameters[2] == 0
+        ),
+        all_combinations,
+    )
+    return pytest.mark.parametrize(
+        ["in_dtype", "kernel_dtype", "kernel_zero_point"],
+        all_combinations,
+    )(test)
+
+
+@parameterize_for_invalid_model
+def test_invalid_parameters(
+    in_dtype,
+    kernel_dtype,
+    kernel_zero_point,
+):
+    ifm_shape = (1, 28, 28, 12)
+    out_channels = 2
+    in_sc = 1
+    in_zp = 24
+    k_sc = [0.11, 0.0237]
+    in_min, in_max = get_range_for_dtype_str(in_dtype)
+
+    kernel_layout = "HWIO"
+    kernel_shape = [3, 3, ifm_shape[3], out_channels]
+    out_sc, out_zp = get_conv2d_qnn_params(
+        kernel_shape, in_sc, in_zp, k_sc, kernel_zero_point, in_dtype, 
kernel_dtype, in_dtype, False
+    )
+    model, params = make_model(
+        shape=ifm_shape,
+        kernel_shape=kernel_shape,
+        input_zp=in_zp,
+        input_sc=in_sc,
+        kernel_zp=kernel_zero_point,
+        kernel_sc=k_sc,
+        output_zp=out_zp,
+        output_sc=out_sc,
+        padding="SAME",
+        strides=(1, 1),
+        dilation=(1, 1),
+        groups=1,
+        dtype=in_dtype,
+        kernel_dtype=kernel_dtype,
+        out_channels=out_channels,
+        weight_format=kernel_layout,
+        enable_bias=True,
+        relu_type="NONE",
+    )
+    orig_mod = make_module(model)
+    cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod, params)
+
+    # print(cmsisnn_mod.astext(False))

Review comment:
       Removed it.

##########
File path: tests/python/contrib/test_cmsisnn/test_networks.py
##########
@@ -92,7 +92,6 @@ def test_cnn_small():
 
     orig_mod, params = convert_to_relay(tflite_model_buf, input_data, "input")
     cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod, params)
-

Review comment:
       No. Added back in.

##########
File path: src/relay/backend/contrib/cmsisnn/extract_constants.cc
##########
@@ -0,0 +1,158 @@
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/ndarray.h>
+
+#include "../../../qnn/utils.h"
+#include "../../../transforms/pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+namespace contrib {
+namespace cmsisnn {
+
+class ExtractConstantsMutator : public MixedModeMutator {
+ public:
+  explicit ExtractConstantsMutator(IRModule& mod) : mod_(mod) {}
+
+ private:
+  String gen_var_name() { return "tvm_var_extract_const_" + 
std::to_string(var_count_++); }
+
+  Expr VisitExpr_(const FunctionNode* func) final {
+    Function final_func = GetRef<Function>(func);
+    ++func_nesting_level_;
+    auto new_body = VisitExpr(func->body);
+    --func_nesting_level_;
+    if (!new_body.same_as(func->body)) {
+      final_func = Function(FreeVars(new_body), new_body, func->ret_type,
+                            FreeTypeVars(new_body, mod_), func->attrs);
+      function_to_constants_.Set(GetRef<Function>(func), 
constants_within_function_);
+      constants_within_function_.clear();
+    }
+    return final_func;
+  }
+
+  Expr Rewrite_(const CallNode* call, const Expr& post) final {
+    Expr final_call = post;
+    auto* post_call = post.as<CallNode>();
+    if (post_call == nullptr) {
+      return final_call;
+    }
+
+    // Replace Constant arguments with Vars for ML Operators
+    // Perform this for non-main Call Nodes only
+    if (func_nesting_level_ && call->op.as<OpNode>()) {
+      Array<Expr> new_args;
+      for (auto& arg : post_call->args) {
+        auto* const_arg = arg.as<ConstantNode>();
+        if (const_arg && !const_arg->is_scalar()) {
+          Var var_arg = Var(gen_var_name(), const_arg->tensor_type());
+          new_args.push_back(var_arg);
+          constants_within_function_.push_back(GetRef<Constant>(const_arg));
+        } else {
+          new_args.push_back(arg);
+        }
+      }
+      final_call = Call(call->op, new_args, call->attrs, {});
+    }
+
+    // Since the constants are kicked out of partitioned functions
+    // a new call to global function is needed
+    if (auto* glob_var_node = post_call->op.as<GlobalVarNode>()) {
+      auto glob_var = GetRef<GlobalVar>(glob_var_node);
+      auto glob_func = Downcast<Function>(mod_->Lookup(glob_var));
+      auto new_glob_func = VisitExpr(glob_func);
+      if (!new_glob_func.same_as(glob_func)) {
+        mod_->Update(glob_var, Downcast<Function>(new_glob_func));
+        Array<Expr> new_args = post_call->args;
+        ICHECK(function_to_constants_.find(glob_func) != 
function_to_constants_.end());
+        for (auto constant : function_to_constants_.at(glob_func)) {
+          new_args.push_back(constant);
+        }
+        final_call = Call(glob_var, new_args);
+      }
+    }
+
+    // Since the constants are kicked out of the local partitioned functions
+    // a new call to local function is needed
+    if (auto* func_node = call->op.as<FunctionNode>()) {
+      Function func = GetRef<Function>(func_node);
+      auto new_func = VisitExpr(func);
+      if (!new_func.same_as(func)) {
+        Array<Expr> new_args = post_call->args;
+        ICHECK(function_to_constants_.find(func) != 
function_to_constants_.end());
+        for (auto constant : function_to_constants_.at(func)) {
+          constants_within_function_.push_back(constant);
+          Var var_arg = Var(gen_var_name(), constant->tensor_type());
+          new_args.push_back(var_arg);
+        }
+        final_call = Call(new_func, new_args);
+      }
+    }
+
+    return final_call;
+  }
+
+ private:
+  /* \brief Updated module where all calls have replaced constants with new 
variables */
+  IRModule mod_;
+  /* \brief Maintains mapping of original function to the replaced constants */
+  Map<Function, Array<Constant>> function_to_constants_;
+  /* \brief Constants being kicked out of a function during the function visit 
*/
+  Array<Constant> constants_within_function_;
+  /* \brief Keeps track of variables being created */
+  int var_count_ = 0;
+  /* \brief Keeps track of function scope */
+  int func_nesting_level_ = 0;
+};
+
+/*!  * \brief Kicks out all constants out of the partitioned function into 
main()  */
+IRModule ExtractConstants(IRModule mod) {
+  String func_name;
+  Function func;
+
+  auto extract_constants = ExtractConstantsMutator(mod);
+  Function main_func = Downcast<Function>(mod->Lookup("main"));
+  auto new_main_body = extract_constants.VisitExpr(main_func->body);
+  if (!new_main_body.same_as(main_func->body)) {
+    auto main_var = mod->GetGlobalVar("main");
+    auto new_main_func = Function(main_func->params, new_main_body, 
main_func->ret_type,
+                                  main_func->type_params, main_func->attrs);
+    mod->Update(main_var, new_main_func);
+  }
+  return mod;
+}
+
+transform::Pass ExtractConstantsFromPartitionedFunction() {
+  runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> 
pass_func =
+      [=](IRModule m, transform::PassContext pc) { return ExtractConstants(m); 
};
+  return tvm::transform::CreateModulePass(pass_func, 0, 
"ExtractConstantsFromPartitionedFunction",
+                                          {});
+}
+
+TVM_REGISTER_GLOBAL("relay.ext.cmsisnn.transform.ExtractConstantsFromPartitionedFunction")
+    .set_body_typed([]() { return ExtractConstantsFromPartitionedFunction(); 
});

Review comment:
       ACK

##########
File path: src/relay/backend/contrib/cmsisnn/generate_constants.cc
##########
@@ -0,0 +1,230 @@
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/attrs/transform.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+#include <tvm/runtime/ndarray.h>
+
+#include "../../../qnn/utils.h"
+#include "../../../transforms/pattern_utils.h"
+
+namespace tvm {
+namespace relay {
+Expr MakeTranspose(Expr data, Array<Integer> axes);
+namespace contrib {
+namespace cmsisnn {
+
+class GenerateConstantsMutator : public MixedModeMutator {
+ public:
+  explicit GenerateConstantsMutator(IRModule& mod) : mod_(mod) {}
+
+ private:
+  /*!  * \brief Converts Kernel layout from HWIO to OHWI to align to CMSIS-NN 
requirements */
+  Expr ConvertKernelLayout(Expr kernel_expr, const Conv2DAttrs* conv2d_attrs, 
Attrs* new_attrs) {
+    auto attrs = make_object<Conv2DAttrs>();
+    attrs->strides = std::move(conv2d_attrs->strides);
+    attrs->padding = std::move(conv2d_attrs->padding);
+    attrs->dilation = std::move(conv2d_attrs->dilation);
+    attrs->groups = conv2d_attrs->groups;
+    attrs->channels = std::move(conv2d_attrs->channels);
+    attrs->kernel_size = std::move(conv2d_attrs->kernel_size);
+    attrs->data_layout = std::move(conv2d_attrs->data_layout);
+    attrs->kernel_layout = runtime::String("OHWI");
+    attrs->out_layout = std::move(conv2d_attrs->out_layout);
+    attrs->out_dtype = std::move(conv2d_attrs->out_dtype);
+    *new_attrs = tvm::Attrs{attrs};
+
+    IRModule kernel_module;
+    auto func_body = MakeTranspose(kernel_expr, {Integer(3), Integer(0), 
Integer(1), Integer(2)});
+    auto kernel_func =
+        Function(FreeVars(func_body), func_body, Type(), 
FreeTypeVars(func_body, kernel_module));
+    GlobalVar kernel_var("main");
+    kernel_module->Add(kernel_var, kernel_func);
+    kernel_module = relay::transform::FoldConstant()(kernel_module);
+    kernel_func = Downcast<Function>(kernel_module->Lookup("main"));
+    return kernel_func->body;
+  }
+
+  /*!  * \brief Performs weight transpose and substitutes existing constants 
in the composite
+   *            function for Conv2D with CMSIS-NN Requantize constants */
+  Expr GenerateConv2dRequantConstants(const Expr& expr) {
+    const CallNode* clip_call = nullptr;
+    const CallNode* requantize_call = nullptr;
+    const CallNode* bias_add_call = nullptr;
+    const CallNode* conv2d_call = nullptr;
+    auto* final_call = expr.as<CallNode>();
+    auto* final_op = final_call->op.as<OpNode>();
+    if (final_op->name == "clip") {
+      clip_call = final_call;
+      requantize_call = clip_call->args[0].as<CallNode>();
+    } else {
+      requantize_call = final_call;
+    }
+    auto* requantize_input = requantize_call->args[0].as<CallNode>();
+    auto* requantize_input_op = requantize_input->op.as<OpNode>();
+    if (requantize_input_op->name == "nn.bias_add") {
+      bias_add_call = requantize_input;
+      conv2d_call = bias_add_call->args[0].as<CallNode>();
+    } else {
+      conv2d_call = requantize_input;
+    }
+
+    // Transpose weights: HWIO -> OHWI
+    auto* conv2d_attrs = conv2d_call->attrs.as<Conv2DAttrs>();
+    tvm::Attrs new_conv2d_attrs;
+    Expr transposed_kernel =
+        ConvertKernelLayout(conv2d_call->args[1], conv2d_attrs, 
&new_conv2d_attrs);
+
+    // Obtain input and output scales from Relay's Requantization
+    int64_t out_channels = conv2d_attrs->channels.as<IntImmNode>()->value;
+    float output_scale = 
GetScalarFromConstant<float>(requantize_call->args[3]);
+    auto input_scales = 
tvm::relay::qnn::GetFloatVectorFromConstant(requantize_call->args[1]);
+    ICHECK(input_scales.size() == static_cast<size_t>(out_channels));
+
+    // Calculate requantization multiplier and shift
+    Device dev{DLDeviceType::kDLCPU, 0};
+    runtime::NDArray multiplier_nda =
+        runtime::NDArray::Empty({out_channels}, DataType::Int(32), dev);
+    runtime::NDArray shift_nda = runtime::NDArray::Empty({out_channels}, 
DataType::Int(32), dev);
+    int32_t* multiplier = static_cast<int32_t*>(multiplier_nda->data);
+    int32_t* shift = static_cast<int32_t*>(shift_nda->data);
+    for (int i = 0; i < out_channels; ++i) {
+      double effective_output_scale =
+          static_cast<double>(input_scales[i]) / 
static_cast<double>(output_scale);
+      std::tie(*(multiplier + i), *(shift + i)) =
+          
tvm::relay::qnn::GetFixedPointMultiplierShift(effective_output_scale);
+    }
+
+    // Create constants from requantization multiplier and shift
+    Constant multiplier_const(multiplier_nda);
+    Constant shift_const(shift_nda);
+
+    // Convert scale scalars into Constants
+    // Scales are expected as Constants by following passes
+    Expr weight_scale = conv2d_call->args[5];
+    Expr req_inp_scale = requantize_call->args[1];
+    if (out_channels == 1) {
+      runtime::NDArray weight_scale_nda =
+          runtime::NDArray::Empty({out_channels}, DataType::Float(32), dev);
+      float* weight_scale_p = static_cast<float*>(weight_scale_nda->data);
+      *weight_scale_p = GetScalarFromConstant<float>(weight_scale);
+      weight_scale = Constant(weight_scale_nda);
+
+      runtime::NDArray req_inp_scale_nda =
+          runtime::NDArray::Empty({out_channels}, DataType::Float(32), dev);
+      float* req_inp_scale_p = static_cast<float*>(req_inp_scale_nda->data);
+      *req_inp_scale_p = GetScalarFromConstant<float>(req_inp_scale);
+      req_inp_scale = Constant(req_inp_scale_nda);
+    }
+
+    // Replace existing weights (HWIO) with the transposed ones (OHWI)
+    // Substitute Conv2D weight_zero_point with the CMSIS-NN multiplier
+    // Substitute Requantize input_zero_point with CMSIS-NN shift
+    // Conv2D arguments: data, weight, input_zp, weight_zp, input_sc, weight_sc
+    Array<Expr> conv2d_args = {conv2d_call->args[0], transposed_kernel,    
conv2d_call->args[2],
+                               multiplier_const,     conv2d_call->args[4], 
weight_scale};
+    Call ret_call = Call(conv2d_call->op, conv2d_args, new_conv2d_attrs, {});
+    if (bias_add_call) {
+      ret_call =
+          Call(bias_add_call->op, {ret_call, bias_add_call->args[1]}, 
bias_add_call->attrs, {});
+    }
+    Array<Expr> requantize_args = {ret_call, req_inp_scale, shift_const, 
requantize_call->args[3],
+                                   requantize_call->args[4]};
+    ret_call = Call(requantize_call->op, requantize_args, 
requantize_call->attrs, {});
+    if (clip_call) {
+      ret_call = Call(clip_call->op, {ret_call}, clip_call->attrs, {});
+    }
+    return ret_call;
+  }
+
+  Expr Rewrite_(const CallNode* call, const Expr& post) final {
+    Expr final_call = post;
+    auto* post_call = post.as<CallNode>();
+    if (post_call == nullptr) {
+      return final_call;
+    }
+
+    auto* global_var = call->op.as<GlobalVarNode>();
+    if (global_var) {
+      // Update to global function call needed because the body changes while
+      // generating new constants
+      Function func = Downcast<Function>(mod_->Lookup(global_var->name_hint));
+      Expr new_body = VisitExpr(func->body);
+      if (!new_body.same_as(func->body)) {
+        Function new_func = Function(FreeVars(new_body), new_body, 
func->ret_type,
+                                     FreeTypeVars(new_body, mod_), 
func->attrs);
+        mod_->Update(GetRef<GlobalVar>(global_var), new_func);
+        final_call = Call(GetRef<GlobalVar>(global_var), post_call->args);
+      }
+    }
+
+    // Recreate composite function and corresponding call
+    // Updated composite function contains CMSIS-NN quantized multiplier and 
shift constants
+    if (call->op.as<FunctionNode>()) {
+      auto* func = call->op.as<FunctionNode>();
+      auto func_name = func->GetAttr<String>(attr::kComposite);
+      if (func_name.defined() && func_name == "cmsisnn.qnn_conv2d") {
+        Expr new_body = GenerateConv2dRequantConstants(func->body);
+        Function new_func = Function(FreeVars(new_body), new_body, 
func->ret_type,
+                                     FreeTypeVars(new_body, mod_), 
func->attrs);
+        final_call = Call(new_func, post_call->args);
+      }
+    }
+
+    return final_call;
+  }
+
+ private:
+  IRModule mod_;
+};
+
+IRModule GenerateConstants(IRModule mod) {
+  String func_name;
+  Function func;
+
+  // Introduces CMSIS-NN constants before the call to the external Relay 
function
+  auto generate_constants = GenerateConstantsMutator(mod);
+  Function main_func = Downcast<Function>(mod->Lookup("main"));
+  auto new_main_body = generate_constants.VisitExpr(main_func->body);
+  if (!new_main_body.same_as(main_func->body)) {
+    auto main_var = mod->GetGlobalVar("main");
+    auto new_main_func = Function(main_func->params, new_main_body, 
main_func->ret_type,
+                                  main_func->type_params, main_func->attrs);
+    mod->Update(main_var, new_main_func);
+  }
+
+  return mod;
+}
+
+transform::Pass GenerateCMSISNNConstants() {
+  runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> 
pass_func =
+      [=](IRModule m, transform::PassContext pc) { return 
GenerateConstants(m); };
+  return tvm::transform::CreateModulePass(pass_func, 0, 
"GenerateCMSISNNConstants", {});
+}
+
+TVM_REGISTER_GLOBAL("relay.ext.cmsisnn.transform.GenerateCMSISNNConstants").set_body_typed([]()
 {
+  return GenerateCMSISNNConstants();
+});

Review comment:
       ACK




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to