This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new f532b89e55 [Relax][PyTorch] Add support for decomposed operators and 
fix IR of ops tests(1) (#18402)
f532b89e55 is described below

commit f532b89e5558c27cf92c573fbc005c6b3c53b0a8
Author: Shushi Hong <[email protected]>
AuthorDate: Tue Oct 28 21:42:41 2025 -0400

    [Relax][PyTorch] Add support for decomposed operators and fix IR of ops 
tests(1) (#18402)
    
    * finish1
    
    * finish2
---
 .../frontend/torch/exported_program_translator.py  |   2 +
 .../relax/test_frontend_from_exported_program.py   | 121 ++++++++++++++++-----
 2 files changed, 97 insertions(+), 26 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index cbf9e33a12..011e23f1df 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -837,7 +837,9 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "log10.default": self._log10,
             "log1p.default": self._log1p,
             "logical_not.default": self._unary_op(relax.op.logical_not),
+            "logical_and.default": self._binary_op(relax.op.logical_and, 
operator.and_),
             "log_softmax.int": self._log_softmax,
+            "_log_softmax.default": self._log_softmax,
             "neg.default": self._unary_op(relax.op.negative),
             "pad.default": self._pad,
             "pixel_shuffle.default": self._pixel_shuffle,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 9d1ef48712..9851804e2a 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -817,7 +817,7 @@ def test_hardtanh():
             return torch.ops.aten.hardtanh_(input)
 
     @tvm.script.ir_module
-    class expected1:
+    class expected_for_1_2:
         @R.function
         def main(
             inp_0: R.Tensor((1, 3, 10, 10), dtype="float32")
@@ -830,10 +830,29 @@ def test_hardtanh():
                 R.output(gv)
             return gv
 
+    @tvm.script.ir_module
+    class expected_hardtanh_for_3:
+        @R.function
+        def main(
+            inp_0: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(
+            R.Tensor((1, 3, 10, 10), dtype="float32"), R.Tensor((1, 3, 10, 
10), dtype="float32")
+        ):
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.clip(
+                    inp_0, R.prim_value(T.float64(-1.0)), 
R.prim_value(T.float64(1.0))
+                )
+                gv: R.Tuple(
+                    R.Tensor((1, 3, 10, 10), dtype="float32"),
+                    R.Tensor((1, 3, 10, 10), dtype="float32"),
+                ) = (lv, lv)
+                R.output(gv)
+            return gv
+
     example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
-    verify_model(Hardtanh(), example_args, {}, expected1)
-    verify_model(Hardtanh2(), example_args, {}, expected1)
-    verify_model(Hardtanh3(), example_args, {}, expected1)
+    verify_model(Hardtanh(), example_args, {}, expected_for_1_2, 
run_ep_decomposition=True)
+    verify_model(Hardtanh2(), example_args, {}, expected_for_1_2, 
run_ep_decomposition=True)
+    verify_model(Hardtanh3(), example_args, {}, expected_hardtanh_for_3, 
run_ep_decomposition=True)
 
 
 def test_softplus():
@@ -861,16 +880,26 @@ def test_softplus():
             x: R.Tensor((1, 3, 10, 10), dtype="float32")
         ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
             with R.dataflow():
-                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.softplus(
-                    x, beta=1.0, threshold=20.0
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(
+                    x, R.const(1.0, "float32")
                 )
-                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+                lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(lv)
+                lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv1, 
R.const(1.0, "float32"))
+                lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.log(lv2)
+                lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.divide(
+                    lv3, R.const(1.0, "float32")
+                )
+                lv5: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater(
+                    lv, R.const(20.0, "float32")
+                )
+                lv6: R.Tensor((1, 3, 10, 10), dtype="float32") = R.where(lv5, 
x, lv4)
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv6,)
                 R.output(gv)
             return gv
 
     example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
-    verify_model(Softplus0(), example_args, {}, expected)
-    verify_model(Softplus1(), example_args, {}, expected)
+    verify_model(Softplus0(), example_args, {}, expected, 
run_ep_decomposition=True)
+    verify_model(Softplus1(), example_args, {}, expected, 
run_ep_decomposition=True)
 
 
 def test_leakyrelu():
@@ -896,22 +925,40 @@ def test_leakyrelu():
             return torch.ops.aten.leaky_relu_(input, 0.02)
 
     @tvm.script.ir_module
-    class expected:
+    class expected_for_1_2:
         @R.function
         def main(
             input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
         ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
             # block 0
             with R.dataflow():
-                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.nn.leakyrelu(input_1, 0.02)
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.nn.leakyrelu(input_1, alpha=0.02)
                 gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
                 R.output(gv)
             return gv
 
+    @tvm.script.ir_module
+    class expected_for_3:
+        @R.function
+        def main(
+            input: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tuple(
+            R.Tensor((1, 3, 10, 10), dtype="float32"), R.Tensor((1, 3, 10, 
10), dtype="float32")
+        ):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.nn.leakyrelu(input, alpha=0.02)
+                gv: R.Tuple(
+                    R.Tensor((1, 3, 10, 10), dtype="float32"),
+                    R.Tensor((1, 3, 10, 10), dtype="float32"),
+                ) = (lv, lv)
+                R.output(gv)
+            return gv
+
     example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
-    verify_model(LeakyReLU0(), example_args, {}, expected)
-    verify_model(LeakyReLU1(), example_args, {}, expected)
-    verify_model(LeakyReLU2(), example_args, {}, expected)
+    verify_model(LeakyReLU0(), example_args, {}, expected_for_1_2, 
run_ep_decomposition=True)
+    verify_model(LeakyReLU1(), example_args, {}, expected_for_1_2, 
run_ep_decomposition=True)
+    verify_model(LeakyReLU2(), example_args, {}, expected_for_3, 
run_ep_decomposition=True)
 
 
 def test_logaddexp():
@@ -923,13 +970,32 @@ def test_logaddexp():
     class expected:
         @R.function
         def main(
-            input_1: R.Tensor((1, 3, 10, 10), dtype="float32"),
-            input_2: R.Tensor((1, 3, 10, 10), dtype="float32"),
+            input1: R.Tensor((1, 3, 10, 10), dtype="float32"),
+            input2: R.Tensor((1, 3, 10, 10), dtype="float32"),
         ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
             # block 0
             with R.dataflow():
-                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.log_add_exp(input_1, input_2)
-                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+                lv: R.Tensor((1, 3, 10, 10), dtype="bool") = 
R.greater_equal(input1, input2)
+                lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = R.where(lv, 
input1, input2)
+                lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.where(lv, 
input2, input1)
+                lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.abs(input1)
+                lv4: R.Tensor((1, 3, 10, 10), dtype="bool") = R.not_equal(
+                    lv3, R.const(float("inf"), "float32")
+                )
+                lv5: R.Tensor((1, 3, 10, 10), dtype="bool") = R.equal(input1, 
input1)
+                lv6: R.Tensor((1, 3, 10, 10), dtype="bool") = R.multiply(lv5, 
lv4)
+                lv7: R.Tensor((1, 3, 10, 10), dtype="bool") = 
R.logical_not(lv6)
+                lv8: R.Tensor((1, 3, 10, 10), dtype="bool") = R.equal(input1, 
input2)
+                lv9: R.Tensor((1, 3, 10, 10), dtype="bool") = 
R.logical_and(lv7, lv8)
+                lv10: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.subtract(lv2, lv1)
+                lv11: R.Tensor((1, 3, 10, 10), dtype="float32") = R.exp(lv10)
+                lv12: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(
+                    lv11, R.const(1.0, "float32")
+                )
+                lv13: R.Tensor((1, 3, 10, 10), dtype="float32") = R.log(lv12)
+                lv14: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv1, 
lv13)
+                lv15: R.Tensor((1, 3, 10, 10), dtype="float32") = R.where(lv9, 
input1, lv14)
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = 
(lv15,)
                 R.output(gv)
             return gv
 
@@ -937,7 +1003,7 @@ def test_logaddexp():
         torch.randn(1, 3, 10, 10, dtype=torch.float32),
         torch.randn(1, 3, 10, 10, dtype=torch.float32),
     )
-    verify_model(LogAddExp(), example_args, {}, expected)
+    verify_model(LogAddExp(), example_args, {}, expected, 
run_ep_decomposition=True)
 
 
 def test_logsoftmax():
@@ -967,8 +1033,8 @@ def test_logsoftmax():
             return gv
 
     example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
-    verify_model(LogSoftmax(), example_args, {}, expected1)
-    verify_model(LogSoftmax2(), example_args, {}, expected1)
+    verify_model(LogSoftmax(), example_args, {}, expected1, 
run_ep_decomposition=True)
+    verify_model(LogSoftmax2(), example_args, {}, expected1, 
run_ep_decomposition=True)
 
 
 def test_prelu():
@@ -995,16 +1061,19 @@ def test_prelu():
             x: R.Tensor((1, 3, 10, 10), dtype="float32")
         ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
             with R.dataflow():
-                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.prelu(
-                    x, R.const([0.25], dtype="float32"), axis=1
+                lv: R.Tensor((1, 1, 1, 1), dtype="float32") = R.reshape(
+                    R.const([0.25], dtype="float32"), R.shape([1, 1, 1, 1])
                 )
-                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+                lv1: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater(x, 
R.const(0.0, "float32"))
+                lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.multiply(lv, x)
+                lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.where(lv1, 
x, lv2)
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv3,)
                 R.output(gv)
             return gv
 
     example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
-    verify_model(Prelu1(), example_args, {}, expected)
-    verify_model(Prelu2(), example_args, {}, expected)
+    verify_model(Prelu1(), example_args, {}, expected, 
run_ep_decomposition=True)
+    verify_model(Prelu2(), example_args, {}, expected, 
run_ep_decomposition=True)
 
 
 def test_softmax():

Reply via email to