This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a747614a83 [Relax][PyTroch] Add NHWC layout support (#18548)
a747614a83 is described below

commit a747614a83ee665a4b0765953b0e5ff098063d5b
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Sat Dec 6 02:28:44 2025 +0800

    [Relax][PyTroch] Add NHWC layout support (#18548)
    
    ## Why
    
    - The interpolate operation was hardcoded to only support NCHW layout
    - Users need flexibility to choose the appropriate layout for their
    target platform
    
    ## How
    
    - Added default_image_layout parameter
    - Exposed default_image_layout parameter in the public from_fx()
---
 python/tvm/relax/frontend/torch/fx_translator.py |  36 +++++--
 tests/python/relax/test_frontend_from_fx.py      | 115 +++++++++++++++++++++++
 2 files changed, 144 insertions(+), 7 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 9c2d53a685..8b1f5de36b 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -33,11 +33,12 @@ class TorchFXImporter(BaseFXGraphImporter):
     import torch  # type: ignore
     from torch import fx
 
-    def __init__(self) -> None:
+    def __init__(self, default_image_layout: str = "NCHW") -> None:
         import torch  # type: ignore
 
         super().__init__()
         self.named_modules: Dict[str, torch.Module] = None
+        self.default_image_layout = default_image_layout
 
     ########## Utilities ##########
 
@@ -480,7 +481,6 @@ class TorchFXImporter(BaseFXGraphImporter):
         # torch.nn.functional.interpolate(
         #   input, size=None, scale_factor=None, mode='nearest', 
align_corners=None,
         #   recompute_scale_factor=None, antialias=False)
-        # (TODO) this is a temporary implementation for interpolate that only 
considers NCHW layout
         data = self.env[node.args[0]]
         size = (
             node.args[1]
@@ -523,13 +523,26 @@ class TorchFXImporter(BaseFXGraphImporter):
         if size is None:
             shape = self.shape_of(data)
             assert isinstance(shape, relax.ShapeExpr)
+            # Determine spatial dimension indices based on layout
+            # NCHW: spatial dims are [2, 3, ...] (skip batch and channel)
+            # NHWC: spatial dims are [1, 2, ...] (skip batch, before channel)
+            if self.default_image_layout == "NHWC":
+                spatial_start = 1
+                spatial_end = len(shape) - 1
+            else:  # NCHW or other layouts
+                spatial_start = 2
+                spatial_end = len(shape)
+
             if isinstance(scale_factor, tuple):
-                assert len(scale_factor) == len(shape) - 2
+                assert len(scale_factor) == spatial_end - spatial_start
                 size = tuple(
-                    int(shape[i].value * scale_factor[i - 2]) for i in 
range(2, len(shape))
+                    int(shape[i].value * scale_factor[i - spatial_start])
+                    for i in range(spatial_start, spatial_end)
                 )
             else:
-                size = tuple(int(shape[i].value * scale_factor) for i in 
range(2, len(shape)))
+                size = tuple(
+                    int(shape[i].value * scale_factor) for i in 
range(spatial_start, spatial_end)
+                )
 
         if method.startswith("nearest"):
             method = "nearest_neighbor"
@@ -545,7 +558,11 @@ class TorchFXImporter(BaseFXGraphImporter):
 
         return self.block_builder.emit(
             relax.op.image.resize2d(
-                data, size, layout="NCHW", method=method, 
coordinate_transformation_mode=coord_trans
+                data,
+                size,
+                layout=self.default_image_layout,
+                method=method,
+                coordinate_transformation_mode=coord_trans,
             )
         )
 
@@ -1150,6 +1167,7 @@ def from_fx(
     unwrap_unit_return_tuple: bool = False,
     no_bind_return_tuple: bool = False,
     custom_convert_map: dict = None,
+    default_image_layout: str = "NCHW",
 ) -> tvm.IRModule:
     """Convert a PyTorch FX GraphModule to a Relax program
 
@@ -1175,6 +1193,10 @@ def from_fx(
     custom_convert_map : Dictionary of str to Relax op
         A custom op conversion map in the same format as 
TorchFXImporter.convert_map
 
+    default_image_layout : str
+        The default layout for image operations (e.g., "NCHW" or "NHWC").
+        Default is "NCHW" which is the standard PyTorch layout.
+
     Returns
     -------
     output : tvm.IRModule
@@ -1242,7 +1264,7 @@ def from_fx(
     to print out the tabular representation of the PyTorch module, and then
     check the placeholder rows in the beginning of the tabular.
     """
-    return TorchFXImporter().from_fx(
+    return TorchFXImporter(default_image_layout=default_image_layout).from_fx(
         model,
         input_info,
         keep_params_as_input,
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index de30af01ee..b7aeea6687 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -3670,6 +3670,121 @@ def test_interpolate():
     verify_model(Interpolate4(), input_info, {}, expected4)
 
 
+def test_interpolate_nhwc_layout():
+    # First verify backward compatibility - default should still be NCHW
+    input_info_nchw = [([1, 3, 10, 10], "float32")]
+
+    class InterpolateDefault(Module):
+        def forward(self, input):
+            return torch.nn.functional.interpolate(input, (5, 5))
+
+    @tvm.script.ir_module
+    class expected_default_nchw:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 5, 5), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 5, 5), dtype="float32") = R.image.resize2d(
+                    input_1,
+                    (5, 5),
+                    roi=[0.000000, 0.000000, 0.000000, 0.000000],
+                    layout="NCHW",
+                    method="nearest_neighbor",
+                    coordinate_transformation_mode="asymmetric",
+                    rounding_method="round",
+                    cubic_alpha=-0.75,
+                    cubic_exclude=0,
+                    extrapolation_value=0,
+                    out_dtype="",
+                )
+                gv: R.Tensor((1, 3, 5, 5), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    # Verify default behavior (no default_image_layout parameter) uses NCHW
+    graph_model_default = fx.symbolic_trace(InterpolateDefault())
+    with torch.no_grad():
+        mod_default = from_fx(graph_model_default, input_info_nchw)
+    tvm.ir.assert_structural_equal(mod_default, expected_default_nchw)
+
+    # Now test NHWC layout
+    input_info = [([1, 10, 10, 3], "float32")]
+
+    class InterpolateNHWC(Module):
+        def forward(self, input):
+            return torch.nn.functional.interpolate(input, (5, 5))
+
+    @tvm.script.ir_module
+    class expected_nhwc:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 10, 10, 3), dtype="float32")
+        ) -> R.Tensor((1, 5, 5, 3), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 5, 5, 3), dtype="float32") = R.image.resize2d(
+                    input_1,
+                    (5, 5),
+                    roi=[0.000000, 0.000000, 0.000000, 0.000000],
+                    layout="NHWC",
+                    method="nearest_neighbor",
+                    coordinate_transformation_mode="asymmetric",
+                    rounding_method="round",
+                    cubic_alpha=-0.75,
+                    cubic_exclude=0,
+                    extrapolation_value=0,
+                    out_dtype="",
+                )
+                gv: R.Tensor((1, 5, 5, 3), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    # Test with NHWC layout
+    graph_model = fx.symbolic_trace(InterpolateNHWC())
+    with torch.no_grad():
+        mod = from_fx(graph_model, input_info, default_image_layout="NHWC")
+    tvm.ir.assert_structural_equal(mod, expected_nhwc)
+
+    # Test with bilinear interpolation and NHWC layout
+    class InterpolateNHWC2(Module):
+        def forward(self, input):
+            return torch.nn.functional.interpolate(
+                input, size=None, scale_factor=2.0, mode="bilinear", 
align_corners=False
+            )
+
+    @tvm.script.ir_module
+    class expected_nhwc2:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 10, 10, 3), dtype="float32")
+        ) -> R.Tensor((1, 20, 20, 3), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 20, 20, 3), dtype="float32") = 
R.image.resize2d(
+                    input_1,
+                    (20, 20),
+                    roi=[0.000000, 0.000000, 0.000000, 0.000000],
+                    layout="NHWC",
+                    method="linear",
+                    coordinate_transformation_mode="half_pixel",
+                    rounding_method="round",
+                    cubic_alpha=-0.75,
+                    cubic_exclude=0,
+                    extrapolation_value=0,
+                    out_dtype="",
+                )
+                gv: R.Tensor((1, 20, 20, 3), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    graph_model2 = fx.symbolic_trace(InterpolateNHWC2())
+    with torch.no_grad():
+        mod2 = from_fx(graph_model2, input_info, default_image_layout="NHWC")
+    tvm.ir.assert_structural_equal(mod2, expected_nhwc2)
+
+
 def test_addmm():
     input_info = [
         ([10, 10], "float32"),

Reply via email to