This is an automated email from the ASF dual-hosted git repository.

zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new b98fcf2c [ISSUE-328] Cleanup unused shuffle servers to avoid app 
heartbeat after stage completed (#334)
b98fcf2c is described below

commit b98fcf2c7e897898c36c7184ca7e624e1eda60a0
Author: xianjingfeng <[email protected]>
AuthorDate: Fri Nov 18 09:41:45 2022 +0800

    [ISSUE-328] Cleanup unused shuffle servers to avoid app heartbeat after 
stage completed (#334)
    
    ### What changes were proposed in this pull request?
    
    Cleanup unused shuffle servers after stage completed
    
    ### Why are the changes needed?
    
    If there are many stages in one applicaiton, spark driver will send 
heartbeat to every shuffle servers, and it may cause app expired in shuffle 
server side. And if we support decommission in the future, it will cause 
shuffle difficult to exit. #328
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Added
---
 .../client/impl/ShuffleWriteClientImpl.java        | 57 ++++++++++++++++++++--
 .../client/impl/ShuffleWriteClientImplTest.java    | 20 ++++++++
 2 files changed, 72 insertions(+), 5 deletions(-)

diff --git 
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
 
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
index e44b9cd1..3bccd467 100644
--- 
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
+++ 
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
@@ -17,6 +17,7 @@
 
 package org.apache.uniffle.client.impl;
 
+import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
@@ -86,7 +87,8 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
   private int retryMax;
   private long retryIntervalMax;
   private List<CoordinatorClient> coordinatorClients = Lists.newLinkedList();
-  private Set<ShuffleServerInfo> shuffleServerInfoSet = 
Sets.newConcurrentHashSet();
+  //appId -> shuffleId -> servers
+  private Map<String, Map<Integer, Set<ShuffleServerInfo>>> 
shuffleServerInfoMap = Maps.newConcurrentMap();
   private CoordinatorClientFactory coordinatorClientFactory;
   private ExecutorService heartBeatExecutorService;
   private int replica;
@@ -350,7 +352,7 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
     String msg = "Error happened when registerShuffle with appId[" + appId + 
"], shuffleId[" + shuffleId
         + "], " + shuffleServerInfo;
     throwExceptionIfNecessary(response, msg);
-    shuffleServerInfoSet.add(shuffleServerInfo);
+    addShuffleServer(appId, shuffleId, shuffleServerInfo);
   }
 
   @Override
@@ -551,7 +553,8 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
   public void sendAppHeartbeat(String appId, long timeoutMs) {
     RssAppHeartBeatRequest request = new RssAppHeartBeatRequest(appId, 
timeoutMs);
     List<Callable<Void>> callableList = Lists.newArrayList();
-    shuffleServerInfoSet.stream().forEach(shuffleServerInfo -> {
+    Set<ShuffleServerInfo> allShuffleServers = getAllShuffleServers(appId);
+    allShuffleServers.forEach(shuffleServerInfo -> {
           callableList.add(() -> {
             try {
               ShuffleServerClient client =
@@ -607,7 +610,16 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
     RssUnregisterShuffleRequest request = new 
RssUnregisterShuffleRequest(appId, shuffleId);
     List<Callable<Void>> callableList = Lists.newArrayList();
 
-    shuffleServerInfoSet.stream().forEach(shuffleServerInfo -> {
+    Map<Integer, Set<ShuffleServerInfo>> appServerMap = 
shuffleServerInfoMap.get(appId);
+    if (appServerMap == null) {
+      return;
+    }
+    Set<ShuffleServerInfo> shuffleServerInfos = appServerMap.get(shuffleId);
+    if (shuffleServerInfos == null) {
+      return;
+    }
+
+    shuffleServerInfos.forEach(shuffleServerInfo -> {
           callableList.add(() -> {
             try {
               ShuffleServerClient client =
@@ -628,7 +640,7 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
     try {
       executorService =
           Executors.newFixedThreadPool(
-              Math.min(unregisterThreadPoolSize, shuffleServerInfoSet.size()),
+              Math.min(unregisterThreadPoolSize, shuffleServerInfos.size()),
               ThreadUtils.getThreadFactory("unregister-shuffle-%d")
           );
       List<Future<Void>> futures = executorService.invokeAll(callableList, 
unregisterRequestTimeSec, TimeUnit.SECONDS);
@@ -643,6 +655,7 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
       if (executorService != null) {
         executorService.shutdownNow();
       }
+      removeShuffleServer(appId, shuffleId);
     }
   }
 
@@ -653,9 +666,43 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
     }
   }
 
+  Set<ShuffleServerInfo> getAllShuffleServers(String appId) {
+    Map<Integer, Set<ShuffleServerInfo>> appServerMap = 
shuffleServerInfoMap.get(appId);
+    if (appServerMap == null) {
+      return Collections.EMPTY_SET;
+    }
+    Set<ShuffleServerInfo> serverInfos = Sets.newHashSet();
+    appServerMap.values().forEach((serverSet) -> {
+      serverInfos.addAll(serverSet);
+    });
+    return serverInfos;
+  }
+
   @VisibleForTesting
   public ShuffleServerClient getShuffleServerClient(ShuffleServerInfo 
shuffleServerInfo) {
     return 
ShuffleServerClientFactory.getInstance().getShuffleServerClient(clientType, 
shuffleServerInfo);
   }
 
+  @VisibleForTesting
+  void addShuffleServer(String appId, int shuffleId, ShuffleServerInfo 
serverInfo) {
+    Map<Integer, Set<ShuffleServerInfo>> appServerMap = 
shuffleServerInfoMap.get(appId);
+    if (appServerMap == null) {
+      appServerMap = Maps.newConcurrentMap();
+      shuffleServerInfoMap.put(appId, appServerMap);
+    }
+    Set<ShuffleServerInfo> shuffleServerInfos = appServerMap.get(shuffleId);
+    if (shuffleServerInfos == null) {
+      shuffleServerInfos = Sets.newConcurrentHashSet();
+      appServerMap.put(shuffleId, shuffleServerInfos);
+    }
+    shuffleServerInfos.add(serverInfo);
+  }
+
+  @VisibleForTesting
+  void removeShuffleServer(String appId, int shuffleId) {
+    Map<Integer, Set<ShuffleServerInfo>> appServerMap = 
shuffleServerInfoMap.get(appId);
+    if (appServerMap != null) {
+      appServerMap.remove(shuffleId);
+    }
+  }
 }
diff --git 
a/client/src/test/java/org/apache/uniffle/client/impl/ShuffleWriteClientImplTest.java
 
b/client/src/test/java/org/apache/uniffle/client/impl/ShuffleWriteClientImplTest.java
index 71fdc637..414d203f 100644
--- 
a/client/src/test/java/org/apache/uniffle/client/impl/ShuffleWriteClientImplTest.java
+++ 
b/client/src/test/java/org/apache/uniffle/client/impl/ShuffleWriteClientImplTest.java
@@ -30,6 +30,7 @@ import 
org.apache.uniffle.client.response.SendShuffleDataResult;
 import org.apache.uniffle.common.ShuffleBlockInfo;
 import org.apache.uniffle.common.ShuffleServerInfo;
 
+import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.Mockito.doReturn;
@@ -56,4 +57,23 @@ public class ShuffleWriteClientImplTest {
 
     assertTrue(result.getFailedBlockIds().contains(10L));
   }
+
+  @Test
+  public void testRegisterAndUnRegisterShuffleServer() {
+    ShuffleWriteClientImpl shuffleWriteClient =
+        new ShuffleWriteClientImpl("GRPC", 3, 2000, 4, 1, 1, 1, true, 1, 1, 
10, 10);
+    String appId1 = "testRegisterAndUnRegisterShuffleServer-1";
+    String appId2 = "testRegisterAndUnRegisterShuffleServer-2";
+    ShuffleServerInfo server1 = new ShuffleServerInfo("host1-0", "host1", 0);
+    ShuffleServerInfo server2 = new ShuffleServerInfo("host2-0", "host2", 0);
+    ShuffleServerInfo server3 = new ShuffleServerInfo("host3-0", "host3", 0);
+    shuffleWriteClient.addShuffleServer(appId1, 0, server1);
+    shuffleWriteClient.addShuffleServer(appId1, 1, server2);
+    shuffleWriteClient.addShuffleServer(appId2, 1, server3);
+    assertEquals(2, shuffleWriteClient.getAllShuffleServers(appId1).size());
+    assertEquals(1, shuffleWriteClient.getAllShuffleServers(appId2).size());
+    shuffleWriteClient.addShuffleServer(appId1, 1, server1);
+    shuffleWriteClient.unregisterShuffle(appId1, 1);
+    assertEquals(1, shuffleWriteClient.getAllShuffleServers(appId1).size());
+  }
 }

Reply via email to