This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit 80025ddf3e8409d238c06084445dafa55e7a8579
Author: baunsgaard <[email protected]>
AuthorDate: Sun Jan 16 13:56:01 2022 +0100

    [SYSTEMDS-3243] Consistent allocation of MatrixBlock for MM
    
    This commit change the matrix multiplication to not allocate or analyze
    the output and inputs before calls to the libraries, to remove a
    unnecessary analysis step from MatrixBlock, and avoid sparse
    allocation into a dense allocation in some cases.
    
    The MM is now consolidated to only have one code path (both single and
    multithreaded) that check for output allocation making the API more
    robust and remove code duplication.
---
 .../sysds/runtime/matrix/data/LibMatrixMult.java   | 228 ++++++++++++---------
 .../sysds/runtime/matrix/data/LibMatrixNative.java | 135 ++++++------
 .../sysds/runtime/matrix/data/MatrixBlock.java     |  66 +++---
 3 files changed, 217 insertions(+), 212 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
index a503085..8384dd2 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
@@ -88,14 +88,14 @@ public class LibMatrixMult
         * 
         * All variants use a IKJ access pattern, and internally use dense 
output. After the
         * actual computation, we recompute nnz and check for sparse/dense 
representation.
-        *  
         * 
         * @param m1 first matrix
         * @param m2 second matrix
         * @param ret result matrix
+        * @return ret Matrix Block
         */
-       public static void matrixMult(MatrixBlock m1, MatrixBlock m2, 
MatrixBlock ret) {
-               matrixMult(m1, m2, ret, 0, m1.rlen);
+       public static MatrixBlock matrixMult(MatrixBlock m1, MatrixBlock m2, 
MatrixBlock ret) {
+               return matrixMult(m1, m2, ret, false, 1);
        }
        
        /**
@@ -109,141 +109,165 @@ public class LibMatrixMult
         * @param m2 second matrix
         * @param ret result matrix
         * @param fixedRet if true, output representation is fixed and nnzs not 
recomputed
+        * @return ret Matrix Block
         */
-       public static void matrixMult(MatrixBlock m1, MatrixBlock m2, 
MatrixBlock ret, boolean fixedRet) {
-               matrixMult(m1, m2, ret, 0, m1.rlen, fixedRet);
+       public static MatrixBlock matrixMult(MatrixBlock m1, MatrixBlock m2, 
MatrixBlock ret, boolean fixedRet) {
+               return matrixMult(m1, m2, ret, fixedRet, 1);
        }
        
-       public static void matrixMult(MatrixBlock m1, MatrixBlock m2, 
MatrixBlock ret, int rl, int ru) {
-               matrixMult(m1, m2, ret, rl, ru, false);
+       /**
+        * Performs a multi-threaded matrix multiplication and stores the 
result in the output matrix.
+        * The parameter k (k&gt;=1) determines the max parallelism k' with 
k'=min(k, vcores, m1.rlen).
+        * 
+        * @param m1 first matrix
+        * @param m2 second matrix
+        * @param ret result matrix
+        * @param k maximum parallelism
+        * @return ret Matrix Block
+        */
+       public static MatrixBlock matrixMult(MatrixBlock m1, MatrixBlock m2, 
MatrixBlock ret, int k) {
+               return matrixMult(m1, m2, ret, false, k);
        }
        
-       public static void matrixMult(MatrixBlock m1, MatrixBlock m2, 
MatrixBlock ret, int rl, int ru, boolean fixedRet) {
-               //check inputs / outputs
-               if( m1.isEmptyBlock(false) || m2.isEmptyBlock(false) ) {
-                       ret.examSparsity(); //turn empty dense into sparse
-                       return;
-               }
+       /**
+        * Performs a matrix multiplication and stores the result in the output 
matrix.
+        * 
+        * All variants use a IKJ access pattern, and internally use dense 
output. After the
+        * actual computation, we recompute nnz and check for sparse/dense 
representation.
+        * 
+        * This method allows one to disabling exam sparsity. This feature is 
useful if matrixMult is used as an intermediate
+        * operation (for example: LibMatrixDNN). It makes sense for 
LibMatrixDNN because the output is internally
+        * consumed by another dense instruction, which makes repeated 
conversion to sparse wasteful.
+        * This should be used in rare cases and if you are unsure,
+        * use the method 'matrixMult(MatrixBlock m1, MatrixBlock m2, 
MatrixBlock ret)' instead.
+        * 
+        * The parameter k (k&gt;=1) determines the max parallelism k' with 
k'=min(k, vcores, m1.rlen).
+        * 
+        * @param m1 first matrix
+        * @param m2 second matrix
+        * @param ret result matrix
+        * @param fixedRet if true, output representation is fixed and nnzs not 
recomputed
+        * @param k maximum parallelism
+        * @return ret Matrix Block
+        */
+       public static MatrixBlock matrixMult(MatrixBlock m1, MatrixBlock m2, 
MatrixBlock ret, boolean fixedRet, int k) {
+               if(m1.isEmptyBlock(false) || m2.isEmptyBlock(false)) 
+                       return emptyMatrixMult(m1, m2, ret);
                
-               //Timing time = new Timing(true);
+               // Timing time = new Timing(true);
                
-               //pre-processing: output allocation
+               // pre analysis
                boolean m1Perm = m1.isSparsePermutationMatrix();
-               boolean ultraSparse = (fixedRet && ret.sparse)
-                       || (!fixedRet && isUltraSparseMatrixMult(m1, m2, 
m1Perm));
-               boolean sparse = !m1Perm && !ultraSparse && !fixedRet 
+               boolean ultraSparse = (fixedRet && ret.sparse) ||
+                       (!fixedRet && isUltraSparseMatrixMult(m1, m2, m1Perm));
+               boolean sparse = !fixedRet && !ultraSparse && !m1Perm
                        && isSparseOutputMatrixMult(m1, m2);
-               boolean tm2 = checkPrepMatrixMultRightInput(m1,m2);
-               m2 = prepMatrixMultRightInput(m1, m2);
-               ret.sparse = ultraSparse | sparse;
-               ret.allocateBlock();
                
-               //prepare row-upper for special cases of vector-matrix
-               boolean pm2 = !ultraSparse &&
-                       checkParMatrixMultRightInputRows(m1, m2, 
Integer.MAX_VALUE);
-               int ru2 = (pm2 && ru==m1.rlen) ? m2.rlen : ru; 
-               int cu = m2.clen;
+               // allocate output
+               if(ret == null)
+                       ret = new MatrixBlock(m1.rlen, m2.clen, ultraSparse | 
sparse);
+               else 
+                       ret.reset(m1.rlen, m2.clen, ultraSparse | sparse);
+               ret.allocateBlock();
                
-               //core matrix mult computation
-               if( ultraSparse )
+               // Detect if we should transpose skinny right side.
+               boolean tm2 = !fixedRet && checkPrepMatrixMultRightInput(m1,m2);
+               m2 = prepMatrixMultRightInput(m1, m2, tm2);
+
+               // check for multi-threading
+               if (!ret.isThreadSafe() 
+                               || !satisfiesMultiThreadingConstraints(m1, m2, 
m1.rlen==1, true, 2, k)
+                               || fixedRet) // Fixed ret not supported in 
multithreaded execution yet
+                       k = 1;
+
+               if(k <= 1)
+                       singleThreadMatrixMult(m1, m2, ret, ultraSparse, 
sparse, tm2, m1Perm, fixedRet);
+               else
+                       parallelMatrixMult(m1, m2, ret, k, ultraSparse, sparse, 
tm2, m1Perm);
+
+               //System.out.println("MM "+k+" 
("+m1.isInSparseFormat()+","+m1.getNumRows()+","+m1.getNumColumns()+","+m1.getNonZeros()+")x"
 +
+               //              
"("+m2.isInSparseFormat()+","+m2.getNumRows()+","+m2.getNumColumns()+","+m2.getNonZeros()+")
 in "+time.stop());
+       
+               return ret;
+       }
+
+       private static void singleThreadMatrixMult(MatrixBlock m1, MatrixBlock 
m2, MatrixBlock ret,  
+               boolean ultraSparse, boolean sparse, boolean tm2, boolean 
m1Perm, boolean fixedRet){
+               // prepare row-upper for special cases of vector-matrix
+               final boolean pm2 = !ultraSparse && 
checkParMatrixMultRightInputRows(m1, m2, Integer.MAX_VALUE);
+               final int ru2 = (pm2) ? m2.rlen : m1.rlen;
+
+               // core matrix mult computation
+               if(ultraSparse)
                        matrixMultUltraSparse(m1, m2, ret, m1Perm, 0, ru2);
                else if(!m1.sparse && !m2.sparse)
-                       matrixMultDenseDense(m1, m2, ret, tm2, pm2, 0, ru2, 0, 
cu);
+                       matrixMultDenseDense(m1, m2, ret, tm2, pm2, 0, ru2, 0, 
m2.clen);
                else if(m1.sparse && m2.sparse)
                        matrixMultSparseSparse(m1, m2, ret, pm2, sparse, 0, 
ru2);
                else if(m1.sparse)
                        matrixMultSparseDense(m1, m2, ret, pm2, 0, ru2);
                else
                        matrixMultDenseSparse(m1, m2, ret, pm2, 0, ru2);
-               
-               //post-processing: nnz/representation
-               if( !fixedRet ) {
-                       if( !ret.sparse )
+
+               // post-processing: nnz/representation
+               if(!fixedRet) {
+                       if(!ret.sparse)
                                ret.recomputeNonZeros();
                        ret.examSparsity();
                }
-               
-               //System.out.println("MM 
("+m1.isInSparseFormat()+","+m1.getNumRows()+","+m1.getNumColumns()+","+m1.getNonZeros()+")x"
 +
-               //              
"("+m2.isInSparseFormat()+","+m2.getNumRows()+","+m2.getNumColumns()+","+m2.getNonZeros()+")
 in "+time.stop());
        }
-       
-       /**
-        * Performs a multi-threaded matrix multiplication and stores the 
result in the output matrix.
-        * The parameter k (k&gt;=1) determines the max parallelism k' with 
k'=min(k, vcores, m1.rlen).
-        * 
-        * @param m1 first matrix
-        * @param m2 second matrix
-        * @param ret result matrix
-        * @param k maximum parallelism
-        */
-       public static void matrixMult(MatrixBlock m1, MatrixBlock m2, 
MatrixBlock ret, int k) {
-               //check inputs / outputs
-               if( m1.isEmptyBlock(false) || m2.isEmptyBlock(false) ) {
-                       ret.examSparsity(); //turn empty dense into sparse
-                       return;
-               }
-               
-               //check too small workload and fallback to sequential if needed
-               if( !satisfiesMultiThreadingConstraints(m1, m2, m1.rlen==1, 
true, 2, k) ) {
-                       matrixMult(m1, m2, ret);
-                       return;
-               }
-               
-               //Timing time = new Timing(true);
-               
-               //pre-processing: output allocation (in contrast to 
single-threaded,
-               //we need to allocate sparse as well in order to prevent 
synchronization)
-               boolean m1Perm = m1.isSparsePermutationMatrix();
-               boolean ultraSparse = isUltraSparseMatrixMult(m1, m2, m1Perm);
-               boolean sparse = !ultraSparse && !m1Perm && 
isSparseOutputMatrixMult(m1, m2);
-               boolean tm2 = checkPrepMatrixMultRightInput(m1,m2);
-               m2 = prepMatrixMultRightInput(m1, m2);
-               ret.sparse = ultraSparse | sparse;
-               ret.allocateBlock();
-               
-               if (!ret.isThreadSafe()) {
-                       matrixMult(m1, m2, ret);
-                       return;
-               }
-               
-               //prepare row-upper for special cases of vector-matrix / 
matrix-matrix
+
+       private static void parallelMatrixMult(MatrixBlock m1, MatrixBlock m2, 
MatrixBlock ret, int k, 
+               boolean ultraSparse, boolean sparse, boolean tm2, boolean 
m1Perm){
+               // prepare row-upper for special cases of vector-matrix / 
matrix-matrix
                boolean pm2r = !ultraSparse && !sparse && 
checkParMatrixMultRightInputRows(m1, m2, k);
                boolean pm2c = !ultraSparse && 
checkParMatrixMultRightInputCols(m1, m2, k, pm2r);
-               int num = pm2r ? m2.rlen : pm2c ? m2.clen : m1.rlen; 
-               
-               //core multi-threaded matrix mult computation
-               //(currently: always parallelization over number of rows)
+               int num = pm2r ? m2.rlen : pm2c ? m2.clen : m1.rlen;
+
+               // core multi-threaded matrix mult computation
+               // (currently: always parallelization over number of rows)
                try {
                        ExecutorService pool = CommonThreadPool.get(k);
                        ArrayList<MatrixMultTask> tasks = new ArrayList<>();
-                       ArrayList<Integer> blklens = 
UtilFunctions.getBalancedBlockSizesDefault(num, k, (pm2r||pm2c));
-                       for( int i=0, lb=0; i<blklens.size(); 
lb+=blklens.get(i), i++ )
-                               tasks.add(new MatrixMultTask(m1, m2, ret, tm2, 
pm2r, pm2c, m1Perm, sparse, lb, lb+blklens.get(i)));
-                       //execute tasks
+                       ArrayList<Integer> blklens = 
UtilFunctions.getBalancedBlockSizesDefault(num, k, (pm2r || pm2c));
+                       for(int i = 0, lb = 0; i < blklens.size(); lb += 
blklens.get(i), i++)
+                               tasks.add(new MatrixMultTask(m1, m2, ret, tm2, 
pm2r, pm2c, m1Perm, sparse, lb, lb + blklens.get(i)));
+                       // execute tasks
                        List<Future<Object>> taskret = pool.invokeAll(tasks);
                        pool.shutdown();
-                       //aggregate partial results (nnz, ret for vector/matrix)
-                       ret.nonZeros = 0; //reset after execute
-                       for( Future<Object> task : taskret ) {
-                               if( pm2r ) //guaranteed single block
-                                       vectAdd((double[])task.get(), 
ret.getDenseBlockValues(), 0, 0, ret.rlen*ret.clen);
+                       // aggregate partial results (nnz, ret for 
vector/matrix)
+                       ret.nonZeros = 0; // reset after execute
+                       for(Future<Object> task : taskret) {
+                               if(pm2r) // guaranteed single block
+                                       vectAdd((double[]) task.get(), 
ret.getDenseBlockValues(), 0, 0, ret.rlen * ret.clen);
                                else
-                                       ret.nonZeros += (Long)task.get();
+                                       ret.nonZeros += (Long) task.get();
                        }
-                       if( pm2r )
+                       if(pm2r)
                                ret.recomputeNonZeros();
                }
                catch(Exception ex) {
                        throw new DMLRuntimeException(ex);
                }
-               
-               //post-processing (nnz maintained in parallel)
+
+               // post-processing (nnz maintained in parallel)
                ret.examSparsity();
-               
-               //System.out.println("MM k="+k+" 
("+m1.isInSparseFormat()+","+m1.getNumRows()+","+m1.getNumColumns()+","+m1.getNonZeros()+")x"
 +
-               //              
"("+m2.isInSparseFormat()+","+m2.getNumRows()+","+m2.getNumColumns()+","+m2.getNonZeros()+")
 in "+time.stop());
        }
-       
+
+       public static MatrixBlock emptyMatrixMult(MatrixBlock m1, MatrixBlock 
m2, MatrixBlock ret){
+               final int rl = m1.rlen;
+               final int cl = m2.clen;
+
+               if(ret == null)
+                       return new MatrixBlock(rl, cl, true);
+               else {
+                       ret.reset(rl, cl, true);
+                       ret.setNonZeros(0);
+                       ret.cleanupBlock(true, true);
+                       return ret;
+               }
+       }
+
        /**
         * Performs a matrix multiplication chain operation of type 
t(X)%*%(X%*%v) or t(X)%*%(w*(X%*%v)).
         * 
@@ -3959,16 +3983,16 @@ public class LibMatrixMult
                boolean sparseOut = 
MatrixBlock.evalSparseFormatInMemory(m1.rlen, m2.clen, estNnz);
                return m2.clen < 4*1024 && sparseOut;
        }
-       
+
        public static boolean isOuterProductTSMM(int rlen, int clen, boolean 
left) {
                return left ? rlen == 1 & clen > 1 : rlen > 1 & clen == 1;
        }
 
-       private static MatrixBlock prepMatrixMultRightInput( MatrixBlock m1, 
MatrixBlock m2 ) {
+       private static MatrixBlock prepMatrixMultRightInput( MatrixBlock m1, 
MatrixBlock m2, boolean tm2 ) {
                MatrixBlock ret = m2;
                
                //transpose if dense-dense, skinny rhs matrix (not vector), and 
memory guarded by output 
-               if( checkPrepMatrixMultRightInput(m1, m2)  ) {
+               if( tm2 ) {
                        MatrixBlock tmpBlock = new MatrixBlock(m2.clen, 
m2.rlen, m2.sparse);
                        LibMatrixReorg.reorg(m2, tmpBlock, new 
ReorgOperator(SwapIndex.getSwapIndexFnObject()));
                        ret = tmpBlock;
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixNative.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixNative.java
index 5c92253..ac0b069 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixNative.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixNative.java
@@ -64,86 +64,83 @@ public class LibMatrixNative
         * @param m2 rhs matrix block
         * @param ret output matrix block
         * @param k number of threads
+        * @return the ret matrixBlock if allocated otherwise a new matrixBlock.
         */
-       public static void matrixMult(MatrixBlock m1, MatrixBlock m2, 
MatrixBlock ret, int k) {
-               matrixMult(m1, m2, ret, k, true);
-       }
-       
-       public static void matrixMult(MatrixBlock m1, MatrixBlock m2, 
MatrixBlock ret, int k, boolean examSparsity) {
-               // Sanity check:
-               k = k <= 0 ? NativeHelper.getMaxNumThreads() : k;
-               
-               // check inputs / outputs
-               if (m1.isEmptyBlock(false) || m2.isEmptyBlock(false)){
-                       ret.setNonZeros(0);
-                       if(examSparsity)
-                               ret.examSparsity(); // turn empty dense into 
sparse
-                       return;
-               }
+       public static MatrixBlock matrixMult(MatrixBlock m1, MatrixBlock m2, 
MatrixBlock ret, int k) {
                
-               boolean isValidForNative = !isMatMultMemoryBound(m1.rlen, 
m1.clen, m2.clen) 
-                       && !m1.isInSparseFormat() && !m2.isInSparseFormat()
-                       && (m1.getDenseBlock().isContiguous() || 
!isSinglePrecision())
-                       && m2.getDenseBlock().isContiguous() //contiguous but 
not allocated
-                       && 8L * ret.getLength() < Integer.MAX_VALUE;
-
-               if( NativeHelper.isNativeLibraryLoaded() && isValidForNative ) 
-               {
-                       ret.sparse = false;
-                       ret.allocateDenseBlock();
-                       long start = DMLScript.STATISTICS ? System.nanoTime() : 
0;
-                       long nnz = 0;
-                       if( isSinglePrecision() ) {
-                               FloatBuffer fin1 = 
toFloatBuffer(m1.getDenseBlockValues(), inBuff, true);
-                               FloatBuffer fin2 = 
toFloatBuffer(m2.getDenseBlockValues(), filterBuff, true);
-                               FloatBuffer fout = 
toFloatBuffer(ret.getDenseBlockValues(), outBuff, false);
-                               nnz = NativeHelper.smmdd(fin1, fin2, fout, 
-                                       m1.getNumRows(), m1.getNumColumns(), 
m2.getNumColumns(), k);
-                               fromFloatBuffer(outBuff.get(), 
ret.getDenseBlockValues());
-                       }
-                       else {
-                               DenseBlock a = m1.getDenseBlock();
-                               if( a.isContiguous() ) {
-                                       nnz = 
NativeHelper.dmmdd(m1.getDenseBlockValues(), m2.getDenseBlockValues(),
-                                               ret.getDenseBlockValues(), 
m1.rlen, m1.clen, m2.clen, k);
+               if(NativeHelper.isNativeLibraryLoaded()){
+                       // Sanity check:
+                       k = k <= 0 ? NativeHelper.getMaxNumThreads() : k;
+                       
+                       // check inputs / outputs
+                       if (m1.isEmptyBlock(false) || m2.isEmptyBlock(false))
+                               return LibMatrixMult.emptyMatrixMult(m1,m2, 
ret);
+                       
+                       boolean isValidForNative = 
!isMatMultMemoryBound(m1.rlen, m1.clen, m2.clen) 
+                               && !m1.isInSparseFormat() && 
!m2.isInSparseFormat()
+                               && (m1.getDenseBlock().isContiguous() || 
!isSinglePrecision())
+                               && m2.getDenseBlock().isContiguous() 
//contiguous but not allocated
+                               && 8L * ret.getLength() < Integer.MAX_VALUE;
+       
+                       if( isValidForNative ) 
+                       {
+                               // allocate output
+                               if(ret == null)
+                                       ret = new MatrixBlock(m1.rlen, m2.clen, 
false);
+                               else 
+                                       ret.reset(m1.rlen, m2.clen, false);
+                               ret.allocateBlock();
+                               
+                               long start = DMLScript.STATISTICS ? 
System.nanoTime() : 0;
+                               long nnz = 0;
+                               if( isSinglePrecision() ) {
+                                       FloatBuffer fin1 = 
toFloatBuffer(m1.getDenseBlockValues(), inBuff, true);
+                                       FloatBuffer fin2 = 
toFloatBuffer(m2.getDenseBlockValues(), filterBuff, true);
+                                       FloatBuffer fout = 
toFloatBuffer(ret.getDenseBlockValues(), outBuff, false);
+                                       nnz = NativeHelper.smmdd(fin1, fin2, 
fout, 
+                                               m1.getNumRows(), 
m1.getNumColumns(), m2.getNumColumns(), k);
+                                       fromFloatBuffer(outBuff.get(), 
ret.getDenseBlockValues());
                                }
                                else {
-                                       //sequential processing of individual 
blocks to 
-                                       //avoid segementation faults with 
concurrent multi-threaded BLAS calls
-                                       for(int bix = 0; bix < a.numBlocks(); 
bix++) {
-                                               double[] tmp = new 
double[a.blockSize(bix)*m2.clen];
-                                               nnz += 
NativeHelper.dmmdd(a.valuesAt(bix), m2.getDenseBlockValues(),
-                                                       tmp, a.blockSize(bix), 
m1.clen, m2.clen, k);
-                                               int rl = bix * a.blockSize();
-                                               ret.getDenseBlock().set(rl, 
rl+a.blockSize(bix), 0, m2.clen,
-                                                       
DenseBlockFactory.createDenseBlock(tmp, new int[]{a.blockSize(bix),m2.clen}));
+                                       DenseBlock a = m1.getDenseBlock();
+                                       if( a.isContiguous() ) {
+                                               nnz = 
NativeHelper.dmmdd(m1.getDenseBlockValues(), m2.getDenseBlockValues(),
+                                                       
ret.getDenseBlockValues(), m1.rlen, m1.clen, m2.clen, k);
+                                       }
+                                       else {
+                                               //sequential processing of 
individual blocks to 
+                                               //avoid segementation faults 
with concurrent multi-threaded BLAS calls
+                                               for(int bix = 0; bix < 
a.numBlocks(); bix++) {
+                                                       double[] tmp = new 
double[a.blockSize(bix)*m2.clen];
+                                                       nnz += 
NativeHelper.dmmdd(a.valuesAt(bix), m2.getDenseBlockValues(),
+                                                               tmp, 
a.blockSize(bix), m1.clen, m2.clen, k);
+                                                       int rl = bix * 
a.blockSize();
+                                                       
ret.getDenseBlock().set(rl, rl+a.blockSize(bix), 0, m2.clen,
+                                                               
DenseBlockFactory.createDenseBlock(tmp, new int[]{a.blockSize(bix),m2.clen}));
+                                               }
                                        }
                                }
-                       }
-                       
-                       if(nnz > -1) {
-                               if(DMLScript.STATISTICS) {
-                                       Statistics.nativeLibMatrixMultTime += 
System.nanoTime() - start;
-                                       
Statistics.numNativeLibMatrixMultCalls.increment();
-                               }
-                               ret.setNonZeros(nnz);
-                               if(examSparsity)
+                               
+                               if(nnz > -1) {
+                                       if(DMLScript.STATISTICS) {
+                                               
Statistics.nativeLibMatrixMultTime += System.nanoTime() - start;
+                                               
Statistics.numNativeLibMatrixMultCalls.increment();
+                                       }
+                                       ret.setNonZeros(nnz);
                                        ret.examSparsity();
-                               return;
+                                       return ret;
+                               }
+                               //else record failure and fallback to java
+                               Statistics.incrementNativeFailuresCounter();
+                               LOG.warn("matrixMult: Native mat mult failed. 
Falling back to java version ("
+                                       + "loaded=" + 
NativeHelper.isNativeLibraryLoaded()
+                                       + ", sparse=" + (m1.isInSparseFormat() 
| m2.isInSparseFormat()) + ")");
                        }
-                       //else record failure and fallback to java
-                       Statistics.incrementNativeFailuresCounter();
-                       LOG.warn("matrixMult: Native mat mult failed. Falling 
back to java version ("
-                               + "loaded=" + 
NativeHelper.isNativeLibraryLoaded()
-                               + ", sparse=" + (m1.isInSparseFormat() | 
m2.isInSparseFormat()) + ")");
                }
-               else if(isValidForNative)
+               else
                        LOG.warn("Was valid for native MM but native lib was 
not loaded");
                
-               if (k == 1)
-                       LibMatrixMult.matrixMult(m1, m2, ret, !examSparsity);
-               else
-                       LibMatrixMult.matrixMult(m1, m2, ret, k);
+               return LibMatrixMult.matrixMult(m1, m2, ret, k);
        }
        
        public static void tsmm(MatrixBlock m1, MatrixBlock ret, boolean 
leftTrans, int k) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index f208174..a0fcef6 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -2616,22 +2616,6 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock, Externalizab
                return size;
        }
 
-       public static SparsityEstimate estimateSparsityOnAggBinary(MatrixBlock 
m1, MatrixBlock m2, AggregateBinaryOperator op)
-       {
-               //Since MatrixMultLib always uses a dense output (except for 
ultra-sparse mm)
-               //with subsequent check for sparsity, we should always return a 
dense estimate.
-               //Once, we support more aggregate binary operations, we need to 
change this.
-               
-               //WARNING: KEEP CONSISTENT WITH LIBMATRIXMULT
-               //Note that it is crucial to report the right output 
representation because
-               //in case of block reuse (e.g., mmcj) the output 'reset' refers 
to either
-               //dense or sparse representation and hence would produce 
incorrect results
-               //if we report the wrong representation (i.e., missing reset on 
ultrasparse mm). 
-               
-               boolean ultrasparse = (m1.isUltraSparse() || 
m2.isUltraSparse());
-               return new SparsityEstimate(ultrasparse, 
m1.getNumRows()*m2.getNumRows());
-       }
-
        private static SparsityEstimate estimateSparsityOnBinary(MatrixBlock 
m1, MatrixBlock m2, BinaryOperator op)
        {
                SparsityEstimate est = new SparsityEstimate();
@@ -4988,34 +4972,34 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock, Externalizab
        }
 
        public MatrixBlock aggregateBinaryOperations(MatrixBlock m1, 
MatrixBlock m2, MatrixBlock ret, AggregateBinaryOperator op) {
+               checkAggregateBinaryOperations(m1, m2, op);
+               final int k = op.getNumThreads();
+               if(NativeHelper.isNativeLibraryLoaded())
+                       return LibMatrixNative.matrixMult(m1, m2, ret, k);
+               else 
+                       return LibMatrixMult.matrixMult(m1, m2, ret, k);
+       }
+
+       protected void checkAggregateBinaryOperations(MatrixBlock m1, 
MatrixBlock m2, AggregateBinaryOperator op) {
                //check input types, dimensions, configuration
-               if( m1.clen != m2.rlen ) {
+               if( m1.clen != m2.rlen )
                        throw new RuntimeException("Dimensions do not match for 
matrix multiplication ("+m1.clen+"!="+m2.rlen+").");
-               }
-               if( !(op.binaryFn instanceof Multiply && op.aggOp.increOp.fn 
instanceof Plus) ) {
-                       throw new DMLRuntimeException("Unsupported binary 
aggregate operation: ("+op.binaryFn+", "+op.aggOp+").");
-               }
-               
-               //setup meta data (dimensions, sparsity)
-               int rl = m1.rlen;
-               int cl = m2.clen;
-               SparsityEstimate sp = estimateSparsityOnAggBinary(m1, m2, op);
-               
-               //create output matrix block
-               if( ret==null )
-                       ret = new MatrixBlock(rl, cl, sp.sparse, 
sp.estimatedNonZeros);
-               else
-                       ret.reset(rl, cl, sp.sparse, sp.estimatedNonZeros);
-               
-               //compute matrix multiplication (only supported binary 
aggregate operation)
-               if( NativeHelper.isNativeLibraryLoaded() )
-                       LibMatrixNative.matrixMult(m1, m2, ret, 
op.getNumThreads());
-               else if( op.getNumThreads() > 1 )
-                       LibMatrixMult.matrixMult(m1, m2, ret, 
op.getNumThreads());
-               else
-                       LibMatrixMult.matrixMult(m1, m2, ret);
+               checkAggregateBinaryOperationsCommon(m1, m2, op);
+       }
+
+       protected void checkAggregateBinaryOperations(MatrixBlock m1, 
MatrixBlock m2, AggregateBinaryOperator op, boolean transposeLeft,
+                       boolean transposeRight) {
+               //check input types, dimensions, configuration
+               if((transposeLeft ? m1.rlen : m1.clen) != ( transposeRight ? 
m2.clen : m2.rlen) )
+                       throw new RuntimeException("Dimensions do not match for 
matrix multiplication ("+m1.clen+"!="+m2.rlen+").");
+               checkAggregateBinaryOperationsCommon(m1, m2, op);
+       }
                
-               return ret;
+       private void checkAggregateBinaryOperationsCommon(MatrixBlock m1, 
MatrixBlock m2, AggregateBinaryOperator op){
+               if( !(op.binaryFn instanceof Multiply && op.aggOp.increOp.fn 
instanceof Plus) )
+                       throw new DMLRuntimeException("Unsupported binary 
aggregate operation: ("+op.binaryFn+", "+op.aggOp+").");
+               if(!(m1 == this || m2 == this))
+                       throw new DMLRuntimeException("Invalid 
aggregateBinaryOperatio: one of either input should be this");
        }
 
        public MatrixBlock aggregateTernaryOperations(MatrixBlock m1, 
MatrixBlock m2, MatrixBlock m3, MatrixBlock ret,

Reply via email to