This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 59e0a4a  [Torch] Various updates for PyTorch frontend   (#7348)
59e0a4a is described below

commit 59e0a4a46461b1a90bc24660cf25e08cfcfb7a1f
Author: masahi <masahi...@gmail.com>
AuthorDate: Thu Jan 28 04:30:08 2021 +0900

    [Torch] Various updates for PyTorch frontend   (#7348)
    
    * add conversion for detr
    
    * remove explicit broadcast_to before batched matmul
    
    * use take with wrap mode
    
    * add test for transformer and negative indices
    
    * add sort and argsort
    
    * add logical_and
    
    * support masked_select
    
    * add gpu targets to masked_select test
    
    * improve sort conversion
---
 python/tvm/relay/frontend/pytorch.py          |  63 ++++++++++++----
 tests/python/frontend/pytorch/test_forward.py | 101 +++++++++++++++++++++++++-
 2 files changed, 150 insertions(+), 14 deletions(-)

diff --git a/python/tvm/relay/frontend/pytorch.py 
b/python/tvm/relay/frontend/pytorch.py
index 991e3a8..68e68fd 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -399,10 +399,7 @@ class PyTorchOpConverter:
         begin = [0] * ndim
         dim = int(inputs[1])
         stride = int(inputs[4])
-        if isinstance(inputs[2], _expr.Call):
-            begin[dim], _ = try_infer_value(inputs[2], lambda ret: 
np.asscalar(ret.astype(np.int)))
-        else:
-            begin[dim] = int(inputs[2])
+        begin[dim], _ = try_infer_value(inputs[2], lambda ret: 
np.asscalar(ret.astype(np.int)))
 
         # Process begin
         if not isinstance(begin[dim], int):
@@ -518,13 +515,13 @@ class PyTorchOpConverter:
         data = inputs[0]
         dim = int(inputs[1])
         index = _wrap_const(inputs[2])
-        return _op.transform.take(data, index, axis=dim)
+        return _op.transform.take(data, index, axis=dim, mode="wrap")
 
     def take(self, inputs, input_types):
         data = inputs[0]
         indices = _op.cast(inputs[1], "int32")
 
-        return _op.transform.take(data, indices=indices)
+        return _op.transform.take(data, indices=indices, mode="wrap")
 
     def topk(self, inputs, input_types):
         data = inputs[0]
@@ -551,7 +548,13 @@ class PyTorchOpConverter:
 
     def repeat(self, inputs, input_types):
         data = inputs[0]
-        reps = inputs[1]
+        reps = []
+        for r in inputs[1]:
+            if isinstance(r, int):
+                reps.append(r)
+            else:
+                reps.append(int(_infer_value(r, {}).asnumpy()))
+
         return _op.transform.tile(data, reps=reps)
 
     def repeat_interleave(self, inputs, input_types):
@@ -1520,12 +1523,6 @@ class PyTorchOpConverter:
             # Convert a and b into 3 dimensional tensors.
             a = _op.reshape(inputs_0, [-1, a_shape[-2], a_shape[-1]])
             b = _op.reshape(inputs_1, [-1, b_shape[-2], b_shape[-1]])
-            # Broadcast b to match batch size of a
-            new_b_shape = list(self.infer_shape_with_prelude(b))
-            new_a_shape = self.infer_shape_with_prelude(a)
-            if new_a_shape[0] > new_b_shape[0]:
-                new_b_shape[0] = new_a_shape[0]
-                b = _op.broadcast_to(b, new_b_shape)
             # Transpose matrix dimensions of b.
             b = _op.transpose(b, [0, 2, 1])
             # Perform a batch matmul.
@@ -2070,6 +2067,40 @@ class PyTorchOpConverter:
         src = inputs[3]
         return _op.scatter_add(data, index, src, axis=axis)
 
+    def cumsum(self, inputs, input_types):
+        data = inputs[0]
+        dim = inputs[1]
+        dtype = inputs[2]
+
+        if inputs[2] is not None:
+            dtype = _convert_dtype_value(inputs[2])
+
+        return _op.cumsum(data, axis=dim, dtype=dtype)
+
+    def masked_fill(self, inputs, input_types):
+        mask = inputs[1]
+        value = _op.cast(_wrap_const(inputs[2]), input_types[0])
+        return _op.where(mask, value, inputs[0])
+
+    def masked_select(self, inputs, input_types):
+        mask = inputs[1]
+        indices = self.nonzero([mask], input_types, is_numpy_style=True)
+        return _op.adv_index([inputs[0]] + [indices[i] for i in 
range(indices.size)])
+
+    def sort(self, inputs, input_types):
+        data = inputs[0]
+        dim = inputs[1]
+        is_descending = inputs[2]
+        # pytorch sort returns both sorted indices and values
+        indices = _op.argsort(data, dim, not is_descending)
+        return _op.gather(data, dim, indices), indices
+
+    def argsort(self, inputs, input_types):
+        data = inputs[0]
+        dim = inputs[1]
+        is_descending = inputs[2]
+        return _op.argsort(data, dim, not is_descending)
+
     def is_floating_point(self, inputs, input_types):
         assert len(inputs) == 1
 
@@ -2263,6 +2294,7 @@ class PyTorchOpConverter:
             "torchvision::roi_align": self.roi_align,
             "aten::unbind": self.unbind,
             "aten::__and__": self.logical_and,
+            "aten::logical_and": self.logical_and,
             "aten::_shape_as_tensor": self.shape_as_tensor,
             "aten::nonzero": self.nonzero,
             "aten::nonzero_numpy": self.nonzero_numpy,
@@ -2278,6 +2310,11 @@ class PyTorchOpConverter:
             "aten::__not__": self.logical_not,
             "aten::hardswish_": self.hard_swish,
             "aten::hardswish": self.hard_swish,
+            "aten::cumsum": self.cumsum,
+            "aten::masked_fill": self.masked_fill,
+            "aten::masked_select": self.masked_select,
+            "aten::argsort": self.argsort,
+            "aten::sort": self.sort,
         }
 
     def update_convert_map(self, custom_map):
diff --git a/tests/python/frontend/pytorch/test_forward.py 
b/tests/python/frontend/pytorch/test_forward.py
index 7cdd450..6d9b559 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -1147,7 +1147,7 @@ def test_forward_view():
 @tvm.testing.uses_gpu
 def test_forward_select():
     torch.set_grad_enabled(False)
-    input_shape = [1, 3, 10, 10]
+    input_shape = [5, 3, 10, 10]
 
     class Select1(Module):
         def forward(self, *args):
@@ -1167,6 +1167,9 @@ def test_forward_select():
     input_data = torch.rand(input_shape).float()
     verify_model(Select1().float().eval(), input_data=input_data)
 
+    # test negative indexing
+    verify_model(lambda x: x[-1], input_data=input_data)
+
     x = torch.randn(3, 4)
     indices = torch.tensor([0, 2])
     verify_model(IndexedSelect(x, 0).eval(), input_data=indices)
@@ -2653,6 +2656,8 @@ def test_forward_take():
     verify_model(Take1().float().eval(), input_data=input_data)
     indices = torch.tensor([[0, 0], [1, 0]])
     verify_model(Take2().float().eval(), input_data=[input_data, indices])
+    indices = torch.tensor([0, -1])
+    verify_model(Take2().float().eval(), input_data=[input_data, indices])
 
 
 @tvm.testing.uses_gpu
@@ -3452,6 +3457,93 @@ def test_hard_swish():
         verify_model(torch.nn.Hardswish(inplace=True).eval(), input_data=input)
 
 
+def test_cumsum():
+    def test_fn(dim, dtype=None):
+        return lambda x: torch.cumsum(x, dim=dim, dtype=dtype)
+
+    inp = torch.randint(0, 100, (10000,), dtype=torch.int32)
+    verify_model(test_fn(0), [inp])
+    verify_model(test_fn(0), [inp.to(torch.int64)])
+    verify_model(test_fn(0, dtype=torch.int64), [inp.to(torch.int64)])
+
+    inp = torch.randn((100, 100), dtype=torch.float32)
+    verify_model(test_fn(dim=0, dtype=torch.float64), [inp])
+    verify_model(test_fn(dim=1), [inp])
+
+    inp = torch.randn((100, 100), dtype=torch.float32) > 0.5
+    verify_model(test_fn(dim=0, dtype=torch.int32), [inp])
+
+
+def test_masked_fill():
+    def test_fn(x, mask):
+        return torch.masked_fill(x, mask, 0.0)
+
+    inp = torch.randn(100, 100)
+    verify_model(test_fn, [inp, inp > 0.5])
+    verify_model(test_fn, [inp.to(torch.float64), inp > 0.5])
+
+
+def test_transformer():
+    model = torch.nn.Transformer(d_model=256, nhead=8, num_encoder_layers=6, 
num_decoder_layers=6)
+    model = model.eval()
+    src = torch.rand((10, 32, 256))
+    tgt = torch.rand((20, 32, 256))
+    verify_model(model.eval(), input_data=[src, tgt])
+
+
+def test_argsort():
+    def test_fn(dim, descending):
+        return lambda x: torch.argsort(x, dim=dim, descending=descending)
+
+    inp = torch.randn(100)
+    verify_model(test_fn(0, True), [inp])
+    verify_model(test_fn(0, False), [inp])
+
+    inp = torch.randn(100, 100)
+    verify_model(test_fn(0, True), [inp])
+    verify_model(test_fn(0, False), [inp])
+    verify_model(test_fn(1, True), [inp])
+    verify_model(test_fn(1, False), [inp])
+
+
+def test_sort():
+    def test_fn(dim, descending):
+        return lambda x: torch.sort(x, dim=dim, descending=descending)
+
+    inp = torch.randn(100)
+    verify_model(test_fn(0, True), [inp])
+    verify_model(test_fn(0, False), [inp])
+
+    inp = torch.randn(100, 100)
+    verify_model(test_fn(0, True), [inp])
+    verify_model(test_fn(0, False), [inp])
+    verify_model(test_fn(1, True), [inp])
+    verify_model(test_fn(1, False), [inp])
+
+
+def test_logical_and():
+    def test_fn(x, y):
+        return torch.logical_and(x, y)
+
+    a = torch.tensor([0, 1, 10, 0], dtype=torch.int8)
+    b = torch.tensor([4, 0, 1, 0], dtype=torch.int8)
+    verify_model(test_fn, [a, b])
+
+    a = torch.tensor([True, False, True])
+    b = torch.tensor([True, False, False])
+    verify_model(test_fn, [a, b])
+
+
+def test_masked_select():
+    def test_fn(x, mask):
+        return torch.masked_select(x, mask)
+
+    for shape in [(10,), (3, 4), (16, 32, 64)]:
+        x = torch.randn(*shape)
+        mask = x.ge(0.5)
+        verify_trace_model(test_fn, [x, mask], ["llvm", "cuda", "nvptx"])
+
+
 if __name__ == "__main__":
     # some structural tests
     test_forward_traced_function()
@@ -3580,6 +3672,13 @@ if __name__ == "__main__":
     test_forward_scatter()
     test_numel()
     test_bincount()
+    test_cumsum()
+    test_masked_fill()
+    test_transformer()
+    test_sort()
+    test_argsort()
+    test_logical_and()
+    test_masked_select()
 
     # Model tests
     test_resnet18()

Reply via email to