This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new e708382fa0 [SYSTEMDS-3496] Improved auc via ordered cumsum
implementation
e708382fa0 is described below
commit e708382fa08c04240f82a96b679741a16509d9c4
Author: Matthias Boehm <[email protected]>
AuthorDate: Tue Feb 7 20:12:27 2023 +0100
[SYSTEMDS-3496] Improved auc via ordered cumsum implementation
This patch replaced the naive ROC curve computation (that scans for
each unique threshold the number of true and false positives), with an
ordered cumsum implementation that extracts theses values from a few
vectorized scans of the ordered scores and responses.
This new implementation is also more robust scaling to large,
distributed vectors (supported sparse instruction, hardened impl).
On a scenario of uniformly distributed vectors Y and P, this patch
improved the auc() execution time as follows:
10M rows, 10K distinct: crash -> 1.97s
10M rows, 1K distinct: 251.6s -> 1.93s
10M rows, 100 distinct: 18.8s -> 1.94s
50M rows, 100 distinct: 217.9s -> 10.8s
---
scripts/builtin/auc.dml | 46 ++++++++++++++++++----
.../functions/builtin/part1/BuiltinAucTest.java | 6 +++
2 files changed, 45 insertions(+), 7 deletions(-)
diff --git a/scripts/builtin/auc.dml b/scripts/builtin/auc.dml
index 1084b62eba..7a87403419 100644
--- a/scripts/builtin/auc.dml
+++ b/scripts/builtin/auc.dml
@@ -49,12 +49,48 @@ m_auc = function(Matrix[Double] Y, Matrix[Double] P)
# convert -1/1 to 0/1 if necessary
if( minv < 0 )
Y = (Y+1) != 0;
+
+ # compute true and false positive rates per unique threshold
+ [tpr, fpr] = cumsumROC(Y, P);
+
+ # compute AUC via Trapezoidal rule
+ nd = nrow(tpr);
+ auc = as.scalar(tpr[1] * fpr[1])
+ + sum((fpr[2:nd]-fpr[1:(nd-1)]) * (tpr[2:nd]+tpr[1:(nd-1)])/2);
+}
+
+cumsumROC = function(Matrix[Double] Y, Matrix[Double] P)
+ return(Matrix[Double] tpr, Matrix[Double] fpr)
+{
+ pos = sum(Y);
+ neg = nrow(Y) - pos;
+
+ # compute ROC curve for distinct threshold scores
+ # (cut-offs > and <= choosen to match R-pROC-package behavior)
+ # vectorized implementation via cumsum of ordered scores P
+ YP = order(target=cbind(Y, P), by=2);
+ oY = YP[,1]; oP = YP[,2];
+ tp = pos - cumsum(oY); # true positives until certain threshold (row)
+ fp = cumsum(!oY); # false positives until certain threshold
+
+ # indicator of unique thresholds for at end of range
+ uI = (oP != rbind(oP[2:nrow(oP)],as.matrix(0)));
+
+ # extract true/false positves for unique thresholds
+ tp = removeEmpty(target=tp, margin="rows", select=uI);
+ fp = removeEmpty(target=fp, margin="rows", select=uI);
+ tpr = tp / pos; # true positive rate, increasing
+ fpr = fp / neg; # false postive rate, increasing
+}
+
+naiveROC = function(Matrix[Double] Y, Matrix[Double] P)
+ return(Matrix[Double] tpr, Matrix[Double] fpr)
+{
pos = sum(Y);
neg = nrow(Y) - pos;
# compute ROC curve for distinct threshold scores
# (cut-offs > and <= choosen to match R-pROC-package behavior)
- # TODO vectorize via ordering + cumsum (but indexes of unique missing)
dP = order(target=unique(P)); # distinct P thresholds, increasing
nd = nrow(dP)
tp = matrix(0, nd, 1);
@@ -63,10 +99,6 @@ m_auc = function(Matrix[Double] Y, Matrix[Double] P)
tp[i] = sum(P>dP[i] & Y)
fp[i] = sum(P<=dP[i] & !Y)
}
- tpr = tp / pos; # true positive rate, increasing
- fpr = fp / neg; # false postive rate, increasing
-
- # compute AUC via Trapezoidal rule
- auc = as.scalar(tpr[1] * fpr[1])
- + sum((fpr[2:nd]-fpr[1:(nd-1)]) * (tpr[2:nd]+tpr[1:(nd-1)])/2);
+ tpr = tp / pos; # true positive rate, decreasing
+ fpr = fp / neg; # false postive rate, decreasing
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinAucTest.java
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinAucTest.java
index fec9e94bdf..ac79eb6f7e 100644
---
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinAucTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinAucTest.java
@@ -49,6 +49,12 @@ public class BuiltinAucTest extends AutomatedTestBase
new double[]{0.1,0.2,0.3,0.4,0.55,0.56});
}
+ @Test
+ public void testPerfectSeparationOrderedDups() {
+ runAucTest(1.0, new double[]{0,0,0,0,0,0,1,1,1,1,1,1},
+ new
double[]{0.1,0.2,0.3,0.1,0.2,0.3,0.4,0.55,0.56,0.4,0.55,0.56});
+ }
+
@Test
public void testPerfectSeparationUnordered() {
runAucTest(1.0, new double[]{0,1,0,1,0,1},