nswamy commented on a change in pull request #14685: [Fit API] improve event handlers URL: https://github.com/apache/incubator-mxnet/pull/14685#discussion_r277221859
########## File path: python/mxnet/gluon/contrib/estimator/event_handler.py ########## @@ -16,85 +16,169 @@ # under the License. # coding: utf-8 -# pylint: disable=wildcard-import +# pylint: disable=wildcard-import, unused-argument """Gluon EventHandlers for Estimators""" -__all__ = ['EventHandler', 'LoggingHandler'] import logging import os import time import warnings import numpy as np +from ....metric import EvalMetric, Loss -class EventHandler(object): - """Basic for event handlers - :py:class:`EventHandler` can perform user defined functions at - different stages of training: train begin, epoch begin, batch begin, - batch end, epoch end, train end. - - Parameters - ---------- - estimator : Estimator - The :py:class:`Estimator` to get training statistics - """ +class TrainBegin(object): + def train_begin(self, estimator, *args, **kwargs): + pass - def __init__(self): - self._estimator = None - @property - def estimator(self): - return self._estimator +class TrainEnd(object): + def train_end(self, estimator, *args, **kwargs): + pass - @estimator.setter - def estimator(self, estimator): - self._estimator = estimator - def train_begin(self): +class EpochBegin(object): + def epoch_begin(self, estimator, *args, **kwargs): pass - def train_end(self): - pass - def batch_begin(self): - pass +class EpochEnd(object): + def epoch_end(self, estimator, *args, **kwargs): + return False - def batch_end(self): - pass - def epoch_begin(self): +class BatchBegin(object): + def batch_begin(self, estimator, *args, **kwargs): pass - def epoch_end(self): - pass +class BatchEnd(object): + def batch_end(self, estimator, *args, **kwargs): + return False + + +class MetricHandler(EpochBegin, BatchEnd): + """Metric Handler that update metric values at batch end + + :py:class:`MetricHandler` takes model predictions and true labels + and update the metrics, it also update metric wrapper for loss with loss values + Validation loss and metrics will be handled by :py:class:`ValidationHandler` + + Parameters + ---------- + train_metrics : List of EvalMetrics + training metrics to be updated at batch end + """ + + def __init__(self, train_metrics): + self.train_metrics = train_metrics or [] + # order to be called among all callbacks + # metrics need to be calculated before other callbacks can access them + self.priority = -np.Inf Review comment: lets call this out explicitly in the documentation. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services