This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push: new 5015f63a79 [SYSTEMDS-3709] Additional tests for UDF backwards compatibility 5015f63a79 is described below commit 5015f63a7980f36e832bdffcbebba575cf8ddd62 Author: Matthias Boehm <mboe...@gmail.com> AuthorDate: Fri Jun 7 09:44:45 2024 +0200 [SYSTEMDS-3709] Additional tests for UDF backwards compatibility This patch adds tests for the old SystemML UDF MultiInputCbind, ensuring the related DML script is properly compiled to an nary cbind and if the inputs are vectors and are reshaped to vectors, we also eliminate the unnecessary reshape. --- .../matrix/UDFBackwardsCompatibilityTest.java | 48 ++++++++++++++++++---- ...owClassMeetTest.dml => MultiInputCbindTest.dml} | 10 ++++- .../functions/binary/matrix/RowClassMeetTest.dml | 2 +- 3 files changed, 48 insertions(+), 12 deletions(-) diff --git a/src/test/java/org/apache/sysds/test/functions/binary/matrix/UDFBackwardsCompatibilityTest.java b/src/test/java/org/apache/sysds/test/functions/binary/matrix/UDFBackwardsCompatibilityTest.java index f4961efc55..44cca625ea 100644 --- a/src/test/java/org/apache/sysds/test/functions/binary/matrix/UDFBackwardsCompatibilityTest.java +++ b/src/test/java/org/apache/sysds/test/functions/binary/matrix/UDFBackwardsCompatibilityTest.java @@ -19,16 +19,20 @@ package org.apache.sysds.test.functions.binary.matrix; +import org.junit.Assert; import org.junit.Test; + import org.apache.sysds.common.Types.ExecMode; import org.apache.sysds.common.Types.ExecType; import org.apache.sysds.test.AutomatedTestBase; import org.apache.sysds.test.TestConfiguration; import org.apache.sysds.test.TestUtils; +import org.apache.sysds.utils.Statistics; public class UDFBackwardsCompatibilityTest extends AutomatedTestBase { private final static String TEST_NAME1 = "RowClassMeetTest"; + private final static String TEST_NAME2 = "MultiInputCbindTest"; private final static String TEST_DIR = "functions/binary/matrix/"; private final static String TEST_CLASS_DIR = TEST_DIR + UDFBackwardsCompatibilityTest.class.getSimpleName() + "/"; @@ -44,29 +48,46 @@ public class UDFBackwardsCompatibilityTest extends AutomatedTestBase public void setUp() { addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "C" }) ); + addTestConfiguration( TEST_NAME2, + new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "C" }) ); } @Test public void testRowClassMeetDenseDense() { - runUDFTest(TEST_NAME1, false, false, ExecType.CP); + runUDFTest(TEST_NAME1, false, false, false, false, ExecType.CP); } @Test public void testRowClassMeetDenseSparse() { - runUDFTest(TEST_NAME1, false, true, ExecType.CP); + runUDFTest(TEST_NAME1, false, true, false, false, ExecType.CP); } @Test public void testRowClassMeetSparseDense() { - runUDFTest(TEST_NAME1, true, false, ExecType.CP); + runUDFTest(TEST_NAME1, true, false, false, false, ExecType.CP); } @Test public void testRowClassMeetSparseSparse() { - runUDFTest(TEST_NAME1, true, true, ExecType.CP); + runUDFTest(TEST_NAME1, true, true, false, false, ExecType.CP); + } + + @Test + public void testMultiInputCBindDenseDenseMatrixMatrix() { + runUDFTest(TEST_NAME2, false, false, false, false, ExecType.CP); + } + + @Test + public void testMultiInputCBindDenseDenseMatrixVector() { + runUDFTest(TEST_NAME2, false, false, false, true, ExecType.CP); + } + + @Test + public void testMultiInputCBindDenseDenseVectorVector() { + runUDFTest(TEST_NAME2, false, false, true, true, ExecType.CP); } - private void runUDFTest(String testname, boolean sparseM1, boolean sparseM2, ExecType instType) + private void runUDFTest(String testname, boolean sparseM1, boolean sparseM2, boolean vectorData, boolean vectorize, ExecType instType) { ExecMode platformOld = setExecMode(instType); String TEST_NAME = testname; @@ -76,18 +97,27 @@ public class UDFBackwardsCompatibilityTest extends AutomatedTestBase String HOME = SCRIPT_DIR + TEST_DIR; fullDMLScriptName = HOME + TEST_NAME + ".dml"; - programArgs = new String[]{"-explain","-args", input("A"), input("B"), output("C")}; + programArgs = new String[]{"-stats", "-explain","-args", + input("A"), input("B"), String.valueOf(vectorize).toUpperCase(), output("C")}; //generate actual dataset + int nr = vectorData ? rows*cols : rows; + int nc = vectorData ? 1 : cols; + double[][] A = TestUtils.round( - getRandomMatrix(rows, cols, 0, 10, sparseM1?sparsity2:sparsity1, 7)); + getRandomMatrix(nr, nc, 0, 10, sparseM1?sparsity2:sparsity1, 7)); writeInputMatrixWithMTD("A", A, false); double[][] B = TestUtils.round( - getRandomMatrix(rows, cols, 0, 10, sparseM2?sparsity2:sparsity1, 3)); + getRandomMatrix(nr, nc, 0, 10, sparseM2?sparsity2:sparsity1, 3)); writeInputMatrixWithMTD("B", B, false); //run test case - runTest(true, false, null, -1); + runTest(true, false, null, -1); + + if( TEST_NAME.equals(TEST_NAME2) ) //check nary cbind + Assert.assertEquals(1, Statistics.getCPHeavyHitterCount("cbind")); + if( vectorData && vectorize ) //check eliminated reshape + Assert.assertFalse(heavyHittersContainsString("rshape")); } finally { rtplatform = platformOld; diff --git a/src/test/scripts/functions/binary/matrix/RowClassMeetTest.dml b/src/test/scripts/functions/binary/matrix/MultiInputCbindTest.dml similarity index 86% copy from src/test/scripts/functions/binary/matrix/RowClassMeetTest.dml copy to src/test/scripts/functions/binary/matrix/MultiInputCbindTest.dml index 9975f8d99d..77445023f6 100644 --- a/src/test/scripts/functions/binary/matrix/RowClassMeetTest.dml +++ b/src/test/scripts/functions/binary/matrix/MultiInputCbindTest.dml @@ -21,6 +21,12 @@ A = read($1); B = read($2); -[C,N] = rowClassMeet(A, B); -write(C, $3); + +if( as.logical($3) ) { + A = matrix(A, rows=length(A), cols=1) + B = matrix(B, rows=length(B), cols=1) +} + +R = cbind(cbind(A, B), A); +write(R, $4); diff --git a/src/test/scripts/functions/binary/matrix/RowClassMeetTest.dml b/src/test/scripts/functions/binary/matrix/RowClassMeetTest.dml index 9975f8d99d..f2d9da3ae8 100644 --- a/src/test/scripts/functions/binary/matrix/RowClassMeetTest.dml +++ b/src/test/scripts/functions/binary/matrix/RowClassMeetTest.dml @@ -22,5 +22,5 @@ A = read($1); B = read($2); [C,N] = rowClassMeet(A, B); -write(C, $3); +write(C, $4);