This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 7214f52  [TIR] Fix opaque access in buffer locator pass and 
match_buffer in region detector (#8855)
7214f52 is described below

commit 7214f5239dbb8da4585d4d10fbc8c65c8f155b12
Author: Siyuan Feng <hzfen...@vip.qq.com>
AuthorDate: Sat Aug 28 17:23:43 2021 +0800

    [TIR] Fix opaque access in buffer locator pass and match_buffer in region 
detector (#8855)
    
    * init
    
    * fix
    
    * Update src/tir/transforms/plan_update_buffer_allocation_location.cc
    
    Co-authored-by: Ruihang Lai <lairuihangdongd...@qq.com>
    
    * Update src/tir/transforms/plan_update_buffer_allocation_location.cc
    
    Co-authored-by: Ruihang Lai <lairuihangdongd...@qq.com>
    
    * address
    
    Co-authored-by: Junru Shao <junrushao1...@gmail.com>
    Co-authored-by: Ruihang Lai <lairuihangdongd...@qq.com>
---
 src/tir/analysis/block_access_region_detector.cc   |  7 ++-
 .../plan_update_buffer_allocation_location.cc      | 39 +++++++++-----
 .../test_tir_analysis_get_block_access_region.py   | 21 +++++---
 ...sform_plan_update_buffer_allocation_location.py | 62 ++++++++++++++++++++++
 4 files changed, 109 insertions(+), 20 deletions(-)

diff --git a/src/tir/analysis/block_access_region_detector.cc 
b/src/tir/analysis/block_access_region_detector.cc
index 8f87ef9..dd01aed 100644
--- a/src/tir/analysis/block_access_region_detector.cc
+++ b/src/tir/analysis/block_access_region_detector.cc
@@ -110,8 +110,11 @@ void BlockReadWriteDetector::operator()(const Stmt& stmt) {
   ICHECK(block != nullptr) << "Only visiting Blocks is allowed, but got " << 
stmt->GetTypeKey();
   for (const MatchBufferRegion& match_buffer : block->match_buffers) {
     const Var& target_var = match_buffer->buffer->data;
-    match_buffers_[target_var.get()] = match_buffer;
-    buffer_var_map_.Set(target_var, match_buffer->buffer);
+    const Var& source_var = match_buffer->source->buffer->data;
+    if (buffer_var_map_.find(source_var) != buffer_var_map_.end()) {
+      match_buffers_[target_var.get()] = match_buffer;
+      buffer_var_map_.Set(target_var, match_buffer->buffer);
+    }
   }
   StmtExprVisitor::operator()(stmt);
 }
diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc 
b/src/tir/transforms/plan_update_buffer_allocation_location.cc
index bee11ad..59f9170 100644
--- a/src/tir/transforms/plan_update_buffer_allocation_location.cc
+++ b/src/tir/transforms/plan_update_buffer_allocation_location.cc
@@ -75,8 +75,6 @@ class BufferAllocationLocator : public StmtExprMutator {
 
   Stmt VisitStmt_(const BlockNode* op) final {
     ICHECK(!op->init.defined());
-    bool is_root = is_root_;
-    is_root_ = false;
     Array<Buffer> alloc_buffers;
     auto it = alloc_buffers_.find(op);
     if (it != alloc_buffers_.end()) {
@@ -85,11 +83,23 @@ class BufferAllocationLocator : public StmtExprMutator {
         buffer_data_to_buffer_.Set(buf->data, buf);
       }
     }
+    for (const MatchBufferRegion match_buffer : op->match_buffers) {
+      const Var& target_var = match_buffer->buffer->data;
+      const Var& source_var = match_buffer->source->buffer->data;
+      ICHECK(buffer_data_to_buffer_.count(source_var));
+      buffer_data_to_buffer_.Set(target_var, match_buffer->buffer);
+    }
     Stmt stmt = StmtMutator::VisitStmt_(op);
     op = stmt.as<BlockNode>();
     ICHECK(op != nullptr);
 
-    // Ignore buffer allocated inside the block when getting access region.
+    // No longer consider buffers created by match_buffer inside the block 
when updating access
+    // region.
+    for (const MatchBufferRegion match_buffer : op->match_buffers) {
+      const Var& target_var = match_buffer->buffer->data;
+      buffer_data_to_buffer_.erase(target_var);
+    }
+    // No longer consider buffers allocated inside the block when updating 
access region.
     if (it != alloc_buffers_.end()) {
       for (const Buffer& buf : it->second) {
         buffer_data_to_buffer_.erase(buf->data);
@@ -98,12 +108,9 @@ class BufferAllocationLocator : public StmtExprMutator {
 
     ObjectPtr<BlockNode> n = CopyOnWrite(op);
     n->alloc_buffers = std::move(alloc_buffers);
-    // The read/write regions of root block are always empty.
-    if (!is_root) {
-      // Recalculate block access region
-      CollectReadWrite(GetRef<Block>(op), &n->reads, &n->writes);
-    }
-
+    // Erase buffer allocated inside the block from access region.
+    n->reads = RemoveRedundantBufferRegion(n->reads);
+    n->writes = RemoveRedundantBufferRegion(n->writes);
     return Stmt(n);
   }
 
@@ -127,8 +134,18 @@ class BufferAllocationLocator : public StmtExprMutator {
     return std::move(realize);
   }
 
+  Array<BufferRegion> RemoveRedundantBufferRegion(const Array<BufferRegion>& 
region) const {
+    Array<BufferRegion> result;
+    for (const BufferRegion& buffer_region : region) {
+      if (buffer_data_to_buffer_.count(buffer_region->buffer->data)) {
+        result.push_back(buffer_region);
+      }
+    }
+    return result;
+  }
+
   void CollectReadWrite(const Block& block, Array<BufferRegion>* reads,
-                        Array<BufferRegion>* writes) {
+                        Array<BufferRegion>* writes) const {
     Array<Array<BufferRegion>> access = GetBlockAccessRegion(block, 
buffer_data_to_buffer_);
     *reads = access[0];
     *writes = access[1];
@@ -142,8 +159,6 @@ class BufferAllocationLocator : public StmtExprMutator {
   std::unordered_map<const StmtNode*, Array<Buffer>> alloc_buffers_;
   /*! \brief The buffer already allocated during recursive visiting. */
   Map<Var, Buffer> buffer_data_to_buffer_;
-  /*! \brief indicate the whether the block is root. */
-  bool is_root_{true};
 };
 
 PrimFunc PlanAndUpdateBufferAllocationLocation(PrimFunc func) {
diff --git a/tests/python/unittest/test_tir_analysis_get_block_access_region.py 
b/tests/python/unittest/test_tir_analysis_get_block_access_region.py
index 7641f0a..9c95b98 100644
--- a/tests/python/unittest/test_tir_analysis_get_block_access_region.py
+++ b/tests/python/unittest/test_tir_analysis_get_block_access_region.py
@@ -114,20 +114,29 @@ def test_match_buffer():
     root_block = match_buffer_func.body.block
     block = root_block.body.body.body.block
     block_inner = block.body[0].body.body.block
-    alloc_buffers = func.body.block.alloc_buffers
+    alloc_buffers = match_buffer_func.body.block.alloc_buffers
     buffer_var_map = {buf.data: buf for buf in alloc_buffers}
 
-    # Check inner block AAA
-    ret = tir.analysis.get_block_access_region(block_inner, buffer_var_map)
-    tvm.ir.assert_structural_equal(block_inner.reads, ret[0])
-    tvm.ir.assert_structural_equal(block_inner.writes, ret[1])
-
     # Check block
     ret = tir.analysis.get_block_access_region(block, buffer_var_map)
     tvm.ir.assert_structural_equal(block.writes, ret[1])
     # B is opaque access
     tvm.ir.assert_structural_equal(block.reads, ret[2])
 
+    # Check inner block AAA without updating buffer_var_map
+    ret = tir.analysis.get_block_access_region(block_inner, buffer_var_map)
+    # Since AA is not in the buffer_var_map, region of AA will not be 
collected.
+    tvm.ir.assert_structural_equal([], ret[1])
+
+    # Check inner block AAA
+    for match_buffer in block.match_buffers:
+        target_buffer = match_buffer.buffer
+        buffer_var_map[target_buffer.data] = target_buffer
+
+    ret = tir.analysis.get_block_access_region(block_inner, buffer_var_map)
+    tvm.ir.assert_structural_equal(block_inner.reads, ret[0])
+    tvm.ir.assert_structural_equal(block_inner.writes, ret[1])
+
 
 if __name__ == "__main__":
     test_block_access_region_detector()
diff --git 
a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py
 
b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py
index 8418e19..07140ab 100644
--- 
a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py
+++ 
b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py
@@ -137,6 +137,63 @@ def transformed_match_buffer_func() -> None:
                 C1[()] = 0
 
 
+@tvm.script.tir
+def opaque_access(a: ty.handle, b: ty.handle) -> None:
+    A = tir.match_buffer(a, [1024])
+    B = tir.match_buffer(b, [1024])
+    A_cache = tir.alloc_buffer([1024])
+    for i in tir.serial(0, 8):
+        with tir.block([8]) as [vi]:
+            with tir.block([8]) as [v]:
+                tir.bind(v, vi)
+                tir.reads([A[(v * 128) : ((v * 128) + 128)]])
+                tir.writes([A_cache[(v * 128) : ((v * 128) + 128)]])
+                tir.evaluate(
+                    tir.call_extern(
+                        "test",
+                        A_cache.data,
+                        (v * 128),
+                        128,
+                        A.data,
+                        (v * 128),
+                        128,
+                        dtype="float32",
+                    )
+                )
+            for j in tir.serial(0, 128):
+                with tir.block([1024]) as [v]:
+                    tir.bind(v, ((vi * 128) + j))
+                    tir.reads([A_cache[v]])
+                    tir.writes([B[v]])
+                    B[v] = A_cache[v]
+
+
+@tvm.script.tir
+def transformed_opaque_access(a: ty.handle, b: ty.handle) -> None:
+    A = tir.match_buffer(a, [1024])
+    B = tir.match_buffer(b, [1024])
+    for i in tir.serial(0, 8):
+        with tir.block([8]) as [vi]:
+            tir.reads(A[vi * 128 : vi * 128 + 128])
+            tir.writes(B[vi * 128 : vi * 128 + 128])
+            A_cache = tir.alloc_buffer([1024])
+            with tir.block([8]) as [v]:
+                tir.bind(v, vi)
+                tir.reads([A[v * 128 : v * 128 + 128]])
+                tir.writes([A_cache[v * 128 : v * 128 + 128]])
+                tir.evaluate(
+                    tir.call_extern(
+                        "test", A_cache.data, v * 128, 128, A.data, v * 128, 
128, dtype="float32"
+                    )
+                )
+            for j in tir.serial(0, 128):
+                with tir.block([1024]) as [v]:
+                    tir.bind(v, ((vi * 128) + j))
+                    tir.reads([A_cache[v]])
+                    tir.writes([B[v]])
+                    B[v] = A_cache[v]
+
+
 def test_elementwise():
     _check(element_func, transformed_element_func)
 
@@ -149,6 +206,10 @@ def test_match_buffer_allocation():
     _check(match_buffer_func, transformed_match_buffer_func)
 
 
+def test_opaque_access():
+    _check(opaque_access, transformed_opaque_access)
+
+
 def test_lower_te():
     x = te.placeholder((1,))
     y = te.compute((1,), lambda i: x[i] + 2)
@@ -164,4 +225,5 @@ if __name__ == "__main__":
     test_elementwise()
     test_locate_buffer_allocation()
     test_match_buffer_allocation()
+    test_opaque_access()
     test_lower_te()

Reply via email to