This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d17f753  Support aten::flip (#8398)
d17f753 is described below

commit d17f75384d83111b9211ef0e6e0570c706a97e49
Author: delldu <31266222+del...@users.noreply.github.com>
AuthorDate: Sun Jul 4 10:33:49 2021 +0800

    Support aten::flip (#8398)
    
    * Support test aten::flip
    
    * Support aten::flip
---
 python/tvm/relay/frontend/pytorch.py          |  6 ++++++
 tests/python/frontend/pytorch/test_forward.py | 20 ++++++++++++++++++++
 2 files changed, 26 insertions(+)

diff --git a/python/tvm/relay/frontend/pytorch.py 
b/python/tvm/relay/frontend/pytorch.py
index 118af5a..909b804 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -2322,6 +2322,11 @@ class PyTorchOpConverter:
             weights = _op.full(_expr.const(1), (num_class,), 
dtype=input_types[0])
         return _op.nn.nll_loss(predictions, targets, weights, reduction, 
ignore_index)
 
+    def flip(self, inputs, input_types):
+        data = inputs[0]
+        axis = inputs[1]
+        return _op.transform.reverse(data, axis=axis[0])
+
     # Operator mappings
     def create_convert_map(self):
         self.convert_map = {
@@ -2536,6 +2541,7 @@ class PyTorchOpConverter:
             "aten::_unique2": self.unique,
             "aten::nll_loss": self.nll_loss,
             "aten::nll_loss2d": self.nll_loss,
+            "aten::flip": self.flip,
         }
 
     def update_convert_map(self, custom_map):
diff --git a/tests/python/frontend/pytorch/test_forward.py 
b/tests/python/frontend/pytorch/test_forward.py
index 2ec2810..f76ea9a 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -3893,6 +3893,25 @@ def test_forward_nll_loss():
     verify_model(torch.nn.NLLLoss(reduction="none").eval(), 
input_data=[predictions, targets])
 
 
+@tvm.testing.uses_gpu
+def test_forward_flip():
+    torch.set_grad_enabled(False)
+
+    class Flip(Module):
+        def __init__(self, axis=0):
+            super().__init__()
+            self.axis = axis
+
+        def forward(self, x):
+            return x.flip([self.axis])
+
+    input = torch.randn(2, 3, 4)
+    verify_model(Flip(axis=0), input_data=input)
+    verify_model(Flip(axis=1), input_data=input)
+    verify_model(Flip(axis=2), input_data=input)
+    verify_model(Flip(axis=-1), input_data=input)
+
+
 if __name__ == "__main__":
     # some structural tests
     test_forward_traced_function()
@@ -4035,6 +4054,7 @@ if __name__ == "__main__":
     test_hard_swish()
     test_hard_sigmoid()
     test_forward_nll_loss()
+    test_forward_flip()
 
     # Model tests
     test_resnet18()

Reply via email to