This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 0a0dd3162b [S-TIR][MetaSchedule] Make evolutionary search resilient to
trace replay failures (#19438)
0a0dd3162b is described below
commit 0a0dd3162bf460ec047d95f956da309487745792
Author: Neo Chien <[email protected]>
AuthorDate: Sun Apr 26 03:01:02 2026 +0800
[S-TIR][MetaSchedule] Make evolutionary search resilient to trace replay
failures (#19438)
Hi Committers,
This PR is trying to fix issues
https://github.com/apache/tvm/issues/17934. Any suggestions would be
appreciated if you are available.
### Root Cause
- During `EvolutionarySearch` candidate generation,
`trace->ApplyToSchedule(...)` could throw `ScheduleError`.
- The exception was propagated through parallel execution and aborted
tuning.
- Error handling was inconsistent between measured and unmeasured paths,
and failure visibility was limited.
### Solutions
- Catch trace replay failures in `ThreadedTraceApply::Apply` and return
`nullopt` instead of crashing.
- Add trace replay failure counting (`trace_fail_counter_`) and accessor
(`TraceFailCount()`).
- Align measured path `PickBestFromDatabase` with unmeasured behavior:
skip invalid candidates and continue.
- Add visible `WARNING` logs when trace replay failures occur (to avoid
silent failures).
---------
Co-authored-by: cchung100m <[email protected]>
---
.../search_strategy/evolutionary_search.cc | 24 +++++--
src/s_tir/meta_schedule/utils.h | 35 ++++++++---
.../test_meta_schedule_search_strategy.py | 73 ++++++++++++++++++++++
3 files changed, 120 insertions(+), 12 deletions(-)
diff --git a/src/s_tir/meta_schedule/search_strategy/evolutionary_search.cc
b/src/s_tir/meta_schedule/search_strategy/evolutionary_search.cc
index fabe50dd60..ec85ffeee5 100644
--- a/src/s_tir/meta_schedule/search_strategy/evolutionary_search.cc
+++ b/src/s_tir/meta_schedule/search_strategy/evolutionary_search.cc
@@ -498,13 +498,24 @@ std::vector<Schedule>
EvolutionarySearchNode::State::PickBestFromDatabase(int nu
TVM_FFI_ICHECK(!result.defined());
if (ffi::Optional<Schedule> sch = pp.Apply(mod, trace, rand_state)) {
result = sch.value();
- } else {
- TVM_FFI_THROW(ValueError) << "Cannot postprocess the trace:\n" << trace;
- throw;
}
};
support::parallel_for_dynamic(0, actual_num, self->ctx_->num_threads,
f_proc_measured);
- return results;
+ TVM_PY_LOG(INFO, self->ctx_->logger) << "Pick-Best-From-Database summary:\n"
+ << pp.SummarizeFailures();
+ if (pp.TraceFailCount() > 0) {
+ TVM_PY_LOG(WARNING, self->ctx_->logger)
+ << "PickBestFromDatabase skipped " << pp.TraceFailCount()
+ << " candidate(s) due to trace replay failures";
+ }
+ std::vector<Schedule> filtered;
+ filtered.reserve(actual_num);
+ for (const Schedule& sch : results) {
+ if (sch.defined()) {
+ filtered.push_back(sch);
+ }
+ }
+ return filtered;
}
std::vector<Schedule> EvolutionarySearchNode::State::SampleInitPopulation(int
num) {
@@ -538,6 +549,11 @@ std::vector<Schedule>
EvolutionarySearchNode::State::SampleInitPopulation(int nu
fail_count += !found_new;
TVM_PY_LOG(INFO, self->ctx_->logger) << "Sample-Init-Population summary:\n"
<< pp.SummarizeFailures();
+ if (pp.TraceFailCount() > 0) {
+ TVM_PY_LOG(WARNING, self->ctx_->logger)
+ << "SampleInitPopulation encountered " << pp.TraceFailCount()
+ << " trace replay failure(s); invalid candidates were skipped";
+ }
}
return out_schs;
}
diff --git a/src/s_tir/meta_schedule/utils.h b/src/s_tir/meta_schedule/utils.h
index 847adc2591..a7804361eb 100644
--- a/src/s_tir/meta_schedule/utils.h
+++ b/src/s_tir/meta_schedule/utils.h
@@ -330,14 +330,24 @@ struct ThreadedTraceApply {
*/
ffi::Optional<s_tir::Schedule> Apply(const IRModule& mod, const
s_tir::Trace& trace,
TRandState* rand_state) {
- s_tir::Schedule sch =
- s_tir::Schedule::Traced(mod,
- /*rand_state=*/ForkSeed(rand_state),
- /*debug_mode=*/0,
-
/*error_render_level=*/s_tir::ScheduleErrorRenderLevel::kNone);
-
- trace->ApplyToSchedule(sch, /*remove_postproc=*/true);
- sch->EnterPostproc();
+ s_tir::Schedule sch{nullptr};
+ try {
+ sch = s_tir::Schedule::Traced(mod,
+ /*rand_state=*/ForkSeed(rand_state),
+ /*debug_mode=*/0,
+ /*error_render_level=*/
+ s_tir::ScheduleErrorRenderLevel::kNone);
+ trace->ApplyToSchedule(sch, /*remove_postproc=*/true);
+ sch->EnterPostproc();
+ } catch (const s_tir::ScheduleError& e) {
+ TVM_PY_LOG(WARNING, nullptr) << "Trace replay failed with ScheduleError:
" << e.what();
+ this->trace_fail_counter_++;
+ return std::nullopt;
+ } catch (const std::exception& e) {
+ TVM_PY_LOG(WARNING, nullptr) << "Trace replay failed with exception: "
<< e.what();
+ this->trace_fail_counter_++;
+ return std::nullopt;
+ }
for (int i = 0; i < n_; ++i) {
Item& item = items_[i];
@@ -364,6 +374,10 @@ struct ThreadedTraceApply {
/*! \brief Returns a string summarizing the failures on each postprocessor */
std::string SummarizeFailures() const {
std::ostringstream os;
+ os << "Trace replay failures: " << this->trace_fail_counter_.load() << "
failure(s)";
+ if (n_ > 0) {
+ os << "\n";
+ }
for (int i = 0; i < n_; ++i) {
const Item& item = items_[i];
os << "Postproc #" << i << " [" << item.postproc //
@@ -375,6 +389,9 @@ struct ThreadedTraceApply {
return os.str();
}
+ /*! \brief Returns the number of trace replay failures. */
+ int TraceFailCount() const { return this->trace_fail_counter_.load(); }
+
private:
/*! \brief A helper data structure that stores the fail count for each
postprocessor. */
struct Item {
@@ -386,6 +403,8 @@ struct ThreadedTraceApply {
/*! \brief The number of total postprocessors. */
int n_;
+ /*! \brief The thread-safe trace replay failure counter. */
+ std::atomic<int> trace_fail_counter_{0};
/*! \brief The pointer to the list of postprocessor items. */
Item* items_;
};
diff --git
a/tests/python/s_tir/meta_schedule/test_meta_schedule_search_strategy.py
b/tests/python/s_tir/meta_schedule/test_meta_schedule_search_strategy.py
index 5df88ba7d5..370eff27c7 100644
--- a/tests/python/s_tir/meta_schedule/test_meta_schedule_search_strategy.py
+++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_search_strategy.py
@@ -49,6 +49,22 @@ class Matmul:
C[vi, vj] = 0.0 # type: ignore
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
+
[email protected]_module
+class OtherBlock:
+ @T.prim_func
+ def main(a: T.handle, b: T.handle, c: T.handle) -> None: # type: ignore
+ T.func_attr({"global_symbol": "main"})
+ A = T.match_buffer(a, (32, 32), "float32")
+ B = T.match_buffer(b, (32, 32), "float32")
+ C = T.match_buffer(c, (32, 32), "float32")
+ for i, j, k in T.grid(32, 32, 32):
+ with T.sblock("other"):
+ vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+ with T.init():
+ C[vi, vj] = 0.0 # type: ignore
+ C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
+
# fmt: on
# pylint:
enable=missing-class-docstring,invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument
@@ -308,6 +324,63 @@ def
test_meta_schedule_evolutionary_search_fail_init_population(): # pylint: di
assert candidates is None
+def test_meta_schedule_evolutionary_search_skip_invalid_measured_trace() #
pylint: disable = invalid-name
+ # Construct an incompatible measured trace: it references block name
"other",
+ # which doesn't exist in Matmul. Replaying this trace should fail and be
skipped.
+ wrong_sch = Schedule(OtherBlock)
+ wrong_sch.get_sblock("other")
+ wrong_trace = wrong_sch.trace
+
+ database = ms.database.MemoryDatabase()
+ workload = database.commit_workload(Matmul)
+ database.commit_tuning_record(
+ ms.database.TuningRecord(
+ trace=wrong_trace,
+ workload=workload,
+ run_secs=[0.1],
+ target=tvm.target.Target("llvm"),
+ args_info=ms.arg_info.ArgInfo.from_prim_func(func=Matmul["main"]),
+ )
+ )
+
+ context = ms.TuneContext(
+ mod=Matmul,
+ space_generator=ms.space_generator.ScheduleFn(
+ sch_fn=_schedule_matmul,
+ sch_rules=[],
+ postprocs=[],
+ mutator_probs={
+ DummyMutator(): 1.0,
+ },
+ ),
+ search_strategy=ms.search_strategy.EvolutionarySearch(
+ population_size=5,
+ init_measured_ratio=1.0,
+ init_min_unmeasured=1,
+ genetic_num_iters=1,
+ genetic_mutate_prob=0.5,
+ genetic_max_fail_count=4,
+ eps_greedy=0.9,
+ ),
+ target=tvm.target.Target("llvm"),
+ num_threads=1,
+ )
+ strategy = context.search_strategy
+ strategy.pre_tuning(
+ max_trials=4,
+ num_trials_per_iter=2,
+
design_spaces=context.space_generator.generate_design_space(context.mod),
+ database=database,
+ cost_model=ms.cost_model.RandomModel(),
+ )
+
+ candidates = strategy.generate_measure_candidates()
+ strategy.post_tuning()
+
+ # Regression assertion: invalid measured trace should be skipped, not crash
+ assert candidates is not None
+
+
def test_search_strategy_abstract_class_instantiation():
"""Test that directly instantiating abstract SearchStrategy raises
TypeError instead of segfault."""
from tvm.s_tir.meta_schedule import SearchStrategy, TuneContext