This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 898f87ffd6 [Bugfix][TIR] Handle AttrStmt of upcoming tir.Var in 
ConvertSSA (#16682)
898f87ffd6 is described below

commit 898f87ffd6ea74fc839f5c002965cd848ce0adb1
Author: Eric Lunderberg <[email protected]>
AuthorDate: Fri Mar 8 21:16:25 2024 -0600

    [Bugfix][TIR] Handle AttrStmt of upcoming tir.Var in ConvertSSA (#16682)
    
    In some cases, an `AttrStmt` may legally refer to a TIR variable that
    hasn't yet been defined.  For example, the
    `"pragma_parallel_launch_point"` attribute, which annotates a variable
    that is about to occur in a ForNode.  Prior to this commit,
    `ConvertSSA` treated the `AttrStmt` as the usage of a variable,
    followed by a nested definition to be de-duplicated.  This resulted in
    the output `AttrStmt` containing a reference to an undefined variable.
    
    This commit updates `ConvertSSA` to handle this case.  If an
    `AttrStmt` refers to a not-yet-defined variable, the body is visited
    before marking it as defined.
    
    This implementation may be simplified in the future by
    moving "pragma_parallel_launch_point" to be an annotation
    on the `ForNode`, rather than an `AttrStmt`.
---
 src/tir/transforms/ir_utils.cc                     | 34 ++++++++++--
 .../test_tir_transform_convert_ssa.py              | 61 +++++++++++++++++++++-
 2 files changed, 90 insertions(+), 5 deletions(-)

diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc
index a85bde6787..584b3cbf58 100644
--- a/src/tir/transforms/ir_utils.cc
+++ b/src/tir/transforms/ir_utils.cc
@@ -358,6 +358,7 @@ class IRConvertSSA final : public StmtExprMutator {
       }
 
       Var var = iter_var->var;
+      bool delayed_define = false;
       if (auto it = function_scope_var_remap_.find(var.get());
           it != function_scope_var_remap_.end()) {
         var = it->second;
@@ -373,8 +374,23 @@ class IRConvertSSA final : public StmtExprMutator {
         function_scope_var_remap_.insert({var.get(), new_var});
         var = new_var;
       } else {
-        function_scope_var_remap_.insert({var.get(), var});
-        defined_.insert(var.get());
+        // The AttrStmt refers to an undefined variable.  This is
+        // allowed for some attributes, such as
+        // "pragma_parallel_launch_point", which annotates a variable
+        // that is about to occur in a ForNode.  In these cases, the
+        // ForNode and the AttrStmt must continue using the same
+        // variable defintion.
+        //
+        // However, other AttrStmt, such as "thread_extent", act as
+        // points of definition for the variable they annotate.  If
+        // the variable has not been defined after visiting the body,
+        // we should mark it as defined before exiting.  This ensures
+        // correct de-duplication between multiple functions.
+        //
+        // This implementation may be simplified in the future by
+        // moving "pragma_parallel_launch_point" to be an annotation
+        // on the `ForNode`, rather than an `AttrStmt`.
+        delayed_define = true;
       }
 
       IterVar new_iter_var;
@@ -387,12 +403,22 @@ class IRConvertSSA final : public StmtExprMutator {
       auto value = VisitExpr(op->value);
       auto body = VisitStmt(op->body);
 
+      Stmt output;
       if (new_iter_var.get() == iter_var && body.same_as(op->body) && 
value.same_as(op->value)) {
-        return GetRef<Stmt>(op);
+        output = GetRef<Stmt>(op);
       } else {
-        return AttrStmt(new_iter_var, op->attr_key, value, body, 
iter_var->span);
+        output = AttrStmt(new_iter_var, op->attr_key, value, body, 
iter_var->span);
       }
 
+      if (delayed_define) {
+        if (!defined_.count(var.get())) {
+          function_scope_var_remap_.insert({var.get(), var});
+          defined_.insert(var.get());
+        }
+      }
+
+      return output;
+
     } else if (const VarNode* v = op->node.as<VarNode>()) {
       Stmt stmt = StmtExprMutator::VisitStmt_(op);
       op = stmt.as<AttrStmtNode>();
diff --git a/tests/python/tir-transform/test_tir_transform_convert_ssa.py 
b/tests/python/tir-transform/test_tir_transform_convert_ssa.py
index 140adcd35b..644ab3b624 100644
--- a/tests/python/tir-transform/test_tir_transform_convert_ssa.py
+++ b/tests/python/tir-transform/test_tir_transform_convert_ssa.py
@@ -17,7 +17,7 @@
 
 import tvm
 import tvm.testing
-from tvm import tir
+from tvm import tir, ir
 from tvm.script import tir as T, ir as I
 
 
@@ -485,5 +485,64 @@ class 
TestThreadIdxReusedWithinAndAcrossFunctions(BaseBeforeAfter):
         return mod
 
 
+class TestTrackForwardDeclarationsInAttrStmt(BaseBeforeAfter):
+    """T.attr statements may refer to a about-to-be-defined tir.Var"""
+
+    def before(self):
+        """Generate the PrimFunc, which is already SSA
+
+        This is constructed directly, rather than using TVMScript or
+        the `tvm.tir.ir_builder`.  This test case requires a
+        `tir.AttrStmt` that references a variable, followed by the
+        `tir.For` defining that variable.  This is not expressible in
+        either TVMScript or `tvm.tir.ir_builder`, as they only provide
+        the loop iterator within the body of the loop.
+        """
+        i0_outer_outer = tir.Var("i0_outer_outer", "int32")
+        i0_outer_inner = tir.Var("i0_outer_inner", "int32")
+        i0_inner = tir.Var("i0_inner", "int32")
+
+        A = tir.decl_buffer(1024, "float32", "A")
+        B = tir.decl_buffer(1024, "float32", "B")
+
+        index = i0_outer_outer * 52 + i0_outer_inner * 4 + i0_inner
+
+        stmt = tir.BufferStore(B, tir.BufferLoad(A, [index]), [index])
+        stmt = tir.IfThenElse(i0_outer_outer * 13 + i0_outer_inner < 256, 
stmt, None)
+        stmt = tir.For(i0_inner, 0, 4, tir.ForKind.VECTORIZED, stmt)
+        stmt = tir.For(i0_outer_inner, 0, 13, tir.ForKind.PARALLEL, stmt)
+        stmt = tir.AttrStmt(
+            T.iter_var(i0_outer_inner, None, "DataPar", ""),
+            "pragma_parallal_barrier_when_finish",
+            1,
+            stmt,
+        )
+        stmt = tir.AttrStmt(
+            T.iter_var(i0_outer_inner, None, "DataPar", ""),
+            "pragma_parallal_stride_pattern",
+            1,
+            stmt,
+        )
+        stmt = tir.For(i0_outer_outer, 0, 20, tir.ForKind.SERIAL, stmt)
+        stmt = tir.AttrStmt(
+            T.iter_var(i0_outer_outer, None, "DataPar", ""),
+            "pragma_parallal_launch_point",
+            1,
+            stmt,
+        )
+
+        A_handle = tir.Var("A_handle", "handle")
+        B_handle = tir.Var("B_handle", "handle")
+
+        func = tir.PrimFunc(
+            [A_handle, B_handle],
+            stmt,
+            buffer_map={A_handle: A, B_handle: B},
+        )
+        return func
+
+    expected = before
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to