This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
commit 20e7dd0800fb69f73b770463b3b90fde283219f5 Author: Matthias Boehm <mboe...@gmail.com> AuthorDate: Wed Apr 17 20:59:59 2024 +0200 [SYSTEMDS-3538] New word-error-rate builtin function for TPCx-AI This patch adds the word-error-rate (WER) builtin function, which is derived from the Levenshtein distance but on words instead of characters. Part of the feature pack for implementing TPCx-AI on SystemDS. --- scripts/builtin/wer.dml | 43 +++++++++++++ .../java/org/apache/sysds/common/Builtins.java | 1 + .../apache/sysds/runtime/util/UtilFunctions.java | 66 ++++++++++++++++++- .../functions/builtin/part2/BuiltinWerTest.java | 75 ++++++++++++++++++++++ src/test/scripts/functions/builtin/wer.dml | 37 +++++++++++ 5 files changed, 219 insertions(+), 3 deletions(-) diff --git a/scripts/builtin/wer.dml b/scripts/builtin/wer.dml new file mode 100644 index 0000000000..fd48f63dde --- /dev/null +++ b/scripts/builtin/wer.dml @@ -0,0 +1,43 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# This built-in function computes the word error rate (WER) +# defined as wer = (numSubst + numDel + numIns) / length(r) +# +# INPUT: +# ------------------------------------------------------------------------------ +# R Input frame of reference strings, shape: [N x 1] +# H Input frame of hypothesis strings, shape: [N x 1] +# ------------------------------------------------------------------------------ +# +# OUTPUT: +# ------------------------------------------------------------------------------ +# W Output matrix of word error rate per pair of strings, +# shape: [N x 1], where W[i,1] = wer(R[i,1], H[i,1]) +# ------------------------------------------------------------------------------ + +f_wer = function(Frame[String] R, Frame[String] H) + return (Matrix[Double] W) +{ + F = cbind(R, H); + T = map(F, "x -> new String[]{String.valueOf(UtilFunctions.getWordErrorRate(x[0], x[1])),\"\"}", margin=1); + W = as.matrix(T[,1]); +} diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java index 82b6441f58..b510698cd4 100644 --- a/src/main/java/org/apache/sysds/common/Builtins.java +++ b/src/main/java/org/apache/sysds/common/Builtins.java @@ -328,6 +328,7 @@ public enum Builtins { VAR("var", false), VALUE_SWAP("valueSwap", false), VECTOR_TO_CSV("vectorToCsv", true), + WER("wer", true, false), WINSORIZE("winsorize", true, false), //TODO parameterize w/ prob, min/max val WINSORIZEAPPLY("winsorizeApply", true, false), //TODO parameterize w/ prob, min/max val WOE("WoE", true ), diff --git a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java index 619dd4467e..b4a809a280 100644 --- a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java +++ b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java @@ -445,7 +445,7 @@ public class UtilFunctions { * * @param low lower bound (inclusive) * @param up upper bound (inclusive) - * @param incr increment + * @param incr increment * @return list of integers */ public static List<Integer> getSeqList(int low, int up, int incr) { @@ -460,7 +460,7 @@ public class UtilFunctions { * * @param low lower bound (inclusive) * @param up upper bound (inclusive) - * @param incr increment + * @param incr increment * @return array of integers */ public static int[] getSeqArray(int low, int up, int incr) { @@ -778,7 +778,7 @@ public class UtilFunctions { byte ret = Byte.MIN_VALUE; for( int i=0; i<array.length; i++ ) ret = (array[i]>ret)?array[i]:ret; - return ret; + return ret; } public static String unquote(String s) { @@ -1363,4 +1363,64 @@ public class UtilFunctions { return Arrays.stream(original).mapToDouble(Double::parseDouble).toArray(); } + + /** + * Computes the word error rate (Levenshtein distance at word level): + * wer = (numSubst + numDel + numIns) / length(r) + * + * This code has been adapted from Apache Commons Lang 3.12 + * (getLevenshteinDistance, but for words instead of characters). + * + * @param r reference string + * @param h hypothesis string + * @return word error rate (WER) + */ + public static double getWordErrorRate(String r, String h) { + if (r == null || h == null) { + throw new IllegalArgumentException("Strings must not be null"); + } + + //prepare string sequences + String[] s = r.split(" "); + String[] t = h.split(" "); + int n = s.length; + int m = t.length; + + //basic size handling + if( n == 0 || m == 0 ) + return Math.max(n, m); + if (n > m) { + // swap the input strings to consume less memory + String[] tmp = s; + s = t; + t = tmp; + n = m; + m = t.length; + } + + final int[] p = new int[n + 1]; + int i; // iterates through s + int j; // iterates through t + int upper_left; + int upper; + + String t_j; // jth word of t + int cost; + for (i = 0; i <= n; i++) { + p[i] = i; + } + for (j = 1; j <= m; j++) { + upper_left = p[0]; + t_j = t[j - 1]; + p[0] = j; + for (i = 1; i <= n; i++) { + upper = p[i]; + cost = s[i - 1].equals(t_j) ? 0 : 1; + p[i] = Math.min(Math.min(p[i - 1] + 1, p[i] + 1), upper_left + cost); + upper_left = upper; + } + } + //wer = number of edits / length + return (double)p[n] / Math.max(n, m); + } } diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinWerTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinWerTest.java new file mode 100644 index 0000000000..eb322f6973 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinWerTest.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.builtin.part2; + +import java.util.HashMap; + +import org.apache.sysds.common.Types.ExecMode; +import org.apache.sysds.common.Types.ExecType; +import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +public class BuiltinWerTest extends AutomatedTestBase { + private final static String TEST_NAME = "wer"; + private final static String TEST_DIR = "functions/builtin/"; + private final static String TEST_CLASS_DIR = TEST_DIR + BuiltinWerTest.class.getSimpleName() + "/"; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"W"})); + } + + @Test + public void testCP() { + runWerTest(ExecType.CP); + } + +// @Test +// public void testSpark() { +// runWerTest(ExecType.SPARK); +// } + + private void runWerTest(ExecType instType) { + ExecMode platformOld = setExecMode(instType); + + try { + loadTestConfiguration(getTestConfiguration(TEST_NAME)); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-explain","-args", output("W")}; + runTest(null); + + HashMap<CellIndex, Double> ret = readDMLMatrixFromOutputDir("W"); + Assert.assertEquals(ret.get(new CellIndex(1,1)), 0, 1e-14); + Assert.assertEquals(ret.get(new CellIndex(2,1)), 1d/4, 1e-14); + Assert.assertEquals(ret.get(new CellIndex(3,1)), 2d/5, 1e-14); + Assert.assertEquals(ret.get(new CellIndex(4,1)), 5d/5, 1e-14); + } + finally { + rtplatform = platformOld; + } + } +} diff --git a/src/test/scripts/functions/builtin/wer.dml b/src/test/scripts/functions/builtin/wer.dml new file mode 100644 index 0000000000..2a77cfc477 --- /dev/null +++ b/src/test/scripts/functions/builtin/wer.dml @@ -0,0 +1,37 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# basic WER comparison +R = rbind( + as.frame("SystemDS is great"), + as.frame("SystemDS is great"), + as.frame("SystemDS is great"), + as.frame("SystemDS is great")); +H = rbind( + as.frame("SystemDS is great"), + as.frame("SystemDS is really great"), + as.frame("SystemDS is really really great"), + as.frame("This test should be extended")); + +# WER plus epsilon to ensure non-zero representaiton +W = wer(R, H) + 1e-16; + +write(W, $1)