This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e5fb395c57 [Relax][PyTorch] Add decomposed operator support for
MaxPool (#18446)
e5fb395c57 is described below
commit e5fb395c578118d2f4542585a0310b66b27dd95a
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Fri Nov 14 13:23:24 2025 +0800
[Relax][PyTorch] Add decomposed operator support for MaxPool (#18446)
Add decomposed operator support for MaxPool
---
.../frontend/torch/base_fx_graph_translator.py | 48 +++++++
.../frontend/torch/exported_program_translator.py | 2 +
.../relax/test_frontend_from_exported_program.py | 152 ++++++++++++++-------
3 files changed, 155 insertions(+), 47 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 0c8cd4b34f..33e8347fb0 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1313,6 +1313,54 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
ceil_mode = args[5] if len(args) > 5 else False
return self._max_pool3d_impl(x, kernel_size, stride, padding,
dilation, ceil_mode)
+ def _max_pool1d_with_indices(self, node: fx.Node) -> relax.Var:
+ # max_pool1d_with_indices returns (output, indices)
+ # We only compute the output and create a placeholder for indices
+ args = self.retrieve_args(node)
+ x = args[0]
+ kernel_size = args[1]
+ stride = args[2] if len(args) > 2 else None
+ padding = args[3] if len(args) > 3 else 0
+ dilation = args[4] if len(args) > 4 else 1
+ ceil_mode = args[5] if len(args) > 5 else False
+
+ output = self._max_pool1d_impl(x, kernel_size, stride, padding,
dilation, ceil_mode)
+ # Create a placeholder for indices (empty tensor with same shape as
output)
+ indices = relax.op.zeros_like(output)
+ return self.block_builder.emit(relax.Tuple([output, indices]))
+
+ def _max_pool2d_with_indices(self, node: fx.Node) -> relax.Var:
+ # max_pool2d_with_indices returns (output, indices)
+ # We only compute the output and create a placeholder for indices
+ args = self.retrieve_args(node)
+ x = args[0]
+ kernel_size = args[1]
+ stride = args[2] if len(args) > 2 else None
+ padding = args[3] if len(args) > 3 else 0
+ dilation = args[4] if len(args) > 4 else 1
+ ceil_mode = args[5] if len(args) > 5 else False
+
+ output = self._max_pool2d_impl(x, kernel_size, stride, padding,
dilation, ceil_mode)
+ # Create a placeholder for indices (empty tensor with same shape as
output)
+ indices = relax.op.zeros_like(output)
+ return self.block_builder.emit(relax.Tuple([output, indices]))
+
+ def _max_pool3d_with_indices(self, node: fx.Node) -> relax.Var:
+ # max_pool3d_with_indices returns (output, indices)
+ # We only compute the output and create a placeholder for indices
+ args = self.retrieve_args(node)
+ x = args[0]
+ kernel_size = args[1]
+ stride = args[2] if len(args) > 2 else None
+ padding = args[3] if len(args) > 3 else 0
+ dilation = args[4] if len(args) > 4 else 1
+ ceil_mode = args[5] if len(args) > 5 else False
+
+ output = self._max_pool3d_impl(x, kernel_size, stride, padding,
dilation, ceil_mode)
+ # Create a placeholder for indices (empty tensor with same shape as
output)
+ indices = relax.op.zeros_like(output)
+ return self.block_builder.emit(relax.Tuple([output, indices]))
+
def _pad(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
pad = node.args[1]
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index a6da21ada8..5cddf24a89 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -986,7 +986,9 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"gru.input": self._gru,
"max_pool1d.default": self._max_pool1d,
"max_pool2d.default": self._max_pool2d,
+ "max_pool2d_with_indices.default": self._max_pool2d_with_indices,
"max_pool3d.default": self._max_pool3d,
+ "max_pool3d_with_indices.default": self._max_pool3d_with_indices,
"scaled_dot_product_attention.default":
self._scaled_dot_product_attention,
"unbind.int": self._unbind,
"upsample_bilinear2d.vec": self._upsample_bilinear2d,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 774a50db0e..71e400a6a8 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -3163,16 +3163,24 @@ def test_maxpool1d():
input_1: R.Tensor((1, 3, 8), dtype="float32")
) -> R.Tuple(R.Tensor((1, 3, 4), dtype="float32")):
with R.dataflow():
- lv = R.nn.max_pool1d(
- input_1,
- pool_size=[2],
- strides=[2],
- dilation=[1],
- padding=[0, 0],
- layout="NCW",
- out_layout="NCW",
+ lv: R.Tensor((1, 3, 1, 8), dtype="float32") =
R.expand_dims(input_1, axis=[-2])
+ lv1: R.Tensor((1, 3, 1, 4), dtype="float32") = R.nn.max_pool2d(
+ lv,
+ pool_size=[1, 2],
+ strides=[1, 2],
+ dilation=[1, 1],
+ padding=[0, 0, 0, 0],
+ layout="NCHW",
+ out_layout="NCHW",
)
- gv = (lv,)
+ lv2: R.Tensor((1, 3, 1, 4), dtype="float32") =
R.zeros_like(lv1)
+ lv3: R.Tuple(
+ R.Tensor((1, 3, 1, 4), dtype="float32"),
+ R.Tensor((1, 3, 1, 4), dtype="float32"),
+ ) = (lv1, lv2)
+ lv4: R.Tensor((1, 3, 1, 4), dtype="float32") = lv3[0]
+ lv5: R.Tensor((1, 3, 4), dtype="float32") = R.squeeze(lv4,
axis=[-2])
+ gv: R.Tuple(R.Tensor((1, 3, 4), dtype="float32")) = (lv5,)
R.output(gv)
return gv
@@ -3183,16 +3191,24 @@ def test_maxpool1d():
input_1: R.Tensor((1, 3, 8), dtype="float32")
) -> R.Tuple(R.Tensor((1, 3, 4), dtype="float32")):
with R.dataflow():
- lv = R.nn.max_pool1d(
- input_1,
- pool_size=[2],
- strides=[2],
- dilation=[1],
- padding=[0, 0],
- layout="NCW",
- out_layout="NCW",
+ lv: R.Tensor((1, 3, 1, 8), dtype="float32") =
R.expand_dims(input_1, axis=[-2])
+ lv1: R.Tensor((1, 3, 1, 4), dtype="float32") = R.nn.max_pool2d(
+ lv,
+ pool_size=[1, 2],
+ strides=[1, 2],
+ dilation=[1, 1],
+ padding=[0, 0, 0, 0],
+ layout="NCHW",
+ out_layout="NCHW",
)
- gv = (lv,)
+ lv2: R.Tensor((1, 3, 1, 4), dtype="float32") =
R.zeros_like(lv1)
+ lv3: R.Tuple(
+ R.Tensor((1, 3, 1, 4), dtype="float32"),
+ R.Tensor((1, 3, 1, 4), dtype="float32"),
+ ) = (lv1, lv2)
+ lv4: R.Tensor((1, 3, 1, 4), dtype="float32") = lv3[0]
+ lv5: R.Tensor((1, 3, 4), dtype="float32") = R.squeeze(lv4,
axis=[-2])
+ gv: R.Tuple(R.Tensor((1, 3, 4), dtype="float32")) = (lv5,)
R.output(gv)
return gv
@@ -3203,16 +3219,24 @@ def test_maxpool1d():
input_1: R.Tensor((1, 3, 10), dtype="float32")
) -> R.Tuple(R.Tensor((1, 3, 4), dtype="float32")):
with R.dataflow():
- lv = R.nn.max_pool1d(
- input_1,
- pool_size=[3],
- strides=[2],
- dilation=[1],
- padding=[0, 0],
- layout="NCW",
- out_layout="NCW",
+ lv: R.Tensor((1, 3, 1, 10), dtype="float32") =
R.expand_dims(input_1, axis=[-2])
+ lv1: R.Tensor((1, 3, 1, 4), dtype="float32") = R.nn.max_pool2d(
+ lv,
+ pool_size=[1, 3],
+ strides=[1, 2],
+ dilation=[1, 1],
+ padding=[0, 0, 0, 0],
+ layout="NCHW",
+ out_layout="NCHW",
)
- gv = (lv,)
+ lv2: R.Tensor((1, 3, 1, 4), dtype="float32") =
R.zeros_like(lv1)
+ lv3: R.Tuple(
+ R.Tensor((1, 3, 1, 4), dtype="float32"),
+ R.Tensor((1, 3, 1, 4), dtype="float32"),
+ ) = (lv1, lv2)
+ lv4: R.Tensor((1, 3, 1, 4), dtype="float32") = lv3[0]
+ lv5: R.Tensor((1, 3, 4), dtype="float32") = R.squeeze(lv4,
axis=[-2])
+ gv: R.Tuple(R.Tensor((1, 3, 4), dtype="float32")) = (lv5,)
R.output(gv)
return gv
@@ -3222,9 +3246,9 @@ def test_maxpool1d():
example_args3 = (torch.randn(1, 3, 10, dtype=torch.float32),)
# Verify the models
- verify_model(MaxPool1d(), example_args1, {}, expected1)
- verify_model(MaxPool1d_functional(), example_args2, {}, expected2)
- verify_model(MaxPool1d2(), example_args3, {}, expected3)
+ verify_model(MaxPool1d(), example_args1, {}, expected1,
run_ep_decomposition=True)
+ verify_model(MaxPool1d_functional(), example_args2, {}, expected2,
run_ep_decomposition=True)
+ verify_model(MaxPool1d2(), example_args3, {}, expected3,
run_ep_decomposition=True)
def test_maxpool2d():
@@ -3260,7 +3284,13 @@ def test_maxpool2d():
layout="NCHW",
out_layout="NCHW",
)
- gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,)
+ lv1: R.Tensor((1, 3, 10, 10), dtype="float32") =
R.zeros_like(lv)
+ lv2: R.Tuple(
+ R.Tensor((1, 3, 10, 10), dtype="float32"),
+ R.Tensor((1, 3, 10, 10), dtype="float32"),
+ ) = (lv, lv1)
+ lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = lv2[0]
+ gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv3,)
R.output(gv)
return gv
@@ -3289,7 +3319,12 @@ def test_maxpool2d():
layout="NCHW",
out_layout="NCHW",
)
- gv: R.Tuple(R.Tensor((1, 3, 4, 4), dtype="float32")) = (lv,)
+ lv1: R.Tensor((1, 3, 4, 4), dtype="float32") = R.zeros_like(lv)
+ lv2: R.Tuple(
+ R.Tensor((1, 3, 4, 4), dtype="float32"), R.Tensor((1, 3,
4, 4), dtype="float32")
+ ) = (lv, lv1)
+ lv3: R.Tensor((1, 3, 4, 4), dtype="float32") = lv2[0]
+ gv: R.Tuple(R.Tensor((1, 3, 4, 4), dtype="float32")) = (lv3,)
R.output(gv)
return gv
@@ -3318,15 +3353,20 @@ def test_maxpool2d():
layout="NCHW",
out_layout="NCHW",
)
- gv: R.Tuple(R.Tensor((1, 3, 6, 6), dtype="float32")) = (lv,)
+ lv1: R.Tensor((1, 3, 6, 6), dtype="float32") = R.zeros_like(lv)
+ lv2: R.Tuple(
+ R.Tensor((1, 3, 6, 6), dtype="float32"), R.Tensor((1, 3,
6, 6), dtype="float32")
+ ) = (lv, lv1)
+ lv3: R.Tensor((1, 3, 6, 6), dtype="float32") = lv2[0]
+ gv: R.Tuple(R.Tensor((1, 3, 6, 6), dtype="float32")) = (lv3,)
R.output(gv)
return gv
example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
- verify_model(MaxPool2d(), example_args, {}, expected1)
- verify_model(MaxPool2d_functional(), example_args, {}, expected1)
- verify_model(MaxPool2d2(), example_args, {}, expected2)
- verify_model(MaxPool2d3(), example_args, {}, expected3)
+ verify_model(MaxPool2d(), example_args, {}, expected1,
run_ep_decomposition=True)
+ verify_model(MaxPool2d_functional(), example_args, {}, expected1,
run_ep_decomposition=True)
+ verify_model(MaxPool2d2(), example_args, {}, expected2,
run_ep_decomposition=True)
+ verify_model(MaxPool2d3(), example_args, {}, expected3,
run_ep_decomposition=True)
def test_maxpool3d():
@@ -3352,7 +3392,7 @@ def test_maxpool3d():
input_1: R.Tensor((1, 3, 4, 4, 4), dtype="float32")
) -> R.Tuple(R.Tensor((1, 3, 4, 4, 4), dtype="float32")):
with R.dataflow():
- lv = R.nn.max_pool3d(
+ lv: R.Tensor((1, 3, 4, 4, 4), dtype="float32") =
R.nn.max_pool3d(
input_1,
pool_size=[1, 1, 1],
strides=[1, 1, 1],
@@ -3361,7 +3401,13 @@ def test_maxpool3d():
layout="NCDHW",
out_layout="NCDHW",
)
- gv = (lv,)
+ lv1: R.Tensor((1, 3, 4, 4, 4), dtype="float32") =
R.zeros_like(lv)
+ lv2: R.Tuple(
+ R.Tensor((1, 3, 4, 4, 4), dtype="float32"),
+ R.Tensor((1, 3, 4, 4, 4), dtype="float32"),
+ ) = (lv, lv1)
+ lv3: R.Tensor((1, 3, 4, 4, 4), dtype="float32") = lv2[0]
+ gv: R.Tuple(R.Tensor((1, 3, 4, 4, 4), dtype="float32")) =
(lv3,)
R.output(gv)
return gv
@@ -3380,7 +3426,7 @@ def test_maxpool3d():
input_1: R.Tensor((1, 3, 8, 8, 8), dtype="float32")
) -> R.Tuple(R.Tensor((1, 3, 3, 3, 3), dtype="float32")):
with R.dataflow():
- lv = R.nn.max_pool3d(
+ lv: R.Tensor((1, 3, 3, 3, 3), dtype="float32") =
R.nn.max_pool3d(
input_1,
pool_size=[2, 2, 2],
strides=[2, 2, 2],
@@ -3389,7 +3435,13 @@ def test_maxpool3d():
layout="NCDHW",
out_layout="NCDHW",
)
- gv = (lv,)
+ lv1: R.Tensor((1, 3, 3, 3, 3), dtype="float32") =
R.zeros_like(lv)
+ lv2: R.Tuple(
+ R.Tensor((1, 3, 3, 3, 3), dtype="float32"),
+ R.Tensor((1, 3, 3, 3, 3), dtype="float32"),
+ ) = (lv, lv1)
+ lv3: R.Tensor((1, 3, 3, 3, 3), dtype="float32") = lv2[0]
+ gv: R.Tuple(R.Tensor((1, 3, 3, 3, 3), dtype="float32")) =
(lv3,)
R.output(gv)
return gv
@@ -3408,7 +3460,7 @@ def test_maxpool3d():
input_1: R.Tensor((1, 3, 10, 10, 10), dtype="float32")
) -> R.Tuple(R.Tensor((1, 3, 5, 5, 5), dtype="float32")):
with R.dataflow():
- lv = R.nn.max_pool3d(
+ lv: R.Tensor((1, 3, 5, 5, 5), dtype="float32") =
R.nn.max_pool3d(
input_1,
pool_size=[3, 3, 3],
strides=[2, 2, 2],
@@ -3417,7 +3469,13 @@ def test_maxpool3d():
layout="NCDHW",
out_layout="NCDHW",
)
- gv = (lv,)
+ lv1: R.Tensor((1, 3, 5, 5, 5), dtype="float32") =
R.zeros_like(lv)
+ lv2: R.Tuple(
+ R.Tensor((1, 3, 5, 5, 5), dtype="float32"),
+ R.Tensor((1, 3, 5, 5, 5), dtype="float32"),
+ ) = (lv, lv1)
+ lv3: R.Tensor((1, 3, 5, 5, 5), dtype="float32") = lv2[0]
+ gv: R.Tuple(R.Tensor((1, 3, 5, 5, 5), dtype="float32")) =
(lv3,)
R.output(gv)
return gv
@@ -3427,10 +3485,10 @@ def test_maxpool3d():
example_args3 = (torch.randn(1, 3, 10, 10, 10, dtype=torch.float32),)
# Verify the models with expected IR modules
- verify_model(MaxPool3d(), example_args1, {}, expected1)
- verify_model(MaxPool3d_functional(), example_args1, {}, expected1)
- verify_model(MaxPool3d2(), example_args2, {}, expected2)
- verify_model(MaxPool3d3(), example_args3, {}, expected3)
+ verify_model(MaxPool3d(), example_args1, {}, expected1,
run_ep_decomposition=True)
+ verify_model(MaxPool3d_functional(), example_args1, {}, expected1,
run_ep_decomposition=True)
+ verify_model(MaxPool3d2(), example_args2, {}, expected2,
run_ep_decomposition=True)
+ verify_model(MaxPool3d3(), example_args3, {}, expected3,
run_ep_decomposition=True)
def test_scaled_dot_product_attention():