This is an automated email from the ASF dual-hosted git repository. masahi pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 22dcf4490d [PyTorch] Fix pad_common for float pad_value (#12134) 22dcf4490d is described below commit 22dcf4490dacc7813f5ef3d700ab0b64171c7662 Author: Yuanjing Shi <yuanj...@octoml.ai> AuthorDate: Thu Aug 11 21:02:48 2022 -1000 [PyTorch] Fix pad_common for float pad_value (#12134) * fix pad * fix constant padding and handle float infinity * revert change to pad_width * fix constant pad value --- python/tvm/relay/frontend/pytorch.py | 11 ++++----- tests/python/frontend/pytorch/test_forward.py | 32 +++++++++++++++++++++++++-- 2 files changed, 36 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 0fe8d57464..ffe4b313c5 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1905,7 +1905,7 @@ class PyTorchOpConverter: # initialize paddings based on input len pad_len = len(self.infer_shape(data)) * 2 - paddings = [pad_value] * pad_len + paddings = [0] * pad_len if len(pad_list) >= 2: paddings[-1] = pad_list[1] @@ -1925,8 +1925,10 @@ class PyTorchOpConverter: for pad in paddings: const_paddings.append([]) for p in pad: - if not isinstance(p, int): + if isinstance(p, _expr.Expr): p = int(_infer_value(p, {}).numpy()) + elif not isinstance(p, int): + raise NotImplementedError("pad width should be int/expr") const_paddings[-1].append(p) if p != 0: non_zero_found = True @@ -1934,12 +1936,11 @@ class PyTorchOpConverter: if not non_zero_found: return data elif mode == "constant": - return _op.nn.pad(data, const_paddings, pad_value=inputs[2], pad_mode=mode) + return _op.nn.pad(data, const_paddings, pad_value=pad_value, pad_mode=mode) else: return _op.nn.pad(data, const_paddings, pad_mode=mode) def pad(self, inputs, input_types): - # mode: Optional default "constant" if len(inputs) > 2 and inputs[2] is not None: mode = inputs[2] @@ -1960,7 +1961,7 @@ class PyTorchOpConverter: return self.pad_common(mode, pad_value, inputs, input_types) def constant_pad_nd(self, inputs, input_types): - return self.pad_common("constant", 0, inputs, input_types) + return self.pad_common("constant", _expr.const(inputs[2]), inputs, input_types) def reflection_pad1d(self, inputs, input_types): return self.pad_common("reflect", 0, inputs, input_types) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index bc848f90b3..6b1eb30a56 100755 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -2010,6 +2010,34 @@ def test_forward_functional_pad(): pad = (0, 1, 2, 1, 3, 3) verify_model(Pad1().float().eval(), input_data=input_data) + class Pad2(Module): + def forward(self, *args): + return torch.nn.functional.pad(args[0], pad, "constant", 1) + + input_data = torch.rand((3, 3, 4, 2)) + pad = (1, 1) + verify_model(Pad2().float().eval(), input_data=input_data) + + pad = (1, 1, 2, 2) + verify_model(Pad2().float().eval(), input_data=input_data) + + pad = (0, 1, 2, 1, 3, 3) + verify_model(Pad2().float().eval(), input_data=input_data) + + class Pad3(Module): + def forward(self, *args): + return torch.nn.functional.pad(args[0], pad, "constant", 1.0) + + input_data = torch.rand((3, 3, 4, 2)) + pad = (1, 1) + verify_model(Pad3().float().eval(), input_data=input_data) + + pad = (1, 1, 2, 2) + verify_model(Pad3().float().eval(), input_data=input_data) + + pad = (0, 1, 2, 1, 3, 3) + verify_model(Pad3().float().eval(), input_data=input_data) + @tvm.testing.uses_gpu def test_forward_zero_pad2d(): @@ -2021,10 +2049,10 @@ def test_forward_zero_pad2d(): @tvm.testing.uses_gpu def test_forward_constant_pad1d(): inp = torch.rand((1, 2, 4)) - verify_model(torch.nn.ConstantPad2d(2, 3.5).eval(), inp) + verify_model(torch.nn.ConstantPad1d(2, 3.5).eval(), inp) inp = torch.rand((1, 2, 3)) - verify_model(torch.nn.ConstantPad2d((3, 1), 3.5).eval(), inp) + verify_model(torch.nn.ConstantPad1d((3, 1), 3.5).eval(), inp) @tvm.testing.uses_gpu