This is an automated email from the ASF dual-hosted git repository.

anijain2305 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new b64a843  [Torch, QNN] Add missing upcast to uint8 avg_pool conversion  
(#5089)
b64a843 is described below

commit b64a843acd15ca34d2baf9fce730e81f91b3a580
Author: masahi <masahi...@gmail.com>
AuthorDate: Thu Mar 19 02:31:06 2020 +0900

    [Torch, QNN] Add missing upcast to uint8 avg_pool conversion  (#5089)
    
    * add missing upcast to avgpool
    
    * add avg pool test
---
 python/tvm/relay/frontend/pytorch.py      | 22 +++++++++++++++-------
 python/tvm/relay/frontend/qnn_torch.py    |  5 ++---
 tests/python/frontend/pytorch/qnn_test.py | 15 +++++++++++++--
 3 files changed, 30 insertions(+), 12 deletions(-)

diff --git a/python/tvm/relay/frontend/pytorch.py 
b/python/tvm/relay/frontend/pytorch.py
index 6da91c1..0c7465b 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -172,7 +172,7 @@ def _adaptive_avg_2d():
             return _op.nn.adaptive_avg_pool2d(x, output_size=output_size)
 
         if input_types[0] == "quint8":
-            return qnn_torch.quantized_adaptive_avg_2d(data, func)
+            return qnn_torch.apply_with_upcast(data, func)
 
         return func(data)
 
@@ -484,14 +484,22 @@ def _avg_pool2d():
         ceil_mode = int(inputs[4])
         count_include_pad = int(inputs[5])
 
-        return _op.nn.avg_pool2d(data,
-                                 pool_size=pool_size,
-                                 strides=strides,
-                                 padding=padding,
-                                 ceil_mode=ceil_mode,
-                                 count_include_pad=count_include_pad)
+        def func(x):
+            return _op.nn.avg_pool2d(x,
+                                     pool_size=pool_size,
+                                     strides=strides,
+                                     padding=padding,
+                                     ceil_mode=ceil_mode,
+                                     count_include_pad=count_include_pad)
+
+        if input_types[0] == "quint8":
+            return qnn_torch.apply_with_upcast(data, func)
+
+        return func(data)
+
     return _impl
 
+
 def _dropout():
     def _impl(inputs, input_types):
         data = inputs[0]
diff --git a/python/tvm/relay/frontend/qnn_torch.py 
b/python/tvm/relay/frontend/qnn_torch.py
index 70178be..e6a015f 100644
--- a/python/tvm/relay/frontend/qnn_torch.py
+++ b/python/tvm/relay/frontend/qnn_torch.py
@@ -359,10 +359,9 @@ def add_quant_params(params, quant_params):
             params[qparam.bias_var.name_hint] = tvm.nd.array(qparam.bias)
 
 
-def quantized_adaptive_avg_2d(data, func_fp32):
-    # this follows tflite impl
+def apply_with_upcast(data, func):
     inp = _op.cast(data, dtype="int32")
-    out = func_fp32(inp)
+    out = func(inp)
     return _op.cast(out, "uint8")
 
 
diff --git a/tests/python/frontend/pytorch/qnn_test.py 
b/tests/python/frontend/pytorch/qnn_test.py
index 23fcb7c..ebc00bf 100644
--- a/tests/python/frontend/pytorch/qnn_test.py
+++ b/tests/python/frontend/pytorch/qnn_test.py
@@ -218,7 +218,6 @@ class MulScalarNegative(nn.Module):
 class UpsamplingBilinear(nn.Module):
     def __init__(self):
         super().__init__()
-        self.relu = QuantWrapper(nn.ReLU())
         self.quant = QuantStub()
         self.dequant = DeQuantStub()
 
@@ -233,12 +232,25 @@ class UpsamplingBilinear(nn.Module):
         pass
 
 
+class AvgPool2d(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.pool = QuantWrapper(nn.AvgPool2d(kernel_size=2))
+
+    def forward(self, x):
+        return self.pool(x)
+
+    def fuse_model(self):
+        pass
+
+
 def test_quantized_modules():
     imagenet_ishape = (1, 3, 224, 224)
 
     qmodules = [
        ("relu", imagenet_ishape, ReLU(), False),
        ("upsample bilinear", (1, 3, 64, 64), UpsamplingBilinear(), False),
+       ("avgpool", imagenet_ishape, AvgPool2d(), False),
     ]
 
     for per_channel in [False, True]:
@@ -276,7 +288,6 @@ def test_quantized_modules():
             pt_result = script_module(inp.clone()).numpy()
 
         input_name = get_graph_input_names(script_module)[0]
-
         runtime = get_tvm_runtime(script_module, input_name, ishape)
         runtime.set_input(input_name, inp.numpy().copy())
         runtime.run()

Reply via email to