This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e5fb395c57 [Relax][PyTorch] Add decomposed operator support for 
MaxPool (#18446)
e5fb395c57 is described below

commit e5fb395c578118d2f4542585a0310b66b27dd95a
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Fri Nov 14 13:23:24 2025 +0800

    [Relax][PyTorch] Add decomposed operator support for MaxPool (#18446)
    
    Add decomposed operator support for MaxPool
---
 .../frontend/torch/base_fx_graph_translator.py     |  48 +++++++
 .../frontend/torch/exported_program_translator.py  |   2 +
 .../relax/test_frontend_from_exported_program.py   | 152 ++++++++++++++-------
 3 files changed, 155 insertions(+), 47 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 0c8cd4b34f..33e8347fb0 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1313,6 +1313,54 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
         ceil_mode = args[5] if len(args) > 5 else False
         return self._max_pool3d_impl(x, kernel_size, stride, padding, 
dilation, ceil_mode)
 
+    def _max_pool1d_with_indices(self, node: fx.Node) -> relax.Var:
+        # max_pool1d_with_indices returns (output, indices)
+        # We only compute the output and create a placeholder for indices
+        args = self.retrieve_args(node)
+        x = args[0]
+        kernel_size = args[1]
+        stride = args[2] if len(args) > 2 else None
+        padding = args[3] if len(args) > 3 else 0
+        dilation = args[4] if len(args) > 4 else 1
+        ceil_mode = args[5] if len(args) > 5 else False
+
+        output = self._max_pool1d_impl(x, kernel_size, stride, padding, 
dilation, ceil_mode)
+        # Create a placeholder for indices (empty tensor with same shape as 
output)
+        indices = relax.op.zeros_like(output)
+        return self.block_builder.emit(relax.Tuple([output, indices]))
+
+    def _max_pool2d_with_indices(self, node: fx.Node) -> relax.Var:
+        # max_pool2d_with_indices returns (output, indices)
+        # We only compute the output and create a placeholder for indices
+        args = self.retrieve_args(node)
+        x = args[0]
+        kernel_size = args[1]
+        stride = args[2] if len(args) > 2 else None
+        padding = args[3] if len(args) > 3 else 0
+        dilation = args[4] if len(args) > 4 else 1
+        ceil_mode = args[5] if len(args) > 5 else False
+
+        output = self._max_pool2d_impl(x, kernel_size, stride, padding, 
dilation, ceil_mode)
+        # Create a placeholder for indices (empty tensor with same shape as 
output)
+        indices = relax.op.zeros_like(output)
+        return self.block_builder.emit(relax.Tuple([output, indices]))
+
+    def _max_pool3d_with_indices(self, node: fx.Node) -> relax.Var:
+        # max_pool3d_with_indices returns (output, indices)
+        # We only compute the output and create a placeholder for indices
+        args = self.retrieve_args(node)
+        x = args[0]
+        kernel_size = args[1]
+        stride = args[2] if len(args) > 2 else None
+        padding = args[3] if len(args) > 3 else 0
+        dilation = args[4] if len(args) > 4 else 1
+        ceil_mode = args[5] if len(args) > 5 else False
+
+        output = self._max_pool3d_impl(x, kernel_size, stride, padding, 
dilation, ceil_mode)
+        # Create a placeholder for indices (empty tensor with same shape as 
output)
+        indices = relax.op.zeros_like(output)
+        return self.block_builder.emit(relax.Tuple([output, indices]))
+
     def _pad(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         pad = node.args[1]
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index a6da21ada8..5cddf24a89 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -986,7 +986,9 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "gru.input": self._gru,
             "max_pool1d.default": self._max_pool1d,
             "max_pool2d.default": self._max_pool2d,
+            "max_pool2d_with_indices.default": self._max_pool2d_with_indices,
             "max_pool3d.default": self._max_pool3d,
+            "max_pool3d_with_indices.default": self._max_pool3d_with_indices,
             "scaled_dot_product_attention.default": 
self._scaled_dot_product_attention,
             "unbind.int": self._unbind,
             "upsample_bilinear2d.vec": self._upsample_bilinear2d,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 774a50db0e..71e400a6a8 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -3163,16 +3163,24 @@ def test_maxpool1d():
             input_1: R.Tensor((1, 3, 8), dtype="float32")
         ) -> R.Tuple(R.Tensor((1, 3, 4), dtype="float32")):
             with R.dataflow():
-                lv = R.nn.max_pool1d(
-                    input_1,
-                    pool_size=[2],
-                    strides=[2],
-                    dilation=[1],
-                    padding=[0, 0],
-                    layout="NCW",
-                    out_layout="NCW",
+                lv: R.Tensor((1, 3, 1, 8), dtype="float32") = 
R.expand_dims(input_1, axis=[-2])
+                lv1: R.Tensor((1, 3, 1, 4), dtype="float32") = R.nn.max_pool2d(
+                    lv,
+                    pool_size=[1, 2],
+                    strides=[1, 2],
+                    dilation=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    layout="NCHW",
+                    out_layout="NCHW",
                 )
-                gv = (lv,)
+                lv2: R.Tensor((1, 3, 1, 4), dtype="float32") = 
R.zeros_like(lv1)
+                lv3: R.Tuple(
+                    R.Tensor((1, 3, 1, 4), dtype="float32"),
+                    R.Tensor((1, 3, 1, 4), dtype="float32"),
+                ) = (lv1, lv2)
+                lv4: R.Tensor((1, 3, 1, 4), dtype="float32") = lv3[0]
+                lv5: R.Tensor((1, 3, 4), dtype="float32") = R.squeeze(lv4, 
axis=[-2])
+                gv: R.Tuple(R.Tensor((1, 3, 4), dtype="float32")) = (lv5,)
                 R.output(gv)
             return gv
 
@@ -3183,16 +3191,24 @@ def test_maxpool1d():
             input_1: R.Tensor((1, 3, 8), dtype="float32")
         ) -> R.Tuple(R.Tensor((1, 3, 4), dtype="float32")):
             with R.dataflow():
-                lv = R.nn.max_pool1d(
-                    input_1,
-                    pool_size=[2],
-                    strides=[2],
-                    dilation=[1],
-                    padding=[0, 0],
-                    layout="NCW",
-                    out_layout="NCW",
+                lv: R.Tensor((1, 3, 1, 8), dtype="float32") = 
R.expand_dims(input_1, axis=[-2])
+                lv1: R.Tensor((1, 3, 1, 4), dtype="float32") = R.nn.max_pool2d(
+                    lv,
+                    pool_size=[1, 2],
+                    strides=[1, 2],
+                    dilation=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    layout="NCHW",
+                    out_layout="NCHW",
                 )
-                gv = (lv,)
+                lv2: R.Tensor((1, 3, 1, 4), dtype="float32") = 
R.zeros_like(lv1)
+                lv3: R.Tuple(
+                    R.Tensor((1, 3, 1, 4), dtype="float32"),
+                    R.Tensor((1, 3, 1, 4), dtype="float32"),
+                ) = (lv1, lv2)
+                lv4: R.Tensor((1, 3, 1, 4), dtype="float32") = lv3[0]
+                lv5: R.Tensor((1, 3, 4), dtype="float32") = R.squeeze(lv4, 
axis=[-2])
+                gv: R.Tuple(R.Tensor((1, 3, 4), dtype="float32")) = (lv5,)
                 R.output(gv)
             return gv
 
@@ -3203,16 +3219,24 @@ def test_maxpool1d():
             input_1: R.Tensor((1, 3, 10), dtype="float32")
         ) -> R.Tuple(R.Tensor((1, 3, 4), dtype="float32")):
             with R.dataflow():
-                lv = R.nn.max_pool1d(
-                    input_1,
-                    pool_size=[3],
-                    strides=[2],
-                    dilation=[1],
-                    padding=[0, 0],
-                    layout="NCW",
-                    out_layout="NCW",
+                lv: R.Tensor((1, 3, 1, 10), dtype="float32") = 
R.expand_dims(input_1, axis=[-2])
+                lv1: R.Tensor((1, 3, 1, 4), dtype="float32") = R.nn.max_pool2d(
+                    lv,
+                    pool_size=[1, 3],
+                    strides=[1, 2],
+                    dilation=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    layout="NCHW",
+                    out_layout="NCHW",
                 )
-                gv = (lv,)
+                lv2: R.Tensor((1, 3, 1, 4), dtype="float32") = 
R.zeros_like(lv1)
+                lv3: R.Tuple(
+                    R.Tensor((1, 3, 1, 4), dtype="float32"),
+                    R.Tensor((1, 3, 1, 4), dtype="float32"),
+                ) = (lv1, lv2)
+                lv4: R.Tensor((1, 3, 1, 4), dtype="float32") = lv3[0]
+                lv5: R.Tensor((1, 3, 4), dtype="float32") = R.squeeze(lv4, 
axis=[-2])
+                gv: R.Tuple(R.Tensor((1, 3, 4), dtype="float32")) = (lv5,)
                 R.output(gv)
             return gv
 
@@ -3222,9 +3246,9 @@ def test_maxpool1d():
     example_args3 = (torch.randn(1, 3, 10, dtype=torch.float32),)
 
     # Verify the models
-    verify_model(MaxPool1d(), example_args1, {}, expected1)
-    verify_model(MaxPool1d_functional(), example_args2, {}, expected2)
-    verify_model(MaxPool1d2(), example_args3, {}, expected3)
+    verify_model(MaxPool1d(), example_args1, {}, expected1, 
run_ep_decomposition=True)
+    verify_model(MaxPool1d_functional(), example_args2, {}, expected2, 
run_ep_decomposition=True)
+    verify_model(MaxPool1d2(), example_args3, {}, expected3, 
run_ep_decomposition=True)
 
 
 def test_maxpool2d():
@@ -3260,7 +3284,13 @@ def test_maxpool2d():
                     layout="NCHW",
                     out_layout="NCHW",
                 )
-                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+                lv1: R.Tensor((1, 3, 10, 10), dtype="float32") = 
R.zeros_like(lv)
+                lv2: R.Tuple(
+                    R.Tensor((1, 3, 10, 10), dtype="float32"),
+                    R.Tensor((1, 3, 10, 10), dtype="float32"),
+                ) = (lv, lv1)
+                lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = lv2[0]
+                gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv3,)
                 R.output(gv)
             return gv
 
@@ -3289,7 +3319,12 @@ def test_maxpool2d():
                     layout="NCHW",
                     out_layout="NCHW",
                 )
-                gv: R.Tuple(R.Tensor((1, 3, 4, 4), dtype="float32")) = (lv,)
+                lv1: R.Tensor((1, 3, 4, 4), dtype="float32") = R.zeros_like(lv)
+                lv2: R.Tuple(
+                    R.Tensor((1, 3, 4, 4), dtype="float32"), R.Tensor((1, 3, 
4, 4), dtype="float32")
+                ) = (lv, lv1)
+                lv3: R.Tensor((1, 3, 4, 4), dtype="float32") = lv2[0]
+                gv: R.Tuple(R.Tensor((1, 3, 4, 4), dtype="float32")) = (lv3,)
                 R.output(gv)
             return gv
 
@@ -3318,15 +3353,20 @@ def test_maxpool2d():
                     layout="NCHW",
                     out_layout="NCHW",
                 )
-                gv: R.Tuple(R.Tensor((1, 3, 6, 6), dtype="float32")) = (lv,)
+                lv1: R.Tensor((1, 3, 6, 6), dtype="float32") = R.zeros_like(lv)
+                lv2: R.Tuple(
+                    R.Tensor((1, 3, 6, 6), dtype="float32"), R.Tensor((1, 3, 
6, 6), dtype="float32")
+                ) = (lv, lv1)
+                lv3: R.Tensor((1, 3, 6, 6), dtype="float32") = lv2[0]
+                gv: R.Tuple(R.Tensor((1, 3, 6, 6), dtype="float32")) = (lv3,)
                 R.output(gv)
             return gv
 
     example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
-    verify_model(MaxPool2d(), example_args, {}, expected1)
-    verify_model(MaxPool2d_functional(), example_args, {}, expected1)
-    verify_model(MaxPool2d2(), example_args, {}, expected2)
-    verify_model(MaxPool2d3(), example_args, {}, expected3)
+    verify_model(MaxPool2d(), example_args, {}, expected1, 
run_ep_decomposition=True)
+    verify_model(MaxPool2d_functional(), example_args, {}, expected1, 
run_ep_decomposition=True)
+    verify_model(MaxPool2d2(), example_args, {}, expected2, 
run_ep_decomposition=True)
+    verify_model(MaxPool2d3(), example_args, {}, expected3, 
run_ep_decomposition=True)
 
 
 def test_maxpool3d():
@@ -3352,7 +3392,7 @@ def test_maxpool3d():
             input_1: R.Tensor((1, 3, 4, 4, 4), dtype="float32")
         ) -> R.Tuple(R.Tensor((1, 3, 4, 4, 4), dtype="float32")):
             with R.dataflow():
-                lv = R.nn.max_pool3d(
+                lv: R.Tensor((1, 3, 4, 4, 4), dtype="float32") = 
R.nn.max_pool3d(
                     input_1,
                     pool_size=[1, 1, 1],
                     strides=[1, 1, 1],
@@ -3361,7 +3401,13 @@ def test_maxpool3d():
                     layout="NCDHW",
                     out_layout="NCDHW",
                 )
-                gv = (lv,)
+                lv1: R.Tensor((1, 3, 4, 4, 4), dtype="float32") = 
R.zeros_like(lv)
+                lv2: R.Tuple(
+                    R.Tensor((1, 3, 4, 4, 4), dtype="float32"),
+                    R.Tensor((1, 3, 4, 4, 4), dtype="float32"),
+                ) = (lv, lv1)
+                lv3: R.Tensor((1, 3, 4, 4, 4), dtype="float32") = lv2[0]
+                gv: R.Tuple(R.Tensor((1, 3, 4, 4, 4), dtype="float32")) = 
(lv3,)
                 R.output(gv)
             return gv
 
@@ -3380,7 +3426,7 @@ def test_maxpool3d():
             input_1: R.Tensor((1, 3, 8, 8, 8), dtype="float32")
         ) -> R.Tuple(R.Tensor((1, 3, 3, 3, 3), dtype="float32")):
             with R.dataflow():
-                lv = R.nn.max_pool3d(
+                lv: R.Tensor((1, 3, 3, 3, 3), dtype="float32") = 
R.nn.max_pool3d(
                     input_1,
                     pool_size=[2, 2, 2],
                     strides=[2, 2, 2],
@@ -3389,7 +3435,13 @@ def test_maxpool3d():
                     layout="NCDHW",
                     out_layout="NCDHW",
                 )
-                gv = (lv,)
+                lv1: R.Tensor((1, 3, 3, 3, 3), dtype="float32") = 
R.zeros_like(lv)
+                lv2: R.Tuple(
+                    R.Tensor((1, 3, 3, 3, 3), dtype="float32"),
+                    R.Tensor((1, 3, 3, 3, 3), dtype="float32"),
+                ) = (lv, lv1)
+                lv3: R.Tensor((1, 3, 3, 3, 3), dtype="float32") = lv2[0]
+                gv: R.Tuple(R.Tensor((1, 3, 3, 3, 3), dtype="float32")) = 
(lv3,)
                 R.output(gv)
             return gv
 
@@ -3408,7 +3460,7 @@ def test_maxpool3d():
             input_1: R.Tensor((1, 3, 10, 10, 10), dtype="float32")
         ) -> R.Tuple(R.Tensor((1, 3, 5, 5, 5), dtype="float32")):
             with R.dataflow():
-                lv = R.nn.max_pool3d(
+                lv: R.Tensor((1, 3, 5, 5, 5), dtype="float32") = 
R.nn.max_pool3d(
                     input_1,
                     pool_size=[3, 3, 3],
                     strides=[2, 2, 2],
@@ -3417,7 +3469,13 @@ def test_maxpool3d():
                     layout="NCDHW",
                     out_layout="NCDHW",
                 )
-                gv = (lv,)
+                lv1: R.Tensor((1, 3, 5, 5, 5), dtype="float32") = 
R.zeros_like(lv)
+                lv2: R.Tuple(
+                    R.Tensor((1, 3, 5, 5, 5), dtype="float32"),
+                    R.Tensor((1, 3, 5, 5, 5), dtype="float32"),
+                ) = (lv, lv1)
+                lv3: R.Tensor((1, 3, 5, 5, 5), dtype="float32") = lv2[0]
+                gv: R.Tuple(R.Tensor((1, 3, 5, 5, 5), dtype="float32")) = 
(lv3,)
                 R.output(gv)
             return gv
 
@@ -3427,10 +3485,10 @@ def test_maxpool3d():
     example_args3 = (torch.randn(1, 3, 10, 10, 10, dtype=torch.float32),)
 
     # Verify the models with expected IR modules
-    verify_model(MaxPool3d(), example_args1, {}, expected1)
-    verify_model(MaxPool3d_functional(), example_args1, {}, expected1)
-    verify_model(MaxPool3d2(), example_args2, {}, expected2)
-    verify_model(MaxPool3d3(), example_args3, {}, expected3)
+    verify_model(MaxPool3d(), example_args1, {}, expected1, 
run_ep_decomposition=True)
+    verify_model(MaxPool3d_functional(), example_args1, {}, expected1, 
run_ep_decomposition=True)
+    verify_model(MaxPool3d2(), example_args2, {}, expected2, 
run_ep_decomposition=True)
+    verify_model(MaxPool3d3(), example_args3, {}, expected3, 
run_ep_decomposition=True)
 
 
 def test_scaled_dot_product_attention():

Reply via email to