This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 99defd25c4 [Relax][PyTorch] Add support for torch.repeat (#17304)
99defd25c4 is described below

commit 99defd25c40c75b00395df1d2d58c84d2e0bd9ca
Author: Masahiro Hiramori <hiramori.masah...@ct.mitsubishielectric.co.jp>
AuthorDate: Wed Aug 28 04:37:30 2024 +0900

    [Relax][PyTorch] Add support for torch.repeat (#17304)
    
    * add test
    
    * add support for torch.repeat
    
    * remove debug print
---
 python/tvm/relax/frontend/torch/fx_translator.py |  9 ++++++
 tests/python/relax/test_frontend_from_fx.py      | 36 ++++++++++++++++++++++++
 2 files changed, 45 insertions(+)

diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 6d01283d3e..676f63b5c3 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -640,6 +640,14 @@ class TorchFXImporter:
             dim = None
         return self.block_builder.emit(relax.op.squeeze(x, dim))
 
+    def _repeat(self, node: fx.node.Node) -> relax.Var:
+        import torch  # type: ignore
+
+        args = self.retrieve_args(node)
+        if isinstance(args[1], (torch.Size, tuple, list)):
+            return self.block_builder.emit(relax.op.tile(args[0], 
tuple(args[1])))
+        return self.block_builder.emit(relax.op.tile(args[0], args[1:]))
+
     def _tile(self, node: fx.node.Node) -> relax.Var:
         import torch  # type: ignore
 
@@ -1484,6 +1492,7 @@ class TorchFXImporter:
             "expand": self._expand,
             "flatten": self._flatten,
             "permute": self._permute,
+            "repeat": self._repeat,
             "reshape": self._reshape,
             "split": self._split,
             "tile": self._tile,
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 5398fe3420..c6c4f25972 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -3311,6 +3311,42 @@ def test_transpose():
     verify_model(Transpose(), input_info, {}, expected1)
 
 
+def test_repeat():
+    class Tile1(Module):
+        def forward(self, x: torch.Tensor):
+            return x.repeat(2)
+
+    class Tile2(Module):
+        def forward(self, x: torch.Tensor):
+            return x.repeat(4, 2)
+
+    @tvm.script.ir_module
+    class expected1:
+        @R.function
+        def main(x: R.Tensor((3,), dtype="float32")) -> R.Tensor((6,), 
dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((6,), dtype="float32") = R.tile(x, 2)
+                gv: R.Tensor((6,), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    @tvm.script.ir_module
+    class expected2:
+        @R.function
+        def main(x: R.Tensor((1, 3), dtype="float32")) -> R.Tensor((4, 6), 
dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((4, 6), dtype="float32") = R.tile(x, [4, 2])
+                gv: R.Tensor((4, 6), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Tile1(), [([3], "float32")], {}, expected1)
+    verify_model(Tile2(), [([1, 3], "float32")], {}, expected2)
+    verify_model(Tile2(), [(torch.Size([1, 3]), "float32")], {}, expected2)
+
+
 def test_view():
     input_info = [([1, 2, 3, 4], "float32")]
 

Reply via email to