This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e5c1a24bda [TIR] Refactor BlockScope outside schedule (#15034)
e5c1a24bda is described below

commit e5c1a24bdacff2ba15098e9b38fbf39fa66b878c
Author: Anirudh Sundar Subramaniam <[email protected]>
AuthorDate: Thu Jun 8 11:49:18 2023 +0530

    [TIR] Refactor BlockScope outside schedule (#15034)
    
    This PR does 2 major changes to TIR.
    
    1. Move stage_pipeline member into `BlockInfo` from `BlockScope`
    2. Extract `BlockScope` outside schedule
    
    This is the first step of a change to expose block level dependence
    analysis to TIR passes. This PR moves BlockScope to be outside schedules
    so that it can be used in a follow-up analysis pass to compute block
    level dependences if needed for a transformation pass.
    
    There is some extra discussion about this change in
    [this discussion forum 
post](https://discuss.tvm.apache.org/t/why-is-block-dependency-info-attached-to-schedules/14975)
---
 include/tvm/tir/{schedule => }/block_scope.h    | 18 ++---
 include/tvm/tir/schedule/state.h                | 20 ++++--
 include/tvm/tir/utils.h                         | 96 +++++++++++++++++++++++++
 python/tvm/tir/{schedule => }/block_scope.py    |  0
 python/tvm/tir/schedule/__init__.py             |  2 +-
 python/tvm/tir/schedule/state.py                |  2 +-
 src/tir/{schedule => ir}/block_scope.cc         | 25 ++++---
 src/tir/schedule/analysis/analysis.cc           |  2 +-
 src/tir/schedule/analysis/verify.cc             |  6 +-
 src/tir/schedule/concrete_schedule.cc           |  1 -
 src/tir/schedule/primitive/cache_index.cc       |  4 +-
 src/tir/schedule/primitive/cache_read_write.cc  | 14 ++--
 src/tir/schedule/primitive/decompose_padding.cc |  4 +-
 src/tir/schedule/primitive/read_write_at.cc     |  2 +-
 src/tir/schedule/primitive/reduction.cc         |  2 +-
 src/tir/schedule/state.cc                       |  9 ++-
 src/tir/schedule/utils.h                        | 71 +-----------------
 17 files changed, 152 insertions(+), 126 deletions(-)

diff --git a/include/tvm/tir/schedule/block_scope.h 
b/include/tvm/tir/block_scope.h
similarity index 93%
rename from include/tvm/tir/schedule/block_scope.h
rename to include/tvm/tir/block_scope.h
index be3e79a183..f09beecd54 100644
--- a/include/tvm/tir/schedule/block_scope.h
+++ b/include/tvm/tir/block_scope.h
@@ -17,13 +17,13 @@
  * under the License.
  */
 /*!
- * \file tvm/tir/schedule/block_scope.h
+ * \file tvm/tir/block_scope.h
  * \brief Definition of two pillar data structure for TensorIR scheduling: 
StmtSRef, BlockScope.
  * \sa StmtSRefNode
  * \sa BlockScopeNode
  */
-#ifndef TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
-#define TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+#ifndef TVM_TIR_BLOCK_SCOPE_H_
+#define TVM_TIR_BLOCK_SCOPE_H_
 
 #include <tvm/tir/stmt.h>
 
@@ -216,16 +216,6 @@ class BlockScopeNode : public Object {
   std::unordered_map<StmtSRef, Array<Dependency>, ObjectPtrHash, 
ObjectPtrEqual> dst2deps;
   /*! \brief The mapping from the buffer to the blocks who write it */
   std::unordered_map<Buffer, Array<StmtSRef>, ObjectPtrHash, ObjectPtrEqual> 
buffer_writers;
-  /*!
-   * \brief This property indicates that the block scope (rooted at its 
corresponding block) is
-   * equivalent to of a stage pipeline. Under the following conditions:
-   *
-   * 1) The region cover property holds for every of its child blocks
-   * 2) No write-after-read dependency or opaque dependency, only 
read-after-write and
-   * write-after-write are allowed
-   * 3) All the statements in the scope are schedulable statements, i.e. Block 
and For
-   */
-  bool stage_pipeline{false};
 
   void VisitAttrs(AttrVisitor* v) {}
 
@@ -270,4 +260,4 @@ class BlockScope : public ObjectRef {
 }  // namespace tir
 }  // namespace tvm
 
-#endif  // TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+#endif  // TVM_TIR_BLOCK_SCOPE_H_
diff --git a/include/tvm/tir/schedule/state.h b/include/tvm/tir/schedule/state.h
index a089de2799..d2d90812dd 100644
--- a/include/tvm/tir/schedule/state.h
+++ b/include/tvm/tir/schedule/state.h
@@ -24,8 +24,8 @@
 #define TVM_TIR_SCHEDULE_STATE_H_
 
 #include <tvm/ir/module.h>
+#include <tvm/tir/block_scope.h>
 #include <tvm/tir/function.h>
-#include <tvm/tir/schedule/block_scope.h>
 
 #include <unordered_map>
 #include <utility>
@@ -51,13 +51,25 @@ struct BlockInfo {
    * produced by its producers
    */
   bool region_cover{false};
+  /*!
+   * \brief This property indicates that the block scope (rooted at its 
corresponding block) is
+   * equivalent to of a stage pipeline. Under the following conditions:
+   *
+   * 1) The region cover property holds for every of its child blocks
+   * 2) No write-after-read dependency or opaque dependency, only 
read-after-write and
+   * write-after-write are allowed
+   * 3) All the statements in the scope are schedulable statements, i.e. Block 
and For
+   */
+  bool stage_pipeline{false};
 
   BlockInfo() = default;
 
-  explicit BlockInfo(BlockScope scope, bool affine_binding = false, bool 
region_cover = false)
+  explicit BlockInfo(BlockScope scope, bool affine_binding = false, bool 
region_cover = false,
+                     bool stage_pipeline = false)
       : scope(std::move(scope)),         //
         affine_binding(affine_binding),  //
-        region_cover(region_cover) {}
+        region_cover(region_cover),
+        stage_pipeline(stage_pipeline) {}
 };
 
 /*!
@@ -185,7 +197,7 @@ class ScheduleStateNode : public Object {
    * \return The corresponding BlockScope
    */
   bool IsStagePipeline(const StmtSRef& scope_root) const {
-    return GetBlockScope(scope_root)->stage_pipeline;
+    return GetBlockInfo(scope_root).stage_pipeline;
   }
 };
 
diff --git a/include/tvm/tir/utils.h b/include/tvm/tir/utils.h
new file mode 100644
index 0000000000..c9aad23dad
--- /dev/null
+++ b/include/tvm/tir/utils.h
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_TIR_UTILS_H_
+#define TVM_TIR_UTILS_H_
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief A helper macro to convert an sref to the statement it points to,
+ * then check if the downcasting succeeded.
+ * \param Result The result variable, used for checking
+ * \param SRef The SRef to be cast
+ * \param Type The type to be cast to, can be Block or For
+ */
+#define TVM_SREF_AS_OR_ERR(Result, SRef, Type) \
+  SRef->StmtAs<Type>();                        \
+  ICHECK(Result)
+
+/*!
+ * \brief A helper macro to convert an sref to the block it points to,
+ *
+ * Throws an internal error if downcasting fails.  The variable name
+ * in the parent scope is used for the error message.
+ *
+ * \param SRef The SRef to be cast
+ */
+#define TVM_SREF_TO_BLOCK(SRef)                                                
                    \
+  [&]() {                                                                      
                    \
+    auto result = TVM_SREF_AS_OR_ERR(result, (SRef), ::tvm::tir::BlockNode)    
                    \
+                  << "TypeError: Expects StmtSRef `" << #SRef << "` points to 
`Block`, but gets: " \
+                  << ((SRef)->stmt ? (SRef)->stmt->GetTypeKey() : "None");     
                    \
+    return result;                                                             
                    \
+  }()
+
+/*!
+ * \brief A helper macro to convert an sref to the for-loop it points to
+ *
+ * Throws an internal error if downcasting fails.  The variable name
+ * in the parent scope is used for the error message.
+ *
+ * \param SRef The SRef to be cast
+ */
+#define TVM_SREF_TO_FOR(SRef)                                                  
                   \
+  [&]() {                                                                      
                   \
+    auto result = TVM_SREF_AS_OR_ERR(result, (SRef), ::tvm::tir::ForNode)      
                   \
+                  << "TypeError: Expects StmtSRef `" << #SRef << "` points to 
`Loop`, but gets: " \
+                  << ((SRef)->stmt ? (SRef)->stmt->GetTypeKey() : "None");     
                   \
+    return result;                                                             
                   \
+  }()
+
+/*!
+ * \brief Downcast a TVM ObjectRef to its corresponding container using 
`ObjectRef::as<Type>`,
+ * then check if the downcasting succeeded.
+ * \param Result The result variable, used for checking
+ * \param From The ObjectRef to be downcast
+ * \param Type The type to be downcast to
+ */
+#define TVM_TYPE_AS_OR_ERR(Result, From, Type) \
+  From.as<Type>();                             \
+  ICHECK(Result)
+
+/*!
+ * \brief Downcast a TVM ObjectRef to its corresponding container using 
`ObjectRef::as<Type>`,
+ * throwing an internal error if downcast fails.
+ * \param From The ObjectRef to be downcast
+ * \param Type The type to be downcast to
+ */
+#define TVM_TYPE_AS(From, Type)                                                
               \
+  [&]() {                                                                      
               \
+    auto result = TVM_TYPE_AS_OR_ERR(result, (From), Type)                     
               \
+                  << "TypeError: Expects `" << #From << "` to have type `" << 
Type::_type_key \
+                  << "`, but gets: " << ((From).defined() ? 
(From)->GetTypeKey() : "None");   \
+    return result;                                                             
               \
+  }()
+
+}  // namespace tir
+}  // namespace tvm
+
+#endif  // TVM_TIR_UTILS_H_
diff --git a/python/tvm/tir/schedule/block_scope.py 
b/python/tvm/tir/block_scope.py
similarity index 100%
rename from python/tvm/tir/schedule/block_scope.py
rename to python/tvm/tir/block_scope.py
diff --git a/python/tvm/tir/schedule/__init__.py 
b/python/tvm/tir/schedule/__init__.py
index 63638a8945..1f68c487c0 100644
--- a/python/tvm/tir/schedule/__init__.py
+++ b/python/tvm/tir/schedule/__init__.py
@@ -17,7 +17,7 @@
 # pylint: disable=unused-import
 """Namespace for the TensorIR schedule API."""
 
-from .block_scope import BlockScope, Dependency, DepKind, StmtSRef
+from ..block_scope import BlockScope, Dependency, DepKind, StmtSRef
 from .instruction import Instruction, InstructionKind
 from .schedule import BlockRV, ExprRV, LoopRV, Schedule, ScheduleError
 from .state import ScheduleDebugMask, ScheduleState
diff --git a/python/tvm/tir/schedule/state.py b/python/tvm/tir/schedule/state.py
index 0e67411103..df2eb534e6 100644
--- a/python/tvm/tir/schedule/state.py
+++ b/python/tvm/tir/schedule/state.py
@@ -26,7 +26,7 @@ from tvm.runtime import Object
 from tvm.tir import Block, BlockRealize, For, PrimFunc
 
 from . import _ffi_api
-from .block_scope import BlockScope, StmtSRef
+from ..block_scope import BlockScope, StmtSRef
 
 CachedFlags = namedtuple("CachedFlags", ["affine_binding", "region_cover", 
"stage_pipeline"])
 
diff --git a/src/tir/schedule/block_scope.cc b/src/tir/ir/block_scope.cc
similarity index 89%
rename from src/tir/schedule/block_scope.cc
rename to src/tir/ir/block_scope.cc
index 31452f4a8f..6b46396317 100644
--- a/src/tir/schedule/block_scope.cc
+++ b/src/tir/ir/block_scope.cc
@@ -16,7 +16,8 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-#include "./utils.h"
+#include <tvm/tir/block_scope.h>
+#include <tvm/tir/utils.h>
 
 namespace tvm {
 namespace tir {
@@ -141,21 +142,19 @@ TVM_REGISTER_NODE_TYPE(StmtSRefNode);
 TVM_REGISTER_NODE_TYPE(DependencyNode);
 TVM_REGISTER_NODE_TYPE(BlockScopeNode);
 
-TVM_REGISTER_GLOBAL("tir.schedule.StmtSRefStmt")
-    .set_body_typed([](StmtSRef sref) -> Optional<Stmt> {
-      return GetRef<Optional<Stmt>>(sref->stmt);
-    });
-TVM_REGISTER_GLOBAL("tir.schedule.StmtSRefParent")
-    .set_body_typed([](StmtSRef sref) -> Optional<StmtSRef> {
-      return GetRef<Optional<StmtSRef>>(sref->parent);
-    });
-TVM_REGISTER_GLOBAL("tir.schedule.StmtSRefRootMark")  //
+TVM_REGISTER_GLOBAL("tir.StmtSRefStmt").set_body_typed([](StmtSRef sref) -> 
Optional<Stmt> {
+  return GetRef<Optional<Stmt>>(sref->stmt);
+});
+TVM_REGISTER_GLOBAL("tir.StmtSRefParent").set_body_typed([](StmtSRef sref) -> 
Optional<StmtSRef> {
+  return GetRef<Optional<StmtSRef>>(sref->parent);
+});
+TVM_REGISTER_GLOBAL("tir.StmtSRefRootMark")  //
     .set_body_typed(StmtSRef::RootMark);
-TVM_REGISTER_GLOBAL("tir.schedule.StmtSRefInlineMark")  //
+TVM_REGISTER_GLOBAL("tir.StmtSRefInlineMark")  //
     .set_body_typed(StmtSRef::InlineMark);
-TVM_REGISTER_GLOBAL("tir.schedule.BlockScopeGetDepsBySrc")
+TVM_REGISTER_GLOBAL("tir.BlockScopeGetDepsBySrc")
     .set_body_method<BlockScope>(&BlockScopeNode::GetDepsBySrc);
-TVM_REGISTER_GLOBAL("tir.schedule.BlockScopeGetDepsByDst")
+TVM_REGISTER_GLOBAL("tir.BlockScopeGetDepsByDst")
     .set_body_method<BlockScope>(&BlockScopeNode::GetDepsByDst);
 
 }  // namespace tir
diff --git a/src/tir/schedule/analysis/analysis.cc 
b/src/tir/schedule/analysis/analysis.cc
index 2c4da4aaf7..1f989ef939 100644
--- a/src/tir/schedule/analysis/analysis.cc
+++ b/src/tir/schedule/analysis/analysis.cc
@@ -104,7 +104,7 @@ Definition of a scope that is a stage pipeline:
   }
   // Step 2. Handle `require_stage_pipeline`
   if (require_stage_pipeline && self->enable_check) {
-    bool stage_pipeline = 
self->GetBlockInfo(scope_root_sref).scope->stage_pipeline;
+    bool stage_pipeline = self->GetBlockInfo(scope_root_sref).stage_pipeline;
     if (stage_pipeline == false) {
       const BlockNode* block = TVM_SREF_TO_BLOCK(scope_root_sref);
       throw NotStagePipelineError(self->mod, GetRef<Block>(block));
diff --git a/src/tir/schedule/analysis/verify.cc 
b/src/tir/schedule/analysis/verify.cc
index ef45f7f8c7..b29d13c3b9 100644
--- a/src/tir/schedule/analysis/verify.cc
+++ b/src/tir/schedule/analysis/verify.cc
@@ -172,10 +172,10 @@ void VerifyCachedFlags(const ScheduleState& self) {
                                                  new_block_info.region_cover,
                                                  old_block_info.region_cover);
     }
-    if (new_block_info.scope->stage_pipeline != 
old_block_info.scope->stage_pipeline) {
+    if (new_block_info.stage_pipeline != old_block_info.stage_pipeline) {
       block_info_wrong_stage_pipeline.emplace_back(new_sref,  //
-                                                   
new_block_info.scope->stage_pipeline,
-                                                   
old_block_info.scope->stage_pipeline);
+                                                   
new_block_info.stage_pipeline,
+                                                   
old_block_info.stage_pipeline);
     }
   }
 
diff --git a/src/tir/schedule/concrete_schedule.cc 
b/src/tir/schedule/concrete_schedule.cc
index d485127242..2359a248fb 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -149,7 +149,6 @@ class ScheduleCopier {
       scope->src2deps = Copy(old_info.scope->src2deps);
       scope->dst2deps = Copy(old_info.scope->dst2deps);
       scope->buffer_writers = Copy(old_info.scope->buffer_writers);
-      scope->stage_pipeline = old_info.scope->stage_pipeline;
       new_info.scope = BlockScope(std::move(scope));
       result[Copy(old_sref)] = std::move(new_info);
     }
diff --git a/src/tir/schedule/primitive/cache_index.cc 
b/src/tir/schedule/primitive/cache_index.cc
index ac62d35b26..58bcd368c8 100644
--- a/src/tir/schedule/primitive/cache_index.cc
+++ b/src/tir/schedule/primitive/cache_index.cc
@@ -463,7 +463,7 @@ Array<StmtSRef> CacheIndex(ScheduleState self, const 
StmtSRef& block_sref,
   Array<Block> cache_stages = MakeIndexCacheStage(&info, storage_scope);
   Stmt new_scope = CacheIndexRewriter::Rewrite(/*scope_sref=*/scope_sref, 
/*info=*/&info);
 
-  bool old_stage_pipeline = self->block_info[block_sref].scope->stage_pipeline;
+  bool old_stage_pipeline = self->block_info[block_sref].stage_pipeline;
 
   // Step 3. Replacing and updating flags.
   self->Replace(scope_sref, new_scope, info.block_reuse);
@@ -486,7 +486,7 @@ Array<StmtSRef> CacheIndex(ScheduleState self, const 
StmtSRef& block_sref,
 
     block_info.affine_binding = affine_binding;
     block_info.region_cover = true;
-    block_info.scope->stage_pipeline = old_stage_pipeline;
+    block_info.stage_pipeline = old_stage_pipeline;
   }
 
   return result_block_srefs;
diff --git a/src/tir/schedule/primitive/cache_read_write.cc 
b/src/tir/schedule/primitive/cache_read_write.cc
index cf139c7df7..74a960eefb 100644
--- a/src/tir/schedule/primitive/cache_read_write.cc
+++ b/src/tir/schedule/primitive/cache_read_write.cc
@@ -1526,7 +1526,7 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& 
block_sref, int read_buff
   BlockInfo& block_info = self->block_info[result_block_sref];
   block_info.affine_binding = CalculateAffineFlag(self, result_block_sref);
   block_info.region_cover = true;
-  block_info.scope->stage_pipeline = true;
+  block_info.stage_pipeline = true;
   return result_block_sref;
 }
 
@@ -1591,7 +1591,7 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& 
block_sref, int write_bu
   BlockInfo& block_info = self->block_info[result_block_sref];
   block_info.affine_binding = CalculateAffineFlag(self, result_block_sref);
   block_info.region_cover = true;
-  block_info.scope->stage_pipeline = true;
+  block_info.stage_pipeline = true;
   return result_block_sref;
 }
 
@@ -1812,7 +1812,7 @@ StmtSRef ReindexCacheRead(ScheduleState self, const 
StmtSRef& block_sref, int re
   BlockInfo& block_info = self->block_info[result_block_sref];
   block_info.affine_binding = CalculateAffineFlag(self, result_block_sref);
   block_info.region_cover = true;
-  block_info.scope->stage_pipeline = true;
+  block_info.stage_pipeline = true;
   return result_block_sref;
 }
 
@@ -1876,7 +1876,7 @@ StmtSRef ReindexCacheWrite(ScheduleState self, const 
StmtSRef& block_sref, int w
   BlockInfo& block_info = self->block_info[result_block_sref];
   block_info.affine_binding = CalculateAffineFlag(self, result_block_sref);
   block_info.region_cover = true;
-  block_info.scope->stage_pipeline = true;
+  block_info.stage_pipeline = true;
   return result_block_sref;
 }
 
@@ -1954,7 +1954,7 @@ Array<StmtSRef> CacheInplace(ScheduleState self, const 
StmtSRef& block_sref, int
   BlockInfo& block_info_read = self->block_info[result_block_sref];
   block_info_read.affine_binding = CalculateAffineFlag(self, 
result_block_sref);
   block_info_read.region_cover = true;
-  block_info_read.scope->stage_pipeline = false;
+  block_info_read.stage_pipeline = false;
   results_block_sref.push_back(result_block_sref);
 
   // Do cache write
@@ -1983,7 +1983,7 @@ Array<StmtSRef> CacheInplace(ScheduleState self, const 
StmtSRef& block_sref, int
   BlockInfo& block_info_write = self->block_info[result_block_sref];
   block_info_write.affine_binding = CalculateAffineFlag(self, 
result_block_sref);
   block_info_write.region_cover = true;
-  block_info_write.scope->stage_pipeline = false;
+  block_info_write.stage_pipeline = false;
   results_block_sref.push_back(result_block_sref);
 
   return results_block_sref;
@@ -2058,7 +2058,7 @@ StmtSRef ReIndex(ScheduleState self, const StmtSRef& 
block_sref, int buffer_inde
   BlockInfo& block_info = self->block_info[result_block_sref];
   block_info.affine_binding = CalculateAffineFlag(self, result_block_sref);
   block_info.region_cover = true;
-  block_info.scope->stage_pipeline = true;
+  block_info.stage_pipeline = true;
   return result_block_sref;
 }
 
diff --git a/src/tir/schedule/primitive/decompose_padding.cc 
b/src/tir/schedule/primitive/decompose_padding.cc
index de7bd93094..1743a34088 100644
--- a/src/tir/schedule/primitive/decompose_padding.cc
+++ b/src/tir/schedule/primitive/decompose_padding.cc
@@ -496,7 +496,7 @@ StmtSRef DecomposePaddingImpl(ScheduleState self, const 
StmtSRef& block_sref,
   BlockInfo& block_info = self->block_info[new_block_sref];
   block_info.affine_binding = true;
   block_info.region_cover = true;
-  block_info.scope->stage_pipeline = true;
+  block_info.stage_pipeline = true;
 
   // If the const pad value filling block is lifted out of the original 
subtree,
   // set the region_cover flag as false since region_cover is the property 
under the subtree.
@@ -518,7 +518,7 @@ StmtSRef DecomposePaddingImpl(ScheduleState self, const 
StmtSRef& block_sref,
     }
   }
   if (!preserve_stage_pipeline) {
-    self->block_info[scope_root_sref].scope->stage_pipeline = false;
+    self->block_info[scope_root_sref].stage_pipeline = false;
   }
   return new_block_sref;
 }
diff --git a/src/tir/schedule/primitive/read_write_at.cc 
b/src/tir/schedule/primitive/read_write_at.cc
index 8b7d78f669..d7afd7d330 100644
--- a/src/tir/schedule/primitive/read_write_at.cc
+++ b/src/tir/schedule/primitive/read_write_at.cc
@@ -162,7 +162,7 @@ struct ReadWriteAtImpl {
     BlockInfo& block_info = self_->block_info[new_block_sref];
     block_info.affine_binding = affine_binding;
     block_info.region_cover = true;
-    block_info.scope->stage_pipeline = true;
+    block_info.stage_pipeline = true;
   }
 
   template <bool is_read>
diff --git a/src/tir/schedule/primitive/reduction.cc 
b/src/tir/schedule/primitive/reduction.cc
index 9c0e45544a..6069f4289c 100644
--- a/src/tir/schedule/primitive/reduction.cc
+++ b/src/tir/schedule/primitive/reduction.cc
@@ -1272,7 +1272,7 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& 
rf_loop_sref, int factor_ax
     BlockInfo& info = self->block_info[new_block_sref];
     info.affine_binding = true;
     info.region_cover = true;
-    info.scope->stage_pipeline = true;
+    info.stage_pipeline = true;
   }
   return new_block_srefs[0];
 }
diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc
index 45ed97af51..33f8598289 100644
--- a/src/tir/schedule/state.cc
+++ b/src/tir/schedule/state.cc
@@ -224,8 +224,7 @@ class BlockInfoCollector : private StmtVisitor {
     // Set `region_cover` to true, will be updated on its scope block
     info.region_cover = true;
     // Set `stage_pipeline` and `region_cover` for its intermediate children
-    info.scope->stage_pipeline =
-        CheckRegionCoverAndStagePipeline(info, scope_root, child_block_srefs);
+    info.stage_pipeline = CheckRegionCoverAndStagePipeline(info, scope_root, 
child_block_srefs);
   }
 
   bool CheckRegionCoverAndStagePipeline(const BlockInfo& info, const StmtSRef& 
scope_root,
@@ -796,11 +795,11 @@ class SRefUpdater : public StmtVisitor {
       BlockInfo& info = insert_result.first->second;
       info.affine_binding = false;
       info.region_cover = false;
-      info.scope->stage_pipeline = false;
+      info.stage_pipeline = false;
     } else {
       // Insertion didn't take place, because the entry has been there before.
       // In this case, we assume that flags are still valid so intentionally 
keep them unchanged
-      new_info.scope->stage_pipeline = info.scope->stage_pipeline;
+      new_info.stage_pipeline = info.stage_pipeline;
       info.scope = std::move(new_info.scope);
     }
   }
@@ -1111,7 +1110,7 @@ TVM_DLL Array<Bool> GetCachedFlags(const ScheduleState& 
self, const StmtSRef& bl
   const BlockInfo& info = self->GetBlockInfo(block_sref);
   return {Bool(info.affine_binding),  //
           Bool(info.region_cover),    //
-          Bool(info.scope->stage_pipeline)};
+          Bool(info.stage_pipeline)};
 }
 
 /**************** FFI ****************/
diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h
index df92c7f807..7f99c92f88 100644
--- a/src/tir/schedule/utils.h
+++ b/src/tir/schedule/utils.h
@@ -31,6 +31,7 @@
 #include <tvm/tir/schedule/state.h>
 #include <tvm/tir/schedule/trace.h>
 #include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/utils.h>
 
 #include <string>
 #include <unordered_map>
@@ -50,76 +51,6 @@
 
 namespace tvm {
 namespace tir {
-
-/*!
- * \brief A helper macro to convert an sref to the statement it points to,
- * then check if the downcasting succeeded.
- * \param Result The result variable, used for checking
- * \param SRef The SRef to be cast
- * \param Type The type to be cast to, can be Block or For
- */
-#define TVM_SREF_AS_OR_ERR(Result, SRef, Type) \
-  SRef->StmtAs<Type>();                        \
-  ICHECK(Result)
-
-/*!
- * \brief A helper macro to convert an sref to the block it points to,
- *
- * Throws an internal error if downcasting fails.  The variable name
- * in the parent scope is used for the error message.
- *
- * \param SRef The SRef to be cast
- */
-#define TVM_SREF_TO_BLOCK(SRef)                                                
                    \
-  [&]() {                                                                      
                    \
-    auto result = TVM_SREF_AS_OR_ERR(result, (SRef), ::tvm::tir::BlockNode)    
                    \
-                  << "TypeError: Expects StmtSRef `" << #SRef << "` points to 
`Block`, but gets: " \
-                  << ((SRef)->stmt ? (SRef)->stmt->GetTypeKey() : "None");     
                    \
-    return result;                                                             
                    \
-  }()
-
-/*!
- * \brief A helper macro to convert an sref to the for-loop it points to
- *
- * Throws an internal error if downcasting fails.  The variable name
- * in the parent scope is used for the error message.
- *
- * \param SRef The SRef to be cast
- */
-#define TVM_SREF_TO_FOR(SRef)                                                  
                   \
-  [&]() {                                                                      
                   \
-    auto result = TVM_SREF_AS_OR_ERR(result, (SRef), ::tvm::tir::ForNode)      
                   \
-                  << "TypeError: Expects StmtSRef `" << #SRef << "` points to 
`Loop`, but gets: " \
-                  << ((SRef)->stmt ? (SRef)->stmt->GetTypeKey() : "None");     
                   \
-    return result;                                                             
                   \
-  }()
-
-/*!
- * \brief Downcast a TVM ObjectRef to its corresponding container using 
`ObjectRef::as<Type>`,
- * then check if the downcasting succeeded.
- * \param Result The result variable, used for checking
- * \param From The ObjectRef to be downcast
- * \param Type The type to be downcast to
- */
-#define TVM_TYPE_AS_OR_ERR(Result, From, Type) \
-  From.as<Type>();                             \
-  ICHECK(Result)
-
-/*!
- * \brief Downcast a TVM ObjectRef to its corresponding container using 
`ObjectRef::as<Type>`,
- * throwing an internal error if downcast fails.
- * \param Result The result variable, used for checking
- * \param From The ObjectRef to be downcast
- * \param Type The type to be downcast to
- */
-#define TVM_TYPE_AS(From, Type)                                                
               \
-  [&]() {                                                                      
               \
-    auto result = TVM_TYPE_AS_OR_ERR(result, (From), Type)                     
               \
-                  << "TypeError: Expects `" << #From << "` to have type `" << 
Type::_type_key \
-                  << "`, but gets: " << ((From).defined() ? 
(From)->GetTypeKey() : "None");   \
-    return result;                                                             
               \
-  }()
-
 /*!
  * \brief Convert an array of loop StmtSRefs to an array of loops
  * \param loop_srefs The loop StmtSRefs to be converted

Reply via email to