This is an automated email from the ASF dual-hosted git repository. masahi pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new d17f753 Support aten::flip (#8398) d17f753 is described below commit d17f75384d83111b9211ef0e6e0570c706a97e49 Author: delldu <31266222+del...@users.noreply.github.com> AuthorDate: Sun Jul 4 10:33:49 2021 +0800 Support aten::flip (#8398) * Support test aten::flip * Support aten::flip --- python/tvm/relay/frontend/pytorch.py | 6 ++++++ tests/python/frontend/pytorch/test_forward.py | 20 ++++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 118af5a..909b804 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2322,6 +2322,11 @@ class PyTorchOpConverter: weights = _op.full(_expr.const(1), (num_class,), dtype=input_types[0]) return _op.nn.nll_loss(predictions, targets, weights, reduction, ignore_index) + def flip(self, inputs, input_types): + data = inputs[0] + axis = inputs[1] + return _op.transform.reverse(data, axis=axis[0]) + # Operator mappings def create_convert_map(self): self.convert_map = { @@ -2536,6 +2541,7 @@ class PyTorchOpConverter: "aten::_unique2": self.unique, "aten::nll_loss": self.nll_loss, "aten::nll_loss2d": self.nll_loss, + "aten::flip": self.flip, } def update_convert_map(self, custom_map): diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 2ec2810..f76ea9a 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -3893,6 +3893,25 @@ def test_forward_nll_loss(): verify_model(torch.nn.NLLLoss(reduction="none").eval(), input_data=[predictions, targets]) +@tvm.testing.uses_gpu +def test_forward_flip(): + torch.set_grad_enabled(False) + + class Flip(Module): + def __init__(self, axis=0): + super().__init__() + self.axis = axis + + def forward(self, x): + return x.flip([self.axis]) + + input = torch.randn(2, 3, 4) + verify_model(Flip(axis=0), input_data=input) + verify_model(Flip(axis=1), input_data=input) + verify_model(Flip(axis=2), input_data=input) + verify_model(Flip(axis=-1), input_data=input) + + if __name__ == "__main__": # some structural tests test_forward_traced_function() @@ -4035,6 +4054,7 @@ if __name__ == "__main__": test_hard_swish() test_hard_sigmoid() test_forward_nll_loss() + test_forward_flip() # Model tests test_resnet18()