This is an automated email from the ASF dual-hosted git repository.

comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new fd39122  [AutoScheduler] Enable schedule sharing in dispatch context 
(#7344)
fd39122 is described below

commit fd391223c19bec454f488f8a976a0766fadb0db3
Author: Cody Yu <comaniac0...@gmail.com>
AuthorDate: Wed Jan 27 14:54:43 2021 -0800

    [AutoScheduler] Enable schedule sharing in dispatch context (#7344)
    
    * [AutoScheduler] Enable schedule sharing in dispatch context
    
    * Update python/tvm/auto_scheduler/dispatcher.py
---
 python/tvm/auto_scheduler/dispatcher.py            | 135 ++++++++++++++++-----
 python/tvm/auto_scheduler/measure_record.py        |  65 +---------
 python/tvm/auto_scheduler/utils.py                 |  65 +++++++++-
 .../python/unittest/test_auto_scheduler_measure.py |  18 +--
 4 files changed, 178 insertions(+), 105 deletions(-)

diff --git a/python/tvm/auto_scheduler/dispatcher.py 
b/python/tvm/auto_scheduler/dispatcher.py
index b0b98d8..f2d7536 100644
--- a/python/tvm/auto_scheduler/dispatcher.py
+++ b/python/tvm/auto_scheduler/dispatcher.py
@@ -30,6 +30,7 @@ import numpy as np
 
 from tvm.tir.expr import FloatImm
 from .measure_record import load_records
+from .utils import calc_workload_dis_factor, decode_workload_key
 
 logger = logging.getLogger("auto_scheduler")
 
@@ -126,18 +127,53 @@ class ApplyHistoryBest(DispatchContext):
         If is str, then it should be the filename of a records log file.
         Each row of this file is an encoded record pair. Otherwise, it is an 
iterator.
     n_lines: Optional[int]
-        if it is not None, only load the first `n_lines` lines of log
+        if it is not None, only load the first `n_lines` lines of log.
+    include_compatible: bool
+        When set to True, compatible records will also be considered.
     """
 
-    def __init__(self, records, n_lines=None):
+    def __init__(self, records, n_lines=None, include_compatible=False):
         super(ApplyHistoryBest, self).__init__()
+        self.include_compatible = include_compatible
 
+        # Dict[str (target key),
+        #   Dict[str (workload hash),
+        #     Dict[tuple (workload args), tuple (State, cost)]]]
         self.best_by_targetkey = {}
         self.best_by_model = {}
         self._best_user_defined = {}
 
         self.load(records, n_lines)
 
+    @staticmethod
+    def get_workload_entry(best_records, target_key, workload_key):
+        """Get the entry of the target key and workload key hash in the given 
best record map.
+
+        Parameters
+        ----------
+        best_records: Dict[str, Dict[str, Dict[str, Any]]]
+            The best record map.
+        target_key: str
+            The first key to the best_records.
+        workload_key: str
+            The workload key that can be decoded to workload hash and args.
+
+        Returns
+        -------
+        entry: Dict[str, Any]
+            The entry in best_records with target key and workload hash.
+        workload_hash: str
+            The workload hash decoded from workload_key.
+        workload_args: Tuple[Any, ...]
+            The hashable tuple of workload args decoded from workload_key.
+        """
+        workload_hash, workload_args = decode_workload_key(workload_key)
+        if target_key not in best_records:
+            best_records[target_key] = {}
+        if workload_hash not in best_records[target_key]:
+            best_records[target_key][workload_hash] = {}
+        return best_records[target_key][workload_hash], workload_hash, 
workload_args
+
     def load(self, records, n_lines=None):
         """Load records to this dispatch context
 
@@ -171,29 +207,32 @@ class ApplyHistoryBest(DispatchContext):
             if res.error_no != 0:
                 continue
 
+            costs = [x.value for x in res.costs if isinstance(x, FloatImm)]
+            cost = np.mean(costs)
+
             # use target keys in tvm target system as key to build best map
             for k in inp.task.target.keys:
-                key = (k, inp.task.workload_key)
-                if key not in best_by_targetkey:
-                    best_by_targetkey[key] = (inp, res)
+                entry, _, workload_args = self.get_workload_entry(
+                    best_by_targetkey, k, inp.task.workload_key
+                )
+                if workload_args not in entry:
+                    entry[workload_args] = (inp.state, cost)
                 else:
-                    _, other_res = best_by_targetkey[key]
-                    other_costs = [x.value for x in other_res.costs if 
isinstance(x, FloatImm)]
-                    costs = [x.value for x in res.costs if isinstance(x, 
FloatImm)]
-                    if np.mean(other_costs) > np.mean(costs):
-                        best_by_targetkey[key] = (inp, res)
+                    _, other_cost = entry[workload_args]
+                    if other_cost > cost:
+                        entry[workload_args] = (inp.state, cost)
 
             # use model as key to build best map
-            key = (inp.task.target.model, inp.task.workload_key)
-            if key not in best_by_model:
+            entry, _, workload_args = self.get_workload_entry(
+                best_by_model, inp.task.target.model, inp.task.workload_key
+            )
+            if workload_args not in entry:
                 if inp.task.target.model != "unknown":
-                    best_by_model[key] = (inp, res)
+                    entry[workload_args] = (inp.state, cost)
             else:
-                _, other_res = best_by_model[key]
-                other_costs = [x.value for x in other_res.costs if 
isinstance(x, FloatImm)]
-                costs = [x.value for x in res.costs if isinstance(x, FloatImm)]
-                if np.mean(other_costs) > np.mean(costs):
-                    best_by_model[key] = (inp, res)
+                _, other_cost = entry[workload_args]
+                if other_cost > cost:
+                    entry[workload_args] = (inp.state, cost)
 
         logger.debug("Finish loading %d records", counter)
 
@@ -205,31 +244,61 @@ class ApplyHistoryBest(DispatchContext):
                 " above the dispatcher call. So does other target. "
             )
 
+        def match_record(best_records, target_key, workload_key):
+            """The helper function to match the record in the given map
+            and return the matched state, or None if no match.
+            """
+            ret = None
+
+            entry, workload_hash, workload_args = self.get_workload_entry(
+                best_records, target_key, workload_key
+            )
+            if workload_args in entry:
+                ret = entry[workload_args][0]
+            elif self.include_compatible:
+                best_cost = float("inf")
+                for args, val in entry.items():
+                    dis_f = calc_workload_dis_factor(
+                        (workload_hash, workload_args), (workload_hash, args)
+                    )
+                    if dis_f == float("inf"):
+                        continue
+
+                    state, cost = val
+                    cost *= dis_f
+                    if ret is None or cost < best_cost:
+                        best_cost = cost
+                        ret = state
+            return ret
+
         # first try matching by model
-        key = (target.model, workload_key)
-        if key in self._best_user_defined:
-            return self._best_user_defined[key]
-        if key in self.best_by_model:
-            return self.best_by_model[key][0].state
+        ret = match_record(self._best_user_defined, target.model, workload_key)
+        if ret is not None:
+            return ret
+        ret = match_record(self.best_by_model, target.model, workload_key)
+        if ret is not None:
+            return ret
 
         # then try matching by target key
         for k in target.keys:
-            key = (k, workload_key)
-            if key in self._best_user_defined:
-                return self._best_user_defined[key]
-            if key in self.best_by_targetkey:
-                return self.best_by_targetkey[key][0].state
+            ret = match_record(self._best_user_defined, k, workload_key)
+            if ret is not None:
+                return ret
+            ret = match_record(self.best_by_targetkey, k, workload_key)
+            if ret is not None:
+                return ret
 
         return None
 
     def update(self, target, workload_key, state):
-        model = target.model
-        key = (model, workload_key)
-        self._best_user_defined[key] = state
+        entry, _, workload_args = self.get_workload_entry(
+            self._best_user_defined, target.model, workload_key
+        )
+        entry[workload_args] = (state, 1)
 
         for k in target.keys:
-            key = (k, workload_key)
-            self._best_user_defined[key] = state
+            entry, _, _ = self.get_workload_entry(self._best_user_defined, k, 
workload_key)
+            entry[workload_args] = (state, 1)
 
 
 class FallbackContext(DispatchContext):
diff --git a/python/tvm/auto_scheduler/measure_record.py 
b/python/tvm/auto_scheduler/measure_record.py
index 9eaef18..200d24f 100644
--- a/python/tvm/auto_scheduler/measure_record.py
+++ b/python/tvm/auto_scheduler/measure_record.py
@@ -27,7 +27,7 @@ import numpy as np
 import tvm._ffi
 from tvm.runtime import Object
 from .measure import MeasureErrorNo, MeasureCallback
-from .utils import decode_workload_key
+from .utils import calc_workload_dis_factor, decode_workload_key
 from . import _ffi_api
 
 logger = logging.getLogger("auto_scheduler")
@@ -130,65 +130,6 @@ class RecordReader(Object):
             yield ret[0], ret[1]  # (input, result)
 
 
-def calc_workload_dis_factor(target_workload_key, workload_key):
-    """Calculate the distance factor of the workload to the target workload.
-    If two workloads are not compatible at all (i.e., different compute DAG or 
function),
-    then the distance factor is "inf". Otherwise, we calculate the factor by 
traversing
-    the workload arguments, which are the arguments of the compute function,
-    or the output shapes for the ComputeDAG. The factor is calculated by the 
following rules:
-
-    1. For non-zero integer values: `product(target_arg / candidate_arg)`.
-    2. For non-integer or zero values: "inf" if not equal else 1.
-
-    As a result, factor=1 is the optimal when two workloads are identical.
-
-    Parameters
-    ----------
-    target_workload_key: str
-        The target workload key in JSON string.
-
-    workload_key: str
-        The candidate workload key in JSON string.
-
-    Returns
-    -------
-    dis_f: float
-        The distance factor.
-    """
-
-    def flatten_list(inp):
-        ret = []
-        for elt in inp:
-            if isinstance(elt, list):
-                ret += flatten_list(elt)
-            else:
-                ret.append(elt)
-        return ret
-
-    target_key, target_args = decode_workload_key(target_workload_key)
-    target_args = flatten_list(target_args) if target_args is not None else []
-    key, args = decode_workload_key(workload_key)
-    args = flatten_list(args) if args is not None else []
-
-    # Not even the same func/DAG.
-    if key != target_key or len(target_args) != len(args):
-        return float("inf")
-
-    dis_f = 1
-    for target_arg, arg in zip(target_args, args):
-        if isinstance(target_arg, int):
-            if target_arg == 0 or arg == 0:
-                if target_arg != arg:
-                    return float("inf")
-            elif target_arg % arg != 0:
-                return float("inf")
-            else:
-                dis_f *= target_arg / arg
-        elif target_arg != arg:
-            return float("inf")
-    return dis_f
-
-
 def load_record_from_string(record):
     """
     Load the measure record from string.
@@ -304,7 +245,9 @@ def load_best_record(filename, workload_key=None, 
target=None, include_compatibl
         cost = np.mean(costs)
 
         if workload_key is not None:
-            dis_f = calc_workload_dis_factor(workload_key, 
inp.task.workload_key)
+            dis_f = calc_workload_dis_factor(
+                decode_workload_key(workload_key), 
decode_workload_key(inp.task.workload_key)
+            )
             if dis_f == float("inf"):
                 continue
             if not include_compatible and dis_f != 1:
diff --git a/python/tvm/auto_scheduler/utils.py 
b/python/tvm/auto_scheduler/utils.py
index fd25fdb..8aa33e6 100644
--- a/python/tvm/auto_scheduler/utils.py
+++ b/python/tvm/auto_scheduler/utils.py
@@ -57,18 +57,77 @@ def decode_workload_key(workload_key):
     -------
     name: str
         The workload function name or the DAG hash.
-    args: Optional[List[Any]]
-        The arguments of the workload, or None if the workload key format is 
not decodeable.
+    args: Optional[Tuple[Any, ...]]
+        The flatten arguments in a tuple, or None if the workload key format 
is not decodeable.
     """
+
+    def flatten_list(inp):
+        ret = []
+        for elt in inp:
+            if isinstance(elt, list):
+                ret += flatten_list(elt)
+            else:
+                ret.append(elt)
+        return ret
+
     try:
         key_list = json.loads(workload_key)
         if isinstance(key_list, list) and len(key_list) >= 1:
-            return key_list[0], key_list[1:]
+            return key_list[0], tuple(flatten_list(key_list[1:]))
     except json.decoder.JSONDecodeError:
         pass
     return workload_key, None
 
 
+def calc_workload_dis_factor(target_workload_pair, workload_pair):
+    """Calculate the distance factor of the workload to the target workload.
+    If two workloads are not compatible at all (i.e., different compute DAG or 
function),
+    then the distance factor is "inf". Otherwise, we calculate the factor by 
traversing
+    the workload arguments, which are the arguments of the compute function,
+    or the output shapes for the ComputeDAG. The factor is calculated by the 
following rules:
+
+    1. For non-zero integer values: `product(target_arg / candidate_arg)`.
+    2. For non-integer or zero values: "inf" if not equal else 1.
+
+    As a result, factor=1 is the optimal when two workloads are identical.
+
+    Parameters
+    ----------
+    target_workload_pair: Tuple[str, Optional[Tuple[Any, ...]]]
+        The target workload pair: (hash, argument tuple).
+
+    workload_pair: Tuple[str, Optional[Tuple[Any, ...]]]
+        The candidate workload pair: (hash, argument tuple).
+
+    Returns
+    -------
+    dis_f: float
+        The distance factor.
+    """
+    target_key, target_args = target_workload_pair
+    target_args = target_args if target_args is not None else []
+    key, args = workload_pair
+    args = args if args is not None else []
+
+    # Not even the same func/DAG.
+    if key != target_key or len(target_args) != len(args):
+        return float("inf")
+
+    dis_f = 1
+    for target_arg, arg in zip(target_args, args):
+        if isinstance(target_arg, int):
+            if target_arg == 0 or arg == 0:
+                if target_arg != arg:
+                    return float("inf")
+            elif target_arg % arg != 0:
+                return float("inf")
+            else:
+                dis_f *= target_arg / arg
+        elif target_arg != arg:
+            return float("inf")
+    return dis_f
+
+
 def get_func_name(func):
     """Get name of a function.
 
diff --git a/tests/python/unittest/test_auto_scheduler_measure.py 
b/tests/python/unittest/test_auto_scheduler_measure.py
index 3b074b2..041fb7e 100644
--- a/tests/python/unittest/test_auto_scheduler_measure.py
+++ b/tests/python/unittest/test_auto_scheduler_measure.py
@@ -202,35 +202,36 @@ def test_recover_measure_input():
 
 
 def test_workload_dis_factor():
-    calc = auto_scheduler.measure_record.calc_workload_dis_factor
+    calc = auto_scheduler.utils.calc_workload_dis_factor
+    decode = auto_scheduler.utils.decode_workload_key
 
     # Identical
     target_wkl_key = json.dumps(
         ["func1", [8, 3, 224, 224], [32, 3, 3, 3], [0, 0], [1, 1], "float32"]
     )
-    assert calc(target_wkl_key, target_wkl_key) == 1
+    assert calc(decode(target_wkl_key), decode(target_wkl_key)) == 1
 
     # Compatible with a factor
     wkl_key = json.dumps(["func1", [1, 3, 112, 112], [32, 3, 3, 3], [0, 0], 
[1, 1], "float32"])
-    assert calc(target_wkl_key, wkl_key) == 8 * 2 * 2
+    assert calc(decode(target_wkl_key), decode(wkl_key)) == 8 * 2 * 2
 
     # Incompatible argument with zeros
     wkl_key = json.dumps(["func1", [8, 3, 224, 224], [32, 3, 3, 3], [1, 1], 
[1, 1], "float32"])
-    assert calc(target_wkl_key, wkl_key) == float("inf")
+    assert calc(decode(target_wkl_key), decode(wkl_key)) == float("inf")
     wkl_key = json.dumps(["func1", [8, 3, 224, 224], [32, 3, 3, 3], [0, 0], 
[0, 0], "float32"])
-    assert calc(target_wkl_key, wkl_key) == float("inf")
+    assert calc(decode(target_wkl_key), decode(wkl_key)) == float("inf")
 
     # Incompatible non-integter argument
     wkl_key = json.dumps(["func1", [8, 3, 224, 224], [32, 3, 3, 3], [0, 0], 
[1, 1], "int8"])
-    assert calc(target_wkl_key, wkl_key) == float("inf")
+    assert calc(decode(target_wkl_key), decode(wkl_key)) == float("inf")
 
     # Incompatible function
     wkl_key = json.dumps(["func2", [8, 3, 224, 224], [32, 3, 3, 3], [0, 0], 
[1, 1], "float32"])
-    assert calc(target_wkl_key, wkl_key) == float("inf")
+    assert calc(decode(target_wkl_key), decode(wkl_key)) == float("inf")
 
     # Incompatible due to non-dividable factor
     wkl_key = json.dumps(["func1", [8, 3, 223, 223], [32, 3, 3, 3], [0, 0], 
[1, 1], "float32"])
-    assert calc(target_wkl_key, wkl_key) == float("inf")
+    assert calc(decode(target_wkl_key), decode(wkl_key)) == float("inf")
 
 
 def test_measure_local_builder_runner():
@@ -322,6 +323,7 @@ if __name__ == "__main__":
     test_record_follow_split_follow_fused_split()
     test_record_pragma_storage_align_rfactor()
     test_recover_measure_input()
+    test_workload_dis_factor()
     test_measure_local_builder_runner()
     test_measure_local_builder_rpc_runner()
     test_measure_target_host()

Reply via email to