merrymercy commented on a change in pull request #6710:
URL: https://github.com/apache/incubator-tvm/pull/6710#discussion_r511955715



##########
File path: python/tvm/auto_scheduler/relay_integration.py
##########
@@ -0,0 +1,179 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-variable,invalid-name
+
+"""
+Integrate auto_scheduler into relay. It implements the following items:
+1. Extract search tasks from a relay program
+2. Provide auto-scheduling for all TOPI compute functions
+"""
+
+import os
+import threading
+
+import tvm
+from tvm import te, transform
+from tvm.te.tensor import PlaceholderOp, ComputeOp
+from .dispatcher import DispatchContext, FallbackConfig
+from .workload_registry import register_workload_tensors
+from .compute_dag import ComputeDAG
+
+
+def call_all_topi_funcs(mod, params, target):
+    """Call all TOPI compute + schedule to extract tasks in a relay program"""
+    # pylint: disable=import-outside-toplevel
+    from tvm import relay
+    from tvm.relay.backend import graph_runtime_codegen
+
+    with transform.PassContext(opt_level=3):
+        opt_mod, _ = relay.optimize(mod, target, params)
+        grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target)
+        grc.codegen(opt_mod["main"])
+
+
+def extract_tasks(mod, params, target):
+    """Extract tuning tasks from a relay program.
+
+    Parameters
+    ----------
+    mod: tvm.IRModule or relay.function.Function
+        The module or function to tune
+    params: dict of str to numpy array
+        The associated parameters of the program
+    target: tvm.target.Target
+        The compilation target
+
+    Returns
+    -------
+    wkl_keys: List[str]
+        The hash key of extracted workloads
+    wkl_weights: List[int]
+        The weight (i.e. the number of appearance) of extracted workload
+    """
+    # pylint: disable=import-outside-toplevel
+    from tvm import relay
+
+    env = TracingEnvironment()
+    with env:
+        # Run the compiler to collect all TOPI calls during compilation.
+
+        # Wrap build call in a new thread to avoid the conflict
+        # between python's multiprocessing and tvm's thread pool
+        build_thread = threading.Thread(target=call_all_topi_funcs, args=(mod, 
params, target))
+        build_thread.start()
+        build_thread.join()
+        relay.backend.compile_engine.get().clear()
+
+    # create tasks for target
+    wkl_keys = []
+    wkl_weights = []
+    for wkl_key, wkl_weight in env.wkl_key_collection.items():
+        wkl_keys.append(wkl_key)
+        wkl_weights.append(wkl_weight)
+
+    return wkl_keys, wkl_weights
+
+
+class TracingEnvironment:
+    """Global environment for tracing all topi function calls"""
+
+    current = None
+
+    def __init__(self):
+        self.relay_disable_build_cache = "false"
+        self.layout_rewrite_success_ct = 0
+        self.wkl_key_collection = {}
+
+    def __enter__(self):
+        self.relay_disable_build_cache = 
os.environ.get("TVM_RELAY_DISABLE_BUILD_CACHE", "false")
+        os.environ["TVM_RELAY_DISABLE_BUILD_CACHE"] = "true"
+        TracingEnvironment.current = self
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        os.environ["TVM_RELAY_DISABLE_BUILD_CACHE"] = 
self.relay_disable_build_cache
+        TracingEnvironment.current = None
+
+    def add_workload_key(self, key):
+        """Add the workload key of an Ansor search task
+
+        Parameters
+        ----------
+        key: str
+            The workload key of a task
+        """
+        if key in self.wkl_key_collection:
+            self.wkl_key_collection[key] += 1
+        else:
+            self.wkl_key_collection[key] = 1
+
+
+def traverse_to_get_io_tensors(outs):
+    """Traverse from a list of output tensors to get a whole computational 
DAG"""
+    inputs = []
+
+    visited = set()
+
+    def traverse(t):
+        if t in visited:
+            return
+        if isinstance(t.op, PlaceholderOp):
+            inputs.append(t)
+        elif isinstance(t.op, ComputeOp):
+            for x in t.op.input_tensors:
+                traverse(x)
+        visited.add(t)
+
+    for t in outs:
+        traverse(t)
+
+    return inputs + list(outs)
+
+
+# The suffix of implementations that use the auto-scheduler in the OpStrategy.
+auto_schedule_impl_suffix = ".auto_scheduler"
+
+
+def auto_schedule_topi(outs):

Review comment:
       We cannot call it `auto_schedule_te`. We have another function 
`auto_schedule` 
(https://github.com/apache/incubator-tvm/blob/e59c603515befb02035e237794aa0645dbfbaf09/python/tvm/auto_scheduler/auto_schedule.py#L161)
  for your use case.
   But this function is designed to be used as a TOPI schedule function for 
Relay, because it does a lot of other things.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to