nswamy commented on a change in pull request #14685: [Fit API] improve event 
handlers
URL: https://github.com/apache/incubator-mxnet/pull/14685#discussion_r277079156
 
 

 ##########
 File path: python/mxnet/gluon/contrib/estimator/event_handler.py
 ##########
 @@ -112,71 +196,83 @@ def __init__(self, file_name=None, file_location=None, 
verbose=LOG_VERBOSITY_PER
             file_location = file_location or './'
             file_handler = logging.FileHandler(os.path.join(file_location, 
file_name))
             self.logger.addHandler(file_handler)
-
-    def train_begin(self):
+        self.train_metrics = train_metrics or []
+        self.val_metrics = val_metrics or []
+        self.batch_index = 0
+        self.current_epoch = 0
+        self.processed_samples = 0
+        # logging handler need to be called at last to make sure all states 
are updated
+        # it will also shut down logging at train end
+        self.priority = np.Inf
+
+    def train_begin(self, estimator, *args, **kwargs):
         self.train_start = time.time()
+        trainer = estimator.trainer
+        optimizer = trainer.optimizer.__class__.__name__
+        lr = trainer.learning_rate
         self.logger.info("Training begin: using optimizer %s "
                          "with current learning rate %.4f ",
-                         self.estimator.trainer.optimizer.__class__.__name__,
-                         self.estimator.trainer.learning_rate)
-        self.logger.info("Train for %d epochs.", self.estimator.max_epoch)
+                         optimizer, lr)
+        self.logger.info("Train for %d epochs.", estimator.max_epochs)
 
-    def train_end(self):
+    def train_end(self, estimator, *args, **kwargs):
         train_time = time.time() - self.train_start
-        epoch = self.estimator.current_epoch
-        msg = 'Train finished using total %ds at epoch %d. ' % (train_time, 
epoch)
+        msg = 'Train finished using total %ds with %d epochs.' % (train_time, 
self.current_epoch)
         # log every result in train stats including train/validation loss & 
metrics
-        for key in self.estimator.train_stats:
-            msg += '%s : %.4f ' % (key, self.estimator.train_stats[key])
+        for metric in self.train_metrics + self.val_metrics:
+            name, value = metric.get()
+            msg += '%s : %.4f ' % (name, value)
         self.logger.info(msg)
+        for handler in self.logger.handlers:
+            handler.close()
+            self.logger.removeHandler(handler)
+        logging.shutdown()
 
-    def batch_begin(self):
+    def batch_begin(self, estimator, *args, **kwargs):
         if self.verbose == self.LOG_VERBOSITY_PER_BATCH:
             self.batch_start = time.time()
 
-    def batch_end(self):
+    def batch_end(self, estimator, *args, **kwargs):
         if self.verbose == self.LOG_VERBOSITY_PER_BATCH:
             batch_time = time.time() - self.batch_start
-            epoch = self.estimator.current_epoch
-            batch = self.estimator.batch_idx
-            msg = '[Epoch %d] [Batch %d] ' % (epoch, batch)
-            if self.estimator.processed_samples:
-                msg += '[Samples %s] ' % (self.estimator.processed_samples)
+            msg = '[Epoch %d] [Batch %d] ' % (self.current_epoch, 
self.batch_index)
+            self.processed_samples += kwargs['batch'][0].shape[0]
+            msg += '[Samples %s] ' % (self.processed_samples)
             msg += 'time/batch: %.3fs ' % batch_time
-            for key in self.estimator.train_stats:
+            for metric in self.train_metrics:
                 # only log current training loss & metric after each batch
-                if key.startswith('train_'):
-                    msg += key + ': ' + '%.4f ' % 
self.estimator.train_stats[key]
+                name, value = metric.get()
+                msg += '%s : %.4f ' % (name, value)
             self.logger.info(msg)
+            self.batch_index += 1
 
 Review comment:
   shouldn't this be in the estimator itself, why should all handlers maintain 
this?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to