nswamy commented on a change in pull request #14685: [Fit API] improve event handlers URL: https://github.com/apache/incubator-mxnet/pull/14685#discussion_r277079156
########## File path: python/mxnet/gluon/contrib/estimator/event_handler.py ########## @@ -112,71 +196,83 @@ def __init__(self, file_name=None, file_location=None, verbose=LOG_VERBOSITY_PER file_location = file_location or './' file_handler = logging.FileHandler(os.path.join(file_location, file_name)) self.logger.addHandler(file_handler) - - def train_begin(self): + self.train_metrics = train_metrics or [] + self.val_metrics = val_metrics or [] + self.batch_index = 0 + self.current_epoch = 0 + self.processed_samples = 0 + # logging handler need to be called at last to make sure all states are updated + # it will also shut down logging at train end + self.priority = np.Inf + + def train_begin(self, estimator, *args, **kwargs): self.train_start = time.time() + trainer = estimator.trainer + optimizer = trainer.optimizer.__class__.__name__ + lr = trainer.learning_rate self.logger.info("Training begin: using optimizer %s " "with current learning rate %.4f ", - self.estimator.trainer.optimizer.__class__.__name__, - self.estimator.trainer.learning_rate) - self.logger.info("Train for %d epochs.", self.estimator.max_epoch) + optimizer, lr) + self.logger.info("Train for %d epochs.", estimator.max_epochs) - def train_end(self): + def train_end(self, estimator, *args, **kwargs): train_time = time.time() - self.train_start - epoch = self.estimator.current_epoch - msg = 'Train finished using total %ds at epoch %d. ' % (train_time, epoch) + msg = 'Train finished using total %ds with %d epochs.' % (train_time, self.current_epoch) # log every result in train stats including train/validation loss & metrics - for key in self.estimator.train_stats: - msg += '%s : %.4f ' % (key, self.estimator.train_stats[key]) + for metric in self.train_metrics + self.val_metrics: + name, value = metric.get() + msg += '%s : %.4f ' % (name, value) self.logger.info(msg) + for handler in self.logger.handlers: + handler.close() + self.logger.removeHandler(handler) + logging.shutdown() - def batch_begin(self): + def batch_begin(self, estimator, *args, **kwargs): if self.verbose == self.LOG_VERBOSITY_PER_BATCH: self.batch_start = time.time() - def batch_end(self): + def batch_end(self, estimator, *args, **kwargs): if self.verbose == self.LOG_VERBOSITY_PER_BATCH: batch_time = time.time() - self.batch_start - epoch = self.estimator.current_epoch - batch = self.estimator.batch_idx - msg = '[Epoch %d] [Batch %d] ' % (epoch, batch) - if self.estimator.processed_samples: - msg += '[Samples %s] ' % (self.estimator.processed_samples) + msg = '[Epoch %d] [Batch %d] ' % (self.current_epoch, self.batch_index) + self.processed_samples += kwargs['batch'][0].shape[0] + msg += '[Samples %s] ' % (self.processed_samples) msg += 'time/batch: %.3fs ' % batch_time - for key in self.estimator.train_stats: + for metric in self.train_metrics: # only log current training loss & metric after each batch - if key.startswith('train_'): - msg += key + ': ' + '%.4f ' % self.estimator.train_stats[key] + name, value = metric.get() + msg += '%s : %.4f ' % (name, value) self.logger.info(msg) + self.batch_index += 1 Review comment: shouldn't this be in the estimator itself, why should all handlers maintain this? ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services