This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new e708382fa0 [SYSTEMDS-3496] Improved auc via ordered cumsum 
implementation
e708382fa0 is described below

commit e708382fa08c04240f82a96b679741a16509d9c4
Author: Matthias Boehm <[email protected]>
AuthorDate: Tue Feb 7 20:12:27 2023 +0100

    [SYSTEMDS-3496] Improved auc via ordered cumsum implementation
    
    This patch replaced the naive ROC curve computation (that scans for
    each unique threshold the number of true and false positives), with an
    ordered cumsum implementation that extracts theses values from a few
    vectorized scans of the ordered scores and responses.
    
    This new implementation is also more robust scaling to large,
    distributed vectors (supported sparse instruction, hardened impl).
    
    On a scenario of uniformly distributed vectors Y and P, this patch
    improved the auc() execution time as follows:
    
    10M rows, 10K distinct: crash ->  1.97s
    10M rows, 1K distinct: 251.6s ->  1.93s
    10M rows, 100 distinct: 18.8s ->  1.94s
    50M rows, 100 distinct: 217.9s -> 10.8s
---
 scripts/builtin/auc.dml                            | 46 ++++++++++++++++++----
 .../functions/builtin/part1/BuiltinAucTest.java    |  6 +++
 2 files changed, 45 insertions(+), 7 deletions(-)

diff --git a/scripts/builtin/auc.dml b/scripts/builtin/auc.dml
index 1084b62eba..7a87403419 100644
--- a/scripts/builtin/auc.dml
+++ b/scripts/builtin/auc.dml
@@ -49,12 +49,48 @@ m_auc = function(Matrix[Double] Y, Matrix[Double] P)
   # convert -1/1 to 0/1 if necessary
   if( minv < 0 )
     Y = (Y+1) != 0;
+
+  # compute true and false positive rates per unique threshold
+  [tpr, fpr] = cumsumROC(Y, P);
+
+  # compute AUC via Trapezoidal rule
+  nd = nrow(tpr);
+  auc = as.scalar(tpr[1] * fpr[1])
+      + sum((fpr[2:nd]-fpr[1:(nd-1)]) * (tpr[2:nd]+tpr[1:(nd-1)])/2);
+}
+
+cumsumROC = function(Matrix[Double] Y, Matrix[Double] P)
+  return(Matrix[Double] tpr, Matrix[Double] fpr)
+{
+  pos = sum(Y);
+  neg = nrow(Y) - pos;
+
+  # compute ROC curve for distinct threshold scores
+  # (cut-offs > and <= choosen to match R-pROC-package behavior)
+  # vectorized implementation via cumsum of ordered scores P
+  YP = order(target=cbind(Y, P), by=2);
+  oY = YP[,1]; oP = YP[,2];
+  tp = pos - cumsum(oY); # true positives until certain threshold (row)
+  fp = cumsum(!oY);      # false positives until certain threshold
+
+  # indicator of unique thresholds for at end of range
+  uI = (oP != rbind(oP[2:nrow(oP)],as.matrix(0)));
+
+  # extract true/false positves for unique thresholds
+  tp = removeEmpty(target=tp, margin="rows", select=uI);
+  fp = removeEmpty(target=fp, margin="rows", select=uI);
+  tpr = tp / pos; # true positive rate, increasing
+  fpr = fp / neg; # false postive rate, increasing
+}
+
+naiveROC = function(Matrix[Double] Y, Matrix[Double] P)
+  return(Matrix[Double] tpr, Matrix[Double] fpr)
+{
   pos = sum(Y);
   neg = nrow(Y) - pos;
 
   # compute ROC curve for distinct threshold scores
   # (cut-offs > and <= choosen to match R-pROC-package behavior)
-  # TODO vectorize via ordering + cumsum (but indexes of unique missing)
   dP = order(target=unique(P)); # distinct P thresholds, increasing
   nd = nrow(dP)
   tp = matrix(0, nd, 1);
@@ -63,10 +99,6 @@ m_auc = function(Matrix[Double] Y, Matrix[Double] P)
     tp[i] = sum(P>dP[i] & Y)
     fp[i] = sum(P<=dP[i] & !Y)
   }
-  tpr = tp / pos; # true positive rate, increasing
-  fpr = fp / neg; # false postive rate, increasing
-
-  # compute AUC via Trapezoidal rule
-  auc = as.scalar(tpr[1] * fpr[1])
-      + sum((fpr[2:nd]-fpr[1:(nd-1)]) * (tpr[2:nd]+tpr[1:(nd-1)])/2);
+  tpr = tp / pos; # true positive rate, decreasing
+  fpr = fp / neg; # false postive rate, decreasing
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinAucTest.java
 
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinAucTest.java
index fec9e94bdf..ac79eb6f7e 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinAucTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinAucTest.java
@@ -49,6 +49,12 @@ public class BuiltinAucTest extends AutomatedTestBase
                        new double[]{0.1,0.2,0.3,0.4,0.55,0.56});
        }
        
+       @Test
+       public void testPerfectSeparationOrderedDups() {
+               runAucTest(1.0, new double[]{0,0,0,0,0,0,1,1,1,1,1,1},
+                       new 
double[]{0.1,0.2,0.3,0.1,0.2,0.3,0.4,0.55,0.56,0.4,0.55,0.56});
+       }
+       
        @Test
        public void testPerfectSeparationUnordered() {
                runAucTest(1.0, new double[]{0,1,0,1,0,1},

Reply via email to