This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e5c1a24bda [TIR] Refactor BlockScope outside schedule (#15034)
e5c1a24bda is described below
commit e5c1a24bdacff2ba15098e9b38fbf39fa66b878c
Author: Anirudh Sundar Subramaniam <[email protected]>
AuthorDate: Thu Jun 8 11:49:18 2023 +0530
[TIR] Refactor BlockScope outside schedule (#15034)
This PR does 2 major changes to TIR.
1. Move stage_pipeline member into `BlockInfo` from `BlockScope`
2. Extract `BlockScope` outside schedule
This is the first step of a change to expose block level dependence
analysis to TIR passes. This PR moves BlockScope to be outside schedules
so that it can be used in a follow-up analysis pass to compute block
level dependences if needed for a transformation pass.
There is some extra discussion about this change in
[this discussion forum
post](https://discuss.tvm.apache.org/t/why-is-block-dependency-info-attached-to-schedules/14975)
---
include/tvm/tir/{schedule => }/block_scope.h | 18 ++---
include/tvm/tir/schedule/state.h | 20 ++++--
include/tvm/tir/utils.h | 96 +++++++++++++++++++++++++
python/tvm/tir/{schedule => }/block_scope.py | 0
python/tvm/tir/schedule/__init__.py | 2 +-
python/tvm/tir/schedule/state.py | 2 +-
src/tir/{schedule => ir}/block_scope.cc | 25 ++++---
src/tir/schedule/analysis/analysis.cc | 2 +-
src/tir/schedule/analysis/verify.cc | 6 +-
src/tir/schedule/concrete_schedule.cc | 1 -
src/tir/schedule/primitive/cache_index.cc | 4 +-
src/tir/schedule/primitive/cache_read_write.cc | 14 ++--
src/tir/schedule/primitive/decompose_padding.cc | 4 +-
src/tir/schedule/primitive/read_write_at.cc | 2 +-
src/tir/schedule/primitive/reduction.cc | 2 +-
src/tir/schedule/state.cc | 9 ++-
src/tir/schedule/utils.h | 71 +-----------------
17 files changed, 152 insertions(+), 126 deletions(-)
diff --git a/include/tvm/tir/schedule/block_scope.h
b/include/tvm/tir/block_scope.h
similarity index 93%
rename from include/tvm/tir/schedule/block_scope.h
rename to include/tvm/tir/block_scope.h
index be3e79a183..f09beecd54 100644
--- a/include/tvm/tir/schedule/block_scope.h
+++ b/include/tvm/tir/block_scope.h
@@ -17,13 +17,13 @@
* under the License.
*/
/*!
- * \file tvm/tir/schedule/block_scope.h
+ * \file tvm/tir/block_scope.h
* \brief Definition of two pillar data structure for TensorIR scheduling:
StmtSRef, BlockScope.
* \sa StmtSRefNode
* \sa BlockScopeNode
*/
-#ifndef TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
-#define TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+#ifndef TVM_TIR_BLOCK_SCOPE_H_
+#define TVM_TIR_BLOCK_SCOPE_H_
#include <tvm/tir/stmt.h>
@@ -216,16 +216,6 @@ class BlockScopeNode : public Object {
std::unordered_map<StmtSRef, Array<Dependency>, ObjectPtrHash,
ObjectPtrEqual> dst2deps;
/*! \brief The mapping from the buffer to the blocks who write it */
std::unordered_map<Buffer, Array<StmtSRef>, ObjectPtrHash, ObjectPtrEqual>
buffer_writers;
- /*!
- * \brief This property indicates that the block scope (rooted at its
corresponding block) is
- * equivalent to of a stage pipeline. Under the following conditions:
- *
- * 1) The region cover property holds for every of its child blocks
- * 2) No write-after-read dependency or opaque dependency, only
read-after-write and
- * write-after-write are allowed
- * 3) All the statements in the scope are schedulable statements, i.e. Block
and For
- */
- bool stage_pipeline{false};
void VisitAttrs(AttrVisitor* v) {}
@@ -270,4 +260,4 @@ class BlockScope : public ObjectRef {
} // namespace tir
} // namespace tvm
-#endif // TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
+#endif // TVM_TIR_BLOCK_SCOPE_H_
diff --git a/include/tvm/tir/schedule/state.h b/include/tvm/tir/schedule/state.h
index a089de2799..d2d90812dd 100644
--- a/include/tvm/tir/schedule/state.h
+++ b/include/tvm/tir/schedule/state.h
@@ -24,8 +24,8 @@
#define TVM_TIR_SCHEDULE_STATE_H_
#include <tvm/ir/module.h>
+#include <tvm/tir/block_scope.h>
#include <tvm/tir/function.h>
-#include <tvm/tir/schedule/block_scope.h>
#include <unordered_map>
#include <utility>
@@ -51,13 +51,25 @@ struct BlockInfo {
* produced by its producers
*/
bool region_cover{false};
+ /*!
+ * \brief This property indicates that the block scope (rooted at its
corresponding block) is
+ * equivalent to of a stage pipeline. Under the following conditions:
+ *
+ * 1) The region cover property holds for every of its child blocks
+ * 2) No write-after-read dependency or opaque dependency, only
read-after-write and
+ * write-after-write are allowed
+ * 3) All the statements in the scope are schedulable statements, i.e. Block
and For
+ */
+ bool stage_pipeline{false};
BlockInfo() = default;
- explicit BlockInfo(BlockScope scope, bool affine_binding = false, bool
region_cover = false)
+ explicit BlockInfo(BlockScope scope, bool affine_binding = false, bool
region_cover = false,
+ bool stage_pipeline = false)
: scope(std::move(scope)), //
affine_binding(affine_binding), //
- region_cover(region_cover) {}
+ region_cover(region_cover),
+ stage_pipeline(stage_pipeline) {}
};
/*!
@@ -185,7 +197,7 @@ class ScheduleStateNode : public Object {
* \return The corresponding BlockScope
*/
bool IsStagePipeline(const StmtSRef& scope_root) const {
- return GetBlockScope(scope_root)->stage_pipeline;
+ return GetBlockInfo(scope_root).stage_pipeline;
}
};
diff --git a/include/tvm/tir/utils.h b/include/tvm/tir/utils.h
new file mode 100644
index 0000000000..c9aad23dad
--- /dev/null
+++ b/include/tvm/tir/utils.h
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#ifndef TVM_TIR_UTILS_H_
+#define TVM_TIR_UTILS_H_
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief A helper macro to convert an sref to the statement it points to,
+ * then check if the downcasting succeeded.
+ * \param Result The result variable, used for checking
+ * \param SRef The SRef to be cast
+ * \param Type The type to be cast to, can be Block or For
+ */
+#define TVM_SREF_AS_OR_ERR(Result, SRef, Type) \
+ SRef->StmtAs<Type>(); \
+ ICHECK(Result)
+
+/*!
+ * \brief A helper macro to convert an sref to the block it points to,
+ *
+ * Throws an internal error if downcasting fails. The variable name
+ * in the parent scope is used for the error message.
+ *
+ * \param SRef The SRef to be cast
+ */
+#define TVM_SREF_TO_BLOCK(SRef)
\
+ [&]() {
\
+ auto result = TVM_SREF_AS_OR_ERR(result, (SRef), ::tvm::tir::BlockNode)
\
+ << "TypeError: Expects StmtSRef `" << #SRef << "` points to
`Block`, but gets: " \
+ << ((SRef)->stmt ? (SRef)->stmt->GetTypeKey() : "None");
\
+ return result;
\
+ }()
+
+/*!
+ * \brief A helper macro to convert an sref to the for-loop it points to
+ *
+ * Throws an internal error if downcasting fails. The variable name
+ * in the parent scope is used for the error message.
+ *
+ * \param SRef The SRef to be cast
+ */
+#define TVM_SREF_TO_FOR(SRef)
\
+ [&]() {
\
+ auto result = TVM_SREF_AS_OR_ERR(result, (SRef), ::tvm::tir::ForNode)
\
+ << "TypeError: Expects StmtSRef `" << #SRef << "` points to
`Loop`, but gets: " \
+ << ((SRef)->stmt ? (SRef)->stmt->GetTypeKey() : "None");
\
+ return result;
\
+ }()
+
+/*!
+ * \brief Downcast a TVM ObjectRef to its corresponding container using
`ObjectRef::as<Type>`,
+ * then check if the downcasting succeeded.
+ * \param Result The result variable, used for checking
+ * \param From The ObjectRef to be downcast
+ * \param Type The type to be downcast to
+ */
+#define TVM_TYPE_AS_OR_ERR(Result, From, Type) \
+ From.as<Type>(); \
+ ICHECK(Result)
+
+/*!
+ * \brief Downcast a TVM ObjectRef to its corresponding container using
`ObjectRef::as<Type>`,
+ * throwing an internal error if downcast fails.
+ * \param From The ObjectRef to be downcast
+ * \param Type The type to be downcast to
+ */
+#define TVM_TYPE_AS(From, Type)
\
+ [&]() {
\
+ auto result = TVM_TYPE_AS_OR_ERR(result, (From), Type)
\
+ << "TypeError: Expects `" << #From << "` to have type `" <<
Type::_type_key \
+ << "`, but gets: " << ((From).defined() ?
(From)->GetTypeKey() : "None"); \
+ return result;
\
+ }()
+
+} // namespace tir
+} // namespace tvm
+
+#endif // TVM_TIR_UTILS_H_
diff --git a/python/tvm/tir/schedule/block_scope.py
b/python/tvm/tir/block_scope.py
similarity index 100%
rename from python/tvm/tir/schedule/block_scope.py
rename to python/tvm/tir/block_scope.py
diff --git a/python/tvm/tir/schedule/__init__.py
b/python/tvm/tir/schedule/__init__.py
index 63638a8945..1f68c487c0 100644
--- a/python/tvm/tir/schedule/__init__.py
+++ b/python/tvm/tir/schedule/__init__.py
@@ -17,7 +17,7 @@
# pylint: disable=unused-import
"""Namespace for the TensorIR schedule API."""
-from .block_scope import BlockScope, Dependency, DepKind, StmtSRef
+from ..block_scope import BlockScope, Dependency, DepKind, StmtSRef
from .instruction import Instruction, InstructionKind
from .schedule import BlockRV, ExprRV, LoopRV, Schedule, ScheduleError
from .state import ScheduleDebugMask, ScheduleState
diff --git a/python/tvm/tir/schedule/state.py b/python/tvm/tir/schedule/state.py
index 0e67411103..df2eb534e6 100644
--- a/python/tvm/tir/schedule/state.py
+++ b/python/tvm/tir/schedule/state.py
@@ -26,7 +26,7 @@ from tvm.runtime import Object
from tvm.tir import Block, BlockRealize, For, PrimFunc
from . import _ffi_api
-from .block_scope import BlockScope, StmtSRef
+from ..block_scope import BlockScope, StmtSRef
CachedFlags = namedtuple("CachedFlags", ["affine_binding", "region_cover",
"stage_pipeline"])
diff --git a/src/tir/schedule/block_scope.cc b/src/tir/ir/block_scope.cc
similarity index 89%
rename from src/tir/schedule/block_scope.cc
rename to src/tir/ir/block_scope.cc
index 31452f4a8f..6b46396317 100644
--- a/src/tir/schedule/block_scope.cc
+++ b/src/tir/ir/block_scope.cc
@@ -16,7 +16,8 @@
* specific language governing permissions and limitations
* under the License.
*/
-#include "./utils.h"
+#include <tvm/tir/block_scope.h>
+#include <tvm/tir/utils.h>
namespace tvm {
namespace tir {
@@ -141,21 +142,19 @@ TVM_REGISTER_NODE_TYPE(StmtSRefNode);
TVM_REGISTER_NODE_TYPE(DependencyNode);
TVM_REGISTER_NODE_TYPE(BlockScopeNode);
-TVM_REGISTER_GLOBAL("tir.schedule.StmtSRefStmt")
- .set_body_typed([](StmtSRef sref) -> Optional<Stmt> {
- return GetRef<Optional<Stmt>>(sref->stmt);
- });
-TVM_REGISTER_GLOBAL("tir.schedule.StmtSRefParent")
- .set_body_typed([](StmtSRef sref) -> Optional<StmtSRef> {
- return GetRef<Optional<StmtSRef>>(sref->parent);
- });
-TVM_REGISTER_GLOBAL("tir.schedule.StmtSRefRootMark") //
+TVM_REGISTER_GLOBAL("tir.StmtSRefStmt").set_body_typed([](StmtSRef sref) ->
Optional<Stmt> {
+ return GetRef<Optional<Stmt>>(sref->stmt);
+});
+TVM_REGISTER_GLOBAL("tir.StmtSRefParent").set_body_typed([](StmtSRef sref) ->
Optional<StmtSRef> {
+ return GetRef<Optional<StmtSRef>>(sref->parent);
+});
+TVM_REGISTER_GLOBAL("tir.StmtSRefRootMark") //
.set_body_typed(StmtSRef::RootMark);
-TVM_REGISTER_GLOBAL("tir.schedule.StmtSRefInlineMark") //
+TVM_REGISTER_GLOBAL("tir.StmtSRefInlineMark") //
.set_body_typed(StmtSRef::InlineMark);
-TVM_REGISTER_GLOBAL("tir.schedule.BlockScopeGetDepsBySrc")
+TVM_REGISTER_GLOBAL("tir.BlockScopeGetDepsBySrc")
.set_body_method<BlockScope>(&BlockScopeNode::GetDepsBySrc);
-TVM_REGISTER_GLOBAL("tir.schedule.BlockScopeGetDepsByDst")
+TVM_REGISTER_GLOBAL("tir.BlockScopeGetDepsByDst")
.set_body_method<BlockScope>(&BlockScopeNode::GetDepsByDst);
} // namespace tir
diff --git a/src/tir/schedule/analysis/analysis.cc
b/src/tir/schedule/analysis/analysis.cc
index 2c4da4aaf7..1f989ef939 100644
--- a/src/tir/schedule/analysis/analysis.cc
+++ b/src/tir/schedule/analysis/analysis.cc
@@ -104,7 +104,7 @@ Definition of a scope that is a stage pipeline:
}
// Step 2. Handle `require_stage_pipeline`
if (require_stage_pipeline && self->enable_check) {
- bool stage_pipeline =
self->GetBlockInfo(scope_root_sref).scope->stage_pipeline;
+ bool stage_pipeline = self->GetBlockInfo(scope_root_sref).stage_pipeline;
if (stage_pipeline == false) {
const BlockNode* block = TVM_SREF_TO_BLOCK(scope_root_sref);
throw NotStagePipelineError(self->mod, GetRef<Block>(block));
diff --git a/src/tir/schedule/analysis/verify.cc
b/src/tir/schedule/analysis/verify.cc
index ef45f7f8c7..b29d13c3b9 100644
--- a/src/tir/schedule/analysis/verify.cc
+++ b/src/tir/schedule/analysis/verify.cc
@@ -172,10 +172,10 @@ void VerifyCachedFlags(const ScheduleState& self) {
new_block_info.region_cover,
old_block_info.region_cover);
}
- if (new_block_info.scope->stage_pipeline !=
old_block_info.scope->stage_pipeline) {
+ if (new_block_info.stage_pipeline != old_block_info.stage_pipeline) {
block_info_wrong_stage_pipeline.emplace_back(new_sref, //
-
new_block_info.scope->stage_pipeline,
-
old_block_info.scope->stage_pipeline);
+
new_block_info.stage_pipeline,
+
old_block_info.stage_pipeline);
}
}
diff --git a/src/tir/schedule/concrete_schedule.cc
b/src/tir/schedule/concrete_schedule.cc
index d485127242..2359a248fb 100644
--- a/src/tir/schedule/concrete_schedule.cc
+++ b/src/tir/schedule/concrete_schedule.cc
@@ -149,7 +149,6 @@ class ScheduleCopier {
scope->src2deps = Copy(old_info.scope->src2deps);
scope->dst2deps = Copy(old_info.scope->dst2deps);
scope->buffer_writers = Copy(old_info.scope->buffer_writers);
- scope->stage_pipeline = old_info.scope->stage_pipeline;
new_info.scope = BlockScope(std::move(scope));
result[Copy(old_sref)] = std::move(new_info);
}
diff --git a/src/tir/schedule/primitive/cache_index.cc
b/src/tir/schedule/primitive/cache_index.cc
index ac62d35b26..58bcd368c8 100644
--- a/src/tir/schedule/primitive/cache_index.cc
+++ b/src/tir/schedule/primitive/cache_index.cc
@@ -463,7 +463,7 @@ Array<StmtSRef> CacheIndex(ScheduleState self, const
StmtSRef& block_sref,
Array<Block> cache_stages = MakeIndexCacheStage(&info, storage_scope);
Stmt new_scope = CacheIndexRewriter::Rewrite(/*scope_sref=*/scope_sref,
/*info=*/&info);
- bool old_stage_pipeline = self->block_info[block_sref].scope->stage_pipeline;
+ bool old_stage_pipeline = self->block_info[block_sref].stage_pipeline;
// Step 3. Replacing and updating flags.
self->Replace(scope_sref, new_scope, info.block_reuse);
@@ -486,7 +486,7 @@ Array<StmtSRef> CacheIndex(ScheduleState self, const
StmtSRef& block_sref,
block_info.affine_binding = affine_binding;
block_info.region_cover = true;
- block_info.scope->stage_pipeline = old_stage_pipeline;
+ block_info.stage_pipeline = old_stage_pipeline;
}
return result_block_srefs;
diff --git a/src/tir/schedule/primitive/cache_read_write.cc
b/src/tir/schedule/primitive/cache_read_write.cc
index cf139c7df7..74a960eefb 100644
--- a/src/tir/schedule/primitive/cache_read_write.cc
+++ b/src/tir/schedule/primitive/cache_read_write.cc
@@ -1526,7 +1526,7 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef&
block_sref, int read_buff
BlockInfo& block_info = self->block_info[result_block_sref];
block_info.affine_binding = CalculateAffineFlag(self, result_block_sref);
block_info.region_cover = true;
- block_info.scope->stage_pipeline = true;
+ block_info.stage_pipeline = true;
return result_block_sref;
}
@@ -1591,7 +1591,7 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef&
block_sref, int write_bu
BlockInfo& block_info = self->block_info[result_block_sref];
block_info.affine_binding = CalculateAffineFlag(self, result_block_sref);
block_info.region_cover = true;
- block_info.scope->stage_pipeline = true;
+ block_info.stage_pipeline = true;
return result_block_sref;
}
@@ -1812,7 +1812,7 @@ StmtSRef ReindexCacheRead(ScheduleState self, const
StmtSRef& block_sref, int re
BlockInfo& block_info = self->block_info[result_block_sref];
block_info.affine_binding = CalculateAffineFlag(self, result_block_sref);
block_info.region_cover = true;
- block_info.scope->stage_pipeline = true;
+ block_info.stage_pipeline = true;
return result_block_sref;
}
@@ -1876,7 +1876,7 @@ StmtSRef ReindexCacheWrite(ScheduleState self, const
StmtSRef& block_sref, int w
BlockInfo& block_info = self->block_info[result_block_sref];
block_info.affine_binding = CalculateAffineFlag(self, result_block_sref);
block_info.region_cover = true;
- block_info.scope->stage_pipeline = true;
+ block_info.stage_pipeline = true;
return result_block_sref;
}
@@ -1954,7 +1954,7 @@ Array<StmtSRef> CacheInplace(ScheduleState self, const
StmtSRef& block_sref, int
BlockInfo& block_info_read = self->block_info[result_block_sref];
block_info_read.affine_binding = CalculateAffineFlag(self,
result_block_sref);
block_info_read.region_cover = true;
- block_info_read.scope->stage_pipeline = false;
+ block_info_read.stage_pipeline = false;
results_block_sref.push_back(result_block_sref);
// Do cache write
@@ -1983,7 +1983,7 @@ Array<StmtSRef> CacheInplace(ScheduleState self, const
StmtSRef& block_sref, int
BlockInfo& block_info_write = self->block_info[result_block_sref];
block_info_write.affine_binding = CalculateAffineFlag(self,
result_block_sref);
block_info_write.region_cover = true;
- block_info_write.scope->stage_pipeline = false;
+ block_info_write.stage_pipeline = false;
results_block_sref.push_back(result_block_sref);
return results_block_sref;
@@ -2058,7 +2058,7 @@ StmtSRef ReIndex(ScheduleState self, const StmtSRef&
block_sref, int buffer_inde
BlockInfo& block_info = self->block_info[result_block_sref];
block_info.affine_binding = CalculateAffineFlag(self, result_block_sref);
block_info.region_cover = true;
- block_info.scope->stage_pipeline = true;
+ block_info.stage_pipeline = true;
return result_block_sref;
}
diff --git a/src/tir/schedule/primitive/decompose_padding.cc
b/src/tir/schedule/primitive/decompose_padding.cc
index de7bd93094..1743a34088 100644
--- a/src/tir/schedule/primitive/decompose_padding.cc
+++ b/src/tir/schedule/primitive/decompose_padding.cc
@@ -496,7 +496,7 @@ StmtSRef DecomposePaddingImpl(ScheduleState self, const
StmtSRef& block_sref,
BlockInfo& block_info = self->block_info[new_block_sref];
block_info.affine_binding = true;
block_info.region_cover = true;
- block_info.scope->stage_pipeline = true;
+ block_info.stage_pipeline = true;
// If the const pad value filling block is lifted out of the original
subtree,
// set the region_cover flag as false since region_cover is the property
under the subtree.
@@ -518,7 +518,7 @@ StmtSRef DecomposePaddingImpl(ScheduleState self, const
StmtSRef& block_sref,
}
}
if (!preserve_stage_pipeline) {
- self->block_info[scope_root_sref].scope->stage_pipeline = false;
+ self->block_info[scope_root_sref].stage_pipeline = false;
}
return new_block_sref;
}
diff --git a/src/tir/schedule/primitive/read_write_at.cc
b/src/tir/schedule/primitive/read_write_at.cc
index 8b7d78f669..d7afd7d330 100644
--- a/src/tir/schedule/primitive/read_write_at.cc
+++ b/src/tir/schedule/primitive/read_write_at.cc
@@ -162,7 +162,7 @@ struct ReadWriteAtImpl {
BlockInfo& block_info = self_->block_info[new_block_sref];
block_info.affine_binding = affine_binding;
block_info.region_cover = true;
- block_info.scope->stage_pipeline = true;
+ block_info.stage_pipeline = true;
}
template <bool is_read>
diff --git a/src/tir/schedule/primitive/reduction.cc
b/src/tir/schedule/primitive/reduction.cc
index 9c0e45544a..6069f4289c 100644
--- a/src/tir/schedule/primitive/reduction.cc
+++ b/src/tir/schedule/primitive/reduction.cc
@@ -1272,7 +1272,7 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef&
rf_loop_sref, int factor_ax
BlockInfo& info = self->block_info[new_block_sref];
info.affine_binding = true;
info.region_cover = true;
- info.scope->stage_pipeline = true;
+ info.stage_pipeline = true;
}
return new_block_srefs[0];
}
diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc
index 45ed97af51..33f8598289 100644
--- a/src/tir/schedule/state.cc
+++ b/src/tir/schedule/state.cc
@@ -224,8 +224,7 @@ class BlockInfoCollector : private StmtVisitor {
// Set `region_cover` to true, will be updated on its scope block
info.region_cover = true;
// Set `stage_pipeline` and `region_cover` for its intermediate children
- info.scope->stage_pipeline =
- CheckRegionCoverAndStagePipeline(info, scope_root, child_block_srefs);
+ info.stage_pipeline = CheckRegionCoverAndStagePipeline(info, scope_root,
child_block_srefs);
}
bool CheckRegionCoverAndStagePipeline(const BlockInfo& info, const StmtSRef&
scope_root,
@@ -796,11 +795,11 @@ class SRefUpdater : public StmtVisitor {
BlockInfo& info = insert_result.first->second;
info.affine_binding = false;
info.region_cover = false;
- info.scope->stage_pipeline = false;
+ info.stage_pipeline = false;
} else {
// Insertion didn't take place, because the entry has been there before.
// In this case, we assume that flags are still valid so intentionally
keep them unchanged
- new_info.scope->stage_pipeline = info.scope->stage_pipeline;
+ new_info.stage_pipeline = info.stage_pipeline;
info.scope = std::move(new_info.scope);
}
}
@@ -1111,7 +1110,7 @@ TVM_DLL Array<Bool> GetCachedFlags(const ScheduleState&
self, const StmtSRef& bl
const BlockInfo& info = self->GetBlockInfo(block_sref);
return {Bool(info.affine_binding), //
Bool(info.region_cover), //
- Bool(info.scope->stage_pipeline)};
+ Bool(info.stage_pipeline)};
}
/**************** FFI ****************/
diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h
index df92c7f807..7f99c92f88 100644
--- a/src/tir/schedule/utils.h
+++ b/src/tir/schedule/utils.h
@@ -31,6 +31,7 @@
#include <tvm/tir/schedule/state.h>
#include <tvm/tir/schedule/trace.h>
#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/utils.h>
#include <string>
#include <unordered_map>
@@ -50,76 +51,6 @@
namespace tvm {
namespace tir {
-
-/*!
- * \brief A helper macro to convert an sref to the statement it points to,
- * then check if the downcasting succeeded.
- * \param Result The result variable, used for checking
- * \param SRef The SRef to be cast
- * \param Type The type to be cast to, can be Block or For
- */
-#define TVM_SREF_AS_OR_ERR(Result, SRef, Type) \
- SRef->StmtAs<Type>(); \
- ICHECK(Result)
-
-/*!
- * \brief A helper macro to convert an sref to the block it points to,
- *
- * Throws an internal error if downcasting fails. The variable name
- * in the parent scope is used for the error message.
- *
- * \param SRef The SRef to be cast
- */
-#define TVM_SREF_TO_BLOCK(SRef)
\
- [&]() {
\
- auto result = TVM_SREF_AS_OR_ERR(result, (SRef), ::tvm::tir::BlockNode)
\
- << "TypeError: Expects StmtSRef `" << #SRef << "` points to
`Block`, but gets: " \
- << ((SRef)->stmt ? (SRef)->stmt->GetTypeKey() : "None");
\
- return result;
\
- }()
-
-/*!
- * \brief A helper macro to convert an sref to the for-loop it points to
- *
- * Throws an internal error if downcasting fails. The variable name
- * in the parent scope is used for the error message.
- *
- * \param SRef The SRef to be cast
- */
-#define TVM_SREF_TO_FOR(SRef)
\
- [&]() {
\
- auto result = TVM_SREF_AS_OR_ERR(result, (SRef), ::tvm::tir::ForNode)
\
- << "TypeError: Expects StmtSRef `" << #SRef << "` points to
`Loop`, but gets: " \
- << ((SRef)->stmt ? (SRef)->stmt->GetTypeKey() : "None");
\
- return result;
\
- }()
-
-/*!
- * \brief Downcast a TVM ObjectRef to its corresponding container using
`ObjectRef::as<Type>`,
- * then check if the downcasting succeeded.
- * \param Result The result variable, used for checking
- * \param From The ObjectRef to be downcast
- * \param Type The type to be downcast to
- */
-#define TVM_TYPE_AS_OR_ERR(Result, From, Type) \
- From.as<Type>(); \
- ICHECK(Result)
-
-/*!
- * \brief Downcast a TVM ObjectRef to its corresponding container using
`ObjectRef::as<Type>`,
- * throwing an internal error if downcast fails.
- * \param Result The result variable, used for checking
- * \param From The ObjectRef to be downcast
- * \param Type The type to be downcast to
- */
-#define TVM_TYPE_AS(From, Type)
\
- [&]() {
\
- auto result = TVM_TYPE_AS_OR_ERR(result, (From), Type)
\
- << "TypeError: Expects `" << #From << "` to have type `" <<
Type::_type_key \
- << "`, but gets: " << ((From).defined() ?
(From)->GetTypeKey() : "None"); \
- return result;
\
- }()
-
/*!
* \brief Convert an array of loop StmtSRefs to an array of loops
* \param loop_srefs The loop StmtSRefs to be converted