This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0a0dd3162b [S-TIR][MetaSchedule] Make evolutionary search resilient to 
trace replay failures (#19438)
0a0dd3162b is described below

commit 0a0dd3162bf460ec047d95f956da309487745792
Author: Neo Chien <[email protected]>
AuthorDate: Sun Apr 26 03:01:02 2026 +0800

    [S-TIR][MetaSchedule] Make evolutionary search resilient to trace replay 
failures (#19438)
    
    Hi Committers,
    
    This PR is trying to fix issues
    https://github.com/apache/tvm/issues/17934. Any suggestions would be
    appreciated if you are available.
    
    ### Root Cause
    - During `EvolutionarySearch` candidate generation,
    `trace->ApplyToSchedule(...)` could throw `ScheduleError`.
    - The exception was propagated through parallel execution and aborted
    tuning.
    - Error handling was inconsistent between measured and unmeasured paths,
    and failure visibility was limited.
    
    ### Solutions
    - Catch trace replay failures in `ThreadedTraceApply::Apply` and return
    `nullopt` instead of crashing.
    - Add trace replay failure counting (`trace_fail_counter_`) and accessor
    (`TraceFailCount()`).
    - Align measured path `PickBestFromDatabase` with unmeasured behavior:
    skip invalid candidates and continue.
    - Add visible `WARNING` logs when trace replay failures occur (to avoid
    silent failures).
    
    ---------
    
    Co-authored-by: cchung100m <[email protected]>
---
 .../search_strategy/evolutionary_search.cc         | 24 +++++--
 src/s_tir/meta_schedule/utils.h                    | 35 ++++++++---
 .../test_meta_schedule_search_strategy.py          | 73 ++++++++++++++++++++++
 3 files changed, 120 insertions(+), 12 deletions(-)

diff --git a/src/s_tir/meta_schedule/search_strategy/evolutionary_search.cc 
b/src/s_tir/meta_schedule/search_strategy/evolutionary_search.cc
index fabe50dd60..ec85ffeee5 100644
--- a/src/s_tir/meta_schedule/search_strategy/evolutionary_search.cc
+++ b/src/s_tir/meta_schedule/search_strategy/evolutionary_search.cc
@@ -498,13 +498,24 @@ std::vector<Schedule> 
EvolutionarySearchNode::State::PickBestFromDatabase(int nu
     TVM_FFI_ICHECK(!result.defined());
     if (ffi::Optional<Schedule> sch = pp.Apply(mod, trace, rand_state)) {
       result = sch.value();
-    } else {
-      TVM_FFI_THROW(ValueError) << "Cannot postprocess the trace:\n" << trace;
-      throw;
     }
   };
   support::parallel_for_dynamic(0, actual_num, self->ctx_->num_threads, 
f_proc_measured);
-  return results;
+  TVM_PY_LOG(INFO, self->ctx_->logger) << "Pick-Best-From-Database summary:\n"
+                                       << pp.SummarizeFailures();
+  if (pp.TraceFailCount() > 0) {
+    TVM_PY_LOG(WARNING, self->ctx_->logger)
+        << "PickBestFromDatabase skipped " << pp.TraceFailCount()
+        << " candidate(s) due to trace replay failures";
+  }
+  std::vector<Schedule> filtered;
+  filtered.reserve(actual_num);
+  for (const Schedule& sch : results) {
+    if (sch.defined()) {
+      filtered.push_back(sch);
+    }
+  }
+  return filtered;
 }
 
 std::vector<Schedule> EvolutionarySearchNode::State::SampleInitPopulation(int 
num) {
@@ -538,6 +549,11 @@ std::vector<Schedule> 
EvolutionarySearchNode::State::SampleInitPopulation(int nu
     fail_count += !found_new;
     TVM_PY_LOG(INFO, self->ctx_->logger) << "Sample-Init-Population summary:\n"
                                          << pp.SummarizeFailures();
+    if (pp.TraceFailCount() > 0) {
+      TVM_PY_LOG(WARNING, self->ctx_->logger)
+          << "SampleInitPopulation encountered " << pp.TraceFailCount()
+          << " trace replay failure(s); invalid candidates were skipped";
+    }
   }
   return out_schs;
 }
diff --git a/src/s_tir/meta_schedule/utils.h b/src/s_tir/meta_schedule/utils.h
index 847adc2591..a7804361eb 100644
--- a/src/s_tir/meta_schedule/utils.h
+++ b/src/s_tir/meta_schedule/utils.h
@@ -330,14 +330,24 @@ struct ThreadedTraceApply {
    */
   ffi::Optional<s_tir::Schedule> Apply(const IRModule& mod, const 
s_tir::Trace& trace,
                                        TRandState* rand_state) {
-    s_tir::Schedule sch =
-        s_tir::Schedule::Traced(mod,
-                                /*rand_state=*/ForkSeed(rand_state),
-                                /*debug_mode=*/0,
-                                
/*error_render_level=*/s_tir::ScheduleErrorRenderLevel::kNone);
-
-    trace->ApplyToSchedule(sch, /*remove_postproc=*/true);
-    sch->EnterPostproc();
+    s_tir::Schedule sch{nullptr};
+    try {
+      sch = s_tir::Schedule::Traced(mod,
+                                    /*rand_state=*/ForkSeed(rand_state),
+                                    /*debug_mode=*/0,
+                                    /*error_render_level=*/
+                                    s_tir::ScheduleErrorRenderLevel::kNone);
+      trace->ApplyToSchedule(sch, /*remove_postproc=*/true);
+      sch->EnterPostproc();
+    } catch (const s_tir::ScheduleError& e) {
+      TVM_PY_LOG(WARNING, nullptr) << "Trace replay failed with ScheduleError: 
" << e.what();
+      this->trace_fail_counter_++;
+      return std::nullopt;
+    } catch (const std::exception& e) {
+      TVM_PY_LOG(WARNING, nullptr) << "Trace replay failed with exception: " 
<< e.what();
+      this->trace_fail_counter_++;
+      return std::nullopt;
+    }
 
     for (int i = 0; i < n_; ++i) {
       Item& item = items_[i];
@@ -364,6 +374,10 @@ struct ThreadedTraceApply {
   /*! \brief Returns a string summarizing the failures on each postprocessor */
   std::string SummarizeFailures() const {
     std::ostringstream os;
+    os << "Trace replay failures: " << this->trace_fail_counter_.load() << " 
failure(s)";
+    if (n_ > 0) {
+      os << "\n";
+    }
     for (int i = 0; i < n_; ++i) {
       const Item& item = items_[i];
       os << "Postproc #" << i << " [" << item.postproc  //
@@ -375,6 +389,9 @@ struct ThreadedTraceApply {
     return os.str();
   }
 
+  /*! \brief Returns the number of trace replay failures. */
+  int TraceFailCount() const { return this->trace_fail_counter_.load(); }
+
  private:
   /*! \brief A helper data structure that stores the fail count for each 
postprocessor. */
   struct Item {
@@ -386,6 +403,8 @@ struct ThreadedTraceApply {
 
   /*! \brief The number of total postprocessors. */
   int n_;
+  /*! \brief The thread-safe trace replay failure counter. */
+  std::atomic<int> trace_fail_counter_{0};
   /*! \brief The pointer to the list of postprocessor items. */
   Item* items_;
 };
diff --git 
a/tests/python/s_tir/meta_schedule/test_meta_schedule_search_strategy.py 
b/tests/python/s_tir/meta_schedule/test_meta_schedule_search_strategy.py
index 5df88ba7d5..370eff27c7 100644
--- a/tests/python/s_tir/meta_schedule/test_meta_schedule_search_strategy.py
+++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_search_strategy.py
@@ -49,6 +49,22 @@ class Matmul:
                     C[vi, vj] = 0.0 # type: ignore
                 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
 
+
[email protected]_module
+class OtherBlock:
+    @T.prim_func
+    def main(a: T.handle, b: T.handle, c: T.handle) -> None: # type: ignore
+        T.func_attr({"global_symbol": "main"})
+        A = T.match_buffer(a, (32, 32), "float32")
+        B = T.match_buffer(b, (32, 32), "float32")
+        C = T.match_buffer(c, (32, 32), "float32")
+        for i, j, k in T.grid(32, 32, 32):
+            with T.sblock("other"):
+                vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+                with T.init():
+                    C[vi, vj] = 0.0 # type: ignore
+                C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
+
 # fmt: on
 # pylint: 
enable=missing-class-docstring,invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument
 
@@ -308,6 +324,63 @@ def 
test_meta_schedule_evolutionary_search_fail_init_population():  # pylint: di
     assert candidates is None
 
 
+def test_meta_schedule_evolutionary_search_skip_invalid_measured_trace()  # 
pylint: disable = invalid-name
+    # Construct an incompatible measured trace: it references block name 
"other",
+    # which doesn't exist in Matmul. Replaying this trace should fail and be 
skipped.
+    wrong_sch = Schedule(OtherBlock)
+    wrong_sch.get_sblock("other")
+    wrong_trace = wrong_sch.trace
+
+    database = ms.database.MemoryDatabase()
+    workload = database.commit_workload(Matmul)
+    database.commit_tuning_record(
+        ms.database.TuningRecord(
+            trace=wrong_trace,
+            workload=workload,
+            run_secs=[0.1],
+            target=tvm.target.Target("llvm"),
+            args_info=ms.arg_info.ArgInfo.from_prim_func(func=Matmul["main"]),
+        )
+    )
+
+    context = ms.TuneContext(
+        mod=Matmul,
+        space_generator=ms.space_generator.ScheduleFn(
+            sch_fn=_schedule_matmul,
+            sch_rules=[],
+            postprocs=[],
+            mutator_probs={
+                DummyMutator(): 1.0,
+            },
+        ),
+        search_strategy=ms.search_strategy.EvolutionarySearch(
+            population_size=5,
+            init_measured_ratio=1.0,
+            init_min_unmeasured=1,
+            genetic_num_iters=1,
+            genetic_mutate_prob=0.5,
+            genetic_max_fail_count=4,
+            eps_greedy=0.9,
+        ),
+        target=tvm.target.Target("llvm"),
+        num_threads=1,
+    )
+    strategy = context.search_strategy
+    strategy.pre_tuning(
+        max_trials=4,
+        num_trials_per_iter=2,
+        
design_spaces=context.space_generator.generate_design_space(context.mod),
+        database=database,
+        cost_model=ms.cost_model.RandomModel(),
+    )
+
+    candidates = strategy.generate_measure_candidates()
+    strategy.post_tuning()
+
+    # Regression assertion: invalid measured trace should be skipped, not crash
+    assert candidates is not None
+
+
 def test_search_strategy_abstract_class_instantiation():
     """Test that directly instantiating abstract SearchStrategy raises 
TypeError instead of segfault."""
     from tvm.s_tir.meta_schedule import SearchStrategy, TuneContext

Reply via email to