This is an automated email from the ASF dual-hosted git repository. masahi pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push: new 082874c [Torch][Quantized] Fix converting serialized quantized models (#5839) 082874c is described below commit 082874c51f728d8ff12a9cd2eed4d2734e71eb8f Author: masahi <masahi...@gmail.com> AuthorDate: Fri Jun 19 01:24:03 2020 +0900 [Torch][Quantized] Fix converting serialized quantized models (#5839) * [Torch] Fix converting serialized quantized models * clean up dtype check * comment clean up --- python/tvm/relay/frontend/pytorch.py | 42 +++++++++++++++++------------ tests/python/frontend/pytorch/qnn_test.py | 45 ++++++++++++++++++++++++++++--- 2 files changed, 67 insertions(+), 20 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index d2451cd..d3b6510 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -115,6 +115,14 @@ def _should_construct_dynamic_list(list_construct_node): return False +def _is_quantized_tensor(data, prelude): + # If a quantized Torch module is saved and loaded back, dtype will be dropped + # Since dtypes from Torch tensors are not reliable in such cases, we use + # Relay's type inference result to decide if an input tensor is quantized + ty = _infer_type_with_prelude(data, prelude) + return ty.dtype == "uint8" + + # operator implementation def _elemwise(name): def _impl(inputs, input_types): @@ -530,10 +538,10 @@ def _linspace(): return _impl -def _relu(): +def _relu(prelude): def _impl(inputs, input_types): data = inputs[0] - if input_types[0] == "quint8": + if _is_quantized_tensor(data, prelude): assert len(inputs) == 3, "Input quant param not found in op inputs" input_zero_point = _expr.const(inputs[2], dtype="int32") return qnn_torch.quantized_relu(data, input_zero_point) @@ -595,7 +603,7 @@ def _log_sigmoid(): return _op.log(_op.tensor.sigmoid(data)) return _impl -def _adaptive_avg_pool_2d(): +def _adaptive_avg_pool_2d(prelude): def _impl(inputs, input_types): data = inputs[0] output_size = _infer_shape(inputs[1]) @@ -603,7 +611,7 @@ def _adaptive_avg_pool_2d(): def func(x): return _op.nn.adaptive_avg_pool2d(x, output_size=output_size) - if input_types[0] == "quint8": + if _is_quantized_tensor(data, prelude): return qnn_torch.apply_with_upcast(data, func) return func(data) @@ -1108,7 +1116,7 @@ def _softplus(): return _op.log(_op.exp(inputs[0] * beta) + _expr.const(1.)) / beta return _impl -def _avg_pool2d(): +def _avg_pool2d(prelude): def _impl(inputs, input_types): data = inputs[0] @@ -1130,7 +1138,7 @@ def _avg_pool2d(): ceil_mode=ceil_mode, count_include_pad=count_include_pad) - if input_types[0] == "quint8": + if _is_quantized_tensor(data, prelude): return qnn_torch.apply_with_upcast(data, func) return func(data) @@ -1254,7 +1262,7 @@ def _variance(): return _impl -def _mean(): +def _mean(prelude): def _impl(inputs, input_types): data = inputs[0] @@ -1274,7 +1282,7 @@ def _mean(): def func(x): return _op.mean(x, axis, keepdims, exclude) - if input_types[0] == "quint8": + if _is_quantized_tensor(data, prelude): assert len(inputs) == 6, "Input quant param not found in op inputs" input_scale = _expr.const(inputs[4]) input_zero_point = _expr.const(inputs[5]) @@ -1492,7 +1500,7 @@ def _to(): return _impl -def _upsample(method): +def _upsample(method, prelude): def _impl(inputs, input_types): if isinstance(inputs[1], _expr.Var): out_size = _infer_shape(inputs[1]) @@ -1516,7 +1524,7 @@ def _upsample(method): def func(x): return _op.image.resize(x, out_size, "NCHW", method, coord_trans) - if input_types[0] == "quint8": + if _is_quantized_tensor(data, prelude): import torch from packaging import version @@ -1835,8 +1843,8 @@ def _get_convert_map(prelude): "aten::take" : _take(), "aten::where" : _where(), "aten::topk" : _topk(), - "aten::relu" : _relu(), - "aten::relu_" : _relu(), + "aten::relu" : _relu(prelude), + "aten::relu_" : _relu(prelude), "aten::prelu" : _prelu(), "aten::leaky_relu" : _leaky_relu(), "aten::elu" : _elu(), @@ -1845,7 +1853,7 @@ def _get_convert_map(prelude): "aten::gelu" : _gelu(), "aten::selu" : _selu(), "aten::log_sigmoid" : _log_sigmoid(), - "aten::adaptive_avg_pool2d" : _adaptive_avg_pool_2d(), + "aten::adaptive_avg_pool2d" : _adaptive_avg_pool_2d(prelude), "aten::adaptive_max_pool2d" : _adaptive_max_pool_2d(), "aten::max_pool2d" : _maxpool_2d(), "aten::max_pool2d_with_indices" : _maxpool_2d_with_indices(), @@ -1874,13 +1882,13 @@ def _get_convert_map(prelude): "aten::log_softmax" : _log_softmax(), "aten::sigmoid" : _sigmoid(), "aten::softplus" : _softplus(), - "aten::avg_pool2d" : _avg_pool2d(), + "aten::avg_pool2d" : _avg_pool2d(prelude), "aten::avg_pool3d" : _avg_pool3d(), "aten::dropout" : _dropout(), "aten::dropout_" : _dropout(), "aten::feature_dropout" : _dropout(), "aten::alpha_dropout" : _dropout(), - "aten::mean" : _mean(), + "aten::mean" : _mean(prelude), "aten::chunk" : _chunk(prelude), "aten::matmul" : _matmul(prelude), "aten::expand" : _expand(), @@ -1932,8 +1940,8 @@ def _get_convert_map(prelude): "aten::isnan" : _unary("isnan"), "aten::clamp" : _clamp(), "aten::detach" : _identity(), - "aten::upsample_bilinear2d" : _upsample("bilinear"), - "aten::upsample_nearest2d" : _upsample("nearest_neighbor"), + "aten::upsample_bilinear2d" : _upsample("bilinear", prelude), + "aten::upsample_nearest2d" : _upsample("nearest_neighbor", prelude), "aten::upsample_trilinear3d" : _upsample3d("trilinear"), "aten::upsample_nearest3d" : _upsample3d("nearest_neighbor"), "aten::expand_as" : _expand_as(), diff --git a/tests/python/frontend/pytorch/qnn_test.py b/tests/python/frontend/pytorch/qnn_test.py index 551cdc4..8c6c248 100644 --- a/tests/python/frontend/pytorch/qnn_test.py +++ b/tests/python/frontend/pytorch/qnn_test.py @@ -63,7 +63,7 @@ def get_qconfig(per_channel): weight=default_weight_observer) -def quantize_model(model, inp, per_channel=False, dummy=True): +def quantize_model(model, inp, per_channel=False): model.fuse_model() model.qconfig = get_qconfig(per_channel) torch.quantization.prepare(model, inplace=True) @@ -243,6 +243,18 @@ class AvgPool2d(nn.Module): pass +class AdaptiveAvgPool2d(nn.Module): + def __init__(self): + super().__init__() + self.pool = QuantWrapper(nn.AdaptiveAvgPool2d((1, 1))) + + def forward(self, x): + return self.pool(x) + + def fuse_model(self): + pass + + def test_quantized_modules(): imagenet_ishape = (1, 3, 224, 224) @@ -280,7 +292,7 @@ def test_quantized_modules(): raw_module.eval() inp = torch.rand(ishape) - quantize_model(raw_module, inp, per_channel=per_channel, dummy=True) + quantize_model(raw_module, inp, per_channel=per_channel) script_module = torch.jit.trace(raw_module, inp).eval() with torch.no_grad(): @@ -376,7 +388,7 @@ def test_quantized_imagenet(): inp = get_imagenet_input() pt_inp = torch.from_numpy(inp) - quantize_model(raw_model, pt_inp, per_channel=per_channel, dummy=False) + quantize_model(raw_model, pt_inp, per_channel=per_channel) script_module = torch.jit.trace(raw_model, pt_inp).eval() with torch.no_grad(): @@ -465,3 +477,30 @@ def test_quantized_imagenet(): mean abs_diff: 0.054197952 558 in 1000 raw outputs identical. """ + + +def test_serialized_modules(): + ishape = (1, 16, 64, 64) + raw_module = AdaptiveAvgPool2d().eval() + inp = torch.rand(ishape) + + quantize_model(raw_module, inp) + script_module = torch.jit.trace(raw_module, inp).eval() + + fname = "tmp.pt" + torch.jit.save(script_module, fname) + loaded = torch.jit.load(fname) + os.remove(fname) + + with torch.no_grad(): + pt_result = loaded(inp.clone()).numpy() + + input_name = "input" + runtime = get_tvm_runtime(loaded, input_name, ishape) + runtime.set_input(input_name, inp.numpy().copy()) + runtime.run() + tvm_result = runtime.get_output(0).asnumpy() + + num_identical = np.sum(tvm_result == pt_result) + match_ratio = num_identical / float(np.prod(tvm_result.shape)) + assert match_ratio > 0.2