This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit 20e7dd0800fb69f73b770463b3b90fde283219f5
Author: Matthias Boehm <mboe...@gmail.com>
AuthorDate: Wed Apr 17 20:59:59 2024 +0200

    [SYSTEMDS-3538] New word-error-rate builtin function for TPCx-AI
    
    This patch adds the word-error-rate (WER) builtin function, which is
    derived from the Levenshtein distance but on words instead of
    characters.
    
    Part of the feature pack for implementing TPCx-AI on SystemDS.
---
 scripts/builtin/wer.dml                            | 43 +++++++++++++
 .../java/org/apache/sysds/common/Builtins.java     |  1 +
 .../apache/sysds/runtime/util/UtilFunctions.java   | 66 ++++++++++++++++++-
 .../functions/builtin/part2/BuiltinWerTest.java    | 75 ++++++++++++++++++++++
 src/test/scripts/functions/builtin/wer.dml         | 37 +++++++++++
 5 files changed, 219 insertions(+), 3 deletions(-)

diff --git a/scripts/builtin/wer.dml b/scripts/builtin/wer.dml
new file mode 100644
index 0000000000..fd48f63dde
--- /dev/null
+++ b/scripts/builtin/wer.dml
@@ -0,0 +1,43 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# This built-in function computes the word error rate (WER)
+# defined as wer = (numSubst + numDel + numIns) / length(r)
+#
+# INPUT:
+# 
------------------------------------------------------------------------------
+# R     Input frame of reference strings, shape: [N x 1]
+# H     Input frame of hypothesis strings, shape: [N x 1]
+# 
------------------------------------------------------------------------------
+#
+# OUTPUT:
+# 
------------------------------------------------------------------------------
+# W     Output matrix of word error rate per pair of strings,
+#       shape: [N x 1], where W[i,1] = wer(R[i,1], H[i,1])
+# 
------------------------------------------------------------------------------
+
+f_wer = function(Frame[String] R, Frame[String] H)
+  return (Matrix[Double] W)
+{
+  F = cbind(R, H);
+  T = map(F, "x -> new 
String[]{String.valueOf(UtilFunctions.getWordErrorRate(x[0], x[1])),\"\"}", 
margin=1);
+  W = as.matrix(T[,1]);
+}
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java 
b/src/main/java/org/apache/sysds/common/Builtins.java
index 82b6441f58..b510698cd4 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -328,6 +328,7 @@ public enum Builtins {
        VAR("var", false),
        VALUE_SWAP("valueSwap", false),
        VECTOR_TO_CSV("vectorToCsv", true),
+       WER("wer", true, false),
        WINSORIZE("winsorize", true, false), //TODO parameterize w/ prob, 
min/max val
        WINSORIZEAPPLY("winsorizeApply", true, false), //TODO parameterize w/ 
prob, min/max val
        WOE("WoE", true ),
diff --git a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java 
b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
index 619dd4467e..b4a809a280 100644
--- a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
+++ b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
@@ -445,7 +445,7 @@ public class UtilFunctions {
         * 
         * @param low   lower bound (inclusive)
         * @param up    upper bound (inclusive)
-        * @param incr  increment 
+        * @param incr  increment
         * @return list of integers
         */
        public static List<Integer> getSeqList(int low, int up, int incr) {
@@ -460,7 +460,7 @@ public class UtilFunctions {
         * 
         * @param low   lower bound (inclusive)
         * @param up    upper bound (inclusive)
-        * @param incr  increment 
+        * @param incr  increment
         * @return array of integers
         */
        public static int[] getSeqArray(int low, int up, int incr) {
@@ -778,7 +778,7 @@ public class UtilFunctions {
                byte ret = Byte.MIN_VALUE;
                for( int i=0; i<array.length; i++ )
                        ret = (array[i]>ret)?array[i]:ret;
-               return ret;     
+               return ret;
        }
        
        public static String unquote(String s) {
@@ -1363,4 +1363,64 @@ public class UtilFunctions {
 
                return 
Arrays.stream(original).mapToDouble(Double::parseDouble).toArray();
        }
+       
+       /**
+        * Computes the word error rate (Levenshtein distance at word level):
+        * wer =  (numSubst + numDel + numIns) / length(r)
+        * 
+        * This code has been adapted from Apache Commons Lang 3.12 
+        * (getLevenshteinDistance, but for words instead of characters).
+        * 
+        * @param r reference string
+        * @param h hypothesis string
+        * @return word error rate (WER)
+        */
+       public static double getWordErrorRate(String r, String h) {
+               if (r == null || h == null) {
+                       throw new IllegalArgumentException("Strings must not be 
null");
+               }
+
+               //prepare string sequences 
+               String[] s = r.split(" ");
+               String[] t = h.split(" ");
+               int n = s.length;
+               int m = t.length;
+               
+               //basic size handling
+               if( n == 0 || m == 0 )
+                       return Math.max(n, m);
+               if (n > m) {
+                       // swap the input strings to consume less memory
+                       String[] tmp = s;
+                       s = t;
+                       t = tmp;
+                       n = m;
+                       m = t.length;
+               }
+
+               final int[] p = new int[n + 1];
+               int i; // iterates through s
+               int j; // iterates through t
+               int upper_left;
+               int upper;
+               
+               String t_j; // jth word of t
+               int cost;
+               for (i = 0; i <= n; i++) {
+                       p[i] = i;
+               }
+               for (j = 1; j <= m; j++) {
+                       upper_left = p[0];
+                       t_j = t[j - 1];
+                       p[0] = j;
+                       for (i = 1; i <= n; i++) {
+                               upper = p[i];
+                               cost = s[i - 1].equals(t_j) ? 0 : 1;
+                               p[i] = Math.min(Math.min(p[i - 1] + 1, p[i] + 
1), upper_left + cost);
+                               upper_left = upper;
+                       }
+               }
+               //wer = number of edits / length
+               return (double)p[n] / Math.max(n, m);
+       }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinWerTest.java
 
b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinWerTest.java
new file mode 100644
index 0000000000..eb322f6973
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinWerTest.java
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin.part2;
+
+import java.util.HashMap;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class BuiltinWerTest extends AutomatedTestBase {
+       private final static String TEST_NAME = "wer";
+       private final static String TEST_DIR = "functions/builtin/";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
BuiltinWerTest.class.getSimpleName() + "/";
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"W"}));
+       }
+
+       @Test
+       public void testCP() {
+               runWerTest(ExecType.CP);
+       }
+
+//     @Test
+//     public void testSpark() {
+//             runWerTest(ExecType.SPARK);
+//     }
+
+       private void runWerTest(ExecType instType) {
+               ExecMode platformOld = setExecMode(instType);
+
+               try {
+                       loadTestConfiguration(getTestConfiguration(TEST_NAME));
+
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       programArgs = new String[] {"-explain","-args", 
output("W")};
+                       runTest(null);
+                       
+                       HashMap<CellIndex, Double> ret = 
readDMLMatrixFromOutputDir("W");
+                       Assert.assertEquals(ret.get(new CellIndex(1,1)), 0, 
1e-14);
+                       Assert.assertEquals(ret.get(new CellIndex(2,1)), 1d/4, 
1e-14);
+                       Assert.assertEquals(ret.get(new CellIndex(3,1)), 2d/5, 
1e-14);
+                       Assert.assertEquals(ret.get(new CellIndex(4,1)), 5d/5, 
1e-14);
+               }
+               finally {
+                       rtplatform = platformOld;
+               }
+       }
+}
diff --git a/src/test/scripts/functions/builtin/wer.dml 
b/src/test/scripts/functions/builtin/wer.dml
new file mode 100644
index 0000000000..2a77cfc477
--- /dev/null
+++ b/src/test/scripts/functions/builtin/wer.dml
@@ -0,0 +1,37 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# basic WER comparison
+R = rbind(
+  as.frame("SystemDS is great"),
+  as.frame("SystemDS is great"),
+  as.frame("SystemDS is great"),
+  as.frame("SystemDS is great"));
+H = rbind(
+  as.frame("SystemDS is great"),
+  as.frame("SystemDS is really great"),
+  as.frame("SystemDS is really really great"),
+  as.frame("This test should be extended"));
+
+# WER plus epsilon to ensure non-zero representaiton
+W = wer(R, H) + 1e-16;
+
+write(W, $1)

Reply via email to