This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 96b8257002 [Arith] Memoize IntervalSet variable relaxation to avoid
exponential blowup (#19670)
96b8257002 is described below
commit 96b825700288d568ca6ea67351c0b035f7170f43
Author: Hongyi Jin <[email protected]>
AuthorDate: Thu Jun 4 08:19:18 2026 -0400
[Arith] Memoize IntervalSet variable relaxation to avoid exponential blowup
(#19670)
## Problem
`Analyzer::Bind` could hang indefinitely (>300s, ~200% CPU, no GPU work)
while binding a small expression for one variable. The root cause is
general and lives in `src/arith/int_set.cc`.
Diagnosis: 100% of the time is spent in `arith::Analyzer::Bind` →
`IntSetAnalyzer` → `IntervalSetEvaluator`, evaluating a **5-node** bound
expression. A counter showed **>2^20 `VisitExpr` calls at recursion
depth 67** with no end in sight.
## Root cause
`IntervalSetEvaluator::VisitExpr_(VarNode)` relaxes a variable's bounds
by recursively evaluating **both** the `min` and `max` sub-expressions
of its mapped interval. For diamond-shaped variable dependency chains
(`a → {b, c}`, `b → {d, e}`, …) the shared sub-expressions are
re-expanded along every path, so cost is **O(2^depth)** in the length of
the dependency chain — bounded only by `dom_map_.size()` (~67
interdependent vars in the failing case).
## Fix
Memoize the fully-relaxed interval **per variable** (`relax_memo_`) and
break cyclic dependencies with an in-progress set
(`relax_in_progress_`). A variable's relaxed interval is deterministic
for a given evaluator instance (`dom_map_`/`dom_constraints_` are
fixed), so memoizing collapses the diamonds to linear cost. Short chains
— the common case, which never reached the old `recur_depth_ >=
dom_map_.size()` cutoff — are unaffected, so the change is
behavior-preserving outside the pathological case.
## Tests
New regression tests in `tests/python/arith/test_arith_intset.py`:
- `test_relax_deep_variable_dependency_chain` — a 64-deep diamond
(`O(2^64)` without the fix; verified to hang on a clean build), also
asserting the relaxed result is correct (`x0 → [-n, 100+n]`).
- `test_relax_cyclic_variable_dependency` — a cyclic `x↔y` dependency
must terminate.
## Verification
- `tests/python/arith/test_arith_intset.py` — 20 passed (the deep-chain
test completes instantly).
- Full `tests/python/arith/` — 933 passed (1 pre-existing flaky
random-seed failure in `test_arith_solve_linear_equations.py` unrelated
to this change, passes on rerun).
Co-authored-by: Claude Opus 4.8 <[email protected]>
---
src/arith/int_set.cc | 32 +++++++++++++++++++++++++++++---
tests/python/arith/test_arith_intset.py | 30 ++++++++++++++++++++++++++++++
2 files changed, 59 insertions(+), 3 deletions(-)
diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc
index 6b3e2b9532..86a2d949bc 100644
--- a/src/arith/int_set.cc
+++ b/src/arith/int_set.cc
@@ -33,6 +33,7 @@
#include <algorithm>
#include <unordered_map>
+#include <unordered_set>
#include <utility>
#include "constraint_extract.h"
@@ -458,9 +459,29 @@ class IntervalSetEvaluator : public
ExprFunctor<IntervalSet(const PrimExpr&)> {
if (res->min_value.same_as(var) && res->max_value.same_as(var)) {
return res;
}
- // recursively evaluate mapped result
- // in case the domain contains variables to be relaxed.
- return Eval(res);
+ // Recursively relax the mapped interval, since the domain bounds may
+ // themselves reference other variables that need to be relaxed.
+ //
+ // Memoize the fully-relaxed interval per variable, and guard against
+ // cyclic variable dependencies with an in-progress set. Without this,
+ // diamond-shaped variable dependencies (var a -> {b, c}, b -> {d, e}, ...)
+ // are re-expanded along every path: each level evaluates both the min and
+ // max sub-expressions, so the cost is exponential (2^depth) in the length
+ // of the variable dependency chain rather than linear.
+ auto memo_it = relax_memo_.find(op);
+ if (memo_it != relax_memo_.end()) {
+ return memo_it->second;
+ }
+ if (relax_in_progress_.count(op)) {
+ // Cyclic dependency among variable bounds: stop relaxing here to keep
+ // the recursion finite, keeping this variable symbolic.
+ return res;
+ }
+ relax_in_progress_.insert(op);
+ IntervalSet relaxed = Eval(res);
+ relax_in_progress_.erase(op);
+ relax_memo_[op] = relaxed;
+ return relaxed;
}
IntervalSet VisitExpr_(const AddNode* op) final { return
VisitBinaryExpr_<Add>(op); }
@@ -606,6 +627,11 @@ class IntervalSetEvaluator : public
ExprFunctor<IntervalSet(const PrimExpr&)> {
// recursive depth
int recur_depth_{0};
+ // Memo of fully-relaxed interval sets per variable, to avoid exponential
+ // re-expansion of diamond-shaped variable dependencies.
+ std::unordered_map<const VarNode*, IntervalSet> relax_memo_;
+ // Variables currently being relaxed, used to break cyclic dependencies.
+ std::unordered_set<const VarNode*> relax_in_progress_;
// analyzer
Analyzer* analyzer_;
const ffi::Map<Var, IntSet>& dom_map_;
diff --git a/tests/python/arith/test_arith_intset.py
b/tests/python/arith/test_arith_intset.py
index 49e09191d6..a34c528e69 100644
--- a/tests/python/arith/test_arith_intset.py
+++ b/tests/python/arith/test_arith_intset.py
@@ -394,5 +394,35 @@ def test_modular_set():
)
+def test_relax_deep_variable_dependency_chain():
+ """Regression test for exponential variable-relaxation blowup.
+
+ When a variable's interval bound references another variable that is also
in
+ the domain map, the evaluator relaxes it transitively. A diamond-shaped
+ chain -- where each variable's bound references the next one in *both* its
+ min and its max -- used to be re-expanded along every path, costing
+ O(2^depth) and hanging indefinitely. The relaxation is now memoized per
+ variable, so this completes in linear time.
+ """
+ ck = IntSetChecker()
+ n = 64 # 2^64 expansions without memoization; trivially fast with it.
+ xs = [tvm.tirx.Var(f"x{i}", "int32") for i in range(n + 1)]
+ dmap = {xs[i]: tvm.arith.IntervalSet(xs[i + 1] - 1, xs[i + 1] + 1) for i
in range(n)}
+ dmap[xs[n]] = tvm.arith.IntervalSet(0, 100)
+ # x0 relaxes through the whole chain: [0 - n, 100 + n].
+ ck.verify(xs[0], dmap, (-n, 100 + n))
+
+
+def test_relax_cyclic_variable_dependency():
+ """A cyclic variable dependency must terminate (and stay symbolic)."""
+ ana = tvm.arith.Analyzer()
+ x = tvm.tirx.Var("x", "int32")
+ y = tvm.tirx.Var("y", "int32")
+ # x depends on y and y depends on x: relaxation must not loop forever.
+ dmap = {x: tvm.arith.IntervalSet(y, y), y: tvm.arith.IntervalSet(x, x)}
+ res = ana.int_set(x, dmap)
+ assert res is not None
+
+
if __name__ == "__main__":
tvm.testing.main()