roywei commented on a change in pull request #14494: [MXNet-1340][Fit API]Adding train history class URL: https://github.com/apache/incubator-mxnet/pull/14494#discussion_r269801433
########## File path: python/mxnet/gluon/estimator/estimator.py ########## @@ -231,37 +221,45 @@ def fit(self, train_data, validation data with data and labels epochs : int, default 1 number of epochs to iterate on the training data. - batch_size : int - number of samples per gradient update. - default will be 32 per device event_handlers : EventHandler or list of EventHandler list of EventHandlers to apply during training batch_fn : function custom batch function to extract data and label from a data batch and load into contexts(devices) """ - - self.epochs = epochs + self.max_epoch = epochs if not batch_size: - batch_size = 32 * len(self.context) + self.batch_size = 32 * len(self.context) + else: + self.batch_size = batch_size + self.stop_training = False + self.samples = None + self.batch_idx = 0 event_handlers = event_handlers or [] # provide default logging handler if not event_handlers or \ not any(isinstance(handler, LoggingHandler) for handler in event_handlers): - event_handlers.append(LoggingHandler(self)) + event_handlers.append(LoggingHandler()) - # training begin + train_begin, epoch_begin, batch_begin, \ + batch_end, epoch_end, train_end = self._categorize_handlers(event_handlers) + + # passing estimator to event handlers so they can access estimator information + # when a event is triggered for handler in event_handlers: + handler.estimator = self Review comment: @nswamy This will avoid to ask user passing estimator during event handler construction, reference: https://github.com/apache/incubator-mxnet/pull/14462#discussion_r267125641 ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services