This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 2cb1a6070c [SYSTEMDS-3908] Improved OOC matrix multiplication w/ early
outputs
2cb1a6070c is described below
commit 2cb1a6070cbfc687140c1adf8d80a55ce558be64
Author: Jessica Priebe <[email protected]>
AuthorDate: Thu Aug 21 17:25:13 2025 +0200
[SYSTEMDS-3908] Improved OOC matrix multiplication w/ early outputs
Closes #2310.
---
.../ooc/MatrixVectorBinaryOOCInstruction.java | 35 +++++++++++++++-------
1 file changed, 24 insertions(+), 11 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java
index a36dc7c885..5e2d36d9df 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java
@@ -20,7 +20,6 @@
package org.apache.sysds.runtime.instructions.ooc;
import java.util.HashMap;
-import java.util.Map;
import java.util.concurrent.ExecutorService;
import org.apache.sysds.common.Opcodes;
@@ -82,6 +81,9 @@ public class MatrixVectorBinaryOOCInstruction extends
ComputationOOCInstruction
partitionedVector.put(key, vectorSlice);
}
+ // number of colBlocks for early block output
+ long nBlocks = min.getDataCharacteristics().getNumColBlocks();
+
LocalTaskQueue<IndexedMatrixValue> qIn = min.getStreamHandle();
LocalTaskQueue<IndexedMatrixValue> qOut = new
LocalTaskQueue<>();
BinaryOperator plus =
InstructionUtils.parseBinaryOperator(Opcodes.PLUS.toString());
@@ -94,6 +96,7 @@ public class MatrixVectorBinaryOOCInstruction extends
ComputationOOCInstruction
IndexedMatrixValue tmp = null;
try {
HashMap<Long, MatrixBlock>
partialResults = new HashMap<>();
+ HashMap<Long, Integer> cnt = new
HashMap<>();
while((tmp = qIn.dequeueTask()) !=
LocalTaskQueue.NO_MORE_TASKS) {
MatrixBlock matrixBlock =
(MatrixBlock) tmp.getValue();
long rowIndex =
tmp.getIndexes().getRowIndex();
@@ -109,19 +112,29 @@ public class MatrixVectorBinaryOOCInstruction extends
ComputationOOCInstruction
qOut.enqueueTask(new
IndexedMatrixValue(tmp.getIndexes(), partialResult));
}
else {
+ // aggregation
MatrixBlock currAgg =
partialResults.get(rowIndex);
- if (currAgg == null)
+ if (currAgg == null) {
partialResults.put(rowIndex, partialResult);
- else
+
cnt.put(rowIndex, 1);
+ }
+ else {
currAgg.binaryOperationsInPlace(plus, partialResult);
- }
- }
-
- // emit aggregated blocks
- if( min.getNumColumns() >
min.getBlocksize() ) {
- for (Map.Entry<Long,
MatrixBlock> entry : partialResults.entrySet()) {
- MatrixIndexes
outIndexes = new MatrixIndexes(entry.getKey(), 1L);
- qOut.enqueueTask(new
IndexedMatrixValue(outIndexes, entry.getValue()));
+ int newCnt =
cnt.get(rowIndex) + 1;
+
+ if(newCnt ==
nBlocks){
+ //
early block output: emit aggregated block
+
MatrixIndexes idx = new MatrixIndexes(rowIndex, 1L);
+
MatrixBlock result = partialResults.get(rowIndex);
+
qOut.enqueueTask(new IndexedMatrixValue(idx, result));
+
partialResults.remove(rowIndex);
+
cnt.remove(rowIndex);
+ }
+ else {
+ //
maintain aggregation counts if not output-ready yet
+
cnt.replace(rowIndex, newCnt);
+ }
+ }
}
}
}