This is an automated email from the ASF dual-hosted git repository.

zaleslaw pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/ignite.git


The following commit(s) were added to refs/heads/master by this push:
     new 579036d  IGNITE-10145: [ML] Implement ROC AUC metric (#6394)
579036d is described below

commit 579036d2ebdae7c6372a925589f33884536eea0b
Author: Alexey Zinoviev <zaleslaw....@gmail.com>
AuthorDate: Thu Apr 4 14:24:56 2019 +0300

    IGNITE-10145: [ML] Implement ROC AUC metric (#6394)
---
 .../BinaryClassificationMetricValues.java          |  12 +-
 .../BinaryClassificationMetrics.java               |  47 ++++++-
 .../scoring/metric/classification/ROCAUC.java      | 141 +++++++++++++++++++++
 .../BinaryClassificationMetricsValuesTest.java     |   3 +-
 .../scoring/metric/classification/ROCAUCTest.java  | 121 ++++++++++++++++++
 5 files changed, 319 insertions(+), 5 deletions(-)

diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/classification/BinaryClassificationMetricValues.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/classification/BinaryClassificationMetricValues.java
index e04ddc0..841faf6 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/classification/BinaryClassificationMetricValues.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/classification/BinaryClassificationMetricValues.java
@@ -65,6 +65,9 @@ public class BinaryClassificationMetricValues implements 
MetricValues {
     /** F1-Score is the harmonic mean of Precision and Sensitivity. */
     private double f1Score;
 
+    /** ROC AUC. */
+    private double rocauc;
+
     /**
      * Initialize an example by 4 metrics.
      *
@@ -73,11 +76,12 @@ public class BinaryClassificationMetricValues implements 
MetricValues {
      * @param fp False Positive (FP).
      * @param fn False Negative (FN).
      */
-    public BinaryClassificationMetricValues(long tp, long tn, long fp, long 
fn) {
+    public BinaryClassificationMetricValues(long tp, long tn, long fp, long 
fn, double rocauc) {
         this.tp = tp;
         this.tn = tn;
         this.fp = fp;
         this.fn = fn;
+        this.rocauc = rocauc;
 
         long p = tp + fn;
         long n = tn + fp;
@@ -168,4 +172,10 @@ public class BinaryClassificationMetricValues implements 
MetricValues {
     public double f1Score() {
         return f1Score;
     }
+
+    /** Returns ROCAUC value. */
+    public double rocauc() {
+        return rocauc;
+    }
+
 }
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/classification/BinaryClassificationMetrics.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/classification/BinaryClassificationMetrics.java
index bfa1cf3..53c21a0 100644
--- 
a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/classification/BinaryClassificationMetrics.java
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/classification/BinaryClassificationMetrics.java
@@ -17,12 +17,14 @@
 
 package org.apache.ignite.ml.selection.scoring.metric.classification;
 
+import java.util.Comparator;
+import java.util.Iterator;
+import java.util.PriorityQueue;
+import org.apache.commons.math3.util.Pair;
 import org.apache.ignite.ml.selection.scoring.LabelPair;
 import org.apache.ignite.ml.selection.scoring.metric.AbstractMetrics;
 import 
org.apache.ignite.ml.selection.scoring.metric.exceptions.UnknownClassLabelException;
 
-import java.util.Iterator;
-
 /**
  * Binary classification metrics calculator.
  * It could be used in two ways: to caculate all binary classification metrics 
or specific metric.
@@ -34,6 +36,9 @@ public class BinaryClassificationMetrics extends 
AbstractMetrics<BinaryClassific
     /** Negative class label. Default value is 0.0. */
     private double negativeClsLb;
 
+    /** This flag enabels ROC AUC calculation that is hard for perfromance due 
to internal implementation. */
+    private boolean enableROCAUC;
+
     {
         metric = BinaryClassificationMetricValues::accuracy;
     }
@@ -48,6 +53,12 @@ public class BinaryClassificationMetrics extends 
AbstractMetrics<BinaryClassific
         long tn = 0;
         long fp = 0;
         long fn = 0;
+        double rocauc = Double.NaN;
+
+        // for ROC AUC calculation
+        long pos = 0;
+        long neg = 0;
+        PriorityQueue<Pair<Double, Double>> queue = new 
PriorityQueue<>(Comparator.comparingDouble(Pair::getKey));
 
         while (iter.hasNext()) {
             LabelPair<Double> e = iter.next();
@@ -64,9 +75,26 @@ public class BinaryClassificationMetrics extends 
AbstractMetrics<BinaryClassific
             else if (truth == positiveClsLb && prediction == negativeClsLb) 
fn++;
             else if (truth == negativeClsLb && prediction == negativeClsLb) 
tn++;
             else if (truth == negativeClsLb && prediction == positiveClsLb) 
fp++;
+
+
+            if(enableROCAUC) {
+                queue.add(new Pair<>(prediction, truth));
+
+                if (truth == positiveClsLb)
+                    pos++;
+                else if (truth == negativeClsLb)
+                    neg++;
+                else
+                    throw new UnknownClassLabelException(truth, positiveClsLb, 
negativeClsLb);
+            }
+
         }
 
-        return new BinaryClassificationMetricValues(tp, tn, fp, fn);
+        if (enableROCAUC)
+            rocauc = ROCAUC.calculateROCAUC(queue, pos, neg, positiveClsLb);
+
+
+        return new BinaryClassificationMetricValues(tp, tn, fp, fn, rocauc);
     }
 
     /** */
@@ -93,6 +121,19 @@ public class BinaryClassificationMetrics extends 
AbstractMetrics<BinaryClassific
         return this;
     }
 
+    /** */
+    public BinaryClassificationMetrics withEnablingROCAUC(boolean 
enableROCAUC) {
+        this.enableROCAUC = this.enableROCAUC;
+        return this;
+    }
+
+
+    /** */
+    public boolean isROCAUCenabled() {
+        return enableROCAUC;
+    }
+
+
     /** {@inheritDoc} */
     @Override public String name() {
         return "Binary classification metrics";
diff --git 
a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/classification/ROCAUC.java
 
b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/classification/ROCAUC.java
new file mode 100644
index 0000000..5911a28
--- /dev/null
+++ 
b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/classification/ROCAUC.java
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.selection.scoring.metric.classification;
+
+import java.util.Comparator;
+import java.util.Iterator;
+import java.util.PriorityQueue;
+import org.apache.commons.math3.util.Pair;
+import org.apache.ignite.ml.selection.scoring.LabelPair;
+import org.apache.ignite.ml.selection.scoring.metric.Metric;
+import 
org.apache.ignite.ml.selection.scoring.metric.exceptions.UnknownClassLabelException;
+
+/**
+ * ROC AUC score calculator.
+ * <p>
+ * The calculation of AUC is based on Mann-Whitney U test
+ * (https://en.wikipedia.org/wiki/Mann-Whitney_U_test).
+ */
+public class ROCAUC implements Metric<Double> {
+    /** Positive class label. */
+    private double positiveClsLb = 1.0;
+
+    /** Negative class label. Default value is 0.0. */
+    private double negativeClsLb;
+
+    /** {@inheritDoc} */
+    @Override public double score(Iterator<LabelPair<Double>> iter) {
+        //TODO: It should work with not binary values only, see IGNITE-11680.
+
+        PriorityQueue<Pair<Double, Double>> queue = new 
PriorityQueue<>(Comparator.comparingDouble(Pair::getKey));
+
+        long pos = 0;
+        long neg = 0;
+
+        while (iter.hasNext()) {
+            LabelPair<Double> e = iter.next();
+
+            Double prediction = e.getPrediction();
+            Double truth = e.getTruth();
+
+            queue.add(new Pair<>(prediction, truth));
+
+            if (truth == positiveClsLb)
+                pos++;
+            else if (truth == negativeClsLb)
+                neg++;
+            else
+                throw new UnknownClassLabelException(truth, positiveClsLb, 
negativeClsLb);
+
+        }
+
+        return calculateROCAUC(queue, pos, neg, positiveClsLb);
+    }
+
+    /**
+     * Calculates the ROC AUC value based on queue of pairs,
+     * amount of positive/negative cases and label of positive class.
+     */
+    public static double calculateROCAUC(PriorityQueue<Pair<Double, Double>> 
queue, long pos, long neg, double positiveClsLb) {
+        double[] lb = new double[queue.size()];
+        double[] prediction = new double[queue.size()];
+        int cnt = 0;
+
+        while (!queue.isEmpty()) {
+            Pair<Double, Double> elem = queue.poll();
+            lb[cnt] = elem.getValue();
+            prediction[cnt] = elem.getKey();
+            cnt++;
+        }
+
+        double[] rank = new double[lb.length];
+        for (int i = 0; i < prediction.length; i++) {
+            if (i == prediction.length - 1 || prediction[i] != prediction[i + 
1])
+                rank[i] = i + 1;
+            else {
+                int j = i + 1;
+                for (; j < prediction.length && prediction[j] == 
prediction[i]; j++);
+                double r = (i + 1 + j) / 2.0;
+                for (int k = i; k < j; k++)
+                    rank[k] = r;
+                i = j - 1;
+            }
+        }
+
+        double auc = 0.0;
+        for (int i = 0; i < lb.length; i++) {
+            if (lb[i] == positiveClsLb)
+                auc += rank[i];
+        }
+
+        if (pos == 0L) return Double.NaN;
+        else if (neg == 0L) return Double.NaN;
+
+        auc = (auc - (pos * (pos + 1) / 2.0)) / (pos * neg);
+        return auc;
+    }
+
+    /** Get the positive label. */
+    public double positiveClsLb() {
+        return positiveClsLb;
+    }
+
+    /** Set the positive label. */
+    public ROCAUC withPositiveClsLb(double positiveClsLb) {
+        if (Double.isFinite(positiveClsLb))
+            this.positiveClsLb = positiveClsLb;
+        return this;
+    }
+
+    /** Get the negative label. */
+    public double negativeClsLb() {
+        return negativeClsLb;
+    }
+
+    /** Set the negative label. */
+    public ROCAUC withNegativeClsLb(double negativeClsLb) {
+        if (Double.isFinite(negativeClsLb))
+            this.negativeClsLb = negativeClsLb;
+        return this;
+    }
+
+    /** {@inheritDoc} */
+    @Override public String name() {
+        return "ROC AUC";
+    }
+}
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/classification/BinaryClassificationMetricsValuesTest.java
 
b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/classification/BinaryClassificationMetricsValuesTest.java
index f513ce3..4e7b03a 100644
--- 
a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/classification/BinaryClassificationMetricsValuesTest.java
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/classification/BinaryClassificationMetricsValuesTest.java
@@ -28,7 +28,7 @@ public class BinaryClassificationMetricsValuesTest {
     /** */
     @Test
     public void testDefaultBehaviour() {
-        BinaryClassificationMetricValues metricValues = new 
BinaryClassificationMetricValues(10, 10, 5, 5);
+        BinaryClassificationMetricValues metricValues = new 
BinaryClassificationMetricValues(10, 10, 5, 5, 0.5);
 
         assertEquals(10, metricValues.tp(), 1e-2);
         assertEquals(10, metricValues.tn(), 1e-2);
@@ -44,5 +44,6 @@ public class BinaryClassificationMetricsValuesTest {
         assertEquals(0.66, metricValues.precision(), 1e-2);
         assertEquals(0.66, metricValues.recall(), 1e-2);
         assertEquals(0.66, metricValues.specificity(), 1e-2);
+        assertEquals(0.5, metricValues.rocauc(), 1e-2);
     }
 }
diff --git 
a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/classification/ROCAUCTest.java
 
b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/classification/ROCAUCTest.java
new file mode 100644
index 0000000..99451a3
--- /dev/null
+++ 
b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/classification/ROCAUCTest.java
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.selection.scoring.metric.classification;
+
+import java.util.Arrays;
+import org.apache.ignite.ml.selection.scoring.TestLabelPairCursor;
+import org.apache.ignite.ml.selection.scoring.cursor.LabelPairCursor;
+import org.apache.ignite.ml.selection.scoring.metric.Metric;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * Tests for {@link ROCAUC}.
+ */
+public class ROCAUCTest {
+    /** */
+    @Test
+    public void testTotalTruth() {
+        Metric<Double> scoreCalculator = new ROCAUC();
+
+        LabelPairCursor<Double> cursor = new TestLabelPairCursor<>(
+            Arrays.asList(1.0, 1.0, 1.0, 1.0, 1.0),
+            Arrays.asList(1.0, 1.0, 1.0, 1.0, 1.0)
+        );
+
+        double score = scoreCalculator.score(cursor.iterator());
+
+        assertEquals(Double.NaN, score, 1e-12);
+    }
+
+    /** */
+    @Test
+    public void testTotalUntruth() {
+        Metric<Double> scoreCalculator = new ROCAUC();
+
+        LabelPairCursor<Double> cursor = new TestLabelPairCursor<>(
+            Arrays.asList(1.0, 1.0, 1.0, 1.0, 1.0),
+            Arrays.asList(0.0, 0.0, 0.0, 0.0, 0.0)
+        );
+
+        double score = scoreCalculator.score(cursor.iterator());
+
+        assertEquals(Double.NaN, score, 1e-12);
+    }
+
+    /** */
+    @Test
+    public void testOneDifferent() {
+        Metric<Double> scoreCalculator = new ROCAUC();
+
+        LabelPairCursor<Double> cursor = new TestLabelPairCursor<>(
+            Arrays.asList(1.0, 1.0, 1.0, 0.0),
+            Arrays.asList(1.0, 1.0, 1.0, 1.0)
+        );
+
+        double score = scoreCalculator.score(cursor.iterator());
+
+        assertEquals(0.5, score, 1e-12);
+    }
+
+    /** */
+    @Test
+    public void testOneDifferentButBalanced() {
+        Metric<Double> scoreCalculator = new ROCAUC();
+
+        LabelPairCursor<Double> cursor = new TestLabelPairCursor<>(
+            Arrays.asList(1.0, 1.0, 0.0, 0.0),
+            Arrays.asList(1.0, 1.0, 0.0, 1.0)
+        );
+
+        double score = scoreCalculator.score(cursor.iterator());
+
+        assertEquals(0.75, score, 1e-12);
+    }
+
+    /** */
+    @Test
+    public void testTwoDifferentAndBalanced() {
+        Metric<Double> scoreCalculator = new ROCAUC();
+
+        LabelPairCursor<Double> cursor = new TestLabelPairCursor<>(
+            Arrays.asList(1.0, 1.0, 0.0, 0.0),
+            Arrays.asList(1.0, 0.0, 0.0, 1.0)
+        );
+
+        double score = scoreCalculator.score(cursor.iterator());
+
+        assertEquals(0.5, score, 1e-12);
+    }
+
+    /** */
+    @Test
+    public void testNotOnlyBinaryValues() {
+        Metric<Double> scoreCalculator = new ROCAUC();
+
+        LabelPairCursor<Double> cursor = new TestLabelPairCursor<>(
+            Arrays.asList(1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0),
+            Arrays.asList(0.40209054, 0.33697626, 0.5449324 , 0.13010869, 
0.19019675, 0.39767829, 0.9686739 , 0.91783275, 0.7503783 , 0.5306605)
+        );
+
+        double score = scoreCalculator.score(cursor.iterator());
+
+        assertEquals(0.625, score, 1e-12);
+    }
+}

Reply via email to