This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 33fa9262fa [Relax][PyTorch] Add support for decomposed operators and 
fix IR of ops tests(7) (#18427)
33fa9262fa is described below

commit 33fa9262faf085ec0ad2d7ef0d843c7e4c2ba148
Author: Shushi Hong <[email protected]>
AuthorDate: Sat Nov 8 22:51:19 2025 -0500

    [Relax][PyTorch] Add support for decomposed operators and fix IR of ops 
tests(7) (#18427)
    
    * f1:
    
    * f2
    
    * f3
---
 .../relax/test_frontend_from_exported_program.py   | 186 ++++++++++++---------
 1 file changed, 103 insertions(+), 83 deletions(-)

diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 44248c1c59..c2ec57ee28 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -3505,7 +3505,7 @@ def test_unbind():
     class expected1:
         @R.function
         def main(
-            input_1: R.Tensor((3, 3, 10, 10), dtype="float32")
+            data: R.Tensor((3, 3, 10, 10), dtype="float32")
         ) -> R.Tuple(
             R.Tensor((3, 10, 10), dtype="float32"),
             R.Tensor((3, 10, 10), dtype="float32"),
@@ -3513,30 +3513,38 @@ def test_unbind():
         ):
             # block 0
             with R.dataflow():
-                lv: R.Tuple(
-                    R.Tensor((1, 3, 10, 10), dtype="float32"),
-                    R.Tensor((1, 3, 10, 10), dtype="float32"),
-                    R.Tensor((1, 3, 10, 10), dtype="float32"),
-                ) = R.split(input_1, indices_or_sections=3, axis=0)
-                lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0]
-                lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, 
axis=[0])
-                lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[1]
-                lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv3, 
axis=[0])
-                lv5: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[2]
-                lv6: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv5, 
axis=[0])
-                lv7: R.Tuple(
-                    R.Tensor((3, 10, 10), dtype="float32"),
-                    R.Tensor((3, 10, 10), dtype="float32"),
-                    R.Tensor((3, 10, 10), dtype="float32"),
-                ) = (lv2, lv4, lv6)
-                lv8: R.Tensor((3, 10, 10), dtype="float32") = lv7[0]
-                lv9: R.Tensor((3, 10, 10), dtype="float32") = lv7[1]
-                lv10: R.Tensor((3, 10, 10), dtype="float32") = lv7[2]
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.strided_slice(
+                    data,
+                    (R.prim_value(0),),
+                    (R.prim_value(0),),
+                    (R.prim_value(1),),
+                    (R.prim_value(1),),
+                    assume_inbound=False,
+                )
+                lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.strided_slice(
+                    data,
+                    (R.prim_value(0),),
+                    (R.prim_value(1),),
+                    (R.prim_value(2),),
+                    (R.prim_value(1),),
+                    assume_inbound=False,
+                )
+                lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.strided_slice(
+                    data,
+                    (R.prim_value(0),),
+                    (R.prim_value(2),),
+                    (R.prim_value(3),),
+                    (R.prim_value(1),),
+                    assume_inbound=False,
+                )
+                lv3: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv, 
axis=[0])
+                lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, 
axis=[0])
+                lv5: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv2, 
axis=[0])
                 gv: R.Tuple(
                     R.Tensor((3, 10, 10), dtype="float32"),
                     R.Tensor((3, 10, 10), dtype="float32"),
                     R.Tensor((3, 10, 10), dtype="float32"),
-                ) = (lv8, lv9, lv10)
+                ) = (lv3, lv4, lv5)
                 R.output(gv)
             return gv
 
@@ -3548,7 +3556,7 @@ def test_unbind():
     class expected2:
         @R.function
         def main(
-            input_1: R.Tensor((3, 3, 10, 10), dtype="float32")
+            data: R.Tensor((3, 3, 10, 10), dtype="float32")
         ) -> R.Tuple(
             R.Tensor((3, 10, 10), dtype="float32"),
             R.Tensor((3, 10, 10), dtype="float32"),
@@ -3556,30 +3564,38 @@ def test_unbind():
         ):
             # block 0
             with R.dataflow():
-                lv: R.Tuple(
-                    R.Tensor((3, 1, 10, 10), dtype="float32"),
-                    R.Tensor((3, 1, 10, 10), dtype="float32"),
-                    R.Tensor((3, 1, 10, 10), dtype="float32"),
-                ) = R.split(input_1, indices_or_sections=3, axis=1)
-                lv1: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[0]
-                lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, 
axis=[1])
-                lv3: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[1]
-                lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv3, 
axis=[1])
-                lv5: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[2]
-                lv6: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv5, 
axis=[1])
-                lv7: R.Tuple(
-                    R.Tensor((3, 10, 10), dtype="float32"),
-                    R.Tensor((3, 10, 10), dtype="float32"),
-                    R.Tensor((3, 10, 10), dtype="float32"),
-                ) = (lv2, lv4, lv6)
-                lv8: R.Tensor((3, 10, 10), dtype="float32") = lv7[0]
-                lv9: R.Tensor((3, 10, 10), dtype="float32") = lv7[1]
-                lv10: R.Tensor((3, 10, 10), dtype="float32") = lv7[2]
+                lv: R.Tensor((3, 1, 10, 10), dtype="float32") = 
R.strided_slice(
+                    data,
+                    (R.prim_value(1),),
+                    (R.prim_value(0),),
+                    (R.prim_value(1),),
+                    (R.prim_value(1),),
+                    assume_inbound=False,
+                )
+                lv1: R.Tensor((3, 1, 10, 10), dtype="float32") = 
R.strided_slice(
+                    data,
+                    (R.prim_value(1),),
+                    (R.prim_value(1),),
+                    (R.prim_value(2),),
+                    (R.prim_value(1),),
+                    assume_inbound=False,
+                )
+                lv2: R.Tensor((3, 1, 10, 10), dtype="float32") = 
R.strided_slice(
+                    data,
+                    (R.prim_value(1),),
+                    (R.prim_value(2),),
+                    (R.prim_value(3),),
+                    (R.prim_value(1),),
+                    assume_inbound=False,
+                )
+                lv3: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv, 
axis=[1])
+                lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1, 
axis=[1])
+                lv5: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv2, 
axis=[1])
                 gv: R.Tuple(
                     R.Tensor((3, 10, 10), dtype="float32"),
                     R.Tensor((3, 10, 10), dtype="float32"),
                     R.Tensor((3, 10, 10), dtype="float32"),
-                ) = (lv8, lv9, lv10)
+                ) = (lv3, lv4, lv5)
                 R.output(gv)
             return gv
 
@@ -3590,18 +3606,24 @@ def test_unbind():
             data: R.Tensor((3, 1, 3), dtype="float32")
         ) -> R.Tuple(R.Tensor((3, 3), dtype="float32")):
             with R.dataflow():
-                lv: R.Tensor((3, 3), dtype="float32") = R.squeeze(data, 
axis=[1])
-                lv1: R.Tuple(R.Tensor((3, 3), dtype="float32")) = (lv,)
-                lv2: R.Tensor((3, 3), dtype="float32") = lv1[0]
-                gv: R.Tuple(R.Tensor((3, 3), dtype="float32")) = (lv2,)
+                lv: R.Tensor((3, 1, 3), dtype="float32") = R.strided_slice(
+                    data,
+                    (R.prim_value(1),),
+                    (R.prim_value(0),),
+                    (R.prim_value(1),),
+                    (R.prim_value(1),),
+                    assume_inbound=False,
+                )
+                lv1: R.Tensor((3, 3), dtype="float32") = R.squeeze(lv, 
axis=[1])
+                gv: R.Tuple(R.Tensor((3, 3), dtype="float32")) = (lv1,)
                 R.output(gv)
             return gv
 
     example_args = (torch.randn(3, 3, 10, 10, dtype=torch.float32),)
-    verify_model(Unbind1(), example_args, {}, expected1)
-    verify_model(Unbind2(), example_args, {}, expected2)
+    verify_model(Unbind1(), example_args, {}, expected1, 
run_ep_decomposition=True)
+    verify_model(Unbind2(), example_args, {}, expected2, 
run_ep_decomposition=True)
     single_dim_args = (torch.randn(3, 1, 3, dtype=torch.float32),)
-    verify_model(Unbind2(), single_dim_args, {}, expected3)
+    verify_model(Unbind2(), single_dim_args, {}, expected3, 
run_ep_decomposition=True)
 
 
 def test_interpolate():
@@ -3732,8 +3754,8 @@ def test_mean():
             return gv
 
     example_args = (torch.randn(256, 256, dtype=torch.float32),)
-    verify_model(Mean(), example_args, {}, Expected1)
-    verify_model(MeanKeepDim(), example_args, {}, Expected2)
+    verify_model(Mean(), example_args, {}, Expected1, 
run_ep_decomposition=True)
+    verify_model(MeanKeepDim(), example_args, {}, Expected2, 
run_ep_decomposition=True)
 
 
 def test_sum():
@@ -3755,7 +3777,7 @@ def test_sum():
             return gv
 
     example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
-    verify_model(Sum(), example_args, {}, expected1)
+    verify_model(Sum(), example_args, {}, expected1, run_ep_decomposition=True)
 
 
 def test_argmax_argmin():
@@ -3799,8 +3821,8 @@ def test_argmax_argmin():
                 R.output(gv)
             return gv
 
-    verify_model(Argmax1(), example_args, {}, expected_argmax1)
-    verify_model(Argmax2(), example_args, {}, expected_argmax2)
+    verify_model(Argmax1(), example_args, {}, expected_argmax1, 
run_ep_decomposition=True)
+    verify_model(Argmax2(), example_args, {}, expected_argmax2, 
run_ep_decomposition=True)
 
     class Argmin1(Module):
         def __init__(self) -> None:
@@ -3840,8 +3862,8 @@ def test_argmax_argmin():
                 R.output(gv)
             return gv
 
-    verify_model(Argmin1(), example_args, {}, expected_argmin1)
-    verify_model(Argmin2(), example_args, {}, expected_argmin2)
+    verify_model(Argmin1(), example_args, {}, expected_argmin1, 
run_ep_decomposition=True)
+    verify_model(Argmin2(), example_args, {}, expected_argmin2, 
run_ep_decomposition=True)
 
 
 def test_cat_concat():
@@ -3888,10 +3910,10 @@ def test_cat_concat():
             return gv
 
     example_args = (torch.randn(2, 3, dtype=torch.float32), torch.randn(2, 3, 
dtype=torch.float32))
-    verify_model(Cat0(), example_args, {}, Expected1)
-    verify_model(Cat1(), example_args, {}, Expected2)
-    verify_model(Cat2(), example_args, {}, Expected2)
-    verify_model(Cat3(), example_args, {}, Expected1)
+    verify_model(Cat0(), example_args, {}, Expected1, 
run_ep_decomposition=True)
+    verify_model(Cat1(), example_args, {}, Expected2, 
run_ep_decomposition=True)
+    verify_model(Cat2(), example_args, {}, Expected2, 
run_ep_decomposition=True)
+    verify_model(Cat3(), example_args, {}, Expected1, 
run_ep_decomposition=True)
 
 
 def test_cumsum():
@@ -3913,7 +3935,7 @@ def test_cumsum():
             return gv
 
     example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
-    verify_model(Cumsum(), example_args, {}, expected1)
+    verify_model(Cumsum(), example_args, {}, expected1, 
run_ep_decomposition=True)
 
 
 def test_expand():
@@ -3939,8 +3961,8 @@ def test_expand():
             return gv
 
     example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
-    verify_model(Expand1(), example_args, {}, expected1)
-    verify_model(Expand2(), example_args, {}, expected1)
+    verify_model(Expand1(), example_args, {}, expected1, 
run_ep_decomposition=True)
+    verify_model(Expand2(), example_args, {}, expected1, 
run_ep_decomposition=True)
 
 
 def test_flatten():
@@ -3966,7 +3988,7 @@ def test_flatten():
             return gv
 
     example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
-    verify_model(Flatten(), example_args, {}, expected1)
+    verify_model(Flatten(), example_args, {}, expected1, 
run_ep_decomposition=True)
 
 
 def test_meshgrid():
@@ -3985,14 +4007,13 @@ def test_meshgrid():
             input1: R.Tensor((3,), dtype="float32"), input2: R.Tensor((3,), 
dtype="float32")
         ) -> R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), 
dtype="float32")):
             with R.dataflow():
-                lv: R.Tuple(
-                    R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), 
dtype="float32")
-                ) = R.meshgrid((input1, input2), indexing="ij")
-                lv1: R.Tensor((3, 3), dtype="float32") = lv[0]
-                lv2: R.Tensor((3, 3), dtype="float32") = lv[1]
+                lv: R.Tensor((3, 1), dtype="float32") = R.reshape(input1, 
R.shape([3, 1]))
+                lv1: R.Tensor((3, 3), dtype="float32") = R.broadcast_to(lv, 
R.shape([3, 3]))
+                lv2: R.Tensor((1, 3), dtype="float32") = R.reshape(input2, 
R.shape([1, 3]))
+                lv3: R.Tensor((3, 3), dtype="float32") = R.broadcast_to(lv2, 
R.shape([3, 3]))
                 gv: R.Tuple(
                     R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), 
dtype="float32")
-                ) = (lv1, lv2)
+                ) = (lv1, lv3)
                 R.output(gv)
             return gv
 
@@ -4003,14 +4024,13 @@ def test_meshgrid():
             input1: R.Tensor((3,), dtype="float32"), input2: R.Tensor((3,), 
dtype="float32")
         ) -> R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), 
dtype="float32")):
             with R.dataflow():
-                lv: R.Tuple(
-                    R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), 
dtype="float32")
-                ) = R.meshgrid((input1, input2), indexing="xy")
-                lv1: R.Tensor((3, 3), dtype="float32") = lv[0]
-                lv2: R.Tensor((3, 3), dtype="float32") = lv[1]
+                lv: R.Tensor((3, 1), dtype="float32") = R.reshape(input2, 
R.shape([3, 1]))
+                lv1: R.Tensor((3, 3), dtype="float32") = R.broadcast_to(lv, 
R.shape([3, 3]))
+                lv2: R.Tensor((1, 3), dtype="float32") = R.reshape(input1, 
R.shape([1, 3]))
+                lv3: R.Tensor((3, 3), dtype="float32") = R.broadcast_to(lv2, 
R.shape([3, 3]))
                 gv: R.Tuple(
                     R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), 
dtype="float32")
-                ) = (lv1, lv2)
+                ) = (lv3, lv1)
                 R.output(gv)
             return gv
 
@@ -4018,8 +4038,8 @@ def test_meshgrid():
         torch.randn(3, dtype=torch.float32),
         torch.randn(3, dtype=torch.float32),
     )
-    verify_model(Meshgrid1(), example_args, {}, expected1)
-    verify_model(Meshgrid2(), example_args, {}, expected2)
+    verify_model(Meshgrid1(), example_args, {}, expected1, 
run_ep_decomposition=True)
+    verify_model(Meshgrid2(), example_args, {}, expected2, 
run_ep_decomposition=True)
 
 
 def test_permute():
@@ -4045,8 +4065,8 @@ def test_permute():
             return gv
 
     example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
-    verify_model(Permute1(), example_args, {}, expected1)
-    verify_model(Permute2(), example_args, {}, expected1)
+    verify_model(Permute1(), example_args, {}, expected1, 
run_ep_decomposition=True)
+    verify_model(Permute2(), example_args, {}, expected1, 
run_ep_decomposition=True)
 
 
 def test_repeat():
@@ -4083,13 +4103,13 @@ def test_repeat():
             return gv
 
     example_args = (torch.randn(3, dtype=torch.float32),)
-    verify_model(Tile1(), example_args, {}, expected1)
+    verify_model(Tile1(), example_args, {}, expected1, 
run_ep_decomposition=True)
 
     example_args = (torch.randn(1, 3, dtype=torch.float32),)
-    verify_model(Tile2(), example_args, {}, expected2)
+    verify_model(Tile2(), example_args, {}, expected2, 
run_ep_decomposition=True)
 
     example_args = (torch.randn(1, 3, dtype=torch.float32),)
-    verify_model(Tile2(), example_args, {}, expected2)
+    verify_model(Tile2(), example_args, {}, expected2, 
run_ep_decomposition=True)
 
 
 def test_reshape():

Reply via email to