This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
commit c7f53c0b22494a290bfd8f251040edc566e85af6 Author: Matthias Boehm <mboe...@gmail.com> AuthorDate: Sat Apr 20 18:00:35 2024 +0200 [SYSTEMDS-3538] Add missing f1Score builtin function We simply compute the F1 score based on the confusion matrix output, and only support binary labels. In the future, we could adopt an one-against-the-rest approach for multi-class settings. --- scripts/builtin/confusionMatrix.dml | 2 +- .../builtin/{confusionMatrix.dml => f1Score.dml} | 39 ++++++---------------- .../java/org/apache/sysds/common/Builtins.java | 1 + 3 files changed, 12 insertions(+), 30 deletions(-) diff --git a/scripts/builtin/confusionMatrix.dml b/scripts/builtin/confusionMatrix.dml index 14d176706c..3ac70fb3f8 100644 --- a/scripts/builtin/confusionMatrix.dml +++ b/scripts/builtin/confusionMatrix.dml @@ -34,7 +34,7 @@ # INPUT: # ------------------------------------------------------------------------------ # P vector of predictions (1-based, recoded) -# Y vector of actual labels (1-based, recorded) +# Y vector of actual labels (1-based, recoded) # ------------------------------------------------------------------------------ # # OUTPUT: diff --git a/scripts/builtin/confusionMatrix.dml b/scripts/builtin/f1Score.dml similarity index 50% copy from scripts/builtin/confusionMatrix.dml copy to scripts/builtin/f1Score.dml index 14d176706c..0407998ac1 100644 --- a/scripts/builtin/confusionMatrix.dml +++ b/scripts/builtin/f1Score.dml @@ -19,44 +19,25 @@ # #------------------------------------------------------------- -# Computes the confusion matrix for input vectors of predictions -# and actual labels. We return both the counts and relative frequency -# (normalized by sum of true labels) -# -# .. code-block:: -# -# True Labels -# 1 2 -# 1 TP | FP -# Predictions ----+---- -# 2 FN | TN +# Computes the F1 score as the harmonic mean of precision and recall. +# F1 = 2TP / (2TP + FP + FN) # # INPUT: # ------------------------------------------------------------------------------ # P vector of predictions (1-based, recoded) -# Y vector of actual labels (1-based, recorded) +# Y vector of actual labels (1-based, recoded) # ------------------------------------------------------------------------------ # # OUTPUT: # ------------------------------------------------------------------------------ -# confusionSum the confusion matrix as absolute counts -# confusionAvg the confusion matrix as relative frequencies +# score the F1 score # ------------------------------------------------------------------------------ -m_confusionMatrix = function(Matrix[Double] P, Matrix[Double] Y) - return(Matrix[Double] confusionSum, Matrix[Double] confusionAvg) +m_f1Score = function(Matrix[Double] P, Matrix[Double] Y) + return(Double score) { - dim = max(max(Y), max(P)) #ensure known dim - - if(ncol(P) > 1 | ncol(Y) > 1) - stop("confusionMatrix: Invalid input number of cols should be 1 in both P ["+ncol(P)+"] and Y ["+ncol(Y)+"]") - if(nrow(P) != nrow(Y)) - stop("confusionMatrix: The number of rows have to be equal in both P ["+nrow(P)+"] and Y ["+nrow(Y)+"]") - if(min(P) < 1 | min(Y) < 1) - stop("confusionMatrix: All Values in P and Y should be abore or equal to 1, min(P):" + min(P) + " min(Y):" + min(Y) ) - - confusionSum = table(P, Y, dim, dim) - # max to avoid division by 0, in case a colum contain no entries. - confusionAvg = confusionSum / max(1,colSums(confusionSum)) + [cS, cA] = confusionMatrix(P, Y); + if(nrow(cS)>2 | ncol(cS)>2) + stop("f1Score: currently only supported for binary class labels."); + score = as.scalar(2*cS[1,1] / (2*cS[1,1] + cS[2,1] + cS[1,2])); } - diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java index b510698cd4..61faccc0bc 100644 --- a/src/main/java/org/apache/sysds/common/Builtins.java +++ b/src/main/java/org/apache/sysds/common/Builtins.java @@ -130,6 +130,7 @@ public enum Builtins { EXP("exp", false), EVAL("eval", false), EVALLIST("evalList", false), + F1SCORE("f1Score", true), FIT_PIPELINE("fit_pipeline", true), FIX_INVALID_LENGTHS("fixInvalidLengths", true), FIX_INVALID_LENGTHS_APPLY("fixInvalidLengthsApply", true),