roywei commented on a change in pull request #14685: [Fit API] improve event handlers URL: https://github.com/apache/incubator-mxnet/pull/14685#discussion_r277094415
########## File path: python/mxnet/gluon/contrib/estimator/estimator.py ########## @@ -226,111 +230,72 @@ def fit(self, train_data, custom batch function to extract data and label from a data batch and load into contexts(devices) """ - - self.max_epoch = epochs - self.stop_training = False - self.processed_samples = None - self.batch_idx = 0 - + self.max_epochs = epochs event_handlers = event_handlers or [] # provide default logging handler - if not event_handlers or \ - not any(isinstance(handler, LoggingHandler) for handler in event_handlers): - event_handlers.append(LoggingHandler()) - warnings.warn("No Event Handler specified, default `LoggingHandler()` " - "is used with verbose=LoggingHandler.LOG_VERBOSITY_PER_EPOCH. " - "Please look at gluon.estimator.event_handler for more detail.") + if not event_handlers: + train_metrics, val_metrics = self.prepare_loss_and_metrics() + event_handlers.append(MetricHandler(train_metrics=train_metrics)) + if val_data: + event_handlers.append(ValidationHandler(val_data=val_data, eval_fn=self.evaluate, + val_metrics=val_metrics)) + event_handlers.append(LoggingHandler(train_metrics=train_metrics, + val_metrics=val_metrics)) + warnings.warn("No Event Handler specified, default %s are used. " + "Please look at gluon.contrib.estimator.event_handler for more detail." % + ", ".join([handler.__class__.__name__ for handler in event_handlers])) + + event_handlers.sort(key=lambda handler: getattr(handler, 'rank', 0), reverse=True) train_begin, epoch_begin, batch_begin, \ batch_end, epoch_end, train_end = self._categorize_handlers(event_handlers) - # passing estimator to event handlers so they can access estimator information - # when a event is triggered - for handler in event_handlers: - handler.estimator = self - + # only pass a weak reference to all event handlers + estimator_ref = weakref.proxy(self) # training begin for handler in train_begin: - handler.train_begin() + handler.train_begin(estimator_ref) - for epoch in range(self.max_epoch): + for epoch in range(epochs): # epoch begin - self.current_epoch = epoch - # Number of samples trained after every batch - completed_samples = 0 - for handler in epoch_begin: - handler.epoch_begin() - - for metric in self.train_metrics + self.train_loss_metrics: - metric.reset() + handler.epoch_begin(estimator_ref) for i, batch in enumerate(train_data): - if not batch_fn: - if isinstance(train_data, gluon.data.DataLoader): - data, label = self._batch_fn(batch, self.context) - else: - raise ValueError("You are using a custom iteration, please also provide " - "batch_fn to extract data and label. Alternatively, you " - "can provide the data as gluon.data.DataLoader") - else: - data, label = batch_fn(batch, self.context) + if not isinstance(train_data, gluon.data.DataLoader): + raise ValueError("Estimator only support input as Gluon DataLoader. Alternatively, you " + "can transform your DataIter or any NDArray into Gluon DataLoader. " + "Refer to gluon.data.dataloader") + data, label = self._get_data_and_label(batch, self.context) batch_size = batch[0].shape[0] # batch begin for handler in batch_begin: - handler.batch_begin() + handler.batch_begin(estimator_ref, batch=batch) with autograd.record(): pred = [self.net(x) for x in data] - losses = [] - for loss in self.loss: - losses.append([loss(y_hat, y) for y_hat, y in zip(pred, label)]) + loss = [self.loss[0](y_hat, y) for y_hat, y in zip(pred, label)] Review comment: as above https://issues.apache.org/jira/browse/MXNET-1395 ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services