roywei commented on a change in pull request #14494: [MXNet-1340][Fit 
API]Adding train history class
URL: https://github.com/apache/incubator-mxnet/pull/14494#discussion_r269801433
 
 

 ##########
 File path: python/mxnet/gluon/estimator/estimator.py
 ##########
 @@ -231,37 +221,45 @@ def fit(self, train_data,
             validation data with data and labels
         epochs : int, default 1
             number of epochs to iterate on the training data.
-        batch_size : int
-            number of samples per gradient update.
-            default will be 32 per device
         event_handlers : EventHandler or list of EventHandler
             list of EventHandlers to apply during training
         batch_fn : function
             custom batch function to extract data and label
             from a data batch and load into contexts(devices)
         """
 
-
-        self.epochs = epochs
+        self.max_epoch = epochs
         if not batch_size:
-            batch_size = 32 * len(self.context)
+            self.batch_size = 32 * len(self.context)
+        else:
+            self.batch_size = batch_size
+        self.stop_training = False
+        self.samples = None
+        self.batch_idx = 0
 
         event_handlers = event_handlers or []
         # provide default logging handler
         if not event_handlers or \
                 not any(isinstance(handler, LoggingHandler) for handler in 
event_handlers):
-            event_handlers.append(LoggingHandler(self))
+            event_handlers.append(LoggingHandler())
 
-        # training begin
+        train_begin, epoch_begin, batch_begin, \
+        batch_end, epoch_end, train_end = 
self._categorize_handlers(event_handlers)
+
+        # passing estimator to event handlers so they can access estimator 
information
+        # when a event is triggered
         for handler in event_handlers:
+            handler.estimator = self
 
 Review comment:
   @nswamy This will avoid to ask user passing estimator during event handler 
construction, reference: 
https://github.com/apache/incubator-mxnet/pull/14462#discussion_r267125641

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to