This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit 7c4f3455958e660669c7cb42cb7f3884314c71fa
Author: Matthias Boehm <[email protected]>
AuthorDate: Mon Feb 6 19:03:39 2023 +0100

    [SYSTEMDS-3496] New auc() builtin function for the area under ROC curves
    
    This patch introduces the new auc() builtin function that takes a
    response vector Y and probabilities P (e.g., from multiLogRegPredict)
    and computes the area under the Receiver-Operating-Characteristic curve.
    The current implementation naively computes the distinct probabilities
    and then evaluates the true and false positive rates for all these
    possible thresholds, with semantics equivalent to the R pROC package.
    
    Next steps include fixes for compiling unique operations (currently
    requires forced single node), missing unique spark operations, and a
    more efficient vectorized auc() implementation via cumsum (and unique
    extensions to obtain the last indexes of unique values).
---
 scripts/builtin/auc.dml                            |  72 +++++++++++
 .../java/org/apache/sysds/common/Builtins.java     |   1 +
 .../functions/builtin/part1/BuiltinAucTest.java    | 133 +++++++++++++++++++++
 src/test/scripts/functions/builtin/auc.dml         |  25 ++++
 4 files changed, 231 insertions(+)

diff --git a/scripts/builtin/auc.dml b/scripts/builtin/auc.dml
new file mode 100644
index 0000000000..1084b62eba
--- /dev/null
+++ b/scripts/builtin/auc.dml
@@ -0,0 +1,72 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# This builting function computes the area under the ROC curve (AUC)
+# for binary classifiers.
+#
+# INPUT:
+# 
------------------------------------------------------------------------------
+# Y            Binary response vector (shape: n x 1), in -1/+1 or 0/1 encoding
+# P            Prediction scores (predictor such as estimated probabilities)
+#              for true class (shape: n x 1), assumed in [0,1]
+# 
------------------------------------------------------------------------------
+#
+# OUTPUT:
+# 
------------------------------------------------------------------------------
+# auc          Area under the ROC curve (AUC)
+# 
------------------------------------------------------------------------------
+
+m_auc = function(Matrix[Double] Y, Matrix[Double] P)
+  return(Double auc)
+{
+  minv = min(Y)
+  maxv = max(Y)
+
+  # check input parameter assertions
+  if(minv == maxv)
+    stop("AUC: stopping because only one class label existing in Y")
+  if(sum(Y==minv) + sum(Y==maxv) < nrow(Y))
+    stop("AUC: stopping because more than two class labels existing in Y")
+
+  # convert -1/1 to 0/1 if necessary
+  if( minv < 0 )
+    Y = (Y+1) != 0;
+  pos = sum(Y);
+  neg = nrow(Y) - pos;
+
+  # compute ROC curve for distinct threshold scores
+  # (cut-offs > and <= choosen to match R-pROC-package behavior)
+  # TODO vectorize via ordering + cumsum (but indexes of unique missing)
+  dP = order(target=unique(P)); # distinct P thresholds, increasing
+  nd = nrow(dP)
+  tp = matrix(0, nd, 1);
+  fp = matrix(0, nd, 1);
+  parfor(i in 1:nd) {
+    tp[i] = sum(P>dP[i] & Y)
+    fp[i] = sum(P<=dP[i] & !Y)
+  }
+  tpr = tp / pos; # true positive rate, increasing
+  fpr = fp / neg; # false postive rate, increasing
+
+  # compute AUC via Trapezoidal rule
+  auc = as.scalar(tpr[1] * fpr[1])
+      + sum((fpr[2:nd]-fpr[1:(nd-1)]) * (tpr[2:nd]+tpr[1:(nd-1)])/2);
+}
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java 
b/src/main/java/org/apache/sysds/common/Builtins.java
index 24215ffdab..f7cbb972df 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -51,6 +51,7 @@ public enum Builtins {
        ARIMA("arima", true),
        ASIN("asin", false),
        ATAN("atan", false),
+       AUC("auc", true),
        AUTOENCODER2LAYER("autoencoder_2layer", true),
        AVG_POOL("avg_pool", false),
        AVG_POOL_BACKWARD("avg_pool_backward", false),
diff --git 
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinAucTest.java
 
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinAucTest.java
new file mode 100644
index 0000000000..49502e8371
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinAucTest.java
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin.part1;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+
+public class BuiltinAucTest extends AutomatedTestBase
+{
+       private final static String TEST_NAME = "auc";
+       private final static String TEST_DIR = "functions/builtin/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
BuiltinAucTest.class.getSimpleName() + "/";
+
+       private double eps = 0.01;
+       
+       @Override
+       public void setUp() {
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME,new String[]{"B"}));
+       }
+
+       //FIXME missing spark instruction unique
+       
+       @Test
+       public void testPerfectSeparationOrdered() {
+               runAucTest(1.0, new double[]{0,0,0,1,1,1},
+                       new double[]{0.1,0.2,0.3,0.4,0.55,0.56});
+       }
+       
+       @Test
+       public void testPerfectSeparationUnordered() {
+               runAucTest(1.0, new double[]{0,1,0,1,0,1},
+                       new double[]{0.1,0.5,0.2,0.55,0.3,0.56});
+       }
+       
+       @Test
+       public void testPerfectSeparationUnorderedDups() {
+               runAucTest(1.0, new double[]{0,1,0,1,0,1,0,1,0,1,0,1},
+                       new 
double[]{0.1,0.5,0.2,0.55,0.3,0.56,0.1,0.5,0.2,0.55,0.3,0.56});
+       }
+
+       //selected cases, double checked with R pROC (but not explicitly 
compared to avoid dependency)
+       
+       @Test
+       public void testMisc1() {
+               runAucTest(0.8899, new double[]{0,0,1,0,1,1},
+                       new double[]{0.1,0.2,0.3,0.4,0.5,0.55});
+       }
+       
+       @Test
+       public void testMisc2() {
+               runAucTest(0.8899, new double[]{-1,-1,1,-1,1,1},
+                       new double[]{0.1,0.2,0.3,0.4,0.5,0.55});
+       }
+       
+       @Test
+       public void testMisc3() {
+               runAucTest(0.75, new double[]{0,0,1,0,1,1,0,1},
+                       new double[]{0.1,0.2,0.2,0.21,0.7,0.7,0.7,0.7});
+       }
+       
+       @Test
+       public void testMisc4() {
+               runAucTest(0.6, new double[]{0,0,1,0,1,1,0,1,0},
+                       new double[]{0.1,0.2,0.2,0.21,0.7,0.7,0.7,0.7,0.9});
+       }
+       
+       @Test
+       public void testMisc5() {
+               runAucTest(0.6, new double[]{0,0,0,1,0,1,1,0,1},
+                       new double[]{0.9,0.1,0.2,0.2,0.21,0.7,0.7,0.7,0.7});
+       }
+       
+       @Test
+       public void testMisc6() {
+               runAucTest(0.5, new double[]{0,0,1,0,1,1,0,1,0,0},
+                       new double[]{0.1,0.2,0.2,0.21,0.7,0.7,0.7,0.7,0.9,0.9});
+       }
+       
+       @Test
+       public void testMisc7() {
+               runAucTest(0.4286, new double[]{0,0,1,0,1,1,0,1,0,0,0},
+                       new 
double[]{0.1,0.2,0.2,0.21,0.7,0.7,0.7,0.7,0.9,0.9,0.99});
+       }
+       
+       private void runAucTest(double auc, double[] Y, double[] P)
+       {
+               ExecMode platformOld = setExecMode(ExecMode.SINGLE_NODE);
+
+               try
+               {
+                       loadTestConfiguration(getTestConfiguration(TEST_NAME));
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       programArgs = new String[]{"-args", input("Yt"), 
input("Pt"), output("C") };
+
+                       //generate actual dataset 
+                       writeInputMatrixWithMTD("Yt", new double[][]{Y}, false);
+                       writeInputMatrixWithMTD("Pt", new double[][]{P}, false);
+
+                       //execute test
+                       runTest(true, false, null, -1);
+
+                       //compare matrices 
+                       double val = readDMLMatrixFromOutputDir("C").get(new 
CellIndex(1,1));
+                       Assert.assertEquals("Incorrect values: ", auc, val, 
eps);
+               }
+               finally {
+                       rtplatform = platformOld;
+               }
+       }
+}
diff --git a/src/test/scripts/functions/builtin/auc.dml 
b/src/test/scripts/functions/builtin/auc.dml
new file mode 100644
index 0000000000..065546f45d
--- /dev/null
+++ b/src/test/scripts/functions/builtin/auc.dml
@@ -0,0 +1,25 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+Y = read($1);
+P = read($2);
+C = as.matrix(auc(t(Y), t(P)));
+write(C, $3)

Reply via email to