This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 394f668e0d [Relax][Pytorch] Support basic range constraints (#18429)
394f668e0d is described below
commit 394f668e0d568b23930b60d7c8e3e91f0bd2d667
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Wed Nov 12 14:57:37 2025 +0800
[Relax][Pytorch] Support basic range constraints (#18429)
* Support basic range constraints
* Apply gemini-code-assist suggestions
Co-authored-by: gemini-code-assist[bot]
<176961590+gemini-code-assist[bot]@users.noreply.github.com>
* Apply reviewer comments
* Fix lint error
* Refactor frontend test to use consistent size variable
---------
Co-authored-by: gemini-code-assist[bot]
<176961590+gemini-code-assist[bot]@users.noreply.github.com>
---
.../frontend/torch/exported_program_translator.py | 30 +++++++++++++++++++---
.../relax/test_frontend_from_exported_program.py | 28 ++++++++++++++++++++
2 files changed, 54 insertions(+), 4 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index ddd19f2b58..0dfa4cc6da 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -1099,11 +1099,23 @@ class ExportedProgramImporter(BaseFXGraphImporter):
def create_input_vars(
self, exported_program: torch.export.ExportedProgram
- ) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var]]:
+ ) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var], Dict[str,
Tuple[int, int]]]:
"""Create relax input vars."""
parameters_buffers_constants = OrderedDict()
user_inputs = OrderedDict()
torch_symbol_to_relax_var: Dict[str, tvm.tir.Var] = {}
+ range_constraints = {}
+
+ if hasattr(exported_program, "range_constraints"):
+ for symbol, value_range in
exported_program.range_constraints.items():
+ symbol_name = str(symbol)
+ if hasattr(value_range, "lower") and hasattr(value_range,
"upper"):
+ try:
+ lower = int(value_range.lower)
+ upper = int(value_range.upper)
+ range_constraints[symbol_name] = (lower, upper)
+ except (OverflowError, AttributeError, TypeError):
+ continue
for spec in exported_program.graph_signature.input_specs:
name_hint = spec.arg.name
@@ -1121,7 +1133,6 @@ class ExportedProgramImporter(BaseFXGraphImporter):
torch_shape = exported_program.state_dict[spec.target].shape
torch_dtype = exported_program.state_dict[spec.target].dtype
- # TODO(mshr-h): Support range constraints
relax_shape = [
torch_symbol_to_relax_var.setdefault(str(s),
tvm.tir.SizeVar(str(s), "int64"))
if isinstance(s, torch.SymInt)
@@ -1136,7 +1147,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
else:
parameters_buffers_constants[name_hint] = relax_var
- return parameters_buffers_constants, user_inputs
+ return parameters_buffers_constants, user_inputs, range_constraints
def from_exported_program(
self,
@@ -1149,7 +1160,11 @@ class ExportedProgramImporter(BaseFXGraphImporter):
from torch import fx # type: ignore
# Create input variables.
- parameter_buffer_constant_vars, user_input_vars =
self.create_input_vars(exported_program)
+ (
+ parameter_buffer_constant_vars,
+ user_input_vars,
+ range_constraints,
+ ) = self.create_input_vars(exported_program)
inputs_vars = user_input_vars.copy()
inputs_vars.update(parameter_buffer_constant_vars)
@@ -1157,6 +1172,13 @@ class ExportedProgramImporter(BaseFXGraphImporter):
self.block_builder = relax.BlockBuilder()
func_name = "main"
func_attrs = {"num_input": len(user_input_vars)} if
keep_params_as_input else None
+ if range_constraints:
+ if func_attrs is None:
+ func_attrs = {}
+ tir_var_upper_bound = {
+ var_name: upper for var_name, (_, upper) in
range_constraints.items()
+ }
+ func_attrs["tir_var_upper_bound"] = tir_var_upper_bound
nodes: List[fx.Node] = exported_program.graph.nodes
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index fb4f77567e..ba14356e8e 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -6663,5 +6663,33 @@ def test_gru():
np.testing.assert_allclose(pytorch_output2.numpy(), tvm_output2_np,
rtol=1e-4, atol=1e-5)
+def test_dynamic_shape_with_range_constraints():
+ class DynamicModel(torch.nn.Module):
+ def forward(self, x1, x2):
+ return torch.ops.aten.add.Tensor(x1, x2)
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x1: R.Tensor(("s0", 4), dtype="float32"), x2: R.Tensor(("s0", 4),
dtype="float32")
+ ) -> R.Tuple(R.Tensor(("s0", 4), dtype="float32")):
+ s0 = T.int64(is_size_var=True)
+ R.func_attr({"tir_var_upper_bound": {"s0": 64}})
+ with R.dataflow():
+ lv: R.Tensor((s0, 4), dtype="float32") = R.add(x1, x2)
+ gv: R.Tuple(R.Tensor((s0, 4), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(8, 4), torch.randn(8, 4))
+ batch = torch.export.Dim("batch", min=1, max=64)
+ dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}
+ exported_program = export(DynamicModel(), args=example_args,
dynamic_shapes=dynamic_shapes)
+
+ mod = from_exported_program(exported_program, run_ep_decomposition=True)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()