This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new dd21ed37a8 [Unity][Relax]fix behaviors for importing torch.cat/concat 
(#15845)
dd21ed37a8 is described below

commit dd21ed37a89eabf53f60e83c1be99ef322a15a72
Author: Guoyao Li <[email protected]>
AuthorDate: Sun Oct 1 23:50:40 2023 -0400

    [Unity][Relax]fix behaviors for importing torch.cat/concat (#15845)
    
    Fix behaviors for importing torch.cat/concat
    Now we can import Stable Diffusion XL and Real-ESRGAN more smoothly
    
    Before fixing, we only support **torch.cat((x, y), dim=n_dim)**
    
    Add support for importing:
    1. **torch.cat((x, y))**        with defualt dim=0
    2. **torch.cat((x, y), n_dim)**      when user doesn't explicitly state dim=
    3. **torch.concat**       an alias of torch.cat
---
 python/tvm/relax/frontend/torch/fx_translator.py |  4 +-
 tests/python/relax/test_frontend_from_fx.py      | 49 ++++++++++++++++++++++++
 2 files changed, 52 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index a5c2a68cd8..d08e8858dc 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -505,7 +505,8 @@ class TorchFXImporter:
 
     def _cat(self, node: fx.node.Node) -> relax.Var:
         args = self.retrieve_args(node)
-        return self.block_builder.emit(relax.op.concat(args[0], 
axis=node.kwargs["dim"]))
+        axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)
+        return self.block_builder.emit(relax.op.concat(args[0], axis=axis))
 
     def _expand(self, node: fx.node.Node) -> relax.Var:
         args = self.retrieve_args(node)
@@ -1346,6 +1347,7 @@ class TorchFXImporter:
             "baddbmm": self._baddbmm,
             "bmm": self._matmul,
             "cat": self._cat,
+            "concat": self._cat,
             "expand": self._expand,
             "flatten": self._flatten,
             "permute": self._permute,
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 36ef25b025..abe6c947e4 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -3089,6 +3089,55 @@ def test_rsqrt():
     verify_model(Rsqrt(), [([256, 256], "float32")], {}, Expected1)
 
 
+def test_cat():
+    class Cat0(Module):
+        def forward(self, x, y):
+            return torch.cat((x, y))
+
+    class Cat1(Module):
+        def forward(self, x, y):
+            return torch.cat((x, y), dim=1)
+
+    class Cat2(Module):
+        def forward(self, x, y):
+            return torch.cat((x, y), 1)
+
+    class Cat3(Module):
+        def forward(self, x, y):
+            return torch.concat((x, y), dim=0)
+
+    @I.ir_module
+    class Expected1:
+        @R.function
+        def main(
+            inp_0: R.Tensor((2, 3), dtype="float32"),
+            inp_1: R.Tensor((2, 3), dtype="float32"),
+        ) -> R.Tensor((4, 3), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((4, 3), dtype="float32") = R.concat((inp_0, 
inp_1), axis=0)
+                gv: R.Tensor((4, 3), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected2:
+        @R.function
+        def main(
+            inp_0: R.Tensor((2, 3), dtype="float32"),
+            inp_1: R.Tensor((2, 3), dtype="float32"),
+        ) -> R.Tensor((2, 6), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((2, 6), dtype="float32") = R.concat((inp_0, 
inp_1), axis=1)
+                gv: R.Tensor((2, 6), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Cat0(), [([2, 3], "float32"), ([2, 3], "float32")], {}, 
Expected1)
+    verify_model(Cat1(), [([2, 3], "float32"), ([2, 3], "float32")], {}, 
Expected2)
+    verify_model(Cat2(), [([2, 3], "float32"), ([2, 3], "float32")], {}, 
Expected2)
+    verify_model(Cat3(), [([2, 3], "float32"), ([2, 3], "float32")], {}, 
Expected1)
+
+
 def test_neg():
     class Neg(Module):
         def forward(self, input):

Reply via email to