Repository: systemml Updated Branches: refs/heads/master 130096893 -> 95de23586
[SYSTEMML-1837] Fix unary aggregate physical output size/mem estimates So far, the drop of correction columns/rows after unary aggregates did not actually drop columns or rows but simply shifted its values into the correct positions. Hence, unary aggregates in CP returned matrices whose physical size exceeds the respective memory estimates (operations and buffer pool), which can potentially lead to OOMs. This patch now fixes this issue but actually slicing out the single column / row vectors. Accordingly, we also adapt the used hop memory estimates (for intermediate memory requirements). Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/57dff5df Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/57dff5df Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/57dff5df Branch: refs/heads/master Commit: 57dff5dfb1a499fc37300fc1df017aad1ad53ef3 Parents: 1300968 Author: Matthias Boehm <mboe...@gmail.com> Authored: Fri Aug 11 01:50:12 2017 -0700 Committer: Matthias Boehm <mboe...@gmail.com> Committed: Fri Aug 11 16:49:17 2017 -0700 ---------------------------------------------------------------------- .../java/org/apache/sysml/hops/AggUnaryOp.java | 14 +-- .../org/apache/sysml/lops/PartialAggregate.java | 11 +- .../runtime/compress/CompressedMatrixBlock.java | 2 +- .../cp/UaggOuterChainCPInstruction.java | 2 +- .../mr/AggregateUnaryInstruction.java | 2 +- .../mr/BinUaggChainInstruction.java | 2 +- .../mr/CumulativeAggregateInstruction.java | 2 +- .../spark/AggregateTernarySPInstruction.java | 2 +- .../spark/AggregateUnarySPInstruction.java | 4 +- .../spark/BinUaggChainSPInstruction.java | 2 +- .../spark/CumulativeAggregateSPInstruction.java | 2 +- .../spark/UaggOuterChainSPInstruction.java | 2 +- .../AggregateDropCorrectionFunction.java | 2 +- .../sysml/runtime/matrix/data/LibMatrixAgg.java | 2 +- .../sysml/runtime/matrix/data/MatrixBlock.java | 118 +++++-------------- .../runtime/matrix/data/SparseBlockFactory.java | 8 +- 16 files changed, 68 insertions(+), 109 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/57dff5df/src/main/java/org/apache/sysml/hops/AggUnaryOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java index e94aaf3..7a6d463 100644 --- a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java +++ b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java @@ -360,16 +360,16 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop case SUM_SQ: //worst-case correction LASTROW / LASTCOLUMN if( _direction == Direction.Col ) //(potentially sparse) - val = OptimizerUtils.estimateSizeExactSparsity(1, dim2, sparsity); + val = OptimizerUtils.estimateSizeExactSparsity(2, dim2, sparsity); else if( _direction == Direction.Row ) //(always dense) - val = OptimizerUtils.estimateSizeExactSparsity(dim1, 1, 1.0); + val = OptimizerUtils.estimateSizeExactSparsity(dim1, 2, 1.0); break; case MEAN: //worst-case correction LASTTWOROWS / LASTTWOCOLUMNS if( _direction == Direction.Col ) //(potentially sparse) - val = OptimizerUtils.estimateSizeExactSparsity(2, dim2, sparsity); + val = OptimizerUtils.estimateSizeExactSparsity(3, dim2, sparsity); else if( _direction == Direction.Row ) //(always dense) - val = OptimizerUtils.estimateSizeExactSparsity(dim1, 2, 1.0); + val = OptimizerUtils.estimateSizeExactSparsity(dim1, 3, 1.0); break; case VAR: //worst-case correction LASTFOURROWS / LASTFOURCOLUMNS @@ -394,9 +394,9 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop } } else if( _direction == Direction.Col ) { //(potentially sparse) - val = OptimizerUtils.estimateSizeExactSparsity(4, dim2, sparsity); + val = OptimizerUtils.estimateSizeExactSparsity(5, dim2, sparsity); } else if( _direction == Direction.Row ) { //(always dense) - val = OptimizerUtils.estimateSizeExactSparsity(dim1, 4, 1.0); + val = OptimizerUtils.estimateSizeExactSparsity(dim1, 5, 1.0); } break; case MAXINDEX: @@ -406,7 +406,7 @@ public class AggUnaryOp extends Hop implements MultiThreadedHop val = 3 * OptimizerUtils.estimateSizeExactSparsity(1, hop._dim2, 1.0); else //worst-case correction LASTCOLUMN - val = OptimizerUtils.estimateSizeExactSparsity(dim1, 1, 1.0); + val = OptimizerUtils.estimateSizeExactSparsity(dim1, 2, 1.0); break; default: //no intermediate memory consumption http://git-wip-us.apache.org/repos/asf/systemml/blob/57dff5df/src/main/java/org/apache/sysml/lops/PartialAggregate.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/lops/PartialAggregate.java b/src/main/java/org/apache/sysml/lops/PartialAggregate.java index 512c7f1..cf7826e 100644 --- a/src/main/java/org/apache/sysml/lops/PartialAggregate.java +++ b/src/main/java/org/apache/sysml/lops/PartialAggregate.java @@ -50,8 +50,15 @@ public class PartialAggregate extends Lop LASTTWOCOLUMNS, LASTFOURROWS, LASTFOURCOLUMNS, - INVALID - }; + INVALID; + + public int getNumRemovedRowsColumns() { + return (this==LASTROW || this==LASTCOLUMN) ? 1 : + (this==LASTTWOROWS || this==LASTTWOCOLUMNS) ? 2 : + (this==LASTFOURROWS || this==LASTFOURCOLUMNS) ? 4 : 0; + } + + } private Aggregate.OperationTypes operation; private DirectionTypes direction; http://git-wip-us.apache.org/repos/asf/systemml/blob/57dff5df/src/main/java/org/apache/sysml/runtime/compress/CompressedMatrixBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/compress/CompressedMatrixBlock.java b/src/main/java/org/apache/sysml/runtime/compress/CompressedMatrixBlock.java index d141d85..ced6d62 100644 --- a/src/main/java/org/apache/sysml/runtime/compress/CompressedMatrixBlock.java +++ b/src/main/java/org/apache/sysml/runtime/compress/CompressedMatrixBlock.java @@ -1170,7 +1170,7 @@ public class CompressedMatrixBlock extends MatrixBlock implements Externalizable //drop correction if necessary if(op.aggOp.correctionExists && inCP) - ret.dropLastRowsOrColums(op.aggOp.correctionLocation); + ret.dropLastRowsOrColumns(op.aggOp.correctionLocation); //post-processing ret.recomputeNonZeros(); http://git-wip-us.apache.org/repos/asf/systemml/blob/57dff5df/src/main/java/org/apache/sysml/runtime/instructions/cp/UaggOuterChainCPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/UaggOuterChainCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/UaggOuterChainCPInstruction.java index 746ee04..0a6db0d 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/UaggOuterChainCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/UaggOuterChainCPInstruction.java @@ -102,7 +102,7 @@ public class UaggOuterChainCPInstruction extends UnaryCPInstruction ec.releaseMatrixInput(input2.getName(), getExtendedOpcode()); if( _uaggOp.aggOp.correctionExists ) - mbOut.dropLastRowsOrColums(_uaggOp.aggOp.correctionLocation); + mbOut.dropLastRowsOrColumns(_uaggOp.aggOp.correctionLocation); String output_name = output.getName(); //final aggregation if required http://git-wip-us.apache.org/repos/asf/systemml/blob/57dff5df/src/main/java/org/apache/sysml/runtime/instructions/mr/AggregateUnaryInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/mr/AggregateUnaryInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/mr/AggregateUnaryInstruction.java index a0285c0..409590a 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/mr/AggregateUnaryInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/mr/AggregateUnaryInstruction.java @@ -98,7 +98,7 @@ public class AggregateUnaryInstruction extends UnaryMRInstructionBase OperationsOnMatrixValues.performAggregateUnary( inix, in.getValue(), out.getIndexes(), out.getValue(), auop, blockRowFactor, blockColFactor); if( _dropCorr ) - ((MatrixBlock)out.getValue()).dropLastRowsOrColums(auop.aggOp.correctionLocation); + ((MatrixBlock)out.getValue()).dropLastRowsOrColumns(auop.aggOp.correctionLocation); } //put the output value in the cache http://git-wip-us.apache.org/repos/asf/systemml/blob/57dff5df/src/main/java/org/apache/sysml/runtime/instructions/mr/BinUaggChainInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/mr/BinUaggChainInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/mr/BinUaggChainInstruction.java index 2a428bb..a8badd3 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/mr/BinUaggChainInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/mr/BinUaggChainInstruction.java @@ -98,7 +98,7 @@ public class BinUaggChainInstruction extends UnaryInstruction //process instruction OperationsOnMatrixValues.performAggregateUnary( inIx, inVal, _tmpIx, _tmpVal, _uaggOp, blockRowFactor, blockColFactor); - ((MatrixBlock)_tmpVal).dropLastRowsOrColums(_uaggOp.aggOp.correctionLocation); + ((MatrixBlock)_tmpVal).dropLastRowsOrColumns(_uaggOp.aggOp.correctionLocation); OperationsOnMatrixValues.performBinaryIgnoreIndexes(inVal, _tmpVal, outVal, _bOp); outIx.setIndexes(inIx); } http://git-wip-us.apache.org/repos/asf/systemml/blob/57dff5df/src/main/java/org/apache/sysml/runtime/instructions/mr/CumulativeAggregateInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/mr/CumulativeAggregateInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/mr/CumulativeAggregateInstruction.java index d2272ac..825fecb 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/mr/CumulativeAggregateInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/mr/CumulativeAggregateInstruction.java @@ -87,7 +87,7 @@ public class CumulativeAggregateInstruction extends AggregateUnaryInstruction OperationsOnMatrixValues.performAggregateUnary( inix, in1.getValue(), out.getIndexes(), out.getValue(), ((AggregateUnaryOperator)optr), blockRowFactor, blockColFactor); if( ((AggregateUnaryOperator)optr).aggOp.correctionExists ) - ((MatrixBlock)out.getValue()).dropLastRowsOrColums(((AggregateUnaryOperator)optr).aggOp.correctionLocation); + ((MatrixBlock)out.getValue()).dropLastRowsOrColumns(((AggregateUnaryOperator)optr).aggOp.correctionLocation); //cumsum expand partial aggregates long rlenOut = (long)Math.ceil((double)_mcIn.getRows()/blockRowFactor); http://git-wip-us.apache.org/repos/asf/systemml/blob/57dff5df/src/main/java/org/apache/sysml/runtime/instructions/spark/AggregateTernarySPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/AggregateTernarySPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/AggregateTernarySPInstruction.java index 7ac1e5b..5822ff5 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/AggregateTernarySPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/AggregateTernarySPInstruction.java @@ -107,7 +107,7 @@ public class AggregateTernarySPInstruction extends ComputationSPInstruction { //single block aggregation and drop correction MatrixBlock ret = RDDAggregateUtils.aggStable(out, aggop.aggOp); - ret.dropLastRowsOrColums(aggop.aggOp.correctionLocation); + ret.dropLastRowsOrColumns(aggop.aggOp.correctionLocation); //put output block into symbol table (no lineage because single block) //this also includes implicit maintenance of matrix characteristics http://git-wip-us.apache.org/repos/asf/systemml/blob/57dff5df/src/main/java/org/apache/sysml/runtime/instructions/spark/AggregateUnarySPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/AggregateUnarySPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/AggregateUnarySPInstruction.java index 352a72e..36f501b 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/AggregateUnarySPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/AggregateUnarySPInstruction.java @@ -102,7 +102,7 @@ public class AggregateUnarySPInstruction extends UnarySPInstruction MatrixBlock out3 = RDDAggregateUtils.aggStable(out2, aggop); //drop correction after aggregation - out3.dropLastRowsOrColums(aggop.correctionLocation); + out3.dropLastRowsOrColumns(aggop.correctionLocation); //put output block into symbol table (no lineage because single block) //this also includes implicit maintenance of matrix characteristics @@ -222,7 +222,7 @@ public class AggregateUnarySPInstruction extends UnarySPInstruction arg0.aggregateUnaryOperations(_op, blkOut, _brlen, _bclen, _ix); //always drop correction since no aggregation - blkOut.dropLastRowsOrColums(_op.aggOp.correctionLocation); + blkOut.dropLastRowsOrColumns(_op.aggOp.correctionLocation); //output new tuple return blkOut; http://git-wip-us.apache.org/repos/asf/systemml/blob/57dff5df/src/main/java/org/apache/sysml/runtime/instructions/spark/BinUaggChainSPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/BinUaggChainSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/BinUaggChainSPInstruction.java index 21652b0..a1b7d35 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/BinUaggChainSPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/BinUaggChainSPInstruction.java @@ -109,7 +109,7 @@ public class BinUaggChainSPInstruction extends UnarySPInstruction arg0.aggregateUnaryOperations(_uaggOp, out1, brlen, bclen, null); //strip-off correction - out1.dropLastRowsOrColums(_uaggOp.aggOp.correctionLocation); + out1.dropLastRowsOrColumns(_uaggOp.aggOp.correctionLocation); //perform binary operation MatrixBlock out2 = new MatrixBlock(); http://git-wip-us.apache.org/repos/asf/systemml/blob/57dff5df/src/main/java/org/apache/sysml/runtime/instructions/spark/CumulativeAggregateSPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/CumulativeAggregateSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/CumulativeAggregateSPInstruction.java index 9dd81aa..0301d54 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/CumulativeAggregateSPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/CumulativeAggregateSPInstruction.java @@ -116,7 +116,7 @@ public class CumulativeAggregateSPInstruction extends AggregateUnarySPInstructio OperationsOnMatrixValues.performAggregateUnary( ixIn, blkIn, ixOut, blkOut, ((AggregateUnaryOperator)_op), _brlen, _bclen); if( ((AggregateUnaryOperator)_op).aggOp.correctionExists ) - blkOut.dropLastRowsOrColums(((AggregateUnaryOperator)_op).aggOp.correctionLocation); + blkOut.dropLastRowsOrColumns(((AggregateUnaryOperator)_op).aggOp.correctionLocation); //cumsum expand partial aggregates long rlenOut = (long)Math.ceil((double)_rlen/_brlen); http://git-wip-us.apache.org/repos/asf/systemml/blob/57dff5df/src/main/java/org/apache/sysml/runtime/instructions/spark/UaggOuterChainSPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/UaggOuterChainSPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/UaggOuterChainSPInstruction.java index 8f74b9d..a270706 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/UaggOuterChainSPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/UaggOuterChainSPInstruction.java @@ -163,7 +163,7 @@ public class UaggOuterChainSPInstruction extends BinarySPInstruction MatrixBlock tmp = RDDAggregateUtils.aggStable(out, _aggOp); //drop correction after aggregation - tmp.dropLastRowsOrColums(_aggOp.correctionLocation); + tmp.dropLastRowsOrColumns(_aggOp.correctionLocation); //put output block into symbol table (no lineage because single block) sec.setMatrixOutput(output.getName(), tmp, getExtendedOpcode()); http://git-wip-us.apache.org/repos/asf/systemml/blob/57dff5df/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/AggregateDropCorrectionFunction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/AggregateDropCorrectionFunction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/AggregateDropCorrectionFunction.java index 9ace752..d08a227 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/AggregateDropCorrectionFunction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/AggregateDropCorrectionFunction.java @@ -44,7 +44,7 @@ public class AggregateDropCorrectionFunction implements Function<MatrixBlock, Ma MatrixBlock blkOut = new MatrixBlock(arg0); //drop correction - blkOut.dropLastRowsOrColums(_op.correctionLocation); + blkOut.dropLastRowsOrColumns(_op.correctionLocation); return blkOut; } http://git-wip-us.apache.org/repos/asf/systemml/blob/57dff5df/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java index 094bc93..1ec5290 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java @@ -358,7 +358,7 @@ public class LibMatrixAgg for( int i=0; i<tasks.size(); i++ ) { MatrixBlock row = tasks.get(i).getResult(); if( uaop.aggOp.correctionExists ) - row.dropLastRowsOrColums(uaop.aggOp.correctionLocation); + row.dropLastRowsOrColumns(uaop.aggOp.correctionLocation); tmp.leftIndexingOperations(row, i, i, 0, n2-1, tmp, UpdateType.INPLACE_PINNED); } MatrixBlock tmp2 = cumaggregateUnaryMatrix(tmp, new MatrixBlock(tasks.size(), n2, false), uop); http://git-wip-us.apache.org/repos/asf/systemml/blob/57dff5df/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java index 7d19d61..e9039ba 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java @@ -4235,7 +4235,7 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab denseAggregateUnaryHelp(op, ret, blockingFactorRow, blockingFactorCol, indexesIn); if(op.aggOp.correctionExists && inCP) - ((MatrixBlock)result).dropLastRowsOrColums(op.aggOp.correctionLocation); + ((MatrixBlock)result).dropLastRowsOrColumns(op.aggOp.correctionLocation); return ret; } @@ -4384,57 +4384,28 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab result.quickSetValue(row, column, newvalue); } } - - public void dropLastRowsOrColums(CorrectionLocationType correctionLocation) + + public void dropLastRowsOrColumns(CorrectionLocationType correctionLocation) { - //do nothing - if( correctionLocation==CorrectionLocationType.NONE - || correctionLocation==CorrectionLocationType.INVALID ) - { - return; - } - //determine number of rows/cols to be removed - int step; - switch (correctionLocation) { - case LASTROW: - case LASTCOLUMN: - step = 1; - break; - case LASTTWOROWS: - case LASTTWOCOLUMNS: - step = 2; - break; - case LASTFOURROWS: - case LASTFOURCOLUMNS: - step = 4; - break; - default: - step = 0; - } - + int step = correctionLocation.getNumRemovedRowsColumns(); + if( step <= 0 ) + return; //e.g., colSums, colMeans, colMaxs, colMeans, colVars if( correctionLocation==CorrectionLocationType.LASTROW || correctionLocation==CorrectionLocationType.LASTTWOROWS || correctionLocation==CorrectionLocationType.LASTFOURROWS ) { - if( sparse ) //SPARSE - { - if(sparseBlock!=null) - for(int i=1; i<=step; i++) - if(!sparseBlock.isEmpty(rlen-i)) - this.nonZeros-=sparseBlock.size(rlen-i); + if( sparse && sparseBlock!=null ) { //SPARSE + nonZeros -= recomputeNonZeros(1, rlen-1, 0, clen-1); + sparseBlock = SparseBlockFactory + .createSparseBlock(DEFAULT_SPARSEBLOCK, sparseBlock.get(0)); } - else //DENSE - { - if(denseBlock!=null) - for(int i=(rlen-step)*clen; i<rlen*clen; i++) - if(denseBlock[i]!=0) - this.nonZeros--; + else if( !sparse && denseBlock!=null ) { //DENSE + nonZeros -= recomputeNonZeros(1, rlen-1, 0, clen-1); + denseBlock = Arrays.copyOfRange(denseBlock, 0, clen); } - - //just need to shrink the dimension, the deleted rows won't be accessed rlen -= step; } @@ -4443,51 +4414,26 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab || correctionLocation==CorrectionLocationType.LASTTWOCOLUMNS || correctionLocation==CorrectionLocationType.LASTFOURCOLUMNS ) { - if(sparse) //SPARSE - { - if(sparseBlock!=null) - { - for(int r=0; r<Math.min(rlen, sparseBlock.numRows()); r++) - if(!sparseBlock.isEmpty(r)) - { - int newSize=sparseBlock.posFIndexGTE(r, clen-step); - if(newSize >= 0) - { - this.nonZeros-=sparseBlock.size(r)-newSize; - int pos = sparseBlock.pos(r); - int cl = sparseBlock.indexes(r)[pos+newSize-1]; - sparseBlock.deleteIndexRange(r, cl+1, clen); - //TODO perf sparse block: truncate replaced by deleteIndexRange - } - } - } + if( sparse && sparseBlock!=null ) { //SPARSE + //sparse blocks are converted to a dense representation + //because column vectors are always smaller in dense + double[] tmp = new double[rlen]; + int lnnz = 0; + for( int i=0; i<rlen; i++ ) + lnnz += ((tmp[i] = sparseBlock.get(i, 0))!=0)? 1 : 0; + cleanupBlock(true, true); + sparse = false; + denseBlock = tmp; + nonZeros = lnnz; } - else //DENSE - { - if(this.denseBlock!=null) - { - //the first row doesn't need to be copied - int targetIndex=clen-step; - int sourceOffset=clen; - this.nonZeros=0; - for(int i=0; i<targetIndex; i++) - if(denseBlock[i]!=0) - this.nonZeros++; - - //start from the 2nd row - for(int r=1; r<rlen; r++) - { - for(int c=0; c<clen-step; c++) - { - if((denseBlock[targetIndex]=denseBlock[sourceOffset+c])!=0) - this.nonZeros++; - targetIndex++; - } - sourceOffset+=clen; - } - } + else if( !sparse && denseBlock!=null ) { //DENSE + double[] tmp = new double[rlen]; + int lnnz = 0; + for( int i=0, aix=0; i<rlen; i++, aix+=clen ) + lnnz += ((tmp[i] = denseBlock[aix])!=0)? 1 : 0; + denseBlock = tmp; + nonZeros = lnnz; } - clen -= step; } } @@ -4980,7 +4926,7 @@ public class MatrixBlock extends MatrixValue implements CacheBlock, Externalizab ret = LibMatrixAgg.aggregateTernary(m1, m2, m3, ret, op); if(op.aggOp.correctionExists && inCP) - ret.dropLastRowsOrColums(op.aggOp.correctionLocation); + ret.dropLastRowsOrColumns(op.aggOp.correctionLocation); return ret; } http://git-wip-us.apache.org/repos/asf/systemml/blob/57dff5df/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockFactory.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockFactory.java b/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockFactory.java index 5abd7ba..d5e45ff 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockFactory.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockFactory.java @@ -27,7 +27,7 @@ public abstract class SparseBlockFactory return createSparseBlock(MatrixBlock.DEFAULT_SPARSEBLOCK, rlen); } - public static SparseBlock createSparseBlock( SparseBlock.Type type, int rlen ) { + public static SparseBlock createSparseBlock(SparseBlock.Type type, int rlen) { switch( type ) { case MCSR: return new SparseBlockMCSR(rlen, -1); case CSR: return new SparseBlockCSR(rlen); @@ -36,6 +36,12 @@ public abstract class SparseBlockFactory throw new RuntimeException("Unexpected sparse block type: "+type.toString()); } } + + public static SparseBlock createSparseBlock(SparseBlock.Type type, SparseRow row) { + SparseBlock ret = createSparseBlock(type, 1); + ret.set(0, row, true); + return ret; + } public static SparseBlock copySparseBlock( SparseBlock.Type type, SparseBlock sblock, boolean forceCopy ) {