nswamy commented on a change in pull request #14685: [Fit API] improve event handlers URL: https://github.com/apache/incubator-mxnet/pull/14685#discussion_r277079658
########## File path: python/mxnet/gluon/contrib/estimator/event_handler.py ########## @@ -219,55 +320,61 @@ def __init__(self, self.best = -np.Inf else: # use greater for accuracy and less otherwise - if 'acc' in self.monitor: + if 'acc' in self.monitor.get()[0].lower(): self.monitor_op = np.greater self.best = -np.Inf else: self.monitor_op = np.less self.best = np.Inf - def epoch_end(self, ): - epoch = self.estimator.current_epoch + def batch_end(self, estimator, *args, **kwargs): + self._save_checkpoint(estimator.net, "Batch", self.num_batches) + self.num_batches += 1 + + def epoch_end(self, estimator, *args, **kwargs): + self._save_checkpoint(estimator.net, "Epoch", self.num_epochs) + self.num_epochs += 1 + + def _save_checkpoint(self, net, period_name, period_value): # add extension for weights if '.params' not in self.filepath: self.filepath += '.params' - self.epochs_since_last_save += 1 - if self.epochs_since_last_save >= self.period: - self.epochs_since_last_save = 0 + if self.num_epochs % self.epoch_period == 0: if self.save_best_only: + monitor_name, monitor_value = self.monitor.get() # check if monitor exists in train stats - if self.monitor not in self.estimator.train_stats: - warnings.warn(RuntimeWarning('Unable to find %s in training statistics, make sure the monitor value' - 'starts with `train_ `or `val_` and contains loss/metric name, ', - 'for example val_accuracy', self.monitor)) - self.estimator.net.save_parameters(self.filepath) + if np.isnan(monitor_value): + warnings.warn(RuntimeWarning('%s is not updated, make sure you pass one of the metric objects' Review comment: use logger, so the user can control. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services