This is an automated email from the ASF dual-hosted git repository.

roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new c5e20d8b9 [#2153] fix(client): Fix requestShuffleAssignment in 
RssShuffleManagerBase (#2138)
c5e20d8b9 is described below

commit c5e20d8b90cbd3166639a5b6390e79fe4a64d075
Author: kqhzz <[email protected]>
AuthorDate: Thu Oct 17 10:31:25 2024 +0800

    [#2153] fix(client): Fix requestShuffleAssignment in RssShuffleManagerBase 
(#2138)
    
    ### What changes were proposed in this pull request?
    
    Modify RssShuffleManagerBase, let getShuffleAssignments and 
registerShuffleServers in one retry, and fix some code style problem
    
    ### Why are the changes needed?
    
    RssShuffleManagerBase.requestShuffleAssignment 's logic has been changed in 
#2095
    
    Fix: #2153
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Uts
---
 .../shuffle/manager/RssShuffleManagerBase.java     | 40 ++++++------
 .../spark/shuffle/DelegationRssShuffleManager.java | 24 ++++---
 .../spark/shuffle/DelegationRssShuffleManager.java | 22 ++++---
 .../client/impl/ShuffleWriteClientImpl.java        | 34 +++++-----
 .../uniffle/client/api/CoordinatorClient.java      |  2 +-
 .../client/factory/CoordinatorClientFactory.java   |  3 +
 .../client/impl/grpc/CoordinatorGrpcClient.java    |  3 +-
 .../impl/grpc/CoordinatorGrpcRetryableClient.java  | 74 +++++++++++++++-------
 .../client/request/RssAccessClusterRequest.java    | 30 +++++++--
 .../request/RssGetShuffleAssignmentsRequest.java   | 20 +++++-
 .../apache/uniffle/server/RegisterHeartBeat.java   | 46 ++++----------
 11 files changed, 181 insertions(+), 117 deletions(-)

diff --git 
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
index 5c5f97864..d82d3a509 100644
--- 
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
+++ 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
@@ -403,7 +403,7 @@ public abstract class RssShuffleManagerBase implements 
RssShuffleManagerInterfac
     int retryTimes = sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_MAX);
     int heartbeatThread = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_HEARTBEAT_THREAD_NUM);
     CoordinatorClientFactory coordinatorClientFactory = 
CoordinatorClientFactory.getInstance();
-    CoordinatorGrpcRetryableClient coordinatorClients =
+    CoordinatorGrpcRetryableClient coordinatorClient =
         coordinatorClientFactory.createCoordinatorClient(
             ClientType.valueOf(clientType),
             coordinators,
@@ -423,11 +423,11 @@ public abstract class RssShuffleManagerBase implements 
RssShuffleManagerInterfac
     }
     RssFetchClientConfRequest request =
         new RssFetchClientConfRequest(timeoutMs, user, Collections.emptyMap());
-    RssFetchClientConfResponse response = 
coordinatorClients.fetchClientConf(request);
+    RssFetchClientConfResponse response = 
coordinatorClient.fetchClientConf(request);
     if (response.getStatusCode() == StatusCode.SUCCESS) {
       RssSparkShuffleUtils.applyDynamicClientConf(sparkConf, 
response.getClientConf());
     }
-    coordinatorClients.close();
+    coordinatorClient.close();
   }
 
   @Override
@@ -951,23 +951,25 @@ public abstract class RssShuffleManagerBase implements 
RssShuffleManagerInterfac
     int retryTimes = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES);
     faultyServerIds.addAll(rssStageResubmitManager.getServerIdBlackList());
     try {
-      ShuffleAssignmentsInfo response =
-          shuffleWriteClient.getShuffleAssignments(
-              appId,
-              shuffleId,
-              partitionNum,
-              partitionNumPerRange,
-              assignmentTags,
-              assignmentShuffleServerNumber,
-              estimateTaskConcurrency,
-              faultyServerIds,
-              stageId,
-              stageAttemptNumber,
-              reassign,
-              retryInterval,
-              retryTimes);
       return RetryUtils.retry(
           () -> {
+            // retry zero times in shuffleWriteClient.getShuffleAssignments, 
let
+            // getShuffleAssignments and registerShuffleServers in one retry 
func
+            ShuffleAssignmentsInfo response =
+                shuffleWriteClient.getShuffleAssignments(
+                    appId,
+                    shuffleId,
+                    partitionNum,
+                    partitionNumPerRange,
+                    assignmentTags,
+                    assignmentShuffleServerNumber,
+                    estimateTaskConcurrency,
+                    faultyServerIds,
+                    stageId,
+                    stageAttemptNumber,
+                    reassign,
+                    0,
+                    0);
             registerShuffleServers(
                 appId,
                 shuffleId,
@@ -979,7 +981,7 @@ public abstract class RssShuffleManagerBase implements 
RssShuffleManagerInterfac
           retryInterval,
           retryTimes);
     } catch (Throwable throwable) {
-      throw new RssException("registerShuffle failed!", throwable);
+      throw new RssException("getShuffleAssignments or registerShuffle 
failed!", throwable);
     }
   }
 
diff --git 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/DelegationRssShuffleManager.java
 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/DelegationRssShuffleManager.java
index fa09ae878..af9295e55 100644
--- 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/DelegationRssShuffleManager.java
+++ 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/DelegationRssShuffleManager.java
@@ -43,7 +43,7 @@ public class DelegationRssShuffleManager implements 
ShuffleManager {
   private static final Logger LOG = 
LoggerFactory.getLogger(DelegationRssShuffleManager.class);
 
   private final ShuffleManager delegate;
-  private final CoordinatorGrpcRetryableClient coordinatorClients;
+  private final CoordinatorGrpcRetryableClient coordinatorClient;
   private final int accessTimeoutMs;
   private final SparkConf sparkConf;
   private String user;
@@ -55,10 +55,10 @@ public class DelegationRssShuffleManager implements 
ShuffleManager {
     this.sparkConf = sparkConf;
     accessTimeoutMs = sparkConf.get(RssSparkConfig.RSS_ACCESS_TIMEOUT_MS);
     if (isDriver) {
-      coordinatorClients = 
RssSparkShuffleUtils.createCoordinatorClients(sparkConf);
+      coordinatorClient = 
RssSparkShuffleUtils.createCoordinatorClients(sparkConf);
       delegate = createShuffleManagerInDriver();
     } else {
-      coordinatorClients = null;
+      coordinatorClient = null;
       delegate = createShuffleManagerInExecutor();
     }
 
@@ -126,13 +126,17 @@ public class DelegationRssShuffleManager implements 
ShuffleManager {
 
     Set<String> assignmentTags = 
RssSparkShuffleUtils.getAssignmentTags(sparkConf);
     try {
-      if (coordinatorClients != null) {
+      if (coordinatorClient != null) {
         RssAccessClusterResponse response =
-            coordinatorClients.accessCluster(
+            coordinatorClient.accessCluster(
                 new RssAccessClusterRequest(
-                    accessId, assignmentTags, accessTimeoutMs, 
extraProperties, user),
-                retryInterval,
-                retryTimes);
+                    accessId,
+                    assignmentTags,
+                    accessTimeoutMs,
+                    extraProperties,
+                    user,
+                    retryInterval,
+                    retryTimes));
         if (response.getStatusCode() == StatusCode.SUCCESS) {
           LOG.warn("Success to access cluster using {}", accessId);
           uuid = response.getUuid();
@@ -205,8 +209,8 @@ public class DelegationRssShuffleManager implements 
ShuffleManager {
   @Override
   public void stop() {
     delegate.stop();
-    if (coordinatorClients != null) {
-      coordinatorClients.close();
+    if (coordinatorClient != null) {
+      coordinatorClient.close();
     }
   }
 
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/DelegationRssShuffleManager.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/DelegationRssShuffleManager.java
index 81c6de7dc..64dd5449d 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/DelegationRssShuffleManager.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/DelegationRssShuffleManager.java
@@ -43,7 +43,7 @@ public class DelegationRssShuffleManager implements 
ShuffleManager {
   private static final Logger LOG = 
LoggerFactory.getLogger(DelegationRssShuffleManager.class);
 
   private final ShuffleManager delegate;
-  private final CoordinatorGrpcRetryableClient coordinatorClients;
+  private final CoordinatorGrpcRetryableClient coordinatorClient;
   private final int accessTimeoutMs;
   private final SparkConf sparkConf;
   private String user;
@@ -55,10 +55,10 @@ public class DelegationRssShuffleManager implements 
ShuffleManager {
     this.sparkConf = sparkConf;
     accessTimeoutMs = sparkConf.get(RssSparkConfig.RSS_ACCESS_TIMEOUT_MS);
     if (isDriver) {
-      coordinatorClients = 
RssSparkShuffleUtils.createCoordinatorClients(sparkConf);
+      coordinatorClient = 
RssSparkShuffleUtils.createCoordinatorClients(sparkConf);
       delegate = createShuffleManagerInDriver();
     } else {
-      coordinatorClients = null;
+      coordinatorClient = null;
       delegate = createShuffleManagerInExecutor();
     }
 
@@ -126,13 +126,17 @@ public class DelegationRssShuffleManager implements 
ShuffleManager {
 
     Set<String> assignmentTags = 
RssSparkShuffleUtils.getAssignmentTags(sparkConf);
     try {
-      if (coordinatorClients != null) {
+      if (coordinatorClient != null) {
         RssAccessClusterResponse response =
-            coordinatorClients.accessCluster(
+            coordinatorClient.accessCluster(
                 new RssAccessClusterRequest(
-                    accessId, assignmentTags, accessTimeoutMs, 
extraProperties, user),
-                retryInterval,
-                retryTimes);
+                    accessId,
+                    assignmentTags,
+                    accessTimeoutMs,
+                    extraProperties,
+                    user,
+                    retryInterval,
+                    retryTimes));
         if (response.getStatusCode() == StatusCode.SUCCESS) {
           LOG.warn("Success to access cluster using {}", accessId);
           uuid = response.getUuid();
@@ -285,7 +289,7 @@ public class DelegationRssShuffleManager implements 
ShuffleManager {
   @Override
   public void stop() {
     delegate.stop();
-    coordinatorClients.close();
+    coordinatorClient.close();
   }
 
   @Override
diff --git 
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
 
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
index 3b7e9ba12..cba7ccc06 100644
--- 
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
+++ 
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
@@ -103,7 +103,7 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
   private String clientType;
   private int retryMax;
   private long retryIntervalMax;
-  private CoordinatorGrpcRetryableClient coordinatorClients;
+  private CoordinatorGrpcRetryableClient coordinatorClient;
   // appId -> shuffleId -> servers
   private Map<String, Map<Integer, Set<ShuffleServerInfo>>> 
shuffleServerInfoMap =
       JavaUtils.newConcurrentMap();
@@ -607,7 +607,7 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
 
   @Override
   public void registerCoordinators(String coordinators, long retryIntervalMs, 
int retryTimes) {
-    coordinatorClients =
+    coordinatorClient =
         coordinatorClientFactory.createCoordinatorClient(
             ClientType.valueOf(this.clientType),
             coordinators,
@@ -618,11 +618,11 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
 
   @Override
   public Map<String, String> fetchClientConf(int timeoutMs) {
-    if (coordinatorClients == null) {
+    if (coordinatorClient == null) {
       return Maps.newHashMap();
     }
     try {
-      return coordinatorClients
+      return coordinatorClient
           .fetchClientConf(new RssFetchClientConfRequest(timeoutMs))
           .getClientConf();
     } catch (RssException e) {
@@ -632,11 +632,13 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
 
   @Override
   public RemoteStorageInfo fetchRemoteStorage(String appId) {
-    if (coordinatorClients == null) {
+    if (coordinatorClient == null) {
       return new RemoteStorageInfo("");
     }
     try {
-      return coordinatorClients.fetchRemoteStorage(new 
RssFetchRemoteStorageRequest(appId));
+      return coordinatorClient
+          .fetchRemoteStorage(new RssFetchRemoteStorageRequest(appId))
+          .getRemoteStorageInfo();
     } catch (RssException e) {
       return new RemoteStorageInfo("");
     }
@@ -670,13 +672,15 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
             faultyServerIds,
             stageId,
             stageAttemptNumber,
-            reassign);
+            reassign,
+            retryIntervalMs,
+            retryTimes);
 
     RssGetShuffleAssignmentsResponse response =
         new RssGetShuffleAssignmentsResponse(StatusCode.INTERNAL_ERROR);
     try {
-      if (coordinatorClients != null) {
-        response = coordinatorClients.getShuffleAssignments(request, 
retryIntervalMs, retryTimes);
+      if (coordinatorClient != null) {
+        response = coordinatorClient.getShuffleAssignments(request);
       }
     } catch (RssException e) {
       String msg =
@@ -895,8 +899,8 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
   @Override
   public void registerApplicationInfo(String appId, long timeoutMs, String 
user) {
     RssApplicationInfoRequest request = new RssApplicationInfoRequest(appId, 
timeoutMs, user);
-    if (coordinatorClients != null) {
-      coordinatorClients.registerApplicationInfo(request, timeoutMs);
+    if (coordinatorClient != null) {
+      coordinatorClient.registerApplicationInfo(request);
     }
   }
 
@@ -924,16 +928,16 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
         },
         timeoutMs,
         "send heartbeat to shuffle server");
-    if (coordinatorClients != null) {
-      coordinatorClients.sendAppHeartBeat(request, timeoutMs);
+    if (coordinatorClient != null) {
+      coordinatorClient.scheduleAtFixedRateToSendAppHeartBeat(request);
     }
   }
 
   @Override
   public void close() {
     heartBeatExecutorService.shutdownNow();
-    if (coordinatorClients != null) {
-      coordinatorClients.close();
+    if (coordinatorClient != null) {
+      coordinatorClient.close();
     }
     dataTransferPool.shutdownNow();
   }
diff --git 
a/internal-client/src/main/java/org/apache/uniffle/client/api/CoordinatorClient.java
 
b/internal-client/src/main/java/org/apache/uniffle/client/api/CoordinatorClient.java
index 0ce1ca4e7..33ed9c8c7 100644
--- 
a/internal-client/src/main/java/org/apache/uniffle/client/api/CoordinatorClient.java
+++ 
b/internal-client/src/main/java/org/apache/uniffle/client/api/CoordinatorClient.java
@@ -34,7 +34,7 @@ import 
org.apache.uniffle.client.response.RssSendHeartBeatResponse;
 
 public interface CoordinatorClient {
 
-  RssAppHeartBeatResponse sendAppHeartBeat(RssAppHeartBeatRequest request);
+  RssAppHeartBeatResponse 
scheduleAtFixedRateToSendAppHeartBeat(RssAppHeartBeatRequest request);
 
   RssApplicationInfoResponse registerApplicationInfo(RssApplicationInfoRequest 
request);
 
diff --git 
a/internal-client/src/main/java/org/apache/uniffle/client/factory/CoordinatorClientFactory.java
 
b/internal-client/src/main/java/org/apache/uniffle/client/factory/CoordinatorClientFactory.java
index c04cac8a1..e652a31ce 100644
--- 
a/internal-client/src/main/java/org/apache/uniffle/client/factory/CoordinatorClientFactory.java
+++ 
b/internal-client/src/main/java/org/apache/uniffle/client/factory/CoordinatorClientFactory.java
@@ -21,6 +21,7 @@ import java.util.Arrays;
 import java.util.List;
 import java.util.stream.Collectors;
 
+import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.Lists;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -44,6 +45,7 @@ public class CoordinatorClientFactory {
     return LazyHolder.INSTANCE;
   }
 
+  @VisibleForTesting
   public synchronized CoordinatorClient createCoordinatorClient(
       ClientType clientType, String host, int port) {
     if (clientType.equals(ClientType.GRPC) || 
clientType.equals(ClientType.GRPC_NETTY)) {
@@ -53,6 +55,7 @@ public class CoordinatorClientFactory {
     }
   }
 
+  @VisibleForTesting
   public synchronized List<CoordinatorClient> createCoordinatorClient(
       ClientType clientType, String coordinators) {
     LOG.info("Start to create coordinator clients from {}", coordinators);
diff --git 
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/CoordinatorGrpcClient.java
 
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/CoordinatorGrpcClient.java
index 2577ae3ef..8583e952e 100644
--- 
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/CoordinatorGrpcClient.java
+++ 
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/CoordinatorGrpcClient.java
@@ -243,7 +243,8 @@ public class CoordinatorGrpcClient extends GrpcClient 
implements CoordinatorClie
   }
 
   @Override
-  public RssAppHeartBeatResponse sendAppHeartBeat(RssAppHeartBeatRequest 
request) {
+  public RssAppHeartBeatResponse scheduleAtFixedRateToSendAppHeartBeat(
+      RssAppHeartBeatRequest request) {
     RssProtos.AppHeartBeatRequest rpcRequest =
         
RssProtos.AppHeartBeatRequest.newBuilder().setAppId(request.getAppId()).build();
     RssProtos.AppHeartBeatResponse rpcResponse =
diff --git 
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/CoordinatorGrpcRetryableClient.java
 
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/CoordinatorGrpcRetryableClient.java
index 12663badd..f4bf3be87 100644
--- 
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/CoordinatorGrpcRetryableClient.java
+++ 
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/CoordinatorGrpcRetryableClient.java
@@ -21,6 +21,7 @@ import java.util.List;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicReference;
 
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -39,13 +40,13 @@ import 
org.apache.uniffle.client.response.RssApplicationInfoResponse;
 import org.apache.uniffle.client.response.RssFetchClientConfResponse;
 import org.apache.uniffle.client.response.RssFetchRemoteStorageResponse;
 import org.apache.uniffle.client.response.RssGetShuffleAssignmentsResponse;
-import org.apache.uniffle.common.RemoteStorageInfo;
+import org.apache.uniffle.client.response.RssSendHeartBeatResponse;
 import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.common.rpc.StatusCode;
 import org.apache.uniffle.common.util.RetryUtils;
 import org.apache.uniffle.common.util.ThreadUtils;
 
-public class CoordinatorGrpcRetryableClient {
+public class CoordinatorGrpcRetryableClient implements CoordinatorClient {
   private static final Logger LOG = 
LoggerFactory.getLogger(CoordinatorGrpcRetryableClient.class);
   private List<CoordinatorClient> coordinatorClients;
   private long retryIntervalMs;
@@ -64,16 +65,22 @@ public class CoordinatorGrpcRetryableClient {
         ThreadUtils.getDaemonFixedThreadPool(heartBeatThreadNum, 
"client-heartbeat");
   }
 
-  public void sendAppHeartBeat(RssAppHeartBeatRequest request, long timeoutMs) 
{
+  @Override
+  public RssAppHeartBeatResponse scheduleAtFixedRateToSendAppHeartBeat(
+      RssAppHeartBeatRequest request) {
+    AtomicReference<RssAppHeartBeatResponse> rssResponse = new 
AtomicReference<>();
+    rssResponse.set(new RssAppHeartBeatResponse(StatusCode.INTERNAL_ERROR));
     ThreadUtils.executeTasks(
         heartBeatExecutorService,
         coordinatorClients,
         coordinatorClient -> {
           try {
-            RssAppHeartBeatResponse response = 
coordinatorClient.sendAppHeartBeat(request);
+            RssAppHeartBeatResponse response =
+                
coordinatorClient.scheduleAtFixedRateToSendAppHeartBeat(request);
             if (response.getStatusCode() != StatusCode.SUCCESS) {
               LOG.warn("Failed to send heartbeat to " + 
coordinatorClient.getDesc());
             } else {
+              rssResponse.set(response);
               LOG.info("Successfully send heartbeat to " + 
coordinatorClient.getDesc());
             }
           } catch (Exception e) {
@@ -81,11 +88,15 @@ public class CoordinatorGrpcRetryableClient {
           }
           return null;
         },
-        timeoutMs,
+        request.getTimeoutMs(),
         "send heartbeat to coordinator");
+    return rssResponse.get();
   }
 
-  public void registerApplicationInfo(RssApplicationInfoRequest request, long 
timeoutMs) {
+  @Override
+  public RssApplicationInfoResponse 
registerApplicationInfo(RssApplicationInfoRequest request) {
+    AtomicReference<RssApplicationInfoResponse> rssResponse = new 
AtomicReference<>();
+    rssResponse.set(new RssApplicationInfoResponse(StatusCode.INTERNAL_ERROR));
     ThreadUtils.executeTasks(
         heartBeatExecutorService,
         coordinatorClients,
@@ -96,6 +107,7 @@ public class CoordinatorGrpcRetryableClient {
             if (response.getStatusCode() != StatusCode.SUCCESS) {
               LOG.error("Failed to send applicationInfo to " + 
coordinatorClient.getDesc());
             } else {
+              rssResponse.set(response);
               LOG.info("Successfully send applicationInfo to " + 
coordinatorClient.getDesc());
             }
           } catch (Exception e) {
@@ -104,11 +116,13 @@ public class CoordinatorGrpcRetryableClient {
           }
           return null;
         },
-        timeoutMs,
+        request.getTimeoutMs(),
         "register application");
+    return rssResponse.get();
   }
 
-  public boolean sendHeartBeat(RssSendHeartBeatRequest request) {
+  @Override
+  public RssSendHeartBeatResponse sendHeartBeat(RssSendHeartBeatRequest 
request) {
     AtomicBoolean sendSuccessfully = new AtomicBoolean(false);
     ThreadUtils.executeTasks(
         heartBeatExecutorService,
@@ -124,16 +138,20 @@ public class CoordinatorGrpcRetryableClient {
             }
           } catch (Exception e) {
             LOG.error(e.getMessage());
-            return null;
           }
           return null;
         });
 
-    return sendSuccessfully.get();
+    if (sendSuccessfully.get()) {
+      return new RssSendHeartBeatResponse(StatusCode.SUCCESS);
+    } else {
+      return new RssSendHeartBeatResponse(StatusCode.INTERNAL_ERROR);
+    }
   }
 
+  @Override
   public RssGetShuffleAssignmentsResponse getShuffleAssignments(
-      RssGetShuffleAssignmentsRequest request, long retryIntervalMs, int 
retryTimes) {
+      RssGetShuffleAssignmentsRequest request) {
     try {
       return RetryUtils.retry(
           () -> {
@@ -149,7 +167,7 @@ public class CoordinatorGrpcRetryableClient {
                 LOG.info(
                     "Success to get shuffle server assignment from {}",
                     coordinatorClient.getDesc());
-                break;
+                return response;
               }
             }
             if (response.getStatusCode() != StatusCode.SUCCESS) {
@@ -164,8 +182,8 @@ public class CoordinatorGrpcRetryableClient {
     }
   }
 
-  public RssAccessClusterResponse accessCluster(
-      RssAccessClusterRequest request, long retryIntervalMs, int retryTimes) {
+  @Override
+  public RssAccessClusterResponse accessCluster(RssAccessClusterRequest 
request) {
     try {
       return RetryUtils.retry(
           () -> {
@@ -177,12 +195,10 @@ public class CoordinatorGrpcRetryableClient {
                     "Success to access cluster {} using {}",
                     coordinatorClient.getDesc(),
                     request.getAccessId());
-                break;
+                return response;
               }
             }
-            if (response.getStatusCode() == StatusCode.SUCCESS) {
-              return response;
-            } else if (response.getStatusCode() == StatusCode.ACCESS_DENIED) {
+            if (response.getStatusCode() == StatusCode.ACCESS_DENIED) {
               throw new RssException(
                   "Request to access cluster is denied using "
                       + request.getAccessId()
@@ -192,13 +208,14 @@ public class CoordinatorGrpcRetryableClient {
               throw new RssException("Fail to reach cluster for " + 
response.getMessage());
             }
           },
-          retryIntervalMs,
-          retryTimes);
+          request.getRetryIntervalMs(),
+          request.getRetryTimes());
     } catch (Throwable throwable) {
       throw new RssException("getShuffleAssignments failed!", throwable);
     }
   }
 
+  @Override
   public RssFetchClientConfResponse fetchClientConf(RssFetchClientConfRequest 
request) {
     try {
       return RetryUtils.retry(
@@ -223,7 +240,8 @@ public class CoordinatorGrpcRetryableClient {
     }
   }
 
-  public RemoteStorageInfo fetchRemoteStorage(RssFetchRemoteStorageRequest 
request) {
+  @Override
+  public RssFetchRemoteStorageResponse 
fetchRemoteStorage(RssFetchRemoteStorageRequest request) {
     try {
       return RetryUtils.retry(
           () -> {
@@ -241,7 +259,7 @@ public class CoordinatorGrpcRetryableClient {
             if (response.getStatusCode() != StatusCode.SUCCESS) {
               throw new RssException(response.getMessage());
             }
-            return response.getRemoteStorageInfo();
+            return response;
           },
           this.retryIntervalMs,
           this.retryTimes);
@@ -250,7 +268,19 @@ public class CoordinatorGrpcRetryableClient {
     }
   }
 
+  @Override
+  public String getDesc() {
+    StringBuilder result = new 
StringBuilder("CoordinatorGrpcRetryableClient:");
+    for (CoordinatorClient coordinatorClient : coordinatorClients) {
+      result.append("\n");
+      result.append(coordinatorClient.getDesc());
+    }
+    return result.toString();
+  }
+
+  @Override
   public void close() {
+    heartBeatExecutorService.shutdownNow();
     coordinatorClients.forEach(CoordinatorClient::close);
   }
 }
diff --git 
a/internal-client/src/main/java/org/apache/uniffle/client/request/RssAccessClusterRequest.java
 
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssAccessClusterRequest.java
index 435e20b1e..727d0562a 100644
--- 
a/internal-client/src/main/java/org/apache/uniffle/client/request/RssAccessClusterRequest.java
+++ 
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssAccessClusterRequest.java
@@ -33,12 +33,11 @@ public class RssAccessClusterRequest {
    */
   private final Map<String, String> extraProperties;
 
+  private final long retryIntervalMs;
+  private final int retryTimes;
+
   public RssAccessClusterRequest(String accessId, Set<String> tags, int 
timeoutMs, String user) {
-    this.accessId = accessId;
-    this.tags = tags;
-    this.timeoutMs = timeoutMs;
-    this.extraProperties = Collections.emptyMap();
-    this.user = user;
+    this(accessId, tags, timeoutMs, Collections.emptyMap(), user, 0, 0);
   }
 
   public RssAccessClusterRequest(
@@ -47,11 +46,24 @@ public class RssAccessClusterRequest {
       int timeoutMs,
       Map<String, String> extraProperties,
       String user) {
+    this(accessId, tags, timeoutMs, extraProperties, user, 0, 0);
+  }
+
+  public RssAccessClusterRequest(
+      String accessId,
+      Set<String> tags,
+      int timeoutMs,
+      Map<String, String> extraProperties,
+      String user,
+      long retryInterval,
+      int retryTimes) {
     this.accessId = accessId;
     this.tags = tags;
     this.timeoutMs = timeoutMs;
     this.extraProperties = extraProperties;
     this.user = user;
+    this.retryIntervalMs = retryInterval;
+    this.retryTimes = retryTimes;
   }
 
   public String getAccessId() {
@@ -73,4 +85,12 @@ public class RssAccessClusterRequest {
   public String getUser() {
     return user;
   }
+
+  public long getRetryIntervalMs() {
+    return retryIntervalMs;
+  }
+
+  public int getRetryTimes() {
+    return retryTimes;
+  }
 }
diff --git 
a/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleAssignmentsRequest.java
 
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleAssignmentsRequest.java
index b80b1f9c0..84e36e103 100644
--- 
a/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleAssignmentsRequest.java
+++ 
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleAssignmentsRequest.java
@@ -36,6 +36,8 @@ public class RssGetShuffleAssignmentsRequest {
   private int stageId = -1;
   private int stageAttemptNumber = 0;
   private boolean reassign = false;
+  private long retryIntervalMs;
+  private int retryTimes;
 
   @VisibleForTesting
   public RssGetShuffleAssignmentsRequest(
@@ -79,7 +81,9 @@ public class RssGetShuffleAssignmentsRequest {
         faultyServerIds,
         -1,
         0,
-        false);
+        false,
+        0,
+        0);
   }
 
   public RssGetShuffleAssignmentsRequest(
@@ -94,7 +98,9 @@ public class RssGetShuffleAssignmentsRequest {
       Set<String> faultyServerIds,
       int stageId,
       int stageAttemptNumber,
-      boolean reassign) {
+      boolean reassign,
+      long retryIntervalMs,
+      int retryTimes) {
     this.appId = appId;
     this.shuffleId = shuffleId;
     this.partitionNum = partitionNum;
@@ -107,6 +113,8 @@ public class RssGetShuffleAssignmentsRequest {
     this.stageId = stageId;
     this.stageAttemptNumber = stageAttemptNumber;
     this.reassign = reassign;
+    this.retryIntervalMs = retryIntervalMs;
+    this.retryTimes = retryTimes;
   }
 
   public String getAppId() {
@@ -156,4 +164,12 @@ public class RssGetShuffleAssignmentsRequest {
   public boolean isReassign() {
     return reassign;
   }
+
+  public long getRetryIntervalMs() {
+    return retryIntervalMs;
+  }
+
+  public int getRetryTimes() {
+    return retryTimes;
+  }
 }
diff --git 
a/server/src/main/java/org/apache/uniffle/server/RegisterHeartBeat.java 
b/server/src/main/java/org/apache/uniffle/server/RegisterHeartBeat.java
index d18f29cd7..4b2d4607a 100644
--- a/server/src/main/java/org/apache/uniffle/server/RegisterHeartBeat.java
+++ b/server/src/main/java/org/apache/uniffle/server/RegisterHeartBeat.java
@@ -17,20 +17,17 @@
 
 package org.apache.uniffle.server;
 
-import java.util.List;
 import java.util.Map;
 import java.util.Set;
-import java.util.concurrent.ExecutorService;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.TimeUnit;
-import java.util.concurrent.atomic.AtomicBoolean;
 
 import com.google.common.annotations.VisibleForTesting;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import org.apache.uniffle.client.api.CoordinatorClient;
 import org.apache.uniffle.client.factory.CoordinatorClientFactory;
+import org.apache.uniffle.client.impl.grpc.CoordinatorGrpcRetryableClient;
 import org.apache.uniffle.client.request.RssSendHeartBeatRequest;
 import org.apache.uniffle.common.ServerStatus;
 import org.apache.uniffle.common.rpc.StatusCode;
@@ -45,10 +42,9 @@ public class RegisterHeartBeat {
   private final long heartBeatInterval;
   private final ShuffleServer shuffleServer;
   private final String coordinatorQuorum;
-  private final List<CoordinatorClient> coordinatorClients;
+  private final CoordinatorGrpcRetryableClient coordinatorClient;
   private final ScheduledExecutorService service =
       ThreadUtils.getDaemonSingleThreadScheduledExecutor("startHeartBeat");
-  private final ExecutorService heartBeatExecutorService;
 
   public RegisterHeartBeat(ShuffleServer shuffleServer) {
     ShuffleServerConf conf = shuffleServer.getShuffleServerConf();
@@ -56,13 +52,14 @@ public class RegisterHeartBeat {
     this.heartBeatInterval = 
conf.getLong(ShuffleServerConf.SERVER_HEARTBEAT_INTERVAL);
     this.coordinatorQuorum = 
conf.getString(ShuffleServerConf.RSS_COORDINATOR_QUORUM);
     CoordinatorClientFactory factory = CoordinatorClientFactory.getInstance();
-    this.coordinatorClients =
+    this.coordinatorClient =
         factory.createCoordinatorClient(
-            conf.get(ShuffleServerConf.RSS_COORDINATOR_CLIENT_TYPE), 
this.coordinatorQuorum);
+            conf.get(ShuffleServerConf.RSS_COORDINATOR_CLIENT_TYPE),
+            this.coordinatorQuorum,
+            0,
+            0,
+            conf.getInteger(ShuffleServerConf.SERVER_HEARTBEAT_THREAD_NUM));
     this.shuffleServer = shuffleServer;
-    this.heartBeatExecutorService =
-        ThreadUtils.getDaemonFixedThreadPool(
-            conf.getInteger(ShuffleServerConf.SERVER_HEARTBEAT_THREAD_NUM), 
"sendHeartBeat");
   }
 
   public void startHeartBeat() {
@@ -111,7 +108,6 @@ public class RegisterHeartBeat {
       int nettyPort,
       int jettyPort,
       long startTimeMs) {
-    AtomicBoolean sendSuccessfully = new AtomicBoolean(false);
     // use `rss.server.heartbeat.interval` as the timeout option
     RssSendHeartBeatRequest request =
         new RssSendHeartBeatRequest(
@@ -130,30 +126,14 @@ public class RegisterHeartBeat {
             jettyPort,
             startTimeMs);
 
-    ThreadUtils.executeTasks(
-        heartBeatExecutorService,
-        coordinatorClients,
-        client -> client.sendHeartBeat(request),
-        request.getTimeout() * 2,
-        "send heartbeat",
-        future -> {
-          try {
-            if (future.get(request.getTimeout() * 2, 
TimeUnit.MILLISECONDS).getStatusCode()
-                == StatusCode.SUCCESS) {
-              sendSuccessfully.set(true);
-            }
-          } catch (Exception e) {
-            LOG.error(e.getMessage());
-            return null;
-          }
-          return null;
-        });
-
-    return sendSuccessfully.get();
+    if (coordinatorClient.sendHeartBeat(request).getStatusCode() == 
StatusCode.SUCCESS) {
+      return true;
+    }
+    return false;
   }
 
   public void shutdown() {
-    heartBeatExecutorService.shutdownNow();
+    coordinatorClient.close();
     service.shutdownNow();
   }
 }


Reply via email to