This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 913fc4bf63 [Arith] Gate canonical-simplify LT Case 2 on extra scale == 
+1 (#19669)
913fc4bf63 is described below

commit 913fc4bf63a3773a94bd415c3250272151178038
Author: Hongyi Jin <[email protected]>
AuthorDate: Thu Jun 4 13:01:12 2026 -0400

    [Arith] Gate canonical-simplify LT Case 2 on extra scale == +1 (#19669)
    
    ## Summary
    
    `CanonicalSimplifier::Impl::VisitExpr_(LTNode)` Case 2 rewrites
    
        S + xn < 0  ⇔  S/d + xn // d < 0      where d = gcd(scales)
    
    The Case 1 derivation only works when `xn ≥ 0`. With `scale = -1` the
    equivalence becomes `≤` rather than `<`, and the rewrite silently
    strengthens the predicate by dropping the boundary `S/d == xn // d`.
    
    After CSE/inlining, a comparison such as `2*(tx%4) < 16*warp +
    (tx%32)//4` (where `row` and `col` are independent projections of the
    same lane id) reaches canonical_simplify with the divided projection on
    the LHS (scale = -1), and Case 2 folds it to a plain `0 < warp_id` —
    zeroing every thread that should have written `val` in warp 0. The same
    path also folds other configurations (e.g. `0 < (tx%32) - 8*warp`) all
    the way to `False`.
    
    The fix gates Case 2 with `extra->args[0]->scale == 1`. The original
    target shape (`yn % m` with positive scale and `lower_factor=1`, plus
    the `scale = +1 / lower_factor > 1` generalization) is unchanged;
    truly-always-true comparisons still fold to `True`.
    
    ## Test plan
    
    - New regression test `test_simplify_le_negative_scale_extra` in
    `tests/python/arith/test_arith_canonical_simplify.py` — asserts on
    simplified `PrimExpr`, no GPU required; pre-fix fails, post-fix passes.
    It also pins the buggy `scale = -1` shapes to their unsimplified form,
    confirms the `scale = +1` Case 2 path still optimizes, and re-asserts
    the truly-always-true variant still folds to `True`.
    - Existing `test_simplify_le` (the original Case 2 target with `scale =
    +1`) still passes.
    - `tests/python/arith/test_arith_canonical_simplify.py` — 16 passed.
    - Full `tests/python/arith/` — 932 passed (1 pre-existing flaky
    random-seed failure in `test_arith_solve_linear_equations.py` unrelated
    to this change, passes on rerun).
---
 src/arith/canonical_simplify.cc                    | 11 +++++-
 .../python/arith/test_arith_canonical_simplify.py  | 44 ++++++++++++++++++++++
 2 files changed, 53 insertions(+), 2 deletions(-)

diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc
index ac1b89f97a..0001afbdfe 100644
--- a/src/arith/canonical_simplify.cc
+++ b/src/arith/canonical_simplify.cc
@@ -1419,10 +1419,17 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const 
LTNode* op) {
       // Case 1. 0 <= xn < d
       divisible.CopyOnWrite()->DivideBy(gcd);
       return Rewriter::VisitExpr(divisible->Normalize() < make_zero(dtype));
-    } else if (extra->args.size() == 1 &&
+    } else if (extra->args.size() == 1 && extra->args[0]->scale == 1 &&
                extra->args[0]->upper_factor != ConstIntBoundNode::kPosInf &&
                extra->args[0]->upper_factor % (gcd * 
extra->args[0]->lower_factor) == 0) {
-      // Case 2. xn == yn % m, where m % d == 0
+      // Case 2. xn == ((yn % m) // L), scale = +1, m % (d*L) == 0.
+      // S + xn < 0 with S divisible by d  ⇔  S/d + xn // d < 0, because
+      // xn % d ∈ [0, d) lets us drop the remainder via the Case 1 argument,
+      // and xn // d = (yn // (d*L)) % (m/(d*L)).
+      // The scale must be +1: with scale = -1 the equivalence becomes ≤
+      // rather than <, so the rewrite would strengthen the predicate and
+      // silently drop the boundary S/d == xn // d (e.g. row > col where
+      // row and col are independent projections of the same lane id).
       divisible.CopyOnWrite()->DivideBy(gcd);
       const auto split_expr = extra->args[0];
       int64_t lower_factor = gcd * extra->args[0]->lower_factor;
diff --git a/tests/python/arith/test_arith_canonical_simplify.py 
b/tests/python/arith/test_arith_canonical_simplify.py
index 35ecf3b700..4d81f9031c 100644
--- a/tests/python/arith/test_arith_canonical_simplify.py
+++ b/tests/python/arith/test_arith_canonical_simplify.py
@@ -490,5 +490,49 @@ def test_simplify_le():
     ck.verify(x * 1024 + y < z * 7168, x - z * 7 < 0)
 
 
+def test_simplify_le_negative_scale_extra():
+    """Regression: Case 2 of the LT-with-divisible-coeffs rewrite must not
+    fire when the leftover split term has a negative scale.
+
+    The rewrite ``S + xn < 0  ⇔  S/d + xn // d < 0`` is only sound when
+    the leftover ``xn`` has scale ``+1``. With scale ``-1`` the equivalence
+    becomes ``≤`` rather than ``<`` and the rewrite silently strengthens
+    the predicate. The original bug surfaced as ``row > col`` masks of
+    ``.16x*b`` tcgen05 readbacks collapsing to plain ``warp_id > k``
+    comparisons (lower-triangle writes were silently dropped on the
+    boundary warp).
+    """
+    ck = CanonicalChecker()
+    tx = tvm.tirx.Var("tx", "int32")
+    warp = tvm.tirx.Var("warp", "int32")
+    ck.analyzer.bind(tx, tvm.ir.Range(0, 128))
+    ck.analyzer.bind(warp, tvm.ir.Range(0, 4))
+
+    # Same-source joint projection: the comparison genuinely depends on tx
+    # at warp == 0 (e.g. tx == 4 ⇒ 0 < 1 = True; tx == 1 ⇒ 2 < 0 = False),
+    # so the simplifier must keep both sides.  Pre-fix this folded to
+    # ``0 < warp`` and dropped every True case in warp 0.
+    expr = (tx % 4) * 2 < warp * 16 + (tx % 32) // 4
+    ck.verify(expr, expr)
+
+    # The simpler ``scale = -1`` with ``lower_factor = 1`` shape.  Pre-fix
+    # this folded to ``False`` (drops all warp >= 1 cases where the rhs
+    # actually exceeds 8*warp).
+    expr = warp * 8 < (tx % 32)
+    ck.verify(expr, expr)
+
+    # The corresponding ``scale = +1`` Case 2 path (the rewrite this guards)
+    # must still optimize — verifies we did not over-restrict.
+    x1 = tvm.tirx.Var("x1", "int32")
+    y1 = tvm.tirx.Var("y1", "int32")
+    ck.verify(x1 * 64 + (y1 % 64) < 120, x1 * 8 + (y1 % 64) // 8 < 15)
+
+    # The truly-always-true comparison that arises from the same kernel
+    # (``r = 2 / va = 1`` in the tcgen05.ld.16x256b readback) must still
+    # fold to True so the masked store can be elided.
+    expr_true = (tx % 4) * 2 < warp * 16 + (tx % 32) // 4 + 8
+    ck.verify(expr_true, tvm.tirx.const(True, "bool"))
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to