This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 898f87ffd6 [Bugfix][TIR] Handle AttrStmt of upcoming tir.Var in
ConvertSSA (#16682)
898f87ffd6 is described below
commit 898f87ffd6ea74fc839f5c002965cd848ce0adb1
Author: Eric Lunderberg <[email protected]>
AuthorDate: Fri Mar 8 21:16:25 2024 -0600
[Bugfix][TIR] Handle AttrStmt of upcoming tir.Var in ConvertSSA (#16682)
In some cases, an `AttrStmt` may legally refer to a TIR variable that
hasn't yet been defined. For example, the
`"pragma_parallel_launch_point"` attribute, which annotates a variable
that is about to occur in a ForNode. Prior to this commit,
`ConvertSSA` treated the `AttrStmt` as the usage of a variable,
followed by a nested definition to be de-duplicated. This resulted in
the output `AttrStmt` containing a reference to an undefined variable.
This commit updates `ConvertSSA` to handle this case. If an
`AttrStmt` refers to a not-yet-defined variable, the body is visited
before marking it as defined.
This implementation may be simplified in the future by
moving "pragma_parallel_launch_point" to be an annotation
on the `ForNode`, rather than an `AttrStmt`.
---
src/tir/transforms/ir_utils.cc | 34 ++++++++++--
.../test_tir_transform_convert_ssa.py | 61 +++++++++++++++++++++-
2 files changed, 90 insertions(+), 5 deletions(-)
diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc
index a85bde6787..584b3cbf58 100644
--- a/src/tir/transforms/ir_utils.cc
+++ b/src/tir/transforms/ir_utils.cc
@@ -358,6 +358,7 @@ class IRConvertSSA final : public StmtExprMutator {
}
Var var = iter_var->var;
+ bool delayed_define = false;
if (auto it = function_scope_var_remap_.find(var.get());
it != function_scope_var_remap_.end()) {
var = it->second;
@@ -373,8 +374,23 @@ class IRConvertSSA final : public StmtExprMutator {
function_scope_var_remap_.insert({var.get(), new_var});
var = new_var;
} else {
- function_scope_var_remap_.insert({var.get(), var});
- defined_.insert(var.get());
+ // The AttrStmt refers to an undefined variable. This is
+ // allowed for some attributes, such as
+ // "pragma_parallel_launch_point", which annotates a variable
+ // that is about to occur in a ForNode. In these cases, the
+ // ForNode and the AttrStmt must continue using the same
+ // variable defintion.
+ //
+ // However, other AttrStmt, such as "thread_extent", act as
+ // points of definition for the variable they annotate. If
+ // the variable has not been defined after visiting the body,
+ // we should mark it as defined before exiting. This ensures
+ // correct de-duplication between multiple functions.
+ //
+ // This implementation may be simplified in the future by
+ // moving "pragma_parallel_launch_point" to be an annotation
+ // on the `ForNode`, rather than an `AttrStmt`.
+ delayed_define = true;
}
IterVar new_iter_var;
@@ -387,12 +403,22 @@ class IRConvertSSA final : public StmtExprMutator {
auto value = VisitExpr(op->value);
auto body = VisitStmt(op->body);
+ Stmt output;
if (new_iter_var.get() == iter_var && body.same_as(op->body) &&
value.same_as(op->value)) {
- return GetRef<Stmt>(op);
+ output = GetRef<Stmt>(op);
} else {
- return AttrStmt(new_iter_var, op->attr_key, value, body,
iter_var->span);
+ output = AttrStmt(new_iter_var, op->attr_key, value, body,
iter_var->span);
}
+ if (delayed_define) {
+ if (!defined_.count(var.get())) {
+ function_scope_var_remap_.insert({var.get(), var});
+ defined_.insert(var.get());
+ }
+ }
+
+ return output;
+
} else if (const VarNode* v = op->node.as<VarNode>()) {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<AttrStmtNode>();
diff --git a/tests/python/tir-transform/test_tir_transform_convert_ssa.py
b/tests/python/tir-transform/test_tir_transform_convert_ssa.py
index 140adcd35b..644ab3b624 100644
--- a/tests/python/tir-transform/test_tir_transform_convert_ssa.py
+++ b/tests/python/tir-transform/test_tir_transform_convert_ssa.py
@@ -17,7 +17,7 @@
import tvm
import tvm.testing
-from tvm import tir
+from tvm import tir, ir
from tvm.script import tir as T, ir as I
@@ -485,5 +485,64 @@ class
TestThreadIdxReusedWithinAndAcrossFunctions(BaseBeforeAfter):
return mod
+class TestTrackForwardDeclarationsInAttrStmt(BaseBeforeAfter):
+ """T.attr statements may refer to a about-to-be-defined tir.Var"""
+
+ def before(self):
+ """Generate the PrimFunc, which is already SSA
+
+ This is constructed directly, rather than using TVMScript or
+ the `tvm.tir.ir_builder`. This test case requires a
+ `tir.AttrStmt` that references a variable, followed by the
+ `tir.For` defining that variable. This is not expressible in
+ either TVMScript or `tvm.tir.ir_builder`, as they only provide
+ the loop iterator within the body of the loop.
+ """
+ i0_outer_outer = tir.Var("i0_outer_outer", "int32")
+ i0_outer_inner = tir.Var("i0_outer_inner", "int32")
+ i0_inner = tir.Var("i0_inner", "int32")
+
+ A = tir.decl_buffer(1024, "float32", "A")
+ B = tir.decl_buffer(1024, "float32", "B")
+
+ index = i0_outer_outer * 52 + i0_outer_inner * 4 + i0_inner
+
+ stmt = tir.BufferStore(B, tir.BufferLoad(A, [index]), [index])
+ stmt = tir.IfThenElse(i0_outer_outer * 13 + i0_outer_inner < 256,
stmt, None)
+ stmt = tir.For(i0_inner, 0, 4, tir.ForKind.VECTORIZED, stmt)
+ stmt = tir.For(i0_outer_inner, 0, 13, tir.ForKind.PARALLEL, stmt)
+ stmt = tir.AttrStmt(
+ T.iter_var(i0_outer_inner, None, "DataPar", ""),
+ "pragma_parallal_barrier_when_finish",
+ 1,
+ stmt,
+ )
+ stmt = tir.AttrStmt(
+ T.iter_var(i0_outer_inner, None, "DataPar", ""),
+ "pragma_parallal_stride_pattern",
+ 1,
+ stmt,
+ )
+ stmt = tir.For(i0_outer_outer, 0, 20, tir.ForKind.SERIAL, stmt)
+ stmt = tir.AttrStmt(
+ T.iter_var(i0_outer_outer, None, "DataPar", ""),
+ "pragma_parallal_launch_point",
+ 1,
+ stmt,
+ )
+
+ A_handle = tir.Var("A_handle", "handle")
+ B_handle = tir.Var("B_handle", "handle")
+
+ func = tir.PrimFunc(
+ [A_handle, B_handle],
+ stmt,
+ buffer_map={A_handle: A, B_handle: B},
+ )
+ return func
+
+ expected = before
+
+
if __name__ == "__main__":
tvm.testing.main()