This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ecc2e56  [TIR][Schedule] Analysis functions to check if compute_inline 
and com… (#9743)
ecc2e56 is described below

commit ecc2e563df1a0b1d7e9d712bce90ee94948c3848
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Wed Dec 15 00:31:45 2021 -0500

    [TIR][Schedule] Analysis functions to check if compute_inline and com… 
(#9743)
    
    * [TIR][Schedule] Analysis functions to check if compute_inline and 
compute_inline is allowed
    
    Co-authored-by: Siyuan Feng <hzfen...@sjtu.edu.cn>
    Co-authored-by: Bohan Hou 
<32121147+spectrometer...@users.noreply.github.com>
    Co-authored-by: Hongyi Jin <3231950...@qq.com>
    Co-authored-by: Ruihang Lai <lairuihangdongd...@qq.com>
    Co-authored-by: Junru Shao <junrushao1...@gmail.com>
    Co-authored-by: Wuwei Lin <wu...@apache.org>
    Co-authored-by: Xiyou Zhou <xi...@octoml.ai>
    
    * Address comments
    
    Co-authored-by: Siyuan Feng <hzfen...@sjtu.edu.cn>
    Co-authored-by: Bohan Hou 
<32121147+spectrometer...@users.noreply.github.com>
    Co-authored-by: Hongyi Jin <3231950...@qq.com>
    Co-authored-by: Ruihang Lai <lairuihangdongd...@qq.com>
    Co-authored-by: Junru Shao <junrushao1...@gmail.com>
    Co-authored-by: Xiyou Zhou <xi...@octoml.ai>
---
 src/tir/schedule/analysis.h                        | 41 ++++++++++++++
 src/tir/schedule/primitive/compute_at.cc           | 46 ++++++++++++---
 src/tir/schedule/primitive/compute_inline.cc       | 66 +++++++++++++++++++---
 .../unittest/test_tir_schedule_compute_inline.py   | 29 ++++++++++
 4 files changed, 168 insertions(+), 14 deletions(-)

diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h
index 42e0e00..82f4afa 100644
--- a/src/tir/schedule/analysis.h
+++ b/src/tir/schedule/analysis.h
@@ -393,6 +393,47 @@ 
std::vector<runtime::TypedPackedFunc<CommReducer(DataType)>> GetReducerGetters()
 bool FromIdentityCombiner(const PrimExpr& identity, const BufferStore& 
combiner,
                           CommReducer* result_reducer, PrimExpr* lhs, 
PrimExpr* rhs);
 
+/******** Misc ********/
+
+/*!
+ * \brief Checks if a block could be successfully computed inline into its 
consumer
+ * \param self The schedule state
+ * \param block_sref The block to be checked
+ * \return A boolean indicating whether the block could be successfully 
computed inline
+ */
+bool CanComputeInline(const ScheduleState& self, const StmtSRef& block_sref);
+
+/*!
+ * \brief Checks if a block could be successfully computed inline into its 
producer
+ * \param self The schedule state
+ * \param block_sref The block to be checked
+ * \return A boolean indicating whether the block could be successfully 
computed inline
+ */
+bool CanReverseComputeInline(const ScheduleState& self, const StmtSRef& 
block_sref);
+
+/*!
+ * \brief Checks if a producer block could be successfully computed at the 
specific loop.
+ * \param self The schedule state
+ * \param block_sref The block to be moved
+ * \param loop_sref The loop where the block to be moved to
+ * \param preserve_unit_loops Whether to keep the trivial loops whose extents 
are 1
+ * \return A boolean indicating whether the block could be successfully 
compute at the specific loop
+ */
+bool CanComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const 
StmtSRef& loop_sref,
+                  bool preserve_unit_loops);
+
+/*!
+ * \brief Checks if a consumer block could be successfully computed at the 
specific loop.
+ * \param self The schedule state
+ * \param block_sref The block to be moved
+ * \param loop_sref The loop where the block to be moved to
+ * \param preserve_unit_loops Whether to keep the trivial loops whose extents 
are 1
+ * \return A boolean indicating whether the block could be successfully 
reverse compute at the
+ * specific loop
+ */
+bool CanReverseComputeAt(const ScheduleState& self, const StmtSRef& block_sref,
+                         const StmtSRef& loop_sref, bool preserve_unit_loops);
+
 }  // namespace tir
 }  // namespace tvm
 
diff --git a/src/tir/schedule/primitive/compute_at.cc 
b/src/tir/schedule/primitive/compute_at.cc
index 0dae50a..00886e8 100644
--- a/src/tir/schedule/primitive/compute_at.cc
+++ b/src/tir/schedule/primitive/compute_at.cc
@@ -451,7 +451,8 @@ void CalculateProvidedRequiredRegions(
 
 template <bool is_compute_at>
 void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& 
block_sref,
-                                     const StmtSRef& loop_sref, bool 
preserve_unit_loops) {
+                                     const StmtSRef& loop_sref, bool 
preserve_unit_loops,
+                                     arith::Analyzer* analyzer, bool 
check_only = false) {
   const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
   const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
   // Step 1. Bunch of checks
@@ -463,11 +464,10 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, 
const StmtSRef& block_s
   BlockScope scope = self->GetBlockScope(scope_root_sref);
   Array<StmtSRef> producer_srefs = GetProducers(block_sref, scope);
   Array<StmtSRef> consumer_srefs = GetConsumers(block_sref, scope);
-  arith::Analyzer analyzer;
   // Check condition 3): `block` and `loop` are under the same scope,
   // and `loop` is not the ancestor of `block`
   NotInSameScopeError::CheckAndBindLoopDomain(self, block_sref, loop_sref, 
scope_root_sref,
-                                              &analyzer);
+                                              analyzer);
   // Check condition 4): `block` is not an output block
   if (is_compute_at) {
     CheckNotOutputBlock(self, block_sref, scope_root_sref);
@@ -501,29 +501,61 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, 
const StmtSRef& block_s
       CalculateBlockVarDomain(/*iter_vars=*/block->iter_vars,
                               /*provided_regions=*/std::move(provided_regions),
                               /*required_regions=*/std::move(required_regions),
-                              /*analyzer=*/&analyzer);
+                              /*analyzer=*/analyzer);
   // Step 6. Create the new scope according to the iteration domain
   reconstructor.MakeNewLoop(/*insert_position=*/insert_position, 
/*iter_doms=*/std::move(iter_doms),
                             /*preserve_unit_loops=*/preserve_unit_loops);
   Block new_scope_root = Downcast<Block>(reconstructor(scope_root));
+
   // Step 7. Do the actual replacement
+  if (check_only) {
+    return;
+  }
   self->Replace(scope_root_sref, new_scope_root, {{scope_root, 
new_scope_root}});
   // Step 8. Update the cached flags
   BlockInfo& block_info = self->block_info[block_sref];
   block_info.affine_binding = IsAffineBinding(
       /*realize=*/reconstructor.new_block_realize_,
       
/*loop_var_ranges=*/LoopDomainOfSRefTreePath(GetRef<StmtSRef>(block_sref->parent)),
-      /*analyzer=*/&analyzer);
+      /*analyzer=*/analyzer);
 }
 
 void ComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& 
loop_sref,
                bool preserve_unit_loops) {
-  ComputeAtOrReverseComputeAtImpl<true>(self, block_sref, loop_sref, 
preserve_unit_loops);
+  arith::Analyzer analyzer;
+  ComputeAtOrReverseComputeAtImpl<true>(self, block_sref, loop_sref, 
preserve_unit_loops,
+                                        &analyzer);
 }
 
 void ReverseComputeAt(ScheduleState self, const StmtSRef& block_sref, const 
StmtSRef& loop_sref,
                       bool preserve_unit_loops) {
-  ComputeAtOrReverseComputeAtImpl<false>(self, block_sref, loop_sref, 
preserve_unit_loops);
+  arith::Analyzer analyzer;
+  ComputeAtOrReverseComputeAtImpl<false>(self, block_sref, loop_sref, 
preserve_unit_loops,
+                                         &analyzer);
+}
+
+bool CanComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const 
StmtSRef& loop_sref,
+                  bool preserve_unit_loops) {
+  arith::Analyzer analyzer;
+  try {
+    ComputeAtOrReverseComputeAtImpl<true>(self, block_sref, loop_sref, 
preserve_unit_loops,
+                                          &analyzer, true);
+  } catch (const tvm::runtime::Error& e) {
+    return false;
+  }
+  return true;
+}
+
+bool CanReverseComputeAt(const ScheduleState& self, const StmtSRef& block_sref,
+                         const StmtSRef& loop_sref, bool preserve_unit_loops) {
+  arith::Analyzer analyzer;
+  try {
+    ComputeAtOrReverseComputeAtImpl<false>(self, block_sref, loop_sref, 
preserve_unit_loops,
+                                           &analyzer, true);
+  } catch (const tvm::runtime::Error& e) {
+    return false;
+  }
+  return true;
 }
 
 /******** InstructionKind Registration ********/
diff --git a/src/tir/schedule/primitive/compute_inline.cc 
b/src/tir/schedule/primitive/compute_inline.cc
index 12ae021..fe2c679 100644
--- a/src/tir/schedule/primitive/compute_inline.cc
+++ b/src/tir/schedule/primitive/compute_inline.cc
@@ -60,11 +60,27 @@ class NotSingleReadWriteBuffer : public ScheduleError {
   bool is_read_;
   Block block_;
 
-  static Buffer GetSingleRead(const ScheduleState& self, const Block& block) {
-    if (block->reads.size() != 1) {
+  static Buffer GetSingleRead(const ScheduleState& self, const Block& block,
+                              const StmtSRef& scope_root_sref) {
+    const std::unordered_map<Buffer, Array<StmtSRef>, ObjectPtrHash, 
ObjectPtrEqual>&
+        buffer_writers = 
self->block_info.at(scope_root_sref).scope->buffer_writers;
+    const BufferNode* read_buffer = nullptr;
+    for (const BufferRegion& read_region : block->reads) {
+      const BufferNode* buffer = read_region->buffer.get();
+      if (buffer == read_buffer) {
+        continue;
+      }
+      if (buffer_writers.count(GetRef<Buffer>(buffer)) > 0) {
+        if (read_buffer != nullptr) {
+          throw NotSingleReadWriteBuffer(self->mod, true, block);
+        }
+        read_buffer = buffer;
+      }
+    }
+    if (read_buffer == nullptr) {
       throw NotSingleReadWriteBuffer(self->mod, true, block);
     }
-    return block->reads[0]->buffer;
+    return GetRef<Buffer>(read_buffer);
   }
 
   static Buffer GetSingleWrite(const ScheduleState& self, const Block& block) {
@@ -167,7 +183,7 @@ class OpaqueAccessError : public ScheduleError {
  * \brief The base class of the inliner, which handles:
  * 1) Substitute a subtree with the specific block being inlined
  * 2) Update the block signature to reflect the changes of 
read/write/allocated buffers
- * 3) Maintain a list of index variables and their substition of the buffer 
being inlined
+ * 3) Maintain a list of index variables and their substitution of the buffer 
being inlined
  */
 class BaseInliner : public StmtExprMutator {
  protected:
@@ -526,7 +542,8 @@ class ReverseComputeInliner : public BaseInliner {
   PrimExpr producer_rhs_{nullptr};
 };
 
-void ComputeInline(ScheduleState self, const StmtSRef& producer_block_sref) {
+void ComputeInlineImpl(ScheduleState self, const StmtSRef& producer_block_sref,
+                       bool check_only = false) {
   const BlockNode* _producer_block = TVM_SREF_TO_BLOCK(_producer_block, 
producer_block_sref);
   Block producer_block = GetRef<Block>(_producer_block);
   Buffer inlined_buffer = NotSingleReadWriteBuffer::GetSingleWrite(self, 
producer_block);
@@ -535,6 +552,7 @@ void ComputeInline(ScheduleState self, const StmtSRef& 
producer_block_sref) {
                                           /*require_stage_pipeline=*/true,
                                           
/*require_subtree_compact_dataflow=*/false);
   // Step 2. Check completeness
+  CheckNotOutputBlock(self, producer_block_sref, scope_root_sref);
   CheckCompleteBlock(self, producer_block_sref, scope_root_sref);
   // Step 3. Analyze the block body
   ComputeInliner inliner(inlined_buffer, producer_block, scope_root_sref);
@@ -550,17 +568,35 @@ void ComputeInline(ScheduleState self, const StmtSRef& 
producer_block_sref) {
     throw OpaqueAccessError(self->mod, scope_root_sref);
   }
   // Step 6. Do the real mutation on the AST and the sref tree in the schedule 
state
+  if (check_only) {
+    return;
+  }
   self->Replace(scope_root_sref, tgt_stmt, inliner.block_reuse);
 }
 
-void ReverseComputeInline(ScheduleState self, const StmtSRef& 
consumer_block_sref) {
+void ComputeInline(ScheduleState self, const StmtSRef& producer_block_sref) {
+  ComputeInlineImpl(self, producer_block_sref);
+}
+
+bool CanComputeInline(const ScheduleState& self, const StmtSRef& 
producer_block_sref) {
+  try {
+    ComputeInlineImpl(self, producer_block_sref, true);
+  } catch (const tvm::runtime::Error& e) {
+    return false;
+  }
+  return true;
+}
+
+void ReverseComputeInlineImpl(ScheduleState self, const StmtSRef& 
consumer_block_sref,
+                              bool check_only = false) {
   const BlockNode* _consumer_block = TVM_SREF_TO_BLOCK(_consumer_block, 
consumer_block_sref);
   Block consumer_block = GetRef<Block>(_consumer_block);
-  Buffer inlined_buffer = NotSingleReadWriteBuffer::GetSingleRead(self, 
consumer_block);
   // Step 1. Get the scope block
   StmtSRef scope_root_sref = GetScopeRoot(self, consumer_block_sref,  //
                                           /*require_stage_pipeline=*/true,
                                           
/*require_subtree_compact_dataflow=*/false);
+  Buffer inlined_buffer =
+      NotSingleReadWriteBuffer::GetSingleRead(self, consumer_block, 
scope_root_sref);
   // Step 2. Check completeness
   CheckCompleteBlock(self, consumer_block_sref, scope_root_sref);
   // Step 3. Check if the consumer has a single complete producer
@@ -579,9 +615,25 @@ void ReverseComputeInline(ScheduleState self, const 
StmtSRef& consumer_block_sre
     throw OpaqueAccessError(self->mod, scope_root_sref);
   }
   // Step 7. Do the real mutation on the AST and the sref tree in the schedule 
state
+  if (check_only) {
+    return;
+  }
   self->Replace(scope_root_sref, tgt_stmt, inliner.block_reuse);
 }
 
+bool CanReverseComputeInline(const ScheduleState& self, const StmtSRef& 
block_sref) {
+  try {
+    ReverseComputeInlineImpl(self, block_sref, true);
+  } catch (const tvm::runtime::Error& e) {
+    return false;
+  }
+  return true;
+}
+
+void ReverseComputeInline(ScheduleState self, const StmtSRef& 
consumer_block_sref) {
+  ReverseComputeInlineImpl(self, consumer_block_sref);
+}
+
 /******** InstructionKind Registration ********/
 
 struct ComputeInlineTraits : public UnpackedInstTraits<ComputeInlineTraits> {
diff --git a/tests/python/unittest/test_tir_schedule_compute_inline.py 
b/tests/python/unittest/test_tir_schedule_compute_inline.py
index a078c0e..5cc36c0 100644
--- a/tests/python/unittest/test_tir_schedule_compute_inline.py
+++ b/tests/python/unittest/test_tir_schedule_compute_inline.py
@@ -329,6 +329,28 @@ def access_opaque_ptr_then_elemwise_inline(a: T.handle, b: 
T.handle) -> None:
             B[vi] = A_cache[vi] * 2.0 + 1.0
 
 
+@T.prim_func
+def matmul_relu(var_A: T.handle, var_B: T.handle, var_compute: T.handle) -> 
None:
+    A = T.match_buffer(var_A, [512, 512], dtype="float32")
+    B = T.match_buffer(var_B, [512, 512], dtype="float32")
+    compute = T.match_buffer(var_compute, [512, 512], dtype="float32")
+    C = T.alloc_buffer([512, 512], dtype="float32")
+    for i0, i1, i2 in T.grid(512, 512, 512):
+        with T.block("C"):
+            i, j, k = T.axis.remap("SSR", [i0, i1, i2])
+            T.reads([C[i, j], A[i, k], B[k, j]])
+            T.writes([C[i, j]])
+            with T.init():
+                C[i, j] = T.float32(0)
+            C[i, j] = C[i, j] + A[i, k] * B[k, j]
+    for i0, i1 in T.grid(512, 512):
+        with T.block("compute"):
+            i0_1, i1_1 = T.axis.remap("SS", [i0, i1])
+            T.reads([C[i0_1, i1_1]])
+            T.writes([compute[i0_1, i1_1]])
+            compute[i0_1, i1_1] = T.max(C[i0_1, i1_1], T.float32(0))
+
+
 # pylint: enable=no-member,invalid-name,unused-variable
 
 
@@ -458,6 +480,13 @@ def test_buffer_matched():
         sch.compute_inline(block_b)
 
 
+def test_output_block():
+    sch = tir.Schedule(matmul_relu, debug_mask="all")
+    block = sch.get_block("compute")
+    with pytest.raises(tvm.tir.ScheduleError):
+        sch.compute_inline(block)
+
+
 def test_compute_inline_predicate():
     sch = tir.Schedule(elementwise_predicate, debug_mask="all")
     block_b = sch.get_block("B")

Reply via email to