comaniac commented on a change in pull request #7197:
URL: https://github.com/apache/tvm/pull/7197#discussion_r551199732



##########
File path: python/tvm/auto_scheduler/task_scheduler.py
##########
@@ -82,11 +82,12 @@ def make_search_policies(
     if isinstance(search_policy, str):
         policy_type, model_type = search_policy.split(".")
         if model_type == "xgb":
-            cost_model = XGBModel(num_warmup_sample=len(tasks) * 
num_measures_per_round)
-            if load_model_file:
-                logger.info("TaskScheduler: Load pretrained model...")
-                cost_model.load(load_model_file)
-            elif load_log_file:
+            cost_model = XGBModel(
+                num_warmup_sample=len(tasks) * num_measures_per_round,
+                model_file=load_model_file,
+            )
+            if load_log_file:
+                logger.info("TaskScheduler: Reload measured states and 
pretrain model...")

Review comment:
       ```suggestion
                   logger.info("TaskScheduler: Reload measured states and 
pretrained model...")
   ```
   

##########
File path: src/auto_scheduler/feature.cc
##########
@@ -1462,12 +1462,18 @@ void GetPerStoreFeaturesFromMeasurePairs(const 
Array<MeasureInput>& inputs,
     if (find_res == task_cache.end()) {
       if (inputs[i]->task->compute_dag.defined()) {  // the measure input is 
complete
         task = inputs[i]->task;
-      } else {  // the measure input is incomplete
-        // rebuild task for incomplete measure pairs read from file
-        Array<te::Tensor> tensors = (*workload_key_to_tensors)(workload_key);
-        task = SearchTask(ComputeDAG(tensors), workload_key, 
inputs[i]->task->target,
-                          inputs[i]->task->target_host, 
inputs[i]->task->hardware_params,
-                          inputs[i]->task->layout_rewrite_option);
+      } else {
+        // The measure input is incomplete, rebuild task for incomplete 
measure pairs read from file
+        try {
+          Array<te::Tensor> tensors = (*workload_key_to_tensors)(workload_key);
+          task = SearchTask(ComputeDAG(tensors), workload_key, 
inputs[i]->task->target,
+                            inputs[i]->task->target_host, 
inputs[i]->task->hardware_params,
+                            inputs[i]->task->layout_rewrite_option);
+        } catch (std::exception& e) {
+          // Cannot build ComputeDAG from workload key, the task may have not 
been registered in
+          // this search round
+          continue;

Review comment:
       Should we have a warning here? Otherwise it may be confusing.

##########
File path: python/tvm/auto_scheduler/cost_model/xgb_model.py
##########
@@ -141,6 +146,12 @@ def update(self, inputs, results):
         self.inputs.extend(inputs)
         self.results.extend(results)
 
+        if len(self.inputs) - self.last_train_length < self.last_train_length 
/ 5:

Review comment:
       ```suggestion
           if len(inputs) < self.last_train_length / 5:
   ```
   Could you explain a bit more on this logic or make it more straightforward?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to