This is an automated email from the ASF dual-hosted git repository. zhaowu pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new e6d5318 [AutoScheduler] Separate shapes from DAG hash and enable schedule sharing (#7317) e6d5318 is described below commit e6d53185b96cc39f2aaec5e86ae11ca0ac675b8a Author: Cody Yu <comaniac0...@gmail.com> AuthorDate: Mon Jan 25 03:27:34 2021 -0800 [AutoScheduler] Separate shapes from DAG hash and enable schedule sharing (#7317) * [AutoScheduler] Separate shapes from DAG hash and enable schedule sharing * Update CI logs * lint * fix registry * add message; fix layout rewrite mismatch * update message * support other formats --- include/tvm/auto_scheduler/compute_dag.h | 7 ++ python/tvm/auto_scheduler/compute_dag.py | 35 +++--- python/tvm/auto_scheduler/measure_record.py | 126 +++++++++++++++++++-- python/tvm/auto_scheduler/relay_integration.py | 6 +- python/tvm/auto_scheduler/search_task.py | 8 +- python/tvm/auto_scheduler/utils.py | 27 +++++ python/tvm/auto_scheduler/workload_registry.py | 37 +++--- src/auto_scheduler/compute_dag.cc | 109 ++++++++++-------- .../python/unittest/test_auto_scheduler_measure.py | 33 ++++++ .../ci_logs/resnet-18-NHWC-B1-cuda.json | 48 ++++---- .../ci_logs/resnet-50-NHWC-B1-llvm.json | 55 +++++---- 11 files changed, 342 insertions(+), 149 deletions(-) diff --git a/include/tvm/auto_scheduler/compute_dag.h b/include/tvm/auto_scheduler/compute_dag.h index 1e3f097..a87563e 100755 --- a/include/tvm/auto_scheduler/compute_dag.h +++ b/include/tvm/auto_scheduler/compute_dag.h @@ -263,6 +263,13 @@ class ComputeDAG : public ObjectRef { String PrintStepsAsPython(const Array<Step>& transform_steps) const; /*! + * \brief Print the compute DAG to a string. This is also used to generate the ComputeDAG hash. + * \param simple_mode Simple mode will only include the op names and brief compute. + * \return The ComputeDAG in a string. + */ + String PrintDAG(bool simple_mode = false) const; + + /*! * \brief Fill the correct bound information for a given state by calling ir_pass::InferBound. * The states can lose complete bound information after some transform steps (e.g., compute_at). * We can call this function to infer and fill all the bound information. diff --git a/python/tvm/auto_scheduler/compute_dag.py b/python/tvm/auto_scheduler/compute_dag.py index a7f200a..948f277 100755 --- a/python/tvm/auto_scheduler/compute_dag.py +++ b/python/tvm/auto_scheduler/compute_dag.py @@ -19,11 +19,11 @@ """ The auto-scheduler's computational graph and related program analyses. """ import hashlib +import json import tvm._ffi from tvm.runtime import Object from tvm.runtime._ffi_node_api import LoadJSON, SaveJSON -from tvm.te import ComputeOp, PlaceholderOp from . import _ffi_api from .loop_state import State, StateObject @@ -220,32 +220,23 @@ class ComputeDAG(Object): state_obj = state if isinstance(state, StateObject) else state.state_object return _ffi_api.ComputeDAGRewriteLayoutFromState(self, state_obj) - def hash_key(self): - """Return the hash key of this compute DAG. + def workload_key(self): + """Return the workload key of this compute DAG. + The workload key is a JSON string from a tuple of (hash-key, tensor shapes...) Returns ------- key: str - The hash key of this compute DAG + The workload key of this compute DAG """ - # TODO(merrymercy): Implement this more carefully and move this to c++ as a member function - # of ComputeDAG - str_key = "" - for op in self.ops: - t = op.output(0) - if isinstance(op, PlaceholderOp): - str_key += "placeholder," - str_key += str(get_const_tuple(t.shape)) + "," - str_key += t.dtype + ";" - elif isinstance(op, ComputeOp): - str_key += str(t.op.body) + "," - str_key += str(get_const_tuple(t.shape)) + "," - str_key += t.dtype + ";" - else: - raise ValueError("Invalid op: " + op) - - str_key = str_key.encode(encoding="utf-8") - return hashlib.md5(str_key).hexdigest() + str_dag = _ffi_api.ComputeDAGPrintDAG(self, True) + str_dag = str_dag.encode(encoding="utf-8") + hash_key = hashlib.md5(str_dag).hexdigest() + + io_shapes = [] + for tensor in self.tensors: + io_shapes += get_const_tuple(tensor.shape) + return json.dumps([hash_key] + io_shapes) def __str__(self): # pretty print diff --git a/python/tvm/auto_scheduler/measure_record.py b/python/tvm/auto_scheduler/measure_record.py index 35e5e9b..9eaef18 100644 --- a/python/tvm/auto_scheduler/measure_record.py +++ b/python/tvm/auto_scheduler/measure_record.py @@ -27,6 +27,7 @@ import numpy as np import tvm._ffi from tvm.runtime import Object from .measure import MeasureErrorNo, MeasureCallback +from .utils import decode_workload_key from . import _ffi_api logger = logging.getLogger("auto_scheduler") @@ -59,8 +60,37 @@ class RecordReader(Object): """ def __init__(self, filename): + # a set to prevent print duplicated message + self.messages = set() + self.__init_handle_by_constructor__(_ffi_api.RecordReader, filename) + def check_workload_key(self, inputs): + """Check and throw warnings for records with old format workload key. + + Parameters + ---------- + inputs: List[MeasureInput] + The measure inputs to be checked. + + Notes + ----- + This checker could be deprecated in the future. + """ + for inp in inputs: + _, args = decode_workload_key(inp.task.workload_key) + if args is None: + continue + if not args: + msg = ( + "MeasureInput with old format workload key %s should be updated " + "using the script from https://github.com/apache/tvm/pull/7317." + % inp.task.workload_key + ) + if msg not in self.messages: + self.messages.add(msg) + logger.warning(msg) + def read_lines(self, max_lines=None, skip_lines=0): """Read multiple lines from the log file. @@ -88,6 +118,7 @@ class RecordReader(Object): inputs, results = _ffi_api.RecordReaderReadLines( self, max_lines if max_lines else -1, skip_lines ) + self.check_workload_key(inputs) return inputs, results def __iter__(self): @@ -95,9 +126,69 @@ class RecordReader(Object): ret = _ffi_api.RecordReaderReadNext(self) if not ret: break + self.check_workload_key([ret[0]]) yield ret[0], ret[1] # (input, result) +def calc_workload_dis_factor(target_workload_key, workload_key): + """Calculate the distance factor of the workload to the target workload. + If two workloads are not compatible at all (i.e., different compute DAG or function), + then the distance factor is "inf". Otherwise, we calculate the factor by traversing + the workload arguments, which are the arguments of the compute function, + or the output shapes for the ComputeDAG. The factor is calculated by the following rules: + + 1. For non-zero integer values: `product(target_arg / candidate_arg)`. + 2. For non-integer or zero values: "inf" if not equal else 1. + + As a result, factor=1 is the optimal when two workloads are identical. + + Parameters + ---------- + target_workload_key: str + The target workload key in JSON string. + + workload_key: str + The candidate workload key in JSON string. + + Returns + ------- + dis_f: float + The distance factor. + """ + + def flatten_list(inp): + ret = [] + for elt in inp: + if isinstance(elt, list): + ret += flatten_list(elt) + else: + ret.append(elt) + return ret + + target_key, target_args = decode_workload_key(target_workload_key) + target_args = flatten_list(target_args) if target_args is not None else [] + key, args = decode_workload_key(workload_key) + args = flatten_list(args) if args is not None else [] + + # Not even the same func/DAG. + if key != target_key or len(target_args) != len(args): + return float("inf") + + dis_f = 1 + for target_arg, arg in zip(target_args, args): + if isinstance(target_arg, int): + if target_arg == 0 or arg == 0: + if target_arg != arg: + return float("inf") + elif target_arg % arg != 0: + return float("inf") + else: + dis_f *= target_arg / arg + elif target_arg != arg: + return float("inf") + return dis_f + + def load_record_from_string(record): """ Load the measure record from string. @@ -174,7 +265,7 @@ def save_records(filename, inputs, results): _ffi_api.SaveRecords(filename, inputs, results) -def load_best_record(filename, workload_key=None, target=None): +def load_best_record(filename, workload_key=None, target=None, include_compatible=False): """Return the best measurement pair form a log file. This may return none results if there is no legal measure pair with the specified workload_key/target found from the log file. @@ -188,6 +279,8 @@ def load_best_record(filename, workload_key=None, target=None): target : Optional[tvm.target.Target] The target device. With `None`, this returns the best measure pair of all target devices. + include_compatible: bool + When set to True, all compatible records in the log file will be considered. Returns ------- @@ -204,13 +297,23 @@ def load_best_record(filename, workload_key=None, target=None): for inp, res in log_reader: if res.error_no != MeasureErrorNo.NO_ERROR: continue - if workload_key and inp.task.workload_key != workload_key: - continue if target and inp.task.target.kind.name != target.kind.name: continue costs = [v.value for v in res.costs] cost = np.mean(costs) + + if workload_key is not None: + dis_f = calc_workload_dis_factor(workload_key, inp.task.workload_key) + if dis_f == float("inf"): + continue + if not include_compatible and dis_f != 1: + continue + + # Since different workloads have different FLOPS, we multiply the factor to + # eliminate this difference, which is basically the concept of throughput. + cost *= dis_f + if cost < best_cost: best_cost = cost best_inp = inp @@ -267,12 +370,8 @@ def distill_record_file(in_file, out_file): logger.info("Extract %d best records from %s to %s", len(inputs), in_file, out_file) -""" -Usage: -* Distill the best entries from a large log file -e.g. python -m tvm.auto_scheduler.measure_record --mode distill --i input.json -""" -if __name__ == "__main__": +def main(): + """The main function for CLI.""" parser = argparse.ArgumentParser() parser.add_argument("--mode", choices=["distill"], required=True) parser.add_argument("--i", type=str, help="input file") @@ -285,3 +384,12 @@ if __name__ == "__main__": if args.mode == "distill": args.o = args.o or args.i + ".best.json" distill_record_file(args.i, args.o) + + +""" +Usage: +* Distill the best entries from a large log file +e.g. python -m tvm.auto_scheduler.measure_record --mode distill --i input.json +""" +if __name__ == "__main__": + main() diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index fb60da1..b39aba2 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -22,7 +22,6 @@ Integrate auto_scheduler into relay. It implements the following items: 2. Provide auto-scheduling for all TOPI compute functions """ -import json import logging import threading @@ -281,7 +280,7 @@ def auto_schedule_topi(outs): logger.info("Failed to create a ComputeDAG for auto_scheduler: %s", str(err)) return None - key = register_workload_tensors(dag.hash_key(), io_tensors) + key = register_workload_tensors(dag.workload_key(), io_tensors) target = tvm.target.Target.current() env = TracingEnvironment.current @@ -310,9 +309,8 @@ def auto_schedule_topi(outs): return None # rewrite the layout and update the context for the new dag - dag = ComputeDAG(outs) new_dag = dag.rewrite_layout_from_state(state) - new_key = json.dumps((new_dag.hash_key(),)) + new_key = new_dag.workload_key() if new_key != key: dispatch_ctx.update(target, new_key, state) else: diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py index d985ed1..83f665b 100644 --- a/python/tvm/auto_scheduler/search_task.py +++ b/python/tvm/auto_scheduler/search_task.py @@ -257,13 +257,15 @@ class SearchTask(Object): _ffi_api.AutoSchedule(search_policy, tuning_options) - def apply_best(self, log_file, layout_rewrite_option=None): + def apply_best(self, log_file, include_compatible=False, layout_rewrite_option=None): """Apply the history best from a log file and return the schedule. Parameters ---------- log_file : str The name of the log file. + include_compatible: bool + When set to True, all compatible records in the log file will be considered. layout_rewrite_option : Optional[LayoutRewriteOption] The layout rewrite option. @@ -272,7 +274,9 @@ class SearchTask(Object): ------- A `te.Schedule` and the a list of `te.Tensor` to be used in `tvm.lower` or `tvm.build`. """ - inp, _ = load_best_record(log_file, self.workload_key) + inp, _ = load_best_record( + log_file, self.workload_key, include_compatible=include_compatible + ) if inp is None: raise RuntimeError( "Cannot find any valid schedule for %s in file %s" % (self.workload_key, log_file) diff --git a/python/tvm/auto_scheduler/utils.py b/python/tvm/auto_scheduler/utils.py index 334acaf..fd25fdb 100644 --- a/python/tvm/auto_scheduler/utils.py +++ b/python/tvm/auto_scheduler/utils.py @@ -19,6 +19,7 @@ """ Common utilities for auto_scheduler. """ from typing import Hashable +import json import multiprocessing import multiprocessing.pool import queue @@ -42,6 +43,32 @@ from tvm.ir.transform import Sequential from ..te import Tensor, placeholder +def decode_workload_key(workload_key): + """Decode the workload key from a string to the name and arguments. The wokrload key + is expected to be a list of "[func_name/hash, args ...]" in a JSON string. If not, + then simply return the workload key as the name without arguments. + + Parameters + ---------- + workload_key: str + The workload key in string. Format: "[func_name/hash, args ...]". + + Returns + ------- + name: str + The workload function name or the DAG hash. + args: Optional[List[Any]] + The arguments of the workload, or None if the workload key format is not decodeable. + """ + try: + key_list = json.loads(workload_key) + if isinstance(key_list, list) and len(key_list) >= 1: + return key_list[0], key_list[1:] + except json.decoder.JSONDecodeError: + pass + return workload_key, None + + def get_func_name(func): """Get name of a function. diff --git a/python/tvm/auto_scheduler/workload_registry.py b/python/tvm/auto_scheduler/workload_registry.py index 9a7c15c..51ae64d 100644 --- a/python/tvm/auto_scheduler/workload_registry.py +++ b/python/tvm/auto_scheduler/workload_registry.py @@ -98,14 +98,14 @@ def register_workload(func_name, f=None, override=False): return register -def register_workload_tensors(func_name, tensors, override=True): +def register_workload_tensors(workload_key, tensors, override=True): """Register a workload by provding input/output tensors. Since this function is used when extracting/deserializing tasks, it expects duplicated registrations by default. Parameters ---------- - func_name: str - The function name or the hash key of the compute DAG. + workload_key: str + The wokrload key of the compute DAG in JSON string. tensors: List[Tensor] The input/output tensors of a compute DAG override : boolean = True @@ -113,11 +113,11 @@ def register_workload_tensors(func_name, tensors, override=True): Returns ------- - key: str - The serialized JSON string as the workload key. + workload_key: str + The wokrload key of the compute DAG in JSON string. """ - register_workload(func_name, override=override)(tensors) - return json.dumps((func_name,)) + register_workload(workload_key, override=override)(tensors) + return workload_key def make_workload_key(func, args): @@ -169,7 +169,8 @@ def workload_key_to_tensors(workload_key): Parameters ---------- workload_key : str - The input workload key. + The input workload key in JSON string. The format is either (func_name, arguments...) + for compute functions, or (hash, shapes...) for ComputeDAG. Returns ------- @@ -178,16 +179,21 @@ def workload_key_to_tensors(workload_key): """ global WORKLOAD_FUNC_REGISTRY + # We register ComputeDAG with both hash and argumetns, which are fixed in ComputeDAG, + # so we use an entire workload key to query the ComputeDAG. + if workload_key in WORKLOAD_FUNC_REGISTRY: + return WORKLOAD_FUNC_REGISTRY[workload_key] + + # We register compute function with only the function name since + # it does not bind to specific arguments, so we use the function name to query + # the function and call the function with arguments to get the tensors. workload = json.loads(workload_key) name = workload[0] value = WORKLOAD_FUNC_REGISTRY[name] + assert callable(value) - # "value" can be either a function or a list of tensors - if callable(value): # if it is a func - args = deserialize_args(workload[1:]) - return value(*args) - # otherwise, it is a list of tensors - return value + args = deserialize_args(workload[1:]) + return value(*args) def serialize_workload_registry_entry(workload_key): @@ -209,6 +215,9 @@ def serialize_workload_registry_entry(workload_key): """ global WORKLOAD_FUNC_REGISTRY + if workload_key in WORKLOAD_FUNC_REGISTRY: + return (workload_key, WORKLOAD_FUNC_REGISTRY[workload_key]) + workload = json.loads(workload_key) name = workload[0] value = WORKLOAD_FUNC_REGISTRY[name] diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc old mode 100755 new mode 100644 index 735f044..4e7fb05 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -1243,6 +1243,62 @@ String ComputeDAG::PrintStepsAsPython(const Array<Step>& transform_steps) const return ss.str(); } +String ComputeDAG::PrintDAG(bool simple_mode) const { + std::stringstream ss; + + for (const auto& op : operator->()->ops) { + if (op->IsInstance<te::PlaceholderOpNode>()) { + ss << op->name << " = PLACEHOLDER "; + if (!simple_mode) { + ss << op.output(0)->shape; + } + ss << "\n"; + } else if (auto pop = op.as<te::ComputeOpNode>()) { + for (size_t k = 0; k < pop->body.size(); ++k) { + ss << op->name << "("; + for (size_t i = 0; i < pop->axis.size(); i++) { + ss << pop->axis[i]->var->name_hint; + if (i != pop->axis.size() - 1) { + ss << ", "; + } + } + ss << ")"; + if (pop->body.size() > 1) { + ss << ".v" << k; + } + if (auto preduce = pop->body[k].as<ReduceNode>()) { + ICHECK_LT(k, preduce->combiner->result.size()); + PrimExpr combiner = preduce->combiner->result[k]; + if (combiner->IsInstance<AddNode>()) { + ss << " += " << preduce->source[0] << "\n"; + } else if (combiner->IsInstance<MaxNode>()) { + ss << " max= " << preduce->source[0] << "\n"; + } else if (combiner->IsInstance<MinNode>()) { + ss << " min= " << preduce->source[0] << "\n"; + } else if (combiner->IsInstance<SelectNode>()) { + const auto& select = combiner.as<SelectNode>(); + ss << " select(" << select->condition << ", " << select->true_value << ", " + << select->false_value << ")= " << '(' << preduce->source[0] << ',' + << preduce->source[1] << ")\n"; + } else { + ss << "reduce" << combiner << "\n"; + } + } else { + auto call = pop->body[k].as<CallNode>(); + if (simple_mode && call) { + ss << " = " << call->op << "\n"; + } else { + ss << " = " << pop->body[k] << "\n"; + } + } + } + } else { + LOG(FATAL) << "Invalid op"; + } + } + return String(ss.str()); +} + State ComputeDAG::InferBound(const State& state) const { ICHECK(state->concrete) << "Only concrete state can be processed to get bound info."; @@ -1383,51 +1439,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch<ComputeDAGNode>([](const ObjectRef& ref, ReprPrinter* p) { auto* node = static_cast<const ComputeDAGNode*>(ref.get()); - std::stringstream ss; - - for (const auto& op : node->ops) { - if (op->IsInstance<te::PlaceholderOpNode>()) { - ss << op->name << " = PLACEHOLDER " << op.output(0)->shape << "\n"; - } else if (auto pop = op.as<te::ComputeOpNode>()) { - for (size_t k = 0; k < pop->body.size(); ++k) { - ss << op->name << "("; - for (size_t i = 0; i < pop->axis.size(); i++) { - ss << pop->axis[i]->var->name_hint; - if (i != pop->axis.size() - 1) { - ss << ", "; - } - } - ss << ")"; - if (pop->body.size() > 1) { - ss << ".v" << k; - } - if (auto preduce = pop->body[k].as<ReduceNode>()) { - ICHECK_LT(k, preduce->combiner->result.size()); - PrimExpr combiner = preduce->combiner->result[k]; - if (combiner->IsInstance<AddNode>()) { - ss << " += " << preduce->source[0] << "\n"; - } else if (combiner->IsInstance<MaxNode>()) { - ss << " max= " << preduce->source[0] << "\n"; - } else if (combiner->IsInstance<MinNode>()) { - ss << " min= " << preduce->source[0] << "\n"; - } else if (combiner->IsInstance<SelectNode>()) { - const auto& select = combiner.as<SelectNode>(); - ss << " select(" << select->condition << ", " << select->true_value << ", " - << select->false_value << ")= " << '(' << preduce->source[0] << ',' - << preduce->source[1] << ")\n"; - } else { - ss << "reduce" << combiner << "\n"; - } - } else { - ss << " = " << pop->body[k] << "\n"; - } - } - } else { - LOG(FATAL) << "Invalid op"; - } - } - - p->stream << ss.str(); + auto dag = GetRef<ComputeDAG>(node); + auto dag_str = dag.PrintDAG(); + p->stream << dag_str; }); Array<PrimExpr> GetShapeFromRewrittenLayout(String rewritten_layout, Array<String> axis_names) { @@ -1476,6 +1490,11 @@ TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAGPrintPythonCodeFromState") return dag.PrintStepsAsPython(state->transform_steps); }); +TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAGPrintDAG") + .set_body_typed([](const ComputeDAG& dag, bool simple_mode) { + return dag.PrintDAG(simple_mode); + }); + TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAGInferBoundFromState") .set_body_typed([](const ComputeDAG& dag, const State& state) { return dag.InferBound(state); diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py index e9f1fa4..3b074b2 100644 --- a/tests/python/unittest/test_auto_scheduler_measure.py +++ b/tests/python/unittest/test_auto_scheduler_measure.py @@ -16,6 +16,7 @@ # under the License. """ Test measurement and log serialization. """ +import json import multiprocessing import tvm @@ -200,6 +201,38 @@ def test_recover_measure_input(): assert str(correct_inp.state) == str(inp.state) +def test_workload_dis_factor(): + calc = auto_scheduler.measure_record.calc_workload_dis_factor + + # Identical + target_wkl_key = json.dumps( + ["func1", [8, 3, 224, 224], [32, 3, 3, 3], [0, 0], [1, 1], "float32"] + ) + assert calc(target_wkl_key, target_wkl_key) == 1 + + # Compatible with a factor + wkl_key = json.dumps(["func1", [1, 3, 112, 112], [32, 3, 3, 3], [0, 0], [1, 1], "float32"]) + assert calc(target_wkl_key, wkl_key) == 8 * 2 * 2 + + # Incompatible argument with zeros + wkl_key = json.dumps(["func1", [8, 3, 224, 224], [32, 3, 3, 3], [1, 1], [1, 1], "float32"]) + assert calc(target_wkl_key, wkl_key) == float("inf") + wkl_key = json.dumps(["func1", [8, 3, 224, 224], [32, 3, 3, 3], [0, 0], [0, 0], "float32"]) + assert calc(target_wkl_key, wkl_key) == float("inf") + + # Incompatible non-integter argument + wkl_key = json.dumps(["func1", [8, 3, 224, 224], [32, 3, 3, 3], [0, 0], [1, 1], "int8"]) + assert calc(target_wkl_key, wkl_key) == float("inf") + + # Incompatible function + wkl_key = json.dumps(["func2", [8, 3, 224, 224], [32, 3, 3, 3], [0, 0], [1, 1], "float32"]) + assert calc(target_wkl_key, wkl_key) == float("inf") + + # Incompatible due to non-dividable factor + wkl_key = json.dumps(["func1", [8, 3, 223, 223], [32, 3, 3, 3], [0, 0], [1, 1], "float32"]) + assert calc(target_wkl_key, wkl_key) == float("inf") + + def test_measure_local_builder_runner(): if not tvm.testing.device_enabled("llvm"): return diff --git a/tutorials/auto_scheduler/ci_logs/resnet-18-NHWC-B1-cuda.json b/tutorials/auto_scheduler/ci_logs/resnet-18-NHWC-B1-cuda.json index 8d0a6ae..7cb3a67 100644 --- a/tutorials/auto_scheduler/ci_logs/resnet-18-NHWC-B1-cuda.json +++ b/tutorials/auto_scheduler/ci_logs/resnet-18-NHWC-B1-cuda.json @@ -1,26 +1,26 @@ # Provide valid schedules for resnet-18 on GPU. # This is used to run the tutorial on the documentation web server. -{"i": [["[\"b32ed43fb351136894c322ee49097a1a\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["SP", 4, 1, 1000, [40], 1], ["AN", 4, 2, 6], ["FSP", 3, 1, 0, 1], ["AN", 3, 2, 6], ["CA", 3, 4, 0], ["CI", 2], ["FSP", 1, 1, 0, 1], ["AN", 1, 2, 6], ["CA", 1, 4, 0], ["AN", 4, 0, 5], ["PR", 1, 0, "auto_unroll_max_step$512"], ["PR", 3, 0, "auto_unroll_max_step$512"]]]], "r": [[4.87396e-06], 0, 1.30575, 1606984701], "v": "v0.3"} -{"i": [["[\"d09dc1a6bb90d59c91b68989ad3492ff\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["SP", 2, 0, 1, [1, 1, 1, 1], 1], ["SP", 2, 5, 1000, [1, 50, 1, 1], 1], ["SP", 2, 10, 512, [1, 16], 1], ["RE", 2, [0, 5, 1, 6, 2, 7, 10, 11, 3, 8, 12, 4, 9]], ["FSP", 4, 0, 0, 3], ["FSP", 4, 4, 1, 3], ["RE", 4, [0, 4, 1, 5, 2, 6, 3, 7]], ["CA", 2, 4, 5], ["CHR", 1, "shared", [2]], ["CA", 2, 3, 6], ["CHR", 0, "shared", [3]], [" [...] -{"i": [["[\"7de313da0ca29a8c63f647791692430d\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 512, [2], 1], ["AN", 2, 0, 5], ["AN", 2, 1, 6], ["FU", 1, [0, 1, 2, 3]], ["SP", 1, 0, 512, [32], 1], ["AN", 1, 0, 5], ["AN", 1, 1, 6], ["PR", 1, 0, "auto_unroll_max_step$64"]]]], "r": [[3.91068e-06], 0, 1.63708, 1606984742], "v": "v0.3"} -{"i": [["[\"8d5a93959138dc7b2ee1f1b3219dfa14\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 15], ["CI", 13], ["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 16, [2], 1], ["SP", 8, 4, 512, [16], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 2, 1, 1], 1], ["SP", 6, 5, 4, [1, 4, 1, 1], 1], ["SP", 6, 10, 16, [4, [...] -{"i": [["[\"ac6920940de3797cc3f9f9c260675e5d\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 16, [4], 1], ["SP", 8, 4, 512, [8], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 1, 1, 2], 1], ["SP", 6, 5, 4, [1, 1, 1, 2], 1], ["SP", 6, 10, 16, [4, 2, 2, 1], 1], ["SP", 6, 1 [...] -{"i": [["[\"7e83a2ee5cd5d50282ed19310700046a\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 16, [1], 1], ["SP", 8, 4, 512, [8], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 1, 1, 2], 1], ["SP", 6, 5, 4, [1, 1, 1, 1], 1], ["SP", 6, 10, 16, [2, 1, 1, 8], 1], ["SP", 6, 15, 512, [1, [...] -{"i": [["[\"424ba83160af31badc0b098136e1a3b0\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 13], ["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 49, [1], 1], ["SP", 8, 4, 256, [16], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 2, 1, 2], 1], ["SP", 6, 5, 4, [1, 1, 1, 2], 1], ["SP", 6, 10, 49, [1, 1, 1, 7], 1] [...] -{"i": [["[\"a169cd0053d3a7ca82998fcb62e42c58\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 49, [1], 1], ["SP", 8, 4, 256, [1], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 2, 1, 1], 1], ["SP", 6, 5, 4, [1, 1, 1, 1], 1], ["SP", 6, 10, 49, [1, 7, 7, 1], 1], ["SP", 6, 1 [...] -{"i": [["[\"0141ffc4fbabc10cc5a94c954419055b\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 49, [7], 1], ["SP", 8, 4, 256, [1], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 4, 1, 1], 1], ["SP", 6, 5, 4, [1, 4, 1, 1], 1], ["SP", 6, 10, 49, [1, 1, 7, 1], 1], ["SP", 6, 15, 256, [4, [...] -{"i": [["[\"81aae4b8e2c076a4014d403e8a2c70a1\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 14, [2, 7, 1, 1], 1], ["SP", 3, 10, 14, [1, 7, 2, 1], 1], ["SP", 3, 15, 256, [2, 2, 1, 4], 1], ["SP", 3, 20, 3, [3, 1], 1], ["SP", 3, 23, 3, [1, 3], 1], ["SP", 3, 26, 128, [4, 1], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27, 3, 8, 13, [...] -{"i": [["[\"c7a6b56bdc04b94c829fb2ef9874019e\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 13], ["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [14], 1], ["SP", 8, 4, 128, [1], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 1, 1, 1], 1], ["SP", 6, 5, 4, [1, 4, 1, 1], 1], ["SP", 6, 10, 196, [1, 7, 1, 7], [...] -{"i": [["[\"c035cc8b0568a8e054d06bd7f4950550\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [2], 1], ["SP", 8, 4, 128, [4], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 1, 1, 4], 1], ["SP", 6, 5, 4, [1, 1, 1, 2], 1], ["SP", 6, 10, 196, [2, 49, 1, 1], 1], ["SP", 6 [...] -{"i": [["[\"c5ee3e05edd9754492d0763aa41fd025\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [2], 1], ["SP", 8, 4, 128, [4], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 2, 1, 1], 1], ["SP", 6, 5, 4, [1, 4, 1, 1], 1], ["SP", 6, 10, 196, [1, 2, 7, 1], 1], ["SP", 6, 15, 128, [1 [...] -{"i": [["[\"022ebb6b7c55c5ed030421380ec83a04\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 28, [1, 1, 2, 1], 1], ["SP", 3, 10, 28, [1, 7, 2, 2], 1], ["SP", 3, 15, 128, [1, 8, 8, 1], 1], ["SP", 3, 20, 3, [3, 1], 1], ["SP", 3, 23, 3, [1, 1], 1], ["SP", 3, 26, 64, [4, 2], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27, 3, 8, 13, [...] -{"i": [["[\"de0df0893e01892cfe69f7bc2c24111f\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 13], ["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [2], 1], ["SP", 8, 4, 64, [1], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 6, [1, 1, 1, 1], 1], ["SP", 6, 5, 6, [1, 1, 1, 2], 1], ["SP", 6, 10, 196, [2, 14, 1, 1], 1 [...] -{"i": [["[\"f2e3c09a00e7d0a9897f70497e089f1e\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [2], 1], ["SP", 8, 4, 64, [64], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 6, [1, 3, 2, 1], 1], ["SP", 6, 5, 6, [1, 3, 1, 2], 1], ["SP", 6, 10, 196, [1, 1, 4, 1], 1], ["SP", 6, [...] -{"i": [["[\"fa26946d7ac51126bfa859cb183f9ca1\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [49], 1], ["SP", 8, 4, 64, [2], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 6, [1, 2, 1, 3], 1], ["SP", 6, 5, 6, [1, 3, 1, 2], 1], ["SP", 6, 10, 196, [1, 1, 1, 4], 1], ["SP", 6, 15, 64, [1, [...] -{"i": [["[\"ba2026d923536b75e9b4faed89287d5f\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 4], ["CI", 1], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 200704, [64], 1], ["AN", 5, 0, 5], ["AN", 5, 1, 6], ["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 200704, [64], 1], ["AN", 2, 0, 5], ["AN", 2, 1, 6], ["PR", 2, 0, "auto_unroll_max_step$16"]]]], "r": [[2.00968e-05], 0, 1.53065, 1606985193], "v": "v0.3"} -{"i": [["[\"a0eb8d6048282a4a0986cc2ccf14eaa2\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 112, [1, 2, 7, 1], 1], ["SP", 3, 10, 112, [1, 7, 1, 1], 1], ["SP", 3, 15, 64, [1, 8, 4, 1], 1], ["SP", 3, 20, 7, [7, 1], 1], ["SP", 3, 23, 7, [1, 7], 1], ["SP", 3, 26, 3, [3, 1], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27, 3, 8, 13, [...] -{"i": [["[\"bf78a7bf0209980f72953637dfd14a6f\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 56, [1, 2, 2, 2], 1], ["SP", 3, 10, 56, [1, 7, 1, 2], 1], ["SP", 3, 15, 64, [1, 16, 1, 4], 1], ["SP", 3, 20, 1, [1, 1], 1], ["SP", 3, 23, 1, [1, 1], 1], ["SP", 3, 26, 64, [2, 8], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27, [...] -{"i": [["[\"6630936c26852f2b89dbfa2ff37fbb9c\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 28, [1, 7, 1, 1], 1], ["SP", 3, 10, 28, [1, 2, 1, 7], 1], ["SP", 3, 15, 128, [8, 8, 1, 1], 1], ["SP", 3, 20, 1, [1, 1], 1], ["SP", 3, 23, 1, [1, 1], 1], ["SP", 3, 26, 64, [2, 2], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27, [...] -{"i": [["[\"ba5f918733ccbbd4a1d7fd3724665a2f\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 14, [1, 1, 1, 1], 1], ["SP", 3, 10, 14, [2, 1, 7, 1], 1], ["SP", 3, 15, 256, [2, 64, 1, 2], 1], ["SP", 3, 20, 1, [1, 1], 1], ["SP", 3, 23, 1, [1, 1], 1], ["SP", 3, 26, 128, [1, 2], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27 [...] -{"i": [["[\"21ad409d72953de188314010134e3acd\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 7, [1, 1, 7, 1], 1], ["SP", 3, 10, 7, [1, 1, 1, 1], 1], ["SP", 3, 15, 512, [4, 128, 1, 1], 1], ["SP", 3, 20, 1, [1, 1], 1], ["SP", 3, 23, 1, [1, 1], 1], ["SP", 3, 26, 256, [1, 16], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27 [...] -{"i": [["[\"1f6cd3637ec856bf5cf5010a623eed05\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 7, [7, 1, 1, 1], 1], ["SP", 3, 10, 7, [1, 7, 1, 1], 1], ["SP", 3, 15, 512, [1, 4, 1, 1], 1], ["SP", 3, 20, 3, [1, 3], 1], ["SP", 3, 23, 3, [1, 3], 1], ["SP", 3, 26, 256, [8, 2], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27, 3, 8, 13, 1 [...] +{"i": [["[\"d7b65649a4dd54becea0a52aabbc5af5\", 1, 1000, 1, 1000]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["SP", 4, 1, 1000, [40], 1], ["AN", 4, 2, 6], ["FSP", 3, 1, 0, 1], ["AN", 3, 2, 6], ["CA", 3, 4, 0], ["CI", 2], ["FSP", 1, 1, 0, 1], ["AN", 1, 2, 6], ["CA", 1, 4, 0], ["AN", 4, 0, 5], ["PR", 1, 0, "auto_unroll_max_step$512"], ["PR", 3, 0, "auto_unroll_max_step$512"]]]], "r": [[4.87396e-06], 0, 1.3 [...] +{"i": [["[\"9847f8cc0b305137f49f2c5c0c8ab25d\", 1, 512, 1000, 512, 1000, 1, 1000]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["SP", 2, 0, 1, [1, 1, 1, 1], 1], ["SP", 2, 5, 1000, [1, 50, 1, 1], 1], ["SP", 2, 10, 512, [1, 16], 1], ["RE", 2, [0, 5, 1, 6, 2, 7, 10, 11, 3, 8, 12, 4, 9]], ["FSP", 4, 0, 0, 3], ["FSP", 4, 4, 1, 3], ["RE", 4, [0, 4, 1, 5, 2, 6, 3, 7]], ["CA", 2, 4, 5], ["CHR", 1, "shared", [2]], [...] +{"i": [["[\"69115f188984ae34ede37c3b8ca40b43\", 1, 7, 7, 512, 1, 1, 1, 512]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 512, [2], 1], ["AN", 2, 0, 5], ["AN", 2, 1, 6], ["FU", 1, [0, 1, 2, 3]], ["SP", 1, 0, 512, [32], 1], ["AN", 1, 0, 5], ["AN", 1, 1, 6], ["PR", 1, 0, "auto_unroll_max_step$64"]]]], "r": [[3.91068e-06], 0, 1.63708, 1606984742], "v": "v0.5"} +{"i": [["[\"ad6cecbf5d85cb1cda3c2bb7af170211\", 1, 7, 7, 512, 4, 4, 512, 512, 1, 7, 7, 512, 1, 1, 1, 512, 1, 1, 1, 512, 1, 7, 7, 512]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 15], ["CI", 13], ["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 16, [2], 1], ["SP", 8, 4, 512, [16], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "l [...] +{"i": [["[\"3a69f9fbc63760d99e36b4c17b3bfc57\", 1, 7, 7, 512, 4, 4, 512, 512, 1, 1, 1, 512, 1, 7, 7, 512]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 16, [4], 1], ["SP", 8, 4, 512, [8], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 1, 1, 2], 1], ["SP", 6, 5 [...] +{"i": [["[\"d730bcd28f0920f6b97245e2a11bd8d6\", 1, 7, 7, 512, 4, 4, 512, 512, 1, 7, 7, 512, 1, 7, 7, 512]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 16, [1], 1], ["SP", 8, 4, 512, [8], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 1, 1, 2], 1], ["SP", 6, 5, 4, [1, 1, [...] +{"i": [["[\"f3b6c10fcc6ce01ff01add933e4d21e9\", 1, 14, 14, 256, 4, 4, 256, 256, 1, 14, 14, 256, 1, 1, 1, 256, 1, 14, 14, 256]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 13], ["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 49, [1], 1], ["SP", 8, 4, 256, [16], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, [...] +{"i": [["[\"b8b52b9be9df6102466a22a014c44c1f\", 1, 14, 14, 256, 4, 4, 256, 256, 1, 1, 1, 256, 1, 14, 14, 256]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 49, [1], 1], ["SP", 8, 4, 256, [1], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 2, 1, 1], 1], ["SP", [...] +{"i": [["[\"d374e472bd9d8164892b9e28a0a8cb59\", 1, 14, 14, 256, 4, 4, 256, 256, 1, 14, 14, 256, 1, 14, 14, 256]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 49, [7], 1], ["SP", 8, 4, 256, [1], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 4, 1, 1], 1], ["SP", 6, 5, 4, [ [...] +{"i": [["[\"12b88bedece6984af589a28b43e0f3c4\", 1, 28, 28, 128, 3, 3, 128, 256, 1, 1, 1, 256, 1, 14, 14, 256]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 14, [2, 7, 1, 1], 1], ["SP", 3, 10, 14, [1, 7, 2, 1], 1], ["SP", 3, 15, 256, [2, 2, 1, 4], 1], ["SP", 3, 20, 3, [3, 1], 1], ["SP", 3, 23, 3, [1, 3], 1], ["SP", 3, 26, 128, [4, 1], 1], ["RE", 3, [0 [...] +{"i": [["[\"c4500b4e2fd04e695c32d2f31bbdc14a\", 1, 28, 28, 128, 4, 4, 128, 128, 1, 28, 28, 128, 1, 1, 1, 128, 1, 28, 28, 128]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 13], ["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [14], 1], ["SP", 8, 4, 128, [1], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0 [...] +{"i": [["[\"e4cdf917b876dbdd64488c3818d9c141\", 1, 28, 28, 128, 4, 4, 128, 128, 1, 1, 1, 128, 1, 28, 28, 128]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [2], 1], ["SP", 8, 4, 128, [4], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 1, 1, 4], 1], ["SP", [...] +{"i": [["[\"dac19035dd5fe9424ee8617421b9c817\", 1, 28, 28, 128, 4, 4, 128, 128, 1, 28, 28, 128, 1, 28, 28, 128]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [2], 1], ["SP", 8, 4, 128, [4], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 2, 1, 1], 1], ["SP", 6, 5, 4, [...] +{"i": [["[\"12b88bedece6984af589a28b43e0f3c4\", 1, 56, 56, 64, 3, 3, 64, 128, 1, 1, 1, 128, 1, 28, 28, 128]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 28, [1, 1, 2, 1], 1], ["SP", 3, 10, 28, [1, 7, 2, 2], 1], ["SP", 3, 15, 128, [1, 8, 8, 1], 1], ["SP", 3, 20, 3, [3, 1], 1], ["SP", 3, 23, 3, [1, 1], 1], ["SP", 3, 26, 64, [4, 2], 1], ["RE", 3, [0, 5 [...] +{"i": [["[\"1e3c4211ffd2f2db91078ae4d04b779d\", 1, 56, 56, 64, 6, 6, 64, 64, 1, 56, 56, 64, 1, 1, 1, 64, 1, 56, 56, 64]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 13], ["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [2], 1], ["SP", 8, 4, 64, [1], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 6, [1, [...] +{"i": [["[\"b818b53148cd450f86569dfc3e04cb8a\", 1, 56, 56, 64, 6, 6, 64, 64, 1, 1, 1, 64, 1, 56, 56, 64]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [2], 1], ["SP", 8, 4, 64, [64], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 6, [1, 3, 2, 1], 1], ["SP", 6, 5 [...] +{"i": [["[\"3ea73fb9b0364374730d09e068821f95\", 1, 56, 56, 64, 6, 6, 64, 64, 1, 56, 56, 64, 1, 56, 56, 64]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [49], 1], ["SP", 8, 4, 64, [2], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 6, [1, 2, 1, 3], 1], ["SP", 6, 5, 6, [1, 3 [...] +{"i": [["[\"a5612fdeb9db4d579a75ec225ea4c06a\", 1, 112, 112, 64, 1, 1, 1, 64, 1, 56, 56, 64]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 4], ["CI", 1], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 200704, [64], 1], ["AN", 5, 0, 5], ["AN", 5, 1, 6], ["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 200704, [64], 1], ["AN", 2, 0, 5], ["AN", 2, 1, 6], ["PR", 2, 0, "auto_unroll_max_step$16"]]]], "r": [[2.00968e-05], 0, 1 [...] +{"i": [["[\"12b88bedece6984af589a28b43e0f3c4\", 1, 224, 224, 3, 7, 7, 3, 64, 1, 1, 1, 64, 1, 112, 112, 64]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 112, [1, 2, 7, 1], 1], ["SP", 3, 10, 112, [1, 7, 1, 1], 1], ["SP", 3, 15, 64, [1, 8, 4, 1], 1], ["SP", 3, 20, 7, [7, 1], 1], ["SP", 3, 23, 7, [1, 7], 1], ["SP", 3, 26, 3, [3, 1], 1], ["RE", 3, [0, 5, [...] +{"i": [["[\"7006235cfc29b73be524cf390ed5a977\", 1, 56, 56, 64, 1, 1, 64, 64, 1, 56, 56, 64]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 56, [1, 2, 2, 2], 1], ["SP", 3, 10, 56, [1, 7, 1, 2], 1], ["SP", 3, 15, 64, [1, 16, 1, 4], 1], ["SP", 3, 20, 1, [1, 1], 1], ["SP", 3, 23, 1, [1, 1], 1], ["SP", 3, 26, 64, [2, 8], 1], ["RE", 3, [0, 5, 10, [...] +{"i": [["[\"f4380bb1dc62422a69ad4a1a9771f927\", 1, 56, 56, 64, 1, 1, 64, 128, 1, 28, 28, 128]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 28, [1, 7, 1, 1], 1], ["SP", 3, 10, 28, [1, 2, 1, 7], 1], ["SP", 3, 15, 128, [8, 8, 1, 1], 1], ["SP", 3, 20, 1, [1, 1], 1], ["SP", 3, 23, 1, [1, 1], 1], ["SP", 3, 26, 64, [2, 2], 1], ["RE", 3, [0, 5, 10 [...] +{"i": [["[\"f4380bb1dc62422a69ad4a1a9771f927\", 1, 28, 28, 128, 1, 1, 128, 256, 1, 14, 14, 256]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 14, [1, 1, 1, 1], 1], ["SP", 3, 10, 14, [2, 1, 7, 1], 1], ["SP", 3, 15, 256, [2, 64, 1, 2], 1], ["SP", 3, 20, 1, [1, 1], 1], ["SP", 3, 23, 1, [1, 1], 1], ["SP", 3, 26, 128, [1, 2], 1], ["RE", 3, [0, 5 [...] +{"i": [["[\"f4380bb1dc62422a69ad4a1a9771f927\", 1, 14, 14, 256, 1, 1, 256, 512, 1, 7, 7, 512]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 7, [1, 1, 7, 1], 1], ["SP", 3, 10, 7, [1, 1, 1, 1], 1], ["SP", 3, 15, 512, [4, 128, 1, 1], 1], ["SP", 3, 20, 1, [1, 1], 1], ["SP", 3, 23, 1, [1, 1], 1], ["SP", 3, 26, 256, [1, 16], 1], ["RE", 3, [0, 5, [...] +{"i": [["[\"12b88bedece6984af589a28b43e0f3c4\", 1, 14, 14, 256, 3, 3, 256, 512, 1, 1, 1, 512, 1, 7, 7, 512]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 7, [7, 1, 1, 1], 1], ["SP", 3, 10, 7, [1, 7, 1, 1], 1], ["SP", 3, 15, 512, [1, 4, 1, 1], 1], ["SP", 3, 20, 3, [1, 3], 1], ["SP", 3, 23, 3, [1, 3], 1], ["SP", 3, 26, 256, [8, 2], 1], ["RE", 3, [0, 5, [...] diff --git a/tutorials/auto_scheduler/ci_logs/resnet-50-NHWC-B1-llvm.json b/tutorials/auto_scheduler/ci_logs/resnet-50-NHWC-B1-llvm.json index 611f776..3dd4541 100644 --- a/tutorials/auto_scheduler/ci_logs/resnet-50-NHWC-B1-llvm.json +++ b/tutorials/auto_scheduler/ci_logs/resnet-50-NHWC-B1-llvm.json @@ -1,31 +1,28 @@ # Provide valid schedules for resnet-50 for CPU. # This is used to run the tutorial on the documentation web server. -{"i": [["[\"b32ed43fb351136894c322ee49097a1a\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["SP", 3, 1, 1000, [50], 1], ["RF", 3, 2, 1], ["RE", 3, [0, 2, 1]], ["SP", 1, 1, 1000, [20], 1], ["RF", 1, 2, 1], ["RE", 1, [0, 2, 1]], ["CR", 6], ["CA", 5, 6, 1], ["CR", 4], ["CA", 2, 3, 1], ["AN", 1, 0, 3], ["FU", 3, [0, 1]], ["AN", 3, 0, 3], ["AN", 4, 0, 3], ["FU", 6, [0, 1]], ["AN", 6, 0, 3], ["PR", 1, 0, "auto_unroll_max_step$16"], ["PR", 2, 0, "auto [...] -{"i": [["[\"6129df1a3d5f6326c8393a8d17160199\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["SP", 2, 0, 1, [1, 1, 1], 1], ["SP", 2, 4, 1000, [1, 1, 1], 1], ["SP", 2, 8, 16, [2, 2, 4], 1], ["SP", 2, 12, 128, [32], 1], ["RE", 2, [0, 4, 8, 1, 5, 9, 12, 2, 6, 10, 13, 3, 7, 11]], ["CR", 5], ["CA", 3, 5, 1], ["FU", 2, [0, 1]], ["AN", 2, 0, 3], ["FU", 5, [0, 1]], ["AN", 5, 0, 3], ["PR", 2, 0, "auto_unroll_max_step$16"], ["PR", 3, 0, "auto_unroll_max_s [...] -{"i": [["[\"36ee2798ed60bae3bcd1bb89a0285fe8\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["CA", 1, 2, 3], ["FU", 2, [0, 1, 2, 3]], ["AN", 2, 0, 3], ["PR", 1, 0, "auto_unroll_max_step$16"]]]], "r": [[6.28e-06, 8.176e-06, 8.048e-06, 7.942e-06, 7.977e-06, 8.002e-06, 8.093e-06, 7.924e-06, 7.943e-06, 7.924e-06], 0, 0.130759, 1606960900], "v": "v0.3"} -{"i": [["[\"dcf6fcf5f56fa614bf9aef0c82382caf\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["CI", 9], ["CI", 7], ["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 7, [1, 7, 1], 1], ["SP", 3, 8, 7, [1, 1, 1], 1], ["SP", 3, 12, 2048, [8, 2, 8], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 512, [2], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["FSP", 10, 0, 3, 2], ["F [...] -{"i": [["[\"7657f886f5e9d8b5f19a5fd2c5b90d8d\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 7, [1, 1, 7], 1], ["SP", 3, 8, 7, [1, 1, 1], 1], ["SP", 3, 12, 512, [32, 1, 8], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 1024, [8], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["CI", 1], ["FU", 3, [0, 1, 2, 3, 4, 5, 6, 7]] [...] -{"i": [["[\"7e09b626cf077cd419190fee02091dd6\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["CI", 7], ["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 14, [7, 1, 2], 1], ["SP", 3, 8, 14, [2, 1, 1], 1], ["SP", 3, 12, 1024, [1, 1, 32], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 256, [8], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["CI", 1], ["FU", 3, [0, 1, 2, 3, [...] -{"i": [["[\"1dce2c5e4269b8a12dfc50cd4dd23ff1\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 14, [2, 1, 7], 1], ["SP", 3, 8, 14, [2, 1, 1], 1], ["SP", 3, 12, 256, [16, 4, 4], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 512, [64], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["CI", 1], ["CR", 6], ["FU", 3, [0, 1, 2, 3, [...] -{"i": [["[\"d3b36ce001dc24d693facfbdae1979b4\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["CI", 7], ["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 28, [1, 1, 1], 1], ["SP", 3, 8, 28, [7, 1, 1], 1], ["SP", 3, 12, 512, [1, 2, 16], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 128, [4], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["FSP", 8, 0, 2, 2], ["FSP", 8, 3, [...] -{"i": [["[\"a085717fb3dcb046e5c4c2c04d3dc541\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 28, [14, 1, 2], 1], ["SP", 3, 8, 28, [2, 1, 1], 1], ["SP", 3, 12, 128, [1, 1, 16], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 256, [16], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["FSP", 6, 0, 1, 1], ["FSP", 6, 2, 2, 1], [ [...] -{"i": [["[\"8dd7d81db440763f622f03fdc99e6d46\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 56, [14, 2, 2], 1], ["SP", 3, 8, 56, [2, 1, 2], 1], ["SP", 3, 12, 64, [1, 16, 4], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 64, [2], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["FSP", 6, 0, 1, 1], ["FSP", 6, 2, 2, 1], ["FS [...] -{"i": [["[\"ba2026d923536b75e9b4faed89287d5f\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["CI", 4], ["CA", 2, 5, 3], ["CR", 1], ["FU", 1, [0, 1, 2]], ["AN", 1, 0, 3], ["FU", 5, [0, 1, 2]], ["AN", 5, 0, 3], ["PR", 2, 0, "auto_unroll_max_step$64"]]]], "r": [[2.9217e-05, 3.1065e-05, 3.188e-05, 3.0897e-05, 3.1295e-05, 3.1307e-05, 3.19e-05, 3.1038e-05, 3.1919e-05, 3.2077e-05], 0, 0.217184, 1606961266], "v": "v0.3"} -{"i": [["[\"0fb1dfcdb5b755e2dab290ed0129dcf2\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 28, [1, 1, 2], 1], ["SP", 3, 8, 28, [1, 1, 2], 1], ["SP", 3, 12, 128, [2, 2, 16], 1], ["SP", 3, 16, 3, [3], 1], ["SP", 3, 18, 3, [3], 1], ["SP", 3, 20, 128, [8], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["FSP", 6, 0, 1, 2], ["FSP", 6, 3, 2, 2], ["F [...] -{"i": [["[\"e043f834cc7f19597227e09dc7f59503\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 14, [1, 1, 2], 1], ["SP", 3, 8, 14, [7, 2, 1], 1], ["SP", 3, 12, 256, [1, 1, 16], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 1024, [8], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["FSP", 6, 0, 1, 2], ["FSP", 6, 3, 2, 2], [" [...] -{"i": [["[\"a0eb8d6048282a4a0986cc2ccf14eaa2\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 112, [1, 1, 4], 1], ["SP", 3, 8, 112, [4, 2, 1], 1], ["SP", 3, 12, 64, [1, 1, 16], 1], ["SP", 3, 16, 7, [7], 1], ["SP", 3, 18, 7, [7], 1], ["SP", 3, 20, 3, [3], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["FSP", 6, 0, 1, 2], ["FSP", 6, 3, 2, 2], ["FS [...] -{"i": [["[\"03614e726dc588d11887eb0953a77e53\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 7, [1, 1, 1], 1], ["SP", 3, 8, 7, [1, 1, 7], 1], ["SP", 3, 12, 2048, [256, 1, 8], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 512, [4], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["FSP", 5, 0, 0, 2], ["FSP", 5, 3, 1, 2], ["FSP", 5, 6, [...] -{"i": [["[\"b51e06c1131d4cded40d1b215f722a4e\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 56, [4, 1, 1], 1], ["SP", 3, 8, 56, [7, 4, 1], 1], ["SP", 3, 12, 64, [4, 1, 16], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 256, [8], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["FSP", 6, 0, 1, 2], ["FSP", 6, 3, 2, 2], ["FS [...] -{"i": [["[\"a9e632e5167afb60fbe29e7aeef1d152\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 56, [1, 1, 1], 1], ["SP", 3, 8, 56, [7, 1, 4], 1], ["SP", 3, 12, 64, [1, 1, 16], 1], ["SP", 3, 16, 3, [1], 1], ["SP", 3, 18, 3, [3], 1], ["SP", 3, 20, 64, [1], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["FSP", 6, 0, 1, 2], ["FSP", 6, 3, 2, 2], ["FSP [...] -{"i": [["[\"e0a9eb3795b531085e0ebb772e7e800c\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 7, [7, 1, 1], 1], ["SP", 3, 8, 7, [1, 7, 1], 1], ["SP", 3, 12, 512, [2, 2, 8], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 2048, [4], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["FSP", 6, 0, 1, 2], ["FSP", 6, 3, 2, 2], ["FSP [...] -{"i": [["[\"8fcee68a4342c38248a827f1c6c69177\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 56, [4, 2, 1], 1], ["SP", 3, 8, 56, [1, 1, 1], 1], ["SP", 3, 12, 256, [2, 4, 8], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 64, [2], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["FSP", 5, 0, 0, 2], ["FSP", 5, 3, 1, 2], ["FSP", 5, 6, 2, [...] -{"i": [["[\"4d7e646d99bfa3cea8245bd7100369cb\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 14, [7, 2, 1], 1], ["SP", 3, 8, 14, [14, 1, 1], 1], ["SP", 3, 12, 1024, [2, 2, 16], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 512, [4], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["FSP", 4, 0, 1, 2], ["FSP", 4, 3 [...] -{"i": [["[\"b2010aa63c95dedf1f58f3fe8bc78634\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 28, [2, 1, 2], 1], ["SP", 3, 8, 28, [1, 2, 1], 1], ["SP", 3, 12, 512, [16, 1, 16], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 256, [8], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["FSP", 4, 0, 1, 2], ["FSP", 4, 3, [...] -{"i": [["[\"537c8642716948c33a6eaaabc86b159d\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 7, [7, 1, 1], 1], ["SP", 3, 8, 7, [1, 7, 1], 1], ["SP", 3, 12, 2048, [128, 1, 16], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 1024, [4], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["FSP", 4, 0, 1, 2], ["FSP", 4, 3 [...] -{"i": [["[\"7e3f0cf5a6dd80d36dab1a3dad92674a\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 7, [1, 7, 1], 1], ["SP", 3, 8, 7, [7, 1, 1], 1], ["SP", 3, 12, 512, [4, 1, 8], 1], ["SP", 3, 16, 3, [3], 1], ["SP", 3, 18, 3, [1], 1], ["SP", 3, 20, 512, [1], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["FSP", 6, 0, 1, 2], ["FSP", 6, 3, 2, 2], ["FSP" [...] -{"i": [["[\"cd7c4a374fb2bbc0d075c8cae638ad14\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 14, [7, 1, 2], 1], ["SP", 3, 8, 14, [7, 2, 1], 1], ["SP", 3, 12, 1024, [16, 1, 16], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 256, [1], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["FSP", 5, 0, 0, 2], ["FSP", 5, 3, 1, 2], ["FSP", 5, 6 [...] -{"i": [["[\"45b4de07687dee43ee1cbde9f516b2bf\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 56, [56, 1, 1], 1], ["SP", 3, 8, 56, [14, 1, 2], 1], ["SP", 3, 12, 256, [1, 2, 32], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 64, [64], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["FSP", 4, 0, 1, 2], ["FSP", 4, 3 [...] -{"i": [["[\"95bf49cc8cf7a351e974b2359702aac0\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 14, [1, 2, 1], 1], ["SP", 3, 8, 14, [1, 7, 1], 1], ["SP", 3, 12, 256, [2, 1, 8], 1], ["SP", 3, 16, 3, [1], 1], ["SP", 3, 18, 3, [3], 1], ["SP", 3, 20, 256, [1], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["FSP", 6, 0, 1, 2], ["FSP", 6, 3, 2, 2], ["FS [...] -{"i": [["[\"5e3ceb6e23ae8c351d5a1770d5fc6c7c\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 28, [1, 1, 2], 1], ["SP", 3, 8, 28, [1, 1, 1], 1], ["SP", 3, 12, 512, [4, 1, 32], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 128, [4], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["FSP", 5, 0, 0, 2], ["FSP", 5, 3, 1, 2], ["FSP", 5, 6, [...] -{"i": [["[\"691feef049c8693bbe91bd5e7c9cdf34\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["CI", 7], ["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 56, [7, 1, 4], 1], ["SP", 3, 8, 56, [4, 2, 1], 1], ["SP", 3, 12, 256, [32, 1, 8], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 64, [2], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["FSP", 8, 0, 2, 2], ["FSP", 8, 3, [...] -{"i": [["[\"45acfc473c772458684f36a34549d8aa\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 28, [7, 1, 4], 1], ["SP", 3, 8, 28, [14, 1, 1], 1], ["SP", 3, 12, 128, [1, 1, 16], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 512, [1], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["FSP", 6, 0, 1, 2], ["FSP", 6, 3, 2, 2], [" [...] +{"i": [["[\"d7b65649a4dd54becea0a52aabbc5af5\", 1, 1000, 1, 1000]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["SP", 3, 1, 1000, [50], 1], ["RF", 3, 2, 1], ["RE", 3, [0, 2, 1]], ["SP", 1, 1, 1000, [20], 1], ["RF", 1, 2, 1], ["RE", 1, [0, 2, 1]], ["CR", 6], ["CA", 5, 6, 1], ["CR", 4], ["CA", 2, 3, 1], ["AN", 1, 0, 3], ["FU", 3, [0, 1]], ["AN", 3, 0, 3], ["AN", 4, 0, 3], ["FU", 6, [0, 1]], ["AN", 6, 0, 3], ["PR", 1, 0, "auto_unroll_max_step$ [...] +{"i": [["[\"69115f188984ae34ede37c3b8ca40b43\", 1, 7, 7, 2048, 1, 1, 1, 2048]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["CA", 1, 2, 3], ["FU", 2, [0, 1, 2, 3]], ["AN", 2, 0, 3], ["PR", 1, 0, "auto_unroll_max_step$16"]]]], "r": [[6.28e-06, 8.176e-06, 8.048e-06, 7.942e-06, 7.977e-06, 8.002e-06, 8.093e-06, 7.924e-06, 7.943e-06, 7.924e-06], 0, 0.130759, 1606960900], "v": "v0.5"} +{"i": [["[\"875556d12d0be2269206a7775d5296a6\", 1, 7, 7, 512, 1, 1, 512, 2048, 1, 7, 7, 2048, 1, 1, 1, 2048, 1, 1, 1, 2048, 1, 7, 7, 2048]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["CI", 9], ["CI", 7], ["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 7, [1, 7, 1], 1], ["SP", 3, 8, 7, [1, 1, 1], 1], ["SP", 3, 12, 2048, [8, 2, 8], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 512, [2], 1], ["RE", 3, [0, 4, 8, [...] +{"i": [["[\"de7d1695278cf52778b038e6573d7626\", 1, 14, 14, 1024, 1, 1, 1024, 512, 1, 1, 1, 512, 1, 7, 7, 512]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 7, [1, 1, 7], 1], ["SP", 3, 8, 7, [1, 1, 1], 1], ["SP", 3, 12, 512, [32, 1, 8], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 1024, [8], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19 [...] +{"i": [["[\"1b524af89dd867d26059e1f621cf987c\", 1, 14, 14, 256, 1, 1, 256, 1024, 1, 14, 14, 1024, 1, 1, 1, 1024, 1, 14, 14, 1024]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["CI", 7], ["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 14, [7, 1, 2], 1], ["SP", 3, 8, 14, [2, 1, 1], 1], ["SP", 3, 12, 1024, [1, 1, 32], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 256, [8], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, [...] +{"i": [["[\"de7d1695278cf52778b038e6573d7626\", 1, 28, 28, 512, 1, 1, 512, 256, 1, 1, 1, 256, 1, 14, 14, 256]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 14, [2, 1, 7], 1], ["SP", 3, 8, 14, [2, 1, 1], 1], ["SP", 3, 12, 256, [16, 4, 4], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 512, [64], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, [...] +{"i": [["[\"1b524af89dd867d26059e1f621cf987c\", 1, 28, 28, 128, 1, 1, 128, 512, 1, 28, 28, 512, 1, 1, 1, 512, 1, 28, 28, 512]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["CI", 7], ["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 28, [1, 1, 1], 1], ["SP", 3, 8, 28, [7, 1, 1], 1], ["SP", 3, 12, 512, [1, 2, 16], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 128, [4], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, [...] +{"i": [["[\"de7d1695278cf52778b038e6573d7626\", 1, 56, 56, 256, 1, 1, 256, 128, 1, 1, 1, 128, 1, 28, 28, 128]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 28, [14, 1, 2], 1], ["SP", 3, 8, 28, [2, 1, 1], 1], ["SP", 3, 12, 128, [1, 1, 16], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 256, [16], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, [...] +{"i": [["[\"6b7583cf23c7c37d3212cad9d06e58c1\", 1, 56, 56, 64, 1, 1, 64, 64, 1, 1, 1, 64, 1, 56, 56, 64]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 56, [14, 2, 2], 1], ["SP", 3, 8, 56, [2, 1, 2], 1], ["SP", 3, 12, 64, [1, 16, 4], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 64, [2], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, [...] +{"i": [["[\"a5612fdeb9db4d579a75ec225ea4c06a\", 1, 112, 112, 64, 1, 1, 1, 64, 1, 56, 56, 64]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["CI", 4], ["CA", 2, 5, 3], ["CR", 1], ["FU", 1, [0, 1, 2]], ["AN", 1, 0, 3], ["FU", 5, [0, 1, 2]], ["AN", 5, 0, 3], ["PR", 2, 0, "auto_unroll_max_step$64"]]]], "r": [[2.9217e-05, 3.1065e-05, 3.188e-05, 3.0897e-05, 3.1295e-05, 3.1307e-05, 3.19e-05, 3.1038e-05, 3.1919e-05, 3.2077e-05], 0, 0.217184, 1606961 [...] +{"i": [["[\"6b7583cf23c7c37d3212cad9d06e58c1\", 1, 14, 14, 1024, 1, 1, 1024, 256, 1, 1, 1, 256, 1, 14, 14, 256]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 14, [1, 1, 2], 1], ["SP", 3, 8, 14, [7, 2, 1], 1], ["SP", 3, 12, 256, [1, 1, 16], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 1024, [8], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17 [...] +{"i": [["[\"12b88bedece6984af589a28b43e0f3c4\", 1, 224, 224, 3, 7, 7, 3, 64, 1, 1, 1, 64, 1, 112, 112, 64]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 112, [1, 1, 4], 1], ["SP", 3, 8, 112, [4, 2, 1], 1], ["SP", 3, 12, 64, [1, 1, 16], 1], ["SP", 3, 16, 7, [7], 1], ["SP", 3, 18, 7, [7], 1], ["SP", 3, 20, 3, [3], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 2 [...] +{"i": [["[\"1cc666833c122282e3fcf3595901b12b\", 1, 7, 7, 512, 1, 1, 512, 2048, 1, 7, 7, 2048, 1, 7, 7, 2048]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 7, [1, 1, 1], 1], ["SP", 3, 8, 7, [1, 1, 7], 1], ["SP", 3, 12, 2048, [256, 1, 8], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 512, [4], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, [...] +{"i": [["[\"6b7583cf23c7c37d3212cad9d06e58c1\", 1, 56, 56, 256, 1, 1, 256, 64, 1, 1, 1, 64, 1, 56, 56, 64]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 56, [4, 1, 1], 1], ["SP", 3, 8, 56, [7, 4, 1], 1], ["SP", 3, 12, 64, [4, 1, 16], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 256, [8], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 2 [...] +{"i": [["[\"2350d19dc42a0665244368384c66b3a5\", 1, 56, 56, 64, 3, 3, 64, 64, 1, 1, 1, 64, 1, 56, 56, 64]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 56, [1, 1, 1], 1], ["SP", 3, 8, 56, [7, 1, 4], 1], ["SP", 3, 12, 64, [1, 1, 16], 1], ["SP", 3, 16, 3, [1], 1], ["SP", 3, 18, 3, [3], 1], ["SP", 3, 20, 64, [1], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, [...] +{"i": [["[\"6b7583cf23c7c37d3212cad9d06e58c1\", 1, 7, 7, 2048, 1, 1, 2048, 512, 1, 1, 1, 512, 1, 7, 7, 512]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 7, [7, 1, 1], 1], ["SP", 3, 8, 7, [1, 7, 1], 1], ["SP", 3, 12, 512, [2, 2, 8], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 2048, [4], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 2 [...] +{"i": [["[\"1cc666833c122282e3fcf3595901b12b\", 1, 56, 56, 64, 1, 1, 64, 256, 1, 56, 56, 256, 1, 56, 56, 256]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 56, [4, 2, 1], 1], ["SP", 3, 8, 56, [1, 1, 1], 1], ["SP", 3, 12, 256, [2, 4, 8], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 64, [2], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, [...] +{"i": [["[\"f4380bb1dc62422a69ad4a1a9771f927\", 1, 28, 28, 512, 1, 1, 512, 1024, 1, 14, 14, 1024]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 14, [7, 2, 1], 1], ["SP", 3, 8, 14, [14, 1, 1], 1], ["SP", 3, 12, 1024, [2, 2, 16], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 512, [4], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 1 [...] +{"i": [["[\"f4380bb1dc62422a69ad4a1a9771f927\", 1, 56, 56, 256, 1, 1, 256, 512, 1, 28, 28, 512]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 28, [2, 1, 2], 1], ["SP", 3, 8, 28, [1, 2, 1], 1], ["SP", 3, 12, 512, [16, 1, 16], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 256, [8], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, [...] +{"i": [["[\"f4380bb1dc62422a69ad4a1a9771f927\", 1, 14, 14, 1024, 1, 1, 1024, 2048, 1, 7, 7, 2048]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 7, [7, 1, 1], 1], ["SP", 3, 8, 7, [1, 7, 1], 1], ["SP", 3, 12, 2048, [128, 1, 16], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 1024, [4], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 1 [...] +{"i": [["[\"2350d19dc42a0665244368384c66b3a5\", 1, 7, 7, 512, 3, 3, 512, 512, 1, 1, 1, 512, 1, 7, 7, 512]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 7, [1, 7, 1], 1], ["SP", 3, 8, 7, [7, 1, 1], 1], ["SP", 3, 12, 512, [4, 1, 8], 1], ["SP", 3, 16, 3, [3], 1], ["SP", 3, 18, 3, [1], 1], ["SP", 3, 20, 512, [1], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, [...] +{"i": [["[\"1cc666833c122282e3fcf3595901b12b\", 1, 14, 14, 256, 1, 1, 256, 1024, 1, 14, 14, 1024, 1, 14, 14, 1024]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 14, [7, 1, 2], 1], ["SP", 3, 8, 14, [7, 2, 1], 1], ["SP", 3, 12, 1024, [16, 1, 16], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 256, [1], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 2 [...] +{"i": [["[\"7006235cfc29b73be524cf390ed5a977\", 1, 56, 56, 64, 1, 1, 64, 256, 1, 56, 56, 256]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 56, [56, 1, 1], 1], ["SP", 3, 8, 56, [14, 1, 2], 1], ["SP", 3, 12, 256, [1, 2, 32], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 64, [64], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 2 [...] +{"i": [["[\"1cc666833c122282e3fcf3595901b12b\", 1, 28, 28, 128, 1, 1, 128, 512, 1, 28, 28, 512, 1, 28, 28, 512]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 28, [1, 1, 2], 1], ["SP", 3, 8, 28, [1, 1, 1], 1], ["SP", 3, 12, 512, [4, 1, 32], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 128, [4], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, [...] +{"i": [["[\"1b524af89dd867d26059e1f621cf987c\", 1, 56, 56, 64, 1, 1, 64, 256, 1, 56, 56, 256, 1, 1, 1, 256, 1, 56, 56, 256]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["CI", 7], ["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 56, [7, 1, 4], 1], ["SP", 3, 8, 56, [4, 2, 1], 1], ["SP", 3, 12, 256, [32, 1, 8], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 64, [2], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, [...] +{"i": [["[\"6b7583cf23c7c37d3212cad9d06e58c1\", 1, 28, 28, 512, 1, 1, 512, 128, 1, 1, 1, 128, 1, 28, 28, 128]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 28, [7, 1, 4], 1], ["SP", 3, 8, 28, [14, 1, 1], 1], ["SP", 3, 12, 128, [1, 1, 16], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 512, [1], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, [...]