This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 15bef0394d [SYSTEMDS-3636] Fix new ultra-sparse tsmm right, and new 
tests
15bef0394d is described below

commit 15bef0394d71eefee98769a1f50eb3dadb336aaa
Author: Matthias Boehm <[email protected]>
AuthorDate: Fri Oct 27 14:28:01 2023 +0200

    [SYSTEMDS-3636] Fix new ultra-sparse tsmm right, and new tests
---
 .../sysds/runtime/matrix/data/LibMatrixMult.java   | 21 ++++-----
 .../FullMatrixMultiplicationTransposeSelfTest.java | 55 ++++++++++------------
 2 files changed, 35 insertions(+), 41 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
index 0f08176fe1..0b8bd216f4 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
@@ -435,9 +435,7 @@ public class LibMatrixMult
                //Timing time = new Timing(true);
                
                //pre-processing
-               double sp = m1.getSparsity();
-               double osp = OptimizerUtils.getMatMultSparsity(sp, sp, m1.rlen, 
m1.clen, m1.rlen, false);
-               ret.sparse = !leftTranspose && m1.sparse && osp < 
MatrixBlock.SPARSITY_TURN_POINT;
+               ret.sparse = isSparseOutputTSMM(m1, leftTranspose);
                ret.allocateBlock();
 
                //core tsmm operation
@@ -477,10 +475,7 @@ public class LibMatrixMult
                //Timing time = new Timing(true);
                
                //pre-processing (no need to check isThreadSafe)
-               double sp = m1.getSparsity();
-               ret.sparse = !leftTranspose && m1.sparse && 
MatrixBlock.SPARSITY_TURN_POINT >
-                       OptimizerUtils.getMatMultSparsity(sp, sp, m1.rlen, 
m1.clen, m1.rlen, false);
-               
+               ret.sparse = isSparseOutputTSMM(m1, leftTranspose);
                ret.allocateBlock();
 
                //core multi-threaded matrix mult computation
@@ -3549,7 +3544,7 @@ public class LibMatrixMult
                double v = 0;
                while( k<asize & k2<bsize ) {
                        int aixk = aix[k];
-                       int bixk = aix[k2];
+                       int bixk = bix[k2];
                        if( aixk < bixk )
                                k++;
                        else if( aixk > bixk )
@@ -4203,9 +4198,7 @@ public class LibMatrixMult
                MatrixBlock ret = m1;
                final int rlen = m1.rlen;
                final int clen = m1.clen;
-               double sp = m1.getSparsity();
-               double osp = OptimizerUtils.getMatMultSparsity(sp, sp, m1.rlen, 
m1.clen, m1.rlen, false);
-               boolean retSparse = !leftTranspose && m1.sparse && osp < 
MatrixBlock.SPARSITY_TURN_POINT;
+               boolean retSparse = isSparseOutputTSMM(m1, leftTranspose);
                
                if( !leftTranspose && !retSparse && m1.sparse && rlen > 1) { 
//X%*%t(X) SPARSE MATRIX
                        //directly via LibMatrixReorg in order to prevent 
sparsity change
@@ -4323,6 +4316,12 @@ public class LibMatrixMult
                boolean sparseOut = 
MatrixBlock.evalSparseFormatInMemory(m1.rlen, m2.clen, estNnz);
                return m2.clen < 4*1024 && sparseOut;
        }
+       
+       public static boolean isSparseOutputTSMM(MatrixBlock m1, boolean 
leftTranspose) {
+               double sp = m1.getSparsity();
+               double osp = OptimizerUtils.getMatMultSparsity(sp, sp, m1.rlen, 
m1.clen, m1.rlen, false);
+               return !leftTranspose && m1.sparse && osp < 
MatrixBlock.ULTRA_SPARSITY_TURN_POINT2;
+       }
 
        public static boolean isOuterProductTSMM(int rlen, int clen, boolean 
left) {
                return left ? rlen == 1 & clen > 1 : rlen > 1 & clen == 1;
diff --git 
a/src/test/java/org/apache/sysds/test/functions/binary/matrix_full_other/FullMatrixMultiplicationTransposeSelfTest.java
 
b/src/test/java/org/apache/sysds/test/functions/binary/matrix_full_other/FullMatrixMultiplicationTransposeSelfTest.java
index 1599bdac86..3d1b4c239f 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/binary/matrix_full_other/FullMatrixMultiplicationTransposeSelfTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/binary/matrix_full_other/FullMatrixMultiplicationTransposeSelfTest.java
@@ -22,11 +22,16 @@ package 
org.apache.sysds.test.functions.binary.matrix_full_other;
 import java.util.HashMap;
 
 import org.junit.AfterClass;
+import org.junit.Assert;
 import org.junit.BeforeClass;
 import org.junit.Test;
 import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.common.Types.ExecType;
 import org.apache.sysds.lops.MMTSJ.MMTSJType;
+import 
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
+import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
+import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
@@ -268,35 +273,25 @@ public class FullMatrixMultiplicationTransposeSelfTest 
extends AutomatedTestBase
        
        private void runTransposeSelfUltraSparseTest( MMTSJType type )
        {
-               //rtplatform for MR
-               ExecMode platformOld = rtplatform;
-               rtplatform = ExecMode.SINGLE_NODE;
-       
-               try {
-                       loadTestConfiguration(getTestConfiguration(TEST_NAME2));
-                       int dim = 10000;
-                       
-                       String HOME = SCRIPT_DIR + TEST_DIR;
-                       fullDMLScriptName = HOME + TEST_NAME2 + ".dml";
-                       programArgs = new String[]{"-stats","-args", input("A"),
-                               String.valueOf(dim), String.valueOf(dim), 
output("B") };
-                       fullRScriptName = HOME + TEST_NAME2 + ".R";
-                       rCmd = "Rscript" + " " + fullRScriptName + " " + 
inputDir() + " " + expectedDir();
-       
-                       //generate actual dataset
-                       double[][] A = getRandomMatrix(dim, dim, 0, 1, 0.0002, 
7); 
-                       writeInputMatrix("A", A, true);
-                       
-                       runTest(true, false, null, -1); 
-                       //runRScript(true); 
-                       
-                       //compare matrices 
-                       //HashMap<CellIndex, Double> dmlfile = 
readDMLMatrixFromOutputDir("B");
-                       //HashMap<CellIndex, Double> rfile  = 
readRMatrixFromExpectedDir("B");
-                       //TestUtils.compareMatrices(dmlfile, rfile, eps, 
"Stat-DML", "Stat-R");
-               }
-               finally {
-                       rtplatform = platformOld;
-               }
+               //compare sparse tsmm and gemm directly to avoid unnecessary 
overhead (e.g., R) 
+               int dim = 10000;
+               
+               MatrixBlock G = MatrixBlock.randOperations(dim, dim, 0.0002, 0, 
1, "uniform", 7);
+               MatrixBlock Gt = LibMatrixReorg.transpose(G);
+               MatrixBlock Gtt = LibMatrixReorg.transpose(Gt);
+               TestUtils.compareMatrices(G, Gtt, 1e-16);
+               
+               //single-threaded core operations
+               MatrixBlock R11 = G.transposeSelfMatrixMultOperations(new 
MatrixBlock(), MMTSJType.RIGHT);
+               MatrixBlock R12 = LibMatrixMult.matrixMult(G, Gt);
+               Assert.assertEquals(R11.getNonZeros(), R12.getNonZeros());
+               TestUtils.compareMatrices(R11, R12, 1e-8);
+               
+               //multi-threaded core operations
+               int k = InfrastructureAnalyzer.getLocalParallelism();
+               MatrixBlock R21 = G.transposeSelfMatrixMultOperations(new 
MatrixBlock(), MMTSJType.RIGHT, k);
+               MatrixBlock R22 = LibMatrixMult.matrixMult(G, Gt, k);
+               Assert.assertEquals(R21.getNonZeros(), R22.getNonZeros());
+               TestUtils.compareMatrices(R21, R22, 1e-8);
        }
 }

Reply via email to