[SYSTEMML-2225] Fix reblock ultra-sparse, incl mem efficiency read

This patch fixes the robustness of reblocking ultra-sparse matrices by
hardening the CSR index lookups, and better handling of empty blocks on
reblock. Furthermore, this also includes a fix for avoiding unnecessary
csr block creation on initial read for empty blocks.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/addd6e12
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/addd6e12
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/addd6e12

Branch: refs/heads/master
Commit: addd6e121ac8d81af0f90859666b9ac1ec1e5009
Parents: c516145
Author: Matthias Boehm <mboe...@gmail.com>
Authored: Sat Mar 31 19:06:19 2018 -0700
Committer: Matthias Boehm <mboe...@gmail.com>
Committed: Sat Mar 31 21:27:34 2018 -0700

----------------------------------------------------------------------
 .../spark/functions/CopyBlockPairFunction.java  |  2 +-
 .../functions/ExtractBlockForBinaryReblock.java | 31 ++++++++++----------
 .../sysml/runtime/matrix/data/MatrixBlock.java  |  8 +++--
 .../runtime/matrix/data/SparseBlockCSR.java     |  7 +++--
 4 files changed, 26 insertions(+), 22 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/addd6e12/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/CopyBlockPairFunction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/CopyBlockPairFunction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/CopyBlockPairFunction.java
index 5423a8f..8ff6c2a 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/CopyBlockPairFunction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/CopyBlockPairFunction.java
@@ -71,7 +71,7 @@ public class CopyBlockPairFunction implements 
PairFlatMapFunction<Iterator<Tuple
                                MatrixIndexes ix = new MatrixIndexes(arg._1());
                                MatrixBlock block = null;
                                //always create deep copies in more 
memory-efficient CSR representation 
-                               //if block is already in sparse format          
        
+                               //if block is already in sparse format
                                if( Checkpoint.CHECKPOINT_SPARSE_CSR && 
arg._2.isInSparseFormat() )
                                        block = new MatrixBlock(arg._2, 
SparseBlock.Type.CSR, true);
                                else

http://git-wip-us.apache.org/repos/asf/systemml/blob/addd6e12/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ExtractBlockForBinaryReblock.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ExtractBlockForBinaryReblock.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ExtractBlockForBinaryReblock.java
index a2a1ce0..66a2271 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ExtractBlockForBinaryReblock.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/ExtractBlockForBinaryReblock.java
@@ -70,37 +70,36 @@ public class ExtractBlockForBinaryReblock implements 
PairFlatMapFunction<Tuple2<
                long endRowGlobalCellIndex = 
getEndGlobalIndex(ixIn.getRowIndex(), true, true);
                long startColGlobalCellIndex = 
UtilFunctions.computeCellIndex(ixIn.getColumnIndex(), in_bclen, 0);
                long endColGlobalCellIndex = 
getEndGlobalIndex(ixIn.getColumnIndex(), true, false);
-               assert(startRowGlobalCellIndex <= endRowGlobalCellIndex && 
startColGlobalCellIndex <= endColGlobalCellIndex);
                
                long out_startRowBlockIndex = 
UtilFunctions.computeBlockIndex(startRowGlobalCellIndex, out_brlen);
                long out_endRowBlockIndex = 
UtilFunctions.computeBlockIndex(endRowGlobalCellIndex, out_brlen);
                long out_startColBlockIndex = 
UtilFunctions.computeBlockIndex(startColGlobalCellIndex, out_bclen);
                long out_endColBlockIndex = 
UtilFunctions.computeBlockIndex(endColGlobalCellIndex, out_bclen);
-               assert(out_startRowBlockIndex <= out_endRowBlockIndex && 
out_startColBlockIndex <= out_endColBlockIndex);
                
                ArrayList<Tuple2<MatrixIndexes, MatrixBlock>> retVal = new 
ArrayList<>();
                
                for(long i = out_startRowBlockIndex; i <= out_endRowBlockIndex; 
i++) {
                        for(long j = out_startColBlockIndex; j <= 
out_endColBlockIndex; j++) {
                                MatrixIndexes indx = new MatrixIndexes(i, j);
-                               long rowLower = 
Math.max(UtilFunctions.computeCellIndex(i, out_brlen, 0), 
startRowGlobalCellIndex);
-                               long rowUpper = Math.min(getEndGlobalIndex(i, 
false, true), endRowGlobalCellIndex);
-                               long colLower = 
Math.max(UtilFunctions.computeCellIndex(j, out_bclen, 0), 
startColGlobalCellIndex);
-                               long colUpper = Math.min(getEndGlobalIndex(j, 
false, false), endColGlobalCellIndex);
-                               
                                int new_lrlen = 
UtilFunctions.computeBlockSize(rlen, i, out_brlen);
                                int new_lclen = 
UtilFunctions.computeBlockSize(clen, j, out_bclen);
                                MatrixBlock blk = new MatrixBlock(new_lrlen, 
new_lclen, true);
                                
-                               int in_i1 = 
UtilFunctions.computeCellInBlock(rowLower, in_brlen);
-                               int out_i1 = 
UtilFunctions.computeCellInBlock(rowLower, out_brlen);
-                               
-                               for(long i1 = rowLower; i1 <= rowUpper; i1++, 
in_i1++, out_i1++) {
-                                       int in_j1 = 
UtilFunctions.computeCellInBlock(colLower, in_bclen);
-                                       int out_j1 = 
UtilFunctions.computeCellInBlock(colLower, out_bclen);
-                                       for(long j1 = colLower; j1 <= colUpper; 
j1++, in_j1++, out_j1++) {
-                                               double val = in.getValue(in_i1, 
in_j1);
-                                               blk.appendValue(out_i1, out_j1, 
val);
+                               if( !in.isEmptyBlock(false) ) {
+                                       long rowLower = 
Math.max(UtilFunctions.computeCellIndex(i, out_brlen, 0), 
startRowGlobalCellIndex);
+                                       long rowUpper = 
Math.min(getEndGlobalIndex(i, false, true), endRowGlobalCellIndex);
+                                       long colLower = 
Math.max(UtilFunctions.computeCellIndex(j, out_bclen, 0), 
startColGlobalCellIndex);
+                                       long colUpper = 
Math.min(getEndGlobalIndex(j, false, false), endColGlobalCellIndex);
+                                       int in_i1 = 
UtilFunctions.computeCellInBlock(rowLower, in_brlen);
+                                       int out_i1 = 
UtilFunctions.computeCellInBlock(rowLower, out_brlen);
+                                       
+                                       for(long i1 = rowLower; i1 <= rowUpper; 
i1++, in_i1++, out_i1++) {
+                                               int in_j1 = 
UtilFunctions.computeCellInBlock(colLower, in_bclen);
+                                               int out_j1 = 
UtilFunctions.computeCellInBlock(colLower, out_bclen);
+                                               for(long j1 = colLower; j1 <= 
colUpper; j1++, in_j1++, out_j1++) {
+                                                       double val = 
in.quickGetValue(in_i1, in_j1);
+                                                       blk.appendValue(out_i1, 
out_j1, val);
+                                               }
                                        }
                                }
                                retVal.add(new Tuple2<>(indx, blk));

http://git-wip-us.apache.org/repos/asf/systemml/blob/addd6e12/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
index f19fbe5..efe2365 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/MatrixBlock.java
@@ -184,10 +184,12 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock, Externalizab
                        throw new RuntimeException("Sparse matrix block 
expected.");
                
                //deep copy and change sparse block type
-               nonZeros = that.nonZeros;
-               estimatedNNzsPerRow = that.estimatedNNzsPerRow;
-               sparseBlock = SparseBlockFactory
+               if( !that.isEmptyBlock(false) ) {
+                       nonZeros = that.nonZeros;
+                       estimatedNNzsPerRow = that.estimatedNNzsPerRow;
+                       sparseBlock = SparseBlockFactory
                                .copySparseBlock(stype, that.sparseBlock, deep);
+               }
        }
        
        ////////

http://git-wip-us.apache.org/repos/asf/systemml/blob/addd6e12/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockCSR.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockCSR.java 
b/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockCSR.java
index de0c34b..6bbc81d 100644
--- a/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockCSR.java
+++ b/src/main/java/org/apache/sysml/runtime/matrix/data/SparseBlockCSR.java
@@ -734,16 +734,20 @@ public class SparseBlockCSR extends SparseBlock
 
        @Override
        public double get(int r, int c) {
+               if( isEmpty(r) )
+                       return 0;
                int pos = pos(r);
                int len = size(r);
                
                //search for existing col index in [pos,pos+len)
-               int index = Arrays.binarySearch(_indexes, pos, pos+len, c);     
        
+               int index = Arrays.binarySearch(_indexes, pos, pos+len, c);
                return (index >= 0) ? _values[index] : 0;
        }
        
        @Override 
        public SparseRow get(int r) {
+               if( isEmpty(r) )
+                       return new SparseRowScalar();
                int pos = pos(r);
                int len = size(r);
                
@@ -751,7 +755,6 @@ public class SparseBlockCSR extends SparseBlock
                System.arraycopy(_indexes, pos, row.indexes(), 0, len);
                System.arraycopy(_values, pos, row.values(), 0, len);
                row.setSize(len);
-               
                return row;
        }
        

Reply via email to