roywei commented on a change in pull request #14685: [Fit API] improve event 
handlers
URL: https://github.com/apache/incubator-mxnet/pull/14685#discussion_r277096082
 
 

 ##########
 File path: python/mxnet/gluon/contrib/estimator/event_handler.py
 ##########
 @@ -16,85 +16,169 @@
 # under the License.
 
 # coding: utf-8
-# pylint: disable=wildcard-import
+# pylint: disable=wildcard-import, unused-argument
 """Gluon EventHandlers for Estimators"""
 
-__all__ = ['EventHandler', 'LoggingHandler']
 import logging
 import os
 import time
 import warnings
 
 import numpy as np
 
+from ....metric import EvalMetric, Loss
 
-class EventHandler(object):
-    """Basic for event handlers
 
-        :py:class:`EventHandler` can perform user defined functions at
-        different stages of training: train begin, epoch begin, batch begin,
-        batch end, epoch end, train end.
-
-        Parameters
-        ----------
-        estimator : Estimator
-            The :py:class:`Estimator` to get training statistics
-        """
+class TrainBegin(object):
+    def train_begin(self, estimator, *args, **kwargs):
+        pass
 
-    def __init__(self):
-        self._estimator = None
 
-    @property
-    def estimator(self):
-        return self._estimator
+class TrainEnd(object):
+    def train_end(self, estimator, *args, **kwargs):
+        pass
 
-    @estimator.setter
-    def estimator(self, estimator):
-        self._estimator = estimator
 
-    def train_begin(self):
+class EpochBegin(object):
+    def epoch_begin(self, estimator, *args, **kwargs):
         pass
 
-    def train_end(self):
-        pass
 
-    def batch_begin(self):
-        pass
+class EpochEnd(object):
+    def epoch_end(self, estimator, *args, **kwargs):
+        return False
 
-    def batch_end(self):
-        pass
 
-    def epoch_begin(self):
+class BatchBegin(object):
+    def batch_begin(self, estimator, *args, **kwargs):
         pass
 
-    def epoch_end(self):
-        pass
 
+class BatchEnd(object):
+    def batch_end(self, estimator, *args, **kwargs):
+        return False
+
+
+class MetricHandler(EpochBegin, BatchEnd):
+    """Metric Handler that update metric values at batch end
+
+    :py:class:`MetricHandler` takes model predictions and true labels
+    and update the metrics, it also update metric wrapper for loss with loss 
values
+    Validation loss and metrics will be handled by 
:py:class:`ValidationHandler`
+
+    Parameters
+    ----------
+    train_metrics : List of EvalMetrics
+        training metrics to be updated at batch end
+    """
+
+    def __init__(self, train_metrics):
+        self.train_metrics = train_metrics or []
+        # order to be called among all callbacks
+        # metrics need to be calculated before other callbacks can access them
+        self.priority = -np.Inf
+
+    def epoch_begin(self, estimator, *args, **kwargs):
+        for metric in self.train_metrics:
+            metric.reset()
+
+    def batch_end(self, estimator, *args, **kwargs):
+        pred = kwargs['pred']
+        label = kwargs['label']
+        loss = kwargs['loss']
+        for metric in self.train_metrics:
+            if isinstance(metric, Loss):
+                # metric wrapper for loss values
+                metric.update(0, loss)
+            else:
+                metric.update(label, pred)
 
-class LoggingHandler(EventHandler):
+
+class ValidationHandler(BatchEnd, EpochEnd):
+    """"Validation Handler that evaluate model on validation dataset
+
+    :py:class:`ValidationHandler` takes validation dataset, an evaluation 
function,
+    metrics to be evaluated, and how often to run the validation. You can 
provide custom
+    evaluation function or use the one provided my :py:class:`Estimator`
+
+    Parameters
+    ----------
+    val_data : DataLoader
+        validation data set to run evaluation
+    eval_fn : function
+        a function defines how to run evaluation and
+        calculate loss and metrics
+    val_metrics : List of EvalMetrics
+        validation metrics to be updated
+    epoch_period : int, default 1
+        how often to run validation at epoch end, by default
+        validate every epoch
+    batch_period : int, default None
+        how often to run validation at batch end, by default
+        does not validate at batch end
+    """
+
+    def __init__(self,
+                 val_data,
+                 eval_fn,
+                 val_metrics=None,
+                 epoch_period=1,
+                 batch_period=None):
+        self.val_data = val_data
+        self.eval_fn = eval_fn
+        self.epoch_period = epoch_period
+        self.batch_period = batch_period
+        self.val_metrics = val_metrics
+        self.num_batches = 0
+        self.num_epochs = 0
+        # order to be called among all callbacks
+        # validation metrics need to be calculated before other callbacks can 
access them
+        self.priority = -np.Inf
+
+    def batch_end(self, estimator, *args, **kwargs):
+        if self.batch_period and self.num_batches % self.batch_period == 0:
+            self.eval_fn(val_data=self.val_data,
+                         val_metrics=self.val_metrics)
+        self.num_batches += 1
+
+    def epoch_end(self, estimator, *args, **kwargs):
+        if self.num_epochs % self.epoch_period == 0:
+            self.eval_fn(val_data=self.val_data,
+                         val_metrics=self.val_metrics)
+
+        self.num_epochs += 1
+
+
+class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, 
BatchEnd):
     """Basic Logging Handler that applies to every Gluon estimator by default.
 
     :py:class:`LoggingHandler` logs hyper-parameters, training statistics,
     and other useful information during training
 
     Parameters
     ----------
-    estimator : Estimator
-        The :py:class:`Estimator` to get training statistics
     file_name : str
         file name to save the logs
-    file_location: str
+    file_location : str
         file location to save the logs
-    verbose: int, default LOG_VERBOSITY_PER_EPOCH
+    verbose : int, default LOG_VERBOSITY_PER_EPOCH
         Limit the granularity of metrics displayed during training process
         verbose=LOG_VERBOSITY_PER_EPOCH: display metrics every epoch
         verbose=LOG_VERBOSITY_PER_BATCH: display metrics every batch
+    train_metrics : list of EvalMetrics
+        training metrics to be logged, logged at batch end, epoch end, train 
end
+    val_metrics : list of EvalMetrics
+        validation metrics to be logged, logged at epoch end, train end
     """
 
     LOG_VERBOSITY_PER_EPOCH = 1
     LOG_VERBOSITY_PER_BATCH = 2
 
-    def __init__(self, file_name=None, file_location=None, 
verbose=LOG_VERBOSITY_PER_EPOCH):
+    def __init__(self, file_name=None,
+                 file_location=None,
+                 verbose=LOG_VERBOSITY_PER_EPOCH,
+                 train_metrics=None,
+                 val_metrics=None):
 
 Review comment:
   yes that's correct

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to