This is an automated email from the ASF dual-hosted git repository. liuyizhi pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new 8df20a2 [scala] EvalMetric sumMetric is now a Double instead of a Float (#8297) 8df20a2 is described below commit 8df20a2bd074c4ab55a9b61e0ec04da48bec6426 Author: BenoƮt Quartier <benoit.quart...@a3.epfl.ch> AuthorDate: Wed Nov 22 02:44:15 2017 +0100 [scala] EvalMetric sumMetric is now a Double instead of a Float (#8297) When the difference in magnitude between the total accuracy and 1 becomes too big and accuracy is not updated anymore due to the low precision of float numbers. --- .../core/src/main/scala/ml/dmlc/mxnet/EvalMetric.scala | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/EvalMetric.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/EvalMetric.scala index 6b993d7..98a09d2 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/EvalMetric.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/EvalMetric.scala @@ -26,7 +26,7 @@ import scala.collection.mutable.ArrayBuffer abstract class EvalMetric(protected val name: String) { protected var numInst: Int = 0 - protected var sumMetric: Float = 0.0f + protected var sumMetric: Double = 0.0d /** * Update the internal evaluation. @@ -41,7 +41,7 @@ abstract class EvalMetric(protected val name: String) { */ def reset(): Unit = { this.numInst = 0 - this.sumMetric = 0.0f + this.sumMetric = 0.0d } /** @@ -50,7 +50,7 @@ abstract class EvalMetric(protected val name: String) { * value, Value of the evaluation */ def get: (Array[String], Array[Float]) = { - (Array(this.name), Array(this.sumMetric / this.numInst)) + (Array(this.name), Array((this.sumMetric / this.numInst).toFloat)) } } @@ -111,11 +111,10 @@ class Accuracy extends EvalMetric("accuracy") { require(label.shape == predLabel.shape, s"label ${label.shape} and prediction ${predLabel.shape}" + s"should have the same length.") - for ((labelElem, predElem) <- label.toArray zip predLabel.toArray) { - if (labelElem == predElem) { - this.sumMetric += 1 - } - } + + this.sumMetric += label.toArray.zip(predLabel.toArray) + .filter{ case (labelElem: Float, predElem: Float) => labelElem == predElem } + .size this.numInst += predLabel.shape(0) predLabel.dispose() } -- To stop receiving notification emails like this one, please contact ['"comm...@mxnet.apache.org" <comm...@mxnet.apache.org>'].