This is an automated email from the ASF dual-hosted git repository.

yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ffa82b30ec [Relax][Frontend]Fix: Output tensor with zero dimension 
after torch.u… (#17990)
ffa82b30ec is described below

commit ffa82b30eccab8099bd4819ddc23af0f28e2d7e7
Author: yin ru xiao <[email protected]>
AuthorDate: Tue May 20 10:03:03 2025 +0800

    [Relax][Frontend]Fix: Output tensor with zero dimension after torch.u… 
(#17990)
    
    * [Relax][Frontend]Fix: Output tensor with zero dimension after 
torch.unbind to relax conversion
    
    * Fix: Output tensor with zero dimension after torch.unbind to relax 
conversion
---
 python/tvm/relax/frontend/torch/base_fx_graph_translator.py |  3 +--
 tests/python/relax/test_frontend_from_exported_program.py   | 12 ++++--------
 tests/python/relax/test_frontend_from_fx.py                 |  6 ++----
 3 files changed, 7 insertions(+), 14 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 5c8d7095e5..50969e85a5 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1269,8 +1269,7 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)
         assert isinstance(dim, int), "Expected 2nd argument of unbind as int"
         selections = self.shape_of(x)[dim].value
-        n_section = list(range(1, selections + 1))
-        ret, split = [], self.block_builder.emit(relax.op.split(x, n_section, 
dim))
+        ret, split = [], self.block_builder.emit(relax.op.split(x, selections, 
dim))
         for i in range(selections):
             ret.append(self.block_builder.emit(relax.op.squeeze(split[i], 
axis=dim)))
         return self.block_builder.emit(relax.Tuple(ret))
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 80da6fcf19..aaaf7e6eac 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -3108,8 +3108,7 @@ def test_unbind():
                     R.Tensor((1, 3, 10, 10), dtype="float32"),
                     R.Tensor((1, 3, 10, 10), dtype="float32"),
                     R.Tensor((1, 3, 10, 10), dtype="float32"),
-                    R.Tensor((0, 3, 10, 10), dtype="float32"),
-                ) = R.split(input_1, indices_or_sections=[1, 2, 3], axis=0)
+                ) = R.split(input_1, indices_or_sections=3, axis=0)
                 lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0]
                 lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, 
axis=[0])
                 lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[1]
@@ -3152,8 +3151,7 @@ def test_unbind():
                     R.Tensor((3, 1, 10, 10), dtype="float32"),
                     R.Tensor((3, 1, 10, 10), dtype="float32"),
                     R.Tensor((3, 1, 10, 10), dtype="float32"),
-                    R.Tensor((3, 0, 10, 10), dtype="float32"),
-                ) = R.split(input_1, indices_or_sections=[1, 2, 3], axis=1)
+                ) = R.split(input_1, indices_or_sections=3, axis=1)
                 lv1: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[0]
                 lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, 
axis=[1])
                 lv3: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[1]
@@ -3978,8 +3976,7 @@ def test_split():
                     R.Tensor((1, 3, 10, 10), dtype="float32"),
                     R.Tensor((1, 3, 10, 10), dtype="float32"),
                     R.Tensor((1, 3, 10, 10), dtype="float32"),
-                    R.Tensor((0, 3, 10, 10), dtype="float32"),
-                ) = R.split(input_1, indices_or_sections=[1, 2, 3], axis=0)
+                ) = R.split(input_1, indices_or_sections=3, axis=0)
                 lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0]
                 lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, 
axis=[0])
                 lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[1]
@@ -4022,8 +4019,7 @@ def test_split():
                     R.Tensor((3, 1, 10, 10), dtype="float32"),
                     R.Tensor((3, 1, 10, 10), dtype="float32"),
                     R.Tensor((3, 1, 10, 10), dtype="float32"),
-                    R.Tensor((3, 0, 10, 10), dtype="float32"),
-                ) = R.split(input_1, indices_or_sections=[1, 2, 3], axis=1)
+                ) = R.split(input_1, indices_or_sections=3, axis=1)
                 lv1: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[0]
                 lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, 
axis=[1])
                 lv3: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[1]
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 7fb2bed328..789c5649e6 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -3746,8 +3746,7 @@ def test_unbind():
                     R.Tensor((1, 3, 10, 10), dtype="float32"),
                     R.Tensor((1, 3, 10, 10), dtype="float32"),
                     R.Tensor((1, 3, 10, 10), dtype="float32"),
-                    R.Tensor((0, 3, 10, 10), dtype="float32"),
-                ) = R.split(input_1, indices_or_sections=[1, 2, 3], axis=0)
+                ) = R.split(input_1, indices_or_sections=3, axis=0)
                 lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0]
                 lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, 
axis=[0])
                 lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[1]
@@ -3783,8 +3782,7 @@ def test_unbind():
                     R.Tensor((3, 1, 10, 10), dtype="float32"),
                     R.Tensor((3, 1, 10, 10), dtype="float32"),
                     R.Tensor((3, 1, 10, 10), dtype="float32"),
-                    R.Tensor((3, 0, 10, 10), dtype="float32"),
-                ) = R.split(input_1, indices_or_sections=[1, 2, 3], axis=1)
+                ) = R.split(input_1, indices_or_sections=3, axis=1)
                 lv1: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[0]
                 lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, 
axis=[1])
                 lv3: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[1]

Reply via email to