This is an automated email from the ASF dual-hosted git repository. masahi pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 65121c878a [Relay][Frontend] Add support for aten::concat (#16199) 65121c878a is described below commit 65121c878aed37adb6434b0238e36611a59881d7 Author: Jongho Choi <sweetco...@snu.ac.kr> AuthorDate: Sat Dec 9 10:34:26 2023 +0900 [Relay][Frontend] Add support for aten::concat (#16199) * Update pytorch.py * Add concat test * rm whitespace * Add diable docstring * update comment --- python/tvm/relay/frontend/pytorch.py | 1 + tests/python/frontend/pytorch/test_forward.py | 22 ++++++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 9583575bfc..c507da13a7 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -4051,6 +4051,7 @@ class PyTorchOpConverter: "aten::squeeze": self.squeeze, "aten::unsqueeze": self.unsqueeze, "aten::cat": self.concatenate, + "aten::concat": self.concatenate, "aten::slice": self.slice, "aten::narrow": self.narrow, "aten::split": self.split, diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 6109141dea..56afe72ecd 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -720,9 +720,31 @@ def test_forward_concatenate(): c = (args[0][:, :, 2] + 5) * 13 return torch.cat([t.unsqueeze(2) for t in [a, b, c]], 2) + class Concatenate3(Module): + """ + torch.concat is preserved as aten::concat only when in a nested module. + (In the most cases, It is converted to aten::cat instead of aten::concat.) + """ + + def __init__(self): + super().__init__() + + class _Concatenate(Module): + def forward(self, *args): + a = (args[0][:, :, 0] + 2) * 7 + b = (args[0][:, :, 1] + 3) * 11 + c = (args[0][:, :, 2] + 5) * 13 + return torch.concat([t.unsqueeze(2) for t in [a, b, c]], 2) + + self.mod = _Concatenate() + + def forward(self, *args): + return self.mod(*args) + input_data = torch.rand(input_shape).float() verify_model(Concatenate1().float().eval(), input_data=input_data) verify_model(Concatenate2().float().eval(), input_data=input_data) + verify_model(Concatenate3().float().eval(), input_data=input_data) @tvm.testing.uses_gpu