This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 2cb1a6070c [SYSTEMDS-3908] Improved OOC matrix multiplication w/ early 
outputs
2cb1a6070c is described below

commit 2cb1a6070cbfc687140c1adf8d80a55ce558be64
Author: Jessica Priebe <[email protected]>
AuthorDate: Thu Aug 21 17:25:13 2025 +0200

    [SYSTEMDS-3908] Improved OOC matrix multiplication w/ early outputs
    
    Closes #2310.
---
 .../ooc/MatrixVectorBinaryOOCInstruction.java      | 35 +++++++++++++++-------
 1 file changed, 24 insertions(+), 11 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java
index a36dc7c885..5e2d36d9df 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MatrixVectorBinaryOOCInstruction.java
@@ -20,7 +20,6 @@
 package org.apache.sysds.runtime.instructions.ooc;
 
 import java.util.HashMap;
-import java.util.Map;
 import java.util.concurrent.ExecutorService;
 
 import org.apache.sysds.common.Opcodes;
@@ -82,6 +81,9 @@ public class MatrixVectorBinaryOOCInstruction extends 
ComputationOOCInstruction
                        partitionedVector.put(key, vectorSlice);
                }
 
+               // number of colBlocks for early block output
+               long nBlocks = min.getDataCharacteristics().getNumColBlocks();
+
                LocalTaskQueue<IndexedMatrixValue> qIn = min.getStreamHandle();
                LocalTaskQueue<IndexedMatrixValue> qOut = new 
LocalTaskQueue<>();
                BinaryOperator plus = 
InstructionUtils.parseBinaryOperator(Opcodes.PLUS.toString());
@@ -94,6 +96,7 @@ public class MatrixVectorBinaryOOCInstruction extends 
ComputationOOCInstruction
                                IndexedMatrixValue tmp = null;
                                try {
                                        HashMap<Long, MatrixBlock> 
partialResults = new  HashMap<>();
+                                       HashMap<Long, Integer> cnt = new 
HashMap<>();
                                        while((tmp = qIn.dequeueTask()) != 
LocalTaskQueue.NO_MORE_TASKS) {
                                                MatrixBlock matrixBlock = 
(MatrixBlock) tmp.getValue();
                                                long rowIndex = 
tmp.getIndexes().getRowIndex();
@@ -109,19 +112,29 @@ public class MatrixVectorBinaryOOCInstruction extends 
ComputationOOCInstruction
                                                        qOut.enqueueTask(new 
IndexedMatrixValue(tmp.getIndexes(), partialResult));
                                                }
                                                else {
+                                                       // aggregation
                                                        MatrixBlock currAgg = 
partialResults.get(rowIndex);
-                                                       if (currAgg == null)
+                                                       if (currAgg == null) {
                                                                
partialResults.put(rowIndex, partialResult);
-                                                       else
+                                                               
cnt.put(rowIndex, 1);
+                                                       }
+                                                       else {
                                                                
currAgg.binaryOperationsInPlace(plus, partialResult);
-                                               }
-                                       }
-                                       
-                                       // emit aggregated blocks
-                                       if( min.getNumColumns() > 
min.getBlocksize() ) {
-                                               for (Map.Entry<Long, 
MatrixBlock> entry : partialResults.entrySet()) {
-                                                       MatrixIndexes 
outIndexes = new MatrixIndexes(entry.getKey(), 1L);
-                                                       qOut.enqueueTask(new 
IndexedMatrixValue(outIndexes, entry.getValue()));
+                                                               int newCnt = 
cnt.get(rowIndex) + 1;
+                                                               
+                                                               if(newCnt == 
nBlocks){
+                                                                       // 
early block output: emit aggregated block
+                                                                       
MatrixIndexes idx = new MatrixIndexes(rowIndex, 1L);
+                                                                       
MatrixBlock result = partialResults.get(rowIndex);
+                                                                       
qOut.enqueueTask(new IndexedMatrixValue(idx, result));
+                                                                       
partialResults.remove(rowIndex);
+                                                                       
cnt.remove(rowIndex);
+                                                               }
+                                                               else {
+                                                                       // 
maintain aggregation counts if not output-ready yet
+                                                                       
cnt.replace(rowIndex, newCnt);
+                                                               }
+                                                       }
                                                }
                                        }
                                }

Reply via email to