This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new acd0f69  [SYSTEMDS-2856] Multi-threaded binary matrix-matrix, 
matrix-scalar ops
acd0f69 is described below

commit acd0f6905c6c556725421794f4010af17f2a75c5
Author: Matthias Boehm <[email protected]>
AuthorDate: Wed Feb 10 23:01:31 2021 +0100

    [SYSTEMDS-2856] Multi-threaded binary matrix-matrix, matrix-scalar ops
    
    This patch is a first step towards extended multi-threaded operations
    support. So far binary operations were not multi-threaded because output
    allocation dominates the runtime for many operations. With parallel
    allocators, future in-place updates, increasing degree of parallelism,
    and somewhat inefficient sparse-unsafe code paths this changes. In this
    first step, we parallelize matrix-matrix unsafe operations, and
    matrix-scalar safe operations which did not have a lot of special case
    handling and thus could simply parallelize over row partitions.
    
    On a scenario of a 1M x 1050 input matrix (mostly dense except one
    one-hot encoded column), this patch improved the Kmeans runtime w/ 50
    centroids, 1 run, MKL matrix multiply, and ~60 iterations from 177s to
    109s (and relevant binary ops for <= and -2* from 87s to 15s).
---
 src/main/java/org/apache/sysds/hops/BinaryOp.java  |  15 +-
 src/main/java/org/apache/sysds/lops/Binary.java    |  27 ++-
 src/main/java/org/apache/sysds/lops/Unary.java     |   4 +-
 .../instructions/cp/BinaryCPInstruction.java       |   2 +-
 .../cp/BinaryMatrixMatrixCPInstruction.java        |   7 +-
 .../cp/BinaryMatrixScalarCPInstruction.java        |   6 +
 .../runtime/matrix/data/LibMatrixBincell.java      | 218 +++++++++++++++++----
 .../sysds/runtime/matrix/data/LibMatrixNative.java |   4 +-
 .../sysds/runtime/matrix/data/MatrixBlock.java     |  38 +++-
 .../runtime/matrix/operators/BinaryOperator.java   |   9 +
 .../matrix/operators/LeftScalarOperator.java       |   2 +-
 .../matrix/operators/RightScalarOperator.java      |   2 +-
 .../runtime/matrix/operators/ScalarOperator.java   |  16 +-
 13 files changed, 282 insertions(+), 68 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java 
b/src/main/java/org/apache/sysds/hops/BinaryOp.java
index cc5d58d..10e1c8d 100644
--- a/src/main/java/org/apache/sysds/hops/BinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java
@@ -414,12 +414,13 @@ public class BinaryOp extends MultiThreadedHop
                                (op==OpOp2.MULT && 
HopRewriteUtils.isLiteralOfValue(right, 2d)) ? OpOp1.MULT2 : null;
                        Lop tmp = null;
                        if( ot != null ) {
-                               tmp = new 
Unary(getInput().get(0).constructLops(),
-                                       getInput().get(1).constructLops(), ot, 
getDataType(), getValueType(), et);
+                               tmp = new Unary(getInput(0).constructLops(), 
getInput(1).constructLops(),
+                                       ot, getDataType(), getValueType(), et);
                        }
                        else { //general case
-                               tmp = new 
Binary(getInput().get(0).constructLops(),
-                                       getInput().get(1).constructLops(), op, 
getDataType(), getValueType(), et);
+                               tmp = new Binary(getInput(0).constructLops(), 
getInput(1).constructLops(),
+                                       op, getDataType(), getValueType(), et,
+                                       
OptimizerUtils.getConstrainedNumThreads(_maxNumThreads));
                        }
                        setOutputDimensions(tmp);
                        setLineNumbers(tmp);
@@ -458,9 +459,9 @@ public class BinaryOp extends MultiThreadedHop
                                                getDataType(), getValueType(), 
et, OptimizerUtils.getConstrainedNumThreads(_maxNumThreads));
                                }
                                else
-                                       binary = new 
Binary(getInput().get(0).constructLops(),
-                                               
getInput().get(1).constructLops(), op,
-                                               getDataType(), getValueType(), 
et);
+                                       binary = new 
Binary(getInput(0).constructLops(), getInput(1).constructLops(),
+                                               op, getDataType(), 
getValueType(), et,
+                                               
OptimizerUtils.getConstrainedNumThreads(_maxNumThreads));
                                
                                setOutputDimensions(binary);
                                setLineNumbers(binary);
diff --git a/src/main/java/org/apache/sysds/lops/Binary.java 
b/src/main/java/org/apache/sysds/lops/Binary.java
index 9ebe551..5fba53d 100644
--- a/src/main/java/org/apache/sysds/lops/Binary.java
+++ b/src/main/java/org/apache/sysds/lops/Binary.java
@@ -34,6 +34,7 @@ import org.apache.sysds.common.Types.ValueType;
 public class Binary extends Lop 
 {
        private OpOp2 operation;
+       private final int _numThreads;
        
        /**
         * Constructor to perform a binary operation.
@@ -45,9 +46,15 @@ public class Binary extends Lop
         * @param vt value type
         * @param et exec type
         */
+       
        public Binary(Lop input1, Lop input2, OpOp2 op, DataType dt, ValueType 
vt, ExecType et) {
+               this(input1, input2, op, dt, vt, et, 1);
+       }
+       
+       public Binary(Lop input1, Lop input2, OpOp2 op, DataType dt, ValueType 
vt, ExecType et, int k) {
                super(Lop.Type.Binary, dt, vt);
                init(input1, input2, op, dt, vt, et);
+               _numThreads = k;
        }
        
        private void init(Lop input1, Lop input2, OpOp2 op, DataType dt, 
ValueType vt, ExecType et)  {
@@ -74,10 +81,20 @@ public class Binary extends Lop
 
        @Override
        public String getInstructions(String input1, String input2, String 
output) {
-               return InstructionUtils.concatOperands(
-                       getExecType().toString(), getOpcode(),
-                       getInputs().get(0).prepInputOperand(input1),
-                       getInputs().get(1).prepInputOperand(input2),
-                       prepOutputOperand(output));
+               if( getExecType() == ExecType.CP ) {
+                       return InstructionUtils.concatOperands(
+                               getExecType().name(), getOpcode(),
+                               getInputs().get(0).prepInputOperand(input1),
+                               getInputs().get(1).prepInputOperand(input2),
+                               prepOutputOperand(output),
+                               String.valueOf(_numThreads));
+               }
+               else {
+                       return InstructionUtils.concatOperands(
+                               getExecType().name(), getOpcode(),
+                               getInputs().get(0).prepInputOperand(input1),
+                               getInputs().get(1).prepInputOperand(input2),
+                               prepOutputOperand(output));
+               }
        }
 }
diff --git a/src/main/java/org/apache/sysds/lops/Unary.java 
b/src/main/java/org/apache/sysds/lops/Unary.java
index aa51477..0e34ba2 100644
--- a/src/main/java/org/apache/sysds/lops/Unary.java
+++ b/src/main/java/org/apache/sysds/lops/Unary.java
@@ -127,7 +127,9 @@ public class Unary extends Lop
                        || op==OpOp1.CUMSUMPROD
                        || op==OpOp1.EXP
                        || op==OpOp1.LOG
-                       || op==OpOp1.SIGMOID;
+                       || op==OpOp1.SIGMOID
+                       || op==OpOp1.POW2
+                       || op==OpOp1.MULT2;
        }
        
        @Override
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java
index 2f0aad4..188b2ac 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java
@@ -65,7 +65,7 @@ public abstract class BinaryCPInstruction extends 
ComputationCPInstruction {
        
        protected static String parseBinaryInstruction(String instr, CPOperand 
in1, CPOperand in2, CPOperand out) {
                String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(instr);
-               InstructionUtils.checkNumFields ( parts, 3 );
+               InstructionUtils.checkNumFields ( parts, 3, 4 );
                String opcode = parts[0];
                in1.split(parts[1]);
                in2.split(parts[2]);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java
index 20ddfb1..abe815a 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java
@@ -25,6 +25,7 @@ import 
org.apache.sysds.runtime.compress.AbstractCompressedMatrixBlock;
 import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.functionobjects.Multiply;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.matrix.data.LibCommonsMath;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
@@ -32,10 +33,14 @@ import org.apache.sysds.runtime.matrix.operators.Operator;
 
 public class BinaryMatrixMatrixCPInstruction extends BinaryCPInstruction {
        private static final Log LOG = 
LogFactory.getLog(BinaryMatrixMatrixCPInstruction.class.getName());
-
+       
        protected BinaryMatrixMatrixCPInstruction(Operator op, CPOperand in1, 
CPOperand in2, CPOperand out,
                        String opcode, String istr) {
                super(CPType.Binary, op, in1, in2, out, opcode, istr);
+               if( op instanceof BinaryOperator ) {
+                       String[] parts = 
InstructionUtils.getInstructionParts(istr);
+                       
((BinaryOperator)op).setNumThreads(Integer.parseInt(parts[parts.length-1]));
+               }
        }
 
        @Override
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixScalarCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixScalarCPInstruction.java
index 6b00759..04932ad 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixScalarCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixScalarCPInstruction.java
@@ -21,6 +21,7 @@ package org.apache.sysds.runtime.instructions.cp;
 
 import org.apache.sysds.common.Types.DataType;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
@@ -30,6 +31,11 @@ public class BinaryMatrixScalarCPInstruction extends 
BinaryCPInstruction {
        protected BinaryMatrixScalarCPInstruction(Operator op, CPOperand in1, 
CPOperand in2, CPOperand out,
                        String opcode, String istr) {
                super(CPType.Binary, op, in1, in2, out, opcode, istr);
+               if( op instanceof ScalarOperator ) {
+                       String[] parts = 
InstructionUtils.getInstructionParts(istr);
+                       if( parts.length > 4 )
+                               
((ScalarOperator)op).setNumThreads(Integer.parseInt(parts[parts.length-1]));
+               }
        }
 
        @Override
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java
index c287d4f..249d0e3 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java
@@ -19,7 +19,13 @@
 
 package org.apache.sysds.runtime.matrix.data;
 
+import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.List;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
 
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.data.DenseBlock;
@@ -48,6 +54,7 @@ import org.apache.sysds.runtime.functionobjects.ValueFunction;
 import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode;
 import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
 import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
+import org.apache.sysds.runtime.util.CommonThreadPool;
 import org.apache.sysds.runtime.util.DataConverter;
 import org.apache.sysds.runtime.util.SortUtils;
 import org.apache.sysds.runtime.util.UtilFunctions;
@@ -61,6 +68,7 @@ import org.apache.sysds.runtime.util.UtilFunctions;
  */
 public class LibMatrixBincell 
 {
+       private static final long PAR_NUMCELL_THRESHOLD2 = 16*1024;   //Min 16K 
elements
 
        public enum BinaryAccessType {
                MATRIX_MATRIX,
@@ -94,7 +102,7 @@ public class LibMatrixBincell
                
                //execute binary cell operations
                if(op.sparseSafe)
-                       safeBinaryScalar(m1, ret, op);
+                       safeBinaryScalar(m1, ret, op, 0, m1.rlen);
                else
                        unsafeBinaryScalar(m1, ret, op);
                
@@ -104,6 +112,48 @@ public class LibMatrixBincell
                        ret.examSparsity();
        }
        
+       public static void bincellOp(MatrixBlock m1, MatrixBlock ret, 
ScalarOperator op, int k) {
+               //check internal assumptions 
+               if(   (op.sparseSafe && 
m1.isInSparseFormat()!=ret.isInSparseFormat())
+                       ||(!op.sparseSafe && ret.isInSparseFormat()) ) {
+                       throw new DMLRuntimeException("Wrong output 
representation for safe="+op.sparseSafe+": "+m1.isInSparseFormat()+", 
"+ret.isInSparseFormat());
+               }
+               
+               //fallback to singlet-threaded for special cases
+               if( m1.isEmpty() || !op.sparseSafe 
+                       || ret.getLength() < PAR_NUMCELL_THRESHOLD2 ) {
+                       bincellOp(m1, ret, op);
+                       return;
+               }
+               
+               //preallocate dense/sparse block for multi-threaded operations
+               ret.allocateBlock();
+               
+               try {
+                       //execute binary cell operations
+                       ExecutorService pool = CommonThreadPool.get(k);
+                       ArrayList<BincellScalarTask> tasks = new ArrayList<>();
+                       ArrayList<Integer> blklens = 
UtilFunctions.getBalancedBlockSizesDefault(ret.rlen, k, false);
+                       for( int i=0, lb=0; i<blklens.size(); 
lb+=blklens.get(i), i++ )
+                               tasks.add(new BincellScalarTask(m1, ret, op, 
lb, lb+blklens.get(i)));
+                       List<Future<Long>> taskret = pool.invokeAll(tasks);
+                       
+                       //aggregate non-zeros
+                       ret.nonZeros = 0; //reset after execute
+                       for( Future<Long> task : taskret )
+                               ret.nonZeros += task.get();
+                       pool.shutdown();
+               }
+               catch(InterruptedException | ExecutionException ex) {
+                       throw new DMLRuntimeException(ex);
+               }
+               
+               //ensure empty results sparse representation 
+               //(no additional memory requirements)
+               if( ret.isEmptyBlock(false) )
+                       ret.examSparsity();
+       }
+       
        /**
         * matrix-matrix binary operations, MM, MV
         * 
@@ -117,7 +167,7 @@ public class LibMatrixBincell
                if(op.sparseSafe || isSparseSafeDivide(op, m2))
                        safeBinary(m1, m2, ret, op);
                else
-                       unsafeBinary(m1, m2, ret, op);
+                       unsafeBinary(m1, m2, ret, op, 0, m1.rlen);
                
                //ensure empty results sparse representation 
                //(no additional memory requirements)
@@ -125,6 +175,45 @@ public class LibMatrixBincell
                        ret.examSparsity();
        }
        
+       public static void bincellOp(MatrixBlock m1, MatrixBlock m2, 
MatrixBlock ret, BinaryOperator op, int k) {
+               //fallback to sequential computation for specialized operations
+               //TODO parallel support for all sparse safe operations
+               if( op.sparseSafe || isSparseSafeDivide(op, m2)
+                       || ret.getLength() < PAR_NUMCELL_THRESHOLD2
+                       || getBinaryAccessType(m1, m2) == 
BinaryAccessType.OUTER_VECTOR_VECTOR)
+               {
+                       bincellOp(m1, m2, ret, op);
+                       return;
+               }
+               
+               //preallocate dense/sparse block for multi-threaded operations
+               ret.allocateBlock();
+               
+               try {
+                       //execute binary cell operations
+                       ExecutorService pool = CommonThreadPool.get(k);
+                       ArrayList<BincellTask> tasks = new ArrayList<>();
+                       ArrayList<Integer> blklens = 
UtilFunctions.getBalancedBlockSizesDefault(ret.rlen, k, false);
+                       for( int i=0, lb=0; i<blklens.size(); 
lb+=blklens.get(i), i++ )
+                               tasks.add(new BincellTask(m1, m2, ret, op, lb, 
lb+blklens.get(i)));
+                       List<Future<Long>> taskret = pool.invokeAll(tasks);
+                       
+                       //aggregate non-zeros
+                       ret.nonZeros = 0; //reset after execute
+                       for( Future<Long> task : taskret )
+                               ret.nonZeros += task.get();
+                       pool.shutdown();
+               }
+               catch(InterruptedException | ExecutionException ex) {
+                       throw new DMLRuntimeException(ex);
+               }
+               
+               //ensure empty results sparse representation
+               //(no additional memory requirements)
+               if( ret.isEmptyBlock(false) )
+                       ret.examSparsity();
+       }
+       
        /**
         * NOTE: operations in place always require m1 and m2 to be of equal 
dimensions
         * 
@@ -919,42 +1008,40 @@ public class LibMatrixBincell
                ret.examSparsity();
        }
 
-       private static void unsafeBinary(MatrixBlock m1, MatrixBlock m2, 
MatrixBlock ret, BinaryOperator op) {
-               int rlen = m1.rlen;
+       private static void unsafeBinary(MatrixBlock m1, MatrixBlock m2, 
MatrixBlock ret, BinaryOperator op, int rl, int ru) {
                int clen = m1.clen;
                BinaryAccessType atype = getBinaryAccessType(m1, m2);
                
-               if( atype == BinaryAccessType.MATRIX_COL_VECTOR ) //MATRIX - 
COL_VECTOR
-               {
-                       for(int r=0; r<rlen; r++) {
+               long lnnz = 0;
+               if( atype == BinaryAccessType.MATRIX_COL_VECTOR ) { //MATRIX - 
COL_VECTOR
+                       for(int r=rl; r<ru; r++) {
                                double v2 = m2.quickGetValue(r, 0);
                                for(int c=0; c<clen; c++) {
                                        double v1 = m1.quickGetValue(r, c);
                                        double v = op.fn.execute( v1, v2 );
-                                       ret.appendValue(r, c, v);
+                                       ret.appendValuePlain(r, c, v);
+                                       lnnz += (v!=0) ? 1 : 0;
                                }
                        }
                }
-               else if( atype == BinaryAccessType.MATRIX_ROW_VECTOR ) //MATRIX 
- ROW_VECTOR
-               {
-                       for(int r=0; r<rlen; r++)
+               else if( atype == BinaryAccessType.MATRIX_ROW_VECTOR ) { 
//MATRIX - ROW_VECTOR
+                       for(int r=rl; r<ru; r++)
                                for(int c=0; c<clen; c++) {
                                        double v1 = m1.quickGetValue(r, c);
                                        double v2 = m2.quickGetValue(0, c);
                                        double v = op.fn.execute( v1, v2 );
-                                       ret.appendValue(r, c, v);
+                                       ret.appendValuePlain(r, c, v);
+                                       lnnz += (v!=0) ? 1 : 0;
                                }
                }
-               else if( atype == BinaryAccessType.OUTER_VECTOR_VECTOR ) 
//VECTOR - VECTOR
-               {
+               else if( atype == BinaryAccessType.OUTER_VECTOR_VECTOR ) { 
//VECTOR - VECTOR
                        int clen2 = m2.clen; 
-                       
                        if(LibMatrixOuterAgg.isCompareOperator(op) 
                                && m2.getNumColumns()>16 && 
SortUtils.isSorted(m2)) {
                                performBinOuterOperation(m1, m2, ret, op);
                        } 
                        else {
-                               for(int r=0; r<rlen; r++) {
+                               for(int r=rl; r<ru; r++) {
                                        double v1 = m1.quickGetValue(r, 0);
                                        for(int c=0; c<clen2; c++) {
                                                double v2 = m2.quickGetValue(0, 
c);
@@ -974,28 +1061,30 @@ public class LibMatrixBincell
                                double[] a = m1.getDenseBlockValues();
                                double[] b = m2.getDenseBlockValues();
                                double[] c = ret.getDenseBlockValues();
-                               int lnnz = 0;
-                               for( int i=0; i<rlen; i++ ) {
+                               for( int i=rl; i<ru; i++ ) {
                                        c[i] = op.fn.execute( a[i], b[i] );
                                        lnnz += (c[i] != 0) ? 1 : 0;
                                }
-                               ret.nonZeros = lnnz;
                        }
                        //general case
-                       else 
-                       {
-                               for(int r=0; r<rlen; r++)
+                       else {
+                               for(int r=rl; r<ru; r++)
                                        for(int c=0; c<clen; c++) {
                                                double v1 = m1.quickGetValue(r, 
c);
                                                double v2 = m2.quickGetValue(r, 
c);
                                                double v = op.fn.execute( v1, 
v2 );
-                                               ret.appendValue(r, c, v);
+                                               ret.appendValuePlain(r, c, v);
+                                               lnnz += (v!=0) ? 1 : 0;
                                        }
                        }
                }
+               
+               //avoid false sharing in multi-threaded ops, while
+               //correctly setting the nnz for single-threaded ops
+               ret.nonZeros = lnnz;
        }
 
-       private static void safeBinaryScalar(MatrixBlock m1, MatrixBlock ret, 
ScalarOperator op) {
+       private static void safeBinaryScalar(MatrixBlock m1, MatrixBlock ret, 
ScalarOperator op, int rl, int ru) {
                //early abort possible since sparsesafe
                if( m1.isEmptyBlock(false) ) {
                        return;
@@ -1016,10 +1105,9 @@ public class LibMatrixBincell
                        ret.allocateSparseRowsBlock();
                        SparseBlock a = m1.sparseBlock;
                        SparseBlock c = ret.sparseBlock;
-                       int rlen = Math.min(m1.rlen, a.numRows());
                        
                        long nnz = 0;
-                       for(int r=0; r<rlen; r++) {
+                       for(int r=rl; r<ru; r++) {
                                if( a.isEmpty(r) ) continue;
                                
                                int apos = a.pos(r);
@@ -1053,7 +1141,7 @@ public class LibMatrixBincell
                        ret.nonZeros = nnz;
                }
                else { //DENSE <- DENSE
-                       denseBinaryScalar(m1, ret, op);
+                       denseBinaryScalar(m1, ret, op, rl, ru);
                }
        }
        
@@ -1078,14 +1166,15 @@ public class LibMatrixBincell
                if( ret.sparse )
                        throw new DMLRuntimeException("Unsupported unsafe 
binary scalar operations over sparse output representation.");
                
+               int m = m1.rlen;
+               int n = m1.clen;
+               
                if( m1.sparse ) //SPARSE MATRIX
                {
                        ret.allocateDenseBlock();
                        
                        SparseBlock a = m1.sparseBlock;
                        DenseBlock dc = ret.getDenseBlock();
-                       int m = m1.rlen;
-                       int n = m1.clen;
                        
                        //init dense result with unsafe 0-value
                        double val0 = op.executeScalar(0);
@@ -1115,26 +1204,27 @@ public class LibMatrixBincell
                        ret.nonZeros = nnz;
                }
                else { //DENSE MATRIX
-                       denseBinaryScalar(m1, ret, op);
+                       denseBinaryScalar(m1, ret, op, 0, m);
                }
        }
 
-       private static void denseBinaryScalar(MatrixBlock m1, MatrixBlock ret, 
ScalarOperator op) {
+       private static void denseBinaryScalar(MatrixBlock m1, MatrixBlock ret, 
ScalarOperator op, int rl, int ru) {
                //allocate dense block (if necessary), incl clear nnz
                ret.allocateDenseBlock(true);
                
                DenseBlock da = m1.getDenseBlock();
                DenseBlock dc = ret.getDenseBlock();
+               int clen = m1.clen;
                
                //compute scalar operation, incl nnz maintenance
                long nnz = 0;
-               for( int bi=0; bi<da.numBlocks(); bi++) {
-                       double[] a = da.valuesAt(bi);
-                       double[] c = dc.valuesAt(bi);
-                       int limit = da.size(bi);
-                       for( int i=0; i<limit; i++ ) {
-                               c[i] = op.executeScalar( a[i] );
-                               nnz += (c[i] != 0) ? 1 : 0;
+               for(int i=rl; i<ru; i++) {
+                       double[] a = da.values(i);
+                       double[] c = dc.values(i);
+                       int apos = da.pos(i), cpos = dc.pos(i);
+                       for(int j=0; j<clen; j++) {
+                               c[cpos+j] = op.executeScalar( a[apos+j] );
+                               nnz += (c[cpos+j] != 0) ? 1 : 0;
                        }
                }
                ret.nonZeros = nnz;
@@ -1408,4 +1498,56 @@ public class LibMatrixBincell
                if( zero )
                        c.compact(r);
        }
+       
+       private static class BincellTask implements Callable<Long> {
+               private final MatrixBlock _m1;
+               private final MatrixBlock _m2;
+               private final MatrixBlock _ret;
+               private final BinaryOperator _bop;
+               private final int _rl;
+               private final int _ru;
+
+               protected BincellTask( MatrixBlock m1, MatrixBlock m2, 
MatrixBlock ret, BinaryOperator bop, int rl, int ru ) {
+                       _m1 = m1;
+                       _m2 = m2;
+                       _ret = ret;
+                       _bop = bop;
+                       _rl = rl;
+                       _ru = ru;
+               }
+               
+               @Override
+               public Long call() {
+                       //execute binary operation on row partition
+                       unsafeBinary(_m1, _m2, _ret, _bop, _rl, _ru);
+                       
+                       //maintain block nnz (upper bounds inclusive)
+                       return _ret.recomputeNonZeros(_rl, _ru-1);
+               }
+       }
+       
+       private static class BincellScalarTask implements Callable<Long> {
+               private final MatrixBlock _m1;
+               private final MatrixBlock _ret;
+               private final ScalarOperator _sop;
+               private final int _rl;
+               private final int _ru;
+
+               protected BincellScalarTask( MatrixBlock m1, MatrixBlock ret, 
ScalarOperator sop, int rl, int ru ) {
+                       _m1 = m1;
+                       _ret = ret;
+                       _sop = sop;
+                       _rl = rl;
+                       _ru = ru;
+               }
+               
+               @Override
+               public Long call() {
+                       //execute binary operation on row partition
+                       safeBinaryScalar(_m1, _ret, _sop, _rl, _ru);
+                       
+                       //maintain block nnz (upper bounds inclusive)
+                       return _ret.recomputeNonZeros(_rl, _ru-1);
+               }
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixNative.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixNative.java
index d64a936..6e7ba49 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixNative.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixNative.java
@@ -116,7 +116,9 @@ public class LibMatrixNative
                        Statistics.incrementNativeFailuresCounter();
                }
                //fallback to default java implementation
-               LOG.warn("matrixMult: Native mat mult failed. Falling back to 
java version.");
+               LOG.warn("matrixMult: Native mat mult failed. Falling back to 
java version ("
+                       + "loaded=" + NativeHelper.isNativeLibraryLoaded() 
+                       + ", sparse=" + (m1.isInSparseFormat() | 
m2.isInSparseFormat()) + ")");
                if (k == 1)
                        LibMatrixMult.matrixMult(m1, m2, ret, !examSparsity);
                else
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index ed090c3..5d4b869 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -694,8 +694,7 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock, Externalizab
                if( v == 0 ) 
                        return;
 
-               if( !sparse ) //DENSE 
-               {
+               if( !sparse ) { //DENSE 
                        //allocate on demand (w/o overwriting nnz)
                        allocateDenseBlock(false);
                        
@@ -703,8 +702,7 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock, Externalizab
                        denseBlock.set(r, c, v);
                        nonZeros++;
                }
-               else //SPARSE
-               {
+               else { //SPARSE
                        //allocation on demand (w/o overwriting nnz)
                        allocateSparseRowsBlock(false);
                        sparseBlock.allocate(r, estimatedNNzsPerRow, clen);
@@ -715,6 +713,28 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock, Externalizab
                }
        }
 
+       public void appendValuePlain(int r, int c, double v) {
+               //early abort (append guarantees no overwrite)
+               if( v == 0 ) 
+                       return;
+
+               if( !sparse ) { //DENSE 
+                       //allocate on demand (w/o overwriting nnz)
+                       allocateDenseBlock(false);
+                       
+                       //set value and maintain nnz
+                       denseBlock.set(r, c, v);
+               }
+               else { //SPARSE
+                       //allocation on demand (w/o overwriting nnz)
+                       allocateSparseRowsBlock(false);
+                       sparseBlock.allocate(r, estimatedNNzsPerRow, clen);
+                       
+                       //set value and maintain nnz
+                       sparseBlock.append(r, c, v);
+               }
+       }
+       
        public void appendRow(int r, SparseRow row) {
                appendRow(r, row, true);
        }
@@ -2659,7 +2679,10 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock, Externalizab
                        ret.reset(rlen, clen, sp, this.nonZeros);
                
                //core scalar operations
-               LibMatrixBincell.bincellOp(this, ret, op);
+               if( op.getNumThreads() > 1 )
+                       LibMatrixBincell.bincellOp(this, ret, op, 
op.getNumThreads());
+               else
+                       LibMatrixBincell.bincellOp(this, ret, op);
                
                return ret;
        }
@@ -2842,7 +2865,10 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock, Externalizab
                        ret.reset(rows, cols, resultSparse.sparse, 
resultSparse.estimatedNonZeros);
                
                //core binary cell operation
-               LibMatrixBincell.bincellOp( this, that, ret, op );
+               if( op.getNumThreads() > 1 )
+                       LibMatrixBincell.bincellOp( this, that, ret, op, 
op.getNumThreads() );
+               else
+                       LibMatrixBincell.bincellOp( this, that, ret, op );
                
                return ret;
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java 
b/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java
index bc4cdd0..7579046 100644
--- 
a/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java
+++ 
b/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java
@@ -57,6 +57,7 @@ public class BinaryOperator  extends Operator implements 
Serializable
 
        public final ValueFunction fn;
        public final boolean commutative;
+       private int _k = 1; // num threads
        
        public BinaryOperator(ValueFunction p) {
                //binaryop is sparse-safe iff (0 op 0) == 0
@@ -70,6 +71,14 @@ public class BinaryOperator  extends Operator implements 
Serializable
                        || p instanceof And || p instanceof Or || p instanceof 
Xor;
        }
        
+       public void setNumThreads(int k) {
+               _k = k;
+       }
+       
+       public int getNumThreads() {
+               return _k;
+       }
+       
        /**
         * Method for getting the hop binary operator type for a given function 
object.
         * This is used in order to use a common code path for consistency 
between 
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/operators/LeftScalarOperator.java
 
b/src/main/java/org/apache/sysds/runtime/matrix/operators/LeftScalarOperator.java
index 7a40a3f..abca742 100644
--- 
a/src/main/java/org/apache/sysds/runtime/matrix/operators/LeftScalarOperator.java
+++ 
b/src/main/java/org/apache/sysds/runtime/matrix/operators/LeftScalarOperator.java
@@ -58,7 +58,7 @@ public class LeftScalarOperator extends ScalarOperator
 
        @Override
        public ScalarOperator setConstant(double cst) {
-               return new LeftScalarOperator(fn, cst);
+               return new LeftScalarOperator(fn, cst, getNumThreads());
        }
 
        @Override
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/operators/RightScalarOperator.java
 
b/src/main/java/org/apache/sysds/runtime/matrix/operators/RightScalarOperator.java
index a55ed66..fe821e0 100644
--- 
a/src/main/java/org/apache/sysds/runtime/matrix/operators/RightScalarOperator.java
+++ 
b/src/main/java/org/apache/sysds/runtime/matrix/operators/RightScalarOperator.java
@@ -56,7 +56,7 @@ public class RightScalarOperator extends ScalarOperator
 
        @Override
        public ScalarOperator setConstant(double cst) {
-               return new RightScalarOperator(fn, cst);
+               return new RightScalarOperator(fn, cst, getNumThreads());
        }
 
        @Override
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/operators/ScalarOperator.java 
b/src/main/java/org/apache/sysds/runtime/matrix/operators/ScalarOperator.java
index 8f27209..d33bbae 100644
--- 
a/src/main/java/org/apache/sysds/runtime/matrix/operators/ScalarOperator.java
+++ 
b/src/main/java/org/apache/sysds/runtime/matrix/operators/ScalarOperator.java
@@ -44,7 +44,7 @@ public abstract class ScalarOperator extends Operator
 
        public final ValueFunction fn;
        protected final double _constant;
-       private final int k; //num threads
+       private int _k; //num threads
        
        public ScalarOperator(ValueFunction p, double cst) {
                this(p, cst, false);
@@ -63,13 +63,21 @@ public abstract class ScalarOperator extends Operator
                                || (p instanceof Builtin && 
((Builtin)p).getBuiltinCode()==BuiltinCode.MIN && cst>=0));
                fn = p;
                _constant = cst;
-               k = numThreads;
+               _k = numThreads;
        }
        
        public double getConstant() {
                return _constant;
        }
        
+       public void setNumThreads(int k) {
+               _k = k;
+       }
+       
+       public int getNumThreads() {
+               return _k;
+       }
+       
        public abstract ScalarOperator setConstant(double cst);
        
        public abstract ScalarOperator setConstant(double cst, int numThreads);
@@ -94,8 +102,4 @@ public abstract class ScalarOperator extends Operator
                        || fn instanceof Builtin && 
((Builtin)fn).getBuiltinCode()==BuiltinCode.LOG_NZ)
                        || fn instanceof BitwShiftL || fn instanceof BitwShiftR;
        }
-
-       public int getNumThreads() {
-               return k;
-       }
 }

Reply via email to