gemini-code-assist[bot] commented on code in PR #18550:
URL: https://github.com/apache/tvm/pull/18550#discussion_r2593399949
##########
tests/python/relax/test_frontend_from_exported_program.py:
##########
@@ -4900,6 +4900,67 @@ def main(
verify_model(InterpolateBilinearAA(), example_args, {},
expected_bilinear_aa)
+def test_interpolate_scale_factor_list():
+ class InterpolateNearestScalar(Module):
+ """Nearest interpolation with scalar scale_factor."""
+
+ def forward(self, input):
+ return torch.nn.functional.interpolate(input, scale_factor=2.0,
mode="nearest")
+
+ class InterpolateNearestList(Module):
+ """Nearest interpolation with list scale_factor (different H and W
scaling)."""
+
+ def forward(self, input):
+ return torch.nn.functional.interpolate(input, scale_factor=[2.0,
3.0], mode="nearest")
+
+ class InterpolateBilinearScalar(Module):
+ """Bilinear interpolation with scalar scale_factor."""
+
+ def forward(self, input):
+ return torch.nn.functional.interpolate(
+ input, scale_factor=2.0, mode="bilinear", align_corners=False
+ )
+
+ class InterpolateBilinearList(Module):
+ """Bilinear interpolation with list scale_factor (different H and W
scaling)."""
+
+ def forward(self, input):
+ return torch.nn.functional.interpolate(
+ input, scale_factor=[2.0, 3.0], mode="bilinear",
align_corners=False
+ )
+
+ class InterpolateBicubicScalar(Module):
+ """Bicubic interpolation with scalar scale_factor."""
+
+ def forward(self, input):
+ return torch.nn.functional.interpolate(
+ input, scale_factor=2.0, mode="bicubic", align_corners=False
+ )
+
+ class InterpolateBicubicList(Module):
+ """Bicubic interpolation with list scale_factor (different H and W
scaling)."""
+
+ def forward(self, input):
+ return torch.nn.functional.interpolate(
+ input, scale_factor=[2.0, 3.0], mode="bicubic",
align_corners=False
+ )
+
+ # Test with 32x32 input
+ example_args = (torch.randn(1, 3, 32, 32, dtype=torch.float32),)
+
+ # Test nearest mode with scalar and list scale_factor
+ verify_model_numerically(InterpolateNearestScalar(), example_args,
rtol=1e-5, atol=1e-5)
+ verify_model_numerically(InterpolateNearestList(), example_args,
rtol=1e-5, atol=1e-5)
+
+ # Test bilinear mode with scalar and list scale_factor
+ verify_model_numerically(InterpolateBilinearScalar(), example_args,
rtol=1e-5, atol=1e-5)
+ verify_model_numerically(InterpolateBilinearList(), example_args,
rtol=1e-5, atol=1e-5)
+
+ # Test bicubic mode with scalar and list scale_factor
+ verify_model_numerically(InterpolateBicubicScalar(), example_args,
rtol=1e-5, atol=1e-5)
+ verify_model_numerically(InterpolateBicubicList(), example_args,
rtol=1e-5, atol=1e-5)
+
Review Comment:

This test is great for covering the different interpolation modes with both
scalar and list `scale_factor`s. However, there's a lot of repeated code
between the different test cases. You can significantly simplify this by using
`pytest.mark.parametrize` to iterate through the different modes and scale
factors. This will make the test more concise and easier to maintain or extend
in the future.
```python
@pytest.mark.parametrize(
"mode, scale_factor",
[
("nearest", 2.0),
("nearest", [2.0, 3.0]),
("bilinear", 2.0),
("bilinear", [2.0, 3.0]),
("bicubic", 2.0),
("bicubic", [2.0, 3.0]),
],
)
def test_interpolate_scale_factor_list(mode, scale_factor):
"""Test interpolation with various modes and scale factors."""
class InterpolateModel(Module):
def forward(self, input_tensor):
kwargs = {"scale_factor": scale_factor, "mode": mode}
if mode != "nearest":
kwargs["align_corners"] = False
return torch.nn.functional.interpolate(input_tensor, **kwargs)
# Test with 32x32 input
example_args = (torch.randn(1, 3, 32, 32, dtype=torch.float32),)
verify_model_numerically(InterpolateModel(), example_args, rtol=1e-5,
atol=1e-5)
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]