This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 897019d  [Pass][Bugfix] Disable re-use of non-flat buffers in 
StorageRewrite. (#10787)
897019d is described below

commit 897019df6a86720f0157a345e62b538975f11ae8
Author: Eric Lunderberg <lunderb...@users.noreply.github.com>
AuthorDate: Wed Mar 30 14:48:19 2022 -0500

    [Pass][Bugfix] Disable re-use of non-flat buffers in StorageRewrite. 
(#10787)
    
    * [Pass][Bugfix] Disable re-use of non-flat buffers in StorageRewrite.
    
    As a follow-up from https://github.com/apache/tvm/pull/9727,
    restricting StorageRewrite to only modify flat memory buffers.  When
    rewriting, the existing algorithm in StorageRewrite flattens N-d
    allocations into 1-d allocations, preventing them from being exposed
    to the codegen.
    
    * Bugfix, flattening of Allocate/AllocateConst extents
    
    Previously, these were ignored entirely.  This worked so long as all
    allocations were 1-d, as `StorageRewrite` erroneously flattened merged
    arrays into 1-d.
---
 src/tir/transforms/storage_flatten.cc | 97 ++++++++++++++++++++++++++++++++++-
 src/tir/transforms/storage_rewrite.cc | 77 +++++++++++++++++++++------
 2 files changed, 155 insertions(+), 19 deletions(-)

diff --git a/src/tir/transforms/storage_flatten.cc 
b/src/tir/transforms/storage_flatten.cc
index 2bfc842..0923517 100644
--- a/src/tir/transforms/storage_flatten.cc
+++ b/src/tir/transforms/storage_flatten.cc
@@ -1405,12 +1405,25 @@ class StorageFlattener : public StmtExprMutator {
   // rather than a buffer_var.
   Stmt VisitStmt_(const AllocateNode* op) final {
     buffer_var_defines_.insert(op->buffer_var.get());
-    return StmtExprMutator::VisitStmt_(op);
+    auto stmt = Downcast<Allocate>(StmtExprMutator::VisitStmt_(op));
+    return Allocate(stmt->buffer_var, stmt->dtype, FlattenExtents(stmt), 
stmt->condition,
+                    stmt->body, stmt->annotations, stmt->span);
   }
 
   Stmt VisitStmt_(const AllocateConstNode* op) final {
     buffer_var_defines_.insert(op->buffer_var.get());
-    return StmtExprMutator::VisitStmt_(op);
+    auto stmt = Downcast<AllocateConst>(StmtExprMutator::VisitStmt_(op));
+    ObjectRef data_or_idx;
+    if (stmt->data) {
+      data_or_idx = stmt->data.value();
+    } else if (stmt->irmod_storage_idx) {
+      data_or_idx = stmt->irmod_storage_idx.value();
+    } else {
+      LOG(FATAL) << "Neither data array nor data index specified for 
allocation of const "
+                 << op->buffer_var->name_hint;
+    }
+    return AllocateConst(stmt->buffer_var, stmt->dtype, FlattenExtents(stmt), 
data_or_idx,
+                         stmt->body, stmt->span);
   }
 
   Stmt VisitStmt_(const LetStmtNode* op) final {
@@ -1598,6 +1611,82 @@ class StorageFlattener : public StmtExprMutator {
   }
 
  private:
+  // Helper function for visiting Allocate and AllocateConst.  If, in
+  // the future, these are updated to hold a buffer (Buffer) object
+  // rather than a buffer_var (Var), this function can be replaced
+  // with a call to GetBufferEntry.
+  template <typename Node>
+  Array<PrimExpr> FlattenExtents(const Node& node) {
+    arith::Analyzer analyzer;
+
+    // If an allocation has extents that match the buffer
+    auto is_compatible_buffer = [&](const Buffer& buffer) {
+      if (buffer->shape.size() != node->extents.size()) {
+        return false;
+      }
+      for (size_t i = 0; i < buffer->shape.size(); i++) {
+        if (!analyzer.CanProveEqual(buffer->shape[i], node->extents[i])) {
+          return false;
+        }
+      }
+
+      return true;
+    };
+
+    auto int_array_equal = [](const Array<IntImm>& a, const Array<IntImm>& b) {
+      if (a.size() != b.size()) {
+        return false;
+      }
+
+      for (size_t i = 0; i < a.size(); i++) {
+        if (a[i]->value != b[i]->value) {
+          return false;
+        }
+      }
+
+      return true;
+    };
+
+    Array<IntImm> axis_separators;
+    auto it = buffer_var_map_.find(node->buffer_var.get());
+    if (it != buffer_var_map_.end()) {
+      const auto& buffers = it->second;
+      if (buffers.size() == 0) {
+        // No buffers use this allocation, treat as flat and optimize
+        // out later.
+      } else if (buffers.size() == 1) {
+        // Only one buffer uses this allocation, so use its axis
+        // separators.
+        axis_separators = buffers[0]->axis_separators;
+      } else {
+        // Try to find a buffer using this allocation with a matching
+        // shape.
+        Buffer compatible_buffer;
+        for (const auto& buffer : buffers) {
+          if (is_compatible_buffer(buffer)) {
+            ICHECK(!compatible_buffer.defined() ||
+                   int_array_equal(compatible_buffer->axis_separators, 
buffer->axis_separators))
+                << "Cannot determine axis separators to use when flattening "
+                << node->buffer_var->name_hint
+                << ", multiple buffer objects found with conflicting axis 
separators";
+            compatible_buffer = buffer;
+          }
+        }
+        ICHECK(compatible_buffer.defined())
+            << "Cannot determine axis separators to use when flattening "
+            << node->buffer_var->name_hint << ", no buffers found with 
matching shape";
+        axis_separators = compatible_buffer->axis_separators;
+      }
+    }
+
+    // Use GetFlattenedBuffer to determine the flattened shape of the
+    // output.  We only need the shape and axis separators defined,
+    // everything else can be dummy values.
+    Buffer dummy_buffer =
+        decl_buffer(node->extents, DataType::Float(32), "buffer", "", 
axis_separators);
+    return dummy_buffer.GetFlattenedBuffer()->shape;
+  }
+
   // The buffer entry in the flatten map
   struct DimAlignInfo {
     int align_factor{0};
@@ -1665,6 +1754,10 @@ class StorageFlattener : public StmtExprMutator {
   // Set of vars that have occurred in an AllocateNode, but haven't
   // yet occurred in a BufferLoad/BufferStore.
   std::unordered_set<const VarNode*> buffer_var_defines_;
+  // Map from an allocation variable to the buffer(s) that it backs.
+  // Used to track the determine the axis_separators that should be
+  // used for flattening the extents of an AllocateNode.
+  std::unordered_map<const VarNode*, std::vector<Buffer>> buffer_var_map_;
   // Buffer map
   std::unordered_map<Buffer, BufferEntry, ObjectPtrHash, ObjectPtrEqual> 
buf_map_;
   // The extern buffer map, updated to include flattened buffers.
diff --git a/src/tir/transforms/storage_rewrite.cc 
b/src/tir/transforms/storage_rewrite.cc
index 0534f31..d1a37e1 100644
--- a/src/tir/transforms/storage_rewrite.cc
+++ b/src/tir/transforms/storage_rewrite.cc
@@ -76,6 +76,8 @@ class LinearAccessPatternFinder final : public 
StmtExprVisitor {
   };
   // The scope of each allocation
   struct AllocEntry {
+    // The physical dimension of the allocation.
+    size_t num_physical_dimensions{0};
     // scope level
     size_t level{0};
     // allocation stmt
@@ -85,8 +87,16 @@ class LinearAccessPatternFinder final : public 
StmtExprVisitor {
   void VisitStmt_(const AllocateNode* op) final {
     size_t level = scope_.size();
     const VarNode* buf = op->buffer_var.get();
-    alloc_info_[buf].alloc = op;
-    alloc_info_[buf].level = level;
+
+    AllocEntry entry;
+    entry.alloc = op;
+    entry.level = level;
+    // Since StorageRewrite occurs after StorageFlatten/FlattenBuffer,
+    // all allocations specify the extent of physical dimensions, and
+    // is 1 for flat memory spaces.
+    entry.num_physical_dimensions = op->extents.size();
+    alloc_info_[buf] = entry;
+
     StmtExprVisitor::VisitStmt_(op);
   }
 
@@ -104,6 +114,12 @@ class LinearAccessPatternFinder final : public 
StmtExprVisitor {
     if (it != alloc_info_.end() && it->second.alloc) {
       ICHECK_LT(it->second.level, scope_.size());
       scope_[it->second.level].touched.push_back(buf);
+
+      ICHECK_EQ(op->buffer->axis_separators.size() + 1, 
it->second.num_physical_dimensions)
+          << "Buffer " << op->buffer->name << " is allocated with "
+          << it->second.num_physical_dimensions
+          << " physical dimensions, but is accessed as having "
+          << op->buffer->axis_separators.size() + 1 << " physical dimensions" 
<< std::endl;
     }
     StmtEntry e = scope_.back();
     scope_.pop_back();
@@ -125,6 +141,12 @@ class LinearAccessPatternFinder final : public 
StmtExprVisitor {
     if (it != alloc_info_.end() && it->second.alloc) {
       ICHECK_LT(it->second.level, scope_.size()) << "Load memory in places 
other than store.";
       scope_[it->second.level].touched.push_back(buf);
+
+      ICHECK_EQ(op->buffer->axis_separators.size() + 1, 
it->second.num_physical_dimensions)
+          << "Buffer " << op->buffer->name << " is allocated with "
+          << it->second.num_physical_dimensions
+          << " physical dimensions, but is accessed as having "
+          << op->buffer->axis_separators.size() + 1 << " physical dimensions" 
<< std::endl;
     }
   }
 
@@ -530,6 +552,10 @@ class StoragePlanRewriter : public StmtExprMutator {
     uint64_t const_nbits{0};
     // The storage scope.
     StorageScope scope;
+    // The physical dimensionality of the allocations.  Since
+    // StorageRewrite is applied after StorageFlatten/FlattenBuffer,
+    // this is size of `AllocateNode::extents`.  If moved
+    size_t ndim;
     // Allocs that shares this entry.
     std::vector<const AllocateNode*> allocs;
     // The children of this entry, not including itself.
@@ -629,8 +655,8 @@ class StoragePlanRewriter : public StmtExprMutator {
           // simply use the original allocation.
           PrimExpr sz = foldl([](PrimExpr a, PrimExpr b, Span span) { return 
mul(a, b, span); },
                               make_const(DataType::Int(32), 1), 
e->allocs[0]->extents);
-          e->new_alloc =
-              Allocate(e->alloc_var, alloc_type, {sz}, 
e->allocs[0]->condition, Evaluate(0));
+          e->new_alloc = Allocate(e->alloc_var, alloc_type, 
e->allocs[0]->extents,
+                                  e->allocs[0]->condition, Evaluate(0));
           if (IsSpecialTaggedMemory(e->scope)) {
             MemoryInfo info = GetMemoryInfo(e->scope.to_string());
             uint64_t total_elem = e->const_nbits / e->elem_type.bits();
@@ -641,8 +667,13 @@ class StoragePlanRewriter : public StmtExprMutator {
           // Build a merged allocation
           PrimExpr combo_size;
           for (const AllocateNode* op : e->allocs) {
-            PrimExpr sz = foldl([](PrimExpr a, PrimExpr b, Span span) { return 
mul(a, b, span); },
-                                make_const(DataType::Int(32), 1), op->extents);
+            ICHECK_EQ(op->extents.size(), 1)
+                << "Buffer var " << op->buffer_var->name_hint
+                << " was identified as a re-usable allocation, but has " << 
op->extents.size()
+                << " physical dimensions.  "
+                << "Currently, only flat 1-d memory spaces should be 
identified as re-usable "
+                   "allocations.";
+            PrimExpr sz = op->extents[0];
             auto nbits = op->dtype.bits() * op->dtype.lanes();
             if (const auto* imm = sz.as<IntImmNode>()) {
               if (imm->value > std::numeric_limits<int>::max() / nbits) {
@@ -790,7 +821,8 @@ class StoragePlanRewriter : public StmtExprMutator {
 
         for (const VarNode* var : it->second.gen) {
           ICHECK(alloc_info.count(var));
-          const AllocateNode* alloc = alloc_info.at(var).alloc;
+          const AllocEntry& entry = alloc_info.at(var);
+          const AllocateNode* alloc = entry.alloc;
           auto storage_scope = 
StorageScope::Create(GetPtrStorageScope(GetRef<Var>(var)));
           StorageEntry* dst_entry = nullptr;
           // inplace detection
@@ -818,7 +850,8 @@ class StoragePlanRewriter : public StmtExprMutator {
             }
           }
           if (dst_entry == nullptr) {
-            dst_entry = FindAlloc(alloc, thread_scope_, storage_scope);
+            dst_entry =
+                FindAlloc(alloc, thread_scope_, storage_scope, 
entry.num_physical_dimensions);
           }
           dst_entry->allocs.emplace_back(alloc);
           alloc_map_[var] = dst_entry;
@@ -871,24 +904,34 @@ class StoragePlanRewriter : public StmtExprMutator {
   }
 
   StorageEntry* FindAlloc(const AllocateNode* op, const Object* attach_scope,
-                          const StorageScope& scope) {
+                          const StorageScope& scope, size_t 
num_physical_dimensions) {
     ICHECK(op != nullptr);
     // skip plan for local variable,
     // compiler can do a better job with register allocation.
     const uint64_t match_range = 16;
     uint64_t op_elem_bits = op->dtype.bits() * op->dtype.lanes();
     uint64_t const_nbits = static_cast<uint64_t>(op->ConstantAllocationSize() 
* op_elem_bits);
+
+    // If the size of the array isn't known at compile-time, it must
+    // have its own allocation with size determined at runtime.
+    bool is_known_size = (const_nbits != 0);
+
+    // Currently, only flat memory spaces can be re-used.  Packing
+    // into N-d space (e.g. 2-d texture memory on GPUs) will require
+    // more in-depth algorithms.
+    bool is_flat_memory_space = (num_physical_dimensions == 1);
+
     // disable reuse of small arrays, they will be lowered to registers in LLVM
     // This rules only apply if we are using non special memory
-    if (scope.tag.length() == 0) {
-      if (scope.rank >= StorageRank::kWarp || op->dtype.is_handle()) {
-        return NewAlloc(op, attach_scope, scope, const_nbits);
-      }
-      if (const_nbits > 0 && const_nbits <= 32) {
-        return NewAlloc(op, attach_scope, scope, const_nbits);
-      }
+    bool is_small_array =
+        (scope.tag.length() == 0) && (scope.rank >= StorageRank::kWarp || 
op->dtype.is_handle() ||
+                                      (is_known_size && const_nbits <= 32));
+
+    if (is_small_array || !is_flat_memory_space) {
+      return NewAlloc(op, attach_scope, scope, const_nbits);
     }
-    if (const_nbits != 0) {
+
+    if (is_known_size) {
       // constant allocation.
       auto begin = const_free_map_.lower_bound(const_nbits / match_range);
       auto mid = const_free_map_.lower_bound(const_nbits);

Reply via email to