merrymercy commented on a change in pull request #6663:
URL: https://github.com/apache/incubator-tvm/pull/6663#discussion_r507182331



##########
File path: python/tvm/auto_scheduler/task_scheduler.py
##########
@@ -0,0 +1,452 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+
+""" The task scheduler that allocates the time resources when tuning multiple 
tasks together
+
+The details of the "gradient" strategy below can be found in the section 6 of 
this paper:
+L. Zheng, C. Jia, M. Sun, Z. Wu, C. Yu, et al. "Ansor : Generating 
High-Performance Tensor
+Programs for Deep Learning." (OSDI 2020).
+"""
+
+import time
+import math
+import logging
+
+import numpy as np
+
+from .search_policy import SearchPolicy, SketchPolicy
+from .cost_model import RandomModel, XGBModel
+from .utils import array_mean, to_str_round
+from .measure import ProgramMeasurer
+from .measure_record import RecordReader
+
+logger = logging.getLogger("auto_scheduler")
+
+
+class TaskScheduler:
+    """Allocate the time resources when tuning multiple tasks together
+
+    Parameters
+    ----------
+    tasks: List[SearchTask]
+        The list of all tasks
+    objective_func: Callable[List[float] -> float]
+        The objective function to be optimized
+    """
+
+    def __init__(self, tasks, objective_func):
+        self.tasks = tasks
+        self.objective_func = objective_func or sum
+
+    def compute_score(self, costs) -> float:
+        return self.objective_func(costs)
+
+
+def make_search_policies(
+    search_policy, tasks, num_measures_per_round, load_model_file=None, 
load_log_file=None
+):
+    """Make a list of search policies for a list of search tasks.
+    It creates one policy per task.
+
+    Parameters
+    ----------
+    search_policy: Union[str, List[SearchPolicy]]
+        The name of search policy.
+    tasks: List[SearchTask]
+        The list of all tasks
+    num_measures_per_round: int
+        The number of schedules to be measured at each search round.
+        This should be the same as `TuningOptions.num_measures_per_round`
+    load_model_file: Optional[str]
+        Load pre-trained model from this file
+    load_log_file: Optional[str]
+        Load measurement records from this file
+
+    Returns
+    -------
+    policies: List[SearchPolicy]
+        The list of search policies
+    """
+    if search_policy == "default":
+        search_policy = "sketch.xgb"
+
+    if isinstance(search_policy, str):
+        policy_type, model_type = search_policy.split(".")
+        if model_type == "xgb":
+            cost_model = XGBModel(num_warmup_sample=len(tasks) * 
num_measures_per_round)
+            if load_model_file:
+                logger.info("Load pretrained model...")
+                cost_model.load(load_model_file)
+            elif load_log_file:
+                cost_model.load_log_file(load_log_file)
+        elif model_type == "random":
+            cost_model = RandomModel()
+        else:
+            raise ValueError("Invalid search policy: " + search_policy)
+
+        if policy_type == "sketch":
+            search_policies = [SketchPolicy(task, cost_model) for task in 
tasks]
+        else:
+            raise ValueError("Invalid search policy: " + search_policy)
+    else:
+        # check type
+        assert isinstance(search_policy, (tuple, list))
+        for item in search_policy:
+            assert isinstance(item, SearchPolicy)
+        search_policies = search_policy
+
+    return search_policies
+
+
+def derive_similarity_tag(dag, log_base=1.618):
+    """Derive the tag for similarity check from one computational DAG.
+    The DAGs with the same tag are considered as similar tasks.
+
+    Parameters
+    ----------
+    dag: ComputeDAG
+        The input computational DAG
+    log_base: float = 1.618
+        The base of log to normalize FLOPS
+
+    Returns
+    -------
+    tag: str
+        The tag of this computational DAG.
+    """
+    ret = ""
+    for op in dag.ops:
+        tag = op.attrs.get("ansor_task_scheduler_tag", None)
+        if tag:
+            ret += op.attrs["ansor_task_scheduler_tag"] + "_"
+    if ret != "":
+        ret += "%d" % int(math.log(dag.flop_ct + 1, log_base))
+    return ret

Review comment:
       If the return value of a task is `""`,  it is not considered to be 
similar to any other tasks.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to