This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 0664e1fd78 [MINOR] Add a few MatrixMult and asFrame Tests
0664e1fd78 is described below

commit 0664e1fd782dd34a5d3abce6db0f0a652bf9f0d3
Author: Sebastian Baunsgaard <[email protected]>
AuthorDate: Tue Oct 17 14:02:43 2023 +0200

    [MINOR] Add a few MatrixMult and asFrame Tests
---
 src/test/java/org/apache/sysds/test/TestUtils.java |   2 +-
 .../component/frame/FrameFromMatrixBlockTest.java  |  51 +++++-
 .../test/component/matrix/MatrixMultiplyTest.java  | 179 +++++++++++++++++++++
 3 files changed, 225 insertions(+), 7 deletions(-)

diff --git a/src/test/java/org/apache/sysds/test/TestUtils.java 
b/src/test/java/org/apache/sysds/test/TestUtils.java
index 907c9adab8..9e866e5b33 100644
--- a/src/test/java/org/apache/sysds/test/TestUtils.java
+++ b/src/test/java/org/apache/sysds/test/TestUtils.java
@@ -1396,7 +1396,7 @@ public class TestUtils {
                        if(countErrors != 0)
                                fail(message + "\n" + countErrors + " values 
are not in equal");
                        if(avgDistance > maxAveragePercentDistance)
-                               fail(message + "\nThe avg distance in bits: " + 
avgDistance + " was higher than max: " + maxAveragePercentDistance);
+                               fail(message + "\nThe avg distance in percent: 
" + avgDistance + " was higher than max: " + maxAveragePercentDistance);
                }
        }
 
diff --git 
a/src/test/java/org/apache/sysds/test/component/frame/FrameFromMatrixBlockTest.java
 
b/src/test/java/org/apache/sysds/test/component/frame/FrameFromMatrixBlockTest.java
index 76c7197322..bc0e242f9d 100644
--- 
a/src/test/java/org/apache/sysds/test/component/frame/FrameFromMatrixBlockTest.java
+++ 
b/src/test/java/org/apache/sysds/test/component/frame/FrameFromMatrixBlockTest.java
@@ -134,20 +134,48 @@ public class FrameFromMatrixBlockTest {
                verifyEquivalence(mb, fb, ValueType.FP64);
        }
 
+       @Test
+       public void random() {
+               MatrixBlock mb = TestUtils.generateTestMatrixBlock(100, 10, 0, 
199, 1.0, 213);
+               FrameBlock fb = FrameFromMatrixBlock.convertToFrameBlock(mb, 1);
+               verifyEquivalence(mb, fb);
+       }
+
+       @Test
+       public void randomRounded() {
+               MatrixBlock mb = 
TestUtils.ceil(TestUtils.generateTestMatrixBlock(100, 10, 0, 199, 1.0, 213));
+               FrameBlock fb = FrameFromMatrixBlock.convertToFrameBlock(mb, 1);
+               verifyEquivalence(mb, fb);
+       }
+
+       @Test
+       public void randomSparse() {
+               MatrixBlock mb = 
TestUtils.ceil(TestUtils.generateTestMatrixBlock(100, 10, 0, 199, 0.1, 213));
+               FrameBlock fb = FrameFromMatrixBlock.convertToFrameBlock(mb, 1);
+               verifyEquivalence(mb, fb);
+       }
+
+       @Test
+       public void randomVerySparse() {
+               MatrixBlock mb = 
TestUtils.ceil(TestUtils.generateTestMatrixBlock(100, 1000, 0, 199, 0.01, 213));
+               FrameBlock fb = FrameFromMatrixBlock.convertToFrameBlock(mb, 1);
+               verifyEquivalence(mb, fb);
+       }
+
        @Test
        public void timeChange() {
                // MatrixBlock mb = TestUtils.generateTestMatrixBlock(64000, 
2000, 1, 1, 0.5, 2340);
 
                // for(int i = 0; i < 10; i++) {
-               //      Timing time = new Timing(true);
-               //      FrameFromMatrixBlock.convertToFrameBlock(mb, 
ValueType.BOOLEAN, 1);
-               //      LOG.error(time.stop());
+               // Timing time = new Timing(true);
+               // FrameFromMatrixBlock.convertToFrameBlock(mb, 
ValueType.BOOLEAN, 1);
+               // LOG.error(time.stop());
                // }
 
                // for(int i = 0; i < 10; i++) {
-               //      Timing time = new Timing(true);
-               //      FrameFromMatrixBlock.convertToFrameBlock(mb, 
ValueType.BOOLEAN, 16);
-               //      LOG.error(time.stop());
+               // Timing time = new Timing(true);
+               // FrameFromMatrixBlock.convertToFrameBlock(mb, 
ValueType.BOOLEAN, 16);
+               // LOG.error(time.stop());
                // }
 
                // for(int i = 0; i < 10; i ++){
@@ -176,6 +204,17 @@ public class FrameFromMatrixBlockTest {
 
        }
 
+       private void verifyEquivalence(MatrixBlock mb, FrameBlock fb) {
+               int nRow = mb.getNumRows();
+               int nCol = mb.getNumColumns();
+               assertEquals(mb.getNumColumns(), fb.getSchema().length);
+
+               for(int i = 0; i < nRow; i++)
+                       for(int j = 0; j < nCol; j++)
+                               assertEquals(i + " " + j, mb.getValue(i, j), 
fb.getDouble(i, j), 0.0000001);
+
+       }
+
        private MatrixBlock mock(MatrixBlock m) {
                MatrixBlock ret = new MatrixBlock(m.getNumRows(), 
m.getNumColumns(),
                        new DenseBlockFP64Mock(new int[] {m.getNumRows(), 
m.getNumColumns()}, m.getDenseBlockValues()));
diff --git 
a/src/test/java/org/apache/sysds/test/component/matrix/MatrixMultiplyTest.java 
b/src/test/java/org/apache/sysds/test/component/matrix/MatrixMultiplyTest.java
new file mode 100644
index 0000000000..d35b47a4c6
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/component/matrix/MatrixMultiplyTest.java
@@ -0,0 +1,179 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.matrix;
+
+import static org.junit.Assert.fail;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.runtime.functionobjects.Multiply;
+import org.apache.sysds.runtime.functionobjects.Plus;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
+import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(value = Parameterized.class)
+public class MatrixMultiplyTest {
+       protected static final Log LOG = 
LogFactory.getLog(MatrixMultiplyTest.class.getName());
+
+       // left side
+       private final MatrixBlock left;
+       // right side
+       private final MatrixBlock right;
+       // expected result
+       private final MatrixBlock exp;
+       // parallelization degree
+       private final int k;
+
+       public MatrixMultiplyTest(int i, int j, int k, double s, double s2, int 
p) {
+               try {
+                       this.left = 
TestUtils.ceil(TestUtils.generateTestMatrixBlock(i, j, -10, 10, i == 1 && j == 
1 ? 1 : s, 13));
+                       this.right = 
TestUtils.ceil(TestUtils.generateTestMatrixBlock(j, k, -10, 10, k == 1 && k == 
1 ? 1 : s2, 14));
+
+                       this.exp = multiply(left, right, 1);
+                       this.k = p;
+               }
+               catch(Exception e) {
+                       e.printStackTrace();
+                       throw new RuntimeException(e);
+               }
+       }
+
+       @Parameters
+       public static Collection<Object[]> data() {
+
+               List<Object[]> tests = new ArrayList<>();
+               try {
+                       double[] sparsities = new double[] {0.001, 0.1, 0.5};
+                       int[] is = new int[] {1, 3, 1024};
+                       int[] js = new int[] {1, 3, 1024};
+                       int[] ks = new int[] {1, 3, 1024};
+                       int[] par = new int[] {1, 4};
+
+                       for(int s = 0; s < sparsities.length; s++) {
+                               for(int s2 = 0; s2 < sparsities.length; s2++) {
+                                       for(int p = 0; p < par.length; p++) {
+                                               for(int i = 0; i < is.length; 
i++) {
+                                                       for(int j = 0; j < 
js.length; j++) {
+                                                               for(int k = 0; 
k < ks.length; k++) {
+                                                                       
tests.add(new Object[] {is[i], js[j], ks[k], sparsities[s], sparsities[s2], 
par[p]});
+                                                               }
+                                                       }
+                                               }
+                                       }
+                               }
+                       }
+               }
+               catch(Exception e) {
+                       e.printStackTrace();
+                       fail("failed constructing tests");
+               }
+
+               return tests;
+       }
+
+       @Test
+       public void testMultiplicationAsIs() {
+               test(left, right);
+       }
+
+       @Test
+       public void testLeftForceDense() {
+               left.sparseToDense();
+               test(left, right);
+       }
+
+       @Test
+       public void testRightForceDense() {
+               right.sparseToDense();
+               test(left, right);
+       }
+
+       @Test
+       public void testBothForceDense() {
+               left.sparseToDense();
+               right.sparseToDense();
+               test(left, right);
+       }
+
+       @Test
+       public void testLeftForceSparse() {
+               left.denseToSparse(true);
+               test(left, right);
+       }
+
+       @Test
+       public void testRightForceSparse() {
+               right.denseToSparse(true);
+               test(left, right);
+       }
+
+       @Test
+       public void testBothForceSparse() {
+               left.denseToSparse(true);
+               right.denseToSparse(true);
+               test(left, right);
+       }
+
+       private void test(MatrixBlock a, MatrixBlock b) {
+               try {
+                       MatrixBlock ret = multiply(a, b, k);
+
+                       boolean sparseLeft = a.isInSparseFormat();
+                       boolean sparseRight = b.isInSparseFormat();
+                       boolean sparseOut = exp.isInSparseFormat();
+                       String sparseErrMessage = "SparseLeft:" + sparseLeft + 
" SparseRight: " + sparseRight + " SparseOut:"
+                               + sparseOut;
+                       String sizeErrMessage = size(a) + "  " + size(b) + "  " 
+ size(exp);
+
+                       String totalMessage = "\n\n" + sizeErrMessage + "\n" + 
sparseErrMessage;
+
+                       if(ret.getNumRows() * ret.getNumColumns() < 1000) {
+                               totalMessage += "\n\nExp" + exp;
+                               totalMessage += "\n\nAct" + ret;
+                       }
+
+                       TestUtils.compareMatricesPercentageDistance(exp, ret, 
0.999, 0.99999, totalMessage, false);
+               }
+               catch(Exception e) {
+                       e.printStackTrace();
+                       fail(e.getMessage());
+               }
+       }
+
+       private static String size(MatrixBlock a) {
+               return a.getNumRows() + "x" + a.getNumColumns() + "n" + 
a.getNonZeros();
+       }
+
+       private static MatrixBlock multiply(MatrixBlock a, MatrixBlock b, int 
k) {
+               AggregateOperator agg = new AggregateOperator(0, 
Plus.getPlusFnObject());
+               AggregateBinaryOperator mult = new 
AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg, k);
+               return a.aggregateBinaryOperations(a, b, mult);
+       }
+}

Reply via email to