This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 61b49bb3f9 [BugFix][Relax][Torch] Honor multi-axis dims in torch.flip
converter (#19511)
61b49bb3f9 is described below
commit 61b49bb3f916e6ae8d3ae28f3f4b36420510aa4c
Author: Soowon Jeong <[email protected]>
AuthorDate: Wed May 6 19:50:40 2026 +0900
[BugFix][Relax][Torch] Honor multi-axis dims in torch.flip converter
(#19511)
## Motivation
PyTorch's `torch.flip(x, dims=[...])` reverses every listed axis. The
Relax converter `_flip` (`base_fx_graph_translator.py`) instead coerces
the list to a single integer:
```python
if isinstance(dims, list | tuple) and len(dims) > 0:
dims = dims[0]
```
Only the first axis is forwarded to `relax.op.flip`, which is itself
single-axis. The remaining axes are silently dropped.
Minimal repro (vs PyTorch eager) on a `(3, 4)` input with
`dims=[-1, -2]`:
```
ref: [11, 10, 9, 8, 7, 6, 5, 4, ...] # both axes flipped
tvm: [ 3, 2, 1, 0, 7, 6, 5, 4, ...] # only last axis flipped
```
max_abs_diff = 8.0. Both the `torch.export` and legacy fx paths share
this converter, so both are affected.
## Fix
Iterate over `dims` in the converter and emit one `relax.op.flip` per
axis (flips along distinct axes commute, so the order is irrelevant).
A scalar `dims` is wrapped to a single-element list; non-int /
non-sequence arguments still raise `TypeError`.
`relax.op.flip` itself is unchanged: it is used elsewhere as a
single-axis op, and widening its signature would expand the scope of
this fix beyond the PyTorch frontend.
---
.../frontend/torch/base_fx_graph_translator.py | 14 +++++---
.../relax/test_frontend_from_exported_program.py | 41 ++++++++++++++++++++++
tests/python/relax/test_frontend_from_fx.py | 21 +++++++++++
3 files changed, 71 insertions(+), 5 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index c146cf6c00..0d92576c59 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1802,11 +1802,15 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
def _flip(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
dims = node.args[1] if len(node.args) > 1 else node.kwargs.get("dims",
None)
- if isinstance(dims, list | tuple) and len(dims) > 0:
- dims = dims[0]
- elif not isinstance(dims, int):
- raise TypeError(f"flip expects an integer axis, but got
{type(dims)}: {dims}")
- return self.block_builder.emit(relax.op.flip(x, dims))
+ if isinstance(dims, int):
+ dims = [dims]
+ elif not isinstance(dims, list | tuple):
+ raise TypeError(f"flip expects an int or list of ints, but got
{type(dims)}: {dims}")
+ # relax.op.flip is single-axis; iterate to honor multi-axis torch.flip
semantics.
+ out = x
+ for d in dims:
+ out = self.block_builder.emit(relax.op.flip(out, d))
+ return out
def _gather(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 6029499372..d5ed2aca7c 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -7441,6 +7441,47 @@ def test_flip():
verify_model(Flip1(), example_args, {}, Expected1)
+def test_flip_multi_axis():
+ class FlipMulti(Module):
+ def forward(self, data):
+ return torch.flip(data, [0, 1])
+
+ class FlipNegMulti(Module):
+ def forward(self, data):
+ return torch.flip(data, dims=[-1, -2])
+
+ @tvm.script.ir_module
+ class ExpectedMulti:
+ @R.function
+ def main(
+ inp_0: R.Tensor((2, 3), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((2, 3), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((2, 3), dtype="float32") = R.flip(inp_0, axis=0)
+ lv1: R.Tensor((2, 3), dtype="float32") = R.flip(lv, axis=1)
+ gv: R.Tuple(R.Tensor((2, 3), dtype="float32")) = (lv1,)
+ R.output(gv)
+ return gv
+
+ @tvm.script.ir_module
+ class ExpectedNegMulti:
+ @R.function
+ def main(
+ inp_0: R.Tensor((2, 3), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((2, 3), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((2, 3), dtype="float32") = R.flip(inp_0, axis=-1)
+ lv1: R.Tensor((2, 3), dtype="float32") = R.flip(lv, axis=-2)
+ gv: R.Tuple(R.Tensor((2, 3), dtype="float32")) = (lv1,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(2, 3, dtype=torch.float32),)
+
+ verify_model(FlipMulti(), example_args, {}, ExpectedMulti)
+ verify_model(FlipNegMulti(), example_args, {}, ExpectedNegMulti)
+
+
def test_take():
class Take(Module):
def forward(self, data, indices):
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 4d9060bf72..890c6ef3a1 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -5862,6 +5862,27 @@ def test_flip():
verify_model(Flip1(), [([2, 2], "float32")], {}, Expected1)
+def test_flip_multi_axis():
+ class FlipMulti(Module):
+ def forward(self, data):
+ return torch.flip(data, [0, 1])
+
+ @tvm.script.ir_module
+ class ExpectedMulti:
+ @R.function
+ def main(
+ inp_0: R.Tensor((2, 3), dtype="float32"),
+ ) -> R.Tensor((2, 3), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((2, 3), dtype="float32") = R.flip(inp_0, axis=0)
+ lv1: R.Tensor((2, 3), dtype="float32") = R.flip(lv, axis=1)
+ gv: R.Tensor((2, 3), dtype="float32") = lv1
+ R.output(gv)
+ return gv
+
+ verify_model(FlipMulti(), [([2, 3], "float32")], {}, ExpectedMulti)
+
+
def test_take():
class Take(Module):
def forward(self, data, indices):