This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 913fc4bf63 [Arith] Gate canonical-simplify LT Case 2 on extra scale ==
+1 (#19669)
913fc4bf63 is described below
commit 913fc4bf63a3773a94bd415c3250272151178038
Author: Hongyi Jin <[email protected]>
AuthorDate: Thu Jun 4 13:01:12 2026 -0400
[Arith] Gate canonical-simplify LT Case 2 on extra scale == +1 (#19669)
## Summary
`CanonicalSimplifier::Impl::VisitExpr_(LTNode)` Case 2 rewrites
S + xn < 0 ⇔ S/d + xn // d < 0 where d = gcd(scales)
The Case 1 derivation only works when `xn ≥ 0`. With `scale = -1` the
equivalence becomes `≤` rather than `<`, and the rewrite silently
strengthens the predicate by dropping the boundary `S/d == xn // d`.
After CSE/inlining, a comparison such as `2*(tx%4) < 16*warp +
(tx%32)//4` (where `row` and `col` are independent projections of the
same lane id) reaches canonical_simplify with the divided projection on
the LHS (scale = -1), and Case 2 folds it to a plain `0 < warp_id` —
zeroing every thread that should have written `val` in warp 0. The same
path also folds other configurations (e.g. `0 < (tx%32) - 8*warp`) all
the way to `False`.
The fix gates Case 2 with `extra->args[0]->scale == 1`. The original
target shape (`yn % m` with positive scale and `lower_factor=1`, plus
the `scale = +1 / lower_factor > 1` generalization) is unchanged;
truly-always-true comparisons still fold to `True`.
## Test plan
- New regression test `test_simplify_le_negative_scale_extra` in
`tests/python/arith/test_arith_canonical_simplify.py` — asserts on
simplified `PrimExpr`, no GPU required; pre-fix fails, post-fix passes.
It also pins the buggy `scale = -1` shapes to their unsimplified form,
confirms the `scale = +1` Case 2 path still optimizes, and re-asserts
the truly-always-true variant still folds to `True`.
- Existing `test_simplify_le` (the original Case 2 target with `scale =
+1`) still passes.
- `tests/python/arith/test_arith_canonical_simplify.py` — 16 passed.
- Full `tests/python/arith/` — 932 passed (1 pre-existing flaky
random-seed failure in `test_arith_solve_linear_equations.py` unrelated
to this change, passes on rerun).
---
src/arith/canonical_simplify.cc | 11 +++++-
.../python/arith/test_arith_canonical_simplify.py | 44 ++++++++++++++++++++++
2 files changed, 53 insertions(+), 2 deletions(-)
diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc
index ac1b89f97a..0001afbdfe 100644
--- a/src/arith/canonical_simplify.cc
+++ b/src/arith/canonical_simplify.cc
@@ -1419,10 +1419,17 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const
LTNode* op) {
// Case 1. 0 <= xn < d
divisible.CopyOnWrite()->DivideBy(gcd);
return Rewriter::VisitExpr(divisible->Normalize() < make_zero(dtype));
- } else if (extra->args.size() == 1 &&
+ } else if (extra->args.size() == 1 && extra->args[0]->scale == 1 &&
extra->args[0]->upper_factor != ConstIntBoundNode::kPosInf &&
extra->args[0]->upper_factor % (gcd *
extra->args[0]->lower_factor) == 0) {
- // Case 2. xn == yn % m, where m % d == 0
+ // Case 2. xn == ((yn % m) // L), scale = +1, m % (d*L) == 0.
+ // S + xn < 0 with S divisible by d ⇔ S/d + xn // d < 0, because
+ // xn % d ∈ [0, d) lets us drop the remainder via the Case 1 argument,
+ // and xn // d = (yn // (d*L)) % (m/(d*L)).
+ // The scale must be +1: with scale = -1 the equivalence becomes ≤
+ // rather than <, so the rewrite would strengthen the predicate and
+ // silently drop the boundary S/d == xn // d (e.g. row > col where
+ // row and col are independent projections of the same lane id).
divisible.CopyOnWrite()->DivideBy(gcd);
const auto split_expr = extra->args[0];
int64_t lower_factor = gcd * extra->args[0]->lower_factor;
diff --git a/tests/python/arith/test_arith_canonical_simplify.py
b/tests/python/arith/test_arith_canonical_simplify.py
index 35ecf3b700..4d81f9031c 100644
--- a/tests/python/arith/test_arith_canonical_simplify.py
+++ b/tests/python/arith/test_arith_canonical_simplify.py
@@ -490,5 +490,49 @@ def test_simplify_le():
ck.verify(x * 1024 + y < z * 7168, x - z * 7 < 0)
+def test_simplify_le_negative_scale_extra():
+ """Regression: Case 2 of the LT-with-divisible-coeffs rewrite must not
+ fire when the leftover split term has a negative scale.
+
+ The rewrite ``S + xn < 0 ⇔ S/d + xn // d < 0`` is only sound when
+ the leftover ``xn`` has scale ``+1``. With scale ``-1`` the equivalence
+ becomes ``≤`` rather than ``<`` and the rewrite silently strengthens
+ the predicate. The original bug surfaced as ``row > col`` masks of
+ ``.16x*b`` tcgen05 readbacks collapsing to plain ``warp_id > k``
+ comparisons (lower-triangle writes were silently dropped on the
+ boundary warp).
+ """
+ ck = CanonicalChecker()
+ tx = tvm.tirx.Var("tx", "int32")
+ warp = tvm.tirx.Var("warp", "int32")
+ ck.analyzer.bind(tx, tvm.ir.Range(0, 128))
+ ck.analyzer.bind(warp, tvm.ir.Range(0, 4))
+
+ # Same-source joint projection: the comparison genuinely depends on tx
+ # at warp == 0 (e.g. tx == 4 ⇒ 0 < 1 = True; tx == 1 ⇒ 2 < 0 = False),
+ # so the simplifier must keep both sides. Pre-fix this folded to
+ # ``0 < warp`` and dropped every True case in warp 0.
+ expr = (tx % 4) * 2 < warp * 16 + (tx % 32) // 4
+ ck.verify(expr, expr)
+
+ # The simpler ``scale = -1`` with ``lower_factor = 1`` shape. Pre-fix
+ # this folded to ``False`` (drops all warp >= 1 cases where the rhs
+ # actually exceeds 8*warp).
+ expr = warp * 8 < (tx % 32)
+ ck.verify(expr, expr)
+
+ # The corresponding ``scale = +1`` Case 2 path (the rewrite this guards)
+ # must still optimize — verifies we did not over-restrict.
+ x1 = tvm.tirx.Var("x1", "int32")
+ y1 = tvm.tirx.Var("y1", "int32")
+ ck.verify(x1 * 64 + (y1 % 64) < 120, x1 * 8 + (y1 % 64) // 8 < 15)
+
+ # The truly-always-true comparison that arises from the same kernel
+ # (``r = 2 / va = 1`` in the tcgen05.ld.16x256b readback) must still
+ # fold to True so the masked store can be elided.
+ expr_true = (tx % 4) * 2 < warp * 16 + (tx % 32) // 4 + 8
+ ck.verify(expr_true, tvm.tirx.const(True, "bool"))
+
+
if __name__ == "__main__":
tvm.testing.main()