This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 9f84d4f9ef [Relax][Frontend][Torch] Fix parsing error when input 
dimension of unbind is 1 (#18351)
9f84d4f9ef is described below

commit 9f84d4f9ef3ab537167f3bfb33ec4cffe1149d22
Author: Ruxiao Yin <[email protected]>
AuthorDate: Mon Sep 29 03:19:37 2025 +0800

    [Relax][Frontend][Torch] Fix parsing error when input dimension of unbind 
is 1 (#18351)
    
    * [Relax][Frontend][Torch] Fix parsing error when input dimension of unbind 
is 1
    
    * reformat code
---
 .../tvm/relax/frontend/torch/base_fx_graph_translator.py | 10 +++++++---
 .../python/relax/test_frontend_from_exported_program.py  | 16 ++++++++++++++++
 2 files changed, 23 insertions(+), 3 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 1895119e79..53b1fdd22c 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1275,9 +1275,13 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)
         assert isinstance(dim, int), "Expected 2nd argument of unbind as int"
         selections = self.shape_of(x)[dim].value
-        ret, split = [], self.block_builder.emit(relax.op.split(x, selections, 
dim))
-        for i in range(selections):
-            ret.append(self.block_builder.emit(relax.op.squeeze(split[i], 
axis=dim)))
+        ret = []
+        if selections == 1:
+            ret.append(self.block_builder.emit(relax.op.squeeze(x, axis=dim)))
+        else:
+            split = self.block_builder.emit(relax.op.split(x, selections, dim))
+            for i in range(selections):
+                ret.append(self.block_builder.emit(relax.op.squeeze(split[i], 
axis=dim)))
         return self.block_builder.emit(relax.Tuple(ret))
 
     ########## Statistical ##########
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index ead341de28..65a7241217 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -3251,9 +3251,25 @@ def test_unbind():
                 R.output(gv)
             return gv
 
+    @tvm.script.ir_module
+    class expected3:
+        @R.function
+        def main(
+            data: R.Tensor((3, 1, 3), dtype="float32")
+        ) -> R.Tuple(R.Tensor((3, 3), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((3, 3), dtype="float32") = R.squeeze(data, 
axis=[1])
+                lv1: R.Tuple(R.Tensor((3, 3), dtype="float32")) = (lv,)
+                lv2: R.Tensor((3, 3), dtype="float32") = lv1[0]
+                gv: R.Tuple(R.Tensor((3, 3), dtype="float32")) = (lv2,)
+                R.output(gv)
+            return gv
+
     example_args = (torch.randn(3, 3, 10, 10, dtype=torch.float32),)
     verify_model(Unbind1(), example_args, {}, expected1)
     verify_model(Unbind2(), example_args, {}, expected2)
+    single_dim_args = (torch.randn(3, 1, 3, dtype=torch.float32),)
+    verify_model(Unbind2(), single_dim_args, {}, expected3)
 
 
 def test_interpolate():

Reply via email to