This is an automated email from the ASF dual-hosted git repository. junrushao pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 93d79ba [AutoScheduler][Relay] Control compile engine cache via PassContext (#7220) 93d79ba is described below commit 93d79bafcf854a928d248aab92782da36eec3b4a Author: Cody Yu <comaniac0...@gmail.com> AuthorDate: Wed Jan 6 17:50:21 2021 -0800 [AutoScheduler][Relay] Control compile engine cache via PassContext (#7220) * [AutoScheduler][Relay] Control compile engine cache via PassContext * lint * lint --- python/tvm/auto_scheduler/relay_integration.py | 35 +++++++++----------------- src/relay/backend/compile_engine.cc | 5 +++- src/relay/backend/utils.h | 9 +++++++ 3 files changed, 25 insertions(+), 24 deletions(-) diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index eecf88b..ea1a8cc 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -56,7 +56,10 @@ def call_all_topi_funcs(mod, params, target): with transform.PassContext( opt_level=3, - config={"relay.backend.use_auto_scheduler": True}, + config={ + "relay.backend.use_auto_scheduler": True, + "relay.backend.disable_compile_engine_cache": True, + }, disabled_pass={"AutoSchedulerLayoutRewrite"}, ): try: @@ -105,7 +108,6 @@ def extract_tasks( The weight (i.e. the number of appearance) of extracted tasks """ # pylint: disable=import-outside-toplevel - from tvm import relay if isinstance(target, str): target = tvm.target.Target(target) @@ -123,17 +125,10 @@ def extract_tasks( build_thread.start() build_thread.join() - # query the compile engine to get the number of occurrence of all tasks - engine = relay.backend.compile_engine.get() - use_count_dict = {} - for k, v in engine.items(): - use_count_dict[k] = v.use_count - # create search tasks tasks = [] weights = [] - for wkl_key, ccache_key in env.wkl_key_to_ccache_key.items(): - dag = ComputeDAG(wkl_key) + for wkl_key, weight in env.wkl_key_to_weight.items(): tasks.append( SearchTask( workload_key=wkl_key, @@ -145,10 +140,7 @@ def extract_tasks( layout_rewrite_option=LayoutRewriteOption.get_target_default(target, True), ) ) - weights.append(use_count_dict[ccache_key] + 1) - - # clean the cached lowering results - engine.clear() + weights.append(weight) return tasks, weights @@ -169,7 +161,7 @@ class TracingEnvironment: def __init__(self, tracing_mode): self.tracing_mode = tracing_mode self.relay_disable_build_cache = "false" - self.wkl_key_to_ccache_key = {} + self.wkl_key_to_weight = {} def __enter__(self): TracingEnvironment.current = self @@ -178,17 +170,17 @@ class TracingEnvironment: def __exit__(self, exc_type, exc_val, exc_tb): TracingEnvironment.current = None - def add_workload_key(self, workload_key, ccache_key): + def add_workload_key(self, workload_key): """Add the workload key of a search task Parameters ---------- workload_key: str The workload key of a task - ccache_key: CCacheKey - The corresponding ccache_key of the task """ - self.wkl_key_to_ccache_key[workload_key] = ccache_key + if workload_key not in self.wkl_key_to_weight: + self.wkl_key_to_weight[workload_key] = 0 + self.wkl_key_to_weight[workload_key] += 1 @tvm._ffi.register_func("auto_scheduler.enter_layout_rewrite") @@ -278,7 +270,6 @@ def auto_schedule_topi(outs): An initial schdule in the tracing mode. """ # pylint: disable=import-outside-toplevel - from tvm import relay io_tensors, has_layout_free, has_complex_op = traverse_to_get_io_tensors(outs) if not io_tensors: # The compute includes dynamic shapes which are not supported yet. @@ -305,9 +296,7 @@ def auto_schedule_topi(outs): elif env.tracing_mode in [TracingMode.EXTRACT_TASK, TracingMode.EXTRACT_COMPLEX_TASK_ONLY]: # in the task extraction mode if has_complex_op or env.tracing_mode == TracingMode.EXTRACT_TASK: - engine = relay.backend.compile_engine.get() - ccache_key = engine.get_current_ccache_key() - env.add_workload_key(key, ccache_key) + env.add_workload_key(key) schedule = te.create_schedule([x.op for x in outs]) elif env.tracing_mode == TracingMode.PREPARE_LAYOUT_REWRITE: # in prepare_layout_rewrite mode diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 789f39d..c969c3b 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -701,7 +701,9 @@ class CompileEngineImpl : public CompileEngineNode { } else { value = CCacheValue(make_object<CCacheValueNode>()); value->use_count = 0; - cache_[key] = value; + if (!backend::IsCompileEngineCacheDisabled()) { + cache_[key] = value; + } } cur_ccache_key_ = key; @@ -832,6 +834,7 @@ CompileEngine& CompileEngine::Global() { } TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_auto_scheduler", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.disable_compile_engine_cache", Bool); TVM_REGISTER_GLOBAL("relay.backend._make_LoweredOutput") .set_body_typed([](tvm::Array<te::Tensor> outputs, OpImplementation impl) { diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index e167720..6908ca8 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -303,6 +303,15 @@ inline bool IsAutoSchedulerEnabled() { .value(); } +/*! + * \brief Return whether the compile engine cache is disabled in the pass context. + */ +inline bool IsCompileEngineCacheDisabled() { + return transform::PassContext::Current() + ->GetConfig<Bool>("relay.backend.disable_compile_engine_cache", Bool(false)) + .value(); +} + } // namespace backend } // namespace relay } // namespace tvm