This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new be37afdf30 [Relax][PyTorch] Add support for decomposed operators and 
fix IR of ops tests(2) (#18403)
be37afdf30 is described below

commit be37afdf300569ccb2fff5b71170985065597335
Author: Shushi Hong <[email protected]>
AuthorDate: Thu Oct 30 12:00:02 2025 -0400

    [Relax][PyTorch] Add support for decomposed operators and fix IR of ops 
tests(2) (#18403)
---
 .../frontend/torch/exported_program_translator.py  |  10 ++
 .../relax/test_frontend_from_exported_program.py   | 112 +++++++++++++--------
 2 files changed, 79 insertions(+), 43 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 011e23f1df..5bb7a9ea8b 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -760,6 +760,14 @@ class ExportedProgramImporter(BaseFXGraphImporter):
         )
         return self.block_builder.emit(relax.op.zeros(size, dtype))
 
+    def _scalar_tensor(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        scalar_value = args[0]
+        dtype = self._convert_data_type(
+            node.kwargs.get("dtype", torch.get_default_dtype()), self.env
+        )
+        return self.block_builder.emit(relax.const(scalar_value, dtype))
+
     def _instance_norm(self, node: fx.Node):
         import numpy as np
 
@@ -851,6 +859,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "relu6_.default": self._unary_op(relax.op.nn.relu6),
             "round.default": self._round,
             "rsqrt.default": self._unary_op(relax.op.rsqrt),
+            "scalar_tensor.default": self._scalar_tensor,
             "rsub.Tensor": self._rsub,
             "rsub.Scalar": self._rsub,
             "selu.default": self._unary_op(relax.op.nn.selu),
@@ -861,6 +870,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "sin.default": self._unary_op(relax.op.sin),
             "sinh.default": self._unary_op(relax.op.sinh),
             "softmax.int": self._softmax,
+            "_softmax.default": self._softmax,
             "softplus.default": self._softplus,
             "softshrink.default": self._softshrink,
             "softsign.default": self._softsign,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 9851804e2a..ac36c3fe8f 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -1103,8 +1103,8 @@ def test_softmax():
             return gv
 
     example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
-    verify_model(Softmax(), example_args, {}, expected1)
-    verify_model(Softmax2(), example_args, {}, expected1)
+    verify_model(Softmax(), example_args, {}, expected1, 
run_ep_decomposition=True)
+    verify_model(Softmax2(), example_args, {}, expected1, 
run_ep_decomposition=True)
 
 
 def test_softsign():
@@ -1135,8 +1135,8 @@ def test_softsign():
             return gv
 
     example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
-    verify_model(Softsign(), example_args, {}, expected_softsign)
-    verify_model(Softsign2(), example_args, {}, expected_softsign)
+    verify_model(Softsign(), example_args, {}, expected_softsign, 
run_ep_decomposition=True)
+    verify_model(Softsign2(), example_args, {}, expected_softsign, 
run_ep_decomposition=True)
 
 
 def test_softshrink():
@@ -1159,32 +1159,24 @@ def test_softshrink():
             input: R.Tensor((1, 3, 10, 10), dtype="float32"),
         ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
             with R.dataflow():
-                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(
-                    input, R.const(0.5, "float32")
-                )
-                lv1: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater(
-                    input, R.const(0.5, "float32")
+                lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.abs(input)
+                lv1: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater(lv, 
R.const(0.5, "float32"))
+                lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.sign(input)
+                lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(
+                    lv2, R.const(0.5, "float32")
                 )
-                lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.astype(lv1, 
"float32")
-                lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.multiply(lv, lv2)
-
-                lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(
-                    input, R.const(0.5, "float32")
+                lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.subtract(input, lv3)
+                lv5: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(
+                    input, R.const(0.0, "float32")
                 )
-                lv5: R.Tensor((), dtype="float32") = R.negative(R.const(0.5, 
"float32"))
-                lv6: R.Tensor((1, 3, 10, 10), dtype="bool") = R.less(input, 
lv5)
-                lv7: R.Tensor((1, 3, 10, 10), dtype="float32") = R.astype(lv6, 
"float32")
-                lv8: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.multiply(lv4, lv7)
-
-                lv9: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv3, 
lv8)
-
-                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv9,)
+                lv6: R.Tensor((1, 3, 10, 10), dtype="float32") = R.where(lv1, 
lv4, lv5)
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv6,)
                 R.output(gv)
             return gv
 
     example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
-    verify_model(Softshrink(), example_args, {}, expected_softshrink)
-    verify_model(Softshrink2(), example_args, {}, expected_softshrink)
+    verify_model(Softshrink(), example_args, {}, expected_softshrink, 
run_ep_decomposition=True)
+    verify_model(Softshrink2(), example_args, {}, expected_softshrink, 
run_ep_decomposition=True)
 
 
 def test_tril_triu():
@@ -1198,16 +1190,27 @@ def test_tril_triu():
     class expected_tril:
         @R.function
         def main(
-            input_1: R.Tensor((10, 10), dtype="float32")
+            input: R.Tensor((10, 10), dtype="float32")
         ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
             # block 0
             with R.dataflow():
-                lv: R.Tensor((10, 10), dtype="float32") = R.tril(input_1, 1)
-                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
+                lv: R.Tensor((10,), dtype="int64") = R.arange(
+                    R.prim_value(0), R.prim_value(10), R.prim_value(1), 
dtype="int64"
+                )
+                lv1: R.Tensor((1, 10), dtype="int64") = R.expand_dims(lv, 
axis=[-2])
+                lv2: R.Tensor((10,), dtype="int64") = R.arange(
+                    R.prim_value(0), R.prim_value(10), R.prim_value(1), 
dtype="int64"
+                )
+                lv3: R.Tensor((10, 1), dtype="int64") = R.expand_dims(lv2, 
axis=[-1])
+                lv4: R.Tensor((10, 10), dtype="int64") = R.subtract(lv1, lv3)
+                lv5: R.Tensor((10, 10), dtype="bool") = R.less_equal(lv4, 
R.const(1, "int64"))
+                lv6: R.Tensor((), dtype="float32") = R.const(0.0, "float32")
+                lv7: R.Tensor((10, 10), dtype="float32") = R.where(lv5, input, 
lv6)
+                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv7,)
                 R.output(gv)
             return gv
 
-    verify_model(Tril(), example_args, {}, expected_tril)
+    verify_model(Tril(), example_args, {}, expected_tril, 
run_ep_decomposition=True)
 
     class Triu(Module):
         def forward(self, input):
@@ -1217,16 +1220,27 @@ def test_tril_triu():
     class expected_triu:
         @R.function
         def main(
-            input_1: R.Tensor((10, 10), dtype="float32")
+            input: R.Tensor((10, 10), dtype="float32")
         ) -> R.Tuple(R.Tensor((10, 10), dtype="float32")):
             # block 0
             with R.dataflow():
-                lv: R.Tensor((10, 10), dtype="float32") = R.triu(input_1, 1)
-                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv,)
+                lv: R.Tensor((10,), dtype="int64") = R.arange(
+                    R.prim_value(0), R.prim_value(10), R.prim_value(1), 
dtype="int64"
+                )
+                lv1: R.Tensor((1, 10), dtype="int64") = R.expand_dims(lv, 
axis=[-2])
+                lv2: R.Tensor((10,), dtype="int64") = R.arange(
+                    R.prim_value(0), R.prim_value(10), R.prim_value(1), 
dtype="int64"
+                )
+                lv3: R.Tensor((10, 1), dtype="int64") = R.expand_dims(lv2, 
axis=[-1])
+                lv4: R.Tensor((10, 10), dtype="int64") = R.subtract(lv1, lv3)
+                lv5: R.Tensor((10, 10), dtype="bool") = R.greater_equal(lv4, 
R.const(1, "int64"))
+                lv6: R.Tensor((), dtype="float32") = R.const(0.0, "float32")
+                lv7: R.Tensor((10, 10), dtype="float32") = R.where(lv5, input, 
lv6)
+                gv: R.Tuple(R.Tensor((10, 10), dtype="float32")) = (lv7,)
                 R.output(gv)
             return gv
 
-    verify_model(Triu(), example_args, {}, expected_triu)
+    verify_model(Triu(), example_args, {}, expected_triu, 
run_ep_decomposition=True)
 
 
 operator_binary_1 = [
@@ -1501,7 +1515,7 @@ def test_div_mode():
         torch.randn(64, 64, dtype=torch.float32),
         torch.randn(64, dtype=torch.float32),
     )
-    verify_model(DivModel(), example_args, {}, expected_div)
+    verify_model(DivModel(), example_args, {}, expected_div, 
run_ep_decomposition=True)
 
     # Case 2: Division with trunc rounding
     class DivTruncModel(torch.nn.Module):
@@ -1521,7 +1535,7 @@ def test_div_mode():
                 R.output(gv)
             return gv
 
-    verify_model(DivTruncModel(), example_args, {}, expected_div_trunc)
+    verify_model(DivTruncModel(), example_args, {}, expected_div_trunc, 
run_ep_decomposition=True)
 
     # Case 3: Division with floor rounding
     class DivFloorModel(torch.nn.Module):
@@ -1540,7 +1554,7 @@ def test_div_mode():
                 R.output(gv)
             return gv
 
-    verify_model(DivFloorModel(), example_args, {}, expected_div_floor)
+    verify_model(DivFloorModel(), example_args, {}, expected_div_floor, 
run_ep_decomposition=True)
 
 
 def test_batchnorm2d():
@@ -1578,6 +1592,8 @@ def test_batchnorm2d():
                     epsilon=1e-05,
                     center=True,
                     scale=True,
+                    momentum=1e-05,
+                    training=False,
                 )
                 lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = lv[0]
                 gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv1,)
@@ -1593,7 +1609,7 @@ def test_batchnorm2d():
         "w3": model.bn.running_mean.detach().numpy(),
         "w4": model.bn.running_var.detach().numpy(),
     }
-    verify_model(model, example_args, binding, expected1)
+    verify_model(model, example_args, binding, expected1, 
run_ep_decomposition=True)
 
 
 def test_adaptive_avgpool1d():
@@ -1748,8 +1764,8 @@ def test_addmm():
         torch.randn(10, 10, dtype=torch.float32),
     )
 
-    verify_model(Addmm1(), example_args, {}, expected1)
-    verify_model(Addmm2(), example_args, {}, expected2)
+    verify_model(Addmm1(), example_args, {}, expected1, 
run_ep_decomposition=True)
+    verify_model(Addmm2(), example_args, {}, expected2, 
run_ep_decomposition=True)
 
 
 def test_avg_pool1d():
@@ -2054,8 +2070,10 @@ def test_baddbmm():
             inp_2: R.Tensor((4, 256, 512), dtype="float32"),
         ) -> R.Tuple(R.Tensor((4, 128, 512), dtype="float32")):
             with R.dataflow():
-                lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(inp_1, 
inp_2)
-                lv1: R.Tensor((4, 128, 512), dtype="float32") = R.add(lv, 
inp_0)
+                lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(
+                    inp_1, inp_2, out_dtype="float32"
+                )
+                lv1: R.Tensor((4, 128, 512), dtype="float32") = R.add(inp_0, 
lv)
                 gv: R.Tuple(R.Tensor((4, 128, 512), dtype="float32")) = (lv1,)
                 R.output(gv)
             return gv
@@ -2076,7 +2094,9 @@ def test_baddbmm():
             inp_2: R.Tensor((4, 256, 512), dtype="float32"),
         ) -> R.Tuple(R.Tensor((4, 128, 512), dtype="float32")):
             with R.dataflow():
-                lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(inp_1, 
inp_2)
+                lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(
+                    inp_1, inp_2, out_dtype="float32"
+                )
                 lv1: R.Tensor((4, 128, 512), dtype="float32") = R.multiply(
                     lv, R.const(2, "float32")
                 )
@@ -2100,14 +2120,16 @@ def test_baddbmm():
             inp_2: R.Tensor((4, 256, 512), dtype="float32"),
         ) -> R.Tuple(R.Tensor((4, 128, 512), dtype="float32")):
             with R.dataflow():
-                lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(inp_1, 
inp_2)
+                lv: R.Tensor((4, 128, 512), dtype="float32") = R.matmul(
+                    inp_1, inp_2, out_dtype="float32"
+                )
                 lv1: R.Tensor((4, 128, 512), dtype="float32") = R.multiply(
                     lv, R.const(2, "float32")
                 )
                 lv2: R.Tensor((4, 128, 512), dtype="float32") = R.multiply(
                     inp_0, R.const(3, "float32")
                 )
-                lv3: R.Tensor((4, 128, 512), dtype="float32") = R.add(lv1, lv2)
+                lv3: R.Tensor((4, 128, 512), dtype="float32") = R.add(lv2, lv1)
                 gv: R.Tuple(R.Tensor((4, 128, 512), dtype="float32")) = (lv3,)
                 R.output(gv)
             return gv
@@ -2122,6 +2144,7 @@ def test_baddbmm():
         example_args,
         {},
         Expected1,
+        run_ep_decomposition=True,
     )
 
     verify_model(
@@ -2129,6 +2152,7 @@ def test_baddbmm():
         example_args,
         {},
         Expected2,
+        run_ep_decomposition=True,
     )
 
     verify_model(
@@ -2136,6 +2160,7 @@ def test_baddbmm():
         example_args,
         {},
         Expected3,
+        run_ep_decomposition=True,
     )
 
 
@@ -2172,6 +2197,7 @@ def test_bmm():
         example_args,
         {},
         Expected,
+        run_ep_decomposition=True,
     )
 
 

Reply via email to