This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 61b49bb3f9 [BugFix][Relax][Torch] Honor multi-axis dims in torch.flip 
converter (#19511)
61b49bb3f9 is described below

commit 61b49bb3f916e6ae8d3ae28f3f4b36420510aa4c
Author: Soowon Jeong <[email protected]>
AuthorDate: Wed May 6 19:50:40 2026 +0900

    [BugFix][Relax][Torch] Honor multi-axis dims in torch.flip converter 
(#19511)
    
    ## Motivation
    
    PyTorch's `torch.flip(x, dims=[...])` reverses every listed axis. The
    Relax converter `_flip` (`base_fx_graph_translator.py`) instead coerces
    the list to a single integer:
    
    ```python
    if isinstance(dims, list | tuple) and len(dims) > 0:
        dims = dims[0]
    ```
    
    Only the first axis is forwarded to `relax.op.flip`, which is itself
    single-axis. The remaining axes are silently dropped.
    
    Minimal repro (vs PyTorch eager) on a `(3, 4)` input with
    `dims=[-1, -2]`:
    
    ```
    ref: [11, 10,  9,  8,  7,  6,  5,  4, ...]   # both axes flipped
    tvm: [ 3,  2,  1,  0,  7,  6,  5,  4, ...]   # only last axis flipped
    ```
    
    max_abs_diff = 8.0. Both the `torch.export` and legacy fx paths share
    this converter, so both are affected.
    
    ## Fix
    
    Iterate over `dims` in the converter and emit one `relax.op.flip` per
    axis (flips along distinct axes commute, so the order is irrelevant).
    A scalar `dims` is wrapped to a single-element list; non-int /
    non-sequence arguments still raise `TypeError`.
    
    `relax.op.flip` itself is unchanged: it is used elsewhere as a
    single-axis op, and widening its signature would expand the scope of
    this fix beyond the PyTorch frontend.
---
 .../frontend/torch/base_fx_graph_translator.py     | 14 +++++---
 .../relax/test_frontend_from_exported_program.py   | 41 ++++++++++++++++++++++
 tests/python/relax/test_frontend_from_fx.py        | 21 +++++++++++
 3 files changed, 71 insertions(+), 5 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index c146cf6c00..0d92576c59 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1802,11 +1802,15 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
     def _flip(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         dims = node.args[1] if len(node.args) > 1 else node.kwargs.get("dims", 
None)
-        if isinstance(dims, list | tuple) and len(dims) > 0:
-            dims = dims[0]
-        elif not isinstance(dims, int):
-            raise TypeError(f"flip expects an integer axis, but got 
{type(dims)}: {dims}")
-        return self.block_builder.emit(relax.op.flip(x, dims))
+        if isinstance(dims, int):
+            dims = [dims]
+        elif not isinstance(dims, list | tuple):
+            raise TypeError(f"flip expects an int or list of ints, but got 
{type(dims)}: {dims}")
+        # relax.op.flip is single-axis; iterate to honor multi-axis torch.flip 
semantics.
+        out = x
+        for d in dims:
+            out = self.block_builder.emit(relax.op.flip(out, d))
+        return out
 
     def _gather(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 6029499372..d5ed2aca7c 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -7441,6 +7441,47 @@ def test_flip():
     verify_model(Flip1(), example_args, {}, Expected1)
 
 
+def test_flip_multi_axis():
+    class FlipMulti(Module):
+        def forward(self, data):
+            return torch.flip(data, [0, 1])
+
+    class FlipNegMulti(Module):
+        def forward(self, data):
+            return torch.flip(data, dims=[-1, -2])
+
+    @tvm.script.ir_module
+    class ExpectedMulti:
+        @R.function
+        def main(
+            inp_0: R.Tensor((2, 3), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((2, 3), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((2, 3), dtype="float32") = R.flip(inp_0, axis=0)
+                lv1: R.Tensor((2, 3), dtype="float32") = R.flip(lv, axis=1)
+                gv: R.Tuple(R.Tensor((2, 3), dtype="float32")) = (lv1,)
+                R.output(gv)
+            return gv
+
+    @tvm.script.ir_module
+    class ExpectedNegMulti:
+        @R.function
+        def main(
+            inp_0: R.Tensor((2, 3), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((2, 3), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((2, 3), dtype="float32") = R.flip(inp_0, axis=-1)
+                lv1: R.Tensor((2, 3), dtype="float32") = R.flip(lv, axis=-2)
+                gv: R.Tuple(R.Tensor((2, 3), dtype="float32")) = (lv1,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(2, 3, dtype=torch.float32),)
+
+    verify_model(FlipMulti(), example_args, {}, ExpectedMulti)
+    verify_model(FlipNegMulti(), example_args, {}, ExpectedNegMulti)
+
+
 def test_take():
     class Take(Module):
         def forward(self, data, indices):
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 4d9060bf72..890c6ef3a1 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -5862,6 +5862,27 @@ def test_flip():
     verify_model(Flip1(), [([2, 2], "float32")], {}, Expected1)
 
 
+def test_flip_multi_axis():
+    class FlipMulti(Module):
+        def forward(self, data):
+            return torch.flip(data, [0, 1])
+
+    @tvm.script.ir_module
+    class ExpectedMulti:
+        @R.function
+        def main(
+            inp_0: R.Tensor((2, 3), dtype="float32"),
+        ) -> R.Tensor((2, 3), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((2, 3), dtype="float32") = R.flip(inp_0, axis=0)
+                lv1: R.Tensor((2, 3), dtype="float32") = R.flip(lv, axis=1)
+                gv: R.Tensor((2, 3), dtype="float32") = lv1
+                R.output(gv)
+            return gv
+
+    verify_model(FlipMulti(), [([2, 3], "float32")], {}, ExpectedMulti)
+
+
 def test_take():
     class Take(Module):
         def forward(self, data, indices):

Reply via email to