This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 33fa9262fa [Relax][PyTorch] Add support for decomposed operators and
fix IR of ops tests(7) (#18427)
33fa9262fa is described below
commit 33fa9262faf085ec0ad2d7ef0d843c7e4c2ba148
Author: Shushi Hong <[email protected]>
AuthorDate: Sat Nov 8 22:51:19 2025 -0500
[Relax][PyTorch] Add support for decomposed operators and fix IR of ops
tests(7) (#18427)
* f1:
* f2
* f3
---
.../relax/test_frontend_from_exported_program.py | 186 ++++++++++++---------
1 file changed, 103 insertions(+), 83 deletions(-)
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 44248c1c59..c2ec57ee28 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -3505,7 +3505,7 @@ def test_unbind():
class expected1:
@R.function
def main(
- input_1: R.Tensor((3, 3, 10, 10), dtype="float32")
+ data: R.Tensor((3, 3, 10, 10), dtype="float32")
) -> R.Tuple(
R.Tensor((3, 10, 10), dtype="float32"),
R.Tensor((3, 10, 10), dtype="float32"),
@@ -3513,30 +3513,38 @@ def test_unbind():
):
# block 0
with R.dataflow():
- lv: R.Tuple(
- R.Tensor((1, 3, 10, 10), dtype="float32"),
- R.Tensor((1, 3, 10, 10), dtype="float32"),
- R.Tensor((1, 3, 10, 10), dtype="float32"),
- ) = R.split(input_1, indices_or_sections=3, axis=0)
- lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0]
- lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1,
axis=[0])
- lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[1]
- lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv3,
axis=[0])
- lv5: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[2]
- lv6: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv5,
axis=[0])
- lv7: R.Tuple(
- R.Tensor((3, 10, 10), dtype="float32"),
- R.Tensor((3, 10, 10), dtype="float32"),
- R.Tensor((3, 10, 10), dtype="float32"),
- ) = (lv2, lv4, lv6)
- lv8: R.Tensor((3, 10, 10), dtype="float32") = lv7[0]
- lv9: R.Tensor((3, 10, 10), dtype="float32") = lv7[1]
- lv10: R.Tensor((3, 10, 10), dtype="float32") = lv7[2]
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.strided_slice(
+ data,
+ (R.prim_value(0),),
+ (R.prim_value(0),),
+ (R.prim_value(1),),
+ (R.prim_value(1),),
+ assume_inbound=False,
+ )
+ lv1: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.strided_slice(
+ data,
+ (R.prim_value(0),),
+ (R.prim_value(1),),
+ (R.prim_value(2),),
+ (R.prim_value(1),),
+ assume_inbound=False,
+ )
+ lv2: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.strided_slice(
+ data,
+ (R.prim_value(0),),
+ (R.prim_value(2),),
+ (R.prim_value(3),),
+ (R.prim_value(1),),
+ assume_inbound=False,
+ )
+ lv3: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv,
axis=[0])
+ lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1,
axis=[0])
+ lv5: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv2,
axis=[0])
gv: R.Tuple(
R.Tensor((3, 10, 10), dtype="float32"),
R.Tensor((3, 10, 10), dtype="float32"),
R.Tensor((3, 10, 10), dtype="float32"),
- ) = (lv8, lv9, lv10)
+ ) = (lv3, lv4, lv5)
R.output(gv)
return gv
@@ -3548,7 +3556,7 @@ def test_unbind():
class expected2:
@R.function
def main(
- input_1: R.Tensor((3, 3, 10, 10), dtype="float32")
+ data: R.Tensor((3, 3, 10, 10), dtype="float32")
) -> R.Tuple(
R.Tensor((3, 10, 10), dtype="float32"),
R.Tensor((3, 10, 10), dtype="float32"),
@@ -3556,30 +3564,38 @@ def test_unbind():
):
# block 0
with R.dataflow():
- lv: R.Tuple(
- R.Tensor((3, 1, 10, 10), dtype="float32"),
- R.Tensor((3, 1, 10, 10), dtype="float32"),
- R.Tensor((3, 1, 10, 10), dtype="float32"),
- ) = R.split(input_1, indices_or_sections=3, axis=1)
- lv1: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[0]
- lv2: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1,
axis=[1])
- lv3: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[1]
- lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv3,
axis=[1])
- lv5: R.Tensor((3, 1, 10, 10), dtype="float32") = lv[2]
- lv6: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv5,
axis=[1])
- lv7: R.Tuple(
- R.Tensor((3, 10, 10), dtype="float32"),
- R.Tensor((3, 10, 10), dtype="float32"),
- R.Tensor((3, 10, 10), dtype="float32"),
- ) = (lv2, lv4, lv6)
- lv8: R.Tensor((3, 10, 10), dtype="float32") = lv7[0]
- lv9: R.Tensor((3, 10, 10), dtype="float32") = lv7[1]
- lv10: R.Tensor((3, 10, 10), dtype="float32") = lv7[2]
+ lv: R.Tensor((3, 1, 10, 10), dtype="float32") =
R.strided_slice(
+ data,
+ (R.prim_value(1),),
+ (R.prim_value(0),),
+ (R.prim_value(1),),
+ (R.prim_value(1),),
+ assume_inbound=False,
+ )
+ lv1: R.Tensor((3, 1, 10, 10), dtype="float32") =
R.strided_slice(
+ data,
+ (R.prim_value(1),),
+ (R.prim_value(1),),
+ (R.prim_value(2),),
+ (R.prim_value(1),),
+ assume_inbound=False,
+ )
+ lv2: R.Tensor((3, 1, 10, 10), dtype="float32") =
R.strided_slice(
+ data,
+ (R.prim_value(1),),
+ (R.prim_value(2),),
+ (R.prim_value(3),),
+ (R.prim_value(1),),
+ assume_inbound=False,
+ )
+ lv3: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv,
axis=[1])
+ lv4: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv1,
axis=[1])
+ lv5: R.Tensor((3, 10, 10), dtype="float32") = R.squeeze(lv2,
axis=[1])
gv: R.Tuple(
R.Tensor((3, 10, 10), dtype="float32"),
R.Tensor((3, 10, 10), dtype="float32"),
R.Tensor((3, 10, 10), dtype="float32"),
- ) = (lv8, lv9, lv10)
+ ) = (lv3, lv4, lv5)
R.output(gv)
return gv
@@ -3590,18 +3606,24 @@ def test_unbind():
data: R.Tensor((3, 1, 3), dtype="float32")
) -> R.Tuple(R.Tensor((3, 3), dtype="float32")):
with R.dataflow():
- lv: R.Tensor((3, 3), dtype="float32") = R.squeeze(data,
axis=[1])
- lv1: R.Tuple(R.Tensor((3, 3), dtype="float32")) = (lv,)
- lv2: R.Tensor((3, 3), dtype="float32") = lv1[0]
- gv: R.Tuple(R.Tensor((3, 3), dtype="float32")) = (lv2,)
+ lv: R.Tensor((3, 1, 3), dtype="float32") = R.strided_slice(
+ data,
+ (R.prim_value(1),),
+ (R.prim_value(0),),
+ (R.prim_value(1),),
+ (R.prim_value(1),),
+ assume_inbound=False,
+ )
+ lv1: R.Tensor((3, 3), dtype="float32") = R.squeeze(lv,
axis=[1])
+ gv: R.Tuple(R.Tensor((3, 3), dtype="float32")) = (lv1,)
R.output(gv)
return gv
example_args = (torch.randn(3, 3, 10, 10, dtype=torch.float32),)
- verify_model(Unbind1(), example_args, {}, expected1)
- verify_model(Unbind2(), example_args, {}, expected2)
+ verify_model(Unbind1(), example_args, {}, expected1,
run_ep_decomposition=True)
+ verify_model(Unbind2(), example_args, {}, expected2,
run_ep_decomposition=True)
single_dim_args = (torch.randn(3, 1, 3, dtype=torch.float32),)
- verify_model(Unbind2(), single_dim_args, {}, expected3)
+ verify_model(Unbind2(), single_dim_args, {}, expected3,
run_ep_decomposition=True)
def test_interpolate():
@@ -3732,8 +3754,8 @@ def test_mean():
return gv
example_args = (torch.randn(256, 256, dtype=torch.float32),)
- verify_model(Mean(), example_args, {}, Expected1)
- verify_model(MeanKeepDim(), example_args, {}, Expected2)
+ verify_model(Mean(), example_args, {}, Expected1,
run_ep_decomposition=True)
+ verify_model(MeanKeepDim(), example_args, {}, Expected2,
run_ep_decomposition=True)
def test_sum():
@@ -3755,7 +3777,7 @@ def test_sum():
return gv
example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
- verify_model(Sum(), example_args, {}, expected1)
+ verify_model(Sum(), example_args, {}, expected1, run_ep_decomposition=True)
def test_argmax_argmin():
@@ -3799,8 +3821,8 @@ def test_argmax_argmin():
R.output(gv)
return gv
- verify_model(Argmax1(), example_args, {}, expected_argmax1)
- verify_model(Argmax2(), example_args, {}, expected_argmax2)
+ verify_model(Argmax1(), example_args, {}, expected_argmax1,
run_ep_decomposition=True)
+ verify_model(Argmax2(), example_args, {}, expected_argmax2,
run_ep_decomposition=True)
class Argmin1(Module):
def __init__(self) -> None:
@@ -3840,8 +3862,8 @@ def test_argmax_argmin():
R.output(gv)
return gv
- verify_model(Argmin1(), example_args, {}, expected_argmin1)
- verify_model(Argmin2(), example_args, {}, expected_argmin2)
+ verify_model(Argmin1(), example_args, {}, expected_argmin1,
run_ep_decomposition=True)
+ verify_model(Argmin2(), example_args, {}, expected_argmin2,
run_ep_decomposition=True)
def test_cat_concat():
@@ -3888,10 +3910,10 @@ def test_cat_concat():
return gv
example_args = (torch.randn(2, 3, dtype=torch.float32), torch.randn(2, 3,
dtype=torch.float32))
- verify_model(Cat0(), example_args, {}, Expected1)
- verify_model(Cat1(), example_args, {}, Expected2)
- verify_model(Cat2(), example_args, {}, Expected2)
- verify_model(Cat3(), example_args, {}, Expected1)
+ verify_model(Cat0(), example_args, {}, Expected1,
run_ep_decomposition=True)
+ verify_model(Cat1(), example_args, {}, Expected2,
run_ep_decomposition=True)
+ verify_model(Cat2(), example_args, {}, Expected2,
run_ep_decomposition=True)
+ verify_model(Cat3(), example_args, {}, Expected1,
run_ep_decomposition=True)
def test_cumsum():
@@ -3913,7 +3935,7 @@ def test_cumsum():
return gv
example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
- verify_model(Cumsum(), example_args, {}, expected1)
+ verify_model(Cumsum(), example_args, {}, expected1,
run_ep_decomposition=True)
def test_expand():
@@ -3939,8 +3961,8 @@ def test_expand():
return gv
example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
- verify_model(Expand1(), example_args, {}, expected1)
- verify_model(Expand2(), example_args, {}, expected1)
+ verify_model(Expand1(), example_args, {}, expected1,
run_ep_decomposition=True)
+ verify_model(Expand2(), example_args, {}, expected1,
run_ep_decomposition=True)
def test_flatten():
@@ -3966,7 +3988,7 @@ def test_flatten():
return gv
example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
- verify_model(Flatten(), example_args, {}, expected1)
+ verify_model(Flatten(), example_args, {}, expected1,
run_ep_decomposition=True)
def test_meshgrid():
@@ -3985,14 +4007,13 @@ def test_meshgrid():
input1: R.Tensor((3,), dtype="float32"), input2: R.Tensor((3,),
dtype="float32")
) -> R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3),
dtype="float32")):
with R.dataflow():
- lv: R.Tuple(
- R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3),
dtype="float32")
- ) = R.meshgrid((input1, input2), indexing="ij")
- lv1: R.Tensor((3, 3), dtype="float32") = lv[0]
- lv2: R.Tensor((3, 3), dtype="float32") = lv[1]
+ lv: R.Tensor((3, 1), dtype="float32") = R.reshape(input1,
R.shape([3, 1]))
+ lv1: R.Tensor((3, 3), dtype="float32") = R.broadcast_to(lv,
R.shape([3, 3]))
+ lv2: R.Tensor((1, 3), dtype="float32") = R.reshape(input2,
R.shape([1, 3]))
+ lv3: R.Tensor((3, 3), dtype="float32") = R.broadcast_to(lv2,
R.shape([3, 3]))
gv: R.Tuple(
R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3),
dtype="float32")
- ) = (lv1, lv2)
+ ) = (lv1, lv3)
R.output(gv)
return gv
@@ -4003,14 +4024,13 @@ def test_meshgrid():
input1: R.Tensor((3,), dtype="float32"), input2: R.Tensor((3,),
dtype="float32")
) -> R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3),
dtype="float32")):
with R.dataflow():
- lv: R.Tuple(
- R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3),
dtype="float32")
- ) = R.meshgrid((input1, input2), indexing="xy")
- lv1: R.Tensor((3, 3), dtype="float32") = lv[0]
- lv2: R.Tensor((3, 3), dtype="float32") = lv[1]
+ lv: R.Tensor((3, 1), dtype="float32") = R.reshape(input2,
R.shape([3, 1]))
+ lv1: R.Tensor((3, 3), dtype="float32") = R.broadcast_to(lv,
R.shape([3, 3]))
+ lv2: R.Tensor((1, 3), dtype="float32") = R.reshape(input1,
R.shape([1, 3]))
+ lv3: R.Tensor((3, 3), dtype="float32") = R.broadcast_to(lv2,
R.shape([3, 3]))
gv: R.Tuple(
R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3),
dtype="float32")
- ) = (lv1, lv2)
+ ) = (lv3, lv1)
R.output(gv)
return gv
@@ -4018,8 +4038,8 @@ def test_meshgrid():
torch.randn(3, dtype=torch.float32),
torch.randn(3, dtype=torch.float32),
)
- verify_model(Meshgrid1(), example_args, {}, expected1)
- verify_model(Meshgrid2(), example_args, {}, expected2)
+ verify_model(Meshgrid1(), example_args, {}, expected1,
run_ep_decomposition=True)
+ verify_model(Meshgrid2(), example_args, {}, expected2,
run_ep_decomposition=True)
def test_permute():
@@ -4045,8 +4065,8 @@ def test_permute():
return gv
example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
- verify_model(Permute1(), example_args, {}, expected1)
- verify_model(Permute2(), example_args, {}, expected1)
+ verify_model(Permute1(), example_args, {}, expected1,
run_ep_decomposition=True)
+ verify_model(Permute2(), example_args, {}, expected1,
run_ep_decomposition=True)
def test_repeat():
@@ -4083,13 +4103,13 @@ def test_repeat():
return gv
example_args = (torch.randn(3, dtype=torch.float32),)
- verify_model(Tile1(), example_args, {}, expected1)
+ verify_model(Tile1(), example_args, {}, expected1,
run_ep_decomposition=True)
example_args = (torch.randn(1, 3, dtype=torch.float32),)
- verify_model(Tile2(), example_args, {}, expected2)
+ verify_model(Tile2(), example_args, {}, expected2,
run_ep_decomposition=True)
example_args = (torch.randn(1, 3, dtype=torch.float32),)
- verify_model(Tile2(), example_args, {}, expected2)
+ verify_model(Tile2(), example_args, {}, expected2,
run_ep_decomposition=True)
def test_reshape():