This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new df26e676a0 [SYSTEMDS-3918] New out-of-core tmp aggregation util
df26e676a0 is described below
commit df26e676a03b8957fe402c13e415bf854e4d6fe3
Author: Jessica Priebe <[email protected]>
AuthorDate: Sun Sep 21 16:48:17 2025 +0200
[SYSTEMDS-3918] New out-of-core tmp aggregation util
Closes #2318.
---
.../ooc/AggregateUnaryOOCInstruction.java | 56 ++++++++++------------
.../ooc/MatrixVectorBinaryOOCInstruction.java | 28 ++++-------
.../runtime/instructions/ooc/OOCInstruction.java | 54 +++++++++++++++++++++
3 files changed, 89 insertions(+), 49 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java
index a656cd337c..b71cdaaeb5 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java
@@ -75,14 +75,11 @@ public class AggregateUnaryOOCInstruction extends
ComputationOOCInstruction {
int blen = ConfigurationManager.getBlocksize();
if (aggun.isRowAggregate() || aggun.isColAggregate()) {
- // intermediate state per aggregation index
- HashMap<Long, MatrixBlock> aggs = new HashMap<>(); //
partial aggregates
- HashMap<Long, MatrixBlock> corrs = new HashMap<>(); //
correction blocks
- HashMap<Long, Integer> cnt = new HashMap<>(); //
processed block count per agg idx
-
DataCharacteristics chars =
ec.getDataCharacteristics(input1.getName());
// number of blocks to process per aggregation idx (row
or column dim)
- long nBlocks = aggun.isRowAggregate()?
chars.getNumColBlocks() : chars.getNumRowBlocks();
+ long emitThreshold = aggun.isRowAggregate()?
chars.getNumColBlocks() : chars.getNumRowBlocks();
+ OOCMatrixBlockTracker aggTracker = new
OOCMatrixBlockTracker(emitThreshold);
+ HashMap<Long, MatrixBlock> corrs = new HashMap<>(); //
correction blocks
LocalTaskQueue<IndexedMatrixValue> qOut = new
LocalTaskQueue<>();
ec.getMatrixObject(output).setStreamHandle(qOut);
@@ -94,9 +91,8 @@ public class AggregateUnaryOOCInstruction extends
ComputationOOCInstruction {
while((tmp = q.dequeueTask())
!= LocalTaskQueue.NO_MORE_TASKS) {
long idx =
aggun.isRowAggregate() ?
tmp.getIndexes().getRowIndex() : tmp.getIndexes().getColumnIndex();
-
if(aggs.containsKey(idx)) {
- // update
existing partial aggregate for this idx
- MatrixBlock ret
= aggs.get(idx);
+ MatrixBlock ret =
aggTracker.get(idx);
+ if(ret != null) {
MatrixBlock
corr = corrs.get(idx);
// aggregation
@@ -105,9 +101,10 @@ public class AggregateUnaryOOCInstruction extends
ComputationOOCInstruction {
OperationsOnMatrixValues.incrementalAggregation(ret,
_aop.existsCorrection() ? corr : null, ltmp, _aop, true);
-
aggs.replace(idx, ret);
-
corrs.replace(idx, corr);
-
cnt.replace(idx, cnt.get(idx) + 1);
+ if
(!aggTracker.putAndIncrementCount(idx, ret)){
+
corrs.replace(idx, corr);
+
continue;
+ }
}
else {
// first block
for this idx - init aggregate and correction
@@ -115,7 +112,7 @@ public class AggregateUnaryOOCInstruction extends
ComputationOOCInstruction {
int rows =
tmp.getValue().getNumRows();
int cols =
tmp.getValue().getNumColumns();
int extra =
_aop.correction.getNumRemovedRowsColumns();
- MatrixBlock ret
= aggun.isRowAggregate()? new MatrixBlock(rows, 1 + extra, false) : new
MatrixBlock(1 + extra, cols, false);
+ ret =
aggun.isRowAggregate()? new MatrixBlock(rows, 1 + extra, false) : new
MatrixBlock(1 + extra, cols, false);
MatrixBlock
corr = aggun.isRowAggregate()? new MatrixBlock(rows, 1 + extra, false) : new
MatrixBlock(1 + extra, cols, false);
// aggregation
@@ -124,25 +121,24 @@ public class AggregateUnaryOOCInstruction extends
ComputationOOCInstruction {
OperationsOnMatrixValues.incrementalAggregation(ret,
_aop.existsCorrection() ? corr : null, ltmp, _aop, true);
- aggs.put(idx,
ret);
- corrs.put(idx,
corr);
- cnt.put(idx, 1);
+
if(emitThreshold > 1){
+
aggTracker.putAndIncrementCount(idx, ret);
+
corrs.put(idx, corr);
+
continue;
+ }
}
- if(cnt.get(idx) ==
nBlocks) {
- // all input
blocks for this idx processed - emit aggregated block
- MatrixBlock ret
= aggs.get(idx);
- // drop
correction row/col
-
ret.dropLastRowsOrColumns(_aop.correction);
- MatrixIndexes
midx = aggun.isRowAggregate()? new
MatrixIndexes(tmp.getIndexes().getRowIndex(), 1) : new MatrixIndexes(1,
tmp.getIndexes().getColumnIndex());
-
IndexedMatrixValue tmpOut = new IndexedMatrixValue(midx, ret);
-
-
qOut.enqueueTask(tmpOut);
- // drop
intermediate states
-
aggs.remove(idx);
-
corrs.remove(idx);
- cnt.remove(idx);
- }
+ // all input blocks for
this idx processed - emit aggregated block
+
ret.dropLastRowsOrColumns(_aop.correction);
+ MatrixIndexes midx =
aggun.isRowAggregate() ?
+ new
MatrixIndexes(tmp.getIndexes().getRowIndex(), 1) :
+ new
MatrixIndexes(1, tmp.getIndexes().getColumnIndex());
+ IndexedMatrixValue
tmpOut = new IndexedMatrixValue(midx, ret);
+
+
qOut.enqueueTask(tmpOut);
+ // drop intermediate
states
+ aggTracker.remove(idx);
+ corrs.remove(idx);
}
qOut.closeInput();
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java
index 5e2d36d9df..ae84e4b541 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java
@@ -82,7 +82,8 @@ public class MatrixVectorBinaryOOCInstruction extends
ComputationOOCInstruction
}
// number of colBlocks for early block output
- long nBlocks = min.getDataCharacteristics().getNumColBlocks();
+ long emitThreshold =
min.getDataCharacteristics().getNumColBlocks();
+ OOCMatrixBlockTracker aggTracker = new
OOCMatrixBlockTracker(emitThreshold);
LocalTaskQueue<IndexedMatrixValue> qIn = min.getStreamHandle();
LocalTaskQueue<IndexedMatrixValue> qOut = new
LocalTaskQueue<>();
@@ -95,8 +96,6 @@ public class MatrixVectorBinaryOOCInstruction extends
ComputationOOCInstruction
pool.submit(() -> {
IndexedMatrixValue tmp = null;
try {
- HashMap<Long, MatrixBlock>
partialResults = new HashMap<>();
- HashMap<Long, Integer> cnt = new
HashMap<>();
while((tmp = qIn.dequeueTask()) !=
LocalTaskQueue.NO_MORE_TASKS) {
MatrixBlock matrixBlock =
(MatrixBlock) tmp.getValue();
long rowIndex =
tmp.getIndexes().getRowIndex();
@@ -108,31 +107,22 @@ public class MatrixVectorBinaryOOCInstruction extends
ComputationOOCInstruction
matrixBlock,
vectorSlice, new MatrixBlock(), (AggregateBinaryOperator) _optr);
// for single column block, no
aggregation neeeded
- if( min.getNumColumns() <=
min.getBlocksize() ) {
+ if(emitThreshold == 1) {
qOut.enqueueTask(new
IndexedMatrixValue(tmp.getIndexes(), partialResult));
}
else {
// aggregation
- MatrixBlock currAgg =
partialResults.get(rowIndex);
+ MatrixBlock currAgg =
aggTracker.get(rowIndex);
if (currAgg == null) {
-
partialResults.put(rowIndex, partialResult);
-
cnt.put(rowIndex, 1);
+
aggTracker.putAndIncrementCount(rowIndex, partialResult);
}
else {
-
currAgg.binaryOperationsInPlace(plus, partialResult);
- int newCnt =
cnt.get(rowIndex) + 1;
-
- if(newCnt ==
nBlocks){
+ currAgg =
currAgg.binaryOperations(plus, partialResult);
+ if
(aggTracker.putAndIncrementCount(rowIndex, currAgg)){
//
early block output: emit aggregated block
MatrixIndexes idx = new MatrixIndexes(rowIndex, 1L);
-
MatrixBlock result = partialResults.get(rowIndex);
-
qOut.enqueueTask(new IndexedMatrixValue(idx, result));
-
partialResults.remove(rowIndex);
-
cnt.remove(rowIndex);
- }
- else {
- //
maintain aggregation counts if not output-ready yet
-
cnt.replace(rowIndex, newCnt);
+
qOut.enqueueTask(new IndexedMatrixValue(idx, currAgg));
+
aggTracker.remove(rowIndex);
}
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
index 0495dcfde5..2e5e6f41eb 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java
@@ -24,8 +24,11 @@ import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.Instruction;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.Operator;
+import java.util.HashMap;
+
public abstract class OOCInstruction extends Instruction {
protected static final Log LOG =
LogFactory.getLog(OOCInstruction.class.getName());
@@ -82,4 +85,55 @@ public abstract class OOCInstruction extends Instruction {
if(DMLScript.LINEAGE_DEBUGGER)
ec.maintainLineageDebuggerInfo(this);
}
+
+ /**
+ * Tracks blocks and their counts to enable early emission
+ * once all blocks for a given index are processed.
+ */
+ public static class OOCMatrixBlockTracker {
+ private final long _emitThreshold;
+ private final HashMap<Long, MatrixBlock> _blocks;
+ private final HashMap<Long, Integer> _cnts;
+
+ public OOCMatrixBlockTracker(long emitThreshold) {
+ _emitThreshold = emitThreshold;
+ _blocks = new HashMap<>();
+ _cnts = new HashMap<>();
+ }
+
+ /**
+ * Adds or updates a block for the given index and updates its
internal count.
+ * @param idx block index
+ * @param block MatrixBlock
+ * @return true if the block count reached the threshold (ready
to emit), false otherwise
+ */
+ public boolean putAndIncrementCount(Long idx, MatrixBlock
block) {
+ _blocks.put(idx, block);
+ int newCnt = _cnts.getOrDefault(idx, 0) + 1;
+ if( newCnt < _emitThreshold )
+ _cnts.put(idx, newCnt);
+ return newCnt == _emitThreshold;
+ }
+
+ public boolean incrementCount(Long idx) {
+ int newCnt = _cnts.get(idx) + 1;
+ if( newCnt < _emitThreshold )
+ _cnts.put(idx, newCnt);
+ return newCnt == _emitThreshold;
+ }
+
+ public void putAndInitCount(Long idx, MatrixBlock block) {
+ _blocks.put(idx, block);
+ _cnts.put(idx, 0);
+ }
+
+ public MatrixBlock get(Long idx) {
+ return _blocks.get(idx);
+ }
+
+ public void remove(Long idx) {
+ _blocks.remove(idx);
+ _cnts.remove(idx);
+ }
+ }
}