This is an automated email from the ASF dual-hosted git repository.
bohan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4b1bd6d431 [TIR] Further robustify floordiv/mod intrin lowering to
prevent overflow (#18699)
4b1bd6d431 is described below
commit 4b1bd6d431a662635b4f2817b2f7edd8be8ccbd3
Author: Tianqi Chen <[email protected]>
AuthorDate: Fri Jan 30 13:17:48 2026 -0500
[TIR] Further robustify floordiv/mod intrin lowering to prevent overflow
(#18699)
This PR further robustifies floordiv/mod intrin lowering in cases where
we can to shift negative values to positive range but also need to
carefully prevent overflow in intermediate compiler checks.
---
src/tir/transforms/lower_intrin.cc | 161 ++++++++-------------
.../test_tir_transform_lower_intrin.py | 89 +++++++++---
2 files changed, 125 insertions(+), 125 deletions(-)
diff --git a/src/tir/transforms/lower_intrin.cc
b/src/tir/transforms/lower_intrin.cc
index 6a7d2b2776..ef844d9e05 100644
--- a/src/tir/transforms/lower_intrin.cc
+++ b/src/tir/transforms/lower_intrin.cc
@@ -115,59 +115,15 @@ class IntrinInjecter : public
tvm::arith::IRMutatorWithAnalyzer {
if (analyzer_->CanProveGreaterEqual(op->a, 0) ||
analyzer_->CanProveGreaterEqual(e, 0)) {
return truncdiv(op->a, op->b);
}
-
- // If the numerator's lower bound is known, express the floordiv
- // in terms of truncdiv using only positive operands.
-
- // The optimization below rewrites expressions involving `-a_min + (b -
1)`.
- // Without proper bounds checking, this expression may overflow the dtype
- // maximum, leading to non-equivalent transformations.
- // To ensure safety, we require:
- // b_max - a_min <= max_value_of_dtype + 1
- // This provides a conservative upper bound that prevents overflow and
- // preserves the original semantics.
- arith::ConstIntBound const_int_bound_a =
analyzer_->const_int_bound(op->a);
- arith::ConstIntBound const_int_bound_b =
analyzer_->const_int_bound(op->b);
- const int64_t max_value_of_dtype =
- Downcast<IntImm>(tvm::max_value(op->a->dtype.element_of()))->value;
- if (const_int_bound_a->min_value < 0 &&
- const_int_bound_b->max_value - const_int_bound_a->min_value <=
max_value_of_dtype + 1) {
- // The goal is to write floordiv(a,b) in terms of truncdiv, without
using
- // negative operands.
- //
- // For any integer c
- //
- // floordiv(a,b) == floordiv(a + b*c - b*c, b)
- // == floordiv(a + b*c, b) - c
- //
- // Choosing `c = ceildiv(-a_min, b)`. This can be rewritten in terms
of
- // truncdiv as follows.
- //
- // c == ceildiv(-a_min,b)
- // == floordiv(-a_min + (b-1), b)
- // == truncdiv(-a_min + (b-1), b)
- //
- // When substituted into `a + b*c`, this results in a positive
argument.
- //
- // a + b*c
- // == a + b*ceildiv(-a_min,b)
- // == a - b*floordiv(a_min,b)
- // >= a - b*floordiv(a,b)
- // == floormod(a, b)
- // >= 0
- //
- // Since the argument is positive, this allows floordiv to be written
as
- // followed.
- //
- // floordiv(a,b)
- // == floordiv(a + b*c, b) - c
- // == truncdiv(a + b*c, b) - c
- IntImm min(op->a->dtype.element_of(), const_int_bound_a->min_value);
- PrimExpr ceildiv = truncdiv((op->b - 1) - min, op->b);
- PrimExpr offset_numerator = analyzer_->Simplify(op->a + op->b *
ceildiv);
- return truncdiv(offset_numerator, op->b) - ceildiv;
+ if (const IntImmNode* b_as_intimm = op->b.as<IntImmNode>()) {
+ int64_t b_value = b_as_intimm->value;
+ if (auto opt_c_value = TryFindShiftCoefficientForPositiveRange(op->a,
b_value)) {
+ int64_t c_value = *opt_c_value;
+ // now we can safely lower to truncdiv
+ return truncdiv(op->a + make_const(dtype, b_value * c_value), op->b)
-
+ make_const(dtype, c_value);
+ }
}
-
DLOG(INFO) << "LowerFloorDiv: Cannot decide the sign of divident";
PrimExpr rdiv = truncdiv(op->a, op->b);
PrimExpr rmod = truncmod(op->a, op->b);
@@ -221,58 +177,14 @@ class IntrinInjecter : public
tvm::arith::IRMutatorWithAnalyzer {
if (analyzer_->CanProveGreaterEqual(op->a, 0)) {
return truncmod(op->a, op->b);
}
-
- // If the numerator's lower bound is known, express the floormod
- // in terms of truncmod using only positive operands.
-
- // The optimization below rewrites expressions involving `-a_min + (b -
1)`.
- // Without proper bounds checking, this expression may overflow the dtype
- // maximum, leading to non-equivalent transformations.
- // To ensure safety, we require:
- // b_max - a_min <= max_value_of_dtype + 1
- // This provides a conservative upper bound that prevents overflow and
- // preserves the original semantics.
- arith::ConstIntBound const_int_bound_a =
analyzer_->const_int_bound(op->a);
- arith::ConstIntBound const_int_bound_b =
analyzer_->const_int_bound(op->b);
- const int64_t max_value_of_dtype =
- Downcast<IntImm>(tvm::max_value(op->a->dtype.element_of()))->value;
- if (const_int_bound_a->min_value < 0 &&
- const_int_bound_b->max_value - const_int_bound_a->min_value <=
max_value_of_dtype + 1) {
- // The goal is to write floormod(a,b) in terms of truncdiv and
truncmod,
- // without using negative operands.
- //
- // For any integer c
- //
- // floormod(a, b) == floormod(a + b*c, b)
- //
- // Choosing `c = ceildiv(-a_min, b)`. This can be rewritten in terms
of
- // truncdiv as follows.
- //
- // c == ceildiv(-a_min,b)
- // == floordiv(-a_min + (b-1), b)
- // == truncdiv(-a_min + (b-1), b)
- //
- // When substituted into `a + b*c`, this results in a positive
argument.
- //
- // a + b*c
- // == a + b*ceildiv(-a_min,b)
- // == a - b*floordiv(a_min,b)
- // >= a - b*floordiv(a,b)
- // == floormod(a, b)
- // >= 0
- //
- // Since the argument is positive, this allows floordiv to be written
as
- // followed.
- //
- // floormod(a,b)
- // == floormod(a + b*c, b)
- // == truncmod(a + b*c, b)
- IntImm min(op->a->dtype.element_of(), const_int_bound_a->min_value);
- PrimExpr ceildiv = truncdiv(-min + (op->b - 1), op->b);
- PrimExpr offset_numerator = analyzer_->Simplify(op->a + op->b *
ceildiv);
- return truncmod(offset_numerator, op->b);
+ if (const IntImmNode* b_as_intimm = op->b.as<IntImmNode>()) {
+ int64_t b_value = b_as_intimm->value;
+ if (auto opt_c_value = TryFindShiftCoefficientForPositiveRange(op->a,
b_value)) {
+ int64_t c_value = *opt_c_value;
+ // floormod(a, b) == floormod(a + b*c, b) == truncmod(a + b*c, b)
+ return truncmod(op->a + make_const(dtype, c_value * b_value), op->b);
+ }
}
-
DLOG(INFO) << "LowerFloorMod: Cannot decide the sign of divident";
// NOTE:condition on b >= 0.
// mod(a, b) < 0 will imply we are doing ceildiv,
@@ -388,6 +300,49 @@ class IntrinInjecter : public
tvm::arith::IRMutatorWithAnalyzer {
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
+ /*!
+ * \brief Try to find a shift co-efficient c such that a + b*c positive and
does not overflow.
+ *
+ * \param a the dividend
+ * \param b_value the divisor
+ * \return the shift co-efficient c, or nullopt if not found
+ */
+ std::optional<int64_t> TryFindShiftCoefficientForPositiveRange(const
PrimExpr& a,
+ int64_t
b_value) {
+ if (b_value <= 0) {
+ return std::nullopt;
+ }
+ // NOTE: we need to be very careful in the checks below, to make sure
+ // all the intermediate calculations in both compiler checks and runtime
checks
+ // do not overflow
+ arith::ConstIntBound const_int_bound_a = analyzer_->const_int_bound(a);
+ if (const_int_bound_a->min_value >= 0) {
+ return std::nullopt;
+ }
+ const int64_t max_value_of_dtype =
+ Downcast<IntImm>(tvm::max_value(a->dtype.element_of()))->value;
+
+ // NOTE: ensures that (b-1) - a_min does not overflow
+ // also note: max_value_of_dtype + const_int_bound_a->min_value won't
overflow
+ // since a_min is negative, adding it to a positive value will not overflow
+ if (b_value - 1 > max_value_of_dtype + const_int_bound_a->min_value) {
+ return std::nullopt;
+ }
+ int64_t c_value = ((b_value - 1) - const_int_bound_a->min_value) / b_value;
+ ICHECK_GT(c_value, 0);
+ // NOTE: the c_value * b_value risks in overflow
+ if (c_value > max_value_of_dtype / b_value) return std::nullopt;
+ // need to check if the offset numerator will overflow
+ // to ensure if don't overflow, we need to use max_value_of_dtype -
b_value * c_value
+ // note that b_value * c_value is positive, max_value_of_dtype is also
positive, so the
+ // subtraction will not overflow
+ if (const_int_bound_a->max_value > max_value_of_dtype - b_value * c_value)
{
+ // a + b * c risks overflow
+ return std::nullopt;
+ }
+ return c_value;
+ }
+
// attribute maps, shared only when FLegalize == FLowerIntrinsic
std::vector<OpAttrMap<FLowerGeneral>> attr_maps_;
FLowerGeneral fma_{nullptr};
diff --git a/tests/python/tir-transform/test_tir_transform_lower_intrin.py
b/tests/python/tir-transform/test_tir_transform_lower_intrin.py
index 63f37e6f41..a0a6ab2508 100644
--- a/tests/python/tir-transform/test_tir_transform_lower_intrin.py
+++ b/tests/python/tir-transform/test_tir_transform_lower_intrin.py
@@ -35,24 +35,35 @@ def lower_intrin(params, stmt):
return stmt.value if lower_expr else stmt.body
-def check_value(expr, vx, vy, data, fref):
+def check_value(expr, variables, data, fref):
+ """
+ Check that expr evaluates to fref(*row) for each row in data.
+ variables: list of TIR vars [x] or [x, y] bound to the columns of data.
+ data: list of tuples, each tuple has len(variables) elements.
+ """
n = len(data)
- A = te.placeholder((n,), name="A", dtype=expr.dtype)
- B = te.placeholder((n,), name="B", dtype=expr.dtype)
+ num_vars = len(variables)
+ assert num_vars >= 1 and all(len(row) == num_vars for row in data)
+
+ placeholders = [
+ te.placeholder((n,), name=f"v{i}", dtype=variables[i].dtype) for i in
range(num_vars)
+ ]
def make_binds(i):
x = expr
- x = tvm.tir.Let(vx, A[i], x)
- x = tvm.tir.Let(vy, B[i], x)
+ for j in range(num_vars - 1, -1, -1):
+ x = tvm.tir.Let(variables[j], placeholders[j][i], x)
return x
C = te.compute((n,), make_binds)
- f = tvm.compile(te.create_prim_func([A, B, C]), "llvm")
- a = tvm.runtime.tensor(np.array([x for x, y in data], dtype=expr.dtype))
- b = tvm.runtime.tensor(np.array([y for x, y in data], dtype=expr.dtype))
- c = tvm.runtime.tensor(np.zeros(len(data), dtype=expr.dtype))
- f(a, b, c)
- cref = np.array([fref(x, y) for x, y in data])
+ f = tvm.compile(te.create_prim_func(placeholders + [C]), "llvm")
+ arrays = [
+ tvm.runtime.tensor(np.array([row[j] for row in data],
dtype=variables[j].dtype))
+ for j in range(num_vars)
+ ]
+ c = tvm.runtime.tensor(np.zeros(n, dtype=expr.dtype))
+ f(*arrays, c)
+ cref = np.array([fref(*row) for row in data])
np.testing.assert_equal(c.numpy(), cref)
@@ -75,29 +86,29 @@ def test_lower_floordiv():
zero = tvm.tir.const(0, dtype)
# no constraints
res = lower_intrin([x, y], tvm.te.floordiv(x, y))
- check_value(res, x, y, data, lambda a, b: a // b)
+ check_value(res, [x, y], data, lambda a, b: a // b)
# rhs >= 0
res = lower_intrin([x, y], tvm.tir.Select(y >= 0, tvm.te.floordiv(x,
y), zero))
- check_value(res, x, y, data, lambda a, b: a // b if b > 0 else 0)
+ check_value(res, [x, y], data, lambda a, b: a // b if b > 0 else 0)
# involves max
res = lower_intrin(
[x, y], tvm.tir.Select(y >= 0, tvm.te.max(tvm.te.floordiv(x, y),
zero), zero)
)
- check_value(res, x, y, data, lambda a, b: max(a // b, 0) if b > 0 else
0)
+ check_value(res, [x, y], data, lambda a, b: max(a // b, 0) if b > 0
else 0)
# lhs >= 0
res = lower_intrin(
[x, y], tvm.tir.Select(tvm.tir.all(y >= 0, x >= 0),
tvm.te.floordiv(x, y), zero)
)
- check_value(res, x, y, data, lambda a, b: a // b if b > 0 and a >= 0
else 0)
+ check_value(res, [x, y], data, lambda a, b: a // b if b > 0 and a >= 0
else 0)
# const power of two
res = lower_intrin([x, y], tvm.te.floordiv(x, tvm.tir.const(8,
dtype=dtype)))
- check_value(res, x, y, [(a, b) for a, b in data if b == 8], lambda a,
b: a // b)
+ check_value(res, [x, y], [(a, b) for a, b in data if b == 8], lambda
a, b: a // b)
# floordiv(x + m, k), m and k are positive constant. 2 <= m <= k-1.
res = lower_intrin(
[x, y],
tvm.te.floordiv(x + tvm.tir.const(4, dtype=dtype),
tvm.tir.const(5, dtype=dtype)),
)
- check_value(res, x, y, [(a, b) for a, b in data if b == 5], lambda a,
b: (a + 4) // b)
+ check_value(res, [x, y], [(a, b) for a, b in data if b == 5], lambda
a, b: (a + 4) // b)
@tvm.testing.requires_llvm
@@ -109,26 +120,60 @@ def test_lower_floormod():
zero = tvm.tir.const(0, dtype)
# no constraints
res = lower_intrin([x, y], tvm.te.floormod(x, y))
- check_value(res, x, y, data, lambda a, b: a % b)
+ check_value(res, [x, y], data, lambda a, b: a % b)
# rhs >= 0
res = lower_intrin([x, y], tvm.tir.Select(y >= 0, tvm.te.floormod(x,
y), zero))
- check_value(res, x, y, data, lambda a, b: a % b if b > 0 else 0)
+ check_value(res, [x, y], data, lambda a, b: a % b if b > 0 else 0)
# lhs >= 0
res = lower_intrin(
[x, y], tvm.tir.Select(tvm.tir.all(y >= 0, x >= 0),
tvm.te.floormod(x, y), zero)
)
- check_value(res, x, y, data, lambda a, b: a % b if b > 0 and a >= 0
else 0)
+ check_value(res, [x, y], data, lambda a, b: a % b if b > 0 and a >= 0
else 0)
# const power of two
res = lower_intrin([x, y], tvm.te.floormod(x, tvm.tir.const(8,
dtype=dtype)))
- check_value(res, x, y, [(a, b) for a, b in data if b == 8], lambda a,
b: a % b)
+ check_value(res, [x, y], [(a, b) for a, b in data if b == 8], lambda
a, b: a % b)
# floormod(x + m, k), m and k are positive constant. 2 <= m <= k-1.
res = lower_intrin(
[x, y],
tvm.te.floormod(x + tvm.tir.const(4, dtype=dtype),
tvm.tir.const(5, dtype=dtype)),
)
- check_value(res, x, y, [(a, b) for a, b in data if b == 5], lambda a,
b: (a + 4) % b)
+ check_value(res, [x, y], [(a, b) for a, b in data if b == 5], lambda
a, b: (a + 4) % b)
+
+
[email protected]_llvm
+def test_lower_floordiv_overflow_checks():
+ """
+ Regression tests for overflow checks in
TryFindShiftCoefficientForPositiveRange.
+ Divisor is constant 3 (not 1 to avoid CSE, not power-of-two so we don't
take the shift path).
+ Reuses lower_intrin and check_value; overflow tests use one var [x].
+ """
+ # Check 3: (b-1) - a_min must not overflow (numerator and C++ int64).
+ # x (int64) full range -> min_value = -2^63. With b = 3: numerator = 2 -
(-2^63) > LLONG_MAX.
+ x = te.var("x", dtype="int64")
+ res = lower_intrin([x], tvm.te.floordiv(x, tvm.tir.const(3, "int64")))
+ data_check3 = [(-(2**63),), (0,), (100,)]
+ check_value(res, [x], data_check3, lambda a: a // 3)
+
+ # Check 4: c_value * b_value must not overflow dtype.
+ # x (int16) full range -> min_value = -32768, c = ceil(32770/3) = 10923;
10923*3 > 32767.
+ x = te.var("x", dtype="int16")
+ res = lower_intrin([x], tvm.te.floordiv(x, tvm.tir.const(3, "int16")))
+ data_check4 = [(-32768,), (0,), (100,)]
+ check_value(res, [x], data_check4, lambda a: a // 3)
+
+ # Check 5: a_max + b*c must not overflow (offset numerator).
+ # tir.min(tir.max(x, -10), 32758) can give bounds [-10, 32758]; b=3, c=4;
a_max + 12 > 32767.
+ # In practice this path may not be triggered. This test still validates
correct lowering.
+ x = te.var("x", dtype="int16")
+ clamped = tvm.tir.min(
+ tvm.tir.max(x, tvm.tir.const(-10, "int16")), tvm.tir.const(32758,
"int16")
+ )
+ res = lower_intrin([x], tvm.te.floordiv(clamped, tvm.tir.const(3,
"int16")))
+ data_check5 = [(-10,), (0,), (32758,), (32757,)]
+ check_value(res, [x], data_check5, lambda a: (min(max(a, -10), 32758)) //
3)
if __name__ == "__main__":
test_lower_floordiv()
test_lower_floormod()
+ test_lower_floordiv_overflow_checks()