This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4d2766409f [AutoTVM] Fix `None` feature in AutoTVM tuning (#12760)
4d2766409f is described below

commit 4d2766409f1b95504aac171649367c2df2813029
Author: Junru Shao <junrushao1...@gmail.com>
AuthorDate: Mon Sep 12 15:06:16 2022 -0800

    [AutoTVM] Fix `None` feature in AutoTVM tuning (#12760)
    
    This PR introduces a couple of fixes to make AutoTVM working more
    robustly:
    - Fixed a very rarecase that `None` could pop up in AutoTVM features;
    - Fixed a misuse of `ARGS` in the testing script;
    - Fixed the filename for caching.
---
 python/tvm/autotvm/testing/tune_relay.py           | 13 +++++++------
 python/tvm/autotvm/tuner/xgboost_cost_model.py     |  7 +++----
 python/tvm/meta_schedule/testing/relay_workload.py |  2 +-
 3 files changed, 11 insertions(+), 11 deletions(-)

diff --git a/python/tvm/autotvm/testing/tune_relay.py 
b/python/tvm/autotvm/testing/tune_relay.py
index e474596374..743127ec1d 100644
--- a/python/tvm/autotvm/testing/tune_relay.py
+++ b/python/tvm/autotvm/testing/tune_relay.py
@@ -139,12 +139,6 @@ def _parse_args():
         tracker_key=parsed.rpc_key,
         session_timeout_sec=600,
     )
-    if ARGS.target.kind.name != "llvm" and ARGS.graph_tuner:
-        raise ValueError("GraphTuner only supports llvm target")
-    if ARGS.target.kind.name != "llvm" and ARGS.cpu_flush:
-        raise ValueError("cpu_flush only supports llvm target")
-    if ARGS.target.kind.name == "llvm" and not ARGS.cpu_flush:
-        warnings.warn("cpu_flush is not enabled for llvm target")
     return parsed
 
 
@@ -152,6 +146,13 @@ ARGS = _parse_args()
 
 
 def main():
+    if ARGS.target.kind.name != "llvm" and ARGS.graph_tuner:
+        raise ValueError("GraphTuner only supports llvm target")
+    if ARGS.target.kind.name != "llvm" and ARGS.cpu_flush:
+        raise ValueError("cpu_flush only supports llvm target")
+    if ARGS.target.kind.name == "llvm" and not ARGS.cpu_flush:
+        warnings.warn("cpu_flush is not enabled for llvm target")
+
     log_file = os.path.join(ARGS.work_dir, f"{ARGS.workload}.json")
     graph_opt_sch_file = os.path.join(ARGS.work_dir, 
f"{ARGS.workload}_graph_opt.log")
     measure_option = autotvm.measure_option(
diff --git a/python/tvm/autotvm/tuner/xgboost_cost_model.py 
b/python/tvm/autotvm/tuner/xgboost_cost_model.py
index d4942ce6a4..6fa04f336f 100644
--- a/python/tvm/autotvm/tuner/xgboost_cost_model.py
+++ b/python/tvm/autotvm/tuner/xgboost_cost_model.py
@@ -21,12 +21,11 @@ import logging
 import time
 
 import numpy as np
-
 from tvm.contrib.popen_pool import PopenPoolExecutor, StatusKind
 
 from .. import feature
 from ..utils import get_rank
-from .metric import max_curve, recall_curve, cover_curve
+from .metric import cover_curve, max_curve, recall_curve
 from .model_based_tuner import CostModel, FeatureCache
 
 xgb = None
@@ -346,7 +345,7 @@ class XGBoostCostModel(CostModel):
         ret = np.empty((len(indexes), feature_len), dtype=np.float32)
         for i, ii in enumerate(indexes):
             t = fea_cache[ii]
-            if t.shape[0] < feature_len:
+            if t is not None and t.shape[0] < feature_len:
                 t = np.pad(t, (0, feature_len - t.shape[0]))
             ret[i, :] = t if t is not None else 0
         return ret
@@ -449,8 +448,8 @@ def custom_callback(
 ):
     """callback function for xgboost to support multiple custom evaluation 
functions"""
     # pylint: disable=import-outside-toplevel
-    from xgboost.core import EarlyStopException
     from xgboost.callback import _fmt_metric
+    from xgboost.core import EarlyStopException
 
     try:
         from xgboost.training import aggcv
diff --git a/python/tvm/meta_schedule/testing/relay_workload.py 
b/python/tvm/meta_schedule/testing/relay_workload.py
index f4f6336df3..98bb995120 100644
--- a/python/tvm/meta_schedule/testing/relay_workload.py
+++ b/python/tvm/meta_schedule/testing/relay_workload.py
@@ -230,7 +230,7 @@ def get_network(
     inputs: Tuple[str, List[int], str]
     params_bytearray: bytearray
 
-    filename = f'relay-{name}-{",".join(str(i) for i in input_shape)}.json'
+    filename = f'relay-{name}-{layout}-{",".join(str(i) for i in 
input_shape)}.json'
     cached = _load_cache(cache_dir, filename)
     if cached is None:
         with multiprocessing.Pool(processes=1) as pool:

Reply via email to