This is an automated email from the ASF dual-hosted git repository.

xiyou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 96a513cd97 Patch replay trace. (#11621)
96a513cd97 is described below

commit 96a513cd97be4b42acb51d1c9b73288820e90185
Author: Xiyou Zhou <xi...@octoml.ai>
AuthorDate: Wed Jun 8 11:39:42 2022 -0700

    Patch replay trace. (#11621)
---
 include/tvm/meta_schedule/search_strategy.h            |  4 +++-
 .../tvm/meta_schedule/search_strategy/replay_trace.py  |  8 +++++++-
 src/meta_schedule/search_strategy/replay_trace.cc      | 18 +++++++++++++++---
 3 files changed, 25 insertions(+), 5 deletions(-)

diff --git a/include/tvm/meta_schedule/search_strategy.h 
b/include/tvm/meta_schedule/search_strategy.h
index baae22f0d9..5e249850f5 100644
--- a/include/tvm/meta_schedule/search_strategy.h
+++ b/include/tvm/meta_schedule/search_strategy.h
@@ -211,8 +211,10 @@ class SearchStrategy : public runtime::ObjectRef {
    * \brief Constructor of replay trace search strategy.
    * \param num_trials_per_iter The number of trials per iteration, i.e., the 
batch size.
    * \param max_trials_per_task The total number of trials for trace replaying.
+   * \param max_fail_count The max number of failures during trace replaying.
    */
-  TVM_DLL static SearchStrategy ReplayTrace(int num_trials_per_iter, int 
max_trials_per_task);
+  TVM_DLL static SearchStrategy ReplayTrace(int num_trials_per_iter, int 
max_trials_per_task,
+                                            int max_fail_count);
 
   /*!
    * \brief Constructor of replay func search strategy.
diff --git a/python/tvm/meta_schedule/search_strategy/replay_trace.py 
b/python/tvm/meta_schedule/search_strategy/replay_trace.py
index 70461d65f7..36dbb8734e 100644
--- a/python/tvm/meta_schedule/search_strategy/replay_trace.py
+++ b/python/tvm/meta_schedule/search_strategy/replay_trace.py
@@ -33,15 +33,21 @@ class ReplayTrace(SearchStrategy):
         Number of trials per iteration.
     max_trials_per_task : int
         Total number of trials for one task
+    max_fail_count : int
+        Max number of failures during trace replaying.
     """
 
     num_trials_per_iter: int
     max_trials_per_task: int
+    max_fail_count: int
 
-    def __init__(self, num_trials_per_iter: int, max_trials_per_task: int):
+    def __init__(
+        self, num_trials_per_iter: int, max_trials_per_task: int, 
max_fail_count: int = 100
+    ):
         """Constructor"""
         self.__init_handle_by_constructor__(
             _ffi_api.SearchStrategyReplayTrace,  # type: ignore # pylint: 
disable=no-member
             num_trials_per_iter,
             max_trials_per_task,
+            max_fail_count,
         )
diff --git a/src/meta_schedule/search_strategy/replay_trace.cc 
b/src/meta_schedule/search_strategy/replay_trace.cc
index 13f32a744e..355f71455d 100644
--- a/src/meta_schedule/search_strategy/replay_trace.cc
+++ b/src/meta_schedule/search_strategy/replay_trace.cc
@@ -60,6 +60,8 @@ class ReplayTraceNode : public SearchStrategyNode {
   int num_trials_per_iter;
   /*! \brief The number of total trials. */
   int max_trials_per_task;
+  /*! \brief The max number of failures during trace replaying. */
+  int max_fail_count;
 
   /*! \brief The tuning context of the search strategy. */
   const TuneContextNode* context_{nullptr};
@@ -71,6 +73,7 @@ class ReplayTraceNode : public SearchStrategyNode {
   void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("num_trials_per_iter", &num_trials_per_iter);
     v->Visit("max_trials_per_task", &max_trials_per_task);
+    v->Visit("max_fail_count", &max_fail_count);
     // `context_` is not visited.
     // `rand_state_` is not visited
     // `state_` is not visited
@@ -136,7 +139,8 @@ inline Optional<Array<MeasureCandidate>> 
ReplayTraceNode::State::GenerateMeasure
                                                                         int 
task_id) -> void {
     TRandState& rand_state = per_thread_rand_state[thread_id];
     IRModule mod = this->per_thread_mod_[thread_id];
-    for (;;) {
+
+    for (int fail_count = 0; fail_count < self->max_fail_count; fail_count++) {
       int design_space_index = tir::SampleInt(&rand_state, 0, 
design_spaces.size());
       tir::Trace trace = design_spaces[design_space_index];
       tir::Trace new_trace = tir::Trace(trace->insts, {});
@@ -147,7 +151,13 @@ inline Optional<Array<MeasureCandidate>> 
ReplayTraceNode::State::GenerateMeasure
     }
   };
   support::parallel_for_dynamic(0, ed - st, ctx->num_threads, f_worker);
-  return per_task_result;
+  Array<MeasureCandidate> filtered;
+  filtered.reserve(ed - st);
+  for (MeasureCandidate result : per_task_result)
+    if (result.defined()) {
+      filtered.push_back(result);
+    }
+  return filtered;
 }
 
 inline void ReplayTraceNode::State::NotifyRunnerResults(const 
Array<RunnerResult>& results) {
@@ -155,10 +165,12 @@ inline void 
ReplayTraceNode::State::NotifyRunnerResults(const Array<RunnerResult
   ed += self->num_trials_per_iter;
 }
 
-SearchStrategy SearchStrategy::ReplayTrace(int num_trials_per_iter, int 
max_trials_per_task) {
+SearchStrategy SearchStrategy::ReplayTrace(int num_trials_per_iter, int 
max_trials_per_task,
+                                           int max_fail_count) {
   ObjectPtr<ReplayTraceNode> n = make_object<ReplayTraceNode>();
   n->num_trials_per_iter = num_trials_per_iter;
   n->max_trials_per_task = max_trials_per_task;
+  n->max_fail_count = max_fail_count;
   return SearchStrategy(n);
 }
 

Reply via email to