mshr-h commented on code in PR #17898:
URL: https://github.com/apache/tvm/pull/17898#discussion_r2072546077
##########
python/tvm/relax/frontend/torch/exported_program_translator.py:
##########
@@ -534,7 +543,55 @@ def create_input_vars(
else:
parameters_buffers_constants[name_hint] = relax_var
- return parameters_buffers_constants, user_inputs
+ # Extract range constraints for TIR vars
+ if hasattr(exported_program, "range_constraints") and
exported_program.range_constraints:
+ for torch_sym_expr, constraint in
exported_program.range_constraints.items():
+ # Convert sympy expression to string for mapping
+ torch_sym_expr_str = str(torch_sym_expr)
+
+ if torch_sym_expr_str in torch_symbol_to_relax_var:
+ relax_tir_var =
torch_symbol_to_relax_var[torch_sym_expr_str]
+ # TODO(sjt): Handle SymFloat, SymBool cases as well.
+ # Note: min / max could be int or SymInt objects.
+ # Need to handle symbolic shapes as well.
+ min_val = constraint.min
+ max_val = constraint.max
+ # Call helper to add/refine constraint
+ self._add_range_constraint(
+ relax_range_constraints, relax_tir_var, min_val,
max_val
+ )
+ # else:
+ # FIXED Indentation for Black:
+ # TODO: Handle complex expressions (e.g., s0 + 1) for advanced
support
+ # print(f"Skipping complex constraint expression:
{torch_sym_expr}")
+
+ return parameters_buffers_constants, user_inputs,
relax_range_constraints
+
+ # NEW HELPER METHOD
Review Comment:
Please remove the unnecessary comment.
##########
python/tvm/relax/frontend/torch/exported_program_translator.py:
##########
@@ -534,7 +543,55 @@ def create_input_vars(
else:
parameters_buffers_constants[name_hint] = relax_var
- return parameters_buffers_constants, user_inputs
+ # Extract range constraints for TIR vars
+ if hasattr(exported_program, "range_constraints") and
exported_program.range_constraints:
+ for torch_sym_expr, constraint in
exported_program.range_constraints.items():
+ # Convert sympy expression to string for mapping
+ torch_sym_expr_str = str(torch_sym_expr)
+
+ if torch_sym_expr_str in torch_symbol_to_relax_var:
+ relax_tir_var =
torch_symbol_to_relax_var[torch_sym_expr_str]
+ # TODO(sjt): Handle SymFloat, SymBool cases as well.
+ # Note: min / max could be int or SymInt objects.
+ # Need to handle symbolic shapes as well.
+ min_val = constraint.min
+ max_val = constraint.max
+ # Call helper to add/refine constraint
+ self._add_range_constraint(
+ relax_range_constraints, relax_tir_var, min_val,
max_val
+ )
+ # else:
+ # FIXED Indentation for Black:
+ # TODO: Handle complex expressions (e.g., s0 + 1) for advanced
support
+ # print(f"Skipping complex constraint expression:
{torch_sym_expr}")
+
+ return parameters_buffers_constants, user_inputs,
relax_range_constraints
+
+ # NEW HELPER METHOD
+ def _add_range_constraint(self, constraints_dict, relax_tir_var, min_val,
max_val):
Review Comment:
`@staticmethod` would be better since it doesn't access the instance
variable or method.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]