This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3c6f9c9bcc [Arith] Added simplification rule for multiple equality
compares (#15628)
3c6f9c9bcc is described below
commit 3c6f9c9bcc2b3fa2ca30ae1f6174b4f536f6d368
Author: Eric Lunderberg <[email protected]>
AuthorDate: Mon Aug 28 19:28:01 2023 -0400
[Arith] Added simplification rule for multiple equality compares (#15628)
The expression `(x==y) && (x==z)` requires that `y==z`. When `y` and
`z` are constants, this can allow better constant folding by
rewriting `(x==c1) && (x==c2)` into `(x==c1) && (c1==c2)`.
This commit adds the above rewrite, and the corresponding rewrite of
the negative expression.
---
src/arith/rewrite_simplify.cc | 2 ++
tests/python/unittest/test_arith_rewrite_simplify.py | 2 ++
2 files changed, 4 insertions(+)
diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc
index 40088fd963..63becf8eb7 100644
--- a/src/arith/rewrite_simplify.cc
+++ b/src/arith/rewrite_simplify.cc
@@ -1856,6 +1856,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const
AndNode* op) {
}),
cfalse, c2.Eval()->value > c1.Eval()->value);
+ TVM_TRY_REWRITE((x == c1) && (x == c2), (x == c1) && (c1 == c2));
TVM_TRY_REWRITE(matches_one_of(x == c1 && x != c2, x != c2 && x == c1), x ==
c1 && c1 != c2);
TVM_TRY_RECURSIVE_REWRITE(matches_one_of(floordiv(x, c2) == c1 &&
floormod(x, c2) == c3,
@@ -2000,6 +2001,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const
OrNode* op) {
TVM_TRY_REWRITE_IF(x <= c1 || c2 <= x, ctrue, c2.Eval()->value <=
c1.Eval()->value + 1);
TVM_TRY_REWRITE_IF(c2 <= x || x <= c1, ctrue, c2.Eval()->value <=
c1.Eval()->value + 1);
+ TVM_TRY_REWRITE(x != c1 || x != c2, x != c1 || c1 != c2);
TVM_TRY_REWRITE(x != c1 || x == c2, x != c1 || c1 == c2);
TVM_TRY_REWRITE(x == c2 || x != c1, x != c1 || c1 == c2);
diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py
b/tests/python/unittest/test_arith_rewrite_simplify.py
index 46ac0f9751..0b0a43a7d3 100644
--- a/tests/python/unittest/test_arith_rewrite_simplify.py
+++ b/tests/python/unittest/test_arith_rewrite_simplify.py
@@ -951,6 +951,7 @@ class TestLogical(BaseCompare):
TestCase(tvm.tir.And(x <= 1, 2 <= x), tvm.tir.const(False, "bool")),
TestCase(tvm.tir.And(2 <= x, x <= 1), tvm.tir.const(False, "bool")),
TestCase(tvm.tir.And(x == 1, x != 2), x == 1),
+ TestCase(tvm.tir.And(x == 1, x == 2), tvm.tir.const(False, "bool")),
TestCase(tvm.tir.Or(tvm.tir.EQ(x, y), tvm.tir.NE(x, y)),
tvm.tir.const(True, "bool")),
TestCase(tvm.tir.Or(tvm.tir.NE(x, y), tvm.tir.EQ(x, y)),
tvm.tir.const(True, "bool")),
TestCase(tvm.tir.Or(x > y, tvm.tir.Not(x > y)), tvm.tir.const(True,
"bool")),
@@ -965,6 +966,7 @@ class TestLogical(BaseCompare):
TestCase(tvm.tir.Or(x <= 1, 2 <= x), tvm.tir.const(True, "bool")),
TestCase(tvm.tir.Or(2 <= x, x <= 1), tvm.tir.const(True, "bool")),
TestCase(tvm.tir.Or(x != 1, x == 2), x != 1),
+ TestCase(tvm.tir.Or(x != 1, x != 2), tvm.tir.const(True, "bool")),
TestCase(
tvm.tir.Or(x == 1, tvm.tir.Or(y == 1, z == 1)),
tvm.tir.Or(tvm.tir.Or(x == 1, y == 1), z == 1),