This is an automated email from the ASF dual-hosted git repository.

jcf94 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new dee3133  [TensorIR][PASS] CompactBufferAllocation (#7923)
dee3133 is described below

commit dee3133c5418fc1d44ab202bff8b2c6906593d1a
Author: Siyuan Feng <hzfen...@vip.qq.com>
AuthorDate: Wed Apr 28 13:11:43 2021 +0800

    [TensorIR][PASS] CompactBufferAllocation (#7923)
    
    Co-authored-by: Tianqi Chen <tqc...@users.noreply.github.com>
    Co-authored-by: Junru Shao <junrushao1...@gmail.com>
    Co-authored-by: Cody Yu <comaniac0...@gmail.com>
---
 include/tvm/tir/expr.h                             |   1 +
 include/tvm/tir/stmt.h                             |  12 +-
 include/tvm/tir/transform.h                        |  46 ++
 python/tvm/tir/transform/transform.py              |  50 +++
 src/support/utils.h                                |  19 +
 src/tir/ir/stmt.cc                                 |   8 +
 src/tir/transforms/compact_buffer_region.cc        | 468 +++++++++++++++++++++
 src/tir/transforms/convert_blocks_to_opaque.cc     | 104 +++++
 .../test_tir_transform_compact_buffer_region.py    | 331 +++++++++++++++
 .../test_tir_transform_convert_blocks_to_opaque.py |  77 ++++
 10 files changed, 1115 insertions(+), 1 deletion(-)

diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h
index 7cab197..e1d0974 100644
--- a/include/tvm/tir/expr.h
+++ b/include/tvm/tir/expr.h
@@ -638,6 +638,7 @@ class BufferLoad : public PrimExpr {
  public:
   TVM_DLL explicit BufferLoad(Buffer buffer, Array<PrimExpr> indices, Span 
span = Span());
   TVM_DEFINE_OBJECT_REF_METHODS(BufferLoad, PrimExpr, BufferLoadNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferLoadNode);
 };
 
 /*!
diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h
index 0931768..cc10c21 100644
--- a/include/tvm/tir/stmt.h
+++ b/include/tvm/tir/stmt.h
@@ -324,6 +324,7 @@ class BufferStore : public Stmt {
                                Span span = Span());
 
   TVM_DEFINE_OBJECT_REF_METHODS(BufferStore, Stmt, BufferStoreNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferStoreNode);
 };
 
 /*!
@@ -991,13 +992,22 @@ class BufferRegion : public ObjectRef {
   TVM_DLL explicit BufferRegion(Buffer buffer, Array<Range> region);
 
   /*!
-   * \brief Create a BufferRegion which is full region of the given buffer..
+   * \brief Create a BufferRegion which is full region of the given buffer.
    * \param buffer The buffer to generate full BufferRegion.
    * \return The BufferRegion which covers all region of the given buffer
    */
   TVM_DLL static BufferRegion FullRegion(Buffer buffer);
 
+  /*!
+   * \brief Create a BufferRegion which is a single point of the given buffer.
+   * \param buffer The buffer to generate single point BufferRegion.
+   * \param indices The access point indices of the buffer
+   * \return The BufferRegion which is the single point of the given buffer.
+   */
+  TVM_DLL static BufferRegion FromPoint(Buffer buffer, Array<PrimExpr> 
indices);
+
   TVM_DEFINE_OBJECT_REF_METHODS(BufferRegion, ObjectRef, BufferRegionNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferRegionNode);
 };
 
 /*!
diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h
index 8e7c16b..a236c50 100644
--- a/include/tvm/tir/transform.h
+++ b/include/tvm/tir/transform.h
@@ -360,6 +360,52 @@ TVM_DLL Pass LowerInitBlock();
  */
 TVM_DLL Pass PlanAndUpdateBufferAllocationLocation();
 
+/*!
+ * \brief Substitute all the block vars with the PrimExprs they are bound to, 
indicated by the
+ *        corresponding iter_values in BlockRealize, for opaque blocks by 
removing all
+ *.        the iter_values in BlockRealize and iter_vars in Block.
+ * \return The pass.
+ */
+TVM_DLL Pass ConvertBlocksToOpaque();
+
+/*!
+ * \brief Compact the buffer access region by removing the buffer regions that 
are not accessed,
+ *        i.e. narrowing the buffer shape and adjust the access region if 
necessary.
+ * \example
+ *  Before narrowing, `B` is a `[16, 16]` buffer, but only a skinny vector 
`B[i, 0:16]` is accessed.
+ *  \code
+ *
+ *  for i in range(0, 16):
+ *      with tir.block([]):
+ *          B = tir.alloc_buffer(16, 16)
+ *          for j in range(0, 16):
+ *              B[i, j] = A[i, j] + 1
+ *          for j in range(0, 16):
+ *              C[i, j] = B[i, j] + 1
+ *
+ *  \endcode
+ *
+ * This pass narrows the buffer shape and adjust its accessed region 
accordingly.
+ * In this particular case, because only a `1 * 16` vector of `B` is accessed,
+ * the pass narrows `B` to shape `[1, 16]`, and changes the access to `B[i, 
j]` to `B[0, j]`.
+ *
+ *  \code
+ *
+ *  for i in range(0, 16):
+ *      with tir.block([]):
+ *          B = tir.alloc_buffer(1, 16)
+ *          for j in range(0, 16):
+ *              B[0, j] = A[i, j] + 1
+ *          for j in range(0, 16):
+ *              C[i, j] = B[0, j] + 1
+ *
+ *  \endcode
+ *
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass CompactBufferAllocation();
+
 }  // namespace transform
 }  // namespace tir
 }  // namespace tvm
diff --git a/python/tvm/tir/transform/transform.py 
b/python/tvm/tir/transform/transform.py
index 8317421..2ae75d2 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -560,3 +560,53 @@ def PlanAndUpdateBufferAllocationLocation():
         The result pass
     """
     return _ffi_api.PlanAndUpdateBufferAllocationLocation()
+
+
+def ConvertBlocksToOpaque():
+    """Substitute all the block vars with the PrimExprs they are bound to, 
indicated by
+    the corresponding iter_values in BlockRealize, and then convert the blocks 
into
+    opaque ones by removing all the iter_values in BlockRealize and iter_vars 
in Block.
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.ConvertBlocksToOpaque()
+
+
+def CompactBufferAllocation():
+    """Compact the buffer access region. by removing the buffer regions that 
are not accessed,
+    i.e. narrowing the buffer shape and adjust the access region if necessary.
+
+    Example
+    -------
+    Before narrowing, `B` is a `[16, 16]` buffer, but only a skinny vector 
`B[i, 0:16]` is accessed.
+    .. code-block:: python
+
+        for i in range(0, 16):
+            with tir.block([]):
+                B = tir.alloc_buffer(16, 16)
+                for j in range(0, 16):
+                    B[i, j] = A[i, j] + 1
+                for j in range(0, 16):
+                    C[i, j] = B[i, j] + 1
+    This pass narrows the buffer shape and adjust its accessed region 
accordingly.
+    In this particular case, because only a `1 * 16` vector of `B` is accessed,
+    the pass narrows `B` to shape `[1, 16]`, and changes the access to `B[i, 
j]` to `B[0, j]`.
+    .. code-block:: python
+
+        for i in range(0, 16):
+            with tir.block([]):
+                B = tir.alloc_buffer(1, 16)
+                for j in range(0, 16):
+                    B[0, j] = A[i, j] + 1
+                for j in range(0, 16):
+                    C[i, j] = B[0, j] + 1
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The result pass
+    """
+    return _ffi_api.CompactBufferAllocation()
diff --git a/src/support/utils.h b/src/support/utils.h
index 2f55d40..0753517 100644
--- a/src/support/utils.h
+++ b/src/support/utils.h
@@ -31,6 +31,9 @@
 #include <sys/wait.h>
 #endif  // __hexagon__
 #endif  // _WIN32
+
+#include <tvm/runtime/container.h>
+
 #include <algorithm>
 #include <array>
 #include <cctype>
@@ -129,6 +132,22 @@ inline std::vector<std::string> Split(const std::string& 
str, char delim) {
 }
 
 /*!
+ * \brief Check whether the string starts with a given prefix.
+ * \param str The given string.
+ * \param prefix The given prefix.
+ * \return Whether the prefix matched.
+ */
+inline bool StartsWith(const String& str, const char* prefix) {
+  size_t n = str.length();
+  for (size_t i = 0; i < n; i++) {
+    if (prefix[i] == '\0') return true;
+    if (str.data()[i] != prefix[i]) return false;
+  }
+  // return true if the str is equal to the prefix
+  return prefix[n + 1] == '\0';
+}
+
+/*!
  * \brief EndsWith check whether the strings ends with
  * \param value The full string
  * \param end The end substring
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index 87ead3e..b2016eb 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -646,6 +646,14 @@ BufferRegion BufferRegion::FullRegion(Buffer buffer) {
   return BufferRegion(buffer, region);
 }
 
+BufferRegion BufferRegion::FromPoint(Buffer buffer, Array<PrimExpr> indices) {
+  Array<Range> region;
+  for (const PrimExpr& index : indices) {
+    region.push_back(Range::FromMinExtent(index, 1));
+  }
+  return BufferRegion(buffer, region);
+}
+
 TVM_REGISTER_GLOBAL("tir.BufferRegion").set_body_typed([](Buffer buffer, 
Array<Range> region) {
   return BufferRegion(buffer, region);
 });
diff --git a/src/tir/transforms/compact_buffer_region.cc 
b/src/tir/transforms/compact_buffer_region.cc
new file mode 100644
index 0000000..a5ca67e
--- /dev/null
+++ b/src/tir/transforms/compact_buffer_region.cc
@@ -0,0 +1,468 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file compact_buffer_region.cc
+ * \brief Compact the buffer size into its exact need.
+ */
+
+#include <tvm/arith/int_set.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+#include <stack>
+
+#include "../../runtime/thread_storage_scope.h"
+#include "../../support/arena.h"
+#include "../../support/utils.h"
+
+namespace tvm {
+namespace tir {
+
+using NDIntSet = std::vector<arith::IntSet>;
+
+arith::IntSet IntSetFromMinExtent(const PrimExpr& min, const PrimExpr& extent) 
{
+  return arith::IntSet::FromRange(Range::FromMinExtent(min, extent));
+}
+
+NDIntSet NDIntSetFromRegion(const Region& region) {
+  NDIntSet result;
+  result.reserve(region.size());
+  for (const Range& range : region) {
+    result.push_back(arith::IntSet::FromRange(range));
+  }
+  return result;
+}
+
+NDIntSet NDIntSetFromShape(const Array<PrimExpr>& shape) {
+  PrimExpr zero = Integer(0);
+  NDIntSet result;
+  result.reserve(shape.size());
+  for (const PrimExpr& extent : shape) {
+    result.push_back(IntSetFromMinExtent(zero, extent));
+  }
+  return result;
+}
+
+NDIntSet NDIntSetFromPoint(const Array<PrimExpr>& indices) {
+  NDIntSet result;
+  result.reserve(indices.size());
+  for (const PrimExpr& index : indices) {
+    result.push_back(arith::IntSet::SinglePoint(index));
+  }
+  return result;
+}
+
+void NDIntSetUnionWith(NDIntSet* lhs, const NDIntSet& rhs) {
+  ICHECK_EQ(lhs->size(), rhs.size());
+  int ndim = rhs.size();
+  for (int i = 0; i < ndim; ++i) {
+    arith::IntSet& int_set = lhs->at(i);
+    int_set = arith::Union({int_set, rhs.at(i)});
+  }
+}
+
+NDIntSet NDIntSetEmpty(int ndim) {
+  return std::vector<arith::IntSet>(ndim, arith::IntSet::Nothing());
+}
+
+NDIntSet EvalNDIntSet(const NDIntSet& nd_int_set,
+                      const std::unordered_map<const VarNode*, arith::IntSet>& 
dom_map) {
+  NDIntSet ret;
+  ret.reserve(nd_int_set.size());
+  for (const arith::IntSet& s : nd_int_set) {
+    ret.push_back(arith::EvalSet(s, dom_map));
+  }
+  return ret;
+}
+
+/*!
+ * \brief return the region collected by NDIntSet. return the oroginal buffer 
shape if the
+ *        int_set is empty.
+ */
+Region NarrowBufferRegionFromNDIntSet(const NDIntSet& nd_int_set,
+                                      const Array<PrimExpr>& original_shape) {
+  Array<Range> result;
+  result.reserve(nd_int_set.size());
+  for (size_t i = 0; i < nd_int_set.size(); ++i) {
+    const arith::IntSet& int_set = nd_int_set[i];
+    result.push_back(int_set.CoverRange(Range(/*begin=*/0, 
/*end=*/original_shape[i])));
+  }
+  return result;
+}
+
+/*!
+ * \brief Collect the access region of each buffer.
+ * \note The param buffer regions will not be collected.
+ */
+class BufferAccessRegionCollector : public StmtExprVisitor {
+ public:
+  static std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual> 
Collect(
+      const PrimFunc& f) {
+    BufferAccessRegionCollector collector;
+    collector(f->body);
+    return std::move(collector.buffer_access_region_);
+  }
+
+ private:
+  struct BufferAccessInfo {
+    /*! \brief The buffer. */
+    Buffer buffer;
+    /*! \brief The buffer access region, which can be updated during visiting. 
*/
+    NDIntSet accessed_region;
+
+    explicit BufferAccessInfo(const Buffer& buffer, const NDIntSet& region)
+        : buffer(buffer), accessed_region(region) {}
+  };
+
+  BufferAccessRegionCollector() = default;
+
+  /**************** Visitor overload ****************/
+
+  void VisitStmt_(const BufferStoreNode* op) final {
+    VisitBufferAccess(BufferRegion::FromPoint(op->buffer, op->indices));
+  }
+
+  void VisitExpr_(const BufferLoadNode* op) final {
+    VisitBufferAccess(BufferRegion::FromPoint(op->buffer, op->indices));
+  }
+
+  void VisitExpr_(const VarNode* op) final { VisitBufferVar(GetRef<Var>(op)); }
+
+  void VisitExpr_(const LoadNode* op) final {
+    StmtExprVisitor::VisitExpr_(op);
+    VisitBufferVar(op->buffer_var);
+  }
+
+  void VisitStmt_(const StoreNode* op) final {
+    StmtExprVisitor::VisitStmt_(op);
+    VisitBufferVar(op->buffer_var);
+  }
+
+  void VisitStmt_(const ForNode* op) final {
+    ancestor_loops_.push_back(op);
+    StmtExprVisitor::VisitStmt_(op);
+    ancestor_loops_.pop_back();
+    // The iter_dom_map is updated by post DFS order.
+    // If the union point is under the for node, the loop var will not be 
relaxed.
+    // If the union point is outer of the for loop, the loop var should be 
relaxed.
+    iter_dom_map_on_post_order_[op->loop_var.get()] = 
IntSetFromMinExtent(op->min, op->extent);
+  }
+
+  void VisitStmt_(const BlockNode* op) final {
+    // Step 0. Check there is no init part.
+    ICHECK(!op->init.defined());
+    // Step 1. Update outer buffer access info using buffer region
+    for (const BufferRegion& region : op->reads) {
+      VisitBufferAccess(region);
+    }
+    for (const BufferRegion& region : op->writes) {
+      VisitBufferAccess(region);
+    }
+
+    // Step 2. Update inner buffer
+    // Step 2.1. rebuild map buffer_var_in_scope
+    std::unordered_map<Var, Buffer, ObjectPtrHash, ObjectPtrEqual> 
buffer_var_in_scope;
+    for (const Buffer& buffer : op->alloc_buffers) {
+      buffer_var_in_scope.emplace(buffer->data, buffer);
+    }
+    // Step 2.2 Record top stack element before recursive visiting.
+    size_t stack_top = buffer_access_stack_.size();
+
+    // Step 2.3. Update the buffer_var_in_scope_ of visitor and visit 
recursively
+    std::swap(buffer_var_in_scope, buffer_var_in_scope_);
+    StmtExprVisitor::VisitStmt_(op);
+    std::swap(buffer_var_in_scope, buffer_var_in_scope_);
+
+    // Step 2.4. Combine and relax access
+    std::unordered_map<Buffer, NDIntSet, ObjectPtrHash, ObjectPtrEqual> 
relaxed_region =
+        CombineAndRelax(stack_top);
+
+    // Step 2.5. Visit ancestor_loops and try to relax outer thread loops.
+    for (const Buffer& buffer : op->alloc_buffers) {
+      auto it = relaxed_region.find(buffer);
+      ICHECK(it != relaxed_region.end());
+      const NDIntSet& nd_int_set = it->second;
+      std::unordered_map<const VarNode*, arith::IntSet> dom_map;
+      for (const ForNode* loop : ancestor_loops_) {
+        const VarNode* loop_var = loop->loop_var.get();
+        if (NeedRelaxThread(GetRef<For>(loop), 
runtime::StorageScope::Create(buffer->scope))) {
+          dom_map[loop_var] = IntSetFromMinExtent(loop->min, loop->extent);
+        }
+      }
+      NDIntSet int_set = EvalNDIntSet(nd_int_set, dom_map);
+      buffer_access_region_[buffer] = NarrowBufferRegionFromNDIntSet(int_set, 
buffer->shape);
+    }
+  }
+
+  /**************** Helper functions ****************/
+
+  void VisitBufferAccess(const BufferRegion& buffer_region) {
+    const BufferNode* buffer = buffer_region->buffer.get();
+    auto it = buffer_var_in_scope_.find(buffer->data);
+    if (it != buffer_var_in_scope_.end()) {
+      const Buffer& buffer = it->second;
+      const BufferAccessInfo* info =
+          arena_.make<BufferAccessInfo>(buffer, 
NDIntSetFromRegion(buffer_region->region));
+      buffer_access_stack_.push(info);
+    }
+  }
+
+  void VisitBufferVar(const Var& var) {
+    auto it = buffer_var_in_scope_.find(var);
+    if (it != buffer_var_in_scope_.end()) {
+      const Buffer& buffer = it->second;
+      VisitBufferAccess(BufferRegion::FullRegion(buffer));
+    }
+  }
+
+  /*!
+   * \brief Combine buffer accesses in the sub-tree.
+   * \details The access info is stored in a stack by DFS order, so that the 
accesses in the
+   *          sub-tree are top-n elements in the stack.
+   * \param stack_top compact the access information in `stack[stack_top:end]`.
+   */
+  std::unordered_map<Buffer, NDIntSet, ObjectPtrHash, ObjectPtrEqual> 
CombineAndRelax(
+      size_t stack_top) {
+    std::unordered_map<Buffer, NDIntSet, ObjectPtrHash, ObjectPtrEqual> 
accesses;
+    while (buffer_access_stack_.size() > stack_top) {
+      const BufferAccessInfo* info = buffer_access_stack_.top();
+      buffer_access_stack_.pop();
+      NDIntSet nd_int_set = EvalNDIntSet(info->accessed_region, 
iter_dom_map_on_post_order_);
+      auto it = accesses.find(info->buffer);
+      if (it != accesses.end()) {
+        NDIntSetUnionWith(&it->second, nd_int_set);
+      } else {
+        accesses[info->buffer] = nd_int_set;
+      }
+    }
+    return accesses;
+  }
+
+  /*!
+   * \brief Combine buffer accesses in the sub-tree and push the combined 
result into the stack.
+   * \details The access info is stored in a stack by DFS order, so that the 
accesses in the
+   *          sub-tree are top-n elements in the stack.
+   * \param stack_top The top element of the stack before visiting the 
sub-tree.
+   */
+  std::unordered_map<Buffer, NDIntSet, ObjectPtrHash, ObjectPtrEqual> 
CombineRelaxAndPushStack(
+      size_t stack_top) {
+    std::unordered_map<Buffer, NDIntSet, ObjectPtrHash, ObjectPtrEqual> 
accesses =
+        CombineAndRelax(stack_top);
+    for (const auto& kv : accesses) {
+      const Buffer& buffer = kv.first;
+      const NDIntSet& int_set = kv.second;
+      buffer_access_stack_.push(arena_.make<BufferAccessInfo>(buffer, 
int_set));
+    }
+    return accesses;
+  }
+
+  /*! \brief Check whether the thread binding loop should be relaxed with 
given storage scope. */
+  static bool NeedRelaxThread(const For& loop, const runtime::StorageScope& 
scope) {
+    if (loop->kind != ForKind::kThreadBinding) {
+      return false;
+    }
+    ICHECK(loop->thread_binding.defined());
+    IterVar binding = loop->thread_binding.value();
+    runtime::ThreadScope ts = 
runtime::ThreadScope::Create(binding->thread_tag);
+
+    // When there is warp memory
+    // threadIdx.x must be set to be warp index.
+    if (scope.rank == runtime::StorageRank::kWarp && ts.rank == 1 && 
ts.dim_index == 0) {
+      return true;
+    }
+    return static_cast<int>(scope.rank) <= ts.rank;
+  }
+
+  /**************** Class members ****************/
+
+  /*! \brief Buffer access in DFS order. */
+  std::stack<const BufferAccessInfo*> buffer_access_stack_;
+  /*! \brief The loops from the current node up to the root. */
+  std::vector<const ForNode*> ancestor_loops_;
+  /*! \brief The vars of the buffer allocated under the current block. */
+  std::unordered_map<Var, Buffer, ObjectPtrHash, ObjectPtrEqual> 
buffer_var_in_scope_;
+  /*! \brief The map from loop vars to their iter range. */
+  std::unordered_map<const VarNode*, arith::IntSet> 
iter_dom_map_on_post_order_;
+  /*! \brief The map from Buffer to it entire access region, used for 
returning. */
+  std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual> 
buffer_access_region_;
+  /*! \brief Internal arena. */
+  support::Arena arena_;
+};
+
+/*! \brief Reallocate the buffers with minimal region. */
+class BufferCompactor : public StmtExprMutator {
+ public:
+  static Stmt Compact(
+      const PrimFunc& f,
+      const std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual>& 
regions) {
+    std::unordered_map<Buffer, BufferAllocInfo, ObjectPtrHash, ObjectPtrEqual> 
buffer_info;
+
+    for (const auto& kv : regions) {
+      const Buffer& buffer = kv.first;
+      Region region = kv.second;
+      buffer_info.emplace(buffer, BufferAllocInfo(std::move(region)));
+    }
+    BufferCompactor compactor(std::move(buffer_info));
+    Stmt stmt = compactor(f->body);
+    return stmt;
+  }
+
+ private:
+  struct BufferAllocInfo {
+    /*! \brief The buffer access region. */
+    Region region;
+    /*!
+     * \brief The reallocated buffer with minimal size.
+     * \note The value if NullOpt if the buffer do not need reallocate (e.g 
parameter buffer).
+     */
+    Buffer new_buffer;
+
+    explicit BufferAllocInfo(Region region) : region(std::move(region)) {}
+  };
+
+  explicit BufferCompactor(
+      std::unordered_map<Buffer, BufferAllocInfo, ObjectPtrHash, 
ObjectPtrEqual> buffer_info)
+      : buffer_info_(std::move(buffer_info)) {}
+
+  Stmt VisitStmt_(const BufferStoreNode* _op) final {
+    BufferStore store = 
Downcast<BufferStore>(StmtExprMutator::VisitStmt_(_op));
+    BufferStoreNode* op = store.CopyOnWrite();
+    RewriteBufferAccess(&op->buffer, &op->indices);
+    return std::move(store);
+  }
+
+  PrimExpr VisitExpr_(const BufferLoadNode* _op) final {
+    BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(_op));
+    BufferLoadNode* op = load.CopyOnWrite();
+    RewriteBufferAccess(&op->buffer, &op->indices);
+    return std::move(load);
+  }
+
+  Stmt VisitStmt_(const BlockNode* op) final {
+    // Step 0. Check there is no Init part.
+    ICHECK(!op->init.defined());
+    // Step 1. Reallocate and rewrite alloc_buffers, also update 
BufferAllocInfo.
+    Array<Buffer> alloc_buffers = RewriteAllocBuffer(op->alloc_buffers);
+    // Step 2. Recursively rewrite BufferLoad/BufferStore.
+    Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
+    // Step 3. Update block signature.
+    BlockNode* n = block.CopyOnWrite();
+    RewriteBufferRegions(&n->reads);
+    RewriteBufferRegions(&n->writes);
+    n->alloc_buffers = std::move(alloc_buffers);
+    return std::move(block);
+  }
+
+  Array<Buffer> RewriteAllocBuffer(const Array<Buffer>& buffers) {
+    Array<Buffer> result;
+    result.reserve(buffers.size());
+    for (const Buffer& buffer : buffers) {
+      auto it = buffer_info_.find(buffer);
+      ICHECK(it != buffer_info_.end());
+      BufferAllocInfo& info = it->second;
+      Array<PrimExpr> shape;
+      shape.reserve(info.region.size());
+      for (const Range& range : info.region) {
+        shape.push_back(range->extent);
+      }
+      ObjectPtr<BufferNode> n = make_object<BufferNode>(*buffer.get());
+      n->shape = std::move(shape);
+      info.new_buffer = Buffer(std::move(n));
+      result.push_back(info.new_buffer);
+    }
+    return result;
+  }
+
+  void RewriteBufferAccess(Buffer* buffer, Array<PrimExpr>* indices) const {
+    auto it = buffer_info_.find(*buffer);
+    if (it == buffer_info_.end()) {
+      // Skip if the buffer is parameter
+      return;
+    }
+    const BufferAllocInfo& info = it->second;
+    ICHECK_EQ(indices->size(), info.region.size());
+    int ndim = info.region.size();
+    Array<PrimExpr> new_indices;
+    new_indices.reserve(ndim);
+    for (int i = 0; i < ndim; ++i) {
+      new_indices.push_back((*indices)[i] - info.region[i]->min);
+    }
+    *buffer = info.new_buffer;
+    *indices = std::move(new_indices);
+  }
+
+  void RewriteBufferRegion(Buffer* buffer, Region* region) const {
+    auto it = buffer_info_.find(*buffer);
+    if (it == buffer_info_.end()) {
+      // Skip if the buffer is parameter
+      return;
+    }
+    const BufferAllocInfo& info = it->second;
+    ICHECK_EQ(region->size(), info.region.size());
+    Region new_region;
+    new_region.reserve(info.region.size());
+    for (size_t i = 0; i < info.region.size(); ++i) {
+      const Range& range = (*region)[i];
+      new_region.push_back(Range::FromMinExtent(range->min - 
info.region[i]->min, range->extent));
+    }
+    *buffer = info.new_buffer;
+    *region = std::move(new_region);
+  }
+
+  void RewriteBufferRegions(Array<BufferRegion>* regions) const {
+    Array<BufferRegion> new_regions;
+    new_regions.reserve(regions->size());
+    for (const auto& region : *regions) {
+      BufferRegion buffer_region = region;
+      BufferRegionNode* p = buffer_region.CopyOnWrite();
+      RewriteBufferRegion(&p->buffer, &p->region);
+      new_regions.push_back(buffer_region);
+    }
+    *regions = std::move(new_regions);
+  }
+
+  /*! \brief The allocation information about each buffer. */
+  std::unordered_map<Buffer, BufferAllocInfo, ObjectPtrHash, ObjectPtrEqual> 
buffer_info_;
+};
+
+PrimFunc CompactBufferAllocation(PrimFunc f) {
+  PrimFuncNode* fptr = f.CopyOnWrite();
+  std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual> region =
+      BufferAccessRegionCollector::Collect(f);
+  fptr->body = BufferCompactor::Compact(f, region);
+  return f;
+}
+
+namespace transform {
+
+Pass CompactBufferAllocation() {
+  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
+    return CompactBufferAllocation(std::move(f));
+  };
+  return CreatePrimFuncPass(pass_func, 0, "tir.CompactBufferAllocation", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.CompactBufferAllocation")
+    .set_body_typed(CompactBufferAllocation);
+}  // namespace transform
+
+}  // namespace tir
+}  // namespace tvm
diff --git a/src/tir/transforms/convert_blocks_to_opaque.cc 
b/src/tir/transforms/convert_blocks_to_opaque.cc
new file mode 100644
index 0000000..4c5e1dd
--- /dev/null
+++ b/src/tir/transforms/convert_blocks_to_opaque.cc
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file convert_block_to_opaque.cc
+ * \brief Convert the blocks to opaque blocks which do not have block vars.
+ */
+
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief Substitute expr via BlockRealize value bindings and convert each 
block into opaque
+ *        blocks.
+ */
+class OpaqueBlockConverter : public StmtExprMutator {
+ public:
+  static Stmt Substitute(const PrimFunc& f) {
+    OpaqueBlockConverter substituter;
+    return substituter.VisitStmt(f->body);
+  }
+
+ private:
+  OpaqueBlockConverter() = default;
+
+  PrimExpr VisitExpr_(const VarNode* var) final {
+    auto it = var_substitutes_.find(var);
+    if (it != var_substitutes_.end()) {
+      return it->second;
+    }
+    return GetRef<Var>(var);
+  }
+
+  Stmt VisitStmt_(const BlockNode* block) final {
+    ICHECK(!block->init.defined())
+        << "Block Init part is not allowed in pass ConvertBlocksToOpaque";
+    Block new_block = Downcast<Block>(StmtExprMutator::VisitStmt_(block));
+    if (!new_block->iter_vars.empty()) {
+      new_block.CopyOnWrite()->iter_vars.clear();
+    }
+    return std::move(new_block);
+  }
+
+  Stmt VisitStmt_(const BlockRealizeNode* realize) final {
+    const auto* block_op = realize->block.get();
+    ICHECK(!block_op->init.defined());
+    // Step 1. Update "block vars => binding values" for substitution.
+    ICHECK_EQ(block_op->iter_vars.size(), realize->iter_values.size());
+    for (int i = 0, n = block_op->iter_vars.size(); i < n; ++i) {
+      IterVar block_var = block_op->iter_vars[i];
+      PrimExpr v = this->VisitExpr(realize->iter_values[i]);
+      var_substitutes_.emplace(block_var->var.get(), v);
+    }
+    // Step 2. Visit recursively.
+    BlockRealize new_realize = 
Downcast<BlockRealize>(StmtExprMutator::VisitStmt_(realize));
+    if (!new_realize->iter_values.empty()) {
+      new_realize.CopyOnWrite()->iter_values.clear();
+    }
+    return std::move(new_realize);
+  }
+
+  /*! \brief The map from block vars to thier binding values. */
+  std::unordered_map<const VarNode*, PrimExpr> var_substitutes_;
+};
+
+PrimFunc ConvertBlocksToOpaque(PrimFunc f) {
+  PrimFuncNode* fptr = f.CopyOnWrite();
+  fptr->body = OpaqueBlockConverter::Substitute(f);
+  return f;
+}
+
+namespace transform {
+
+Pass ConvertBlocksToOpaque() {
+  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
+    return ConvertBlocksToOpaque(std::move(f));
+  };
+  return CreatePrimFuncPass(pass_func, 0, "tir.ConvertBlocksToOpaque", {});
+}
+
+TVM_REGISTER_GLOBAL("tir.transform.ConvertBlocksToOpaque").set_body_typed(ConvertBlocksToOpaque);
+}  // namespace transform
+
+}  // namespace tir
+}  // namespace tvm
diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py 
b/tests/python/unittest/test_tir_transform_compact_buffer_region.py
new file mode 100644
index 0000000..7c06b5e
--- /dev/null
+++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py
@@ -0,0 +1,331 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+from tvm import tir
+from tvm.script import ty
+
+
+def _check(original, transformed):
+    func = original
+    mod = tvm.IRModule.from_expr(func)
+    mod = tvm.tir.transform.CompactBufferAllocation()(mod)
+    mod = tvm.tir.transform.Simplify()(mod)
+    tvm.ir.assert_structural_equal(mod["main"], transformed)
+
+
+@tvm.script.tir
+def elementwise_func(a: ty.handle, c: ty.handle) -> None:
+    A = tir.match_buffer(a, (16, 16), "float32")
+    C = tir.match_buffer(c, (16, 16), "float32")
+    for i in range(0, 16):
+        with tir.block([]):
+            tir.reads(A[i, 0:16])
+            tir.writes(C[i, 0:16])
+            B = tir.alloc_buffer((16, 16), "float32")
+            for j in range(0, 16):
+                with tir.block([]) as []:
+                    tir.reads(A[i, j])
+                    tir.writes(B[i, j])
+                    B[i, j] = A[i, j] + 1.0
+            for j in range(0, 16):
+                with tir.block([]) as []:
+                    tir.reads(B[i, j])
+                    tir.writes(C[i, j])
+                    C[i, j] = B[i, j] * 2.0
+
+
+@tvm.script.tir
+def compacted_elementwise_func(a: ty.handle, c: ty.handle) -> None:
+    A = tir.match_buffer(a, (16, 16), "float32")
+    C = tir.match_buffer(c, (16, 16), "float32")
+    for i in range(0, 16):
+        with tir.block([]):
+            tir.reads(A[i, 0:16])
+            tir.writes(C[i, 0:16])
+            B = tir.alloc_buffer((1, 16), "float32")
+            for j in range(0, 16):
+                with tir.block() as []:
+                    tir.reads(A[i, j])
+                    tir.writes(B[0, j])
+                    B[0, j] = A[i, j] + 1.0
+            for j in range(0, 16):
+                with tir.block() as []:
+                    tir.reads(B[0, j])
+                    tir.writes(C[i, j])
+                    C[i, j] = B[0, j] * 2.0
+
+
+@tvm.script.tir
+def unschedulable_func(a: ty.handle, c: ty.handle) -> None:
+    A = tir.match_buffer(a, (16, 16), "float32")
+    C = tir.match_buffer(c, (16, 16), "float32")
+    for i in range(0, 16):
+        with tir.block([]):
+            tir.reads(A[i, 0:16])
+            tir.writes(C[i, 0:16])
+            B = tir.alloc_buffer((16, 16), "float32")
+            for j in range(0, 16):
+                tir.store(B.data, i * 16 + j, A[i, j] + 1.0)
+            for j in range(0, 16):
+                C[i, j] = B[i, j] * 2.0
+
+
+@tvm.script.tir
+def param_buffer_access_func(a: ty.handle, c: ty.handle) -> None:
+    A = tir.match_buffer(a, (20, 20), "float32")
+    B = tir.match_buffer(c, (20, 20), "float32")
+    for i in range(0, 16):
+        with tir.block([]):
+            tir.reads(A[i, 0:16])
+            tir.writes(B[i, 0:16])
+            for j in range(0, 16):
+                with tir.block([]) as []:
+                    tir.reads(A[i, j])
+                    tir.writes(B[i, j])
+                    B[i, j] = A[i, j] + 1.0
+
+
+@tvm.script.tir
+def shared_mem_func(a: ty.handle, c: ty.handle) -> None:
+    A = tir.match_buffer(a, (16, 16), "float32")
+    C = tir.match_buffer(c, (16, 16), "float32")
+    for i0 in tir.thread_binding(0, 2, thread="blockIdx.x"):
+        for i1 in tir.thread_binding(0, 2, thread="vthread"):
+            for i2 in tir.thread_binding(0, 4, thread="threadIdx.x"):
+                with tir.block([]):
+                    tir.reads(A[i0 * 8 + i1 * 4 + i2, 0:16])
+                    tir.writes(C[i0 * 8 + i1 * 4 + i2, 0:16])
+                    B = tir.alloc_buffer((16, 16), "float32", scope="shared")
+                    for j in range(0, 16):
+                        with tir.block([]) as []:
+                            tir.reads(A[i0 * 8 + i1 * 4 + i2, j])
+                            tir.writes(B[i0 * 8 + i1 * 4 + i2, j])
+                            B[i0 * 8 + i1 * 4 + i2, j] = A[i0 * 8 + i1 * 4 + 
i2, j] + 1.0
+                    for j in range(0, 16):
+                        with tir.block([]) as []:
+                            tir.reads(B[i0 * 8 + i1 * 4 + i2, j])
+                            tir.writes(C[i0 * 8 + i1 * 4 + i2, j])
+                            C[i0 * 8 + i1 * 4 + i2, j] = B[i0 * 8 + i1 * 4 + 
i2, j] * 2.0
+
+
+@tvm.script.tir
+def compacted_shared_mem_func(a: ty.handle, c: ty.handle) -> None:
+    A = tir.match_buffer(a, (16, 16), "float32")
+    C = tir.match_buffer(c, (16, 16), "float32")
+    for i0 in tir.thread_binding(0, 2, thread="blockIdx.x"):
+        for i1 in tir.thread_binding(0, 2, thread="vthread"):
+            for i2 in tir.thread_binding(0, 4, thread="threadIdx.x"):
+                with tir.block([]):
+                    tir.reads(A[i0 * 8 + i1 * 4 + i2, 0:16])
+                    tir.writes(C[i0 * 8 + i1 * 4 + i2, 0:16])
+                    B = tir.alloc_buffer((8, 16), "float32", scope="shared")
+                    for j in range(0, 16):
+                        with tir.block([]) as []:
+                            tir.reads(A[i0 * 8 + i1 * 4 + i2, j])
+                            tir.writes(B[i1 * 4 + i2, j])
+                            B[i1 * 4 + i2, j] = A[i0 * 8 + i1 * 4 + i2, j] + 
1.0
+                    for j in range(0, 16):
+                        with tir.block([]) as []:
+                            tir.reads(B[i1 * 4 + i2, j])
+                            tir.writes(C[i0 * 8 + i1 * 4 + i2, j])
+                            C[i0 * 8 + i1 * 4 + i2, j] = B[i1 * 4 + i2, j] * 
2.0
+
+
+@tvm.script.tir
+def warp_mem_func(a: ty.handle, c: ty.handle) -> None:
+    A = tir.match_buffer(a, (16, 16), "float32")
+    C = tir.match_buffer(c, (16, 16), "float32")
+    for i0 in tir.thread_binding(0, 2, thread="blockIdx.x"):
+        for i1 in tir.thread_binding(0, 2, thread="vthread"):
+            for i2 in tir.thread_binding(0, 4, thread="threadIdx.x"):
+                with tir.block([]):
+                    tir.reads(A[i0 * 8 + i1 * 4 + i2, 0:16])
+                    tir.writes(C[i0 * 8 + i1 * 4 + i2, 0:16])
+                    B = tir.alloc_buffer((16, 16), "float32", scope="warp")
+                    for j in range(0, 16):
+                        with tir.block([]) as []:
+                            tir.reads(A[i0 * 8 + i1 * 4 + i2, j])
+                            tir.writes(B[i0 * 8 + i1 * 4 + i2, j])
+                            B[i0 * 8 + i1 * 4 + i2, j] = A[i0 * 8 + i1 * 4 + 
i2, j] + 1.0
+                    for j in range(0, 16):
+                        with tir.block([]) as []:
+                            tir.reads(B[i0 * 8 + i1 * 4 + i2, j])
+                            tir.writes(C[i0 * 8 + i1 * 4 + i2, j])
+                            C[i0 * 8 + i1 * 4 + i2, j] = B[i0 * 8 + i1 * 4 + 
i2, j] * 2.0
+
+
+@tvm.script.tir
+def compacted_warp_mem_func(a: ty.handle, c: ty.handle) -> None:
+    A = tir.match_buffer(a, (16, 16), "float32")
+    C = tir.match_buffer(c, (16, 16), "float32")
+    for i0 in tir.thread_binding(0, 2, thread="blockIdx.x"):
+        for i1 in tir.thread_binding(0, 2, thread="vthread"):
+            for i2 in tir.thread_binding(0, 4, thread="threadIdx.x"):
+                with tir.block([]):
+                    tir.reads(A[i0 * 8 + i1 * 4 + i2, 0:16])
+                    tir.writes(C[i0 * 8 + i1 * 4 + i2, 0:16])
+                    B = tir.alloc_buffer((4, 16), "float32", scope="warp")
+                    for j in range(0, 16):
+                        with tir.block([]) as []:
+                            tir.reads(A[i0 * 8 + i1 * 4 + i2, j])
+                            tir.writes(B[i2, j])
+                            B[i2, j] = A[i0 * 8 + i1 * 4 + i2, j] + 1.0
+                    for j in range(0, 16):
+                        with tir.block([]) as []:
+                            tir.reads(B[i2, j])
+                            tir.writes(C[i0 * 8 + i1 * 4 + i2, j])
+                            C[i0 * 8 + i1 * 4 + i2, j] = B[i2, j] * 2.0
+
+
+@tvm.script.tir
+def symbolic_func(a: ty.handle, c: ty.handle, n: ty.int32) -> None:
+    A = tir.match_buffer(a, (n * 8,), "float32")
+    C = tir.match_buffer(c, (n * 8,), "float32")
+    for i in range(0, n):
+        with tir.block([]):
+            tir.reads(A[i * 8 : i * 8 + 8])
+            tir.writes(C[i * 8 : i * 8 + 8])
+            B = tir.alloc_buffer((n * 8,), "float32")
+            for j in range(0, 8):
+                with tir.block([]) as []:
+                    tir.reads(A[i * 8 + j])
+                    tir.writes(B[i * 8 + j])
+                    B[i * 8 + j] = A[i * 8 + j] + 1.0
+            for j in range(0, 8):
+                with tir.block([]) as []:
+                    tir.reads(B[i * 8 + j])
+                    tir.writes(C[i * 8 + j])
+                    C[i * 8 + j] = B[i * 8 + j] * 2.0
+
+
+@tvm.script.tir
+def compacted_symbolic_func(a: ty.handle, c: ty.handle, n: ty.int32) -> None:
+    A = tir.match_buffer(a, (n * 8,), "float32")
+    C = tir.match_buffer(c, (n * 8,), "float32")
+    for i in range(0, n):
+        with tir.block([]):
+            tir.reads(A[i * 8 : i * 8 + 8])
+            tir.writes(C[i * 8 : i * 8 + 8])
+            B = tir.alloc_buffer((8,), "float32")
+            for j in range(0, 8):
+                with tir.block([]) as []:
+                    tir.reads(A[i * 8 + j])
+                    tir.writes(B[j])
+                    B[j] = A[i * 8 + j] + 1.0
+            for j in range(0, 8):
+                with tir.block([]) as []:
+                    tir.reads(B[j])
+                    tir.writes(C[i * 8 + j])
+                    C[i * 8 + j] = B[j] * 2.0
+
+
+@tvm.script.tir
+def complex_func(a: ty.handle, c: ty.handle, n: ty.int32) -> None:
+    A = tir.match_buffer(a, (8, 8), "float32")
+    C = tir.match_buffer(c, (8, 8), "float32")
+    for i in range(0, 8):
+        with tir.block([]):
+            tir.reads(A[0, 8])
+            tir.writes(C[0, 8])
+            B = tir.alloc_buffer((8, 8), "float32")
+            for j in range(0, 4):
+                with tir.block([]) as []:
+                    D = tir.alloc_buffer((8, 8), "float32")
+                    tir.reads(A[i, j])
+                    tir.writes(B[i, j])
+                    for k in range(4, 8):
+                        D[k, j] = 1.0
+                    for k in range(2, 4):
+                        tir.store(B.data, j, A[i, j] + D[k, j])
+            for j in range(3, 5):
+                with tir.block([]) as []:
+                    tir.reads(B[i, j])
+                    tir.writes(C[i, j])
+                    C[i, j] = B[i, j]
+            for j in range(6, 8):
+                with tir.block([]) as []:
+                    tir.reads(B[i, j])
+                    tir.writes(C[i, j])
+                    C[i, j] = B[i, j]
+
+
+@tvm.script.tir
+def compacted_complex_func(a: ty.handle, c: ty.handle, n: ty.int32) -> None:
+    A = tir.match_buffer(a, (8, 8), "float32")
+    C = tir.match_buffer(c, (8, 8), "float32")
+    for i in range(0, 8):
+        with tir.block([]):
+            tir.reads(A[0, 8])
+            tir.writes(C[0, 8])
+            B = tir.alloc_buffer((1, 8), "float32")
+            for j in range(0, 4):
+                with tir.block([]) as []:
+                    D = tir.alloc_buffer((6, 1), "float32")
+                    tir.reads(A[i, j])
+                    tir.writes(B[0, j])
+                    for k in range(4, 8):
+                        D[k - 2, 0] = 1.0
+                    for k in range(2, 4):
+                        tir.store(B.data, j, A[i, j] + D[k - 2, 0])
+            for j in range(3, 5):
+                with tir.block([]) as []:
+                    tir.reads(B[0, j])
+                    tir.writes(C[i, j])
+                    C[i, j] = B[0, j]
+            for j in range(6, 8):
+                with tir.block([]) as []:
+                    tir.reads(B[0, j])
+                    tir.writes(C[i, j])
+                    C[i, j] = B[0, j]
+
+
+def test_elementwise():
+    _check(elementwise_func, compacted_elementwise_func)
+
+
+def test_unschedulable_block():
+    _check(unschedulable_func, unschedulable_func)  # changes nothing
+
+
+def test_param_access():
+    _check(param_buffer_access_func, param_buffer_access_func)  # changes 
nothing
+
+
+def test_shared_mem():
+    _check(shared_mem_func, compacted_shared_mem_func)
+
+
+def test_warp_mem():
+    _check(warp_mem_func, compacted_warp_mem_func)
+
+
+def test_symbolic():
+    _check(symbolic_func, compacted_symbolic_func)
+
+
+def test_complex():
+    _check(complex_func, compacted_complex_func)
+
+
+if __name__ == "__main__":
+    test_elementwise()
+    test_unschedulable_block()
+    test_param_access()
+    test_shared_mem()
+    test_warp_mem()
+    test_symbolic()
+    test_complex()
diff --git 
a/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py 
b/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py
new file mode 100644
index 0000000..38fe1c9
--- /dev/null
+++ b/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py
@@ -0,0 +1,77 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+from tvm import tir
+from tvm.script import ty
+
+
+def _check(original, transformed):
+    func = original
+    mod = tvm.IRModule.from_expr(func)
+    mod = tvm.tir.transform.ConvertBlocksToOpaque()(mod)
+    mod = tvm.tir.transform.Simplify()(mod)
+    tvm.ir.assert_structural_equal(mod["main"], transformed)
+
+
+@tvm.script.tir
+def elementwise_func(a: ty.handle, c: ty.handle) -> None:
+    A = tir.match_buffer(a, (16, 16), "float32")
+    C = tir.match_buffer(c, (16, 16), "float32")
+    for i in range(0, 16):
+        with tir.block([]):
+            tir.reads(A[i, 0:16])
+            tir.writes(C[i, 0:16])
+            B = tir.alloc_buffer((16, 16), "float32")
+            for j in range(0, 16):
+                with tir.block([16, 16]) as [vi, vj]:
+                    tir.bind(vi, i)
+                    tir.bind(vj, j)
+                    B[vi, vj] = A[vi, vj] + 1.0
+            for j in range(0, 16):
+                with tir.block([16, 16]) as [vi, vj]:
+                    tir.bind(vi, i)
+                    tir.bind(vj, j)
+                    C[vi, vj] = B[vi, vj] * 2.0
+
+
+@tvm.script.tir
+def substituted_elementwise_func(a: ty.handle, c: ty.handle) -> None:
+    A = tir.match_buffer(a, (16, 16), "float32")
+    C = tir.match_buffer(c, (16, 16), "float32")
+    for i in range(0, 16):
+        with tir.block([]):
+            tir.reads(A[i, 0:16])
+            tir.writes(C[i, 0:16])
+            B = tir.alloc_buffer([16, 16], "float32")
+            for j in range(0, 16):
+                with tir.block() as []:
+                    tir.reads(A[i, j])
+                    tir.writes(B[i, j])
+                    B[i, j] = A[i, j] + 1.0
+            for j in range(0, 16):
+                with tir.block() as []:
+                    tir.reads(B[i, j])
+                    tir.writes(C[i, j])
+                    C[i, j] = B[i, j] * 2.0
+
+
+def test_elementwise():
+    _check(elementwise_func, substituted_elementwise_func)
+
+
+if __name__ == "__main__":
+    test_elementwise()

Reply via email to