This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ce5f287bdb [Relax][PyTorch] Add support for decomposed operators and 
fix IR of ops tests (#18433)
ce5f287bdb is described below

commit ce5f287bdb6fd2505c147acc9feae3585b8380ed
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Wed Nov 12 23:09:47 2025 +0800

    [Relax][PyTorch] Add support for decomposed operators and fix IR of ops 
tests (#18433)
    
    Add decomposed operators support for conv
---
 .../frontend/torch/base_fx_graph_translator.py     | 74 ++++++++++++++++++++++
 .../frontend/torch/exported_program_translator.py  |  1 +
 .../relax/test_frontend_from_exported_program.py   |  4 +-
 3 files changed, 78 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py 
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 177e3d91f9..0c8cd4b34f 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1003,6 +1003,80 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
             groups=groups,
         )
 
+    def _convolution(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        x = args[0]
+        weight = args[1]
+        bias = args[2] if len(args) > 2 else None
+        stride = args[3] if len(args) > 3 else 1
+        padding = args[4] if len(args) > 4 else 0
+        dilation = args[5] if len(args) > 5 else 1
+        transposed = args[6] if len(args) > 6 else False
+        output_padding = args[7] if len(args) > 7 else 0
+        groups = args[8] if len(args) > 8 else 1
+
+        input_shape = self.shape_of(x)
+        ndim = len(input_shape)
+
+        if transposed:
+            if ndim == 3:  # 1D convolution (N, C, W)
+                return self._conv_transpose1d_impl(
+                    x,
+                    weight,
+                    bias=bias,
+                    strides=stride,
+                    padding=padding,
+                    dilation=dilation,
+                    groups=groups,
+                    output_padding=output_padding,
+                )
+            elif ndim == 4:  # 2D convolution (N, C, H, W)
+                return self._conv_transpose2d_impl(
+                    x,
+                    weight,
+                    bias=bias,
+                    strides=stride,
+                    padding=padding,
+                    dilation=dilation,
+                    groups=groups,
+                    output_padding=output_padding,
+                )
+            else:
+                raise ValueError(f"Unsupported transposed convolution 
dimensionality: {ndim}")
+        else:
+            if ndim == 3:  # 1D convolution (N, C, W)
+                return self._conv1d_impl(
+                    x,
+                    weight,
+                    bias=bias,
+                    strides=stride,
+                    padding=padding,
+                    dilation=dilation,
+                    groups=groups,
+                )
+            elif ndim == 4:  # 2D convolution (N, C, H, W)
+                return self._conv2d_impl(
+                    x,
+                    weight,
+                    bias=bias,
+                    strides=stride,
+                    padding=padding,
+                    dilation=dilation,
+                    groups=groups,
+                )
+            elif ndim == 5:  # 3D convolution (N, C, D, H, W)
+                return self._conv3d_impl(
+                    x,
+                    weight,
+                    bias=bias,
+                    strides=stride,
+                    padding=padding,
+                    dilation=dilation,
+                    groups=groups,
+                )
+            else:
+                raise ValueError(f"Unsupported convolution dimensionality: 
{ndim}")
+
     def _cross_entropy_loss(
         self,
         preds: relax.Expr,
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 0dfa4cc6da..0d4abb0336 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -969,6 +969,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "conv1d.default": self._conv1d,
             "conv2d.default": self._conv2d,
             "conv3d.default": self._conv3d,
+            "convolution.default": self._convolution,
             "cross_entropy_loss.default": self._cross_entropy_default,
             "einsum.default": self._einsum,
             "embedding.default": lambda node: self._embedding_impl(
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index ba14356e8e..8f308e59b7 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -5254,7 +5254,9 @@ def test_keep_params():
     example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
     model = Conv2D1()
     exported_program = torch.export.export(model, example_args)
-    mod = from_exported_program(exported_program, keep_params_as_input=True)
+    mod = from_exported_program(
+        exported_program, keep_params_as_input=True, run_ep_decomposition=True
+    )
     mod, params = detach_params(mod)
     tvm.ir.assert_structural_equal(mod, expected1)
     func = mod["main"]

Reply via email to