This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new eea6268c92 [TIR] Handle DeclBuffer in
Inline/ComputeAt/ReverseComputeAt (#15038)
eea6268c92 is described below
commit eea6268c928a5d92b0c4b9c864c841edd0740c68
Author: Eric Lunderberg <[email protected]>
AuthorDate: Sat Jun 10 01:03:18 2023 -0400
[TIR] Handle DeclBuffer in Inline/ComputeAt/ReverseComputeAt (#15038)
* [Util] Handle AllocateConst in MergeNest
* [TIR] Handle DeclBuffer in Inline/ComputeAt/ReverseComputeAt
Part of changes being split out from
https://github.com/apache/tvm/pull/14778 into independent portions.
This commit allows TIR `compute_inline`, `compute_at`, and
`reverse_compute_at` schedule primitives to preserve `DeclBuffer`
nodes.
---
src/tir/schedule/transform.cc | 28 +++---
src/tir/transforms/ir_utils.cc | 5 ++
.../unittest/test_tir_schedule_compute_at.py | 99 ++++++++++++++++++++++
3 files changed, 122 insertions(+), 10 deletions(-)
diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc
index baa7f44bbc..9c209658c3 100644
--- a/src/tir/schedule/transform.cc
+++ b/src/tir/schedule/transform.cc
@@ -17,6 +17,7 @@
* under the License.
*/
+#include "../transforms/ir_utils.h"
#include "./utils.h"
namespace tvm {
@@ -261,21 +262,28 @@ void LeafBlockRemovalPlan(const ScheduleState& self,
const StmtSRef& leaf_block_
if (const auto* block = sref->StmtAs<BlockNode>()) {
auto body = block->body;
// Peel off AllocateConst nodes at the beginning of the block body.
- std::vector<const AllocateConstNode*> allocs;
- while (const auto* alloc = body.as<AllocateConstNode>()) {
- allocs.push_back(alloc);
- body = alloc->body;
+ std::vector<Stmt> allocs;
+ while (true) {
+ if (auto opt = body.as<AllocateConst>()) {
+ auto alloc = opt.value();
+ body = alloc->body;
+ alloc.CopyOnWrite()->body = Evaluate(0);
+ allocs.push_back(alloc);
+ } else if (auto opt = body.as<DeclBuffer>()) {
+ auto decl_buffer = opt.value();
+ body = decl_buffer->body;
+ decl_buffer.CopyOnWrite()->body = Evaluate(0);
+ allocs.push_back(decl_buffer);
+ } else {
+ break;
+ }
}
+
if (const auto* seq = body.as<SeqStmtNode>()) {
ObjectPtr<BlockNode> n = make_object<BlockNode>(*block);
auto new_seq = RemoveFromSeqStmt(GetRef<SeqStmt>(seq),
GetRef<Stmt>(last_stmt));
// Re-attach AllocateConst nodes
- auto new_body = new_seq;
- for (int i = 0; i < static_cast<int>(allocs.size()); ++i) {
- auto alloc = allocs[allocs.size() - 1 - i];
- new_body = AllocateConst(alloc->buffer_var, alloc->dtype,
alloc->extents, alloc->data,
- new_body, alloc->annotations, alloc->span);
- }
+ auto new_body = MergeNest(allocs, new_seq);
n->body = new_body;
*src_stmt = GetRef<Stmt>(block);
*tgt_stmt = Stmt(std::move(n));
diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc
index 604dbed325..43bf6b983e 100644
--- a/src/tir/transforms/ir_utils.cc
+++ b/src/tir/transforms/ir_utils.cc
@@ -75,6 +75,11 @@ Stmt MergeNest(const std::vector<Stmt>& nest, Stmt body) {
ICHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
+ } else if (const auto* alloc = s.as<AllocateConstNode>()) {
+ auto n = make_object<AllocateConstNode>(*alloc);
+ ICHECK(is_no_op(n->body));
+ n->body = body;
+ body = Stmt(n);
} else if (const auto* decl_buffer = s.as<DeclBufferNode>()) {
auto n = make_object<DeclBufferNode>(*decl_buffer);
ICHECK(is_no_op(n->body));
diff --git a/tests/python/unittest/test_tir_schedule_compute_at.py
b/tests/python/unittest/test_tir_schedule_compute_at.py
index 0623fb02f3..7efb4cccc0 100644
--- a/tests/python/unittest/test_tir_schedule_compute_at.py
+++ b/tests/python/unittest/test_tir_schedule_compute_at.py
@@ -1672,5 +1672,104 @@ def test_reverse_compute_at_layout_trans():
verify_trace_roundtrip(sch=sch, mod=before)
[email protected]("use_decl_buffer", [True, False])
[email protected]("use_reverse_compute_at", [True, False])
+def test_compute_at_allocate_const(use_decl_buffer, use_reverse_compute_at):
+ def apply_decl_buffer(*args, **kwargs):
+ if use_decl_buffer:
+ return T.decl_buffer(*args, **kwargs)
+ else:
+ return T.Buffer(*args, **kwargs)
+
+ @T.prim_func
+ def before(A: T.Buffer([4, 256], "float32"), C: T.Buffer([4, 256],
"float32")):
+ B = T.alloc_buffer([4])
+
+ offset_ptr = T.allocate_const([1.0, 2.0, 3.0, 4.0], dtype="float32",
extents=[4])
+ offset = apply_decl_buffer([4], data=offset_ptr)
+ for i in range(4):
+ with T.block("compute_B"):
+ vi = T.axis.remap("S", [i])
+ B[vi] = 10.0 * vi + offset[vi]
+
+ for i, j in T.grid(4, 256):
+ with T.block("compute_C"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ C[vi, vj] = B[vi] + 100.0 * vj
+
+ @T.prim_func
+ def expected(A: T.Buffer([4, 256], "float32"), C: T.Buffer([4, 256],
"float32")):
+ B = T.alloc_buffer([4])
+
+ offset_ptr = T.allocate_const([1.0, 2.0, 3.0, 4.0], dtype="float32",
extents=[4])
+ offset = apply_decl_buffer([4], data=offset_ptr)
+ for i in range(4):
+ with T.block("compute_B"):
+ vi = T.axis.remap("S", [i])
+ B[vi] = 10.0 * vi + offset[vi]
+
+ for j in range(256):
+ with T.block("compute_C"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ C[vi, vj] = B[vi] + 100.0 * vj
+
+ sch = tir.Schedule(before, debug_mask="all")
+ if use_reverse_compute_at:
+ block = sch.get_block("compute_C")
+ axis = sch.get_loops("compute_B")[0]
+ sch.reverse_compute_at(block, axis)
+ else:
+ block = sch.get_block("compute_B")
+ axis = sch.get_loops("compute_C")[0]
+ sch.compute_at(block, axis)
+
+ after = sch.mod["main"]
+
+ tvm.ir.assert_structural_equal(expected, after)
+ verify_trace_roundtrip(sch=sch, mod=before)
+
+
[email protected]("use_decl_buffer", [True, False])
+def test_compute_inline_allocate_const(use_decl_buffer):
+ def apply_decl_buffer(*args, **kwargs):
+ if use_decl_buffer:
+ return T.decl_buffer(*args, **kwargs)
+ else:
+ return T.Buffer(*args, **kwargs)
+
+ @T.prim_func
+ def before(A: T.Buffer([4, 256], "float32"), C: T.Buffer([4, 256],
"float32")):
+ B = T.alloc_buffer([4])
+
+ offset_ptr = T.allocate_const([1.0, 2.0, 3.0, 4.0], dtype="float32",
extents=[4])
+ offset = apply_decl_buffer([4], data=offset_ptr)
+ for i in range(4):
+ with T.block("compute_B"):
+ vi = T.axis.remap("S", [i])
+ B[vi] = 10.0 * vi + offset[vi]
+
+ for i, j in T.grid(4, 256):
+ with T.block("compute_C"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ C[vi, vj] = B[vi] + 100.0 * vj
+
+ @T.prim_func
+ def expected(A: T.Buffer([4, 256], "float32"), C: T.Buffer([4, 256],
"float32")):
+ offset_ptr = T.allocate_const([1.0, 2.0, 3.0, 4.0], dtype="float32",
extents=[4])
+ offset = apply_decl_buffer([4], data=offset_ptr)
+ for i, j in T.grid(4, 256):
+ with T.block("compute_C"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ C[vi, vj] = (10.0 * vi + offset[vi]) + 100.0 * vj
+
+ sch = tir.Schedule(before, debug_mask="all")
+ block = sch.get_block("compute_B")
+ sch.compute_inline(block)
+ after = sch.mod["main"]
+
+ tvm.ir.assert_structural_equal(expected, after)
+ verify_trace_roundtrip(sch=sch, mod=before)
+
+
if __name__ == "__main__":
tvm.testing.main()