This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new ad3f0d8  revert acc changes (#9731)
ad3f0d8 is described below

commit ad3f0d8faeaabf302923796d1379bec6040e8d53
Author: Sheng Zha <s...@users.noreply.github.com>
AuthorDate: Wed Feb 7 10:07:58 2018 -0800

    revert acc changes (#9731)
    
    * Revert "avoid per-batch blocking in metric (#9636)"
    
    This reverts commit 3fe694e7b1ed7fa6a2dcfeddeac44c14ab77b015.
    
    * Revert "proper flatten in acc (#9619)"
    
    This reverts commit ed823b2e187eb859d9475eb651465edf714c6c5f.
    
    * Revert "use nd for accuracy calculation (#9583)"
    
    This reverts commit f5f1b91ff972ad70e9131d3cd1d7408ddddb7684.
    
    * keep doc change
---
 python/mxnet/metric.py | 15 ++++-----------
 1 file changed, 4 insertions(+), 11 deletions(-)

diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py
index e91fd3b..8bb3f6e 100644
--- a/python/mxnet/metric.py
+++ b/python/mxnet/metric.py
@@ -28,7 +28,6 @@ import numpy
 from .base import numeric_types, string_types
 from . import ndarray
 from . import registry
-from .context import cpu
 
 
 def check_label_shapes(labels, preds, shape=0):
@@ -389,22 +388,16 @@ class Accuracy(EvalMetric):
         """
         check_label_shapes(labels, preds)
 
-        results = []
         for label, pred_label in zip(labels, preds):
             if pred_label.shape != label.shape:
                 pred_label = ndarray.argmax(pred_label, axis=self.axis)
-            pred_label = pred_label.astype('int32')
-            label = label.astype('int32')
+            pred_label = pred_label.asnumpy().astype('int32')
+            label = label.asnumpy().astype('int32')
 
             check_label_shapes(label, pred_label)
 
-            if pred_label.context != label.context:
-                pred_label = pred_label.as_in_context(label.context)
-
-            self.num_inst += pred_label.size
-            results.append((pred_label.reshape((-1,)) == label.reshape((-1,)))
-                           .sum().as_in_context(cpu()))
-        self.sum_metric += ndarray.add_n(*results).asscalar()
+            self.sum_metric += (pred_label.flat == label.flat).sum()
+            self.num_inst += len(pred_label.flat)
 
 
 @register

-- 
To stop receiving notification emails like this one, please contact
hai...@apache.org.

Reply via email to