This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 68854d63d0 [Relax][PyTorch] Enable decomposition for unary ops and
refactor tests (#18401)
68854d63d0 is described below
commit 68854d63d073bc7c78b317a45e5cd82e457c101a
Author: Shushi Hong <[email protected]>
AuthorDate: Mon Oct 27 22:16:24 2025 -0400
[Relax][PyTorch] Enable decomposition for unary ops and refactor tests
(#18401)
* finish1
* finish2
* finish4
* finish5
---
.../relax/test_frontend_from_exported_program.py | 157 +++++++++++++++++++--
1 file changed, 149 insertions(+), 8 deletions(-)
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 3382141567..9d1ef48712 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -61,18 +61,13 @@ operator_basic_unary = [
(torch.log, R.log),
(torch.neg, R.negative),
(torch.relu, R.nn.relu),
- (torch.relu_, R.nn.relu),
(torch.round, R.round),
(torch.rsqrt, R.rsqrt),
- (torch.selu, R.nn.selu),
(torch.sigmoid, R.sigmoid),
- (torch.ops.aten.silu, R.nn.silu),
- (torch.ops.aten.silu_, R.nn.silu),
(torch.sin, R.sin),
(torch.sinh, R.sinh),
(torch.sign, R.sign),
(torch.sqrt, R.sqrt),
- (torch.square, R.square),
(torch.tan, R.tan),
(torch.tanh, R.tanh),
(torch.trunc, R.trunc),
@@ -99,11 +94,10 @@ def test_basic_unary_ops(pytorch_op, relax_op):
R.output(gv)
return gv
- verify_model(UnaryOp(), example_args, {}, expected)
+ verify_model(UnaryOp(), example_args, {}, expected,
run_ep_decomposition=True)
operator_bool_unary = [
- (torch.isfinite, R.isfinite),
(torch.isinf, R.isinf),
(torch.isnan, R.isnan),
]
@@ -129,7 +123,7 @@ def test_bool_unary_ops(pytorch_op, relax_op):
R.output(gv)
return gv
- verify_model(UnaryOp(), example_args, {}, expected)
+ verify_model(UnaryOp(), example_args, {}, expected,
run_ep_decomposition=True)
def test_extended_unary_ops():
@@ -467,6 +461,30 @@ def test_extended_unary_ops():
Hardswish3(), example_args, {}, expected_hardswish_for_3,
run_ep_decomposition=True
)
+ # isfinite
+ class IsFinite(Module):
+ def forward(self, input):
+ return torch.isfinite(input)
+
+ @tvm.script.ir_module
+ class expected_isfinite:
+ @R.function
+ def main(
+ input: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="bool")):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.abs(input)
+ lv1: R.Tensor((1, 3, 10, 10), dtype="bool") = R.not_equal(
+ lv, R.const(float("inf"), "float32")
+ )
+ lv2: R.Tensor((1, 3, 10, 10), dtype="bool") = R.equal(input,
input)
+ lv3: R.Tensor((1, 3, 10, 10), dtype="bool") = R.multiply(lv2,
lv1)
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="bool")) = (lv3,)
+ R.output(gv)
+ return gv
+
+ verify_model(IsFinite(), example_args, {}, expected_isfinite,
run_ep_decomposition=True)
+
# log2
class Log2(Module):
def forward(self, x):
@@ -657,6 +675,129 @@ def test_extended_unary_ops():
verify_model(ReLU6_2(), example_args, {}, expected_relu6_2,
run_ep_decomposition=True)
verify_model(ReLU6_3(), example_args, {}, expected_relu6_3,
run_ep_decomposition=True)
+ # selu
+ class SELU(Module):
+ def forward(self, input):
+ return torch.nn.functional.selu(input)
+
+ @tvm.script.ir_module
+ class expected_selu:
+ @R.function
+ def main(
+ input: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater(
+ input, R.const(0.0, "float32")
+ )
+ lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(
+ input, R.const(1.0507010221481323, "float32")
+ )
+ lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(
+ input, R.const(1.0, "float32")
+ )
+ lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(lv2)
+ lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(
+ lv3, R.const(1.0, "float32")
+ )
+ lv5: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(
+ lv4, R.const(1.7580993175506592, "float32")
+ )
+ lv6: R.Tensor((1, 3, 10, 10), dtype="float32") = R.where(lv,
lv1, lv5)
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv6,)
+ R.output(gv)
+ return gv
+
+ verify_model(SELU(), example_args, {}, expected_selu,
run_ep_decomposition=True)
+
+ # silu
+ class SiLU(Module):
+ def forward(self, input):
+ return torch.nn.functional.silu(input)
+
+ @tvm.script.ir_module
+ class expected_silu:
+ @R.function
+ def main(
+ input: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.sigmoid(input)
+ lv1: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.multiply(input, lv)
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv1,)
+ R.output(gv)
+ return gv
+
+ verify_model(SiLU(), example_args, {}, expected_silu,
run_ep_decomposition=True)
+
+ # silu_
+ class SiLU_(Module):
+ def forward(self, input):
+ return torch.ops.aten.silu_(input)
+
+ @tvm.script.ir_module
+ class expected_silu_:
+ @R.function
+ def main(
+ input: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(
+ R.Tensor((1, 3, 10, 10), dtype="float32"), R.Tensor((1, 3, 10,
10), dtype="float32")
+ ):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.sigmoid(input)
+ lv1: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.multiply(input, lv)
+ gv: R.Tuple(
+ R.Tensor((1, 3, 10, 10), dtype="float32"),
+ R.Tensor((1, 3, 10, 10), dtype="float32"),
+ ) = (
+ lv1,
+ lv1,
+ )
+ R.output(gv)
+ return gv
+
+ verify_model(SiLU_(), example_args, {}, expected_silu_,
run_ep_decomposition=True)
+
+ # square
+ class Square(Module):
+ def forward(self, input):
+ return torch.square(input)
+
+ @tvm.script.ir_module
+ class expected_square:
+ @R.function
+ def main(
+ input: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.power(
+ input, R.const(2.0, "float32")
+ )
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ verify_model(Square(), example_args, {}, expected_square,
run_ep_decomposition=True)
+
+ # relu_
+ class ReLU_(Module):
+ def forward(self, input):
+ return torch.relu_(input.clone())
+
+ @tvm.script.ir_module
+ class expected_relu_:
+ @R.function
+ def main(
+ input: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.nn.relu(input)
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ verify_model(ReLU_(), example_args, {}, expected_relu_,
run_ep_decomposition=True)
+
def test_hardtanh():
class Hardtanh(torch.nn.Module):