nswamy commented on a change in pull request #14685: [Fit API] improve event 
handlers
URL: https://github.com/apache/incubator-mxnet/pull/14685#discussion_r277079658
 
 

 ##########
 File path: python/mxnet/gluon/contrib/estimator/event_handler.py
 ##########
 @@ -219,55 +320,61 @@ def __init__(self,
             self.best = -np.Inf
         else:
             # use greater for accuracy and less otherwise
-            if 'acc' in self.monitor:
+            if 'acc' in self.monitor.get()[0].lower():
                 self.monitor_op = np.greater
                 self.best = -np.Inf
             else:
                 self.monitor_op = np.less
                 self.best = np.Inf
 
-    def epoch_end(self, ):
-        epoch = self.estimator.current_epoch
+    def batch_end(self, estimator, *args, **kwargs):
+        self._save_checkpoint(estimator.net, "Batch", self.num_batches)
+        self.num_batches += 1
+
+    def epoch_end(self, estimator, *args, **kwargs):
+        self._save_checkpoint(estimator.net, "Epoch", self.num_epochs)
+        self.num_epochs += 1
+
+    def _save_checkpoint(self, net, period_name, period_value):
         # add extension for weights
         if '.params' not in self.filepath:
             self.filepath += '.params'
-        self.epochs_since_last_save += 1
-        if self.epochs_since_last_save >= self.period:
-            self.epochs_since_last_save = 0
+        if self.num_epochs % self.epoch_period == 0:
             if self.save_best_only:
+                monitor_name, monitor_value = self.monitor.get()
                 # check if monitor exists in train stats
-                if self.monitor not in self.estimator.train_stats:
-                    warnings.warn(RuntimeWarning('Unable to find %s in 
training statistics, make sure the monitor value'
-                                                 'starts with `train_ `or 
`val_` and contains loss/metric name, ',
-                                                 'for example val_accuracy', 
self.monitor))
-                    self.estimator.net.save_parameters(self.filepath)
+                if np.isnan(monitor_value):
+                    warnings.warn(RuntimeWarning('%s is not updated, make sure 
you pass one of the metric objects'
 
 Review comment:
   use logger, so the user can control.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to