This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new df26e676a0 [SYSTEMDS-3918] New out-of-core tmp aggregation util
df26e676a0 is described below

commit df26e676a03b8957fe402c13e415bf854e4d6fe3
Author: Jessica Priebe <[email protected]>
AuthorDate: Sun Sep 21 16:48:17 2025 +0200

    [SYSTEMDS-3918] New out-of-core tmp aggregation util
    
    Closes #2318.
---
 .../ooc/AggregateUnaryOOCInstruction.java          | 56 ++++++++++------------
 .../ooc/MatrixVectorBinaryOOCInstruction.java      | 28 ++++-------
 .../runtime/instructions/ooc/OOCInstruction.java   | 54 +++++++++++++++++++++
 3 files changed, 89 insertions(+), 49 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java
index a656cd337c..b71cdaaeb5 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java
@@ -75,14 +75,11 @@ public class AggregateUnaryOOCInstruction extends 
ComputationOOCInstruction {
                int blen = ConfigurationManager.getBlocksize();
 
                if (aggun.isRowAggregate() || aggun.isColAggregate()) {
-                       // intermediate state per aggregation index
-                       HashMap<Long, MatrixBlock> aggs = new HashMap<>(); // 
partial aggregates
-                       HashMap<Long, MatrixBlock> corrs = new HashMap<>(); // 
correction blocks
-                       HashMap<Long, Integer> cnt = new HashMap<>(); // 
processed block count per agg idx
-
                        DataCharacteristics chars = 
ec.getDataCharacteristics(input1.getName());
                        // number of blocks to process per aggregation idx (row 
or column dim)
-                       long nBlocks = aggun.isRowAggregate()? 
chars.getNumColBlocks() : chars.getNumRowBlocks();
+                       long emitThreshold = aggun.isRowAggregate()? 
chars.getNumColBlocks() : chars.getNumRowBlocks();
+                       OOCMatrixBlockTracker aggTracker = new 
OOCMatrixBlockTracker(emitThreshold);
+                       HashMap<Long, MatrixBlock> corrs = new HashMap<>(); // 
correction blocks
 
                        LocalTaskQueue<IndexedMatrixValue> qOut = new 
LocalTaskQueue<>();
                        ec.getMatrixObject(output).setStreamHandle(qOut);
@@ -94,9 +91,8 @@ public class AggregateUnaryOOCInstruction extends 
ComputationOOCInstruction {
                                                while((tmp = q.dequeueTask()) 
!= LocalTaskQueue.NO_MORE_TASKS) {
                                                        long idx  = 
aggun.isRowAggregate() ?
                                                                
tmp.getIndexes().getRowIndex() : tmp.getIndexes().getColumnIndex();
-                                                       
if(aggs.containsKey(idx)) {
-                                                               // update 
existing partial aggregate for this idx
-                                                               MatrixBlock ret 
= aggs.get(idx);
+                                                       MatrixBlock ret = 
aggTracker.get(idx);
+                                                       if(ret != null) {
                                                                MatrixBlock 
corr = corrs.get(idx);
 
                                                                // aggregation
@@ -105,9 +101,10 @@ public class AggregateUnaryOOCInstruction extends 
ComputationOOCInstruction {
                                                                
OperationsOnMatrixValues.incrementalAggregation(ret,
                                                                        
_aop.existsCorrection() ? corr : null, ltmp, _aop, true);
 
-                                                               
aggs.replace(idx, ret);
-                                                               
corrs.replace(idx, corr);
-                                                               
cnt.replace(idx, cnt.get(idx) + 1);
+                                                               if 
(!aggTracker.putAndIncrementCount(idx, ret)){
+                                                                       
corrs.replace(idx, corr);
+                                                                       
continue;
+                                                               }
                                                        }
                                                        else {
                                                                // first block 
for this idx - init aggregate and correction
@@ -115,7 +112,7 @@ public class AggregateUnaryOOCInstruction extends 
ComputationOOCInstruction {
                                                                int rows = 
tmp.getValue().getNumRows();
                                                                int cols = 
tmp.getValue().getNumColumns();
                                                                int extra = 
_aop.correction.getNumRemovedRowsColumns();
-                                                               MatrixBlock ret 
= aggun.isRowAggregate()? new MatrixBlock(rows, 1 + extra, false) : new 
MatrixBlock(1 + extra, cols, false);
+                                                               ret = 
aggun.isRowAggregate()? new MatrixBlock(rows, 1 + extra, false) : new 
MatrixBlock(1 + extra, cols, false);
                                                                MatrixBlock 
corr = aggun.isRowAggregate()? new MatrixBlock(rows, 1 + extra, false) : new 
MatrixBlock(1 + extra, cols, false);
 
                                                                // aggregation
@@ -124,25 +121,24 @@ public class AggregateUnaryOOCInstruction extends 
ComputationOOCInstruction {
                                                                
OperationsOnMatrixValues.incrementalAggregation(ret,
                                                                        
_aop.existsCorrection() ? corr : null, ltmp, _aop, true);
 
-                                                               aggs.put(idx, 
ret);
-                                                               corrs.put(idx, 
corr);
-                                                               cnt.put(idx, 1);
+                                                               
if(emitThreshold > 1){
+                                                                       
aggTracker.putAndIncrementCount(idx, ret);
+                                                                       
corrs.put(idx, corr);
+                                                                       
continue;
+                                                               }
                                                        }
 
-                                                       if(cnt.get(idx) == 
nBlocks) {
-                                                               // all input 
blocks for this idx processed - emit aggregated block
-                                                               MatrixBlock ret 
= aggs.get(idx);
-                                                               // drop 
correction row/col
-                                                               
ret.dropLastRowsOrColumns(_aop.correction);
-                                                               MatrixIndexes 
midx = aggun.isRowAggregate()? new 
MatrixIndexes(tmp.getIndexes().getRowIndex(), 1) : new MatrixIndexes(1, 
tmp.getIndexes().getColumnIndex());
-                                                               
IndexedMatrixValue tmpOut = new IndexedMatrixValue(midx, ret);
-
-                                                               
qOut.enqueueTask(tmpOut);
-                                                               // drop 
intermediate states
-                                                               
aggs.remove(idx);
-                                                               
corrs.remove(idx);
-                                                               cnt.remove(idx);
-                                                       }
+                                                       // all input blocks for 
this idx processed - emit aggregated block
+                                                       
ret.dropLastRowsOrColumns(_aop.correction);
+                                                       MatrixIndexes midx = 
aggun.isRowAggregate() ?
+                                                               new 
MatrixIndexes(tmp.getIndexes().getRowIndex(), 1) :
+                                                               new 
MatrixIndexes(1, tmp.getIndexes().getColumnIndex());
+                                                       IndexedMatrixValue 
tmpOut = new IndexedMatrixValue(midx, ret);
+
+                                                       
qOut.enqueueTask(tmpOut);
+                                                       // drop intermediate 
states
+                                                       aggTracker.remove(idx);
+                                                       corrs.remove(idx);
                                                }
                                                qOut.closeInput();
                                        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java
index 5e2d36d9df..ae84e4b541 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java
@@ -82,7 +82,8 @@ public class MatrixVectorBinaryOOCInstruction extends 
ComputationOOCInstruction
                }
 
                // number of colBlocks for early block output
-               long nBlocks = min.getDataCharacteristics().getNumColBlocks();
+               long emitThreshold = 
min.getDataCharacteristics().getNumColBlocks();
+               OOCMatrixBlockTracker aggTracker = new 
OOCMatrixBlockTracker(emitThreshold);
 
                LocalTaskQueue<IndexedMatrixValue> qIn = min.getStreamHandle();
                LocalTaskQueue<IndexedMatrixValue> qOut = new 
LocalTaskQueue<>();
@@ -95,8 +96,6 @@ public class MatrixVectorBinaryOOCInstruction extends 
ComputationOOCInstruction
                        pool.submit(() -> {
                                IndexedMatrixValue tmp = null;
                                try {
-                                       HashMap<Long, MatrixBlock> 
partialResults = new  HashMap<>();
-                                       HashMap<Long, Integer> cnt = new 
HashMap<>();
                                        while((tmp = qIn.dequeueTask()) != 
LocalTaskQueue.NO_MORE_TASKS) {
                                                MatrixBlock matrixBlock = 
(MatrixBlock) tmp.getValue();
                                                long rowIndex = 
tmp.getIndexes().getRowIndex();
@@ -108,31 +107,22 @@ public class MatrixVectorBinaryOOCInstruction extends 
ComputationOOCInstruction
                                                        matrixBlock, 
vectorSlice, new MatrixBlock(), (AggregateBinaryOperator) _optr);
 
                                                // for single column block, no 
aggregation neeeded
-                                               if( min.getNumColumns() <= 
min.getBlocksize() ) {
+                                               if(emitThreshold == 1) {
                                                        qOut.enqueueTask(new 
IndexedMatrixValue(tmp.getIndexes(), partialResult));
                                                }
                                                else {
                                                        // aggregation
-                                                       MatrixBlock currAgg = 
partialResults.get(rowIndex);
+                                                       MatrixBlock currAgg = 
aggTracker.get(rowIndex);
                                                        if (currAgg == null) {
-                                                               
partialResults.put(rowIndex, partialResult);
-                                                               
cnt.put(rowIndex, 1);
+                                                               
aggTracker.putAndIncrementCount(rowIndex, partialResult);
                                                        }
                                                        else {
-                                                               
currAgg.binaryOperationsInPlace(plus, partialResult);
-                                                               int newCnt = 
cnt.get(rowIndex) + 1;
-                                                               
-                                                               if(newCnt == 
nBlocks){
+                                                               currAgg = 
currAgg.binaryOperations(plus, partialResult);
+                                                               if 
(aggTracker.putAndIncrementCount(rowIndex, currAgg)){
                                                                        // 
early block output: emit aggregated block
                                                                        
MatrixIndexes idx = new MatrixIndexes(rowIndex, 1L);
-                                                                       
MatrixBlock result = partialResults.get(rowIndex);
-                                                                       
qOut.enqueueTask(new IndexedMatrixValue(idx, result));
-                                                                       
partialResults.remove(rowIndex);
-                                                                       
cnt.remove(rowIndex);
-                                                               }
-                                                               else {
-                                                                       // 
maintain aggregation counts if not output-ready yet
-                                                                       
cnt.replace(rowIndex, newCnt);
+                                                                       
qOut.enqueueTask(new IndexedMatrixValue(idx, currAgg));
+                                                                       
aggTracker.remove(rowIndex);
                                                                }
                                                        }
                                                }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
index 0495dcfde5..2e5e6f41eb 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
@@ -24,8 +24,11 @@ import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.instructions.Instruction;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 
+import java.util.HashMap;
+
 public abstract class OOCInstruction extends Instruction {
        protected static final Log LOG = 
LogFactory.getLog(OOCInstruction.class.getName());
 
@@ -82,4 +85,55 @@ public abstract class OOCInstruction extends Instruction {
                if(DMLScript.LINEAGE_DEBUGGER)
                        ec.maintainLineageDebuggerInfo(this);
        }
+
+       /**
+        * Tracks blocks and their counts to enable early emission
+        * once all blocks for a given index are processed.
+        */
+       public static class OOCMatrixBlockTracker {
+               private final long _emitThreshold;
+               private final HashMap<Long, MatrixBlock> _blocks;
+               private final HashMap<Long, Integer> _cnts;
+
+               public OOCMatrixBlockTracker(long emitThreshold) {
+                       _emitThreshold = emitThreshold;
+                       _blocks = new HashMap<>();
+                       _cnts = new HashMap<>();
+               }
+
+               /**
+                * Adds or updates a block for the given index and updates its 
internal count.
+                * @param idx   block index
+                * @param block MatrixBlock
+                * @return true if the block count reached the threshold (ready 
to emit), false otherwise
+                */
+               public boolean putAndIncrementCount(Long idx, MatrixBlock 
block) {
+                       _blocks.put(idx, block);
+                       int newCnt = _cnts.getOrDefault(idx, 0) + 1;
+                       if( newCnt < _emitThreshold )
+                               _cnts.put(idx, newCnt);
+                       return newCnt == _emitThreshold;
+               }
+
+               public boolean incrementCount(Long idx) {
+                       int newCnt = _cnts.get(idx) + 1;
+                       if( newCnt < _emitThreshold )
+                               _cnts.put(idx, newCnt);
+                       return newCnt == _emitThreshold;
+               }
+
+               public void putAndInitCount(Long idx, MatrixBlock block) {
+                       _blocks.put(idx, block);
+                       _cnts.put(idx, 0);
+               }
+
+               public MatrixBlock get(Long idx) {
+                       return _blocks.get(idx);
+               }
+
+               public void remove(Long idx) {
+                       _blocks.remove(idx);
+                       _cnts.remove(idx);
+               }
+       }
 }

Reply via email to