This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new bfb0dd6a16 [Relax][PyTorch] Add support for decomposed operators and 
fix IR of ops tests(8) (#18428)
bfb0dd6a16 is described below

commit bfb0dd6a161d33c58f67469988244b139366e063
Author: Shushi Hong <[email protected]>
AuthorDate: Mon Nov 10 00:35:43 2025 -0500

    [Relax][PyTorch] Add support for decomposed operators and fix IR of ops 
tests(8) (#18428)
---
 .../frontend/torch/base_fx_graph_translator.py     |   8 +
 .../frontend/torch/exported_program_translator.py  |   2 +
 .../relax/test_frontend_from_exported_program.py   | 222 ++++++++++++++++-----
 3 files changed, 179 insertions(+), 53 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 03e3b8d557..177e3d91f9 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1379,6 +1379,14 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         keepdim = args[2] if len(node.args) > 2 else 
node.kwargs.get("keepdim", False)
         return self.block_builder.emit(relax.op.variance(x, dim, 
keepdims=keepdim))
 
+    def _any(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        x = args[0]
+        dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
+        keepdim = args[2] if len(node.args) > 2 else 
node.kwargs.get("keepdim", False)
+        # For boolean tensors, any is equivalent to max (checking if any 
element is True)
+        return self.block_builder.emit(relax.op.max(x, dim, keepdims=keepdim))
+
     ########## Search ##########
 
     def _argmax_argmin(self, op: Callable) -> Callable:
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 4f3132b8d8..ddd19f2b58 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -930,6 +930,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "remainder.Tensor": self._binary_op(relax.op.floor_mod, 
operator.mod),
             "remainder.Scalar": self._binary_op(relax.op.floor_mod, 
operator.mod),
             "mul.Tensor": self._binary_op(relax.op.multiply, operator.mul),
+            "mul.Scalar": self._binary_op(relax.op.multiply, operator.mul),
             "mul_.Tensor": self._binary_op(relax.op.multiply, operator.mul),
             "ne.Tensor": self._binary_op(relax.op.not_equal, operator.ne),
             "ne.Scalar": self._binary_op(relax.op.not_equal, operator.ne),
@@ -988,6 +989,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "upsample_nearest2d.vec": self._upsample_nearest2d,
             "upsample_bicubic2d.vec": self._upsample_bicubic2d,
             # statistical
+            "any.dim": self._any,
             "mean.dim": self._mean,
             "prod.default": self._prod,
             "std.correction": self._std,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index c2ec57ee28..fb4f77567e 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -2778,16 +2778,26 @@ def test_pixel_shuffle():
             x: R.Tensor((1, 8, 10, 15), dtype="float32")
         ) -> R.Tuple(R.Tensor((1, 2, 20, 30), dtype="float32")):
             with R.dataflow():
-                lv: R.Tensor((1, 2, 20, 30), dtype="float32") = 
R.nn.pixel_shuffle(
-                    x, upscale_factor=2
+                lv: R.Tensor((1, 2, 2, 2, 10, 15), dtype="float32") = 
R.reshape(
+                    x, R.shape([1, 2, 2, 2, 10, 15])
                 )
-                gv: R.Tuple(R.Tensor((1, 2, 20, 30), dtype="float32")) = (lv,)
+                lv1: R.Tensor((1, 2, 10, 2, 15, 2), dtype="float32") = 
R.permute_dims(
+                    lv, axes=[0, 1, 4, 2, 5, 3]
+                )
+                lv2: R.Tensor((1, 2, 20, 30), dtype="float32") = R.reshape(
+                    lv1, R.shape([1, 2, 20, 30])
+                )
+                gv: R.Tuple(R.Tensor((1, 2, 20, 30), dtype="float32")) = (lv2,)
                 R.output(gv)
             return gv
 
     example_args = (torch.randn(1, 8, 10, 15, dtype=torch.float32),)
-    verify_model(PixelShuffle1(upscale_factor=2), example_args, {}, expected)
-    verify_model(PixelShuffle2(upscale_factor=2), example_args, {}, expected)
+    verify_model(
+        PixelShuffle1(upscale_factor=2), example_args, {}, expected, 
run_ep_decomposition=True
+    )
+    verify_model(
+        PixelShuffle2(upscale_factor=2), example_args, {}, expected, 
run_ep_decomposition=True
+    )
 
 
 def test_einsum():
@@ -2832,10 +2842,10 @@ def test_einsum():
             return gv
 
     example_args = (torch.randn(4, 4, dtype=torch.float32),)
-    verify_model(Einsum1(), example_args, {}, Expected1)
+    verify_model(Einsum1(), example_args, {}, Expected1, 
run_ep_decomposition=False)
 
     example_args = (torch.randn(5, dtype=torch.float32), torch.randn(4, 
dtype=torch.float32))
-    verify_model(Einsum2(), example_args, {}, Expected2)
+    verify_model(Einsum2(), example_args, {}, Expected2, 
run_ep_decomposition=False)
 
 
 def test_outer():
@@ -2847,11 +2857,12 @@ def test_outer():
     class expected:
         @R.function
         def main(
-            a: R.Tensor((3,), dtype="float32"), b: R.Tensor((4,), 
dtype="float32")
+            x: R.Tensor((3,), dtype="float32"), y: R.Tensor((4,), 
dtype="float32")
         ) -> R.Tuple(R.Tensor((3, 4), dtype="float32")):
             with R.dataflow():
-                lv: R.Tensor((3, 4), dtype="float32") = R.outer(a, b)
-                gv: R.Tuple(R.Tensor((3, 4), dtype="float32")) = (lv,)
+                lv: R.Tensor((3, 1), dtype="float32") = R.reshape(x, 
R.shape([3, 1]))
+                lv1: R.Tensor((3, 4), dtype="float32") = R.multiply(lv, y)
+                gv: R.Tuple(R.Tensor((3, 4), dtype="float32")) = (lv1,)
                 R.output(gv)
             return gv
 
@@ -2859,7 +2870,7 @@ def test_outer():
         torch.randn(3, dtype=torch.float32),
         torch.randn(4, dtype=torch.float32),
     )
-    verify_model(Outer(), example_args, {}, expected)
+    verify_model(Outer(), example_args, {}, expected, 
run_ep_decomposition=True)
 
 
 def test_embedding():
@@ -2889,7 +2900,7 @@ def test_embedding():
 
     model = Embedding()
     binding = {"w1": model.embedding.weight.detach().numpy()}
-    verify_model(model, example_args, binding, expected1)
+    verify_model(model, example_args, binding, expected1, 
run_ep_decomposition=True)
 
 
 def test_groupnorm():
@@ -3056,12 +3067,14 @@ def test_linear():
         ) -> R.Tuple(R.Tensor((1, 3, 10, 7), dtype="float32")):
             # block 0
             with R.dataflow():
-                lv: R.Tensor((10, 7), dtype="float32") = R.permute_dims(w1, 
axes=None)
-                lv1: R.Tensor((1, 3, 10, 7), dtype="float32") = R.matmul(
-                    input_1, lv, out_dtype="float32"
+                lv: R.Tensor((30, 10), dtype="float32") = R.reshape(input_1, 
R.shape([30, 10]))
+                lv1: R.Tensor((10, 7), dtype="float32") = R.permute_dims(w1, 
axes=[1, 0])
+                lv2: R.Tensor((30, 7), dtype="float32") = R.matmul(lv, lv1, 
out_dtype="float32")
+                lv3: R.Tensor((30, 7), dtype="float32") = R.add(w2, lv2)
+                lv4: R.Tensor((1, 3, 10, 7), dtype="float32") = R.reshape(
+                    lv3, R.shape([1, 3, 10, 7])
                 )
-                lv2: R.Tensor((1, 3, 10, 7), dtype="float32") = R.add(lv1, w2)
-                gv: R.Tuple(R.Tensor((1, 3, 10, 7), dtype="float32")) = (lv2,)
+                gv: R.Tuple(R.Tensor((1, 3, 10, 7), dtype="float32")) = (lv4,)
                 R.output(gv)
             return gv
 
@@ -3082,11 +3095,13 @@ def test_linear():
         ) -> R.Tuple(R.Tensor((1, 3, 10, 7), dtype="float32")):
             # block 0
             with R.dataflow():
-                lv: R.Tensor((10, 7), dtype="float32") = R.permute_dims(w1, 
axes=None)
-                lv1: R.Tensor((1, 3, 10, 7), dtype="float32") = R.matmul(
-                    input_1, lv, out_dtype="float32"
+                lv: R.Tensor((10, 7), dtype="float32") = R.permute_dims(w1, 
axes=[1, 0])
+                lv1: R.Tensor((30, 10), dtype="float32") = R.reshape(input_1, 
R.shape([30, 10]))
+                lv2: R.Tensor((30, 7), dtype="float32") = R.matmul(lv1, lv, 
out_dtype="float32")
+                lv3: R.Tensor((1, 3, 10, 7), dtype="float32") = R.reshape(
+                    lv2, R.shape([1, 3, 10, 7])
                 )
-                gv: R.Tuple(R.Tensor((1, 3, 10, 7), dtype="float32")) = (lv1,)
+                gv: R.Tuple(R.Tensor((1, 3, 10, 7), dtype="float32")) = (lv3,)
                 R.output(gv)
             return gv
 
@@ -3094,15 +3109,15 @@ def test_linear():
 
     model = Dense1()
     binding = {"w1": model.linear.weight.detach().numpy(), "w2": 
model.linear.bias.detach().numpy()}
-    verify_model(model, example_args, binding, expected1)
+    verify_model(model, example_args, binding, expected1, 
run_ep_decomposition=True)
 
     model = Dense1Func()
     binding = {"w1": model.weight.detach().numpy(), "w2": 
model.bias.detach().numpy()}
-    verify_model(model, example_args, binding, expected1)
+    verify_model(model, example_args, binding, expected1, 
run_ep_decomposition=True)
 
     model = Dense2()
     binding = {"w1": model.linear.weight.detach().numpy()}
-    verify_model(model, example_args, binding, expected2)
+    verify_model(model, example_args, binding, expected2, 
run_ep_decomposition=True)
 
 
 def test_maxpool1d():
@@ -3415,27 +3430,76 @@ def test_scaled_dot_product_attention():
     class Expected1:
         @R.function
         def main(
-            inp_0: R.Tensor((32, 8, 128, 64), dtype="float32"),
-            inp_1: R.Tensor((32, 8, 128, 64), dtype="float32"),
-            inp_2: R.Tensor((32, 8, 128, 64), dtype="float32"),
+            q: R.Tensor((32, 8, 128, 64), dtype="float32"),
+            k: R.Tensor((32, 8, 128, 64), dtype="float32"),
+            v: R.Tensor((32, 8, 128, 64), dtype="float32"),
         ) -> R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")):
             with R.dataflow():
-                lv: R.Tensor((32, 128, 8, 64), dtype="float32") = 
R.permute_dims(
-                    inp_0, axes=[0, 2, 1, 3]
+                lv: R.Tensor((32, 8, 128, 64), dtype="float32") = R.multiply(
+                    q, R.const(0.35355338454246521, "float32")
+                )
+                lv1: R.Tensor((32, 8, 64, 128), dtype="float32") = 
R.permute_dims(
+                    k, axes=[0, 1, 3, 2]
+                )
+                lv2: R.Tensor((32, 8, 64, 128), dtype="float32") = R.multiply(
+                    lv1, R.const(0.35355338454246521, "float32")
+                )
+                lv3: R.Tensor((32, 8, 128, 64), dtype="float32") = 
R.broadcast_to(
+                    lv, R.shape([32, 8, 128, 64])
+                )
+                lv4: R.Tensor((256, 128, 64), dtype="float32") = R.reshape(
+                    lv3, R.shape([256, 128, 64])
+                )
+                lv5: R.Tensor((32, 8, 64, 128), dtype="float32") = 
R.broadcast_to(
+                    lv2, R.shape([32, 8, 64, 128])
+                )
+                lv6: R.Tensor((256, 64, 128), dtype="float32") = R.reshape(
+                    lv5, R.shape([256, 64, 128])
+                )
+                lv7: R.Tensor((256, 128, 128), dtype="float32") = R.matmul(
+                    lv4, lv6, out_dtype="float32"
+                )
+                lv8: R.Tensor((32, 8, 128, 128), dtype="float32") = R.reshape(
+                    lv7, R.shape([32, 8, 128, 128])
+                )
+                lv9: R.Tensor((32, 8, 128, 128), dtype="float32") = 
R.nn.softmax(lv8, axis=-1)
+                lv10: R.Tensor((32, 8, 128, 128), dtype="bool") = R.equal(
+                    lv8, R.const(float("-inf"), "float32")
                 )
-                lv1: R.Tensor((32, 128, 8, 64), dtype="float32") = 
R.permute_dims(
-                    inp_1, axes=[0, 2, 1, 3]
+                lv11: R.Tensor((32, 8, 128, 128), dtype="bool") = 
R.logical_not(lv10)
+                lv12: R.Tensor((32, 8, 128, 1), dtype="bool") = R.max(
+                    lv11, axis=[-1], keepdims=True
                 )
-                lv2: R.Tensor((32, 128, 8, 64), dtype="float32") = 
R.permute_dims(
-                    inp_2, axes=[0, 2, 1, 3]
+                lv13: R.Tensor((32, 8, 128, 1), dtype="bool") = 
R.logical_not(lv12)
+                lv14: R.Tensor((32, 8, 128, 128), dtype="float32") = 
R.full_like(
+                    lv9, R.const(0, "int32"), dtype="void"
                 )
-                lv3: R.Tensor((32, 128, 8, 64), dtype="float32") = 
R.nn.attention(
-                    lv, lv1, lv2, scale=None
+                lv15: R.Tensor((32, 8, 128, 128), dtype="float32") = 
R.where(lv13, lv14, lv9)
+                lv16: R.Tensor((32, 8, 128, 128), dtype="float32") = 
R.broadcast_to(
+                    lv15, R.shape([32, 8, 128, 128])
                 )
-                lv4: R.Tensor((32, 8, 128, 64), dtype="float32") = 
R.permute_dims(
-                    lv3, axes=[0, 2, 1, 3]
+                lv17: R.Tensor((256, 128, 128), dtype="float32") = R.reshape(
+                    lv16, R.shape([256, 128, 128])
                 )
-                gv: R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")) = 
(lv4,)
+                lv18: R.Tensor((32, 8, 128, 64), dtype="float32") = 
R.broadcast_to(
+                    v, R.shape([32, 8, 128, 64])
+                )
+                lv19: R.Tensor((256, 128, 64), dtype="float32") = R.reshape(
+                    lv18, R.shape([256, 128, 64])
+                )
+                lv20: R.Tensor((256, 128, 64), dtype="float32") = R.matmul(
+                    lv17, lv19, out_dtype="float32"
+                )
+                lv21: R.Tensor((32, 8, 128, 64), dtype="float32") = R.reshape(
+                    lv20, R.shape([32, 8, 128, 64])
+                )
+                lv22: R.Tensor((128, 32, 8, 64), dtype="float32") = 
R.permute_dims(
+                    lv21, axes=[2, 0, 1, 3]
+                )
+                lv23: R.Tensor((32, 8, 128, 64), dtype="float32") = 
R.permute_dims(
+                    lv22, axes=[1, 2, 0, 3]
+                )
+                gv: R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")) = 
(lv23,)
                 R.output(gv)
             return gv
 
@@ -3447,28 +3511,78 @@ def test_scaled_dot_product_attention():
     class Expected2:
         @R.function
         def main(
-            inp_0: R.Tensor((32, 8, 128, 64), dtype="float32"),
-            inp_1: R.Tensor((32, 8, 128, 64), dtype="float32"),
-            inp_2: R.Tensor((32, 8, 128, 64), dtype="float32"),
-            inp_3: R.Tensor((32, 8, 128, 128), dtype="float32"),
+            q: R.Tensor((32, 8, 128, 64), dtype="float32"),
+            k: R.Tensor((32, 8, 128, 64), dtype="float32"),
+            v: R.Tensor((32, 8, 128, 64), dtype="float32"),
+            mask: R.Tensor((32, 8, 128, 128), dtype="float32"),
         ) -> R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")):
             with R.dataflow():
-                lv: R.Tensor((32, 128, 8, 64), dtype="float32") = 
R.permute_dims(
-                    inp_0, axes=[0, 2, 1, 3]
+                lv: R.Tensor((32, 8, 128, 64), dtype="float32") = R.multiply(
+                    q, R.const(0.35355338454246521, "float32")
+                )
+                lv1: R.Tensor((32, 8, 64, 128), dtype="float32") = 
R.permute_dims(
+                    k, axes=[0, 1, 3, 2]
+                )
+                lv2: R.Tensor((32, 8, 64, 128), dtype="float32") = R.multiply(
+                    lv1, R.const(0.35355338454246521, "float32")
+                )
+                lv3: R.Tensor((32, 8, 128, 64), dtype="float32") = 
R.broadcast_to(
+                    lv, R.shape([32, 8, 128, 64])
+                )
+                lv4: R.Tensor((256, 128, 64), dtype="float32") = R.reshape(
+                    lv3, R.shape([256, 128, 64])
+                )
+                lv5: R.Tensor((32, 8, 64, 128), dtype="float32") = 
R.broadcast_to(
+                    lv2, R.shape([32, 8, 64, 128])
+                )
+                lv6: R.Tensor((256, 64, 128), dtype="float32") = R.reshape(
+                    lv5, R.shape([256, 64, 128])
                 )
-                lv1: R.Tensor((32, 128, 8, 64), dtype="float32") = 
R.permute_dims(
-                    inp_1, axes=[0, 2, 1, 3]
+                lv7: R.Tensor((256, 128, 128), dtype="float32") = R.matmul(
+                    lv4, lv6, out_dtype="float32"
                 )
-                lv2: R.Tensor((32, 128, 8, 64), dtype="float32") = 
R.permute_dims(
-                    inp_2, axes=[0, 2, 1, 3]
+                lv8: R.Tensor((32, 8, 128, 128), dtype="float32") = R.reshape(
+                    lv7, R.shape([32, 8, 128, 128])
                 )
-                lv3: R.Tensor((32, 128, 8, 64), dtype="float32") = 
R.nn.attention(
-                    lv, lv1, lv2, inp_3, scale=None
+                lv9: R.Tensor((32, 8, 128, 128), dtype="float32") = R.add(lv8, 
mask)
+                lv10: R.Tensor((32, 8, 128, 128), dtype="float32") = 
R.nn.softmax(lv9, axis=-1)
+                lv11: R.Tensor((32, 8, 128, 128), dtype="bool") = R.equal(
+                    lv9, R.const(float("-inf"), "float32")
                 )
-                lv4: R.Tensor((32, 8, 128, 64), dtype="float32") = 
R.permute_dims(
-                    lv3, axes=[0, 2, 1, 3]
+                lv12: R.Tensor((32, 8, 128, 128), dtype="bool") = 
R.logical_not(lv11)
+                lv13: R.Tensor((32, 8, 128, 1), dtype="bool") = R.max(
+                    lv12, axis=[-1], keepdims=True
                 )
-                gv: R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")) = 
(lv4,)
+                lv14: R.Tensor((32, 8, 128, 1), dtype="bool") = 
R.logical_not(lv13)
+                lv15: R.Tensor((32, 8, 128, 128), dtype="float32") = 
R.full_like(
+                    lv10, R.const(0, "int32"), dtype="void"
+                )
+                lv16: R.Tensor((32, 8, 128, 128), dtype="float32") = 
R.where(lv14, lv15, lv10)
+                lv17: R.Tensor((32, 8, 128, 128), dtype="float32") = 
R.broadcast_to(
+                    lv16, R.shape([32, 8, 128, 128])
+                )
+                lv18: R.Tensor((256, 128, 128), dtype="float32") = R.reshape(
+                    lv17, R.shape([256, 128, 128])
+                )
+                lv19: R.Tensor((32, 8, 128, 64), dtype="float32") = 
R.broadcast_to(
+                    v, R.shape([32, 8, 128, 64])
+                )
+                lv20: R.Tensor((256, 128, 64), dtype="float32") = R.reshape(
+                    lv19, R.shape([256, 128, 64])
+                )
+                lv21: R.Tensor((256, 128, 64), dtype="float32") = R.matmul(
+                    lv18, lv20, out_dtype="float32"
+                )
+                lv22: R.Tensor((32, 8, 128, 64), dtype="float32") = R.reshape(
+                    lv21, R.shape([32, 8, 128, 64])
+                )
+                lv23: R.Tensor((128, 32, 8, 64), dtype="float32") = 
R.permute_dims(
+                    lv22, axes=[2, 0, 1, 3]
+                )
+                lv24: R.Tensor((32, 8, 128, 64), dtype="float32") = 
R.permute_dims(
+                    lv23, axes=[1, 2, 0, 3]
+                )
+                gv: R.Tuple(R.Tensor((32, 8, 128, 64), dtype="float32")) = 
(lv24,)
                 R.output(gv)
             return gv
 
@@ -3481,6 +3595,7 @@ def test_scaled_dot_product_attention():
         ),
         {},
         Expected1,
+        run_ep_decomposition=True,
     )
 
     verify_model(
@@ -3493,6 +3608,7 @@ def test_scaled_dot_product_attention():
         ),
         {},
         Expected2,
+        run_ep_decomposition=True,
     )
 
 

Reply via email to