janniklinde commented on code in PR #2361:
URL: https://github.com/apache/systemds/pull/2361#discussion_r2613867240


##########
src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java:
##########
@@ -288,6 +289,12 @@ else if((ct == CompressionType.SDC || ct == 
CompressionType.CONST) //
                else if(ct == CompressionType.DDC) {
                        return directCompressDDC(colIndexes, cg);
                }
+               else if(ct == CompressionType.DeltaDDC) {
+                       return directCompressDeltaDDC(colIndexes, cg);
+               }
+               else if(ct == CompressionType.CONST && cs.preferDeltaEncoding) {

Review Comment:
   Why would you encode CONST as DeltaDDC? 



##########
src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java:
##########
@@ -1105,4 +1109,74 @@ protected boolean allowShallowIdentityRightMult() {
                return true;
        }
 
+       public AColGroup convertToDeltaDDC() {
+               int numCols = _colIndexes.size();
+               int numRows = _data.size();
+               
+               DblArrayCountHashMap map = new 
DblArrayCountHashMap(Math.max(numRows, 64));
+               double[] rowDelta = new double[numCols];
+               double[] prevRow = new double[numCols];
+               DblArray dblArray = new DblArray(rowDelta);
+               int[] rowToDictId = new int[numRows];
+               
+               double[] dictVals = null;
+               try {
+                       dictVals = _dict.getValues();

Review Comment:
   Why is this wrapped in a try ... catch? I don't think that there is a 
scenario where this would fail



##########
src/main/java/org/apache/sysds/runtime/compress/lib/CLALibUnary.java:
##########
@@ -43,6 +49,70 @@ public static MatrixBlock 
unaryOperations(CompressedMatrixBlock m, UnaryOperator
                final boolean overlapping = m.isOverlapping();
                final int r = m.getNumRows();
                final int c = m.getNumColumns();
+               
+               if(Builtin.isBuiltinCode(op.fn, BuiltinCode.CUMSUM, 
BuiltinCode.ROWCUMSUM)) {
+                       List<AColGroup> groups = m.getColGroups();
+                       boolean allDDC = true;
+                       for(AColGroup g : groups) {
+                               if(g.getCompType() != CompressionType.DDC) {
+                                       allDDC = false;
+                                       break;
+                               }
+                       }
+                       
+                       if(allDDC && !groups.isEmpty()) {
+                               MatrixBlock uncompressed = 
m.getUncompressed("CUMSUM/ROWCUMSUM requires uncompression", 
op.getNumThreads());
+                               MatrixBlock opResult = 
uncompressed.unaryOperations(op, null);
+                               
+                               List<AColGroup> convertedGroups = new 
ArrayList<>(groups.size());
+                               for(AColGroup g : groups) {
+                                       AColGroup converted = ((ColGroupDDC) 
g).convertToDeltaDDC();
+                                       if(converted == null) {
+                                               allDDC = false;
+                                               break;
+                                       }
+                                       convertedGroups.add(converted);
+                               }
+                               
+                               if(allDDC) {
+                                       CompressedMatrixBlock ret = new 
CompressedMatrixBlock(m.getNumRows(), m.getNumColumns());
+                                       
ret.allocateColGroupList(convertedGroups);
+                                       ret.recomputeNonZeros();
+                                       
+                                       MatrixBlock verifyDecompressed = 
ret.getUncompressed("Verification", op.getNumThreads());
+                                       if(verifyDecompressed.equals(opResult)) 
{
+                                               return ret;
+                                       }
+                               }
+                       }
+                       
+                       MatrixBlock uncompressed = 
m.getUncompressed("CUMSUM/ROWCUMSUM requires uncompression", 
op.getNumThreads());
+                       MatrixBlock opResult = uncompressed.unaryOperations(op, 
null);
+                       
+                       CompressionSettingsBuilder csb = new 
CompressionSettingsBuilder();
+                       csb.clearValidCompression();
+                       csb.setPreferDeltaEncoding(true);
+                       csb.addValidCompression(CompressionType.DeltaDDC);
+                       csb.addValidCompression(CompressionType.UNCOMPRESSED);
+                       csb.setTransposeInput("false");
+                       Pair<MatrixBlock, CompressionStatistics> compressedPair 
= CompressedMatrixBlockFactory.compress(opResult, op.getNumThreads(), csb);
+                       MatrixBlock compressedResult = compressedPair.getLeft();
+                       
+                       if(compressedResult == null) {
+                               compressedResult = opResult;
+                       }
+                       
+                       CompressedMatrixBlock finalResult;
+                       if(compressedResult instanceof CompressedMatrixBlock) {
+                               finalResult = (CompressedMatrixBlock) 
compressedResult;
+                       }
+                       else {
+                               finalResult = 
CompressedMatrixBlockFactory.genUncompressedCompressedMatrixBlock(compressedResult);
+                       }
+                       
+                       return finalResult;
+               }
+               

Review Comment:
   In general, it might make more sense to put this branch below if 
(m.isEmpty())...



##########
src/main/java/org/apache/sysds/runtime/compress/lib/CLALibUnary.java:
##########
@@ -43,6 +49,70 @@ public static MatrixBlock 
unaryOperations(CompressedMatrixBlock m, UnaryOperator
                final boolean overlapping = m.isOverlapping();
                final int r = m.getNumRows();
                final int c = m.getNumColumns();
+               
+               if(Builtin.isBuiltinCode(op.fn, BuiltinCode.CUMSUM, 
BuiltinCode.ROWCUMSUM)) {
+                       List<AColGroup> groups = m.getColGroups();
+                       boolean allDDC = true;
+                       for(AColGroup g : groups) {
+                               if(g.getCompType() != CompressionType.DDC) {
+                                       allDDC = false;
+                                       break;
+                               }
+                       }
+                       
+                       if(allDDC && !groups.isEmpty()) {
+                               MatrixBlock uncompressed = 
m.getUncompressed("CUMSUM/ROWCUMSUM requires uncompression", 
op.getNumThreads());
+                               MatrixBlock opResult = 
uncompressed.unaryOperations(op, null);
+                               
+                               List<AColGroup> convertedGroups = new 
ArrayList<>(groups.size());
+                               for(AColGroup g : groups) {
+                                       AColGroup converted = ((ColGroupDDC) 
g).convertToDeltaDDC();
+                                       if(converted == null) {
+                                               allDDC = false;
+                                               break;
+                                       }
+                                       convertedGroups.add(converted);
+                               }
+                               
+                               if(allDDC) {
+                                       CompressedMatrixBlock ret = new 
CompressedMatrixBlock(m.getNumRows(), m.getNumColumns());
+                                       
ret.allocateColGroupList(convertedGroups);
+                                       ret.recomputeNonZeros();
+                                       
+                                       MatrixBlock verifyDecompressed = 
ret.getUncompressed("Verification", op.getNumThreads());
+                                       if(verifyDecompressed.equals(opResult)) 
{
+                                               return ret;
+                                       }
+                               }
+                       }
+                       
+                       MatrixBlock uncompressed = 
m.getUncompressed("CUMSUM/ROWCUMSUM requires uncompression", 
op.getNumThreads());
+                       MatrixBlock opResult = uncompressed.unaryOperations(op, 
null);
+                       
+                       CompressionSettingsBuilder csb = new 
CompressionSettingsBuilder();
+                       csb.clearValidCompression();
+                       csb.setPreferDeltaEncoding(true);
+                       csb.addValidCompression(CompressionType.DeltaDDC);
+                       csb.addValidCompression(CompressionType.UNCOMPRESSED);
+                       csb.setTransposeInput("false");
+                       Pair<MatrixBlock, CompressionStatistics> compressedPair 
= CompressedMatrixBlockFactory.compress(opResult, op.getNumThreads(), csb);
+                       MatrixBlock compressedResult = compressedPair.getLeft();
+                       
+                       if(compressedResult == null) {
+                               compressedResult = opResult;
+                       }
+                       
+                       CompressedMatrixBlock finalResult;
+                       if(compressedResult instanceof CompressedMatrixBlock) {
+                               finalResult = (CompressedMatrixBlock) 
compressedResult;
+                       }
+                       else {
+                               finalResult = 
CompressedMatrixBlockFactory.genUncompressedCompressedMatrixBlock(compressedResult);
+                       }
+                       
+                       return finalResult;

Review Comment:
   This part is unnecessary. Let it just fall through and let the case 
`LibMatrixAgg.isSupportedUnaryOperator(op)` handle it.



##########
src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java:
##########
@@ -1105,4 +1109,74 @@ protected boolean allowShallowIdentityRightMult() {
                return true;
        }
 
+       public AColGroup convertToDeltaDDC() {
+               int numCols = _colIndexes.size();
+               int numRows = _data.size();
+               
+               DblArrayCountHashMap map = new 
DblArrayCountHashMap(Math.max(numRows, 64));
+               double[] rowDelta = new double[numCols];
+               double[] prevRow = new double[numCols];
+               DblArray dblArray = new DblArray(rowDelta);
+               int[] rowToDictId = new int[numRows];
+               
+               double[] dictVals = null;
+               try {
+                       dictVals = _dict.getValues();
+               } catch (Exception e) {
+               }
+
+               for(int i = 0; i < numRows; i++) {
+                       int dictIdx = _data.getIndex(i);
+                       if(dictVals != null) {
+                               int off = dictIdx * numCols;
+                               for(int j = 0; j < numCols; j++) {
+                                       double val = dictVals[off + j];
+                                       if(i == 0) {
+                                               rowDelta[j] = val;
+                                               prevRow[j] = val;
+                                       } else {
+                                               rowDelta[j] = val - prevRow[j];
+                                               prevRow[j] = val;
+                                       }
+                               }
+                       } else {

Review Comment:
   Can this case actually happen? Otherwise remove that null check



##########
src/main/java/org/apache/sysds/runtime/compress/lib/CLALibUnary.java:
##########
@@ -43,6 +49,70 @@ public static MatrixBlock 
unaryOperations(CompressedMatrixBlock m, UnaryOperator
                final boolean overlapping = m.isOverlapping();
                final int r = m.getNumRows();
                final int c = m.getNumColumns();
+               
+               if(Builtin.isBuiltinCode(op.fn, BuiltinCode.CUMSUM, 
BuiltinCode.ROWCUMSUM)) {

Review Comment:
   Don't handle the `ROWCUMSUM` case, it would not be efficient (and DDC 
reinterpretation would be wrong)



##########
src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java:
##########
@@ -311,7 +318,7 @@ else if(ct == CompressionType.SDC && colIndexes.size() == 1 
&& !t) {
                        return new ColGroupEmpty(colIndexes);
                }
                final IntArrayList[] of = ubm.getOffsetList();
-               if(of.length == 1 && of[0].size() == nRow) { // If this always 
constant
+               if(of.length == 1 && of[0].size() == nRow && ct != 
CompressionType.DeltaDDC) { // If this always constant

Review Comment:
   Why would you encode CONST as DeltaDDC?



##########
src/main/java/org/apache/sysds/runtime/compress/utils/DblArrayCountHashMap.java:
##########
@@ -40,7 +40,7 @@ protected final int hash(DblArray key) {
        }
 
        protected final DArrCounts create(DblArray key, int id) {
-               return new DArrCounts(key, id);
+               return new DArrCounts(new DblArray(key), id);

Review Comment:
   You don't need to create a copy of `key` because `new DArrCounts(...)` 
already takes care of that. So you can safely revert that change



##########
src/main/java/org/apache/sysds/runtime/compress/lib/CLALibUnary.java:
##########
@@ -43,6 +49,70 @@ public static MatrixBlock 
unaryOperations(CompressedMatrixBlock m, UnaryOperator
                final boolean overlapping = m.isOverlapping();
                final int r = m.getNumRows();
                final int c = m.getNumColumns();
+               
+               if(Builtin.isBuiltinCode(op.fn, BuiltinCode.CUMSUM, 
BuiltinCode.ROWCUMSUM)) {
+                       List<AColGroup> groups = m.getColGroups();
+                       boolean allDDC = true;
+                       for(AColGroup g : groups) {
+                               if(g.getCompType() != CompressionType.DDC) {
+                                       allDDC = false;
+                                       break;
+                               }
+                       }
+                       
+                       if(allDDC && !groups.isEmpty()) {
+                               MatrixBlock uncompressed = 
m.getUncompressed("CUMSUM/ROWCUMSUM requires uncompression", 
op.getNumThreads());
+                               MatrixBlock opResult = 
uncompressed.unaryOperations(op, null);

Review Comment:
   Don't uncompress to do this redundant operation. In case of `CUMSUM` the 
reinterpretation should always be correct, so no need to verify. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to