This is an automated email from the ASF dual-hosted git repository. wuwei pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 4d2766409f [AutoTVM] Fix `None` feature in AutoTVM tuning (#12760) 4d2766409f is described below commit 4d2766409f1b95504aac171649367c2df2813029 Author: Junru Shao <junrushao1...@gmail.com> AuthorDate: Mon Sep 12 15:06:16 2022 -0800 [AutoTVM] Fix `None` feature in AutoTVM tuning (#12760) This PR introduces a couple of fixes to make AutoTVM working more robustly: - Fixed a very rarecase that `None` could pop up in AutoTVM features; - Fixed a misuse of `ARGS` in the testing script; - Fixed the filename for caching. --- python/tvm/autotvm/testing/tune_relay.py | 13 +++++++------ python/tvm/autotvm/tuner/xgboost_cost_model.py | 7 +++---- python/tvm/meta_schedule/testing/relay_workload.py | 2 +- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/python/tvm/autotvm/testing/tune_relay.py b/python/tvm/autotvm/testing/tune_relay.py index e474596374..743127ec1d 100644 --- a/python/tvm/autotvm/testing/tune_relay.py +++ b/python/tvm/autotvm/testing/tune_relay.py @@ -139,12 +139,6 @@ def _parse_args(): tracker_key=parsed.rpc_key, session_timeout_sec=600, ) - if ARGS.target.kind.name != "llvm" and ARGS.graph_tuner: - raise ValueError("GraphTuner only supports llvm target") - if ARGS.target.kind.name != "llvm" and ARGS.cpu_flush: - raise ValueError("cpu_flush only supports llvm target") - if ARGS.target.kind.name == "llvm" and not ARGS.cpu_flush: - warnings.warn("cpu_flush is not enabled for llvm target") return parsed @@ -152,6 +146,13 @@ ARGS = _parse_args() def main(): + if ARGS.target.kind.name != "llvm" and ARGS.graph_tuner: + raise ValueError("GraphTuner only supports llvm target") + if ARGS.target.kind.name != "llvm" and ARGS.cpu_flush: + raise ValueError("cpu_flush only supports llvm target") + if ARGS.target.kind.name == "llvm" and not ARGS.cpu_flush: + warnings.warn("cpu_flush is not enabled for llvm target") + log_file = os.path.join(ARGS.work_dir, f"{ARGS.workload}.json") graph_opt_sch_file = os.path.join(ARGS.work_dir, f"{ARGS.workload}_graph_opt.log") measure_option = autotvm.measure_option( diff --git a/python/tvm/autotvm/tuner/xgboost_cost_model.py b/python/tvm/autotvm/tuner/xgboost_cost_model.py index d4942ce6a4..6fa04f336f 100644 --- a/python/tvm/autotvm/tuner/xgboost_cost_model.py +++ b/python/tvm/autotvm/tuner/xgboost_cost_model.py @@ -21,12 +21,11 @@ import logging import time import numpy as np - from tvm.contrib.popen_pool import PopenPoolExecutor, StatusKind from .. import feature from ..utils import get_rank -from .metric import max_curve, recall_curve, cover_curve +from .metric import cover_curve, max_curve, recall_curve from .model_based_tuner import CostModel, FeatureCache xgb = None @@ -346,7 +345,7 @@ class XGBoostCostModel(CostModel): ret = np.empty((len(indexes), feature_len), dtype=np.float32) for i, ii in enumerate(indexes): t = fea_cache[ii] - if t.shape[0] < feature_len: + if t is not None and t.shape[0] < feature_len: t = np.pad(t, (0, feature_len - t.shape[0])) ret[i, :] = t if t is not None else 0 return ret @@ -449,8 +448,8 @@ def custom_callback( ): """callback function for xgboost to support multiple custom evaluation functions""" # pylint: disable=import-outside-toplevel - from xgboost.core import EarlyStopException from xgboost.callback import _fmt_metric + from xgboost.core import EarlyStopException try: from xgboost.training import aggcv diff --git a/python/tvm/meta_schedule/testing/relay_workload.py b/python/tvm/meta_schedule/testing/relay_workload.py index f4f6336df3..98bb995120 100644 --- a/python/tvm/meta_schedule/testing/relay_workload.py +++ b/python/tvm/meta_schedule/testing/relay_workload.py @@ -230,7 +230,7 @@ def get_network( inputs: Tuple[str, List[int], str] params_bytearray: bytearray - filename = f'relay-{name}-{",".join(str(i) for i in input_shape)}.json' + filename = f'relay-{name}-{layout}-{",".join(str(i) for i in input_shape)}.json' cached = _load_cache(cache_dir, filename) if cached is None: with multiprocessing.Pool(processes=1) as pool: