This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit ab4ec284b9dbe320087c9108c041ebdeccc23282
Author: Sebastian Baunsgaard <[email protected]>
AuthorDate: Sun Jan 7 19:31:20 2024 +0100

    [MINOR] Change CLA to normal SUM
    
    This commit change CLA to utilize the recently committed SUM operation
    without KAHAN. This commit also modify the block size for the
    parallelization to improve performance over a number of test files.
    
    Closes #1977
---
 .../sysds/runtime/compress/lib/CLALibCompAgg.java  | 53 ++++++++++++----------
 1 file changed, 29 insertions(+), 24 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java 
b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java
index 95a460a2e0..999c95d54f 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java
@@ -31,6 +31,7 @@ import org.apache.commons.lang3.NotImplementedException;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types.CorrectionLocationType;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
 import org.apache.sysds.runtime.compress.CompressionSettings;
@@ -214,7 +215,7 @@ public final class CLALibCompAgg {
 
        private static AggregateUnaryOperator 
replaceKahnOperations(AggregateUnaryOperator op) {
                if(op.aggOp.increOp.fn instanceof KahanPlus)
-                       return new AggregateUnaryOperator(new 
AggregateOperator(0, Plus.getPlusFnObject()), op.indexFn,
+                       return new AggregateUnaryOperator(new 
AggregateOperator(0, Plus.getPlusFnObject(), CorrectionLocationType.NONE), 
op.indexFn,
                                op.getNumThreads());
                return op;
        }
@@ -224,7 +225,7 @@ public final class CLALibCompAgg {
                int k = op.getNumThreads();
                // replace mean operation with plus.
                AggregateUnaryOperator opm = (op.aggOp.increOp.fn instanceof 
Mean) ? new AggregateUnaryOperator(
-                       new AggregateOperator(0, Plus.getPlusFnObject()), 
op.indexFn) : op;
+                       new AggregateOperator(0, Plus.getPlusFnObject(), 
CorrectionLocationType.NONE), op.indexFn) : op;
 
                if(isValidForParallelProcessing(m, op))
                        aggregateInParallel(m, o, opm, k);
@@ -415,7 +416,7 @@ public final class CLALibCompAgg {
                final ArrayList<UAOverlappingTask> tasks = new ArrayList<>();
                final int nCol = m1.getNumColumns();
                final int nRow = m1.getNumRows();
-               final int blklen = Math.max(512, nRow / k);
+               final int blklen = Math.max(64, nRow / k);
                final List<AColGroup> groups = m1.getColGroups();
                final boolean shouldFilter = 
CLALibUtils.shouldPreFilter(groups);
                if(shouldFilter) {
@@ -568,7 +569,7 @@ public final class CLALibCompAgg {
                        _op = op;
                        _rl = rl;
                        _ru = ru;
-                       _blklen = Math.max(65536  / ret.getNumColumns() / 
filteredGroups.size(), 64);
+                       _blklen = Math.max(16384  / nCol, 64);
                        _ret = ret;
                        _nCol = nCol;
                }
@@ -581,7 +582,6 @@ public final class CLALibCompAgg {
 
                private MatrixBlock decompressToTemp(MatrixBlock tmp, int rl, 
int ru, AIterator[] its) {
                        Timing time = new Timing(true);
-
                        DenseBlock db = tmp.getDenseBlock();
                        for(int i = 0; i < _groups.size(); i++) {
                                AColGroup g = _groups.get(i);
@@ -619,12 +619,34 @@ public final class CLALibCompAgg {
                        for(int i = 0; i < _groups.size(); i++)
                                if(_groups.get(i) instanceof ASDCZero)
                                        its[i] = ((ASDCZero) 
_groups.get(i)).getIterator(_rl);
-                       if(_op.indexFn instanceof ReduceCol) {
+
+                       if(_op.indexFn instanceof ReduceCol) { // row aggregates
+                               reduceCol(tmp, its, isBinaryOp);
+                               return null;
+                       }
+                       else if(_op.indexFn instanceof ReduceAll) {
+                               decompressToTemp(tmp, _rl, _ru, its);
+                               MatrixBlock outputBlock = 
LibMatrixAgg.prepareAggregateUnaryOutput(tmp, _op, null, 1000);
+                               LibMatrixAgg.aggregateUnaryMatrix(tmp, 
outputBlock, _op);
+                               
outputBlock.dropLastRowsOrColumns(_op.aggOp.correction);
+                               return outputBlock;
+                       }
+                       else { // reduce to rows.
+                               decompressToTemp(tmp, _rl, _ru, its);
+                               MatrixBlock outputBlock = 
LibMatrixAgg.prepareAggregateUnaryOutput(tmp, _op, null, 1000);
+                               LibMatrixAgg.aggregateUnaryMatrix(tmp, 
outputBlock, _op);
+                               
outputBlock.dropLastRowsOrColumns(_op.aggOp.correction);
+                               return outputBlock;
+                       }
+               }
+
+               private void reduceCol(MatrixBlock tmp,AIterator[] its, boolean 
isBinaryOp){
+                               final MatrixBlock tmpR = 
LibMatrixAgg.prepareAggregateUnaryOutput(tmp, _op, null, 1000);
                                for(int r = _rl; r < _ru; r += _blklen) {
                                        final int rbu = Math.min(r + _blklen, 
_ru);
                                        tmp.reset(rbu - r, tmp.getNumColumns(), 
false);
                                        decompressToTemp(tmp, r, rbu, its);
-                                       final MatrixBlock tmpR = 
tmp.prepareAggregateUnaryOutput(_op, null, 1000);
+                                       tmpR.reset();
                                        LibMatrixAgg.aggregateUnaryMatrix(tmp, 
tmpR, _op);
 
                                        
tmpR.dropLastRowsOrColumns(_op.aggOp.correction);
@@ -649,23 +671,6 @@ public final class CLALibCompAgg {
                                                System.arraycopy(tmpRValues, 0, 
retValues, currentIndex, length);
                                        }
                                }
-                               return null;
-
-                       }
-                       else if(_op.indexFn instanceof ReduceAll) {
-                               decompressToTemp(tmp, _rl, _ru, its);
-                               MatrixBlock outputBlock = 
tmp.prepareAggregateUnaryOutput(_op, null, 1000);
-                               LibMatrixAgg.aggregateUnaryMatrix(tmp, 
outputBlock, _op);
-                               
outputBlock.dropLastRowsOrColumns(_op.aggOp.correction);
-                               return outputBlock;
-                       }
-                       else { // reduce to rows.
-                               decompressToTemp(tmp, _rl, _ru, its);
-                               MatrixBlock outputBlock = 
tmp.prepareAggregateUnaryOutput(_op, null, 1000);
-                               LibMatrixAgg.aggregateUnaryMatrix(tmp, 
outputBlock, _op);
-                               
outputBlock.dropLastRowsOrColumns(_op.aggOp.correction);
-                               return outputBlock;
-                       }
                }
        }
 }

Reply via email to