Repository: systemml
Updated Branches:
  refs/heads/master 130096893 -> 95de23586


[SYSTEMML-1837] Fix unary aggregate physical output size/mem estimates

So far, the drop of correction columns/rows after unary aggregates did
not actually drop columns or rows but simply shifted its values into the
correct positions. Hence, unary aggregates in CP returned matrices whose
physical size exceeds the respective memory estimates (operations and
buffer pool), which can potentially lead to OOMs. 

This patch now fixes this issue but actually slicing out the single
column / row vectors. Accordingly, we also adapt the used hop memory
estimates (for intermediate memory requirements).


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/57dff5df
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/57dff5df
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/57dff5df

Branch: refs/heads/master
Commit: 57dff5dfb1a499fc37300fc1df017aad1ad53ef3
Parents: 1300968
Author: Matthias Boehm <mboe...@gmail.com>
Authored: Fri Aug 11 01:50:12 2017 -0700
Committer: Matthias Boehm <mboe...@gmail.com>
Committed: Fri Aug 11 16:49:17 2017 -0700

----------------------------------------------------------------------
 .../java/org/apache/sysml/hops/AggUnaryOp.java  |  14 +--
 .../org/apache/sysml/lops/PartialAggregate.java |  11 +-
 .../runtime/compress/CompressedMatrixBlock.java |   2 +-
 .../cp/UaggOuterChainCPInstruction.java         |   2 +-
 .../mr/AggregateUnaryInstruction.java           |   2 +-
 .../mr/BinUaggChainInstruction.java             |   2 +-
 .../mr/CumulativeAggregateInstruction.java      |   2 +-
 .../spark/AggregateTernarySPInstruction.java    |   2 +-
 .../spark/AggregateUnarySPInstruction.java      |   4 +-
 .../spark/BinUaggChainSPInstruction.java        |   2 +-
 .../spark/CumulativeAggregateSPInstruction.java |   2 +-
 .../spark/UaggOuterChainSPInstruction.java      |   2 +-
 .../AggregateDropCorrectionFunction.java        |   2 +-
 .../sysml/runtime/matrix/data/LibMatrixAgg.java |   2 +-
 .../sysml/runtime/matrix/data/MatrixBlock.java  | 118 +++++--------------
 .../runtime/matrix/data/SparseBlockFactory.java |   8 +-
 16 files changed, 68 insertions(+), 109 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/57dff5df/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java 
b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
index e94aaf3..7a6d463 100644
--- a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
@@ -360,16 +360,16 @@ public class AggUnaryOp extends Hop implements 
MultiThreadedHop
                        case SUM_SQ:
                                //worst-case correction LASTROW / LASTCOLUMN
                                if( _direction == Direction.Col ) 
//(potentially sparse)
-                                       val = 
OptimizerUtils.estimateSizeExactSparsity(1, dim2, sparsity);
+                                       val = 
OptimizerUtils.estimateSizeExactSparsity(2, dim2, sparsity);
                                else if( _direction == Direction.Row ) 
//(always dense)
-                                       val = 
OptimizerUtils.estimateSizeExactSparsity(dim1, 1, 1.0);
+                                       val = 
OptimizerUtils.estimateSizeExactSparsity(dim1, 2, 1.0);
                                break;
                        case MEAN:
                                //worst-case correction LASTTWOROWS / 
LASTTWOCOLUMNS
                                if( _direction == Direction.Col ) 
//(potentially sparse)
-                                       val = 
OptimizerUtils.estimateSizeExactSparsity(2, dim2, sparsity);
+                                       val = 
OptimizerUtils.estimateSizeExactSparsity(3, dim2, sparsity);
                                else if( _direction == Direction.Row ) 
//(always dense)
-                                       val = 
OptimizerUtils.estimateSizeExactSparsity(dim1, 2, 1.0);
+                                       val = 
OptimizerUtils.estimateSizeExactSparsity(dim1, 3, 1.0);
                                break;
                        case VAR:
                                //worst-case correction LASTFOURROWS / 
LASTFOURCOLUMNS
@@ -394,9 +394,9 @@ public class AggUnaryOp extends Hop implements 
MultiThreadedHop
                                        }
 
                                } else if( _direction == Direction.Col ) { 
//(potentially sparse)
-                                       val = 
OptimizerUtils.estimateSizeExactSparsity(4, dim2, sparsity);
+                                       val = 
OptimizerUtils.estimateSizeExactSparsity(5, dim2, sparsity);
                                } else if( _direction == Direction.Row ) { 
//(always dense)
-                                       val = 
OptimizerUtils.estimateSizeExactSparsity(dim1, 4, 1.0);
+                                       val = 
OptimizerUtils.estimateSizeExactSparsity(dim1, 5, 1.0);
                                }
                                break;
                        case MAXINDEX:
@@ -406,7 +406,7 @@ public class AggUnaryOp extends Hop implements 
MultiThreadedHop
                                        val = 3 * 
OptimizerUtils.estimateSizeExactSparsity(1, hop._dim2, 1.0);
                                else
                                        //worst-case correction LASTCOLUMN 
-                                       val = 
OptimizerUtils.estimateSizeExactSparsity(dim1, 1, 1.0);
+                                       val = 
OptimizerUtils.estimateSizeExactSparsity(dim1, 2, 1.0);
                                break;
                        default:
                                //no intermediate memory consumption

http://git-wip-us.apache.org/repos/asf/systemml/blob/57dff5df/src/main/java/org/apache/sysml/lops/PartialAggregate.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/PartialAggregate.java 
b/src/main/java/org/apache/sysml/lops/PartialAggregate.java
index 512c7f1..cf7826e 100644
--- a/src/main/java/org/apache/sysml/lops/PartialAggregate.java
+++ b/src/main/java/org/apache/sysml/lops/PartialAggregate.java
@@ -50,8 +50,15 @@ public class PartialAggregate extends Lop
                LASTTWOCOLUMNS,
                LASTFOURROWS,
                LASTFOURCOLUMNS,
-               INVALID
-       };
+               INVALID;
+               
+               public int getNumRemovedRowsColumns() {
+                       return (this==LASTROW || this==LASTCOLUMN) ? 1 :
+                               (this==LASTTWOROWS || this==LASTTWOCOLUMNS) ? 2 
:
+                               (this==LASTFOURROWS || this==LASTFOURCOLUMNS) ? 
4 : 0;
+               }
+               
+       }
        
        private Aggregate.OperationTypes operation;
        private DirectionTypes direction;

http://git-wip-us.apache.org/repos/asf/systemml/blob/57dff5df/src/main/java/org/apache/sysml/runtime/compress/CompressedMatrixBlock.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/compress/CompressedMatrixBlock.java 
b/src/main/java/org/apache/sysml/runtime/compress/CompressedMatrixBlock.java
index d141d85..ced6d62 100644
--- a/src/main/java/org/apache/sysml/runtime/compress/CompressedMatrixBlock.java
+++ b/src/main/java/org/apache/sysml/runtime/compress/CompressedMatrixBlock.java
@@ -1170,7 +1170,7 @@ public class CompressedMatrixBlock extends MatrixBlock 
implements Externalizable
                
                //drop correction if necessary
                if(op.aggOp.correctionExists && inCP)
-                       ret.dropLastRowsOrColums(op.aggOp.correctionLocation);
+                       ret.dropLastRowsOrColumns(op.aggOp.correctionLocation);
        
                //post-processing
                ret.recomputeNonZeros();

http://git-wip-us.apache.org/repos/asf/systemml/blob/57dff5df/src/main/java/org/apache/sysml/runtime/instructions/cp/UaggOuterChainCPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/UaggOuterChainCPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/UaggOuterChainCPInstruction.java
index 746ee04..0a6db0d 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/cp/UaggOuterChainCPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/cp/UaggOuterChainCPInstruction.java
@@ -102,7 +102,7 @@ public class UaggOuterChainCPInstruction extends 
UnaryCPInstruction
                ec.releaseMatrixInput(input2.getName(), getExtendedOpcode());
                
                if( _uaggOp.aggOp.correctionExists )
-                       
mbOut.dropLastRowsOrColums(_uaggOp.aggOp.correctionLocation);
+                       
mbOut.dropLastRowsOrColumns(_uaggOp.aggOp.correctionLocation);
                
                String output_name = output.getName();
                //final aggregation if required

http://git-wip-us.apache.org/repos/asf/systemml/blob/57dff5df/src/main/java/org/apache/sysml/runtime/instructions/mr/AggregateUnaryInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/mr/AggregateUnaryInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/mr/AggregateUnaryInstruction.java
index a0285c0..409590a 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/mr/AggregateUnaryInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/mr/AggregateUnaryInstruction.java
@@ -98,7 +98,7 @@ public class AggregateUnaryInstruction extends 
UnaryMRInstructionBase
                                        
OperationsOnMatrixValues.performAggregateUnary( inix, in.getValue(), 
out.getIndexes(), out.getValue(), 
                                                                                
    auop, blockRowFactor, blockColFactor);
                                        if( _dropCorr )
-                                               
((MatrixBlock)out.getValue()).dropLastRowsOrColums(auop.aggOp.correctionLocation);
+                                               
((MatrixBlock)out.getValue()).dropLastRowsOrColumns(auop.aggOp.correctionLocation);
                                }
                                
                                //put the output value in the cache

http://git-wip-us.apache.org/repos/asf/systemml/blob/57dff5df/src/main/java/org/apache/sysml/runtime/instructions/mr/BinUaggChainInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/mr/BinUaggChainInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/mr/BinUaggChainInstruction.java
index 2a428bb..a8badd3 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/mr/BinUaggChainInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/mr/BinUaggChainInstruction.java
@@ -98,7 +98,7 @@ public class BinUaggChainInstruction extends UnaryInstruction
                        
                        //process instruction
                        OperationsOnMatrixValues.performAggregateUnary( inIx, 
inVal, _tmpIx, _tmpVal, _uaggOp, blockRowFactor, blockColFactor);
-                       
((MatrixBlock)_tmpVal).dropLastRowsOrColums(_uaggOp.aggOp.correctionLocation);
+                       
((MatrixBlock)_tmpVal).dropLastRowsOrColumns(_uaggOp.aggOp.correctionLocation);
                        
OperationsOnMatrixValues.performBinaryIgnoreIndexes(inVal, _tmpVal, outVal, 
_bOp);
                        outIx.setIndexes(inIx);
                }

http://git-wip-us.apache.org/repos/asf/systemml/blob/57dff5df/src/main/java/org/apache/sysml/runtime/instructions/mr/CumulativeAggregateInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/mr/CumulativeAggregateInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/mr/CumulativeAggregateInstruction.java
index d2272ac..825fecb 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/mr/CumulativeAggregateInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/mr/CumulativeAggregateInstruction.java
@@ -87,7 +87,7 @@ public class CumulativeAggregateInstruction extends 
AggregateUnaryInstruction
                        OperationsOnMatrixValues.performAggregateUnary( inix, 
in1.getValue(), out.getIndexes(), out.getValue(), 
                                                                    
((AggregateUnaryOperator)optr), blockRowFactor, blockColFactor);
                        if( 
((AggregateUnaryOperator)optr).aggOp.correctionExists )
-                               
((MatrixBlock)out.getValue()).dropLastRowsOrColums(((AggregateUnaryOperator)optr).aggOp.correctionLocation);
+                               
((MatrixBlock)out.getValue()).dropLastRowsOrColumns(((AggregateUnaryOperator)optr).aggOp.correctionLocation);
                        
                        //cumsum expand partial aggregates
                        long rlenOut = 
(long)Math.ceil((double)_mcIn.getRows()/blockRowFactor);

http://git-wip-us.apache.org/repos/asf/systemml/blob/57dff5df/src/main/java/org/apache/sysml/runtime/instructions/spark/AggregateTernarySPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/AggregateTernarySPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/AggregateTernarySPInstruction.java
index 7ac1e5b..5822ff5 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/AggregateTernarySPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/AggregateTernarySPInstruction.java
@@ -107,7 +107,7 @@ public class AggregateTernarySPInstruction extends 
ComputationSPInstruction
                {
                        //single block aggregation and drop correction
                        MatrixBlock ret = RDDAggregateUtils.aggStable(out, 
aggop.aggOp);
-                       
ret.dropLastRowsOrColums(aggop.aggOp.correctionLocation);
+                       
ret.dropLastRowsOrColumns(aggop.aggOp.correctionLocation);
                        
                        //put output block into symbol table (no lineage 
because single block)
                        //this also includes implicit maintenance of matrix 
characteristics

http://git-wip-us.apache.org/repos/asf/systemml/blob/57dff5df/src/main/java/org/apache/sysml/runtime/instructions/spark/AggregateUnarySPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/AggregateUnarySPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/AggregateUnarySPInstruction.java
index 352a72e..36f501b 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/AggregateUnarySPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/AggregateUnarySPInstruction.java
@@ -102,7 +102,7 @@ public class AggregateUnarySPInstruction extends 
UnarySPInstruction
                        MatrixBlock out3 = RDDAggregateUtils.aggStable(out2, 
aggop);
                        
                        //drop correction after aggregation
-                       out3.dropLastRowsOrColums(aggop.correctionLocation);
+                       out3.dropLastRowsOrColumns(aggop.correctionLocation);
                        
                        //put output block into symbol table (no lineage 
because single block)
                        //this also includes implicit maintenance of matrix 
characteristics
@@ -222,7 +222,7 @@ public class AggregateUnarySPInstruction extends 
UnarySPInstruction
                        arg0.aggregateUnaryOperations(_op, blkOut, _brlen, 
_bclen, _ix);
                        
                        //always drop correction since no aggregation
-                       
blkOut.dropLastRowsOrColums(_op.aggOp.correctionLocation);
+                       
blkOut.dropLastRowsOrColumns(_op.aggOp.correctionLocation);
                        
                        //output new tuple
                        return blkOut;

http://git-wip-us.apache.org/repos/asf/systemml/blob/57dff5df/src/main/java/org/apache/sysml/runtime/instructions/spark/BinUaggChainSPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/BinUaggChainSPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/BinUaggChainSPInstruction.java
index 21652b0..a1b7d35 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/BinUaggChainSPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/BinUaggChainSPInstruction.java
@@ -109,7 +109,7 @@ public class BinUaggChainSPInstruction extends 
UnarySPInstruction
                        arg0.aggregateUnaryOperations(_uaggOp, out1, brlen, 
bclen, null);
                        
                        //strip-off correction
-                       
out1.dropLastRowsOrColums(_uaggOp.aggOp.correctionLocation);
+                       
out1.dropLastRowsOrColumns(_uaggOp.aggOp.correctionLocation);
                
                        //perform binary operation
                        MatrixBlock out2 = new MatrixBlock();

http://git-wip-us.apache.org/repos/asf/systemml/blob/57dff5df/src/main/java/org/apache/sysml/runtime/instructions/spark/CumulativeAggregateSPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/CumulativeAggregateSPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/CumulativeAggregateSPInstruction.java
index 9dd81aa..0301d54 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/CumulativeAggregateSPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/CumulativeAggregateSPInstruction.java
@@ -116,7 +116,7 @@ public class CumulativeAggregateSPInstruction extends 
AggregateUnarySPInstructio
                        OperationsOnMatrixValues.performAggregateUnary( ixIn, 
blkIn, ixOut, blkOut, 
                                                                    
((AggregateUnaryOperator)_op), _brlen, _bclen);
                        if( 
((AggregateUnaryOperator)_op).aggOp.correctionExists )
-                               
blkOut.dropLastRowsOrColums(((AggregateUnaryOperator)_op).aggOp.correctionLocation);
+                               
blkOut.dropLastRowsOrColumns(((AggregateUnaryOperator)_op).aggOp.correctionLocation);
                        
                        //cumsum expand partial aggregates
                        long rlenOut = (long)Math.ceil((double)_rlen/_brlen);

http://git-wip-us.apache.org/repos/asf/systemml/blob/57dff5df/src/main/java/org/apache/sysml/runtime/instructions/spark/UaggOuterChainSPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/UaggOuterChainSPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/UaggOuterChainSPInstruction.java
index 8f74b9d..a270706 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/UaggOuterChainSPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/UaggOuterChainSPInstruction.java
@@ -163,7 +163,7 @@ public class UaggOuterChainSPInstruction extends 
BinarySPInstruction
                        MatrixBlock tmp = RDDAggregateUtils.aggStable(out, 
_aggOp);
                        
                        //drop correction after aggregation
-                       tmp.dropLastRowsOrColums(_aggOp.correctionLocation);
+                       tmp.dropLastRowsOrColumns(_aggOp.correctionLocation);
 
                        //put output block into symbol table (no lineage 
because single block)
                        sec.setMatrixOutput(output.getName(), tmp, 
getExtendedOpcode());

http://git-wip-us.apache.org/repos/asf/systemml/blob/57dff5df/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/AggregateDropCorrectionFunction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/AggregateDropCorrectionFunction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/AggregateDropCorrectionFunction.java
index 9ace752..d08a227 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/AggregateDropCorrectionFunction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/AggregateDropCorrectionFunction.java
@@ -44,7 +44,7 @@ public class AggregateDropCorrectionFunction implements 
Function<MatrixBlock, Ma
                MatrixBlock blkOut = new MatrixBlock(arg0);
                
                //drop correction
-               blkOut.dropLastRowsOrColums(_op.correctionLocation);
+               blkOut.dropLastRowsOrColumns(_op.correctionLocation);
                
                return blkOut;
        }       

http://git-wip-us.apache.org/repos/asf/systemml/blob/57dff5df/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java
index 094bc93..1ec5290 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixAgg.java
@@ -358,7 +358,7 @@ public class LibMatrixAgg
                        for( int i=0; i<tasks.size(); i++ ) {
                                MatrixBlock row = tasks.get(i).getResult();
                                if( uaop.aggOp.correctionExists )
-                                       
row.dropLastRowsOrColums(uaop.aggOp.correctionLocation);
+                                       
row.dropLastRowsOrColumns(uaop.aggOp.correctionLocation);
                                tmp.leftIndexingOperations(row, i, i, 0, n2-1, 
tmp, UpdateType.INPLACE_PINNED);
                        }
                        MatrixBlock tmp2 = cumaggregateUnaryMatrix(tmp, new 
MatrixBlock(tasks.size(), n2, false), uop);

http://git-wip-us.apache.org/repos/asf/systemml/blob/57dff5df/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
index 7d19d61..e9039ba 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
@@ -4235,7 +4235,7 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock, Externalizab
                        denseAggregateUnaryHelp(op, ret, blockingFactorRow, 
blockingFactorCol, indexesIn);
                
                if(op.aggOp.correctionExists && inCP)
-                       
((MatrixBlock)result).dropLastRowsOrColums(op.aggOp.correctionLocation);
+                       
((MatrixBlock)result).dropLastRowsOrColumns(op.aggOp.correctionLocation);
                
                return ret;
        }
@@ -4384,57 +4384,28 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock, Externalizab
                        result.quickSetValue(row, column, newvalue);
                }
        }
-
-       public void dropLastRowsOrColums(CorrectionLocationType 
correctionLocation) 
+       
+       public void dropLastRowsOrColumns(CorrectionLocationType 
correctionLocation) 
        {
-               //do nothing 
-               if(   correctionLocation==CorrectionLocationType.NONE 
-              || correctionLocation==CorrectionLocationType.INVALID )
-               {
-                       return;
-               }
-               
                //determine number of rows/cols to be removed
-               int step;
-               switch (correctionLocation) {
-                       case LASTROW:
-                       case LASTCOLUMN:
-                               step = 1;
-                               break;
-                       case LASTTWOROWS:
-                       case LASTTWOCOLUMNS:
-                               step = 2;
-                               break;
-                       case LASTFOURROWS:
-                       case LASTFOURCOLUMNS:
-                               step = 4;
-                               break;
-                       default:
-                               step = 0;
-               }
-
+               int step = correctionLocation.getNumRemovedRowsColumns();
+               if( step <= 0 )
+                       return; 
                
                //e.g., colSums, colMeans, colMaxs, colMeans, colVars
                if(   correctionLocation==CorrectionLocationType.LASTROW
                   || correctionLocation==CorrectionLocationType.LASTTWOROWS
                   || correctionLocation==CorrectionLocationType.LASTFOURROWS )
                {
-                       if( sparse ) //SPARSE
-                       {
-                               if(sparseBlock!=null)
-                                       for(int i=1; i<=step; i++)
-                                               if(!sparseBlock.isEmpty(rlen-i))
-                                                       
this.nonZeros-=sparseBlock.size(rlen-i);
+                       if( sparse && sparseBlock!=null ) { //SPARSE
+                               nonZeros -= recomputeNonZeros(1, rlen-1, 0, 
clen-1);
+                               sparseBlock = SparseBlockFactory
+                                       .createSparseBlock(DEFAULT_SPARSEBLOCK, 
sparseBlock.get(0));
                        }
-                       else //DENSE
-                       {
-                               if(denseBlock!=null)
-                                       for(int i=(rlen-step)*clen; 
i<rlen*clen; i++)
-                                               if(denseBlock[i]!=0)
-                                                       this.nonZeros--;
+                       else if( !sparse && denseBlock!=null ) { //DENSE
+                               nonZeros -= recomputeNonZeros(1, rlen-1, 0, 
clen-1);
+                               denseBlock = Arrays.copyOfRange(denseBlock, 0, 
clen);
                        }
-                       
-                       //just need to shrink the dimension, the deleted rows 
won't be accessed
                        rlen -= step;
                }
                
@@ -4443,51 +4414,26 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock, Externalizab
                        || 
correctionLocation==CorrectionLocationType.LASTTWOCOLUMNS
                        || 
correctionLocation==CorrectionLocationType.LASTFOURCOLUMNS )
                {
-                       if(sparse) //SPARSE
-                       {
-                               if(sparseBlock!=null)
-                               {
-                                       for(int r=0; r<Math.min(rlen, 
sparseBlock.numRows()); r++)
-                                               if(!sparseBlock.isEmpty(r))
-                                               {
-                                                       int 
newSize=sparseBlock.posFIndexGTE(r, clen-step);
-                                                       if(newSize >= 0)
-                                                       {
-                                                               
this.nonZeros-=sparseBlock.size(r)-newSize;
-                                                               int pos = 
sparseBlock.pos(r);
-                                                               int cl = 
sparseBlock.indexes(r)[pos+newSize-1];
-                                                               
sparseBlock.deleteIndexRange(r, cl+1, clen);
-                                                               //TODO perf 
sparse block: truncate replaced by deleteIndexRange
-                                                       }
-                                               }
-                               }
+                       if( sparse && sparseBlock!=null ) { //SPARSE
+                               //sparse blocks are converted to a dense 
representation
+                               //because column vectors are always smaller in 
dense
+                               double[] tmp = new double[rlen];
+                               int lnnz = 0;
+                               for( int i=0; i<rlen; i++ )
+                                       lnnz += ((tmp[i] = sparseBlock.get(i, 
0))!=0)? 1 : 0;
+                               cleanupBlock(true, true);
+                               sparse = false;
+                               denseBlock = tmp;
+                               nonZeros = lnnz;
                        }
-                       else //DENSE
-                       {
-                               if(this.denseBlock!=null)
-                               {
-                                       //the first row doesn't need to be 
copied
-                                       int targetIndex=clen-step;
-                                       int sourceOffset=clen;
-                                       this.nonZeros=0;
-                                       for(int i=0; i<targetIndex; i++)
-                                               if(denseBlock[i]!=0)
-                                                       this.nonZeros++;
-                                       
-                                       //start from the 2nd row
-                                       for(int r=1; r<rlen; r++)
-                                       {
-                                               for(int c=0; c<clen-step; c++)
-                                               {
-                                                       
if((denseBlock[targetIndex]=denseBlock[sourceOffset+c])!=0)
-                                                               this.nonZeros++;
-                                                       targetIndex++;
-                                               }
-                                               sourceOffset+=clen;
-                                       }
-                               }
+                       else if( !sparse && denseBlock!=null ) { //DENSE
+                               double[] tmp = new double[rlen];
+                               int lnnz = 0;
+                               for( int i=0, aix=0; i<rlen; i++, aix+=clen )
+                                       lnnz += ((tmp[i] = 
denseBlock[aix])!=0)? 1 : 0;
+                               denseBlock = tmp;
+                               nonZeros = lnnz;
                        }
-                       
                        clen -= step;
                }
        }
@@ -4980,7 +4926,7 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock, Externalizab
                        ret = LibMatrixAgg.aggregateTernary(m1, m2, m3, ret, 
op);
                
                if(op.aggOp.correctionExists && inCP)
-                       ret.dropLastRowsOrColums(op.aggOp.correctionLocation);
+                       ret.dropLastRowsOrColumns(op.aggOp.correctionLocation);
                return ret;
        }
 

http://git-wip-us.apache.org/repos/asf/systemml/blob/57dff5df/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockFactory.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockFactory.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockFactory.java
index 5abd7ba..d5e45ff 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockFactory.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockFactory.java
@@ -27,7 +27,7 @@ public abstract class SparseBlockFactory
                return createSparseBlock(MatrixBlock.DEFAULT_SPARSEBLOCK, rlen);
        }
 
-       public static SparseBlock createSparseBlock( SparseBlock.Type type, int 
rlen ) {
+       public static SparseBlock createSparseBlock(SparseBlock.Type type, int 
rlen) {
                switch( type ) {
                        case MCSR: return new SparseBlockMCSR(rlen, -1);
                        case CSR: return new SparseBlockCSR(rlen);
@@ -36,6 +36,12 @@ public abstract class SparseBlockFactory
                                throw new RuntimeException("Unexpected sparse 
block type: "+type.toString());
                }
        }
+       
+       public static SparseBlock createSparseBlock(SparseBlock.Type type, 
SparseRow row) {
+               SparseBlock ret = createSparseBlock(type, 1);
+               ret.set(0, row, true);
+               return ret;
+       }
 
        public static SparseBlock copySparseBlock( SparseBlock.Type type, 
SparseBlock sblock, boolean forceCopy )
        {

Reply via email to