This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit aab900c181fc5b33483d527792fec889d86862cb
Author: Matthias Boehm <[email protected]>
AuthorDate: Thu Oct 26 21:35:27 2023 +0200

    [SYSTEMDS-3636] Sparse Transpose-Self MatMult w/ Sparse Outputs
    
    So far all dense and sparse tsmm operations always worked with
    dense outputs and only finally converted the output to sparse where
    needed. On large graph operations like G %*% t(G) this quickly runs
    output of memory. This patch adds for tsmm right a dedicated kernel
    that directly outputs sparse representations because we can perform
    sparse dot products for row-column combinations.
---
 .../sysds/runtime/matrix/data/LibMatrixMult.java   | 165 ++++++++++++++++++---
 .../matrix/TransposeMatrixMultiplicationTest.java  |   1 -
 .../FullMatrixMultiplicationTransposeSelfTest.java |  38 +++++
 .../TransposeSelfMatrixMultiplication1.dml         |   2 +-
 .../TransposeSelfMatrixMultiplication2.dml         |   2 +-
 5 files changed, 185 insertions(+), 23 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
index 711197352a..03dcbc359b 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
@@ -50,6 +50,7 @@ import org.apache.sysds.runtime.data.SparseBlockCSR;
 import org.apache.sysds.runtime.data.SparseBlockFactory;
 import org.apache.sysds.runtime.data.SparseBlockMCSR;
 import org.apache.sysds.runtime.data.SparseRowScalar;
+import org.apache.sysds.runtime.data.SparseRowVector;
 import org.apache.sysds.runtime.functionobjects.SwapIndex;
 import org.apache.sysds.runtime.functionobjects.ValueFunction;
 import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
@@ -434,13 +435,13 @@ public class LibMatrixMult
                //Timing time = new Timing(true);
                
                //pre-processing
-               ret.sparse = false;
-               ret.allocateDenseBlock();
+               double sp = m1.getSparsity();
+               double osp = OptimizerUtils.getMatMultSparsity(sp, sp, m1.rlen, 
m1.clen, m1.rlen, false);
+               ret.sparse = !leftTranspose && m1.sparse && osp < 
MatrixBlock.SPARSITY_TURN_POINT;
+               ret.allocateBlock();
 
-               if( m1.sparse )
-                       matrixMultTransposeSelfSparse(m1, ret, leftTranspose, 
0, ret.rlen);
-               else 
-                       matrixMultTransposeSelfDense(m1, ret, leftTranspose, 0, 
ret.rlen );
+               //core tsmm operation
+               matrixMultTransposeSelf(m1, ret, leftTranspose, 0, m1.rlen);
 
                //post-processing
                if(copyToLowerTriangle){
@@ -476,15 +477,18 @@ public class LibMatrixMult
                //Timing time = new Timing(true);
                
                //pre-processing (no need to check isThreadSafe)
-               ret.sparse = false;
-               ret.allocateDenseBlock();
-       
+               double sp = m1.getSparsity();
+               ret.sparse = !leftTranspose && m1.sparse && 
MatrixBlock.SPARSITY_TURN_POINT >
+                       OptimizerUtils.getMatMultSparsity(sp, sp, m1.rlen, 
m1.clen, m1.rlen, false);
+               
+               ret.allocateBlock();
+
                //core multi-threaded matrix mult computation
                ExecutorService pool = CommonThreadPool.get(k);
                try {
                        ArrayList<MatrixMultTransposeTask> tasks = new 
ArrayList<>();
-                       //load balance via #tasks=2k due to triangular shape 
-                       int blklen = (int)(Math.ceil((double)ret.rlen / (2 * 
k)));
+                       //load balance via #tasks=4k due to triangular shape 
+                       int blklen = (int)(Math.ceil((double)ret.rlen / (4 * 
k)));
                        for(int i = 0; i < ret.rlen; i += blklen)
                                tasks.add(new MatrixMultTransposeTask(m1, ret, 
leftTranspose, i, Math.min(i+blklen, ret.rlen)));
                        for( Future<Object> rtask :  pool.invokeAll(tasks) )
@@ -500,7 +504,7 @@ public class LibMatrixMult
                //post-processing
                long nnz = copyUpperToLowerTriangle(ret);
                ret.setNonZeros(nnz);
-               ret.examSparsity();     
+               ret.examSparsity();
                
                //System.out.println("TSMM k="+k+" 
("+m1.isInSparseFormat()+","+m1.getNumRows()+","+m1.getNumColumns()+","+m1.getNonZeros()+","+leftTranspose+")
 in "+time.stop());
        }
@@ -2236,6 +2240,15 @@ public class LibMatrixMult
                }
        }
 
+       private static void matrixMultTransposeSelf(MatrixBlock m1, MatrixBlock 
ret, boolean leftTranspose, int rl, int ru) {
+               if(m1.sparse && ret.sparse)
+                       matrixMultTransposeSelfUltraSparse(m1, ret, 
leftTranspose, rl, ru);
+               else if( m1.sparse )
+                       matrixMultTransposeSelfSparse(m1, ret, leftTranspose, 
rl, ru);
+               else 
+                       matrixMultTransposeSelfDense(m1, ret, leftTranspose, 
rl, ru );
+       }
+       
        private static void matrixMultTransposeSelfSparse( MatrixBlock m1, 
MatrixBlock ret, boolean leftTranspose, int rl, int ru ) {
                //2) transpose self matrix multiply sparse
                // (compute only upper-triangular matrix due to symmetry)
@@ -2357,6 +2370,46 @@ public class LibMatrixMult
                        }
                }
        }
+       
+       private static void matrixMultTransposeSelfUltraSparse( MatrixBlock m1, 
MatrixBlock ret, boolean leftTranspose, int rl, int ru ) {
+               if( leftTranspose )
+                       throw new DMLRuntimeException("Left tsmm with sparse 
output not supported");
+
+               // Operation X%*%t(X), sparse input and output
+               SparseBlock a = m1.sparseBlock;
+               SparseBlock c = ret.sparseBlock;
+               int m = m1.rlen;
+               
+               final int blocksize = 256;
+               for(int bi=rl; bi<ru; bi+=blocksize) { //blocking rows in X
+                       int bimin = Math.min(bi+blocksize, ru);
+                       for(int i=bi; i<bimin; i++) //preallocation
+                               if( !a.isEmpty(i) )
+                                       c.allocate(i, 
8*SparseRowVector.initialCapacity); //heuristic
+                       for(int bj=bi; bj<m; bj+=blocksize ) { //blocking cols 
in t(X) 
+                               int bjmin = Math.min(bj+blocksize, m);
+                               for(int i=bi; i<bimin; i++) { //rows in X
+                                       if( a.isEmpty(i) ) continue;
+                                       int apos = a.pos(i);
+                                       int alen = a.size(i);
+                                       int[] aix = a.indexes(i);
+                                       double[] avals = a.values(i);
+                                       for(int j=Math.max(bj,i); j<bjmin; j++) 
{ //cols in t(X)
+                                               if( a.isEmpty(j) ) continue;
+                                               int bpos = a.pos(j);
+                                               int blen = a.size(j);
+                                               int[] bix = a.indexes(j);
+                                               double[] bvals = a.values(j);
+                                               
+                                               //compute sparse dot product 
and append
+                                               double v = dotProduct(avals, 
aix, apos, alen, bvals, bix, bpos, blen);
+                                               if( v != 0 )
+                                                       c.append(i, j, v);
+                                       }
+                               }
+                       }
+               }
+       }
 
        private static void matrixMultPermuteDense(MatrixBlock pm1, MatrixBlock 
m2, MatrixBlock ret1, MatrixBlock ret2, int rl, int ru) {
                double[] a = pm1.getDenseBlockValues();
@@ -3482,6 +3535,36 @@ public class LibMatrixMult
                //scalar result
                return val; 
        }
+       
+       private static double dotProduct(double[] a, int[] aix, final int apos, 
final int alen, double[] b, int bix[], final int bpos, final int blen) {
+               final int asize = apos+alen;
+               final int bsize = bpos+blen;
+               int k = apos, k2 = bpos;
+               
+               //pruning filter
+               if(aix[apos]>bix[bsize-1] || aix[asize-1]<bix[bpos] )
+                       return 0;
+               
+               //sorted set intersection
+               double v = 0;
+               while( k<asize & k2<bsize ) {
+                       int aixk = aix[k];
+                       int bixk = aix[k2];
+                       if( aixk < bixk )
+                               k++;
+                       else if( aixk > bixk )
+                               k2++;
+                       else { // ===
+                               v += a[k] * b[k2];
+                               k++; k2++;
+                       }
+                       //note: branchless version slower
+                       //v += (aixk==bixk) ? a[k] * b[k2] : 0;
+                       //k += (aixk <= bixk) ? 1 : 0;
+                       //k2 += (aixk >= bixk) ? 1 : 0;
+               }
+               return v;
+       }
 
        //note: public for use by codegen for consistency
        public static void vectMultiplyAdd( final double aval, double[] b, 
double[] c, int bi, int ci, final int len )
@@ -4025,6 +4108,13 @@ public class LibMatrixMult
                return val;
        }
        
+       public static long copyUpperToLowerTriangle( MatrixBlock ret ) {
+               return ret.sparse ?
+                       copyUpperToLowerTriangleSparse(ret) :
+                       copyUpperToLowerTriangleDense(ret);
+       }
+       
+       
        /**
         * Used for all version of TSMM where the result is known to be 
symmetric.
         * Hence, we compute only the upper triangular matrix and copy this 
partial
@@ -4033,7 +4123,7 @@ public class LibMatrixMult
         * @param ret matrix
         * @return number of non zeros
         */
-       public static long copyUpperToLowerTriangle( MatrixBlock ret )
+       public static long copyUpperToLowerTriangleDense( MatrixBlock ret )
        {
                //ret is guaranteed to be a squared, symmetric matrix
                if( ret.rlen != ret.clen )
@@ -4074,18 +4164,56 @@ public class LibMatrixMult
                return nnz;
        }
 
+       public static long copyUpperToLowerTriangleSparse( MatrixBlock ret )
+       {
+               //ret is guaranteed to be a squared, symmetric matrix
+               if( ret.rlen != ret.clen )
+                       throw new RuntimeException("Invalid non-squared input 
matrix.");
+               
+               SparseBlock c = ret.getSparseBlock();
+               int n = ret.rlen;
+               long nnz = 0;
+               
+               //copy non-diagonal values from upper-triangular matrix
+               for(int i=0; i<n; i++) {
+                       if(c.isEmpty(i)) continue;
+                       int cpos = c.pos(i);
+                       //int cpos2 = c.posFIndexGTE(i, i);
+                       //if( cpos2 < 0 ) continue;
+                       int clen = c.size(i);
+                       int[] cix = c.indexes(i);
+                       double[] cvals = c.values(i);
+                       for(int k=cpos; k<cpos+clen; k++) {
+                               if( cix[k] == i )
+                                       nnz ++;
+                               else if( cix[k] > i ) {
+                                       c.append(cix[k], i, cvals[k]);
+                                       nnz += 2;
+                               }
+                       }
+               }
+               
+               //sort sparse rows (because append out of order)
+               c.sort();
+               
+               return nnz;
+       }
+       
        public static MatrixBlock prepMatrixMultTransposeSelfInput( MatrixBlock 
m1, boolean leftTranspose, boolean par ) {
                MatrixBlock ret = m1;
                final int rlen = m1.rlen;
                final int clen = m1.clen;
+               double sp = m1.getSparsity();
+               double osp = OptimizerUtils.getMatMultSparsity(sp, sp, m1.rlen, 
m1.clen, m1.rlen, false);
+               boolean retSparse = !leftTranspose && m1.sparse && osp < 
MatrixBlock.SPARSITY_TURN_POINT;
                
-               if( !leftTranspose && m1.sparse && rlen > 1) { //X%*%t(X) 
SPARSE MATRIX
+               if( !leftTranspose && !retSparse && m1.sparse && rlen > 1) { 
//X%*%t(X) SPARSE MATRIX
                        //directly via LibMatrixReorg in order to prevent 
sparsity change
                        MatrixBlock tmpBlock = new MatrixBlock(clen, rlen, 
m1.sparse);
                        LibMatrixReorg.reorg(m1, tmpBlock, new 
ReorgOperator(SwapIndex.getSwapIndexFnObject()));
                        ret = tmpBlock;
                }
-               else if( leftTranspose && m1.sparse && m1.sparseBlock 
instanceof SparseBlockCSR ) {
+               else if( leftTranspose && !retSparse && m1.sparse && 
m1.sparseBlock instanceof SparseBlockCSR ) {
                        //for a special case of CSR inputs where all non-empty 
rows are dense, we can
                        //create a shallow copy of the values arrays to a 
"dense" block and perform
                        //tsmm with the existing dense block operations w/o 
unnecessary gather/scatter
@@ -4158,7 +4286,7 @@ public class LibMatrixMult
                        (sharedTP ? PAR_MINFLOP_THRESHOLD2 : 
PAR_MINFLOP_THRESHOLD1));
        }
        
-       private static boolean 
satisfiesMultiThreadingConstraintsTSMM(MatrixBlock m1, boolean leftTranspose, 
long FPfactor, int k) {
+       private static boolean 
satisfiesMultiThreadingConstraintsTSMM(MatrixBlock m1, boolean leftTranspose, 
double FPfactor, int k) {
                boolean sharedTP = 
(InfrastructureAnalyzer.getLocalParallelism() == k);
                double threshold = sharedTP ? PAR_MINFLOP_THRESHOLD2 : 
PAR_MINFLOP_THRESHOLD1;
                return k > 1 && LOW_LEVEL_OPTIMIZATION && 
(leftTranspose?m1.clen:m1.rlen)!=1
@@ -4425,10 +4553,7 @@ public class LibMatrixMult
                
                @Override
                public Object call() {
-                       if( _m1.sparse )
-                               matrixMultTransposeSelfSparse(_m1, _ret, _left, 
_rl, _ru);
-                       else
-                               matrixMultTransposeSelfDense(_m1, _ret, _left, 
_rl, _ru);
+                       matrixMultTransposeSelf(_m1, _ret, _left, _rl, _ru);
                        return null;
                }
        }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/binary/matrix/TransposeMatrixMultiplicationTest.java
 
b/src/test/java/org/apache/sysds/test/functions/binary/matrix/TransposeMatrixMultiplicationTest.java
index 6a73499f90..ba65889a95 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/binary/matrix/TransposeMatrixMultiplicationTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/binary/matrix/TransposeMatrixMultiplicationTest.java
@@ -285,5 +285,4 @@ public class TransposeMatrixMultiplicationTest extends 
AutomatedTestBase
                        DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
                }
        }
-
 }
\ No newline at end of file
diff --git 
a/src/test/java/org/apache/sysds/test/functions/binary/matrix_full_other/FullMatrixMultiplicationTransposeSelfTest.java
 
b/src/test/java/org/apache/sysds/test/functions/binary/matrix_full_other/FullMatrixMultiplicationTransposeSelfTest.java
index 6413db4e06..1599bdac86 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/binary/matrix_full_other/FullMatrixMultiplicationTransposeSelfTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/binary/matrix_full_other/FullMatrixMultiplicationTransposeSelfTest.java
@@ -119,6 +119,10 @@ public class FullMatrixMultiplicationTransposeSelfTest 
extends AutomatedTestBase
                runTransposeSelfVectorMultiplicationTest(MMTSJType.RIGHT, 
ExecType.CP, true);
        }
        
+       @Test
+       public void testRightUltraSparseCP() {
+               runTransposeSelfUltraSparseTest(MMTSJType.RIGHT);
+       }
        
        private void runTransposeSelfMatrixMultiplicationTest( MMTSJType type, 
ExecType instType, boolean sparse )
        {
@@ -261,4 +265,38 @@ public class FullMatrixMultiplicationTransposeSelfTest 
extends AutomatedTestBase
                        rtplatform = platformOld;
                }
        }
+       
+       private void runTransposeSelfUltraSparseTest( MMTSJType type )
+       {
+               //rtplatform for MR
+               ExecMode platformOld = rtplatform;
+               rtplatform = ExecMode.SINGLE_NODE;
+       
+               try {
+                       loadTestConfiguration(getTestConfiguration(TEST_NAME2));
+                       int dim = 10000;
+                       
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + TEST_NAME2 + ".dml";
+                       programArgs = new String[]{"-stats","-args", input("A"),
+                               String.valueOf(dim), String.valueOf(dim), 
output("B") };
+                       fullRScriptName = HOME + TEST_NAME2 + ".R";
+                       rCmd = "Rscript" + " " + fullRScriptName + " " + 
inputDir() + " " + expectedDir();
+       
+                       //generate actual dataset
+                       double[][] A = getRandomMatrix(dim, dim, 0, 1, 0.0002, 
7); 
+                       writeInputMatrix("A", A, true);
+                       
+                       runTest(true, false, null, -1); 
+                       //runRScript(true); 
+                       
+                       //compare matrices 
+                       //HashMap<CellIndex, Double> dmlfile = 
readDMLMatrixFromOutputDir("B");
+                       //HashMap<CellIndex, Double> rfile  = 
readRMatrixFromExpectedDir("B");
+                       //TestUtils.compareMatrices(dmlfile, rfile, eps, 
"Stat-DML", "Stat-R");
+               }
+               finally {
+                       rtplatform = platformOld;
+               }
+       }
 }
diff --git 
a/src/test/scripts/functions/binary/matrix_full_other/TransposeSelfMatrixMultiplication1.dml
 
b/src/test/scripts/functions/binary/matrix_full_other/TransposeSelfMatrixMultiplication1.dml
index 9e8779f41c..562fafa7bf 100644
--- 
a/src/test/scripts/functions/binary/matrix_full_other/TransposeSelfMatrixMultiplication1.dml
+++ 
b/src/test/scripts/functions/binary/matrix_full_other/TransposeSelfMatrixMultiplication1.dml
@@ -24,4 +24,4 @@
 A = read($1, rows=$2, cols=$3, format="text");
 B = t(A) %*% A;
 
-write(B, $4, format="text");
\ No newline at end of file
+write(B, $4, format="text");
diff --git 
a/src/test/scripts/functions/binary/matrix_full_other/TransposeSelfMatrixMultiplication2.dml
 
b/src/test/scripts/functions/binary/matrix_full_other/TransposeSelfMatrixMultiplication2.dml
index 4679990870..91ecd00f69 100644
--- 
a/src/test/scripts/functions/binary/matrix_full_other/TransposeSelfMatrixMultiplication2.dml
+++ 
b/src/test/scripts/functions/binary/matrix_full_other/TransposeSelfMatrixMultiplication2.dml
@@ -24,4 +24,4 @@
 A = read($1, rows=$2, cols=$3, format="text");
 B = A %*% t(A);
 
-write(B, $4, format="text");
\ No newline at end of file
+write(B, $4, format="text");

Reply via email to