This is an automated email from the ASF dual-hosted git repository. xiyou pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 96a513cd97 Patch replay trace. (#11621) 96a513cd97 is described below commit 96a513cd97be4b42acb51d1c9b73288820e90185 Author: Xiyou Zhou <xi...@octoml.ai> AuthorDate: Wed Jun 8 11:39:42 2022 -0700 Patch replay trace. (#11621) --- include/tvm/meta_schedule/search_strategy.h | 4 +++- .../tvm/meta_schedule/search_strategy/replay_trace.py | 8 +++++++- src/meta_schedule/search_strategy/replay_trace.cc | 18 +++++++++++++++--- 3 files changed, 25 insertions(+), 5 deletions(-) diff --git a/include/tvm/meta_schedule/search_strategy.h b/include/tvm/meta_schedule/search_strategy.h index baae22f0d9..5e249850f5 100644 --- a/include/tvm/meta_schedule/search_strategy.h +++ b/include/tvm/meta_schedule/search_strategy.h @@ -211,8 +211,10 @@ class SearchStrategy : public runtime::ObjectRef { * \brief Constructor of replay trace search strategy. * \param num_trials_per_iter The number of trials per iteration, i.e., the batch size. * \param max_trials_per_task The total number of trials for trace replaying. + * \param max_fail_count The max number of failures during trace replaying. */ - TVM_DLL static SearchStrategy ReplayTrace(int num_trials_per_iter, int max_trials_per_task); + TVM_DLL static SearchStrategy ReplayTrace(int num_trials_per_iter, int max_trials_per_task, + int max_fail_count); /*! * \brief Constructor of replay func search strategy. diff --git a/python/tvm/meta_schedule/search_strategy/replay_trace.py b/python/tvm/meta_schedule/search_strategy/replay_trace.py index 70461d65f7..36dbb8734e 100644 --- a/python/tvm/meta_schedule/search_strategy/replay_trace.py +++ b/python/tvm/meta_schedule/search_strategy/replay_trace.py @@ -33,15 +33,21 @@ class ReplayTrace(SearchStrategy): Number of trials per iteration. max_trials_per_task : int Total number of trials for one task + max_fail_count : int + Max number of failures during trace replaying. """ num_trials_per_iter: int max_trials_per_task: int + max_fail_count: int - def __init__(self, num_trials_per_iter: int, max_trials_per_task: int): + def __init__( + self, num_trials_per_iter: int, max_trials_per_task: int, max_fail_count: int = 100 + ): """Constructor""" self.__init_handle_by_constructor__( _ffi_api.SearchStrategyReplayTrace, # type: ignore # pylint: disable=no-member num_trials_per_iter, max_trials_per_task, + max_fail_count, ) diff --git a/src/meta_schedule/search_strategy/replay_trace.cc b/src/meta_schedule/search_strategy/replay_trace.cc index 13f32a744e..355f71455d 100644 --- a/src/meta_schedule/search_strategy/replay_trace.cc +++ b/src/meta_schedule/search_strategy/replay_trace.cc @@ -60,6 +60,8 @@ class ReplayTraceNode : public SearchStrategyNode { int num_trials_per_iter; /*! \brief The number of total trials. */ int max_trials_per_task; + /*! \brief The max number of failures during trace replaying. */ + int max_fail_count; /*! \brief The tuning context of the search strategy. */ const TuneContextNode* context_{nullptr}; @@ -71,6 +73,7 @@ class ReplayTraceNode : public SearchStrategyNode { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("num_trials_per_iter", &num_trials_per_iter); v->Visit("max_trials_per_task", &max_trials_per_task); + v->Visit("max_fail_count", &max_fail_count); // `context_` is not visited. // `rand_state_` is not visited // `state_` is not visited @@ -136,7 +139,8 @@ inline Optional<Array<MeasureCandidate>> ReplayTraceNode::State::GenerateMeasure int task_id) -> void { TRandState& rand_state = per_thread_rand_state[thread_id]; IRModule mod = this->per_thread_mod_[thread_id]; - for (;;) { + + for (int fail_count = 0; fail_count < self->max_fail_count; fail_count++) { int design_space_index = tir::SampleInt(&rand_state, 0, design_spaces.size()); tir::Trace trace = design_spaces[design_space_index]; tir::Trace new_trace = tir::Trace(trace->insts, {}); @@ -147,7 +151,13 @@ inline Optional<Array<MeasureCandidate>> ReplayTraceNode::State::GenerateMeasure } }; support::parallel_for_dynamic(0, ed - st, ctx->num_threads, f_worker); - return per_task_result; + Array<MeasureCandidate> filtered; + filtered.reserve(ed - st); + for (MeasureCandidate result : per_task_result) + if (result.defined()) { + filtered.push_back(result); + } + return filtered; } inline void ReplayTraceNode::State::NotifyRunnerResults(const Array<RunnerResult>& results) { @@ -155,10 +165,12 @@ inline void ReplayTraceNode::State::NotifyRunnerResults(const Array<RunnerResult ed += self->num_trials_per_iter; } -SearchStrategy SearchStrategy::ReplayTrace(int num_trials_per_iter, int max_trials_per_task) { +SearchStrategy SearchStrategy::ReplayTrace(int num_trials_per_iter, int max_trials_per_task, + int max_fail_count) { ObjectPtr<ReplayTraceNode> n = make_object<ReplayTraceNode>(); n->num_trials_per_iter = num_trials_per_iter; n->max_trials_per_task = max_trials_per_task; + n->max_fail_count = max_fail_count; return SearchStrategy(n); }