This is an automated email from the ASF dual-hosted git repository.
echuraev pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 050b23ff06 [Relay]Disable InferType if it was done and no changes
after previous pass (#17585)
050b23ff06 is described below
commit 050b23ff069ed810b424106df74c06562b07f34f
Author: Andrey Malyshev <[email protected]>
AuthorDate: Mon Jan 27 17:28:00 2025 +0200
[Relay]Disable InferType if it was done and no changes after previous pass
(#17585)
Disable InferType if it was done and no changes after previous pass
This optimizatin allows to speedup PatternRewriter transformations by
reusing of preious type inferred expression instead of perform
InferType multiple times
---
src/relay/ir/dataflow_matcher.cc | 22 +++++++++++++++-------
1 file changed, 15 insertions(+), 7 deletions(-)
diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc
index 3e86e1c8ea..9d117adbbc 100644
--- a/src/relay/ir/dataflow_matcher.cc
+++ b/src/relay/ir/dataflow_matcher.cc
@@ -851,24 +851,32 @@ Expr PatternRewriter::Rewrite(const
Array<DFPatternCallback>& callbacks, const E
std::unordered_map<DFPatternCallback, bool, ObjectPtrHash, ObjectPtrEqual>
done;
do {
last = post;
+ // We don't have to call InferType if previous pass has not modified
anything
+ // We can just take previous typed state of the expression
+ bool types_invalidated = true;
for (auto callback : callbacks) {
if (!done[callback]) {
auto before = post;
+ auto post_typed = post;
callback_ = callback;
- if (callback_->require_type) {
- post = InferTypeWithModule(post, mod_);
+ if (callback_->require_type && types_invalidated) {
+ post_typed = InferTypeWithModule(post, mod_);
}
auto grouper = PatternGrouper();
- groups_ = grouper.GroupMatches(callback_->pattern, post);
+ groups_ = grouper.GroupMatches(callback_->pattern, post_typed);
gid_assignments_ = grouper.GetGIDAssignments();
memo_.clear();
VLOG(1) << "pre rewritten:" << std::endl << PrettyPrint(pre);
- post = this->VisitExpr(post);
+ post = this->VisitExpr(post_typed);
VLOG(1) << "post rewritten:" << std::endl << PrettyPrint(post);
count++;
- if (callback_->rewrite_once) {
- bool current_equal = (*structural_equal)(before, post, false, true);
- if (!current_equal) {
+ bool current_equal = (*structural_equal)(before, post, false, true);
+ if (callback_->require_type && current_equal) {
+ types_invalidated = false;
+ post = post_typed;
+ } else {
+ types_invalidated = true;
+ if (callback_->rewrite_once) {
done[callback] = true;
}
}