comaniac commented on a change in pull request #6355:
URL: https://github.com/apache/incubator-tvm/pull/6355#discussion_r483878188



##########
File path: src/runtime/contrib/ethosn/ethosn_runtime.cc
##########
@@ -120,6 +120,14 @@ Module EthosnModule::LoadFromBinary(void* strm) {
   return Module(n);
 }
 
+void EthosnModule::SaveToFile(const std::string& path, const std::string& 
format) {

Review comment:
       It looks me that this still saves a binary file. Are there specific use 
cases?

##########
File path: tests/python/contrib/test_ethosn/test_addition.py
##########
@@ -0,0 +1,96 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Ethos-N integration addition tests"""
+
+import tvm
+from tvm import relay
+from tvm.relay.op.contrib.ethosn import ethosn_available
+from . import infrastructure as tei
+import numpy as np
+
+
+def _get_model(input_shape, lhs_zp, lhs_sc, rhs_zp, rhs_sc, out_zp, out_sc, 
dtype):
+    """Return a model and any parameters it may have"""
+
+    a = relay.var("a", shape=input_shape, dtype=dtype)
+    b = relay.var("b", shape=input_shape, dtype=dtype)
+    model = relay.qnn.op.add(lhs=a, rhs=b,
+                             lhs_scale=relay.const(lhs_sc, 'float32'),
+                             lhs_zero_point=relay.const(lhs_zp, 'int32'),
+                             rhs_scale=relay.const(rhs_sc, 'float32'),
+                             rhs_zero_point=relay.const(rhs_zp, 'int32'),
+                             output_scale=relay.const(out_sc, 'float32'),
+                             output_zero_point=relay.const(out_zp, 'int32'))
+    return model
+
+
+def _get_addition_qnn_params(input1_zp, input1_sc, input2_zp, input2_sc):
+    input1_max = input1_sc * (255 - input1_zp)
+    input1_min = - input1_sc * input1_zp
+    input2_max = input2_sc * (255 - input2_zp)
+    input2_min = - input2_sc * input2_zp
+    output_max = input1_max + input2_max
+    output_min = input1_min + input2_min
+    output_sc = (output_max - output_min) / 255
+    output_zp = - int(output_min / output_sc)
+    return output_zp, output_sc
+
+
+def test_addition():
+    if not ethosn_available():
+        return
+
+    num_trials = 5
+    np.random.seed(0)
+    for _ in range(num_trials):

Review comment:
       Random shapes could make the test flaky, so we should avoid it as 
possible.

##########
File path: tests/python/contrib/test_ethosn/test_networks.py
##########
@@ -0,0 +1,143 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Ethos-N integration end-to-end network tests"""
+
+import pytest
+pytest.importorskip('tflite')
+pytest.importorskip('tensorflow')
+
+from tvm import relay
+from tvm.relay.op.contrib.ethosn import ethosn_available, Available
+from tvm.contrib import download
+import tvm.relay.testing.tf as tf_testing
+import tflite.Model
+from . import infrastructure as tei
+
+
+def _get_tflite_model(tflite_model_path, inputs_dict, dtype):
+    with open(tflite_model_path, 'rb') as f:
+        tflite_model_buffer = f.read()
+
+    try:
+        tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buffer, 
0)
+    except AttributeError:
+        tflite_model = tflite.Model.GetRootAsModel(tflite_model_buffer, 0)
+    shape_dict = {}
+    dtype_dict = {}
+    for input in inputs_dict:
+        input_shape = inputs_dict[input]
+        shape_dict[input] = input_shape
+        dtype_dict[input] = dtype
+
+    return relay.frontend.from_tflite(
+        tflite_model,
+        shape_dict=shape_dict,
+        dtype_dict=dtype_dict,
+    )
+
+
+def _test_image_network(model_url, model_sub_path, input_dict, compile_hash, 
output_count, run=True, host_ops=0, npu_partitions=1):
+    if not ethosn_available():
+        return
+
+    def get_model():
+        if model_url[-3:] in ("tgz", "zip"):
+            model_path = tf_testing.get_workload_official(
+                model_url,
+                model_sub_path,
+            )
+        else:
+            model_path = download.download_testdata(
+                model_url,
+                model_sub_path,
+            )
+        return _get_tflite_model(model_path, input_dict, 'uint8')
+
+    outputs = []
+    inputs = {}
+    for input_name in input_dict:
+        input_shape = input_dict[input_name]
+        inputs[input_name] = tei.get_real_image(input_shape[1], input_shape[2])
+
+    for npu in [False, True]:
+        mod, params = get_model()
+        graph, lib, params = tei.build(mod, params, npu=npu, 
expected_host_ops=host_ops, npu_partitions=npu_partitions)
+        if npu:
+            tei.assert_lib_hash(lib, compile_hash)
+        if run:
+            outputs.append(tei.run(graph, lib, params, inputs, output_count, 
npu=npu))
+
+    if run:
+        tei.verify(outputs, 1, verify_saturation=False)
+
+
+def test_mobilenet_v1():
+    hw = ethosn_available()
+    _test_image_network(
+        model_url="https://storage.googleapis.com/download.tensorflow.org/"; \
+                  
"models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz",
+        model_sub_path="mobilenet_v1_1.0_224_quant.tflite",
+        input_dict={"input": (1, 224, 224, 3)},
+        compile_hash="81637c89339201a07dc96e3b5dbf836a",

Review comment:
       How this hash comes from? In case someone in the future changes Ethos-N 
codegen, how can this be updated?

##########
File path: python/tvm/relay/op/contrib/ethosn.py
##########
@@ -54,18 +54,107 @@ def qnn_conv_pattern():
             pattern, is_constant(), is_constant(), is_constant(), 
is_constant())
         return pattern
 
+    def qnn_fc_pattern():
+        pattern = is_op('qnn.dense')(
+            wildcard(), is_constant(), is_constant(), is_constant(), 
is_constant(), is_constant())
+        pattern = is_op('nn.bias_add')(pattern, is_constant())
+        pattern = is_op('qnn.requantize')(
+            pattern, is_constant(), is_constant(), is_constant(), 
is_constant())
+        return pattern
+
+    def qnn_avg_pool2d_pattern():
+        pattern = is_op('cast')(wildcard())
+        pattern = is_op('nn.avg_pool2d')(pattern)
+        pattern = is_op('cast')(pattern)
+        return pattern
+
+    def qnn_sigmoid_pattern():
+        pattern = is_op('qnn.dequantize')(wildcard(), is_constant(), 
is_constant())
+        pattern = is_op('sigmoid')(pattern)
+        pattern = is_op('qnn.quantize')(pattern, is_constant(), is_constant())
+        return pattern
+
     def check_conv2d(extract):
         """Check if a conv2d is supported by Ethos-N."""
         if not ethosn_available():
             return False
 
         return support.conv2d(extract)
 
+    def check_fc(extract):
+        """Check if a fully connected is supported by Ethos-N."""
+        if not ethosn_available():
+            return False
+
+        return support.fc(extract)
+
+    def check_avg_pool2d(extract):
+        """Check if a avg pool2d is supported by Ethos-N."""
+        if not ethosn_available():
+            return False
+
+        return support.avg_pool2d(extract)
+
+    def check_sigmoid(extract):
+        """Check if a sigmoid is supported by Ethos-N."""
+        if not ethosn_available():
+            return False
+
+        if extract.attrs.out_dtype != 'uint8':
+            return False
+
+        return support.sigmoid(extract)
+
     return [
         ("ethos-n.qnn_conv2d", qnn_conv_pattern(), check_conv2d),
+        ("ethos-n.qnn_avg_pool2d", qnn_avg_pool2d_pattern(), check_avg_pool2d),
+        ("ethos-n.qnn_sigmoid", qnn_sigmoid_pattern(), check_sigmoid),
+        ("ethos-n.qnn_fc", qnn_fc_pattern(), check_fc),
     ]
 
 
+def _is_ethos_composite(node):

Review comment:
       Very miner point: should this be `_is_ethosn_composite `?

##########
File path: tests/python/contrib/test_ethosn/test_fullyconnected.py
##########
@@ -0,0 +1,118 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Ethos-N integration fully connected tests"""
+
+import numpy as np
+import tvm
+from tvm import relay
+from tvm.relay.op.contrib.ethosn import ethosn_available
+from . import infrastructure as tei
+
+
+def _get_model(shape, weight_shape,
+              input_zp, input_sc,
+              kernel_zp, kernel_sc,
+              output_zp, output_sc,
+              dtype):
+    """Return a model an any parameters it may have"""
+    a = relay.var('a', shape=shape, dtype=dtype)
+    w = tvm.nd.array(np.ones(weight_shape, dtype))
+    weights = relay.const(w, dtype)
+    fc = relay.qnn.op.dense(
+        a,
+        weights,
+        input_zero_point=relay.const(input_zp, "int32"),
+        kernel_zero_point=relay.const(kernel_zp, "int32"),
+        input_scale=relay.const(input_sc, "float32"),
+        kernel_scale=relay.const(kernel_sc, "float32"),
+        units=weight_shape[0],
+        out_dtype='int32'
+    )
+    b = tvm.nd.array(np.random.randint(0, high=255, size=(shape[0],), 
dtype="int32"))
+    biasc = relay.const(b, "int32")
+    bias = relay.nn.bias_add(fc, biasc, axis=0)
+    req = relay.qnn.op.requantize(
+        bias,
+        relay.const(input_sc * kernel_sc, 'float32'),  # input zero scale
+        relay.const(input_zp * kernel_zp, 'int32'),    # input zero point
+        relay.const(output_sc, 'float32'),             # output zero scale
+        relay.const(output_zp, 'int32'),               # output zero point
+        out_dtype="uint8"
+    )
+    params = {"w": w,
+              "b": b}
+    return req, params
+
+
+def test_fullyconnected():
+    if not ethosn_available():
+        return
+
+    dtype = "uint8"
+    np.random.seed(0)
+    for shape in [(1, 32*32), (1, 64*64), (1, 128*128)]:
+        inputs = {
+            "a": tvm.nd.array(np.random.randint(0, high=255, size=shape, 
dtype=dtype)),
+        }
+        outputs = []
+        input_zp = np.random.randint(0, 255)
+        input_sc = np.random.random() * 2
+        kernel_zp = np.random.randint(0, 255)
+        kernel_sc = np.random.random() * 2

Review comment:
       ditto.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to