gemini-code-assist[bot] commented on code in PR #18550:
URL: https://github.com/apache/tvm/pull/18550#discussion_r2593399949


##########
tests/python/relax/test_frontend_from_exported_program.py:
##########
@@ -4900,6 +4900,67 @@ def main(
     verify_model(InterpolateBilinearAA(), example_args, {}, 
expected_bilinear_aa)
 
 
+def test_interpolate_scale_factor_list():
+    class InterpolateNearestScalar(Module):
+        """Nearest interpolation with scalar scale_factor."""
+
+        def forward(self, input):
+            return torch.nn.functional.interpolate(input, scale_factor=2.0, 
mode="nearest")
+
+    class InterpolateNearestList(Module):
+        """Nearest interpolation with list scale_factor (different H and W 
scaling)."""
+
+        def forward(self, input):
+            return torch.nn.functional.interpolate(input, scale_factor=[2.0, 
3.0], mode="nearest")
+
+    class InterpolateBilinearScalar(Module):
+        """Bilinear interpolation with scalar scale_factor."""
+
+        def forward(self, input):
+            return torch.nn.functional.interpolate(
+                input, scale_factor=2.0, mode="bilinear", align_corners=False
+            )
+
+    class InterpolateBilinearList(Module):
+        """Bilinear interpolation with list scale_factor (different H and W 
scaling)."""
+
+        def forward(self, input):
+            return torch.nn.functional.interpolate(
+                input, scale_factor=[2.0, 3.0], mode="bilinear", 
align_corners=False
+            )
+
+    class InterpolateBicubicScalar(Module):
+        """Bicubic interpolation with scalar scale_factor."""
+
+        def forward(self, input):
+            return torch.nn.functional.interpolate(
+                input, scale_factor=2.0, mode="bicubic", align_corners=False
+            )
+
+    class InterpolateBicubicList(Module):
+        """Bicubic interpolation with list scale_factor (different H and W 
scaling)."""
+
+        def forward(self, input):
+            return torch.nn.functional.interpolate(
+                input, scale_factor=[2.0, 3.0], mode="bicubic", 
align_corners=False
+            )
+
+    # Test with 32x32 input
+    example_args = (torch.randn(1, 3, 32, 32, dtype=torch.float32),)
+
+    # Test nearest mode with scalar and list scale_factor
+    verify_model_numerically(InterpolateNearestScalar(), example_args, 
rtol=1e-5, atol=1e-5)
+    verify_model_numerically(InterpolateNearestList(), example_args, 
rtol=1e-5, atol=1e-5)
+
+    # Test bilinear mode with scalar and list scale_factor
+    verify_model_numerically(InterpolateBilinearScalar(), example_args, 
rtol=1e-5, atol=1e-5)
+    verify_model_numerically(InterpolateBilinearList(), example_args, 
rtol=1e-5, atol=1e-5)
+
+    # Test bicubic mode with scalar and list scale_factor
+    verify_model_numerically(InterpolateBicubicScalar(), example_args, 
rtol=1e-5, atol=1e-5)
+    verify_model_numerically(InterpolateBicubicList(), example_args, 
rtol=1e-5, atol=1e-5)
+

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   This test is great for covering the different interpolation modes with both 
scalar and list `scale_factor`s. However, there's a lot of repeated code 
between the different test cases. You can significantly simplify this by using 
`pytest.mark.parametrize` to iterate through the different modes and scale 
factors. This will make the test more concise and easier to maintain or extend 
in the future.
   
   ```python
   @pytest.mark.parametrize(
       "mode, scale_factor",
       [
           ("nearest", 2.0),
           ("nearest", [2.0, 3.0]),
           ("bilinear", 2.0),
           ("bilinear", [2.0, 3.0]),
           ("bicubic", 2.0),
           ("bicubic", [2.0, 3.0]),
       ],
   )
   def test_interpolate_scale_factor_list(mode, scale_factor):
       """Test interpolation with various modes and scale factors."""
   
       class InterpolateModel(Module):
           def forward(self, input_tensor):
               kwargs = {"scale_factor": scale_factor, "mode": mode}
               if mode != "nearest":
                   kwargs["align_corners"] = False
               return torch.nn.functional.interpolate(input_tensor, **kwargs)
   
       # Test with 32x32 input
       example_args = (torch.randn(1, 3, 32, 32, dtype=torch.float32),)
       verify_model_numerically(InterpolateModel(), example_args, rtol=1e-5, 
atol=1e-5)
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to