abhinavs95 commented on a change in pull request #14442: [MXNet-1349][Fit 
API]Add validation support and unit tests for fit() API
URL: https://github.com/apache/incubator-mxnet/pull/14442#discussion_r267095204
 
 

 ##########
 File path: python/mxnet/gluon/estimator/estimator.py
 ##########
 @@ -156,7 +168,33 @@ def _batch_fn(self, batch, ctx, is_iterator=False):
         label = gluon.utils.split_and_load(label, ctx_list=ctx, batch_axis=0)
         return data, label
 
+    def _evaluate(self, val_data, batch_fn=None):
+        for metric in self.test_metrics + self.test_loss_metrics:
+            metric.reset()
+
+        for _, batch in enumerate(val_data):
+            if not batch_fn:
+                if isinstance(val_data, gluon.data.DataLoader):
+                    data, label = self._batch_fn(batch, self.context)
+                elif isinstance(val_data, DataIter):
+                    data, label = self._batch_fn(batch, self.context, 
is_iterator=True)
+                else:
+                    raise ValueError("You are using a custom iteration, please 
also provide "
+                                     "batch_fn to extract data and label")
+            else:
+                data, label = batch_fn(batch, self.context)
+            pred = [self.net(x) for x in data]
+            losses = []
+            for loss in self.loss:
+                losses.append([loss(y_hat, y) for y_hat, y in zip(pred, 
label)])
+            # update metrics
+            for metric in self.test_metrics:
+                metric.update(label, pred)
+            for loss, loss_metric, in zip(losses, self.test_loss_metrics):
+                loss_metric.update(0, [l for l in loss])
+
     def fit(self, train_data,
+            val_data=None,
 
 Review comment:
   users might want to train without a validation set

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to