This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 99defd25c4 [Relax][PyTorch] Add support for torch.repeat (#17304) 99defd25c4 is described below commit 99defd25c40c75b00395df1d2d58c84d2e0bd9ca Author: Masahiro Hiramori <hiramori.masah...@ct.mitsubishielectric.co.jp> AuthorDate: Wed Aug 28 04:37:30 2024 +0900 [Relax][PyTorch] Add support for torch.repeat (#17304) * add test * add support for torch.repeat * remove debug print --- python/tvm/relax/frontend/torch/fx_translator.py | 9 ++++++ tests/python/relax/test_frontend_from_fx.py | 36 ++++++++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 6d01283d3e..676f63b5c3 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -640,6 +640,14 @@ class TorchFXImporter: dim = None return self.block_builder.emit(relax.op.squeeze(x, dim)) + def _repeat(self, node: fx.node.Node) -> relax.Var: + import torch # type: ignore + + args = self.retrieve_args(node) + if isinstance(args[1], (torch.Size, tuple, list)): + return self.block_builder.emit(relax.op.tile(args[0], tuple(args[1]))) + return self.block_builder.emit(relax.op.tile(args[0], args[1:])) + def _tile(self, node: fx.node.Node) -> relax.Var: import torch # type: ignore @@ -1484,6 +1492,7 @@ class TorchFXImporter: "expand": self._expand, "flatten": self._flatten, "permute": self._permute, + "repeat": self._repeat, "reshape": self._reshape, "split": self._split, "tile": self._tile, diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 5398fe3420..c6c4f25972 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -3311,6 +3311,42 @@ def test_transpose(): verify_model(Transpose(), input_info, {}, expected1) +def test_repeat(): + class Tile1(Module): + def forward(self, x: torch.Tensor): + return x.repeat(2) + + class Tile2(Module): + def forward(self, x: torch.Tensor): + return x.repeat(4, 2) + + @tvm.script.ir_module + class expected1: + @R.function + def main(x: R.Tensor((3,), dtype="float32")) -> R.Tensor((6,), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((6,), dtype="float32") = R.tile(x, 2) + gv: R.Tensor((6,), dtype="float32") = lv + R.output(gv) + return gv + + @tvm.script.ir_module + class expected2: + @R.function + def main(x: R.Tensor((1, 3), dtype="float32")) -> R.Tensor((4, 6), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((4, 6), dtype="float32") = R.tile(x, [4, 2]) + gv: R.Tensor((4, 6), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Tile1(), [([3], "float32")], {}, expected1) + verify_model(Tile2(), [([1, 3], "float32")], {}, expected2) + verify_model(Tile2(), [(torch.Size([1, 3]), "float32")], {}, expected2) + + def test_view(): input_info = [([1, 2, 3, 4], "float32")]