This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit c7f53c0b22494a290bfd8f251040edc566e85af6
Author: Matthias Boehm <mboe...@gmail.com>
AuthorDate: Sat Apr 20 18:00:35 2024 +0200

    [SYSTEMDS-3538] Add missing f1Score builtin function
    
    We simply compute the F1 score based on the confusion matrix output,
    and only support binary labels. In the future, we could adopt an
    one-against-the-rest approach for multi-class settings.
---
 scripts/builtin/confusionMatrix.dml                |  2 +-
 .../builtin/{confusionMatrix.dml => f1Score.dml}   | 39 ++++++----------------
 .../java/org/apache/sysds/common/Builtins.java     |  1 +
 3 files changed, 12 insertions(+), 30 deletions(-)

diff --git a/scripts/builtin/confusionMatrix.dml 
b/scripts/builtin/confusionMatrix.dml
index 14d176706c..3ac70fb3f8 100644
--- a/scripts/builtin/confusionMatrix.dml
+++ b/scripts/builtin/confusionMatrix.dml
@@ -34,7 +34,7 @@
 # INPUT:
 # 
------------------------------------------------------------------------------
 # P              vector of predictions (1-based, recoded)
-# Y              vector of actual labels (1-based, recorded)
+# Y              vector of actual labels (1-based, recoded)
 # 
------------------------------------------------------------------------------
 #
 # OUTPUT:
diff --git a/scripts/builtin/confusionMatrix.dml b/scripts/builtin/f1Score.dml
similarity index 50%
copy from scripts/builtin/confusionMatrix.dml
copy to scripts/builtin/f1Score.dml
index 14d176706c..0407998ac1 100644
--- a/scripts/builtin/confusionMatrix.dml
+++ b/scripts/builtin/f1Score.dml
@@ -19,44 +19,25 @@
 #
 #-------------------------------------------------------------
 
-# Computes the confusion matrix for input vectors of predictions
-# and actual labels. We return both the counts and relative frequency
-# (normalized by sum of true labels)
-#
-# .. code-block::
-#
-#                   True Labels
-#                     1    2
-#                 1   TP | FP
-#   Predictions      ----+----
-#                 2   FN | TN
+# Computes the F1 score as the harmonic mean of precision and recall.
+# F1 = 2TP / (2TP + FP + FN)
 #
 # INPUT:
 # 
------------------------------------------------------------------------------
 # P              vector of predictions (1-based, recoded)
-# Y              vector of actual labels (1-based, recorded)
+# Y              vector of actual labels (1-based, recoded)
 # 
------------------------------------------------------------------------------
 #
 # OUTPUT:
 # 
------------------------------------------------------------------------------
-# confusionSum   the confusion matrix as absolute counts
-# confusionAvg   the confusion matrix as relative frequencies
+# score          the F1 score
 # 
------------------------------------------------------------------------------
 
-m_confusionMatrix = function(Matrix[Double] P, Matrix[Double] Y)
-  return(Matrix[Double] confusionSum, Matrix[Double] confusionAvg)
+m_f1Score = function(Matrix[Double] P, Matrix[Double] Y)
+  return(Double score)
 {
-  dim = max(max(Y), max(P)) #ensure known dim
-
-  if(ncol(P) > 1  | ncol(Y) > 1)
-    stop("confusionMatrix: Invalid input number of cols should be 1 in both P 
["+ncol(P)+"] and Y ["+ncol(Y)+"]")
-  if(nrow(P) != nrow(Y))
-    stop("confusionMatrix: The number of rows have to be equal in both P 
["+nrow(P)+"] and Y ["+nrow(Y)+"]")
-  if(min(P) < 1 | min(Y) < 1)
-    stop("confusionMatrix: All Values in P and Y should be abore or equal to 
1, min(P):" + min(P) + " min(Y):" + min(Y) )
-
-  confusionSum = table(P, Y, dim, dim)
-  # max to avoid division by 0, in case a colum contain no entries.
-  confusionAvg = confusionSum / max(1,colSums(confusionSum))
+  [cS, cA] = confusionMatrix(P, Y);
+  if(nrow(cS)>2 | ncol(cS)>2)
+    stop("f1Score: currently only supported for binary class labels.");
+  score = as.scalar(2*cS[1,1] / (2*cS[1,1] + cS[2,1] + cS[1,2]));
 }
-
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java 
b/src/main/java/org/apache/sysds/common/Builtins.java
index b510698cd4..61faccc0bc 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -130,6 +130,7 @@ public enum Builtins {
        EXP("exp", false),
        EVAL("eval", false),
        EVALLIST("evalList", false),
+       F1SCORE("f1Score", true),
        FIT_PIPELINE("fit_pipeline", true),
        FIX_INVALID_LENGTHS("fixInvalidLengths", true),
        FIX_INVALID_LENGTHS_APPLY("fixInvalidLengthsApply", true),

Reply via email to