This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ae839848b2 [Relax][PyTorch] Add support for decomposed operators and 
fix IR of ops tests(6) (#18420)
ae839848b2 is described below

commit ae839848b22f16aa92adb2a83ab050ecde8ee3cc
Author: Shushi Hong <[email protected]>
AuthorDate: Wed Nov 5 21:46:48 2025 -0500

    [Relax][PyTorch] Add support for decomposed operators and fix IR of ops 
tests(6) (#18420)
    
    * finish1
    
    * finish2
---
 .../frontend/torch/base_fx_graph_translator.py     |  14 ++
 .../frontend/torch/exported_program_translator.py  |  20 +-
 .../relax/test_frontend_from_exported_program.py   | 229 +++++++++------------
 3 files changed, 130 insertions(+), 133 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index aedef8acf8..03e3b8d557 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1725,6 +1725,20 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         # Support both "dim" and "dims" parameters
         if dim is None:
             dim = node.kwargs.get("dims", None)
+
+        # If dims is a list, filter out axes where dimension is not 1
+        # This is needed because PyTorch decomposition may pass all axes
+        if isinstance(dim, (list, tuple)) and len(dim) > 0:
+            shape = self.shape_of(x)
+            # Filter to only include axes where the dimension is 1
+            valid_dims = []
+            for d in dim:
+                axis = d if d >= 0 else len(shape) + d
+                if axis < len(shape) and shape[axis] == 1:
+                    valid_dims.append(d)
+            # If no valid dims, use None to squeeze all size-1 dimensions
+            dim = valid_dims if valid_dims else None
+
         return self.block_builder.emit(relax.op.squeeze(x, dim))
 
     def _stack(self, node: fx.Node) -> relax.Var:
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 3be255a29a..4f3132b8d8 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -701,11 +701,23 @@ class ExportedProgramImporter(BaseFXGraphImporter):
         return self.block_builder.emit(relax.op.take(x, index, dim))
 
     def _slice(self, node: fx.Node) -> relax.Var:
+        import sys
+
         x = self.env[node.args[0]]
-        axes = [node.args[1]]
-        begin = [node.args[2]]
-        end = [node.args[3]]
-        stride = [node.args[4] if len(node.args) > 4 else 1]
+        dim = node.args[1] if len(node.args) > 1 else 0
+        start = node.args[2] if len(node.args) > 2 else None
+        end_val = node.args[3] if len(node.args) > 3 else None
+        step = node.args[4] if len(node.args) > 4 else 1
+
+        if start is None:
+            start = 0
+        if end_val is None:
+            end_val = sys.maxsize
+
+        axes = [dim]
+        begin = [start]
+        end = [end_val]
+        stride = [step]
         return self.block_builder.emit(relax.op.strided_slice(x, axes, begin, 
end, stride))
 
     def _unflatten(self, node: fx.Node) -> relax.Var:
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 8a9fe66a0f..44248c1c59 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -4111,7 +4111,7 @@ def test_reshape():
             return gv
 
     example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
-    verify_model(Reshape(), example_args, {}, expected1)
+    verify_model(Reshape(), example_args, {}, expected1, 
run_ep_decomposition=True)
 
 
 def test_reshape_as():
@@ -4137,7 +4137,7 @@ def test_reshape_as():
         torch.randn(1, 2, 3, 4, dtype=torch.float32),
         torch.randn(2, 12, dtype=torch.float32),
     )
-    verify_model(ReshapeAs(), example_args, {}, expected1)
+    verify_model(ReshapeAs(), example_args, {}, expected1, 
run_ep_decomposition=True)
 
 
 def test_roll():
@@ -4160,25 +4160,14 @@ def test_roll():
         def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 
2), dtype="int64")):
             with R.dataflow():
                 lv: R.Tensor((8,), dtype="int64") = R.reshape(x, R.shape([8]))
-                lv1: R.Tensor((7,), dtype="int64") = R.strided_slice(
-                    lv,
-                    axes=[0],
-                    begin=[R.prim_value(0)],
-                    end=[R.prim_value(7)],
-                    strides=[R.prim_value(1)],
-                    assume_inbound=False,
-                )
-                lv2: R.Tensor((1,), dtype="int64") = R.strided_slice(
-                    lv,
-                    axes=[0],
-                    begin=[R.prim_value(7)],
-                    end=[R.prim_value(8)],
-                    strides=[R.prim_value(1)],
-                    assume_inbound=False,
+                lv1: R.Tensor((8,), dtype="int64") = R.arange(
+                    R.prim_value(0), R.prim_value(8), R.prim_value(1), 
dtype="int64"
                 )
-                lv3: R.Tensor((8,), dtype="int64") = R.concat((lv2, lv1), 
axis=0)
-                lv4: R.Tensor((4, 2), dtype="int64") = R.reshape(lv3, 
R.shape([4, 2]))
-                gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv4,)
+                lv2: R.Tensor((8,), dtype="int64") = R.add(lv1, R.const(7, 
"int64"))
+                lv3: R.Tensor((8,), dtype="int64") = R.mod(lv2, R.const(8, 
"int64"))
+                lv4: R.Tensor((8,), dtype="int64") = R.take(lv, lv3, axis=0, 
mode="fast")
+                lv5: R.Tensor((4, 2), dtype="int64") = R.reshape(lv4, 
R.shape([4, 2]))
+                gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv5,)
                 R.output(gv)
             return gv
 
@@ -4188,24 +4177,13 @@ def test_roll():
         @R.function
         def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 
2), dtype="int64")):
             with R.dataflow():
-                lv: R.Tensor((1, 2), dtype="int64") = R.strided_slice(
-                    x,
-                    axes=[0],
-                    begin=[R.prim_value(0)],
-                    end=[R.prim_value(1)],
-                    strides=[R.prim_value(1)],
-                    assume_inbound=False,
+                lv: R.Tensor((4,), dtype="int64") = R.arange(
+                    R.prim_value(0), R.prim_value(4), R.prim_value(1), 
dtype="int64"
                 )
-                lv1: R.Tensor((3, 2), dtype="int64") = R.strided_slice(
-                    x,
-                    axes=[0],
-                    begin=[R.prim_value(1)],
-                    end=[R.prim_value(4)],
-                    strides=[R.prim_value(1)],
-                    assume_inbound=False,
-                )
-                lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv), 
axis=0)
-                gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv2,)
+                lv1: R.Tensor((4,), dtype="int64") = R.add(lv, R.const(1, 
"int64"))
+                lv2: R.Tensor((4,), dtype="int64") = R.mod(lv1, R.const(4, 
"int64"))
+                lv3: R.Tensor((4, 2), dtype="int64") = R.take(x, lv2, axis=0, 
mode="fast")
+                gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv3,)
                 R.output(gv)
             return gv
 
@@ -4216,43 +4194,20 @@ def test_roll():
         def main(x: R.Tensor((4, 2), dtype="int64")) -> R.Tuple(R.Tensor((4, 
2), dtype="int64")):
             with R.dataflow():
                 # First roll along dim=0 with shift=2
-                lv: R.Tensor((2, 2), dtype="int64") = R.strided_slice(
-                    x,
-                    axes=[0],
-                    begin=[R.prim_value(0)],
-                    end=[R.prim_value(2)],
-                    strides=[R.prim_value(1)],
-                    assume_inbound=False,
+                lv: R.Tensor((4,), dtype="int64") = R.arange(
+                    R.prim_value(0), R.prim_value(4), R.prim_value(1), 
dtype="int64"
                 )
-                lv1: R.Tensor((2, 2), dtype="int64") = R.strided_slice(
-                    x,
-                    axes=[0],
-                    begin=[R.prim_value(2)],
-                    end=[R.prim_value(4)],
-                    strides=[R.prim_value(1)],
-                    assume_inbound=False,
-                )
-                lv2: R.Tensor((4, 2), dtype="int64") = R.concat((lv1, lv), 
axis=0)
-
+                lv1: R.Tensor((4,), dtype="int64") = R.add(lv, R.const(2, 
"int64"))
+                lv2: R.Tensor((4,), dtype="int64") = R.mod(lv1, R.const(4, 
"int64"))
+                lv3: R.Tensor((4, 2), dtype="int64") = R.take(x, lv2, axis=0, 
mode="fast")
                 # Second roll along dim=1 with shift=1
-                lv3: R.Tensor((4, 1), dtype="int64") = R.strided_slice(
-                    lv2,
-                    axes=[1],
-                    begin=[R.prim_value(0)],
-                    end=[R.prim_value(1)],
-                    strides=[R.prim_value(1)],
-                    assume_inbound=False,
-                )
-                lv4: R.Tensor((4, 1), dtype="int64") = R.strided_slice(
-                    lv2,
-                    axes=[1],
-                    begin=[R.prim_value(1)],
-                    end=[R.prim_value(2)],
-                    strides=[R.prim_value(1)],
-                    assume_inbound=False,
+                lv4: R.Tensor((2,), dtype="int64") = R.arange(
+                    R.prim_value(0), R.prim_value(2), R.prim_value(1), 
dtype="int64"
                 )
-                lv5: R.Tensor((4, 2), dtype="int64") = R.concat((lv4, lv3), 
axis=1)
-                gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv5,)
+                lv5: R.Tensor((2,), dtype="int64") = R.add(lv4, R.const(1, 
"int64"))
+                lv6: R.Tensor((2,), dtype="int64") = R.mod(lv5, R.const(2, 
"int64"))
+                lv7: R.Tensor((4, 2), dtype="int64") = R.take(lv3, lv6, 
axis=1, mode="fast")
+                gv: R.Tuple(R.Tensor((4, 2), dtype="int64")) = (lv7,)
                 R.output(gv)
             return gv
 
@@ -4260,9 +4215,9 @@ def test_roll():
     example_input = torch.randint(0, 10, (4, 2), dtype=torch.int64)
 
     # Run verification for each case
-    verify_model(Roll1(), (example_input,), {}, Expected1)
-    verify_model(Roll2(), (example_input,), {}, Expected2)
-    verify_model(Roll3(), (example_input,), {}, Expected3)
+    verify_model(Roll1(), (example_input,), {}, Expected1, 
run_ep_decomposition=True)
+    verify_model(Roll2(), (example_input,), {}, Expected2, 
run_ep_decomposition=True)
+    verify_model(Roll3(), (example_input,), {}, Expected3, 
run_ep_decomposition=True)
 
 
 def test_select_slice():
@@ -4342,10 +4297,10 @@ def test_select_slice():
             return gv
 
     example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
-    verify_model(Slice1(), example_args, {}, expected1)
+    verify_model(Slice1(), example_args, {}, expected1, 
run_ep_decomposition=True)
 
     example_args = (torch.randn(8, 16, dtype=torch.float32),)
-    verify_model(Slice2(), example_args, {}, expected2)
+    verify_model(Slice2(), example_args, {}, expected2, 
run_ep_decomposition=True)
 
 
 def test_slice_scatter():
@@ -4387,10 +4342,10 @@ def test_slice_scatter():
             return gv
 
     example_args = (torch.randn(8, 8, 10, 10, dtype=torch.float32), 
torch.randn(8, 3, 10, 10))
-    verify_model(SliceScatter1(), example_args, {}, expected1)
+    verify_model(SliceScatter1(), example_args, {}, expected1, 
run_ep_decomposition=True)
 
     example_args = (torch.randn(8, 16, dtype=torch.float32), torch.randn(6, 
16))
-    verify_model(SliceScatter2(), example_args, {}, expected2)
+    verify_model(SliceScatter2(), example_args, {}, expected2, 
run_ep_decomposition=True)
 
 
 def test_split():
@@ -4402,7 +4357,7 @@ def test_split():
     class Expected:
         @R.function
         def main(
-            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+            input: R.Tensor((1, 3, 10, 10), dtype="float32")
         ) -> R.Tuple(
             R.Tensor((1, 1, 10, 10), dtype="float32"),
             R.Tensor((1, 1, 10, 10), dtype="float32"),
@@ -4414,7 +4369,7 @@ def test_split():
                     R.Tensor((1, 1, 10, 10), dtype="float32"),
                     R.Tensor((1, 1, 10, 10), dtype="float32"),
                     R.Tensor((1, 1, 10, 10), dtype="float32"),
-                ) = R.split(input_1, indices_or_sections=3, axis=1)
+                ) = R.split(input, indices_or_sections=[1, 2], axis=1)
                 lv1: R.Tensor((1, 1, 10, 10), dtype="float32") = lv[0]
                 lv2: R.Tensor((1, 1, 10, 10), dtype="float32") = lv[1]
                 lv3: R.Tensor((1, 1, 10, 10), dtype="float32") = lv[2]
@@ -4434,7 +4389,7 @@ def test_split():
     class expected1:
         @R.function
         def main(
-            input_1: R.Tensor((3, 3, 10, 10), dtype="float32")
+            data: R.Tensor((3, 3, 10, 10), dtype="float32")
         ) -> R.Tuple(
             R.Tensor((3, 10, 10), dtype="float32"),
             R.Tensor((3, 10, 10), dtype="float32"),
@@ -4442,30 +4397,38 @@ def test_split():
         ):
             # block 0
             with R.dataflow():
-                lv: R.Tuple(
-                    R.Tensor((1, 3, 10, 10), dtype="float32"),
-                    R.Tensor((1, 3, 10, 10), dtype="float32"),
-                    R.Tensor((1, 3, 10, 10), dtype="float32"),
-                ) = R.split(input_1, indices_or_sections=3, axis=0)
-                lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0]
-                lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, 
axis=[0])
-                lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[1]
-                lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv3, 
axis=[0])
-                lv5: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[2]
-                lv6: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv5, 
axis=[0])
-                lv7: R.Tuple(
-                    R.Tensor((3, 10, 10), dtype="float32"),
-                    R.Tensor((3, 10, 10), dtype="float32"),
-                    R.Tensor((3, 10, 10), dtype="float32"),
-                ) = (lv2, lv4, lv6)
-                lv8: R.Tensor((3, 10, 10), dtype="float32") = lv7[0]
-                lv9: R.Tensor((3, 10, 10), dtype="float32") = lv7[1]
-                lv10: R.Tensor((3, 10, 10), dtype="float32") = lv7[2]
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.strided_slice(
+                    data,
+                    (R.prim_value(0),),
+                    (R.prim_value(0),),
+                    (R.prim_value(1),),
+                    (R.prim_value(1),),
+                    assume_inbound=False,
+                )
+                lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.strided_slice(
+                    data,
+                    (R.prim_value(0),),
+                    (R.prim_value(1),),
+                    (R.prim_value(2),),
+                    (R.prim_value(1),),
+                    assume_inbound=False,
+                )
+                lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.strided_slice(
+                    data,
+                    (R.prim_value(0),),
+                    (R.prim_value(2),),
+                    (R.prim_value(3),),
+                    (R.prim_value(1),),
+                    assume_inbound=False,
+                )
+                lv3: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv, 
axis=[0])
+                lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, 
axis=[0])
+                lv5: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv2, 
axis=[0])
                 gv: R.Tuple(
                     R.Tensor((3, 10, 10), dtype="float32"),
                     R.Tensor((3, 10, 10), dtype="float32"),
                     R.Tensor((3, 10, 10), dtype="float32"),
-                ) = (lv8, lv9, lv10)
+                ) = (lv3, lv4, lv5)
                 R.output(gv)
             return gv
 
@@ -4477,7 +4440,7 @@ def test_split():
     class expected2:
         @R.function
         def main(
-            input_1: R.Tensor((3, 3, 10, 10), dtype="float32")
+            data: R.Tensor((3, 3, 10, 10), dtype="float32")
         ) -> R.Tuple(
             R.Tensor((3, 10, 10), dtype="float32"),
             R.Tensor((3, 10, 10), dtype="float32"),
@@ -4485,39 +4448,47 @@ def test_split():
         ):
             # block 0
             with R.dataflow():
-                lv: R.Tuple(
-                    R.Tensor((3, 1, 10, 10), dtype="float32"),
-                    R.Tensor((3, 1, 10, 10), dtype="float32"),
-                    R.Tensor((3, 1, 10, 10), dtype="float32"),
-                ) = R.split(input_1, indices_or_sections=3, axis=1)
-                lv1: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[0]
-                lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, 
axis=[1])
-                lv3: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[1]
-                lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv3, 
axis=[1])
-                lv5: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[2]
-                lv6: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv5, 
axis=[1])
-                lv7: R.Tuple(
-                    R.Tensor((3, 10, 10), dtype="float32"),
-                    R.Tensor((3, 10, 10), dtype="float32"),
-                    R.Tensor((3, 10, 10), dtype="float32"),
-                ) = (lv2, lv4, lv6)
-                lv8: R.Tensor((3, 10, 10), dtype="float32") = lv7[0]
-                lv9: R.Tensor((3, 10, 10), dtype="float32") = lv7[1]
-                lv10: R.Tensor((3, 10, 10), dtype="float32") = lv7[2]
+                lv: R.Tensor((3, 1, 10, 10), dtype="float32") = 
R.strided_slice(
+                    data,
+                    (R.prim_value(1),),
+                    (R.prim_value(0),),
+                    (R.prim_value(1),),
+                    (R.prim_value(1),),
+                    assume_inbound=False,
+                )
+                lv1: R.Tensor((3, 1, 10, 10), dtype="float32") = 
R.strided_slice(
+                    data,
+                    (R.prim_value(1),),
+                    (R.prim_value(1),),
+                    (R.prim_value(2),),
+                    (R.prim_value(1),),
+                    assume_inbound=False,
+                )
+                lv2: R.Tensor((3, 1, 10, 10), dtype="float32") = 
R.strided_slice(
+                    data,
+                    (R.prim_value(1),),
+                    (R.prim_value(2),),
+                    (R.prim_value(3),),
+                    (R.prim_value(1),),
+                    assume_inbound=False,
+                )
+                lv3: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv, 
axis=[1])
+                lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, 
axis=[1])
+                lv5: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv2, 
axis=[1])
                 gv: R.Tuple(
                     R.Tensor((3, 10, 10), dtype="float32"),
                     R.Tensor((3, 10, 10), dtype="float32"),
                     R.Tensor((3, 10, 10), dtype="float32"),
-                ) = (lv8, lv9, lv10)
+                ) = (lv3, lv4, lv5)
                 R.output(gv)
             return gv
 
     example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
-    verify_model(Chunk(), example_args, {}, Expected)
+    verify_model(Chunk(), example_args, {}, Expected, 
run_ep_decomposition=True)
 
     example_args = (torch.randn(3, 3, 10, 10, dtype=torch.float32),)
-    verify_model(Unbind1(), example_args, {}, expected1)
-    verify_model(Unbind2(), example_args, {}, expected2)
+    verify_model(Unbind1(), example_args, {}, expected1, 
run_ep_decomposition=True)
+    verify_model(Unbind2(), example_args, {}, expected2, 
run_ep_decomposition=True)
 
 
 def test_squeeze():
@@ -4545,18 +4516,18 @@ def test_squeeze():
     class Expected2:
         @R.function
         def main(
-            inp_0: R.Tensor((3, 1, 4, 1), dtype="float32")
+            input: R.Tensor((3, 1, 4, 1), dtype="float32")
         ) -> R.Tuple(R.Tensor((3, 4), dtype="float32")):
             with R.dataflow():
-                lv: R.Tensor((3, 4), dtype="float32") = R.squeeze(inp_0, 
axis=None)
+                lv: R.Tensor((3, 4), dtype="float32") = R.squeeze(input, 
axis=[1, 3])
                 gv: R.Tuple(R.Tensor((3, 4), dtype="float32")) = (lv,)
                 R.output(gv)
             return gv
 
     example_args = (torch.randn(3, 1, 4, 1, dtype=torch.float32),)
 
-    verify_model(Squeeze1(), example_args, {}, Expected1)
-    verify_model(Squeeze2(), example_args, {}, Expected2)
+    verify_model(Squeeze1(), example_args, {}, Expected1, 
run_ep_decomposition=True)
+    verify_model(Squeeze2(), example_args, {}, Expected2, 
run_ep_decomposition=True)
 
 
 def test_stack():

Reply via email to