This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 31aff0ea4e [MINOR] Federated Compressed Workload Estimation Fixes
31aff0ea4e is described below

commit 31aff0ea4e87da2fd837a68d6fbeeaa7d404bb3b
Author: Sebastian Baunsgaard <[email protected]>
AuthorDate: Sat Dec 28 14:29:16 2024 +0100

    [MINOR] Federated Compressed Workload Estimation Fixes
    
    This commit fixes a bug for asynchronous compression on
    federated workers. Previously, the compression would only tigger if
    the sum of federated requests instructions summed to % 10 == 9.
    This bug effectively made it impossible to perform compression if all
    requests send contained an even number of instructions.
    
    This commit change the logic to instruction counter >= 10.
    
    Closes #2159
    
    Signed-off-by: Sebastian Baunsgaard <[email protected]>
---
 .../federated/FederatedWorkloadAnalyzer.java       | 26 ++++++++++++++++++----
 1 file changed, 22 insertions(+), 4 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkloadAnalyzer.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkloadAnalyzer.java
index 1db1a458be..89abacbf19 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkloadAnalyzer.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkloadAnalyzer.java
@@ -35,7 +35,7 @@ public class FederatedWorkloadAnalyzer {
        protected static final Log LOG = 
LogFactory.getLog(FederatedWorkerHandler.class.getName());
 
        /** Frequency value for how many instructions before we do a pass for 
compression */
-       private static int compressRunFrequency = 10;
+       private static final int compressRunFrequency = 10;
 
        /** Instruction maps to interesting variables */
        private final ConcurrentHashMap<Long, ConcurrentHashMap<Long, 
InstructionTypeCounter>> m;
@@ -49,14 +49,17 @@ public class FederatedWorkloadAnalyzer {
        }
 
        public void incrementWorkload(ExecutionContext ec, long tid, 
Instruction ins) {
+               LOG.error("Increment Workload  " + tid + " " + ins + "\n" + 
this);
                if(ins instanceof ComputationCPInstruction)
                        incrementWorkload(ec, tid, (ComputationCPInstruction) 
ins);
                // currently we ignore everything that is not CP instructions
        }
 
        public void compressRun(ExecutionContext ec, long tid) {
-               if(counter % compressRunFrequency == compressRunFrequency - 1)
+               if(counter >= compressRunFrequency ){
+                       counter = 0;
                        get(tid).forEach((K, V) -> 
CompressedMatrixBlockFactory.compressAsync(ec, Long.toString(K), V));
+               }
        }
 
        private void incrementWorkload(ExecutionContext ec, long tid, 
ComputationCPInstruction cpIns) {
@@ -77,13 +80,16 @@ public class FederatedWorkloadAnalyzer {
                        int r2 = (int) d2.getDim(0);
                        int c2 = (int) d2.getDim(1);
                        if(validSize(r1, c1)) {
-                               getOrMakeCounter(mm, 
Long.parseLong(n1)).incRMM(r1);
+                               getOrMakeCounter(mm, 
Long.parseLong(n1)).incRMM(c2);
+                               // safety add overlapping decompress for RMM
+                               getOrMakeCounter(mm, 
Long.parseLong(n1)).incOverlappingDecompressions();
                                counter++;
                        }
                        if(validSize(r2, c2)) {
-                               getOrMakeCounter(mm, 
Long.parseLong(n2)).incLMM(c2);
+                               getOrMakeCounter(mm, 
Long.parseLong(n2)).incLMM(r1);
                                counter++;
                        }
+                       
                }
        }
 
@@ -111,4 +117,16 @@ public class FederatedWorkloadAnalyzer {
        private static boolean validSize(int nRow, int nCol) {
                return nRow > 90 && nRow >= nCol;
        }
+
+       @Override 
+       public String toString(){
+               StringBuilder sb = new StringBuilder();
+               sb.append(this.getClass().getSimpleName());
+               sb.append("  Counter: ");
+               sb.append(counter);
+               sb.append("\n");
+               sb.append(m);
+
+               return sb.toString();
+       }
 }

Reply via email to