gemini-code-assist[bot] commented on code in PR #18429:
URL: https://github.com/apache/tvm/pull/18429#discussion_r2513643742
##########
python/tvm/relax/frontend/torch/exported_program_translator.py:
##########
@@ -1097,11 +1097,28 @@ def create_convert_map(
def create_input_vars(
self, exported_program: torch.export.ExportedProgram
- ) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var]]:
+ ) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var], Dict[str,
Tuple[int, int]]]:
"""Create relax input vars."""
parameters_buffers_constants = OrderedDict()
user_inputs = OrderedDict()
torch_symbol_to_relax_var: Dict[str, tvm.tir.Var] = {}
+ range_constraints = {}
+
+ if hasattr(exported_program, "range_constraints"):
+ for symbol, value_range in
exported_program.range_constraints.items():
+ symbol_name = str(symbol)
+ if hasattr(value_range, "lower") and hasattr(value_range,
"upper"):
+ try:
+ lower = int(value_range.lower)
+ except (OverflowError, AttributeError, TypeError):
+ continue
+
+ try:
+ upper = int(value_range.upper)
+ except (OverflowError, AttributeError, TypeError):
+ continue
+
+ range_constraints[symbol_name] = (lower, upper)
Review Comment:

The two `try-except` blocks for parsing `lower` and `upper` values are
repetitive. You can combine them into a single `try-except` block for better
readability and conciseness. This would also ensure that `range_constraints` is
only updated if both `lower` and `upper` are successfully parsed.
```suggestion
try:
lower = int(value_range.lower)
upper = int(value_range.upper)
range_constraints[symbol_name] = (lower, upper)
except (OverflowError, AttributeError, TypeError):
continue
```
##########
tests/python/relax/test_frontend_from_exported_program.py:
##########
@@ -6527,5 +6527,29 @@ def forward(self, x):
np.testing.assert_allclose(pytorch_output2.numpy(), tvm_output2_np,
rtol=1e-4, atol=1e-5)
+def test_dynamic_shape_with_range_constraints():
+ class DynamicModel(torch.nn.Module):
+ def forward(self, x1, x2):
+ return torch.ops.aten.add.Tensor(x1, x2)
+
+ example_args = (torch.randn(8, 4), torch.randn(8, 4))
+ batch = torch.export.Dim("batch", min=1, max=64)
+ dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}
+ exported_program = export(DynamicModel(), args=example_args,
dynamic_shapes=dynamic_shapes)
+
+ mod = from_exported_program(exported_program)
+
+ main_func = mod["main"]
+ assert hasattr(main_func, "attrs"), "Function should have attributes"
+
+ if "shape_var_constraints" in main_func.attrs:
+ constraints = main_func.attrs["shape_var_constraints"]
+ assert len(constraints) > 0, "Should have at least one constraint"
+
+ for symbol_name, (min_val, max_val) in constraints.items():
+ assert min_val == 1, f"Expected min=1 for {symbol_name}, got
{min_val}"
+ assert max_val == 64, f"Expected max=64 for {symbol_name}, got
{max_val}"
Review Comment:

Using `if "shape_var_constraints" in main_func.attrs:` makes the test pass
silently if the attribute is missing, which would hide a potential regression.
To make the test more robust and ensure it actually verifies the presence of
range constraints, you should use an `assert` statement instead.
```suggestion
assert "shape_var_constraints" in main_func.attrs
constraints = main_func.attrs["shape_var_constraints"]
assert len(constraints) > 0, "Should have at least one constraint"
for symbol_name, (min_val, max_val) in constraints.items():
assert min_val == 1, f"Expected min=1 for {symbol_name}, got
{min_val}"
assert max_val == 64, f"Expected max=64 for {symbol_name}, got
{max_val}"
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]