This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new da7889eb1f [SYSTEMDS-3709] Fix backwards compatibility (rowClassMeet
UDF)
da7889eb1f is described below
commit da7889eb1f7e1a831903fc513aa3906f995e2dfd
Author: Matthias Boehm <[email protected]>
AuthorDate: Thu Jun 6 13:21:49 2024 +0200
[SYSTEMDS-3709] Fix backwards compatibility (rowClassMeet UDF)
This patch reintroduces a former UDF (shipped in a UDF library) as a
multi-return builtin function. Since even in SystemML 1.2 tests were
missing for this UDF, we now also add various tests for combinations
of dense/sparse inputs.
---
.../java/org/apache/sysds/common/Builtins.java | 1 +
.../java/org/apache/sysds/hops/FunctionOp.java | 11 +-
.../sysds/parser/BuiltinFunctionExpression.java | 17 +++
.../org/apache/sysds/parser/DMLTranslator.java | 1 +
.../runtime/instructions/CPInstructionParser.java | 3 +-
.../cp/MultiReturnBuiltinCPInstruction.java | 8 +-
...ltiReturnComplexMatrixBuiltinCPInstruction.java | 8 ++
.../sysds/runtime/matrix/data/LibCommonsMath.java | 138 ++++++++++++++++++++-
.../matrix/UDFBackwardsCompatibilityTest.java | 96 ++++++++++++++
.../functions/binary/matrix/RowClassMeetTest.dml | 26 ++++
10 files changed, 302 insertions(+), 7 deletions(-)
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java
b/src/main/java/org/apache/sysds/common/Builtins.java
index b9cda07791..a358b2e2bc 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -277,6 +277,7 @@ public enum Builtins {
RANGE("range", false),
RASELECTION("raSelection", true),
RBIND("rbind", false),
+ RCM("rowClassMeet", "rcm", false, false, ReturnType.MULTI_RETURN),
REMOVE("remove", false, ReturnType.MULTI_RETURN),
REV("rev", false),
ROUND("round", false),
diff --git a/src/main/java/org/apache/sysds/hops/FunctionOp.java
b/src/main/java/org/apache/sysds/hops/FunctionOp.java
index 7f424d36d0..f612440075 100644
--- a/src/main/java/org/apache/sysds/hops/FunctionOp.java
+++ b/src/main/java/org/apache/sysds/hops/FunctionOp.java
@@ -251,8 +251,13 @@ public class FunctionOp extends MultiThreadedHop
long outputV =
OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(2).getDim1(),
getOutputs().get(2).getDim2(), 1.0);
return outputU+outputSigma+outputV;
}
+ else if( getFunctionName().equalsIgnoreCase("rcm") ) {
+ long nr = Math.max(getInput(0).getDim1(),
getInput(1).getDim1());
+ long nc = Math.max(getInput(0).getDim2(),
getInput(1).getDim2());
+ return
2*OptimizerUtils.estimateSizeExactSparsity(nr, nc, 1.0);
+ }
else
- throw new RuntimeException("Invalid call of
computeOutputMemEstimate in FunctionOp.");
+ throw new RuntimeException("Invalid call of
computeOutputMemEstimate in FunctionOp: "+getFunctionName());
}
}
@@ -299,7 +304,9 @@ public class FunctionOp extends MultiThreadedHop
getFunctionName().equalsIgnoreCase("batch_norm2d_train") ||
getFunctionName().equalsIgnoreCase("batch_norm2d_test")) {
return 0;
}
- else if ( getFunctionName().equalsIgnoreCase("lstm") ||
getFunctionName().equalsIgnoreCase("lstm_backward") ) {
+ else if ( getFunctionName().equalsIgnoreCase("lstm")
+ ||
getFunctionName().equalsIgnoreCase("lstm_backward")
+ ||
getFunctionName().equalsIgnoreCase("rcm")) {
// TODO: To allow for initial version to always
run on the GPU
return 0;
}
diff --git
a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
index c3f1026627..ec9b4a4bbd 100644
--- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
@@ -380,6 +380,23 @@ public class BuiltinFunctionExpression extends
DataIdentifier {
break;
}
+ case RCM: {
+ checkNumParameters(2);
+ checkMatrixParam(getFirstExpr());
+ checkMatrixParam(getSecondExpr());
+ long nr = Math.max(getFirstExpr().getOutput().getDim1(),
+ getSecondExpr().getOutput().getDim1());
+ long nc = Math.max(getFirstExpr().getOutput().getDim2(),
+ getSecondExpr().getOutput().getDim2());
+ for(int i=0; i<2; i++) {
+ DataIdentifier out = (DataIdentifier)
getOutputs()[i];
+ out.setDataType(DataType.MATRIX);
+ out.setValueType(ValueType.FP64);
+ out.setDimensions(nr, nc);
+
out.setBlocksize(getFirstExpr().getOutput().getBlocksize());
+ }
+ break;
+ }
case FFT: {
Expression expressionOne = getFirstExpr();
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index 2b876c12be..5ff351da4c 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -2247,6 +2247,7 @@ public class DMLTranslator
case BATCH_NORM2D_BACKWARD:
case REMOVE:
case SVD:
+ case RCM:
// Number of outputs = size of targetList = #of
identifiers in source.getOutputs
String[] outputNames = new
String[targetList.size()];
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
index 994c1cd51a..792eace24e 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
@@ -340,7 +340,8 @@ public class CPInstructionParser extends InstructionParser {
String2CPInstructionType.put( "ifft_linearized",
CPType.MultiReturnComplexMatrixBuiltin);
String2CPInstructionType.put( "stft",
CPType.MultiReturnComplexMatrixBuiltin);
String2CPInstructionType.put( "svd",
CPType.MultiReturnBuiltin);
-
+ String2CPInstructionType.put( "rcm",
CPType.MultiReturnComplexMatrixBuiltin);
+
String2CPInstructionType.put( "partition", CPType.Partition);
String2CPInstructionType.put( Compression.OPCODE,
CPType.Compression);
String2CPInstructionType.put( DeCompression.OPCODE,
CPType.DeCompression);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/MultiReturnBuiltinCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/MultiReturnBuiltinCPInstruction.java
index e1e0420513..cb4651182b 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/MultiReturnBuiltinCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/MultiReturnBuiltinCPInstruction.java
@@ -47,6 +47,13 @@ public class MultiReturnBuiltinCPInstruction extends
ComputationCPInstruction {
_numThreads = threads;
}
+ private MultiReturnBuiltinCPInstruction(Operator op, CPOperand input1,
CPOperand input2, ArrayList<CPOperand> outputs, String opcode,
+ String istr, int threads) {
+ super(CPType.MultiReturnBuiltin, op, input1, input2,
outputs.get(0), opcode, istr);
+ _outputs = outputs;
+ _numThreads = threads;
+ }
+
public CPOperand getOutput(int i) {
return _outputs.get(i);
}
@@ -151,7 +158,6 @@ public class MultiReturnBuiltinCPInstruction extends
ComputationCPInstruction {
int threads = Integer.parseInt(parts[5]);
return new MultiReturnBuiltinCPInstruction(null, in1,
outputs, opcode, str, threads);
-
}
else {
throw new DMLRuntimeException("Invalid opcode in
MultiReturnBuiltin instruction: " + opcode);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/MultiReturnComplexMatrixBuiltinCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/MultiReturnComplexMatrixBuiltinCPInstruction.java
index 3238b9381a..40b9191044 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/MultiReturnComplexMatrixBuiltinCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/MultiReturnComplexMatrixBuiltinCPInstruction.java
@@ -139,6 +139,14 @@ public class MultiReturnComplexMatrixBuiltinCPInstruction
extends ComputationCPI
return new
MultiReturnComplexMatrixBuiltinCPInstruction(null, in1, in2, windowSize,
overlap, outputs, opcode,
str, threads);
}
+ else if ( opcode.equalsIgnoreCase("rcm") ) {
+ CPOperand in1 = new CPOperand(parts[1]);
+ CPOperand in2 = new CPOperand(parts[2]);
+ outputs.add ( new CPOperand(parts[3], ValueType.FP64,
DataType.MATRIX) );
+ outputs.add ( new CPOperand(parts[4], ValueType.FP64,
DataType.MATRIX) );
+ int threads = Integer.parseInt(parts[5]);
+ return new
MultiReturnComplexMatrixBuiltinCPInstruction(null, in1, in2, outputs, opcode,
str, threads);
+ }
else {
throw new DMLRuntimeException("Invalid opcode in
MultiReturnBuiltin instruction: " + opcode);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java
index d53b30e713..8757fda822 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java
@@ -24,6 +24,10 @@ import static
org.apache.sysds.runtime.matrix.data.LibMatrixFourier.fft_lineariz
import static org.apache.sysds.runtime.matrix.data.LibMatrixFourier.ifft;
import static
org.apache.sysds.runtime.matrix.data.LibMatrixFourier.ifft_linearized;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Map.Entry;
+
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.commons.math3.exception.MaxCountExceededException;
@@ -37,6 +41,9 @@ import org.apache.commons.math3.linear.QRDecomposition;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.SingularValueDecomposition;
import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.codegen.CodegenUtils;
+import org.apache.sysds.runtime.codegen.SpoofOperator.SideInput;
+import org.apache.sysds.runtime.compress.utils.IntArrayList;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.Divide;
@@ -53,6 +60,7 @@ import
org.apache.sysds.runtime.matrix.operators.ScalarOperator;
import org.apache.sysds.runtime.matrix.operators.TernaryOperator;
import org.apache.sysds.runtime.matrix.operators.UnaryOperator;
import org.apache.sysds.runtime.util.DataConverter;
+import org.apache.sysds.runtime.util.UtilFunctions;
/**
* Library for matrix operations that need invocation of
@@ -78,13 +86,14 @@ public class LibCommonsMath
public static boolean isSupportedMultiReturnOperation( String opcode ) {
switch (opcode) {
- case "qr":
- case "lu":
case "eigen":
case "fft":
- case "ifft":
case "fft_linearized":
+ case "ifft":
case "ifft_linearized":
+ case "lu":
+ case "qr":
+ case "rcm":
case "stft":
case "svd": return true;
default: return false;
@@ -158,6 +167,8 @@ public class LibCommonsMath
return computeIFFT(in1, in2, threads);
case "ifft_linearized":
return computeIFFT_LINEARIZED(in1, in2,
threads);
+ case "rcm":
+ return computeRCM(in1, in2);
default:
return null;
}
@@ -816,4 +827,125 @@ public class LibCommonsMath
evec.init(eVectors, n, n);
return new MatrixBlock[] {eval, evec};
}
+
+ /**
+ * Performs following operation:
+ * Computes the intersection ("meet") of equivalence classes for
+ * each row of A and B, excluding 0-valued cells.
+ * INPUT:
+ * A, B = matrices whose rows contain that row's class labels;
+ * for each i, rows A [i, ] and B [i, ] define two
+ * equivalence relations on some of the columns, which
+ * we want to intersect
+ * A [i, j] == A [i, k] != 0 if and only if (j ~ k) as defined
+ * by row A [i, ];
+ * A [i, j] == 0 means that j is excluded by A [i, ]
+ * B [i, j] is analogous
+ * NOTE 1: Either nrow(A) == nrow(B), or exactly one of A or B
+ * has one row that "applies" to each row of the other matrix.
+ * NOTE 2: If ncol(A) != ncol(B), we pad extra 0-columns up to
+ * max (ncol(A), ncol(B)).
+ * OUTPUT:
+ * Both C and N have the same size as (the max of) A and B.
+ * C = matrix whose rows contain class labels that represent
+ * the intersection (coarsest common refinement) of the
+ * corresponding rows of A and B.
+ * C [i, j] == C [i, k] != 0 if and only if (j ~ k) as defined
+ * by both A [i, ] and B [j, ]
+ * C [i, j] == 0 if and only if A [i, j] == 0 or B [i, j] == 0
+ * Additionally, we guarantee that non-0 labels in C [i, ]
+ * will be integers from 1 to max (C [i, ]) without gaps.
+ * For A and B the labels can be arbitrary.
+ * N = matrix with class-size information for C-cells
+ * N [i, j] = count of {C [i, k] | C [i, j] == C [i, k] != 0}
+ *
+ * @param A first input matrix
+ * @param B second input matrix
+ * @return output matrices C and N
+ */
+ private static MatrixBlock[] computeRCM(MatrixBlock A, MatrixBlock B) {
+ int nr = Math.max(A.getNumRows(), B.getNumRows());
+ int nc = Math.max(A.getNumColumns(), B.getNumColumns());
+ MatrixBlock C = new MatrixBlock(nr, nc, false).allocateBlock();
+ MatrixBlock N = new MatrixBlock(nr, nc, false).allocateBlock();
+ double[] dC = C.getDenseBlockValues();
+ double[] dN = N.getDenseBlockValues();
+ //wrap both A and B into side inputs for efficient sparse access
+ SideInput sB = CodegenUtils.createSideInput(B);
+ boolean mv = (B.getNumRows() == 1);
+ int numCols = Math.min(A.getNumColumns(), B.getNumColumns());
+
+ Map<ClassLabel, IntArrayList> classLabelMapping = new
HashMap<>();
+ for(int i=0, ai=0; i < A.getNumRows(); i++,
ai+=A.getNumColumns()) {
+ classLabelMapping.clear(); sB.reset();
+ if( A.isInSparseFormat() ) {
+ if(A.getSparseBlock()==null ||
A.getSparseBlock().isEmpty(i))
+ continue;
+ int alen = A.getSparseBlock().size(i);
+ int apos = A.getSparseBlock().pos(i);
+ int[] aix = A.getSparseBlock().indexes(i);
+ double[] avals = A.getSparseBlock().values(i);
+ for(int k=apos; k<apos+alen; k++) {
+ if( aix[k] >= numCols ) break;
+ int bval = (int)sB.getValue(mv?0:i,
aix[k]);
+ if( bval != 0 ) {
+ ClassLabel key = new
ClassLabel((int)avals[k], bval);
+
if(!classLabelMapping.containsKey(key))
+
classLabelMapping.put(key, new IntArrayList());
+
classLabelMapping.get(key).appendValue(aix[k]);
+ }
+ }
+ }
+ else {
+ double [] denseBlk = A.getDenseBlockValues();
+ if(denseBlk == null) break;
+ for(int j = 0; j < numCols; j++) {
+ int aVal = (int) denseBlk[ai+j];
+ int bVal = (int) sB.getValue(mv?0:i, j);
+ if(aVal != 0 && bVal != 0) {
+ ClassLabel key = new
ClassLabel(aVal, bVal);
+
if(!classLabelMapping.containsKey(key))
+
classLabelMapping.put(key, new IntArrayList());
+
classLabelMapping.get(key).appendValue(j);
+ }
+ }
+ }
+
+ int labelID = 1;
+ for(Entry<ClassLabel, IntArrayList> entry :
classLabelMapping.entrySet()) {
+ int nVal = entry.getValue().size();
+ int[] list = entry.getValue().extractValues();
+ for(int k=0, off=i*nc; k<nVal; k++) {
+ dN[off+list[k]] = nVal;
+ dC[off+list[k]] = labelID;
+ }
+ labelID++;
+ }
+ }
+
+ //prepare outputs
+ C.recomputeNonZeros(); C.examSparsity();
+ N.recomputeNonZeros(); N.examSparsity();
+ return new MatrixBlock[] {C, N};
+ }
+
+ private static class ClassLabel {
+ public int aVal;
+ public int bVal;
+ public ClassLabel(int aVal, int bVal) {
+ this.aVal = aVal;
+ this.bVal = bVal;
+ }
+ @Override
+ public int hashCode() {
+ return UtilFunctions.intHashCode(aVal, bVal);
+ }
+ @Override
+ public boolean equals(Object o) {
+ if( !(o instanceof ClassLabel) )
+ return false;
+ ClassLabel that = (ClassLabel) o;
+ return aVal == that.aVal && bVal == that.bVal;
+ }
+ }
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/binary/matrix/UDFBackwardsCompatibilityTest.java
b/src/test/java/org/apache/sysds/test/functions/binary/matrix/UDFBackwardsCompatibilityTest.java
new file mode 100644
index 0000000000..f4961efc55
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/binary/matrix/UDFBackwardsCompatibilityTest.java
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.binary.matrix;
+
+import org.junit.Test;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+
+public class UDFBackwardsCompatibilityTest extends AutomatedTestBase
+{
+ private final static String TEST_NAME1 = "RowClassMeetTest";
+ private final static String TEST_DIR = "functions/binary/matrix/";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
+ UDFBackwardsCompatibilityTest.class.getSimpleName() + "/";
+
+ private final static int rows = 1267;
+ private final static int cols = 56;
+
+ private final static double sparsity1 = 0.7;
+ private final static double sparsity2 = 0.1;
+
+
+ @Override
+ public void setUp() {
+ addTestConfiguration( TEST_NAME1,
+ new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new
String[] { "C" }) );
+ }
+
+ @Test
+ public void testRowClassMeetDenseDense() {
+ runUDFTest(TEST_NAME1, false, false, ExecType.CP);
+ }
+
+ @Test
+ public void testRowClassMeetDenseSparse() {
+ runUDFTest(TEST_NAME1, false, true, ExecType.CP);
+ }
+
+ @Test
+ public void testRowClassMeetSparseDense() {
+ runUDFTest(TEST_NAME1, true, false, ExecType.CP);
+ }
+
+ @Test
+ public void testRowClassMeetSparseSparse() {
+ runUDFTest(TEST_NAME1, true, true, ExecType.CP);
+ }
+
+ private void runUDFTest(String testname, boolean sparseM1, boolean
sparseM2, ExecType instType)
+ {
+ ExecMode platformOld = setExecMode(instType);
+ String TEST_NAME = testname;
+
+ try {
+ getAndLoadTestConfiguration(TEST_NAME);
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[]{"-explain","-args",
input("A"), input("B"), output("C")};
+
+ //generate actual dataset
+ double[][] A = TestUtils.round(
+ getRandomMatrix(rows, cols, 0, 10,
sparseM1?sparsity2:sparsity1, 7));
+ writeInputMatrixWithMTD("A", A, false);
+ double[][] B = TestUtils.round(
+ getRandomMatrix(rows, cols, 0, 10,
sparseM2?sparsity2:sparsity1, 3));
+ writeInputMatrixWithMTD("B", B, false);
+
+ //run test case
+ runTest(true, false, null, -1);
+ }
+ finally {
+ rtplatform = platformOld;
+ }
+ }
+}
diff --git a/src/test/scripts/functions/binary/matrix/RowClassMeetTest.dml
b/src/test/scripts/functions/binary/matrix/RowClassMeetTest.dml
new file mode 100644
index 0000000000..9975f8d99d
--- /dev/null
+++ b/src/test/scripts/functions/binary/matrix/RowClassMeetTest.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = read($1);
+B = read($2);
+[C,N] = rowClassMeet(A, B);
+write(C, $3);
+