roywei commented on a change in pull request #14685: [Fit API] improve event 
handlers
URL: https://github.com/apache/incubator-mxnet/pull/14685#discussion_r277094415
 
 

 ##########
 File path: python/mxnet/gluon/contrib/estimator/estimator.py
 ##########
 @@ -226,111 +230,72 @@ def fit(self, train_data,
             custom batch function to extract data and label
             from a data batch and load into contexts(devices)
         """
-
-        self.max_epoch = epochs
-        self.stop_training = False
-        self.processed_samples = None
-        self.batch_idx = 0
-
+        self.max_epochs = epochs
         event_handlers = event_handlers or []
         # provide default logging handler
-        if not event_handlers or \
-                not any(isinstance(handler, LoggingHandler) for handler in 
event_handlers):
-            event_handlers.append(LoggingHandler())
-            warnings.warn("No Event Handler specified, default 
`LoggingHandler()` "
-                          "is used with 
verbose=LoggingHandler.LOG_VERBOSITY_PER_EPOCH. "
-                          "Please look at gluon.estimator.event_handler for 
more detail.")
+        if not event_handlers:
+            train_metrics, val_metrics = self.prepare_loss_and_metrics()
+            event_handlers.append(MetricHandler(train_metrics=train_metrics))
+            if val_data:
+                event_handlers.append(ValidationHandler(val_data=val_data, 
eval_fn=self.evaluate,
+                                                        
val_metrics=val_metrics))
+            event_handlers.append(LoggingHandler(train_metrics=train_metrics,
+                                                 val_metrics=val_metrics))
+            warnings.warn("No Event Handler specified, default %s are used. "
+                          "Please look at 
gluon.contrib.estimator.event_handler for more detail." %
+                          ", ".join([handler.__class__.__name__ for handler in 
event_handlers]))
+
+        event_handlers.sort(key=lambda handler: getattr(handler, 'rank', 0), 
reverse=True)
 
         train_begin, epoch_begin, batch_begin, \
         batch_end, epoch_end, train_end = 
self._categorize_handlers(event_handlers)
 
-        # passing estimator to event handlers so they can access estimator 
information
-        # when a event is triggered
-        for handler in event_handlers:
-            handler.estimator = self
-
+        # only pass a weak reference to all event handlers
+        estimator_ref = weakref.proxy(self)
         # training begin
         for handler in train_begin:
-            handler.train_begin()
+            handler.train_begin(estimator_ref)
 
-        for epoch in range(self.max_epoch):
+        for epoch in range(epochs):
             # epoch begin
-            self.current_epoch = epoch
-            # Number of samples trained after every batch
-            completed_samples = 0
-
             for handler in epoch_begin:
-                handler.epoch_begin()
-
-            for metric in self.train_metrics + self.train_loss_metrics:
-                metric.reset()
+                handler.epoch_begin(estimator_ref)
 
             for i, batch in enumerate(train_data):
-                if not batch_fn:
-                    if isinstance(train_data, gluon.data.DataLoader):
-                        data, label = self._batch_fn(batch, self.context)
-                    else:
-                        raise ValueError("You are using a custom iteration, 
please also provide "
-                                         "batch_fn to extract data and label. 
Alternatively, you "
-                                         "can provide the data as 
gluon.data.DataLoader")
-                else:
-                    data, label = batch_fn(batch, self.context)
+                if not isinstance(train_data, gluon.data.DataLoader):
+                    raise ValueError("Estimator only support input as Gluon 
DataLoader. Alternatively, you "
+                                     "can transform your DataIter or any 
NDArray into Gluon DataLoader. "
+                                     "Refer to gluon.data.dataloader")
+                data, label = self._get_data_and_label(batch, self.context)
 
                 batch_size = batch[0].shape[0]
 
                 # batch begin
                 for handler in batch_begin:
-                    handler.batch_begin()
+                    handler.batch_begin(estimator_ref, batch=batch)
 
                 with autograd.record():
                     pred = [self.net(x) for x in data]
-                    losses = []
-                    for loss in self.loss:
-                        losses.append([loss(y_hat, y) for y_hat, y in 
zip(pred, label)])
+                    loss = [self.loss[0](y_hat, y) for y_hat, y in zip(pred, 
label)]
 
 Review comment:
   as above https://issues.apache.org/jira/browse/MXNET-1395

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to