This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 5015f63a79 [SYSTEMDS-3709] Additional tests for UDF backwards 
compatibility
5015f63a79 is described below

commit 5015f63a7980f36e832bdffcbebba575cf8ddd62
Author: Matthias Boehm <mboe...@gmail.com>
AuthorDate: Fri Jun 7 09:44:45 2024 +0200

    [SYSTEMDS-3709] Additional tests for UDF backwards compatibility
    
    This patch adds tests for the old SystemML UDF MultiInputCbind,
    ensuring the related DML script is properly compiled to an nary cbind
    and if the inputs are vectors and are reshaped to vectors, we also
    eliminate the unnecessary reshape.
---
 .../matrix/UDFBackwardsCompatibilityTest.java      | 48 ++++++++++++++++++----
 ...owClassMeetTest.dml => MultiInputCbindTest.dml} | 10 ++++-
 .../functions/binary/matrix/RowClassMeetTest.dml   |  2 +-
 3 files changed, 48 insertions(+), 12 deletions(-)

diff --git 
a/src/test/java/org/apache/sysds/test/functions/binary/matrix/UDFBackwardsCompatibilityTest.java
 
b/src/test/java/org/apache/sysds/test/functions/binary/matrix/UDFBackwardsCompatibilityTest.java
index f4961efc55..44cca625ea 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/binary/matrix/UDFBackwardsCompatibilityTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/binary/matrix/UDFBackwardsCompatibilityTest.java
@@ -19,16 +19,20 @@
 
 package org.apache.sysds.test.functions.binary.matrix;
 
+import org.junit.Assert;
 import org.junit.Test;
+
 import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.common.Types.ExecType;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
 
 public class UDFBackwardsCompatibilityTest extends AutomatedTestBase 
 {
        private final static String TEST_NAME1 = "RowClassMeetTest";
+       private final static String TEST_NAME2 = "MultiInputCbindTest";
        private final static String TEST_DIR = "functions/binary/matrix/";
        private final static String TEST_CLASS_DIR = TEST_DIR + 
                UDFBackwardsCompatibilityTest.class.getSimpleName() + "/";
@@ -44,29 +48,46 @@ public class UDFBackwardsCompatibilityTest extends 
AutomatedTestBase
        public void setUp() {
                addTestConfiguration( TEST_NAME1,
                        new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new 
String[] { "C" }) );
+               addTestConfiguration( TEST_NAME2,
+                       new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new 
String[] { "C" }) );
        }
 
        @Test
        public void testRowClassMeetDenseDense() {
-               runUDFTest(TEST_NAME1, false, false, ExecType.CP);
+               runUDFTest(TEST_NAME1, false, false, false, false, ExecType.CP);
        }
        
        @Test
        public void testRowClassMeetDenseSparse() {
-               runUDFTest(TEST_NAME1, false, true, ExecType.CP);
+               runUDFTest(TEST_NAME1, false, true, false, false, ExecType.CP);
        }
        
        @Test
        public void testRowClassMeetSparseDense() {
-               runUDFTest(TEST_NAME1, true, false, ExecType.CP);
+               runUDFTest(TEST_NAME1, true, false, false, false, ExecType.CP);
        }
        
        @Test
        public void testRowClassMeetSparseSparse() {
-               runUDFTest(TEST_NAME1, true, true, ExecType.CP);
+               runUDFTest(TEST_NAME1, true, true, false, false, ExecType.CP);
+       }
+       
+       @Test
+       public void testMultiInputCBindDenseDenseMatrixMatrix() {
+               runUDFTest(TEST_NAME2, false, false, false, false, ExecType.CP);
+       }
+       
+       @Test
+       public void testMultiInputCBindDenseDenseMatrixVector() {
+               runUDFTest(TEST_NAME2, false, false, false, true, ExecType.CP);
+       }
+       
+       @Test
+       public void testMultiInputCBindDenseDenseVectorVector() {
+               runUDFTest(TEST_NAME2, false, false, true, true, ExecType.CP);
        }
        
-       private void runUDFTest(String testname, boolean sparseM1, boolean 
sparseM2, ExecType instType)
+       private void runUDFTest(String testname, boolean sparseM1, boolean 
sparseM2, boolean vectorData, boolean vectorize, ExecType instType)
        {
                ExecMode platformOld = setExecMode(instType);
                String TEST_NAME = testname;
@@ -76,18 +97,27 @@ public class UDFBackwardsCompatibilityTest extends 
AutomatedTestBase
                        
                        String HOME = SCRIPT_DIR + TEST_DIR;
                        fullDMLScriptName = HOME + TEST_NAME + ".dml";
-                       programArgs = new String[]{"-explain","-args", 
input("A"), input("B"), output("C")};
+                       programArgs = new String[]{"-stats", "-explain","-args",
+                               input("A"), input("B"), 
String.valueOf(vectorize).toUpperCase(), output("C")};
                        
                        //generate actual dataset
+                       int nr = vectorData ? rows*cols : rows;
+                       int nc = vectorData ? 1 : cols;
+                       
                        double[][] A = TestUtils.round(
-                               getRandomMatrix(rows, cols, 0, 10, 
sparseM1?sparsity2:sparsity1, 7));
+                               getRandomMatrix(nr, nc, 0, 10, 
sparseM1?sparsity2:sparsity1, 7));
                        writeInputMatrixWithMTD("A", A, false);
                        double[][] B = TestUtils.round(
-                               getRandomMatrix(rows, cols, 0, 10, 
sparseM2?sparsity2:sparsity1, 3));
+                               getRandomMatrix(nr, nc, 0, 10, 
sparseM2?sparsity2:sparsity1, 3));
                        writeInputMatrixWithMTD("B", B, false);
        
                        //run test case
-                       runTest(true, false, null, -1); 
+                       runTest(true, false, null, -1);
+                       
+                       if( TEST_NAME.equals(TEST_NAME2) ) //check nary cbind
+                               Assert.assertEquals(1, 
Statistics.getCPHeavyHitterCount("cbind"));
+                       if( vectorData && vectorize ) //check eliminated reshape
+                               
Assert.assertFalse(heavyHittersContainsString("rshape"));
                }
                finally {
                        rtplatform = platformOld;
diff --git a/src/test/scripts/functions/binary/matrix/RowClassMeetTest.dml 
b/src/test/scripts/functions/binary/matrix/MultiInputCbindTest.dml
similarity index 86%
copy from src/test/scripts/functions/binary/matrix/RowClassMeetTest.dml
copy to src/test/scripts/functions/binary/matrix/MultiInputCbindTest.dml
index 9975f8d99d..77445023f6 100644
--- a/src/test/scripts/functions/binary/matrix/RowClassMeetTest.dml
+++ b/src/test/scripts/functions/binary/matrix/MultiInputCbindTest.dml
@@ -21,6 +21,12 @@
 
 A = read($1);
 B = read($2);
-[C,N] = rowClassMeet(A, B);
-write(C, $3);
+
+if( as.logical($3) ) {
+  A = matrix(A, rows=length(A), cols=1)
+  B = matrix(B, rows=length(B), cols=1)
+}
+
+R = cbind(cbind(A, B), A);
+write(R, $4);
 
diff --git a/src/test/scripts/functions/binary/matrix/RowClassMeetTest.dml 
b/src/test/scripts/functions/binary/matrix/RowClassMeetTest.dml
index 9975f8d99d..f2d9da3ae8 100644
--- a/src/test/scripts/functions/binary/matrix/RowClassMeetTest.dml
+++ b/src/test/scripts/functions/binary/matrix/RowClassMeetTest.dml
@@ -22,5 +22,5 @@
 A = read($1);
 B = read($2);
 [C,N] = rowClassMeet(A, B);
-write(C, $3);
+write(C, $4);
 

Reply via email to