This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/systemds.git
commit af4f3d7683cf5dcd62d45858fb0290a607e66dfe Author: Matthias Boehm <[email protected]> AuthorDate: Sat Feb 13 01:31:57 2021 +0100 [SYSTEMDS-2856] Extended multi-threading binary and ternary operations This patch generalized the multi-threading of binary (sparse-unsafe) to binary (sparse-unsafe and sparse-safe matrix) and ternary operations, where the latter often calls binary sparse-safe matrix operations. For an mnist lenet parameter server scenario, this patch improved end-to-end performance from 205s to 168s. It also slightly improved other algorithms like KMeans. --- src/main/java/org/apache/sysds/hops/TernaryOp.java | 16 +- src/main/java/org/apache/sysds/lops/Ternary.java | 11 +- .../runtime/instructions/InstructionUtils.java | 6 +- .../cp/ParamservBuiltinCPInstruction.java | 6 +- .../instructions/cp/TernaryCPInstruction.java | 3 +- .../runtime/matrix/data/LibMatrixBincell.java | 241 ++++++++++++--------- .../runtime/matrix/data/LibMatrixTercell.java | 132 +++++++++++ .../sysds/runtime/matrix/data/MatrixBlock.java | 25 +-- .../runtime/matrix/operators/TernaryOperator.java | 10 +- 9 files changed, 321 insertions(+), 129 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/TernaryOp.java b/src/main/java/org/apache/sysds/hops/TernaryOp.java index f8c369c..47e42bb 100644 --- a/src/main/java/org/apache/sysds/hops/TernaryOp.java +++ b/src/main/java/org/apache/sysds/hops/TernaryOp.java @@ -57,9 +57,8 @@ import org.apache.sysds.runtime.meta.MatrixCharacteristics; * * CTABLE op takes 2 extra inputs with target dimensions for padding and pruning. */ -public class TernaryOp extends Hop +public class TernaryOp extends MultiThreadedHop { - public static boolean ALLOW_CTABLE_SEQUENCE_REWRITES = true; private OpOp3 _op = null; @@ -147,6 +146,13 @@ public class TernaryOp extends Hop } @Override + public boolean isMultiThreadedOpType() { + return _op == OpOp3.IFELSE + || _op == OpOp3.MINUS_MULT + || _op == OpOp3.PLUS_MULT; + } + + @Override public Lop constructLops() { //return already created lops @@ -324,13 +330,17 @@ public class TernaryOp extends Hop private void constructLopsTernaryDefault() { ExecType et = optFindExecType(); + int k = 1; if( getInput().stream().allMatch(h -> h.getDataType().isScalar()) ) et = ExecType.CP; //always CP for pure scalar operations + else + k= OptimizerUtils.getConstrainedNumThreads( _maxNumThreads ); + Ternary plusmult = new Ternary(_op, getInput().get(0).constructLops(), getInput().get(1).constructLops(), getInput().get(2).constructLops(), - getDataType(),getValueType(), et ); + getDataType(),getValueType(), et, k ); setOutputDimensions(plusmult); setLineNumbers(plusmult); setLops(plusmult); diff --git a/src/main/java/org/apache/sysds/lops/Ternary.java b/src/main/java/org/apache/sysds/lops/Ternary.java index 8faea5d..a1a2d53 100644 --- a/src/main/java/org/apache/sysds/lops/Ternary.java +++ b/src/main/java/org/apache/sysds/lops/Ternary.java @@ -33,10 +33,12 @@ import org.apache.sysds.common.Types.ValueType; public class Ternary extends Lop { private final OpOp3 _op; - - public Ternary(OpOp3 op, Lop input1, Lop input2, Lop input3, DataType dt, ValueType vt, ExecType et) { + private final int _numThreads; + + public Ternary(OpOp3 op, Lop input1, Lop input2, Lop input3, DataType dt, ValueType vt, ExecType et, int numThreads) { super(Lop.Type.Ternary, dt, vt); _op = op; + _numThreads = numThreads; init(input1, input2, input3, et); } @@ -71,6 +73,11 @@ public class Ternary extends Lop sb.append( OPERAND_DELIMITOR ); sb.append( prepOutputOperand(output) ); + if( getExecType() == ExecType.CP && getDataType().isMatrix() ) { + sb.append( OPERAND_DELIMITOR ); + sb.append( _numThreads ); + } + return sb.toString(); } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java index 49c3452..9245132 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java @@ -595,8 +595,12 @@ public class InstructionUtils } public static TernaryOperator parseTernaryOperator(String opcode) { + return parseTernaryOperator(opcode, 1); + } + + public static TernaryOperator parseTernaryOperator(String opcode, int numThreads) { return new TernaryOperator(opcode.equals("+*") ? PlusMultiply.getFnObject() : - opcode.equals("-*") ? MinusMultiply.getFnObject() : IfElse.getFnObject()); + opcode.equals("-*") ? MinusMultiply.getFnObject() : IfElse.getFnObject(), numThreads); } /** diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java index 0fa5297..a99e8ee 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java @@ -320,7 +320,8 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc // Create the local workers List<LocalPSWorker> workers = IntStream.range(0, workerNum) - .mapToObj(i -> new LocalPSWorker(i, updFunc, freq, getEpochs(), getBatchSize(), workerECs.get(i), ps)) + .mapToObj(i -> new LocalPSWorker(i, updFunc, freq, + getEpochs(), getBatchSize(), workerECs.get(i), ps, workerNum==1)) .collect(Collectors.toList()); // Do data partition @@ -497,7 +498,8 @@ public class ParamservBuiltinCPInstruction extends ParameterizedBuiltinCPInstruc private void partitionLocally(PSScheme scheme, ExecutionContext ec, List<LocalPSWorker> workers) { MatrixObject features = ec.getMatrixObject(getParam(PS_FEATURES)); MatrixObject labels = ec.getMatrixObject(getParam(PS_LABELS)); - DataPartitionLocalScheme.Result result = new LocalDataPartitioner(scheme).doPartitioning(workers.size(), features.acquireReadAndRelease(), labels.acquireReadAndRelease()); + DataPartitionLocalScheme.Result result = new LocalDataPartitioner(scheme) + .doPartitioning(workers.size(), features.acquireReadAndRelease(), labels.acquireReadAndRelease()); List<MatrixObject> pfs = result.pFeatures; List<MatrixObject> pls = result.pLabels; if (pfs.size() < workers.size()) { diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/TernaryCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/TernaryCPInstruction.java index 9d4232b..14d0090 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/TernaryCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/TernaryCPInstruction.java @@ -38,7 +38,8 @@ public class TernaryCPInstruction extends ComputationCPInstruction { CPOperand operand2 = new CPOperand(parts[2]); CPOperand operand3 = new CPOperand(parts[3]); CPOperand outOperand = new CPOperand(parts[4]); - TernaryOperator op = InstructionUtils.parseTernaryOperator(opcode); + int numThreads = parts.length>5 ? Integer.parseInt(parts[5]) : 1; + TernaryOperator op = InstructionUtils.parseTernaryOperator(opcode, numThreads); return new TernaryCPInstruction(op, operand1, operand2, operand3, outOperand, opcode,str); } diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java index 249d0e3..29363b0 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixBincell.java @@ -163,11 +163,19 @@ public class LibMatrixBincell * @param op binary operator */ public static void bincellOp(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, BinaryOperator op) { + BinaryAccessType atype = getBinaryAccessType(m1, m2); + + //preallocate for consistency + if( atype == BinaryAccessType.MATRIX_MATRIX ) + ret.allocateBlock(); //chosen outside + //execute binary cell operations + long nnz = 0; if(op.sparseSafe || isSparseSafeDivide(op, m2)) - safeBinary(m1, m2, ret, op); + nnz = safeBinary(m1, m2, ret, op, atype, 0, m1.rlen); else - unsafeBinary(m1, m2, ret, op, 0, m1.rlen); + nnz = unsafeBinary(m1, m2, ret, op, 0, m1.rlen); + ret.setNonZeros(nnz); //ensure empty results sparse representation //(no additional memory requirements) @@ -176,18 +184,20 @@ public class LibMatrixBincell } public static void bincellOp(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, BinaryOperator op, int k) { + BinaryAccessType atype = getBinaryAccessType(m1, m2); + //fallback to sequential computation for specialized operations - //TODO parallel support for all sparse safe operations - if( op.sparseSafe || isSparseSafeDivide(op, m2) - || ret.getLength() < PAR_NUMCELL_THRESHOLD2 - || getBinaryAccessType(m1, m2) == BinaryAccessType.OUTER_VECTOR_VECTOR) + if( m1.isEmpty() || m2.isEmpty() + || ret.getLength() < PAR_NUMCELL_THRESHOLD2 + || ((op.sparseSafe || isSparseSafeDivide(op, m2)) + && atype != BinaryAccessType.MATRIX_MATRIX)) { bincellOp(m1, m2, ret, op); return; } //preallocate dense/sparse block for multi-threaded operations - ret.allocateBlock(); + ret.allocateBlock(); //chosen outside try { //execute binary cell operations @@ -195,7 +205,7 @@ public class LibMatrixBincell ArrayList<BincellTask> tasks = new ArrayList<>(); ArrayList<Integer> blklens = UtilFunctions.getBalancedBlockSizesDefault(ret.rlen, k, false); for( int i=0, lb=0; i<blklens.size(); lb+=blklens.get(i), i++ ) - tasks.add(new BincellTask(m1, m2, ret, op, lb, lb+blklens.get(i))); + tasks.add(new BincellTask(m1, m2, ret, op, atype, lb, lb+blklens.get(i))); List<Future<Long>> taskret = pool.invokeAll(tasks); //aggregate non-zeros @@ -286,7 +296,11 @@ public class LibMatrixBincell // private sparse-safe/sparse-unsafe implementations /////////////////////////////////// - private static void safeBinary(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, BinaryOperator op) { + private static long safeBinary(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, BinaryOperator op, + BinaryAccessType atype, int rl, int ru) + { + //NOTE: multi-threaded over rl-ru only applied for matrix-matrix, non-empty + boolean skipEmpty = (op.fn instanceof Multiply || isSparseSafeDivide(op, m2) ); boolean copyLeftRightEmpty = (op.fn instanceof Plus || op.fn instanceof Minus @@ -297,12 +311,11 @@ public class LibMatrixBincell if( m1.isEmptyBlock(false) && m2.isEmptyBlock(false) || skipEmpty && (m1.isEmptyBlock(false) || m2.isEmptyBlock(false)) ) { - return; + return 0; } - BinaryAccessType atype = getBinaryAccessType(m1, m2); if( atype == BinaryAccessType.MATRIX_COL_VECTOR //MATRIX - VECTOR - || atype == BinaryAccessType.MATRIX_ROW_VECTOR) + || atype == BinaryAccessType.MATRIX_ROW_VECTOR) { //note: m2 vector and hence always dense if( !m1.sparse && !m2.sparse && !ret.sparse ) //DENSE all @@ -318,7 +331,7 @@ public class LibMatrixBincell safeBinaryMVDenseSparseMult(m1, m2, ret, op); else //generic combinations safeBinaryMVGeneric(m1, m2, ret, op); - } + } else if( atype == BinaryAccessType.OUTER_VECTOR_VECTOR ) //VECTOR - VECTOR { safeBinaryVVGeneric(m1, m2, ret, op); @@ -334,25 +347,27 @@ public class LibMatrixBincell ret.copyShallow(m2); } else if(m1.sparse && m2.sparse) { - safeBinaryMMSparseSparse(m1, m2, ret, op); + return safeBinaryMMSparseSparse(m1, m2, ret, op, rl, ru); } else if( !ret.sparse && (m1.sparse || m2.sparse) && (op.fn instanceof Plus || op.fn instanceof Minus || op.fn instanceof PlusMultiply || op.fn instanceof MinusMultiply || (op.fn instanceof Multiply && !m2.sparse ))) { - safeBinaryMMSparseDenseDense(m1, m2, ret, op); + return safeBinaryMMSparseDenseDense(m1, m2, ret, op, rl, ru); } else if( !ret.sparse && !m1.sparse && !m2.sparse && m1.denseBlock!=null && m2.denseBlock!=null ) { - safeBinaryMMDenseDenseDense(m1, m2, ret, op); + return safeBinaryMMDenseDenseDense(m1, m2, ret, op, rl, ru); } else if( skipEmpty && (m1.sparse || m2.sparse) ) { - safeBinaryMMSparseDenseSkip(m1, m2, ret, op); + return safeBinaryMMSparseDenseSkip(m1, m2, ret, op, rl, ru); } else { //generic case - safeBinaryMMGeneric(m1, m2, ret, op); + return safeBinaryMMGeneric(m1, m2, ret, op, rl, ru); } } + //default catch all + return ret.getNonZeros(); } private static void safeBinaryMVDense(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, BinaryOperator op) { @@ -737,12 +752,11 @@ public class LibMatrixBincell //no need to recomputeNonZeros since maintained in append value } - private static void safeBinaryMMSparseSparse(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, BinaryOperator op) { - final int rlen = m1.rlen; - if(ret.sparse) - ret.allocateSparseRowsBlock(); - + private static long safeBinaryMMSparseSparse(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, + BinaryOperator op, int rl, int ru) + { //both sparse blocks existing + long lnnz = 0; if(m1.sparseBlock!=null && m2.sparseBlock!=null) { SparseBlock lsblock = m1.sparseBlock; @@ -751,7 +765,7 @@ public class LibMatrixBincell if( ret.sparse && lsblock.isAligned(rsblock) ) { SparseBlock c = ret.sparseBlock; - for(int r=0; r<rlen; r++) + for(int r=rl; r<ru; r++) if( !lsblock.isEmpty(r) ) { int alen = lsblock.size(r); int apos = lsblock.pos(r); @@ -763,12 +777,12 @@ public class LibMatrixBincell double tmp = op.fn.execute(avals[j], bvals[j]); c.append(r, aix[j], tmp); } - ret.nonZeros += c.size(r); + lnnz += c.size(r); } } else //general case { - for(int r=0; r<rlen; r++) { + for(int r=rl; r<ru; r++) { if( !lsblock.isEmpty(r) && !rsblock.isEmpty(r) ) { mergeForSparseBinary(op, lsblock.values(r), lsblock.indexes(r), lsblock.pos(r), lsblock.size(r), rsblock.values(r), rsblock.indexes(r), rsblock.pos(r), rsblock.size(r), r, ret); @@ -782,6 +796,7 @@ public class LibMatrixBincell lsblock.pos(r), lsblock.size(r), 0, r, ret); } // do nothing if both not existing + lnnz += ret.recomputeNonZeros(r, r); } } } @@ -789,55 +804,63 @@ public class LibMatrixBincell else if( m2.sparseBlock!=null ) { SparseBlock rsblock = m2.sparseBlock; - for(int r=0; r<Math.min(rlen, rsblock.numRows()); r++) { + for(int r=rl; r<Math.min(ru, rsblock.numRows()); r++) { if( rsblock.isEmpty(r) ) continue; appendRightForSparseBinary(op, rsblock.values(r), rsblock.indexes(r), rsblock.pos(r), rsblock.size(r), 0, r, ret); + lnnz += ret.recomputeNonZeros(r, r); } } //left sparse block existing else { SparseBlock lsblock = m1.sparseBlock; - for(int r=0; r<rlen; r++) { + for(int r=rl; r<ru; r++) { if( lsblock.isEmpty(r) ) continue; appendLeftForSparseBinary(op, lsblock.values(r), lsblock.indexes(r), lsblock.pos(r), lsblock.size(r), 0, r, ret); + lnnz += ret.recomputeNonZeros(r, r); } } + return lnnz; } - private static void safeBinaryMMSparseDenseDense(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, BinaryOperator op) { + private static long safeBinaryMMSparseDenseDense(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, + BinaryOperator op, int rl, int ru) + { //specific case in order to prevent binary search on sparse inputs (see quickget and quickset) - ret.allocateDenseBlock(); final int n = ret.clen; DenseBlock dc = ret.getDenseBlock(); //1) process left input: assignment - if( m1.sparse && m1.sparseBlock != null ) //SPARSE left { SparseBlock a = m1.sparseBlock; - for( int bi=0; bi<dc.numBlocks(); bi++ ) { - double[] c = dc.valuesAt(bi); - int blen = dc.blockSize(bi); - int off = bi * dc.blockSize(); - for( int i=0, ix=0; i<blen; i++, ix+=n ) { - int ai = off + i; - if( a.isEmpty(ai) ) continue; - int apos = a.pos(ai); - int alen = a.size(ai); - int[] aix = a.indexes(ai); - double[] avals = a.values(ai); - for(int k = apos; k < apos+alen; k++) - c[ix+aix[k]] = avals[k]; - } + for(int i=rl; i<ru; i++) { + double[] c = dc.values(i); + int cpos = dc.pos(i); + if( a.isEmpty(i) ) continue; + int apos = a.pos(i); + int alen = a.size(i); + int[] aix = a.indexes(i); + double[] avals = a.values(i); + for(int k = apos; k < apos+alen; k++) + c[cpos+aix[k]] = avals[k]; } } else if( !m1.sparse ) //DENSE left { - if( !m1.isEmptyBlock(false) ) - dc.set(m1.getDenseBlock()); + if( !m1.isEmptyBlock(false) ) { + int rlbix = dc.index(rl); + int rubix = dc.index(ru-1); + DenseBlock da = m1.getDenseBlock(); + if( rlbix == rubix ) + System.arraycopy(da.valuesAt(rlbix), da.pos(rl), dc.valuesAt(rlbix), dc.pos(rl), (ru-rl)*n); + else { + for(int i=rl; i<ru; i++) + System.arraycopy(da.values(i), da.pos(i), dc.values(i), dc.pos(i), n); + } + } else dc.set(0); } @@ -847,35 +870,32 @@ public class LibMatrixBincell if( m2.sparse && m2.sparseBlock!=null ) //SPARSE right { SparseBlock a = m2.sparseBlock; - for( int bi=0; bi<dc.numBlocks(); bi++ ) { - double[] c = dc.valuesAt(bi); - int blen = dc.blockSize(bi); - int off = bi * dc.blockSize(); - for( int i=0, ix=0; i<blen; i++, ix+=n ) { - int ai = off + i; - if( !a.isEmpty(ai) ) { - int apos = a.pos(ai); - int alen = a.size(ai); - int[] aix = a.indexes(ai); - double[] avals = a.values(ai); - for(int k = apos; k < apos+alen; k++) - c[ix+aix[k]] = op.fn.execute(c[ix+aix[k]], avals[k]); - } - //exploit temporal locality of rows - lnnz += ret.recomputeNonZeros(ai, ai, 0, n-1); + for(int i=rl; i<ru; i++) { + double[] c = dc.values(i); + int cpos = dc.pos(i); + if( !a.isEmpty(i) ) { + int apos = a.pos(i); + int alen = a.size(i); + int[] aix = a.indexes(i); + double[] avals = a.values(i); + for(int k = apos; k < apos+alen; k++) + c[cpos+aix[k]] = op.fn.execute(c[cpos+aix[k]], avals[k]); } + //exploit temporal locality of rows + lnnz += ret.recomputeNonZeros(i, i); } } else if( !m2.sparse ) //DENSE right { if( !m2.isEmptyBlock(false) ) { - for( int bi=0; bi<dc.numBlocks(); bi++ ) { - double[] a = m2.getDenseBlock().valuesAt(bi); - double[] c = dc.valuesAt(bi); - int len = dc.size(bi); - for( int i=0; i<len; i++ ) { - c[i] = op.fn.execute(c[i], a[i]); - lnnz += (c[i]!=0) ? 1 : 0; + DenseBlock da = m2.getDenseBlock(); + for( int i=rl; i<ru; i++ ) { + double[] a = da.values(i); + double[] c = dc.values(i); + int apos = da.pos(i); + for( int j = apos; j<apos+n; j++ ) { + c[j] = op.fn.execute(c[j], a[j]); + lnnz += (c[j]!=0) ? 1 : 0; } } } @@ -886,41 +906,45 @@ public class LibMatrixBincell } //3) recompute nnz - ret.setNonZeros(lnnz); + return lnnz; } - private static void safeBinaryMMDenseDenseDense(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, BinaryOperator op) { - ret.allocateDenseBlock(); + private static long safeBinaryMMDenseDenseDense(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, + BinaryOperator op, int rl, int ru) + { DenseBlock da = m1.getDenseBlock(); DenseBlock db = m2.getDenseBlock(); DenseBlock dc = ret.getDenseBlock(); ValueFunction fn = op.fn; + int clen = m1.clen; //compute dense-dense binary, maintain nnz on-the-fly long lnnz = 0; - for( int bi=0; bi<da.numBlocks(); bi++ ) { - double[] a = da.valuesAt(bi); - double[] b = db.valuesAt(bi); - double[] c = dc.valuesAt(bi); - int len = da.size(bi); - for( int i=0; i<len; i++ ) { - c[i] = fn.execute(a[i], b[i]); - lnnz += (c[i]!=0)? 1 : 0; + for(int i=rl; i<ru; i++) { + double[] a = da.values(i); + double[] b = db.values(i); + double[] c = dc.values(i); + int pos = da.pos(i); + for(int j=pos; j<pos+clen; j++) { + c[j] = fn.execute(a[j], b[j]); + lnnz += (c[j]!=0)? 1 : 0; } } - ret.setNonZeros(lnnz); + return lnnz; } - private static void safeBinaryMMSparseDenseSkip(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, BinaryOperator op) { + private static long safeBinaryMMSparseDenseSkip(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, + BinaryOperator op, int rl, int ru) + { SparseBlock a = m1.sparse ? m1.sparseBlock : m2.sparseBlock; if( a == null ) - return; + return 0; //prepare second input and allocate output MatrixBlock b = m1.sparse ? m2 : m1; - ret.allocateBlock(); - for( int i=0; i<a.numRows(); i++ ) { + long lnnz = 0; + for( int i=rl; i<Math.min(ru, a.numRows()); i++ ) { if( a.isEmpty(i) ) continue; int apos = a.pos(i); int alen = a.size(i); @@ -932,22 +956,28 @@ public class LibMatrixBincell double in2 = b.quickGetValue(i, aix[k]); if( in2==0 ) continue; double val = op.fn.execute(avals[k], in2); - ret.appendValue(i, aix[k], val); + lnnz += (val != 0) ? 1 : 0; + ret.appendValuePlain(i, aix[k], val); } } + return lnnz; } - private static void safeBinaryMMGeneric(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, BinaryOperator op) { - int rlen = m1.rlen; + private static long safeBinaryMMGeneric(MatrixBlock m1, MatrixBlock m2, + MatrixBlock ret, BinaryOperator op, int rl, int ru) + { int clen = m2.clen; - for(int r=0; r<rlen; r++) + long lnnz = 0; + for(int r=rl; r<ru; r++) for(int c=0; c<clen; c++) { double in1 = m1.quickGetValue(r, c); double in2 = m2.quickGetValue(r, c); if( in1==0 && in2==0) continue; double val = op.fn.execute(in1, in2); - ret.appendValue(r, c, val); + lnnz += (val != 0) ? 1 : 0; + ret.appendValuePlain(r, c, val); } + return lnnz; } /** @@ -960,7 +990,7 @@ public class LibMatrixBincell * @param bOp binary operator * */ - private static void performBinOuterOperation(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, BinaryOperator bOp) { + private static long performBinOuterOperation(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, BinaryOperator bOp) { int rlen = m1.rlen; int clen = ret.clen; double b[] = DataConverter.convertToDoubleVector(m2); @@ -1006,9 +1036,10 @@ public class LibMatrixBincell } ret.setNonZeros(lnnz); ret.examSparsity(); + return lnnz; } - private static void unsafeBinary(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, BinaryOperator op, int rl, int ru) { + private static long unsafeBinary(MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, BinaryOperator op, int rl, int ru) { int clen = m1.clen; BinaryAccessType atype = getBinaryAccessType(m1, m2); @@ -1038,7 +1069,7 @@ public class LibMatrixBincell int clen2 = m2.clen; if(LibMatrixOuterAgg.isCompareOperator(op) && m2.getNumColumns()>16 && SortUtils.isSorted(m2)) { - performBinOuterOperation(m1, m2, ret, op); + lnnz = performBinOuterOperation(m1, m2, ret, op); } else { for(int r=rl; r<ru; r++) { @@ -1046,7 +1077,8 @@ public class LibMatrixBincell for(int c=0; c<clen2; c++) { double v2 = m2.quickGetValue(0, c); double v = op.fn.execute( v1, v2 ); - ret.appendValue(r, c, v); + lnnz += (v != 0) ? 1 : 0; + ret.appendValuePlain(r, c, v); } } } @@ -1079,9 +1111,7 @@ public class LibMatrixBincell } } - //avoid false sharing in multi-threaded ops, while - //correctly setting the nnz for single-threaded ops - ret.nonZeros = lnnz; + return lnnz; } private static void safeBinaryScalar(MatrixBlock m1, MatrixBlock ret, ScalarOperator op, int rl, int ru) { @@ -1504,25 +1534,28 @@ public class LibMatrixBincell private final MatrixBlock _m2; private final MatrixBlock _ret; private final BinaryOperator _bop; + BinaryAccessType _atype; private final int _rl; private final int _ru; - protected BincellTask( MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, BinaryOperator bop, int rl, int ru ) { + protected BincellTask( MatrixBlock m1, MatrixBlock m2, MatrixBlock ret, BinaryOperator bop, BinaryAccessType atype, int rl, int ru ) { _m1 = m1; _m2 = m2; _ret = ret; _bop = bop; + _atype = atype; _rl = rl; _ru = ru; } @Override public Long call() { - //execute binary operation on row partition - unsafeBinary(_m1, _m2, _ret, _bop, _rl, _ru); - - //maintain block nnz (upper bounds inclusive) - return _ret.recomputeNonZeros(_rl, _ru-1); + // execute binary operation on row partition + // (including nnz maintenance) + if(_bop.sparseSafe || isSparseSafeDivide(_bop, _m2)) + return safeBinary(_m1, _m2, _ret, _bop, _atype, _rl, _ru); + else + return unsafeBinary(_m1, _m2, _ret, _bop, _rl, _ru); } } diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixTercell.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixTercell.java new file mode 100644 index 0000000..cb48a1f --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixTercell.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.matrix.data; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; + +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.matrix.operators.TernaryOperator; +import org.apache.sysds.runtime.util.CommonThreadPool; +import org.apache.sysds.runtime.util.UtilFunctions; + +/** + * Library for ternary cellwise operations. + * + */ +public class LibMatrixTercell +{ + private static final long PAR_NUMCELL_THRESHOLD = 8*1024; + + private LibMatrixTercell() { + //prevent instantiation via private constructor + } + + public static void tercellOp(MatrixBlock m1, MatrixBlock m2, MatrixBlock m3, MatrixBlock ret, TernaryOperator op) + { + final boolean s1 = (m1.rlen==1 && m1.clen==1); + final boolean s2 = (m2.rlen==1 && m2.clen==1); + final boolean s3 = (m3.rlen==1 && m3.clen==1); + final double d1 = s1 ? m1.quickGetValue(0, 0) : Double.NaN; + final double d2 = s2 ? m2.quickGetValue(0, 0) : Double.NaN; + final double d3 = s3 ? m3.quickGetValue(0, 0) : Double.NaN; + + //allocate dense/sparse output + ret.allocateBlock(); + + //execute ternary cell operations + if( op.getNumThreads() > 1 && ret.getLength() > PAR_NUMCELL_THRESHOLD) { + try { + //execute binary cell operations + ExecutorService pool = CommonThreadPool.get(op.getNumThreads()); + ArrayList<TercellTask> tasks = new ArrayList<>(); + ArrayList<Integer> blklens = UtilFunctions + .getBalancedBlockSizesDefault(ret.rlen, op.getNumThreads(), false); + for( int i=0, lb=0; i<blklens.size(); lb+=blklens.get(i), i++ ) + tasks.add(new TercellTask(m1, m2, m3, ret, op, s1, s2, s3, d1, d2, d3, lb, lb+blklens.get(i))); + List<Future<Long>> taskret = pool.invokeAll(tasks); + + //aggregate non-zeros + ret.nonZeros = 0; //reset after execute + for( Future<Long> task : taskret ) + ret.nonZeros += task.get(); + pool.shutdown(); + } + catch(InterruptedException | ExecutionException ex) { + throw new DMLRuntimeException(ex); + } + } + else { + unsafeTernary(m1, m2, m3, ret, op, s1, s2, s3, d1, d2, d3, 0, ret.rlen); + } + } + + private static void unsafeTernary(MatrixBlock m1, MatrixBlock m2, MatrixBlock m3, MatrixBlock ret, + TernaryOperator op, boolean s1, boolean s2, boolean s3, double d1, double d2, double d3, int rl, int ru) + { + //basic ternary operations (all combinations sparse/dense) + int n = ret.clen; + long lnnz = 0; + for( int i=rl; i<ru; i++ ) + for( int j=0; j<n; j++ ) { + double in1 = s1 ? d1 : m1.quickGetValue(i, j); + double in2 = s2 ? d2 : m2.quickGetValue(i, j); + double in3 = s3 ? d3 : m3.quickGetValue(i, j); + double val = op.fn.execute(in1, in2, in3); + lnnz += (val != 0) ? 1 : 0; + ret.appendValuePlain(i, j, val); + } + + //set global output nnz once + ret.nonZeros = lnnz; + } + + private static class TercellTask implements Callable<Long> { + private final MatrixBlock _m1, _m2, _m3; + private final boolean _s1, _s2, _s3; + private final double _d1, _d2, _d3; + private final MatrixBlock _ret; + private final TernaryOperator _op; + private final int _rl, _ru; + + protected TercellTask(MatrixBlock m1, MatrixBlock m2, MatrixBlock m3, MatrixBlock ret, TernaryOperator op, + boolean s1, boolean s2, boolean s3, double d1, double d2, double d3, int rl, int ru) { + _m1 = m1; _m2 = m2; _m3 = m3; + _s1 = s1; _s2 = s2; _s3 = s3; + _d1 = d1; _d2 = d2; _d3 = d3; + _ret = ret; + _op = op; + _rl = rl; _ru = ru; + } + + @Override + public Long call() { + //execute binary operation on row partition + unsafeTernary(_m1, _m2, _m3, _ret, _op, _s1, _s2, _s3, _d1, _d2, _d3, _rl, _ru); + + //maintain block nnz (upper bounds inclusive) + return _ret.recomputeNonZeros(_rl, _ru-1); + } + } +} diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java index 5d4b869..1695465 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java @@ -2911,7 +2911,10 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab } //prepare result - ret.reset(m, n, false); + boolean sparseOutput = (op.fn instanceof PlusMultiply || op.fn instanceof MinusMultiply)? + evalSparseFormatInMemory(m, n, (s1?m*n*(d1!=0?1:0):getNonZeros()) + + Math.min(s2?m*n:m2.getNonZeros(), s3?m*n:m3.getNonZeros())) : false; + ret.reset(m, n, sparseOutput); if( op.fn instanceof IfElse && (s1 || nnz==0 || nnz==(long)m*n) ) { //SPECIAL CASE for shallow-copy if-else @@ -2933,21 +2936,15 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab } else if (s2 != s3 && (op.fn instanceof PlusMultiply || op.fn instanceof MinusMultiply) ) { //SPECIAL CASE for sparse-dense combinations of common +* and -* - BinaryOperator bop = ((ValueFunctionWithConstant)op.fn) - .setOp2Constant(s2 ? d2 : d3); - LibMatrixBincell.bincellOp(this, s2 ? m3 : m2, ret, bop); + BinaryOperator bop = ((ValueFunctionWithConstant)op.fn).setOp2Constant(s2 ? d2 : d3); + if( op.getNumThreads() > 1 ) + LibMatrixBincell.bincellOp(this, s2 ? m3 : m2, ret, bop, op.getNumThreads()); + else + LibMatrixBincell.bincellOp(this, s2 ? m3 : m2, ret, bop); } else { - ret.allocateDenseBlock(); - - //basic ternary operations - for( int i=0; i<m; i++ ) - for( int j=0; j<n; j++ ) { - double in1 = s1 ? d1 : quickGetValue(i, j); - double in2 = s2 ? d2 : m2.quickGetValue(i, j); - double in3 = s3 ? d3 : m3.quickGetValue(i, j); - ret.appendValue(i, j, op.fn.execute(in1, in2, in3)); - } + //DEFAULT CASE + LibMatrixTercell.tercellOp(this, m2, m3, ret, op); //ensure correct output representation ret.examSparsity(); diff --git a/src/main/java/org/apache/sysds/runtime/matrix/operators/TernaryOperator.java b/src/main/java/org/apache/sysds/runtime/matrix/operators/TernaryOperator.java index 1caacd7..6ff8e89 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/operators/TernaryOperator.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/operators/TernaryOperator.java @@ -32,11 +32,17 @@ public class TernaryOperator extends Operator implements Serializable private static final long serialVersionUID = 3456088891054083634L; public final TernaryValueFunction fn; - - public TernaryOperator(TernaryValueFunction p) { + private final int _k; // num threads + + public TernaryOperator(TernaryValueFunction p, int numThreads) { //ternaryop is sparse-safe iff (op 0 0 0) == 0 super (p instanceof PlusMultiply || p instanceof MinusMultiply || p instanceof IfElse); fn = p; + _k = numThreads; + } + + public int getNumThreads() { + return _k; } @Override
