mshr-h commented on code in PR #18429:
URL: https://github.com/apache/tvm/pull/18429#discussion_r2513601966
##########
python/tvm/relax/frontend/torch/exported_program_translator.py:
##########
@@ -1121,13 +1136,19 @@ def create_input_vars(
torch_shape = exported_program.state_dict[spec.target].shape
torch_dtype = exported_program.state_dict[spec.target].dtype
- # TODO(mshr-h): Support range constraints
- relax_shape = [
- torch_symbol_to_relax_var.setdefault(str(s),
tvm.tir.SizeVar(str(s), "int64"))
- if isinstance(s, torch.SymInt)
- else s
- for s in torch_shape
- ]
+ # Create TIR variables for symbolic dimensions
Review Comment:
Looks like there's no functional changes. Any reason for the change?
##########
tests/python/relax/test_frontend_from_exported_program.py:
##########
@@ -6663,5 +6527,29 @@ def forward(self, x):
np.testing.assert_allclose(pytorch_output2.numpy(), tvm_output2_np,
rtol=1e-4, atol=1e-5)
+def test_dynamic_shape_with_range_constraints():
+ class DynamicModel(torch.nn.Module):
+ def forward(self, x1, x2):
+ return torch.ops.aten.add.Tensor(x1, x2)
+
+ example_args = (torch.randn(8, 4), torch.randn(8, 4))
+ batch = torch.export.Dim("batch", min=1, max=64)
+ dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}
+ exported_program = export(DynamicModel(), args=example_args,
dynamic_shapes=dynamic_shapes)
+
+ mod = from_exported_program(exported_program)
+
+ main_func = mod["main"]
+ assert hasattr(main_func, "attrs"), "Function should have attributes"
Review Comment:
Please use structual equality instead of manually checking the attributes.
##########
python/tvm/relax/frontend/torch/exported_program_translator.py:
##########
@@ -1149,14 +1170,22 @@ def from_exported_program(
from torch import fx # type: ignore
# Create input variables.
- parameter_buffer_constant_vars, user_input_vars =
self.create_input_vars(exported_program)
+ (
+ parameter_buffer_constant_vars,
+ user_input_vars,
+ range_constraints,
+ ) = self.create_input_vars(exported_program)
inputs_vars = user_input_vars.copy()
inputs_vars.update(parameter_buffer_constant_vars)
# Initialize the block builder with a function and a dataflow block.
self.block_builder = relax.BlockBuilder()
func_name = "main"
func_attrs = {"num_input": len(user_input_vars)} if
keep_params_as_input else None
+ if range_constraints:
+ if func_attrs is None:
+ func_attrs = {}
+ func_attrs["shape_var_constraints"] = range_constraints
Review Comment:
Please use `tir_var_upper_bound` to annotate upper bound.
I don't think we need to keep the lower bound at the moment. If we have a
real use case for it, it's fine to keep it in the Relax module.
https://github.com/apache/tvm/blob/main/src/relax/transform/static_plan_block_memory.cc#L62-L66
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]