This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 082874c  [Torch][Quantized] Fix converting serialized quantized models 
(#5839)
082874c is described below

commit 082874c51f728d8ff12a9cd2eed4d2734e71eb8f
Author: masahi <masahi...@gmail.com>
AuthorDate: Fri Jun 19 01:24:03 2020 +0900

    [Torch][Quantized] Fix converting serialized quantized models (#5839)
    
    * [Torch] Fix converting serialized quantized models
    
    * clean up dtype check
    
    * comment clean up
---
 python/tvm/relay/frontend/pytorch.py      | 42 +++++++++++++++++------------
 tests/python/frontend/pytorch/qnn_test.py | 45 ++++++++++++++++++++++++++++---
 2 files changed, 67 insertions(+), 20 deletions(-)

diff --git a/python/tvm/relay/frontend/pytorch.py 
b/python/tvm/relay/frontend/pytorch.py
index d2451cd..d3b6510 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -115,6 +115,14 @@ def _should_construct_dynamic_list(list_construct_node):
     return False
 
 
+def _is_quantized_tensor(data, prelude):
+    # If a quantized Torch module is saved and loaded back, dtype will be 
dropped
+    # Since dtypes from Torch tensors are not reliable in such cases, we use
+    # Relay's type inference result to decide if an input tensor is quantized
+    ty = _infer_type_with_prelude(data, prelude)
+    return ty.dtype == "uint8"
+
+
 # operator implementation
 def _elemwise(name):
     def _impl(inputs, input_types):
@@ -530,10 +538,10 @@ def _linspace():
     return _impl
 
 
-def _relu():
+def _relu(prelude):
     def _impl(inputs, input_types):
         data = inputs[0]
-        if input_types[0] == "quint8":
+        if _is_quantized_tensor(data, prelude):
             assert len(inputs) == 3, "Input quant param not found in op inputs"
             input_zero_point = _expr.const(inputs[2], dtype="int32")
             return qnn_torch.quantized_relu(data, input_zero_point)
@@ -595,7 +603,7 @@ def _log_sigmoid():
         return _op.log(_op.tensor.sigmoid(data))
     return _impl
 
-def _adaptive_avg_pool_2d():
+def _adaptive_avg_pool_2d(prelude):
     def _impl(inputs, input_types):
         data = inputs[0]
         output_size = _infer_shape(inputs[1])
@@ -603,7 +611,7 @@ def _adaptive_avg_pool_2d():
         def func(x):
             return _op.nn.adaptive_avg_pool2d(x, output_size=output_size)
 
-        if input_types[0] == "quint8":
+        if _is_quantized_tensor(data, prelude):
             return qnn_torch.apply_with_upcast(data, func)
 
         return func(data)
@@ -1108,7 +1116,7 @@ def _softplus():
         return _op.log(_op.exp(inputs[0] * beta) + _expr.const(1.)) / beta
     return _impl
 
-def _avg_pool2d():
+def _avg_pool2d(prelude):
     def _impl(inputs, input_types):
         data = inputs[0]
 
@@ -1130,7 +1138,7 @@ def _avg_pool2d():
                                      ceil_mode=ceil_mode,
                                      count_include_pad=count_include_pad)
 
-        if input_types[0] == "quint8":
+        if _is_quantized_tensor(data, prelude):
             return qnn_torch.apply_with_upcast(data, func)
 
         return func(data)
@@ -1254,7 +1262,7 @@ def _variance():
 
     return _impl
 
-def _mean():
+def _mean(prelude):
     def _impl(inputs, input_types):
         data = inputs[0]
 
@@ -1274,7 +1282,7 @@ def _mean():
         def func(x):
             return _op.mean(x, axis, keepdims, exclude)
 
-        if input_types[0] == "quint8":
+        if _is_quantized_tensor(data, prelude):
             assert len(inputs) == 6, "Input quant param not found in op inputs"
             input_scale = _expr.const(inputs[4])
             input_zero_point = _expr.const(inputs[5])
@@ -1492,7 +1500,7 @@ def _to():
 
     return _impl
 
-def _upsample(method):
+def _upsample(method, prelude):
     def _impl(inputs, input_types):
         if isinstance(inputs[1], _expr.Var):
             out_size = _infer_shape(inputs[1])
@@ -1516,7 +1524,7 @@ def _upsample(method):
         def func(x):
             return _op.image.resize(x, out_size, "NCHW", method, coord_trans)
 
-        if input_types[0] == "quint8":
+        if _is_quantized_tensor(data, prelude):
             import torch
             from packaging import version
 
@@ -1835,8 +1843,8 @@ def _get_convert_map(prelude):
         "aten::take"                            : _take(),
         "aten::where"                           : _where(),
         "aten::topk"                            : _topk(),
-        "aten::relu"                            : _relu(),
-        "aten::relu_"                           : _relu(),
+        "aten::relu"                            : _relu(prelude),
+        "aten::relu_"                           : _relu(prelude),
         "aten::prelu"                           : _prelu(),
         "aten::leaky_relu"                      : _leaky_relu(),
         "aten::elu"                             : _elu(),
@@ -1845,7 +1853,7 @@ def _get_convert_map(prelude):
         "aten::gelu"                            : _gelu(),
         "aten::selu"                            : _selu(),
         "aten::log_sigmoid"                     : _log_sigmoid(),
-        "aten::adaptive_avg_pool2d"             : _adaptive_avg_pool_2d(),
+        "aten::adaptive_avg_pool2d"             : 
_adaptive_avg_pool_2d(prelude),
         "aten::adaptive_max_pool2d"             : _adaptive_max_pool_2d(),
         "aten::max_pool2d"                      : _maxpool_2d(),
         "aten::max_pool2d_with_indices"         : _maxpool_2d_with_indices(),
@@ -1874,13 +1882,13 @@ def _get_convert_map(prelude):
         "aten::log_softmax"                     : _log_softmax(),
         "aten::sigmoid"                         : _sigmoid(),
         "aten::softplus"                        : _softplus(),
-        "aten::avg_pool2d"                      : _avg_pool2d(),
+        "aten::avg_pool2d"                      : _avg_pool2d(prelude),
         "aten::avg_pool3d"                      : _avg_pool3d(),
         "aten::dropout"                         : _dropout(),
         "aten::dropout_"                        : _dropout(),
         "aten::feature_dropout"                 : _dropout(),
         "aten::alpha_dropout"                   : _dropout(),
-        "aten::mean"                            : _mean(),
+        "aten::mean"                            : _mean(prelude),
         "aten::chunk"                           : _chunk(prelude),
         "aten::matmul"                          : _matmul(prelude),
         "aten::expand"                          : _expand(),
@@ -1932,8 +1940,8 @@ def _get_convert_map(prelude):
         "aten::isnan"                           : _unary("isnan"),
         "aten::clamp"                           : _clamp(),
         "aten::detach"                          : _identity(),
-        "aten::upsample_bilinear2d"             : _upsample("bilinear"),
-        "aten::upsample_nearest2d"              : 
_upsample("nearest_neighbor"),
+        "aten::upsample_bilinear2d"             : _upsample("bilinear", 
prelude),
+        "aten::upsample_nearest2d"              : 
_upsample("nearest_neighbor", prelude),
         "aten::upsample_trilinear3d"            : _upsample3d("trilinear"),
         "aten::upsample_nearest3d"              : 
_upsample3d("nearest_neighbor"),
         "aten::expand_as"                       : _expand_as(),
diff --git a/tests/python/frontend/pytorch/qnn_test.py 
b/tests/python/frontend/pytorch/qnn_test.py
index 551cdc4..8c6c248 100644
--- a/tests/python/frontend/pytorch/qnn_test.py
+++ b/tests/python/frontend/pytorch/qnn_test.py
@@ -63,7 +63,7 @@ def get_qconfig(per_channel):
                                           weight=default_weight_observer)
 
 
-def quantize_model(model, inp, per_channel=False, dummy=True):
+def quantize_model(model, inp, per_channel=False):
     model.fuse_model()
     model.qconfig = get_qconfig(per_channel)
     torch.quantization.prepare(model, inplace=True)
@@ -243,6 +243,18 @@ class AvgPool2d(nn.Module):
         pass
 
 
+class AdaptiveAvgPool2d(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self.pool = QuantWrapper(nn.AdaptiveAvgPool2d((1, 1)))
+
+    def forward(self, x):
+        return self.pool(x)
+
+    def fuse_model(self):
+        pass
+
+
 def test_quantized_modules():
     imagenet_ishape = (1, 3, 224, 224)
 
@@ -280,7 +292,7 @@ def test_quantized_modules():
         raw_module.eval()
         inp = torch.rand(ishape)
 
-        quantize_model(raw_module, inp, per_channel=per_channel, dummy=True)
+        quantize_model(raw_module, inp, per_channel=per_channel)
         script_module = torch.jit.trace(raw_module, inp).eval()
 
         with torch.no_grad():
@@ -376,7 +388,7 @@ def test_quantized_imagenet():
         inp = get_imagenet_input()
         pt_inp = torch.from_numpy(inp)
 
-        quantize_model(raw_model, pt_inp, per_channel=per_channel, dummy=False)
+        quantize_model(raw_model, pt_inp, per_channel=per_channel)
         script_module = torch.jit.trace(raw_model, pt_inp).eval()
 
         with torch.no_grad():
@@ -465,3 +477,30 @@ def test_quantized_imagenet():
         mean abs_diff: 0.054197952
         558 in 1000 raw outputs identical.
         """
+
+
+def test_serialized_modules():
+    ishape = (1, 16, 64, 64)
+    raw_module = AdaptiveAvgPool2d().eval()
+    inp = torch.rand(ishape)
+
+    quantize_model(raw_module, inp)
+    script_module = torch.jit.trace(raw_module, inp).eval()
+
+    fname = "tmp.pt"
+    torch.jit.save(script_module, fname)
+    loaded = torch.jit.load(fname)
+    os.remove(fname)
+
+    with torch.no_grad():
+        pt_result = loaded(inp.clone()).numpy()
+
+    input_name = "input"
+    runtime = get_tvm_runtime(loaded, input_name, ishape)
+    runtime.set_input(input_name, inp.numpy().copy())
+    runtime.run()
+    tvm_result = runtime.get_output(0).asnumpy()
+
+    num_identical = np.sum(tvm_result == pt_result)
+    match_ratio = num_identical / float(np.prod(tvm_result.shape))
+    assert match_ratio > 0.2

Reply via email to