This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6cf49e6ee3 [Relax][PyTorch] Enhance scale_factor handling in
interpolation (#18550)
6cf49e6ee3 is described below
commit 6cf49e6ee3ba5209766a7aeff4000c00e7c4f58c
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Sat Dec 6 17:35:27 2025 +0800
[Relax][PyTorch] Enhance scale_factor handling in interpolation (#18550)
## Why
Fixes interpolation to support different scaling factors for height and
width (e.g., scale_factor=[2.0, 3.0])
## How
- Removed the bug: Stopped extracting just the first element ([0]) from
scale_factor lists
- Passed full value: Now passes the entire scale_factor (scalar or list)
to the underlying implementation, which already handles both correctly
---
.../frontend/torch/exported_program_translator.py | 18 ++++----
.../relax/test_frontend_from_exported_program.py | 51 ++++++++++++++++++++++
2 files changed, 60 insertions(+), 9 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 2ec61796c3..641e16f599 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -337,11 +337,11 @@ class ExportedProgramImporter(BaseFXGraphImporter):
)
else:
- # TODO figure out why pytorch export passes a list such as
- # [scale_factor,scale_factor] instead of just an int for
- # scale_factor. Using first element for now
+ # PyTorch export passes scale_factor as either a scalar or a
list/tuple
+ # (e.g., [2.0, 3.0] for different H and W scaling).
+ # Pass it as-is to _upsample_impl which handles both cases
correctly.
scale_factor = (
- node.args[2][0] if len(node.args) > 2 else
node.kwargs.get("scale_factor", 1)
+ node.args[2] if len(node.args) > 2 else
node.kwargs.get("scale_factor", 1)
)
align_corners = (
node.args[3] if len(node.args) > 3 else
node.kwargs.get("align_corners", None)
@@ -364,11 +364,11 @@ class ExportedProgramImporter(BaseFXGraphImporter):
if size is not None:
scale_factor = None
else:
- scale_arg = node.args[3] if len(node.args) > 3 else
node.kwargs.get("scale_factor", 1)
- if isinstance(scale_arg, (list, tuple)):
- scale_factor = scale_arg[0]
- else:
- scale_factor = scale_arg
+ # PyTorch export passes scale_factor as either a scalar or a
list/tuple.
+ # Pass it as-is to _upsample_impl which handles both cases
correctly.
+ scale_factor = (
+ node.args[3] if len(node.args) > 3 else
node.kwargs.get("scale_factor", 1)
+ )
return self._upsample_impl(
x,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 010bd026a8..68567e1fc8 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -8542,5 +8542,56 @@ def test_grid_sample():
verify_model(GridSample(), example_args, {}, expected)
+def test_upsample_nearest2d():
+ class UpsampleNearest2dScale(Module):
+ def forward(self, input):
+ return torch.nn.functional.interpolate(input, scale_factor=2.0,
mode="nearest")
+
+ class UpsampleNearest2dSize(Module):
+ def forward(self, input):
+ return torch.nn.functional.interpolate(input, size=(20, 20),
mode="nearest")
+
+ example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
+
+ @tvm.script.ir_module
+ class expected_scale:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 20, 20), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 20, 20), dtype="float32") =
R.image.resize2d(
+ input_1,
+ size=(20, 20),
+ layout="NCHW",
+ method="nearest_neighbor",
+ coordinate_transformation_mode="half_pixel",
+ )
+ gv: R.Tuple(R.Tensor((1, 3, 20, 20), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ @tvm.script.ir_module
+ class expected_size:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 10, 10), dtype="float32")
+ ) -> R.Tuple(R.Tensor((1, 3, 20, 20), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 20, 20), dtype="float32") =
R.image.resize2d(
+ input_1,
+ size=(20, 20),
+ layout="NCHW",
+ method="nearest_neighbor",
+ coordinate_transformation_mode="half_pixel",
+ )
+ gv: R.Tuple(R.Tensor((1, 3, 20, 20), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ verify_model(UpsampleNearest2dScale(), example_args, {}, expected_scale)
+ verify_model(UpsampleNearest2dSize(), example_args, {}, expected_size)
+
+
if __name__ == "__main__":
tvm.testing.main()