merrymercy commented on a change in pull request #6710:
URL: https://github.com/apache/incubator-tvm/pull/6710#discussion_r511954470



##########
File path: python/tvm/auto_scheduler/relay_integration.py
##########
@@ -0,0 +1,179 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-variable,invalid-name
+
+"""
+Integrate auto_scheduler into relay. It implements the following items:
+1. Extract search tasks from a relay program
+2. Provide auto-scheduling for all TOPI compute functions
+"""
+
+import os
+import threading
+
+import tvm
+from tvm import te, transform
+from tvm.te.tensor import PlaceholderOp, ComputeOp
+from .dispatcher import DispatchContext, FallbackConfig
+from .workload_registry import register_workload_tensors
+from .compute_dag import ComputeDAG
+
+
+def call_all_topi_funcs(mod, params, target):
+    """Call all TOPI compute + schedule to extract tasks in a relay program"""
+    # pylint: disable=import-outside-toplevel
+    from tvm import relay
+    from tvm.relay.backend import graph_runtime_codegen
+
+    with transform.PassContext(opt_level=3):
+        opt_mod, _ = relay.optimize(mod, target, params)
+        grc = graph_runtime_codegen.GraphRuntimeCodegen(None, target)
+        grc.codegen(opt_mod["main"])
+
+
+def extract_tasks(mod, params, target):
+    """Extract tuning tasks from a relay program.
+
+    Parameters
+    ----------
+    mod: tvm.IRModule or relay.function.Function
+        The module or function to tune
+    params: dict of str to numpy array
+        The associated parameters of the program
+    target: tvm.target.Target
+        The compilation target
+
+    Returns
+    -------
+    wkl_keys: List[str]
+        The hash key of extracted workloads
+    wkl_weights: List[int]
+        The weight (i.e. the number of appearance) of extracted workload
+    """
+    # pylint: disable=import-outside-toplevel
+    from tvm import relay

Review comment:
       Otherwise, it will lead to mutual import.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to