This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 93d79ba  [AutoScheduler][Relay] Control compile engine cache via 
PassContext (#7220)
93d79ba is described below

commit 93d79bafcf854a928d248aab92782da36eec3b4a
Author: Cody Yu <comaniac0...@gmail.com>
AuthorDate: Wed Jan 6 17:50:21 2021 -0800

    [AutoScheduler][Relay] Control compile engine cache via PassContext (#7220)
    
    * [AutoScheduler][Relay] Control compile engine cache via PassContext
    
    * lint
    
    * lint
---
 python/tvm/auto_scheduler/relay_integration.py | 35 +++++++++-----------------
 src/relay/backend/compile_engine.cc            |  5 +++-
 src/relay/backend/utils.h                      |  9 +++++++
 3 files changed, 25 insertions(+), 24 deletions(-)

diff --git a/python/tvm/auto_scheduler/relay_integration.py 
b/python/tvm/auto_scheduler/relay_integration.py
index eecf88b..ea1a8cc 100644
--- a/python/tvm/auto_scheduler/relay_integration.py
+++ b/python/tvm/auto_scheduler/relay_integration.py
@@ -56,7 +56,10 @@ def call_all_topi_funcs(mod, params, target):
 
     with transform.PassContext(
         opt_level=3,
-        config={"relay.backend.use_auto_scheduler": True},
+        config={
+            "relay.backend.use_auto_scheduler": True,
+            "relay.backend.disable_compile_engine_cache": True,
+        },
         disabled_pass={"AutoSchedulerLayoutRewrite"},
     ):
         try:
@@ -105,7 +108,6 @@ def extract_tasks(
         The weight (i.e. the number of appearance) of extracted tasks
     """
     # pylint: disable=import-outside-toplevel
-    from tvm import relay
 
     if isinstance(target, str):
         target = tvm.target.Target(target)
@@ -123,17 +125,10 @@ def extract_tasks(
         build_thread.start()
         build_thread.join()
 
-    # query the compile engine to get the number of occurrence of all tasks
-    engine = relay.backend.compile_engine.get()
-    use_count_dict = {}
-    for k, v in engine.items():
-        use_count_dict[k] = v.use_count
-
     # create search tasks
     tasks = []
     weights = []
-    for wkl_key, ccache_key in env.wkl_key_to_ccache_key.items():
-        dag = ComputeDAG(wkl_key)
+    for wkl_key, weight in env.wkl_key_to_weight.items():
         tasks.append(
             SearchTask(
                 workload_key=wkl_key,
@@ -145,10 +140,7 @@ def extract_tasks(
                 
layout_rewrite_option=LayoutRewriteOption.get_target_default(target, True),
             )
         )
-        weights.append(use_count_dict[ccache_key] + 1)
-
-    # clean the cached lowering results
-    engine.clear()
+        weights.append(weight)
 
     return tasks, weights
 
@@ -169,7 +161,7 @@ class TracingEnvironment:
     def __init__(self, tracing_mode):
         self.tracing_mode = tracing_mode
         self.relay_disable_build_cache = "false"
-        self.wkl_key_to_ccache_key = {}
+        self.wkl_key_to_weight = {}
 
     def __enter__(self):
         TracingEnvironment.current = self
@@ -178,17 +170,17 @@ class TracingEnvironment:
     def __exit__(self, exc_type, exc_val, exc_tb):
         TracingEnvironment.current = None
 
-    def add_workload_key(self, workload_key, ccache_key):
+    def add_workload_key(self, workload_key):
         """Add the workload key of a search task
 
         Parameters
         ----------
         workload_key: str
             The workload key of a task
-        ccache_key: CCacheKey
-            The corresponding ccache_key of the task
         """
-        self.wkl_key_to_ccache_key[workload_key] = ccache_key
+        if workload_key not in self.wkl_key_to_weight:
+            self.wkl_key_to_weight[workload_key] = 0
+        self.wkl_key_to_weight[workload_key] += 1
 
 
 @tvm._ffi.register_func("auto_scheduler.enter_layout_rewrite")
@@ -278,7 +270,6 @@ def auto_schedule_topi(outs):
         An initial schdule in the tracing mode.
     """
     # pylint: disable=import-outside-toplevel
-    from tvm import relay
 
     io_tensors, has_layout_free, has_complex_op = 
traverse_to_get_io_tensors(outs)
     if not io_tensors:  # The compute includes dynamic shapes which are not 
supported yet.
@@ -305,9 +296,7 @@ def auto_schedule_topi(outs):
     elif env.tracing_mode in [TracingMode.EXTRACT_TASK, 
TracingMode.EXTRACT_COMPLEX_TASK_ONLY]:
         # in the task extraction mode
         if has_complex_op or env.tracing_mode == TracingMode.EXTRACT_TASK:
-            engine = relay.backend.compile_engine.get()
-            ccache_key = engine.get_current_ccache_key()
-            env.add_workload_key(key, ccache_key)
+            env.add_workload_key(key)
         schedule = te.create_schedule([x.op for x in outs])
     elif env.tracing_mode == TracingMode.PREPARE_LAYOUT_REWRITE:
         # in prepare_layout_rewrite mode
diff --git a/src/relay/backend/compile_engine.cc 
b/src/relay/backend/compile_engine.cc
index 789f39d..c969c3b 100644
--- a/src/relay/backend/compile_engine.cc
+++ b/src/relay/backend/compile_engine.cc
@@ -701,7 +701,9 @@ class CompileEngineImpl : public CompileEngineNode {
     } else {
       value = CCacheValue(make_object<CCacheValueNode>());
       value->use_count = 0;
-      cache_[key] = value;
+      if (!backend::IsCompileEngineCacheDisabled()) {
+        cache_[key] = value;
+      }
     }
     cur_ccache_key_ = key;
 
@@ -832,6 +834,7 @@ CompileEngine& CompileEngine::Global() {
 }
 
 TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_auto_scheduler", Bool);
+TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.disable_compile_engine_cache", 
Bool);
 
 TVM_REGISTER_GLOBAL("relay.backend._make_LoweredOutput")
     .set_body_typed([](tvm::Array<te::Tensor> outputs, OpImplementation impl) {
diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h
index e167720..6908ca8 100644
--- a/src/relay/backend/utils.h
+++ b/src/relay/backend/utils.h
@@ -303,6 +303,15 @@ inline bool IsAutoSchedulerEnabled() {
       .value();
 }
 
+/*!
+ * \brief Return whether the compile engine cache is disabled in the pass 
context.
+ */
+inline bool IsCompileEngineCacheDisabled() {
+  return transform::PassContext::Current()
+      ->GetConfig<Bool>("relay.backend.disable_compile_engine_cache", 
Bool(false))
+      .value();
+}
+
 }  // namespace backend
 }  // namespace relay
 }  // namespace tvm

Reply via email to