This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 31aff0ea4e [MINOR] Federated Compressed Workload Estimation Fixes
31aff0ea4e is described below
commit 31aff0ea4e87da2fd837a68d6fbeeaa7d404bb3b
Author: Sebastian Baunsgaard <[email protected]>
AuthorDate: Sat Dec 28 14:29:16 2024 +0100
[MINOR] Federated Compressed Workload Estimation Fixes
This commit fixes a bug for asynchronous compression on
federated workers. Previously, the compression would only tigger if
the sum of federated requests instructions summed to % 10 == 9.
This bug effectively made it impossible to perform compression if all
requests send contained an even number of instructions.
This commit change the logic to instruction counter >= 10.
Closes #2159
Signed-off-by: Sebastian Baunsgaard <[email protected]>
---
.../federated/FederatedWorkloadAnalyzer.java | 26 ++++++++++++++++++----
1 file changed, 22 insertions(+), 4 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkloadAnalyzer.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkloadAnalyzer.java
index 1db1a458be..89abacbf19 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkloadAnalyzer.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkloadAnalyzer.java
@@ -35,7 +35,7 @@ public class FederatedWorkloadAnalyzer {
protected static final Log LOG =
LogFactory.getLog(FederatedWorkerHandler.class.getName());
/** Frequency value for how many instructions before we do a pass for
compression */
- private static int compressRunFrequency = 10;
+ private static final int compressRunFrequency = 10;
/** Instruction maps to interesting variables */
private final ConcurrentHashMap<Long, ConcurrentHashMap<Long,
InstructionTypeCounter>> m;
@@ -49,14 +49,17 @@ public class FederatedWorkloadAnalyzer {
}
public void incrementWorkload(ExecutionContext ec, long tid,
Instruction ins) {
+ LOG.error("Increment Workload " + tid + " " + ins + "\n" +
this);
if(ins instanceof ComputationCPInstruction)
incrementWorkload(ec, tid, (ComputationCPInstruction)
ins);
// currently we ignore everything that is not CP instructions
}
public void compressRun(ExecutionContext ec, long tid) {
- if(counter % compressRunFrequency == compressRunFrequency - 1)
+ if(counter >= compressRunFrequency ){
+ counter = 0;
get(tid).forEach((K, V) ->
CompressedMatrixBlockFactory.compressAsync(ec, Long.toString(K), V));
+ }
}
private void incrementWorkload(ExecutionContext ec, long tid,
ComputationCPInstruction cpIns) {
@@ -77,13 +80,16 @@ public class FederatedWorkloadAnalyzer {
int r2 = (int) d2.getDim(0);
int c2 = (int) d2.getDim(1);
if(validSize(r1, c1)) {
- getOrMakeCounter(mm,
Long.parseLong(n1)).incRMM(r1);
+ getOrMakeCounter(mm,
Long.parseLong(n1)).incRMM(c2);
+ // safety add overlapping decompress for RMM
+ getOrMakeCounter(mm,
Long.parseLong(n1)).incOverlappingDecompressions();
counter++;
}
if(validSize(r2, c2)) {
- getOrMakeCounter(mm,
Long.parseLong(n2)).incLMM(c2);
+ getOrMakeCounter(mm,
Long.parseLong(n2)).incLMM(r1);
counter++;
}
+
}
}
@@ -111,4 +117,16 @@ public class FederatedWorkloadAnalyzer {
private static boolean validSize(int nRow, int nCol) {
return nRow > 90 && nRow >= nCol;
}
+
+ @Override
+ public String toString(){
+ StringBuilder sb = new StringBuilder();
+ sb.append(this.getClass().getSimpleName());
+ sb.append(" Counter: ");
+ sb.append(counter);
+ sb.append("\n");
+ sb.append(m);
+
+ return sb.toString();
+ }
}