This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new dd21ed37a8 [Unity][Relax]fix behaviors for importing torch.cat/concat
(#15845)
dd21ed37a8 is described below
commit dd21ed37a89eabf53f60e83c1be99ef322a15a72
Author: Guoyao Li <[email protected]>
AuthorDate: Sun Oct 1 23:50:40 2023 -0400
[Unity][Relax]fix behaviors for importing torch.cat/concat (#15845)
Fix behaviors for importing torch.cat/concat
Now we can import Stable Diffusion XL and Real-ESRGAN more smoothly
Before fixing, we only support **torch.cat((x, y), dim=n_dim)**
Add support for importing:
1. **torch.cat((x, y))** with defualt dim=0
2. **torch.cat((x, y), n_dim)** when user doesn't explicitly state dim=
3. **torch.concat** an alias of torch.cat
---
python/tvm/relax/frontend/torch/fx_translator.py | 4 +-
tests/python/relax/test_frontend_from_fx.py | 49 ++++++++++++++++++++++++
2 files changed, 52 insertions(+), 1 deletion(-)
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index a5c2a68cd8..d08e8858dc 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -505,7 +505,8 @@ class TorchFXImporter:
def _cat(self, node: fx.node.Node) -> relax.Var:
args = self.retrieve_args(node)
- return self.block_builder.emit(relax.op.concat(args[0],
axis=node.kwargs["dim"]))
+ axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)
+ return self.block_builder.emit(relax.op.concat(args[0], axis=axis))
def _expand(self, node: fx.node.Node) -> relax.Var:
args = self.retrieve_args(node)
@@ -1346,6 +1347,7 @@ class TorchFXImporter:
"baddbmm": self._baddbmm,
"bmm": self._matmul,
"cat": self._cat,
+ "concat": self._cat,
"expand": self._expand,
"flatten": self._flatten,
"permute": self._permute,
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 36ef25b025..abe6c947e4 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -3089,6 +3089,55 @@ def test_rsqrt():
verify_model(Rsqrt(), [([256, 256], "float32")], {}, Expected1)
+def test_cat():
+ class Cat0(Module):
+ def forward(self, x, y):
+ return torch.cat((x, y))
+
+ class Cat1(Module):
+ def forward(self, x, y):
+ return torch.cat((x, y), dim=1)
+
+ class Cat2(Module):
+ def forward(self, x, y):
+ return torch.cat((x, y), 1)
+
+ class Cat3(Module):
+ def forward(self, x, y):
+ return torch.concat((x, y), dim=0)
+
+ @I.ir_module
+ class Expected1:
+ @R.function
+ def main(
+ inp_0: R.Tensor((2, 3), dtype="float32"),
+ inp_1: R.Tensor((2, 3), dtype="float32"),
+ ) -> R.Tensor((4, 3), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((4, 3), dtype="float32") = R.concat((inp_0,
inp_1), axis=0)
+ gv: R.Tensor((4, 3), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected2:
+ @R.function
+ def main(
+ inp_0: R.Tensor((2, 3), dtype="float32"),
+ inp_1: R.Tensor((2, 3), dtype="float32"),
+ ) -> R.Tensor((2, 6), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((2, 6), dtype="float32") = R.concat((inp_0,
inp_1), axis=1)
+ gv: R.Tensor((2, 6), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Cat0(), [([2, 3], "float32"), ([2, 3], "float32")], {},
Expected1)
+ verify_model(Cat1(), [([2, 3], "float32"), ([2, 3], "float32")], {},
Expected2)
+ verify_model(Cat2(), [([2, 3], "float32"), ([2, 3], "float32")], {},
Expected2)
+ verify_model(Cat3(), [([2, 3], "float32"), ([2, 3], "float32")], {},
Expected1)
+
+
def test_neg():
class Neg(Module):
def forward(self, input):