This is an automated email from the ASF dual-hosted git repository. zaleslaw pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/ignite.git
The following commit(s) were added to refs/heads/master by this push: new 579036d IGNITE-10145: [ML] Implement ROC AUC metric (#6394) 579036d is described below commit 579036d2ebdae7c6372a925589f33884536eea0b Author: Alexey Zinoviev <zaleslaw....@gmail.com> AuthorDate: Thu Apr 4 14:24:56 2019 +0300 IGNITE-10145: [ML] Implement ROC AUC metric (#6394) --- .../BinaryClassificationMetricValues.java | 12 +- .../BinaryClassificationMetrics.java | 47 ++++++- .../scoring/metric/classification/ROCAUC.java | 141 +++++++++++++++++++++ .../BinaryClassificationMetricsValuesTest.java | 3 +- .../scoring/metric/classification/ROCAUCTest.java | 121 ++++++++++++++++++ 5 files changed, 319 insertions(+), 5 deletions(-) diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/classification/BinaryClassificationMetricValues.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/classification/BinaryClassificationMetricValues.java index e04ddc0..841faf6 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/classification/BinaryClassificationMetricValues.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/classification/BinaryClassificationMetricValues.java @@ -65,6 +65,9 @@ public class BinaryClassificationMetricValues implements MetricValues { /** F1-Score is the harmonic mean of Precision and Sensitivity. */ private double f1Score; + /** ROC AUC. */ + private double rocauc; + /** * Initialize an example by 4 metrics. * @@ -73,11 +76,12 @@ public class BinaryClassificationMetricValues implements MetricValues { * @param fp False Positive (FP). * @param fn False Negative (FN). */ - public BinaryClassificationMetricValues(long tp, long tn, long fp, long fn) { + public BinaryClassificationMetricValues(long tp, long tn, long fp, long fn, double rocauc) { this.tp = tp; this.tn = tn; this.fp = fp; this.fn = fn; + this.rocauc = rocauc; long p = tp + fn; long n = tn + fp; @@ -168,4 +172,10 @@ public class BinaryClassificationMetricValues implements MetricValues { public double f1Score() { return f1Score; } + + /** Returns ROCAUC value. */ + public double rocauc() { + return rocauc; + } + } diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/classification/BinaryClassificationMetrics.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/classification/BinaryClassificationMetrics.java index bfa1cf3..53c21a0 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/classification/BinaryClassificationMetrics.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/classification/BinaryClassificationMetrics.java @@ -17,12 +17,14 @@ package org.apache.ignite.ml.selection.scoring.metric.classification; +import java.util.Comparator; +import java.util.Iterator; +import java.util.PriorityQueue; +import org.apache.commons.math3.util.Pair; import org.apache.ignite.ml.selection.scoring.LabelPair; import org.apache.ignite.ml.selection.scoring.metric.AbstractMetrics; import org.apache.ignite.ml.selection.scoring.metric.exceptions.UnknownClassLabelException; -import java.util.Iterator; - /** * Binary classification metrics calculator. * It could be used in two ways: to caculate all binary classification metrics or specific metric. @@ -34,6 +36,9 @@ public class BinaryClassificationMetrics extends AbstractMetrics<BinaryClassific /** Negative class label. Default value is 0.0. */ private double negativeClsLb; + /** This flag enabels ROC AUC calculation that is hard for perfromance due to internal implementation. */ + private boolean enableROCAUC; + { metric = BinaryClassificationMetricValues::accuracy; } @@ -48,6 +53,12 @@ public class BinaryClassificationMetrics extends AbstractMetrics<BinaryClassific long tn = 0; long fp = 0; long fn = 0; + double rocauc = Double.NaN; + + // for ROC AUC calculation + long pos = 0; + long neg = 0; + PriorityQueue<Pair<Double, Double>> queue = new PriorityQueue<>(Comparator.comparingDouble(Pair::getKey)); while (iter.hasNext()) { LabelPair<Double> e = iter.next(); @@ -64,9 +75,26 @@ public class BinaryClassificationMetrics extends AbstractMetrics<BinaryClassific else if (truth == positiveClsLb && prediction == negativeClsLb) fn++; else if (truth == negativeClsLb && prediction == negativeClsLb) tn++; else if (truth == negativeClsLb && prediction == positiveClsLb) fp++; + + + if(enableROCAUC) { + queue.add(new Pair<>(prediction, truth)); + + if (truth == positiveClsLb) + pos++; + else if (truth == negativeClsLb) + neg++; + else + throw new UnknownClassLabelException(truth, positiveClsLb, negativeClsLb); + } + } - return new BinaryClassificationMetricValues(tp, tn, fp, fn); + if (enableROCAUC) + rocauc = ROCAUC.calculateROCAUC(queue, pos, neg, positiveClsLb); + + + return new BinaryClassificationMetricValues(tp, tn, fp, fn, rocauc); } /** */ @@ -93,6 +121,19 @@ public class BinaryClassificationMetrics extends AbstractMetrics<BinaryClassific return this; } + /** */ + public BinaryClassificationMetrics withEnablingROCAUC(boolean enableROCAUC) { + this.enableROCAUC = this.enableROCAUC; + return this; + } + + + /** */ + public boolean isROCAUCenabled() { + return enableROCAUC; + } + + /** {@inheritDoc} */ @Override public String name() { return "Binary classification metrics"; diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/classification/ROCAUC.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/classification/ROCAUC.java new file mode 100644 index 0000000..5911a28 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/scoring/metric/classification/ROCAUC.java @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.selection.scoring.metric.classification; + +import java.util.Comparator; +import java.util.Iterator; +import java.util.PriorityQueue; +import org.apache.commons.math3.util.Pair; +import org.apache.ignite.ml.selection.scoring.LabelPair; +import org.apache.ignite.ml.selection.scoring.metric.Metric; +import org.apache.ignite.ml.selection.scoring.metric.exceptions.UnknownClassLabelException; + +/** + * ROC AUC score calculator. + * <p> + * The calculation of AUC is based on Mann-Whitney U test + * (https://en.wikipedia.org/wiki/Mann-Whitney_U_test). + */ +public class ROCAUC implements Metric<Double> { + /** Positive class label. */ + private double positiveClsLb = 1.0; + + /** Negative class label. Default value is 0.0. */ + private double negativeClsLb; + + /** {@inheritDoc} */ + @Override public double score(Iterator<LabelPair<Double>> iter) { + //TODO: It should work with not binary values only, see IGNITE-11680. + + PriorityQueue<Pair<Double, Double>> queue = new PriorityQueue<>(Comparator.comparingDouble(Pair::getKey)); + + long pos = 0; + long neg = 0; + + while (iter.hasNext()) { + LabelPair<Double> e = iter.next(); + + Double prediction = e.getPrediction(); + Double truth = e.getTruth(); + + queue.add(new Pair<>(prediction, truth)); + + if (truth == positiveClsLb) + pos++; + else if (truth == negativeClsLb) + neg++; + else + throw new UnknownClassLabelException(truth, positiveClsLb, negativeClsLb); + + } + + return calculateROCAUC(queue, pos, neg, positiveClsLb); + } + + /** + * Calculates the ROC AUC value based on queue of pairs, + * amount of positive/negative cases and label of positive class. + */ + public static double calculateROCAUC(PriorityQueue<Pair<Double, Double>> queue, long pos, long neg, double positiveClsLb) { + double[] lb = new double[queue.size()]; + double[] prediction = new double[queue.size()]; + int cnt = 0; + + while (!queue.isEmpty()) { + Pair<Double, Double> elem = queue.poll(); + lb[cnt] = elem.getValue(); + prediction[cnt] = elem.getKey(); + cnt++; + } + + double[] rank = new double[lb.length]; + for (int i = 0; i < prediction.length; i++) { + if (i == prediction.length - 1 || prediction[i] != prediction[i + 1]) + rank[i] = i + 1; + else { + int j = i + 1; + for (; j < prediction.length && prediction[j] == prediction[i]; j++); + double r = (i + 1 + j) / 2.0; + for (int k = i; k < j; k++) + rank[k] = r; + i = j - 1; + } + } + + double auc = 0.0; + for (int i = 0; i < lb.length; i++) { + if (lb[i] == positiveClsLb) + auc += rank[i]; + } + + if (pos == 0L) return Double.NaN; + else if (neg == 0L) return Double.NaN; + + auc = (auc - (pos * (pos + 1) / 2.0)) / (pos * neg); + return auc; + } + + /** Get the positive label. */ + public double positiveClsLb() { + return positiveClsLb; + } + + /** Set the positive label. */ + public ROCAUC withPositiveClsLb(double positiveClsLb) { + if (Double.isFinite(positiveClsLb)) + this.positiveClsLb = positiveClsLb; + return this; + } + + /** Get the negative label. */ + public double negativeClsLb() { + return negativeClsLb; + } + + /** Set the negative label. */ + public ROCAUC withNegativeClsLb(double negativeClsLb) { + if (Double.isFinite(negativeClsLb)) + this.negativeClsLb = negativeClsLb; + return this; + } + + /** {@inheritDoc} */ + @Override public String name() { + return "ROC AUC"; + } +} diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/classification/BinaryClassificationMetricsValuesTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/classification/BinaryClassificationMetricsValuesTest.java index f513ce3..4e7b03a 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/classification/BinaryClassificationMetricsValuesTest.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/classification/BinaryClassificationMetricsValuesTest.java @@ -28,7 +28,7 @@ public class BinaryClassificationMetricsValuesTest { /** */ @Test public void testDefaultBehaviour() { - BinaryClassificationMetricValues metricValues = new BinaryClassificationMetricValues(10, 10, 5, 5); + BinaryClassificationMetricValues metricValues = new BinaryClassificationMetricValues(10, 10, 5, 5, 0.5); assertEquals(10, metricValues.tp(), 1e-2); assertEquals(10, metricValues.tn(), 1e-2); @@ -44,5 +44,6 @@ public class BinaryClassificationMetricsValuesTest { assertEquals(0.66, metricValues.precision(), 1e-2); assertEquals(0.66, metricValues.recall(), 1e-2); assertEquals(0.66, metricValues.specificity(), 1e-2); + assertEquals(0.5, metricValues.rocauc(), 1e-2); } } diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/classification/ROCAUCTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/classification/ROCAUCTest.java new file mode 100644 index 0000000..99451a3 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/scoring/metric/classification/ROCAUCTest.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.selection.scoring.metric.classification; + +import java.util.Arrays; +import org.apache.ignite.ml.selection.scoring.TestLabelPairCursor; +import org.apache.ignite.ml.selection.scoring.cursor.LabelPairCursor; +import org.apache.ignite.ml.selection.scoring.metric.Metric; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** + * Tests for {@link ROCAUC}. + */ +public class ROCAUCTest { + /** */ + @Test + public void testTotalTruth() { + Metric<Double> scoreCalculator = new ROCAUC(); + + LabelPairCursor<Double> cursor = new TestLabelPairCursor<>( + Arrays.asList(1.0, 1.0, 1.0, 1.0, 1.0), + Arrays.asList(1.0, 1.0, 1.0, 1.0, 1.0) + ); + + double score = scoreCalculator.score(cursor.iterator()); + + assertEquals(Double.NaN, score, 1e-12); + } + + /** */ + @Test + public void testTotalUntruth() { + Metric<Double> scoreCalculator = new ROCAUC(); + + LabelPairCursor<Double> cursor = new TestLabelPairCursor<>( + Arrays.asList(1.0, 1.0, 1.0, 1.0, 1.0), + Arrays.asList(0.0, 0.0, 0.0, 0.0, 0.0) + ); + + double score = scoreCalculator.score(cursor.iterator()); + + assertEquals(Double.NaN, score, 1e-12); + } + + /** */ + @Test + public void testOneDifferent() { + Metric<Double> scoreCalculator = new ROCAUC(); + + LabelPairCursor<Double> cursor = new TestLabelPairCursor<>( + Arrays.asList(1.0, 1.0, 1.0, 0.0), + Arrays.asList(1.0, 1.0, 1.0, 1.0) + ); + + double score = scoreCalculator.score(cursor.iterator()); + + assertEquals(0.5, score, 1e-12); + } + + /** */ + @Test + public void testOneDifferentButBalanced() { + Metric<Double> scoreCalculator = new ROCAUC(); + + LabelPairCursor<Double> cursor = new TestLabelPairCursor<>( + Arrays.asList(1.0, 1.0, 0.0, 0.0), + Arrays.asList(1.0, 1.0, 0.0, 1.0) + ); + + double score = scoreCalculator.score(cursor.iterator()); + + assertEquals(0.75, score, 1e-12); + } + + /** */ + @Test + public void testTwoDifferentAndBalanced() { + Metric<Double> scoreCalculator = new ROCAUC(); + + LabelPairCursor<Double> cursor = new TestLabelPairCursor<>( + Arrays.asList(1.0, 1.0, 0.0, 0.0), + Arrays.asList(1.0, 0.0, 0.0, 1.0) + ); + + double score = scoreCalculator.score(cursor.iterator()); + + assertEquals(0.5, score, 1e-12); + } + + /** */ + @Test + public void testNotOnlyBinaryValues() { + Metric<Double> scoreCalculator = new ROCAUC(); + + LabelPairCursor<Double> cursor = new TestLabelPairCursor<>( + Arrays.asList(1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0), + Arrays.asList(0.40209054, 0.33697626, 0.5449324 , 0.13010869, 0.19019675, 0.39767829, 0.9686739 , 0.91783275, 0.7503783 , 0.5306605) + ); + + double score = scoreCalculator.score(cursor.iterator()); + + assertEquals(0.625, score, 1e-12); + } +}