This is an automated email from the ASF dual-hosted git repository. csullivan pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 3eb673478b [LowerVTCMAlloc] Move LowerVtcmAlloc to after StorageRewrite (#12364) 3eb673478b is described below commit 3eb673478bc444daf24ee8d6308a42a71c81b74f Author: Tristan Konolige <tkonol...@octoml.ai> AuthorDate: Fri Aug 12 13:16:23 2022 -0700 [LowerVTCMAlloc] Move LowerVtcmAlloc to after StorageRewrite (#12364) * [LowerVTCMAlloc] Move LowerVtcmAlloc to after StorageRewrite Vtcm allocations were being moved inside loops even if they were originally allocated outside of the loops. Normally PlanAndUpdateBufferAllocationLocation moves allocations as close to use as possible and then StorageRewrite moves them back out as far as possible. However, with Vtcm allocation, PlanAndUpdateBufferAllocationLocation would move the Vtcm allocation close to the compute, then LowerVtcm would convert the allocation to a LetStmt. StorageRewrite would not move this LetStmt as it only handles allocations. Moving LowerVtcmAlloc to after StorageRewrite ensures that the vtcm allocations are in their final spot before converting them to a LetStmt. * fix issues with tagging and storage rewrite --- src/driver/driver_api.cc | 3 ++- src/tir/transforms/storage_rewrite.cc | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index cbf809a267..9bd2e8a812 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -204,7 +204,6 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) { pass_list.push_back(tir::transform::InjectSoftwarePipeline()); pass_list.push_back(tir::transform::LowerOpaqueBlock()); pass_list.push_back(tir::transform::FlattenBuffer()); - pass_list.push_back(tir::transform::LowerVtcmAlloc()); pass_list.push_back(tir::transform::BF16Legalize()); pass_list.push_back(tir::transform::NarrowDataType(32)); pass_list.push_back(tir::transform::Simplify()); @@ -223,6 +222,8 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) { if (!disable_storage_rewrite) { pass_list.push_back(tir::transform::StorageRewrite()); } + // LowerVtcmAlloc must occur after any transformations that modify memory allocation locations + pass_list.push_back(tir::transform::LowerVtcmAlloc()); pass_list.push_back(tir::transform::UnrollLoop()); // Add user-defined phase-2 passes diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 5a326d9fac..d15bed56fd 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -583,8 +583,10 @@ class StoragePlanRewriter : public StmtExprMutator { }; // Checks whether the storage_scope is especially tagged for a specific memory. + // Special memory is all combined into a single allocation. bool IsSpecialTaggedMemory(const StorageScope& scope) { - return scope.tag.length() != 0 && scope.tag != ".dyn" && scope.tag != ".workspace"; + return scope.tag.length() != 0 && scope.tag != ".dyn" && scope.tag != ".workspace" && + scope.tag != ".vtcm"; } // Alllocate entry of node. @@ -655,8 +657,6 @@ class StoragePlanRewriter : public StmtExprMutator { if (e->allocs.size() == 1) { // simply use the original allocation. - PrimExpr sz = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); }, - make_const(DataType::Int(32), 1), e->allocs[0]->extents); e->new_alloc = Allocate(e->alloc_var, alloc_type, e->allocs[0]->extents, e->allocs[0]->condition, Evaluate(0)); if (IsSpecialTaggedMemory(e->scope)) {