This is an automated email from the ASF dual-hosted git repository.

ywcb00 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new f45ae975b9 [MINOR] FederatedLookupTable Eviction Fix
f45ae975b9 is described below

commit f45ae975b904008ac4c5fd6ebbc587dbfa865b10
Author: ywcb00 <[email protected]>
AuthorDate: Thu Jul 28 17:31:10 2022 +0200

    [MINOR] FederatedLookupTable Eviction Fix
    
    - Remove the coordinator-specific entry from the FederatedLookupTable
      when receiving a CLEAR request.
    - Fix scalar broadcasting of federated left indexing instruction.
    - Avoid ConcurrentModificationException by changing the ArrayList for
      the coordinator's traffic bytes in the federated statistics to a
      CopyOnWriteArrayList
    - Avoid race condition while obtaining the heavy hitters for statistics
    
    Closes #1663.
---
 .../federated/FederatedLookupTable.java            | 22 +++++++++++++++++-----
 .../federated/FederatedStatistics.java             |  3 ++-
 .../federated/FederatedWorkerHandler.java          |  7 +++++--
 .../instructions/fed/IndexingFEDInstruction.java   |  2 +-
 .../java/org/apache/sysds/utils/Statistics.java    | 12 ++++++------
 5 files changed, 31 insertions(+), 15 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLookupTable.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLookupTable.java
index afba8ac42a..188c57ed95 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLookupTable.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLookupTable.java
@@ -47,10 +47,6 @@ public class FederatedLookupTable {
                _lookup_table = new ConcurrentHashMap<>();
        }
 
-       public void clear() {
-               _lookup_table.clear();
-       }
-       
        /**
         * Get the ExecutionContextMap corresponding to the given host and pid 
of the
         * requesting coordinator from the lookup table. Create a new
@@ -61,9 +57,9 @@ public class FederatedLookupTable {
         * @return ExecutionContextMap the ECM corresponding to the requesting 
coordinator
         */
        public ExecutionContextMap getECM(String host, long pid) {
-               LOG.trace("Getting the ExecutionContextMap for coordinator " + 
pid + "@" + host);
                long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
                FedUniqueCoordID funCID = new FedUniqueCoordID(host, pid);
+               LOG.trace("Getting the ExecutionContextMap for coordinator " + 
funCID.toString());
                ExecutionContextMap ecm = _lookup_table.computeIfAbsent(funCID,
                        k -> createNewECM());
                if(ecm == null) {
@@ -79,6 +75,22 @@ public class FederatedLookupTable {
                return ecm;
        }
 
+       /**
+        * Remove the ExecutionContextMap corresponding to the given host and 
pid of the
+        * requesting coordinator from the lookup table. Do nothing if no entry
+        * is associated to the host and pid.
+        *
+        * @param host the host string of the requesting coordinator (usually 
IP address)
+        * @param pid the process id of the requesting coordinator
+        */
+       public void removeECM(String host, long pid) {
+               FedUniqueCoordID funCID = new FedUniqueCoordID(host, pid);
+               LOG.trace("Removing the ExecutionContextMap of coordinator " + 
funCID.toString());
+               if(_lookup_table.remove(funCID) == null)
+                       LOG.warn("Removing federated execution context map 
failed. "
+                               + "No valid resolution for " + 
funCID.toString() + " found.");
+       }
+
        /**
         * Check if there is a mapped ExecutionContextMap for the coordinator
         * with the given host and pid.
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java
index 17b4012fec..b53ef801a5 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java
@@ -30,6 +30,7 @@ import java.time.format.DateTimeFormatter;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Comparator;
+import java.util.concurrent.CopyOnWriteArrayList;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
@@ -98,7 +99,7 @@ public class FederatedStatistics {
        private static final LongAdder fedSerializationReuseBytes = new 
LongAdder();
        // Traffic between federated worker and a coordinator site
        // in the form of [{ datetime, coordinatorAddress, transferredBytes }, 
{ ... }] }
-       private static List<Triple<LocalDateTime, String, Long>> 
coordinatorsTrafficBytes = new ArrayList<>();
+       private static CopyOnWriteArrayList<Triple<LocalDateTime, String, 
Long>> coordinatorsTrafficBytes = new CopyOnWriteArrayList<>();
 
        public static void logServerTraffic(long read, long written) {
                bytesReceived.add(read);
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index bfeb19cc16..509e0998eb 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -188,6 +188,7 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
                        
                FederatedResponse response = null; // last response
                boolean containsCLEAR = false;
+               long clearReqPid = -1;
                for(int i = 0; i < requests.length; i++) {
                        final FederatedRequest request = requests[i];
                        final RequestType t = request.getType();
@@ -233,12 +234,14 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
                                }
                        }
 
-                       if(t == RequestType.CLEAR)
+                       if(t == RequestType.CLEAR) {
                                containsCLEAR = true;
+                               clearReqPid = request.getPID();
+                       }
                }
 
                if(containsCLEAR) {
-                       _flt.clear();
+                       _flt.removeECM(remoteHost, clearReqPid);
                        printStatistics();
                }
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java
index 128fc1d4a6..6fc8c24a7e 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java
@@ -331,7 +331,7 @@ public final class IndexingFEDInstruction extends 
UnaryFEDInstruction {
                        FederatedRequest fr3 = fedMap.cleanup(getTID(), 
fr1.getID());
 
                        if(fr2.length == 1)
-                               fedMap.execute(getTID(), true, fr2, fr1, fr3);
+                               fedMap.execute(getTID(), true, fr1, fr2[0], 
fr3);
                        else
                                fedMap.execute(getTID(), true, ranges, 
fr2[cpVarInstIx], fr2[from], fr1, fr3);
                }
diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java 
b/src/main/java/org/apache/sysds/utils/Statistics.java
index aece9b655a..454ecac6e1 100644
--- a/src/main/java/org/apache/sysds/utils/Statistics.java
+++ b/src/main/java/org/apache/sysds/utils/Statistics.java
@@ -392,12 +392,11 @@ public class Statistics
         */
        @SuppressWarnings("unchecked")
        public static String getHeavyHitters(int num) {
-               int len = _instStats.size();
-               if (num <= 0 || len <= 0)
+               if (num <= 0 || _instStats.size() <= 0)
                        return "-";
 
                // get top k via sort
-               Entry<String, InstStats>[] tmp = 
_instStats.entrySet().toArray(new Entry[len]);
+               Entry<String, InstStats>[] tmp = 
_instStats.entrySet().toArray(Entry[]::new);
                Arrays.sort(tmp, new Comparator<Entry<String, InstStats>>() {
                        @Override
                        public int compare(Entry<String, InstStats> e1, 
Entry<String, InstStats> e2) {
@@ -410,6 +409,7 @@ public class Statistics
                final String timeSCol = "Time(s)";
                final String countCol = "Count";
                StringBuilder sb = new StringBuilder();
+               int len = tmp.length;
                int numHittersToDisplay = Math.min(num, len);
                int maxNumLen = String.valueOf(numHittersToDisplay).length();
                int maxInstLen = instCol.length();
@@ -466,11 +466,10 @@ public class Statistics
 
        @SuppressWarnings("unchecked")
        public static String getCPHeavyHittersMem(int num) {
-               int n = _cpMemObjs.size();
-               if ((n <= 0) || (num <= 0))
+               if ((_cpMemObjs.size() <= 0) || (num <= 0))
                        return "-";
 
-               Entry<String,Double>[] entries = 
_cpMemObjs.entrySet().toArray(new Entry[_cpMemObjs.size()]);
+               Entry<String,Double>[] entries = 
_cpMemObjs.entrySet().toArray(Entry[]::new);
                Arrays.sort(entries, new Comparator<Entry<String, Double>>() {
                        @Override
                        public int compare(Entry<String, Double> a, 
Entry<String, Double> b) {
@@ -478,6 +477,7 @@ public class Statistics
                        }
                });
 
+               int n = entries.length;
                int numHittersToDisplay = Math.min(num, n);
                int numPadLen = String.format("%d", 
numHittersToDisplay).length();
                int maxNameLength = 0;

Reply via email to