This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ce5f287bdb [Relax][PyTorch] Add support for decomposed operators and
fix IR of ops tests (#18433)
ce5f287bdb is described below
commit ce5f287bdb6fd2505c147acc9feae3585b8380ed
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Wed Nov 12 23:09:47 2025 +0800
[Relax][PyTorch] Add support for decomposed operators and fix IR of ops
tests (#18433)
Add decomposed operators support for conv
---
.../frontend/torch/base_fx_graph_translator.py | 74 ++++++++++++++++++++++
.../frontend/torch/exported_program_translator.py | 1 +
.../relax/test_frontend_from_exported_program.py | 4 +-
3 files changed, 78 insertions(+), 1 deletion(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 177e3d91f9..0c8cd4b34f 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1003,6 +1003,80 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
groups=groups,
)
+ def _convolution(self, node: fx.Node) -> relax.Var:
+ args = self.retrieve_args(node)
+ x = args[0]
+ weight = args[1]
+ bias = args[2] if len(args) > 2 else None
+ stride = args[3] if len(args) > 3 else 1
+ padding = args[4] if len(args) > 4 else 0
+ dilation = args[5] if len(args) > 5 else 1
+ transposed = args[6] if len(args) > 6 else False
+ output_padding = args[7] if len(args) > 7 else 0
+ groups = args[8] if len(args) > 8 else 1
+
+ input_shape = self.shape_of(x)
+ ndim = len(input_shape)
+
+ if transposed:
+ if ndim == 3: # 1D convolution (N, C, W)
+ return self._conv_transpose1d_impl(
+ x,
+ weight,
+ bias=bias,
+ strides=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ output_padding=output_padding,
+ )
+ elif ndim == 4: # 2D convolution (N, C, H, W)
+ return self._conv_transpose2d_impl(
+ x,
+ weight,
+ bias=bias,
+ strides=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ output_padding=output_padding,
+ )
+ else:
+ raise ValueError(f"Unsupported transposed convolution
dimensionality: {ndim}")
+ else:
+ if ndim == 3: # 1D convolution (N, C, W)
+ return self._conv1d_impl(
+ x,
+ weight,
+ bias=bias,
+ strides=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ )
+ elif ndim == 4: # 2D convolution (N, C, H, W)
+ return self._conv2d_impl(
+ x,
+ weight,
+ bias=bias,
+ strides=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ )
+ elif ndim == 5: # 3D convolution (N, C, D, H, W)
+ return self._conv3d_impl(
+ x,
+ weight,
+ bias=bias,
+ strides=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ )
+ else:
+ raise ValueError(f"Unsupported convolution dimensionality:
{ndim}")
+
def _cross_entropy_loss(
self,
preds: relax.Expr,
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 0dfa4cc6da..0d4abb0336 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -969,6 +969,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"conv1d.default": self._conv1d,
"conv2d.default": self._conv2d,
"conv3d.default": self._conv3d,
+ "convolution.default": self._convolution,
"cross_entropy_loss.default": self._cross_entropy_default,
"einsum.default": self._einsum,
"embedding.default": lambda node: self._embedding_impl(
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index ba14356e8e..8f308e59b7 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -5254,7 +5254,9 @@ def test_keep_params():
example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
model = Conv2D1()
exported_program = torch.export.export(model, example_args)
- mod = from_exported_program(exported_program, keep_params_as_input=True)
+ mod = from_exported_program(
+ exported_program, keep_params_as_input=True, run_ep_decomposition=True
+ )
mod, params = detach_params(mod)
tvm.ir.assert_structural_equal(mod, expected1)
func = mod["main"]