This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit efc42d7b7d3f7e529400cac3e4fa4be8ad895a02
Author: Matthias Boehm <[email protected]>
AuthorDate: Sat Apr 20 17:41:35 2024 +0200

    [SYSTEMDS-3528] Fix documentation and efficiency confusionMatrix
    
    * Fix misleading documentation, the inputs are not one-hot encoded,
      but only recoded (see the stop conditions and table)
    * Move the dimension computation into a previous basic block in order
      to make this information available (as a scalar) during recompilation
---
 scripts/builtin/confusionMatrix.dml                | 36 +++++++++++-----------
 .../builtin/part1/BuiltinConfusionMatrixTest.java  | 10 ++++--
 2 files changed, 26 insertions(+), 20 deletions(-)

diff --git a/scripts/builtin/confusionMatrix.dml 
b/scripts/builtin/confusionMatrix.dml
index 18228d14c2..14d176706c 100644
--- a/scripts/builtin/confusionMatrix.dml
+++ b/scripts/builtin/confusionMatrix.dml
@@ -19,10 +19,9 @@
 #
 #-------------------------------------------------------------
 
-# Accepts a vector for prediction and a one-hot-encoded matrix
-# Then it computes the max value of each vector and compare them
-# After which, it calculates and returns the sum of classifications
-# and the average of each true class.
+# Computes the confusion matrix for input vectors of predictions
+# and actual labels. We return both the counts and relative frequency
+# (normalized by sum of true labels)
 #
 # .. code-block::
 #
@@ -33,30 +32,31 @@
 #                 2   FN | TN
 #
 # INPUT:
-# 
--------------------------------------------------------------------------------
-# P     vector of Predictions
-# Y     vector of Golden standard One Hot Encoded; the one hot
-#       encoded vector of actual labels
-# 
--------------------------------------------------------------------------------
+# 
------------------------------------------------------------------------------
+# P              vector of predictions (1-based, recoded)
+# Y              vector of actual labels (1-based, recorded)
+# 
------------------------------------------------------------------------------
 #
 # OUTPUT:
-# 
------------------------------------------------------------------------------------------------
-# confusionSum   The Confusion Matrix Sums of classifications
-# confusionAvg   The Confusion Matrix averages of each true class
-# 
------------------------------------------------------------------------------------------------
+# 
------------------------------------------------------------------------------
+# confusionSum   the confusion matrix as absolute counts
+# confusionAvg   the confusion matrix as relative frequencies
+# 
------------------------------------------------------------------------------
 
 m_confusionMatrix = function(Matrix[Double] P, Matrix[Double] Y)
   return(Matrix[Double] confusionSum, Matrix[Double] confusionAvg)
 {
+  dim = max(max(Y), max(P)) #ensure known dim
+
   if(ncol(P) > 1  | ncol(Y) > 1)
-    stop("CONFUSION MATRIX: Invalid input number of cols should be 1 in both P 
["+ncol(P)+"] and Y ["+ncol(Y)+"]")
+    stop("confusionMatrix: Invalid input number of cols should be 1 in both P 
["+ncol(P)+"] and Y ["+ncol(Y)+"]")
   if(nrow(P) != nrow(Y))
-    stop("CONFUSION MATRIX: The number of rows have to be equal in both P 
["+nrow(P)+"] and Y ["+nrow(Y)+"]")
+    stop("confusionMatrix: The number of rows have to be equal in both P 
["+nrow(P)+"] and Y ["+nrow(Y)+"]")
   if(min(P) < 1 | min(Y) < 1)
-    stop("CONFUSION MATRIX: All Values in P and Y should be abore or equal to 
1, min(P):" + min(P) + " min(Y):" + min(Y) )
+    stop("confusionMatrix: All Values in P and Y should be abore or equal to 
1, min(P):" + min(P) + " min(Y):" + min(Y) )
 
-  dim = max(max(Y),max(P))
-  confusionSum = table(P, Y,  dim, dim)
+  confusionSum = table(P, Y, dim, dim)
   # max to avoid division by 0, in case a colum contain no entries.
   confusionAvg = confusionSum / max(1,colSums(confusionSum))
 }
+
diff --git 
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinConfusionMatrixTest.java
 
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinConfusionMatrixTest.java
index 3ea84cd183..4bc298338d 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinConfusionMatrixTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinConfusionMatrixTest.java
@@ -29,6 +29,8 @@ import 
org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+import org.junit.Assert;
 import org.junit.Test;
 
 public class BuiltinConfusionMatrixTest extends AutomatedTestBase {
@@ -114,8 +116,9 @@ public class BuiltinConfusionMatrixTest extends 
AutomatedTestBase {
                runConfusionMatrixTest(y, p, res, ExecType.CP);
        }
 
-       private void runConfusionMatrixTest(double[][] y, double[][] p, 
HashMap<MatrixValue.CellIndex, Double> res,
-               ExecType instType) {
+       private void runConfusionMatrixTest(double[][] y, double[][] p,
+               HashMap<MatrixValue.CellIndex, Double> res, ExecType instType)
+       {
                ExecMode platformOld = setExecMode(instType);
 
                try {
@@ -131,6 +134,9 @@ public class BuiltinConfusionMatrixTest extends 
AutomatedTestBase {
 
                        HashMap<MatrixValue.CellIndex, Double> dmlResult = 
readDMLMatrixFromOutputDir("B");
                        TestUtils.compareMatrices(dmlResult, res, eps, 
"DML_Result", "Expected");
+                       
+                       if( instType != ExecType.SPARK )
+                               Assert.assertEquals(0, 
Statistics.getNoOfExecutedSPInst());
                }
                finally {
                        rtplatform = platformOld;

Reply via email to