This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 354c5f1008 [Unity] [Bugfix] Fix bug in interpolate operator's default 
mode parameter in PyTorch frontend (#15933)
354c5f1008 is described below

commit 354c5f100832733d809a37daec3fec2a4c115d06
Author: Thrsu <[email protected]>
AuthorDate: Mon Oct 16 20:29:52 2023 +0800

    [Unity] [Bugfix] Fix bug in interpolate operator's default mode parameter 
in PyTorch frontend (#15933)
    
    * Fix wrong attribute name of interpolate
    
    * Add regression test case.
    
    * Reformat test_frontend_from_fx.py
    
    * Update test_frontend_from_fx.py
---
 python/tvm/relax/frontend/torch/fx_translator.py |  2 +-
 tests/python/relax/test_frontend_from_fx.py      | 37 ++++++++++++++++++++++++
 2 files changed, 38 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 6062280b9d..7fa0358dc6 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -1094,7 +1094,7 @@ class TorchFXImporter:
         method = (
             node.args[3]
             if len(node.args) > 3
-            else (node.kwargs["method"] if "method" in node.kwargs else 
"nearest")
+            else (node.kwargs["mode"] if "mode" in node.kwargs else "nearest")
         )
         align_corners = (
             node.args[4]
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index a1acff4974..d7ad0d83dd 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -2376,6 +2376,43 @@ def test_interpolate():
 
     verify_model(Interpolate(), input_info, {}, expected1)
 
+    class Interpolate2(Module):
+        def forward(self, input):
+            return torch.nn.functional.interpolate(
+                input,
+                size=None,
+                scale_factor=2.0,
+                mode="bilinear",
+                align_corners=False,
+            )
+
+    @tvm.script.ir_module
+    class expected2:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+        ) -> R.Tensor((1, 3, 20, 20), dtype="float32"):
+            # block 0
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 20, 20), dtype="float32") = 
R.image.resize2d(
+                    input_1,
+                    (20, 20),
+                    roi=[0.000000, 0.000000, 0.000000, 0.000000],
+                    layout="NCHW",
+                    method="linear",
+                    coordinate_transformation_mode="half_pixel",
+                    rounding_method="round",
+                    cubic_alpha=-0.5,
+                    cubic_exclude=0,
+                    extrapolation_value=0,
+                    out_dtype="",
+                )
+                gv: R.Tensor((1, 3, 20, 20), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Interpolate2(), input_info, {}, expected2)
+
 
 def test_addmm():
     input_info = [

Reply via email to