This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
commit efc42d7b7d3f7e529400cac3e4fa4be8ad895a02 Author: Matthias Boehm <[email protected]> AuthorDate: Sat Apr 20 17:41:35 2024 +0200 [SYSTEMDS-3528] Fix documentation and efficiency confusionMatrix * Fix misleading documentation, the inputs are not one-hot encoded, but only recoded (see the stop conditions and table) * Move the dimension computation into a previous basic block in order to make this information available (as a scalar) during recompilation --- scripts/builtin/confusionMatrix.dml | 36 +++++++++++----------- .../builtin/part1/BuiltinConfusionMatrixTest.java | 10 ++++-- 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/scripts/builtin/confusionMatrix.dml b/scripts/builtin/confusionMatrix.dml index 18228d14c2..14d176706c 100644 --- a/scripts/builtin/confusionMatrix.dml +++ b/scripts/builtin/confusionMatrix.dml @@ -19,10 +19,9 @@ # #------------------------------------------------------------- -# Accepts a vector for prediction and a one-hot-encoded matrix -# Then it computes the max value of each vector and compare them -# After which, it calculates and returns the sum of classifications -# and the average of each true class. +# Computes the confusion matrix for input vectors of predictions +# and actual labels. We return both the counts and relative frequency +# (normalized by sum of true labels) # # .. code-block:: # @@ -33,30 +32,31 @@ # 2 FN | TN # # INPUT: -# -------------------------------------------------------------------------------- -# P vector of Predictions -# Y vector of Golden standard One Hot Encoded; the one hot -# encoded vector of actual labels -# -------------------------------------------------------------------------------- +# ------------------------------------------------------------------------------ +# P vector of predictions (1-based, recoded) +# Y vector of actual labels (1-based, recorded) +# ------------------------------------------------------------------------------ # # OUTPUT: -# ------------------------------------------------------------------------------------------------ -# confusionSum The Confusion Matrix Sums of classifications -# confusionAvg The Confusion Matrix averages of each true class -# ------------------------------------------------------------------------------------------------ +# ------------------------------------------------------------------------------ +# confusionSum the confusion matrix as absolute counts +# confusionAvg the confusion matrix as relative frequencies +# ------------------------------------------------------------------------------ m_confusionMatrix = function(Matrix[Double] P, Matrix[Double] Y) return(Matrix[Double] confusionSum, Matrix[Double] confusionAvg) { + dim = max(max(Y), max(P)) #ensure known dim + if(ncol(P) > 1 | ncol(Y) > 1) - stop("CONFUSION MATRIX: Invalid input number of cols should be 1 in both P ["+ncol(P)+"] and Y ["+ncol(Y)+"]") + stop("confusionMatrix: Invalid input number of cols should be 1 in both P ["+ncol(P)+"] and Y ["+ncol(Y)+"]") if(nrow(P) != nrow(Y)) - stop("CONFUSION MATRIX: The number of rows have to be equal in both P ["+nrow(P)+"] and Y ["+nrow(Y)+"]") + stop("confusionMatrix: The number of rows have to be equal in both P ["+nrow(P)+"] and Y ["+nrow(Y)+"]") if(min(P) < 1 | min(Y) < 1) - stop("CONFUSION MATRIX: All Values in P and Y should be abore or equal to 1, min(P):" + min(P) + " min(Y):" + min(Y) ) + stop("confusionMatrix: All Values in P and Y should be abore or equal to 1, min(P):" + min(P) + " min(Y):" + min(Y) ) - dim = max(max(Y),max(P)) - confusionSum = table(P, Y, dim, dim) + confusionSum = table(P, Y, dim, dim) # max to avoid division by 0, in case a colum contain no entries. confusionAvg = confusionSum / max(1,colSums(confusionSum)) } + diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinConfusionMatrixTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinConfusionMatrixTest.java index 3ea84cd183..4bc298338d 100644 --- a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinConfusionMatrixTest.java +++ b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinConfusionMatrixTest.java @@ -29,6 +29,8 @@ import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex; import org.apache.sysds.test.AutomatedTestBase; import org.apache.sysds.test.TestConfiguration; import org.apache.sysds.test.TestUtils; +import org.apache.sysds.utils.Statistics; +import org.junit.Assert; import org.junit.Test; public class BuiltinConfusionMatrixTest extends AutomatedTestBase { @@ -114,8 +116,9 @@ public class BuiltinConfusionMatrixTest extends AutomatedTestBase { runConfusionMatrixTest(y, p, res, ExecType.CP); } - private void runConfusionMatrixTest(double[][] y, double[][] p, HashMap<MatrixValue.CellIndex, Double> res, - ExecType instType) { + private void runConfusionMatrixTest(double[][] y, double[][] p, + HashMap<MatrixValue.CellIndex, Double> res, ExecType instType) + { ExecMode platformOld = setExecMode(instType); try { @@ -131,6 +134,9 @@ public class BuiltinConfusionMatrixTest extends AutomatedTestBase { HashMap<MatrixValue.CellIndex, Double> dmlResult = readDMLMatrixFromOutputDir("B"); TestUtils.compareMatrices(dmlResult, res, eps, "DML_Result", "Expected"); + + if( instType != ExecType.SPARK ) + Assert.assertEquals(0, Statistics.getNoOfExecutedSPInst()); } finally { rtplatform = platformOld;
