shingjan commented on code in PR #12141:
URL: https://github.com/apache/tvm/pull/12141#discussion_r930543004


##########
python/tvm/meta_schedule/cost_model/xgb_model.py:
##########
@@ -763,3 +768,162 @@ def callback(env: "xgb.core.CallbackEnv"):
             raise EarlyStopException(best_iteration)
 
     return callback
+
+
+class XGBoostCallback(TrainingCallback):
+    """Base class for XGBoost callbacks."""
+
+    def __call__(self, env: "xgb.core.CallbackEnv"):
+        # Compatibility with xgboost < 1.3
+        return self.after_iteration(env.model, env.iteration, 
env.evaluation_result_list)
+
+    def after_iteration(self, model: "xgb.Booster", epoch: int, evals_log: 
Dict):
+        raise NotImplementedError
+
+
+class XGBoostCustomCallback(XGBoostCallback):
+    """Custom callback class for xgboost to support multiple custom evaluation 
functions"""
+
+    def __init__(
+        self,
+        early_stopping_rounds: int,
+        verbose_eval: int,
+        fevals: List[Callable],
+        evals: List[Tuple["xgb.DMatrix", str]],
+        focused_metric: str = "tr-p-rmse",
+        cvfolds: List["xgb.training.CVPack"] = None,
+    ):
+        self.early_stopping_rounds = early_stopping_rounds
+        self.verbose_eval = verbose_eval
+        self.fevals = fevals
+        self.evals = evals
+        self.state: Dict[str, Any] = {}
+        self.focused_metric = focused_metric
+        self.sort_key = make_metric_sorter(focused_metric=focused_metric)
+        self.cvfolds = cvfolds
+        if cvfolds is not None:
+            self.aggregated_cv = None
+
+    def init(self, model: "xgb.Booster"):
+        """Internal function for intialization"""
+        booster: "xgb.Booster" = model
+        self.state["best_iteration"] = 0
+        self.state["best_score"] = float("inf")
+        if booster is None:
+            assert self.cvfolds is not None
+            return
+        if booster.attr("best_score") is not None:
+            self.state["best_score"] = float(booster.attr("best_score"))
+            self.state["best_iteration"] = int(booster.attr("best_iteration"))
+            self.state["best_msg"] = booster.attr("best_msg")
+        else:
+            booster.set_attr(best_iteration=str(self.state["best_iteration"]))
+            booster.set_attr(best_score=str(self.state["best_score"]))
+
+    def after_iteration(self, model: "xgb.Booster", epoch: int, evals_log: 
Dict):
+        """Internal function for after_iteration"""
+        # pylint:disable = import-outside-toplevel

Review Comment:
   I guess we will need to keep this one disabled as there are other imports 
outside of toplevel in this specific function



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to