This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 9f84d4f9ef [Relax][Frontend][Torch] Fix parsing error when input
dimension of unbind is 1 (#18351)
9f84d4f9ef is described below
commit 9f84d4f9ef3ab537167f3bfb33ec4cffe1149d22
Author: Ruxiao Yin <[email protected]>
AuthorDate: Mon Sep 29 03:19:37 2025 +0800
[Relax][Frontend][Torch] Fix parsing error when input dimension of unbind
is 1 (#18351)
* [Relax][Frontend][Torch] Fix parsing error when input dimension of unbind
is 1
* reformat code
---
.../tvm/relax/frontend/torch/base_fx_graph_translator.py | 10 +++++++---
.../python/relax/test_frontend_from_exported_program.py | 16 ++++++++++++++++
2 files changed, 23 insertions(+), 3 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 1895119e79..53b1fdd22c 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1275,9 +1275,13 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)
assert isinstance(dim, int), "Expected 2nd argument of unbind as int"
selections = self.shape_of(x)[dim].value
- ret, split = [], self.block_builder.emit(relax.op.split(x, selections,
dim))
- for i in range(selections):
- ret.append(self.block_builder.emit(relax.op.squeeze(split[i],
axis=dim)))
+ ret = []
+ if selections == 1:
+ ret.append(self.block_builder.emit(relax.op.squeeze(x, axis=dim)))
+ else:
+ split = self.block_builder.emit(relax.op.split(x, selections, dim))
+ for i in range(selections):
+ ret.append(self.block_builder.emit(relax.op.squeeze(split[i],
axis=dim)))
return self.block_builder.emit(relax.Tuple(ret))
########## Statistical ##########
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index ead341de28..65a7241217 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -3251,9 +3251,25 @@ def test_unbind():
R.output(gv)
return gv
+ @tvm.script.ir_module
+ class expected3:
+ @R.function
+ def main(
+ data: R.Tensor((3, 1, 3), dtype="float32")
+ ) -> R.Tuple(R.Tensor((3, 3), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((3, 3), dtype="float32") = R.squeeze(data,
axis=[1])
+ lv1: R.Tuple(R.Tensor((3, 3), dtype="float32")) = (lv,)
+ lv2: R.Tensor((3, 3), dtype="float32") = lv1[0]
+ gv: R.Tuple(R.Tensor((3, 3), dtype="float32")) = (lv2,)
+ R.output(gv)
+ return gv
+
example_args = (torch.randn(3, 3, 10, 10, dtype=torch.float32),)
verify_model(Unbind1(), example_args, {}, expected1)
verify_model(Unbind2(), example_args, {}, expected2)
+ single_dim_args = (torch.randn(3, 1, 3, dtype=torch.float32),)
+ verify_model(Unbind2(), single_dim_args, {}, expected3)
def test_interpolate():