This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new fc606c09b2 [TIR][TVMScript] Cleaner printing of And/Or chains (#13432)
fc606c09b2 is described below

commit fc606c09b223df445dd6bc6a33d3e3bfbd670535
Author: Eric Lunderberg <[email protected]>
AuthorDate: Fri Nov 25 16:37:45 2022 -0600

    [TIR][TVMScript] Cleaner printing of And/Or chains (#13432)
    
    Even though the operator precedence of `And` has a higher precedence
    than `Or`, removing parentheses based on this precedence can harm
    readability.  This commit adds an exception to the TVMScript rules for
    parentheses, to always insert parentheses between `And` and `Or`
    operators.
    
    In addition, adding a rewrite rule to preferentially produce And/Or
    chains that may be expressed in a single left-associative chain of
    operators.
    
    Between these two changes, the readability of boolean expressions can
    be improved.  Below is the motivating example for this change.  In
    each case, the output had been passed through the `black` formatter.
    Both expressions are equivalent, but the before-case was much more
    difficult to read.
    
    ```python
    x = (
        AAA == 0
        and BBB < 4
        or AAA == 7
        and 6 <= BBB
        or (CCC == 0 and DDD < 4 or CCC == 7 and 6 <= DDD)
    )
    
    x = (
        (AAA == 0 and BBB < 4)
        or (AAA == 7 and 6 <= BBB)
        or (CCC == 0 and DDD < 4)
        or (CCC == 7 and 6 <= DDD)
    )
    ```
---
 src/arith/rewrite_simplify.cc                      |  9 ++++--
 src/printer/tvmscript_printer.cc                   |  6 ++--
 .../python/unittest/test_arith_rewrite_simplify.py |  9 ++++++
 .../unittest/test_tir_schedule_transform_layout.py | 12 ++++----
 tests/python/unittest/test_tvmscript_roundtrip.py  | 32 ++++++++++++++++++++++
 5 files changed, 57 insertions(+), 11 deletions(-)

diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc
index 90c448f4ea..c9d92f9925 100644
--- a/src/arith/rewrite_simplify.cc
+++ b/src/arith/rewrite_simplify.cc
@@ -1737,7 +1737,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const 
AndNode* op) {
   }
 
   // Pattern var to match any expression
-  PVar<PrimExpr> x, y;
+  PVar<PrimExpr> x, y, z;
   // Pattern var match IntImm
   PVar<IntImm> c1, c2, c3;
   PVar<int> lanes;
@@ -1815,6 +1815,9 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const 
AndNode* op) {
                             c1 * c2 + c3 < x && x < (c1 + 1) * c2);
   TVM_TRY_RECURSIVE_REWRITE(c3 < floormod(x, c2) && floordiv(x, c2) == c1,
                             c1 * c2 + c3 < x && x < (c1 + 1) * c2);
+
+  TVM_TRY_RECURSIVE_REWRITE(x && (y && z), (x && y) && z);
+
   return ret;
 }
 
@@ -1874,7 +1877,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const 
OrNode* op) {
   }
 
   // Pattern var to match any expression
-  PVar<PrimExpr> x, y;
+  PVar<PrimExpr> x, y, z;
   // Pattern var match IntImm
   PVar<IntImm> c1, c2;
   PVar<int> lanes;
@@ -1912,6 +1915,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const 
OrNode* op) {
   TVM_TRY_RECURSIVE_REWRITE(x == y || x < y, x <= y);
   TVM_TRY_RECURSIVE_REWRITE(y == x || x < y, x <= y);
 
+  TVM_TRY_RECURSIVE_REWRITE(x || (y || z), (x || y) || z);
+
   return ret;
 }
 
diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc
index 8f012f3b0e..05e514295c 100644
--- a/src/printer/tvmscript_printer.cc
+++ b/src/printer/tvmscript_printer.cc
@@ -823,13 +823,15 @@ bool WillPrintConstScalar(const PrimExpr& expr) {
     ICHECK(rhs_precedence != ExprPrecedence::kUnknown);                        
                   \
     /* Update out_precedence of current node. */                               
                   \
     *out_precedence = OpPrecedence;                                            
                   \
-    if (lhs_precedence > OpPrecedence) {                                       
                   \
+    if (lhs_precedence > OpPrecedence ||                                       
                   \
+        (lhs_precedence == ExprPrecedence::kAnd && OpPrecedence == 
ExprPrecedence::kOr)) {        \
       doc << "(" << lhs_doc << ")";                                            
                   \
     } else {                                                                   
                   \
       doc << lhs_doc;                                                          
                   \
     }                                                                          
                   \
     doc << OpString;                                                           
                   \
-    if (rhs_precedence >= OpPrecedence) {                                      
                   \
+    if (rhs_precedence >= OpPrecedence ||                                      
                   \
+        (rhs_precedence == ExprPrecedence::kAnd && OpPrecedence == 
ExprPrecedence::kOr)) {        \
       doc << "(" << rhs_doc << ")";                                            
                   \
     } else {                                                                   
                   \
       doc << rhs_doc;                                                          
                   \
diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py 
b/tests/python/unittest/test_arith_rewrite_simplify.py
index 4199cb9a56..d6c2cfe8bb 100644
--- a/tests/python/unittest/test_arith_rewrite_simplify.py
+++ b/tests/python/unittest/test_arith_rewrite_simplify.py
@@ -992,6 +992,15 @@ def test_logical_simplify():
     ck.verify(tvm.tir.Or(2 <= x, x <= 1), tvm.tir.const(True, "bool"))
     ck.verify(tvm.tir.Or(x != 1, x == 2), x != 1)
 
+    ck.verify(
+        tvm.tir.Or(x == 1, tvm.tir.Or(y == 1, z == 1)),
+        tvm.tir.Or(tvm.tir.Or(x == 1, y == 1), z == 1),
+    )
+    ck.verify(
+        tvm.tir.And(x == 1, tvm.tir.And(y == 1, z == 1)),
+        tvm.tir.And(tvm.tir.And(x == 1, y == 1), z == 1),
+    )
+
 
 def test_let_simplify():
     ck = RewriteChecker()
diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py 
b/tests/python/unittest/test_tir_schedule_transform_layout.py
index e904789223..faeaf87686 100644
--- a/tests/python/unittest/test_tir_schedule_transform_layout.py
+++ b/tests/python/unittest/test_tir_schedule_transform_layout.py
@@ -872,13 +872,11 @@ class 
TestTransformLayoutWithVar(tvm.testing.CompareBeforeAfter):
                 B[vi, vj] = T.if_then_else(
                     # Checks if the transform introduced padding
                     -16 % n != 0
-                    and (
-                        # If so, is vi in the last group (which may
-                        # include padding).
-                        (vj + vi * n) // n == 16 // n
-                        # And is vj within the padding
-                        and 16 % n <= (vj + vi * n) % n
-                    ),
+                    # If so, is vi in the last group (which may
+                    # include padding).
+                    and (vj + vi * n) // n == 16 // n
+                    # And is vj within the padding
+                    and 16 % n <= (vj + vi * n) % n,
                     0,
                     A[vj + vi * n],
                     dtype="int32",
diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py 
b/tests/python/unittest/test_tvmscript_roundtrip.py
index 53b3cd69ea..0ead66bd60 100644
--- a/tests/python/unittest/test_tvmscript_roundtrip.py
+++ b/tests/python/unittest/test_tvmscript_roundtrip.py
@@ -3506,6 +3506,37 @@ def elif_chain_with_else():
     return func
 
 
+def nested_boolean_expressions():
+    expressions = {
+        "and_lhs_and": lambda i, j, k: tir.all(tir.all(i, j), k),
+        "and_rhs_and": lambda i, j, k: tir.all(i, tir.all(j, k)),
+        "and_lhs_or": lambda i, j, k: tir.all(tir.any(i, j), k),
+        "and_rhs_or": lambda i, j, k: tir.all(i, tir.any(j, k)),
+        "or_lhs_and": lambda i, j, k: tir.any(tir.all(i, j), k),
+        "or_rhs_and": lambda i, j, k: tir.any(i, tir.all(j, k)),
+        "or_lhs_or": lambda i, j, k: tir.any(tir.any(i, j), k),
+        "or_rhs_or": lambda i, j, k: tir.any(i, tir.any(j, k)),
+        "and_of_ors": lambda i, j, k: tir.all(tir.any(i, j), tir.any(j, k), 
tir.any(i, k), i, j, k),
+        "or_of_ands": lambda i, j, k: tir.any(tir.all(i, j), tir.all(j, k), 
tir.all(i, k), i, j, k),
+    }
+
+    def make_ir_generator(name, expression):
+        def inner():
+            @T.prim_func
+            def func(A: T.Buffer[1, "bool"], i: T.bool, j: T.bool, k: T.bool):
+                A[0] = expression(i, j, k)
+
+            return func
+
+        inner.__name__ = f"nested_boolean_expr_{name}"
+        return inner
+
+    for name, expression in expressions.items():
+        generator = make_ir_generator(name, expression)
+
+        yield generator
+
+
 ir_generator = tvm.testing.parameter(
     opt_gemm_normalize,
     opt_gemm_lower,
@@ -3561,6 +3592,7 @@ ir_generator = tvm.testing.parameter(
     if_true_else,
     elif_chain_without_else,
     elif_chain_with_else,
+    *nested_boolean_expressions(),
 )
 
 

Reply via email to