leezu commented on a change in pull request #18083:
URL: https://github.com/apache/incubator-mxnet/pull/18083#discussion_r412605078



##########
File path: python/mxnet/gluon/metric.py
##########
@@ -619,90 +583,100 @@ def update_binary_stats(self, label, pred):
         """
         pred = pred.asnumpy()
         label = label.asnumpy().astype('int32')
-        pred_label = numpy.argmax(pred, axis=1)
-
-        check_label_shapes(label, pred)
-        if len(numpy.unique(label)) > 2:
-            raise ValueError("%s currently only supports binary 
classification."
-                             % self.__class__.__name__)
+        if self.class_type == "binary":
+            self._set(1)
+            if len(numpy.unique(label)) > 2:
+                raise ValueError("Wrong label for binary classification.")
+            if pred.shape == label.shape:
+                pass
+            elif pred.shape[-1] > 2:
+                raise ValueError("The shape of prediction {} is wrong for 
binary classification.".format(pred.shape))
+            elif pred.shape[-1] == 2:
+                pred = pred.reshape(-1, 2)[:, 1]     
+            pred_label = predict_with_threshold(pred, self.threshold).flat
+            label = label.flat
+            
+        elif self.class_type == "multiclass":
+            num = pred.shape[-1]
+            self._set(num)
+            assert label.max() < num, "pred contains fewer classes than label!"
+            pred_label = one_hot(pred.argmax(axis=-1).reshape(-1), num)        
 
+            label = one_hot(label.reshape(-1), num)
+            
+        elif self.class_type == "multilabel":
+            num = pred.shape[-1]
+            self._set(num)
+            assert pred.shape == label.shape, "The shape of label should be 
same as that of prediction for multilabel classification."
+            pred_label = predict_with_threshold(pred, 
self.threshold).reshape(-1, num)
+            label = label.reshape(-1, num)
+        else:
+            raise ValueError("Wrong class_type {}! Only supports ['binary', 
'multiclass', 'multilabel']".format(self.class_type))
+            
+        check_label_shapes(label, pred_label)
+        
         pred_true = (pred_label == 1)
         pred_false = 1 - pred_true
         label_true = (label == 1)
         label_false = 1 - label_true
 
-        true_pos = (pred_true * label_true).sum()
-        false_pos = (pred_true * label_false).sum()
-        false_neg = (pred_false * label_true).sum()
-        true_neg = (pred_false * label_false).sum()
+        true_pos = (pred_true * label_true).sum(0)
+        false_pos = (pred_true * label_false).sum(0)
+        false_neg = (pred_false * label_true).sum(0)
+        true_neg = (pred_false * label_false).sum(0)
         self.true_positives += true_pos
-        self.global_true_positives += true_pos
         self.false_positives += false_pos
-        self.global_false_positives += false_pos
         self.false_negatives += false_neg
-        self.global_false_negatives += false_neg
         self.true_negatives += true_neg
-        self.global_true_negatives += true_neg
 
     @property
     def precision(self):
-        if self.true_positives + self.false_positives > 0:
-            return float(self.true_positives) / (self.true_positives + 
self.false_positives)
+        if self.num_classes is not None:
+            return self.true_positives / numpy.maximum(self.true_positives + 
self.false_positives, 1e-12)
         else:
             return 0.
 
     @property
     def global_precision(self):
-        if self.global_true_positives + self.global_false_positives > 0:
-            return float(self.global_true_positives) / 
(self.global_true_positives + self.global_false_positives)
+        if self.num_classes is not None:
+            return self.true_positives.sum() / 
numpy.maximum(self.true_positives.sum() + self.false_positives.sum(), 1e-12)
         else:
             return 0.
-
+            
     @property
     def recall(self):
-        if self.true_positives + self.false_negatives > 0:
-            return float(self.true_positives) / (self.true_positives + 
self.false_negatives)
+        if self.num_classes is not None:
+            return self.true_positives / numpy.maximum(self.true_positives + 
self.false_negatives, 1e-12)
         else:
             return 0.
 
     @property
     def global_recall(self):
-        if self.global_true_positives + self.global_false_negatives > 0:
-            return float(self.global_true_positives) / 
(self.global_true_positives + self.global_false_negatives)
+        if self.num_classes is not None:
+            return self.true_positives.sum() / 
numpy.maximum(self.true_positives.sum() + self.false_negatives.sum(), 1e-12)
         else:
             return 0.
-
+            
     @property
     def fscore(self):
-        if self.precision + self.recall > 0:
-            return 2 * self.precision * self.recall / (self.precision + 
self.recall)
-        else:
-            return 0.
+        return (1 + self.beta ** 2) * self.precision * self.recall / 
numpy.maximum(self.beta ** 2 * self.precision + self.recall, 1e-12)
 
     @property
     def global_fscore(self):

Review comment:
       This method should be removed as you dropped the global states?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to