This is an automated email from the ASF dual-hosted git repository.

xianjin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new 457c86536 [#1608] feat: Introduce ExpiringClosableSupplier and 
refactor ShuffleManagerClient creation (#1838)
457c86536 is described below

commit 457c865362e1dc573004b30c505287c253a6dba0
Author: xumanbu <jam...@vipshop.com>
AuthorDate: Fri Jul 26 21:24:28 2024 +0800

    [#1608] feat: Introduce ExpiringClosableSupplier and refactor 
ShuffleManagerClient creation (#1838)
    
    ### What changes were proposed in this pull request?
    1. Introduce StatefulCloseable and ExpiringClosableSupplier
    2. refactor ShuffleManagerClient to leverage ExpiringClosableSupplier
    
    ### Why are the changes needed?
    For better code quality
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Existing UTs and new UTs.
---
 .../apache/spark/shuffle/RssSparkShuffleUtils.java |  48 +++---
 .../shuffle/reader/RssFetchFailedIterator.java     |  63 +++-----
 .../BlockIdSelfManagedShuffleWriteClient.java      |  13 +-
 .../uniffle/shuffle/RssShuffleClientFactory.java   |  12 +-
 .../shuffle/manager/RssShuffleManagerBase.java     |  36 +++--
 .../apache/spark/shuffle/RssShuffleManager.java    |  18 ++-
 .../spark/shuffle/reader/RssShuffleReader.java     |  12 +-
 .../spark/shuffle/writer/RssShuffleWriter.java     |  71 ++++-----
 .../spark/shuffle/reader/RssShuffleReaderTest.java |   6 +-
 .../spark/shuffle/writer/RssShuffleWriterTest.java |   8 +
 .../apache/spark/shuffle/RssShuffleManager.java    |   9 +-
 .../spark/shuffle/reader/RssShuffleReader.java     |  11 +-
 .../spark/shuffle/writer/RssShuffleWriter.java     |  84 +++++-----
 .../spark/shuffle/reader/RssShuffleReaderTest.java |   6 +
 .../spark/shuffle/writer/RssShuffleWriterTest.java |  14 ++
 .../common/util/ExpiringCloseableSupplier.java     | 110 +++++++++++++
 .../uniffle/common/util/StatefulCloseable.java     |  25 +++
 .../common/util/ExpiringCloseableSupplierTest.java | 172 +++++++++++++++++++++
 .../uniffle/test/ShuffleServerManagerTestBase.java |  13 +-
 .../uniffle/client/api/ShuffleManagerClient.java   |   5 +-
 .../factory/ShuffleManagerClientFactory.java       |   4 +-
 .../client/impl/grpc/ShuffleManagerGrpcClient.java |  20 ++-
 .../factory/ShuffleManagerClientFactoryTest.java   |   5 +-
 23 files changed, 545 insertions(+), 220 deletions(-)

diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java
index b3763df32..feee2a331 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java
@@ -17,7 +17,6 @@
 
 package org.apache.spark.shuffle;
 
-import java.io.IOException;
 import java.lang.reflect.Constructor;
 import java.lang.reflect.InvocationTargetException;
 import java.util.Arrays;
@@ -25,6 +24,7 @@ import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.function.Supplier;
 
 import scala.Option;
 import scala.reflect.ClassTag;
@@ -43,21 +43,18 @@ import org.slf4j.LoggerFactory;
 import org.apache.uniffle.client.api.CoordinatorClient;
 import org.apache.uniffle.client.api.ShuffleManagerClient;
 import org.apache.uniffle.client.factory.CoordinatorClientFactory;
-import org.apache.uniffle.client.factory.ShuffleManagerClientFactory;
 import org.apache.uniffle.client.request.RssReportShuffleFetchFailureRequest;
 import org.apache.uniffle.client.response.RssReportShuffleFetchFailureResponse;
 import org.apache.uniffle.client.util.ClientUtils;
 import org.apache.uniffle.common.ClientType;
 import org.apache.uniffle.common.RemoteStorageInfo;
 import org.apache.uniffle.common.ShuffleServerInfo;
-import org.apache.uniffle.common.config.RssClientConf;
 import org.apache.uniffle.common.config.RssConf;
 import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.common.exception.RssFetchFailedException;
 import org.apache.uniffle.common.util.Constants;
 
 import static 
org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED;
-import static org.apache.uniffle.common.util.Constants.DRIVER_HOST;
 
 public class RssSparkShuffleUtils {
 
@@ -346,6 +343,7 @@ public class RssSparkShuffleUtils {
   }
 
   public static RssException reportRssFetchFailedException(
+      Supplier<ShuffleManagerClient> managerClientSupplier,
       RssFetchFailedException rssFetchFailedException,
       SparkConf sparkConf,
       String appId,
@@ -355,32 +353,24 @@ public class RssSparkShuffleUtils {
     RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
     if (rssConf.getBoolean(RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED)
         && RssSparkShuffleUtils.isStageResubmitSupported()) {
-      String driver = rssConf.getString(DRIVER_HOST, "");
-      int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
-      try (ShuffleManagerClient client =
-          ShuffleManagerClientFactory.getInstance()
-              .createShuffleManagerClient(ClientType.GRPC, driver, port)) {
-        // todo: Create a new rpc interface to report failures in batch.
-        for (int partitionId : failedPartitions) {
-          RssReportShuffleFetchFailureRequest req =
-              new RssReportShuffleFetchFailureRequest(
-                  appId,
-                  shuffleId,
-                  stageAttemptId,
-                  partitionId,
-                  rssFetchFailedException.getMessage());
-          RssReportShuffleFetchFailureResponse response = 
client.reportShuffleFetchFailure(req);
-          if (response.getReSubmitWholeStage()) {
-            // since we are going to roll out the whole stage, mapIndex 
shouldn't matter, hence -1
-            // is provided.
-            FetchFailedException ffe =
-                RssSparkShuffleUtils.createFetchFailedException(
-                    shuffleId, -1, partitionId, rssFetchFailedException);
-            return new RssException(ffe);
-          }
+      for (int partitionId : failedPartitions) {
+        RssReportShuffleFetchFailureRequest req =
+            new RssReportShuffleFetchFailureRequest(
+                appId,
+                shuffleId,
+                stageAttemptId,
+                partitionId,
+                rssFetchFailedException.getMessage());
+        RssReportShuffleFetchFailureResponse response =
+            managerClientSupplier.get().reportShuffleFetchFailure(req);
+        if (response.getReSubmitWholeStage()) {
+          // since we are going to roll out the whole stage, mapIndex 
shouldn't matter, hence -1
+          // is provided.
+          FetchFailedException ffe =
+              RssSparkShuffleUtils.createFetchFailedException(
+                  shuffleId, -1, partitionId, rssFetchFailedException);
+          return new RssException(ffe);
         }
-      } catch (IOException ioe) {
-        LOG.info("Error closing shuffle manager client with error:", ioe);
       }
     }
     return rssFetchFailedException;
diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssFetchFailedIterator.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssFetchFailedIterator.java
index c394f510b..1bc61dc74 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssFetchFailedIterator.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssFetchFailedIterator.java
@@ -17,8 +17,8 @@
 
 package org.apache.spark.shuffle.reader;
 
-import java.io.IOException;
 import java.util.Objects;
+import java.util.function.Supplier;
 
 import scala.Product2;
 import scala.collection.AbstractIterator;
@@ -30,10 +30,8 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import org.apache.uniffle.client.api.ShuffleManagerClient;
-import org.apache.uniffle.client.factory.ShuffleManagerClientFactory;
 import org.apache.uniffle.client.request.RssReportShuffleFetchFailureRequest;
 import org.apache.uniffle.client.response.RssReportShuffleFetchFailureResponse;
-import org.apache.uniffle.common.ClientType;
 import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.common.exception.RssFetchFailedException;
 
@@ -52,8 +50,7 @@ public class RssFetchFailedIterator<K, C> extends 
AbstractIterator<Product2<K, C
     private int shuffleId;
     private int partitionId;
     private int stageAttemptId;
-    private String reportServerHost;
-    private int reportServerPort;
+    private Supplier<ShuffleManagerClient> managerClientSupplier;
 
     private Builder() {}
 
@@ -77,19 +74,13 @@ public class RssFetchFailedIterator<K, C> extends 
AbstractIterator<Product2<K, C
       return this;
     }
 
-    Builder reportServerHost(String host) {
-      this.reportServerHost = host;
-      return this;
-    }
-
-    Builder port(int port) {
-      this.reportServerPort = port;
+    Builder managerClientSupplier(Supplier<ShuffleManagerClient> 
managerClientSupplier) {
+      this.managerClientSupplier = managerClientSupplier;
       return this;
     }
 
     <K, C> RssFetchFailedIterator<K, C> build(Iterator<Product2<K, C>> iter) {
       Objects.requireNonNull(this.appId);
-      Objects.requireNonNull(this.reportServerHost);
       return new RssFetchFailedIterator<>(this, iter);
     }
   }
@@ -98,37 +89,23 @@ public class RssFetchFailedIterator<K, C> extends 
AbstractIterator<Product2<K, C
     return new Builder();
   }
 
-  private static ShuffleManagerClient createShuffleManagerClient(String host, 
int port)
-      throws IOException {
-    ClientType grpc = ClientType.GRPC;
-    // host is passed from spark.driver.bindAddress, which would be set when 
SparkContext is
-    // constructed.
-    return 
ShuffleManagerClientFactory.getInstance().createShuffleManagerClient(grpc, 
host, port);
-  }
-
   private RssException generateFetchFailedIfNecessary(RssFetchFailedException 
e) {
-    String driver = builder.reportServerHost;
-    int port = builder.reportServerPort;
-    // todo: reuse this manager client if this is a bottleneck.
-    try (ShuffleManagerClient client = createShuffleManagerClient(driver, 
port)) {
-      RssReportShuffleFetchFailureRequest req =
-          new RssReportShuffleFetchFailureRequest(
-              builder.appId,
-              builder.shuffleId,
-              builder.stageAttemptId,
-              builder.partitionId,
-              e.getMessage());
-      RssReportShuffleFetchFailureResponse response = 
client.reportShuffleFetchFailure(req);
-      if (response.getReSubmitWholeStage()) {
-        // since we are going to roll out the whole stage, mapIndex shouldn't 
matter, hence -1 is
-        // provided.
-        FetchFailedException ffe =
-            RssSparkShuffleUtils.createFetchFailedException(
-                builder.shuffleId, -1, builder.partitionId, e);
-        return new RssException(ffe);
-      }
-    } catch (IOException ioe) {
-      LOG.info("Error closing shuffle manager client with error:", ioe);
+    ShuffleManagerClient client = builder.managerClientSupplier.get();
+    RssReportShuffleFetchFailureRequest req =
+        new RssReportShuffleFetchFailureRequest(
+            builder.appId,
+            builder.shuffleId,
+            builder.stageAttemptId,
+            builder.partitionId,
+            e.getMessage());
+    RssReportShuffleFetchFailureResponse response = 
client.reportShuffleFetchFailure(req);
+    if (response.getReSubmitWholeStage()) {
+      // since we are going to roll out the whole stage, mapIndex shouldn't 
matter, hence -1 is
+      // provided.
+      FetchFailedException ffe =
+          RssSparkShuffleUtils.createFetchFailedException(
+              builder.shuffleId, -1, builder.partitionId, e);
+      return new RssException(ffe);
     }
     return e;
   }
diff --git 
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/BlockIdSelfManagedShuffleWriteClient.java
 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/BlockIdSelfManagedShuffleWriteClient.java
index 1429bacbf..93aa3f0fc 100644
--- 
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/BlockIdSelfManagedShuffleWriteClient.java
+++ 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/BlockIdSelfManagedShuffleWriteClient.java
@@ -22,6 +22,7 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.function.Supplier;
 import java.util.stream.Collectors;
 
 import org.roaringbitmap.longlong.Roaring64NavigableMap;
@@ -41,16 +42,16 @@ import org.apache.uniffle.common.util.BlockIdLayout;
  * driver side.
  */
 public class BlockIdSelfManagedShuffleWriteClient extends 
ShuffleWriteClientImpl {
-  private ShuffleManagerClient shuffleManagerClient;
+  private Supplier<ShuffleManagerClient> managerClientSupplier;
 
   public BlockIdSelfManagedShuffleWriteClient(
       RssShuffleClientFactory.ExtendWriteClientBuilder builder) {
     super(builder);
 
-    if (builder.getShuffleManagerClient() == null) {
+    if (builder.getManagerClientSupplier() == null) {
       throw new RssException("Illegal empty shuffleManagerClient. This should 
not happen");
     }
-    this.shuffleManagerClient = builder.getShuffleManagerClient();
+    this.managerClientSupplier = builder.getManagerClientSupplier();
   }
 
   @Override
@@ -73,7 +74,7 @@ public class BlockIdSelfManagedShuffleWriteClient extends 
ShuffleWriteClientImpl
     RssReportShuffleResultRequest request =
         new RssReportShuffleResultRequest(
             appId, shuffleId, taskAttemptId, partitionToBlockIds, bitmapNum);
-    shuffleManagerClient.reportShuffleResult(request);
+    managerClientSupplier.get().reportShuffleResult(request);
   }
 
   @Override
@@ -85,7 +86,7 @@ public class BlockIdSelfManagedShuffleWriteClient extends 
ShuffleWriteClientImpl
       int partitionId) {
     RssGetShuffleResultRequest request =
         new RssGetShuffleResultRequest(appId, shuffleId, partitionId, 
BlockIdLayout.DEFAULT);
-    return shuffleManagerClient.getShuffleResult(request).getBlockIdBitmap();
+    return 
managerClientSupplier.get().getShuffleResult(request).getBlockIdBitmap();
   }
 
   @Override
@@ -101,6 +102,6 @@ public class BlockIdSelfManagedShuffleWriteClient extends 
ShuffleWriteClientImpl
     RssGetShuffleResultForMultiPartRequest request =
         new RssGetShuffleResultForMultiPartRequest(
             appId, shuffleId, partitionIds, BlockIdLayout.DEFAULT);
-    return 
shuffleManagerClient.getShuffleResultForMultiPart(request).getBlockIdBitmap();
+    return 
managerClientSupplier.get().getShuffleResultForMultiPart(request).getBlockIdBitmap();
   }
 }
diff --git 
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/RssShuffleClientFactory.java
 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/RssShuffleClientFactory.java
index c19d91324..bad10ab72 100644
--- 
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/RssShuffleClientFactory.java
+++ 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/RssShuffleClientFactory.java
@@ -17,6 +17,8 @@
 
 package org.apache.uniffle.shuffle;
 
+import java.util.function.Supplier;
+
 import org.apache.uniffle.client.api.ShuffleManagerClient;
 import org.apache.uniffle.client.api.ShuffleWriteClient;
 import org.apache.uniffle.client.factory.ShuffleClientFactory;
@@ -41,18 +43,18 @@ public class RssShuffleClientFactory extends 
ShuffleClientFactory {
   public static class ExtendWriteClientBuilder<T extends 
ExtendWriteClientBuilder<T>>
       extends WriteClientBuilder<T> {
     private boolean blockIdSelfManagedEnabled;
-    private ShuffleManagerClient shuffleManagerClient;
+    private Supplier<ShuffleManagerClient> managerClientSupplier;
 
     public boolean isBlockIdSelfManagedEnabled() {
       return blockIdSelfManagedEnabled;
     }
 
-    public ShuffleManagerClient getShuffleManagerClient() {
-      return shuffleManagerClient;
+    public Supplier<ShuffleManagerClient> getManagerClientSupplier() {
+      return managerClientSupplier;
     }
 
-    public T shuffleManagerClient(ShuffleManagerClient client) {
-      this.shuffleManagerClient = client;
+    public T managerClientSupplier(Supplier<ShuffleManagerClient> 
managerClientSupplier) {
+      this.managerClientSupplier = managerClientSupplier;
       return self();
     }
 
diff --git 
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
index 6a281db2e..d314b9bb6 100644
--- 
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
+++ 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
@@ -31,6 +31,7 @@ import java.util.Set;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.Function;
+import java.util.function.Supplier;
 import java.util.stream.Collectors;
 
 import com.google.common.annotations.VisibleForTesting;
@@ -78,10 +79,12 @@ import org.apache.uniffle.common.ShuffleAssignmentsInfo;
 import org.apache.uniffle.common.ShuffleDataDistributionType;
 import org.apache.uniffle.common.ShuffleServerInfo;
 import org.apache.uniffle.common.config.ConfigOption;
+import org.apache.uniffle.common.config.RssBaseConf;
 import org.apache.uniffle.common.config.RssClientConf;
 import org.apache.uniffle.common.config.RssConf;
 import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.common.rpc.StatusCode;
+import org.apache.uniffle.common.util.ExpiringCloseableSupplier;
 import org.apache.uniffle.common.util.RetryUtils;
 import org.apache.uniffle.shuffle.BlockIdManager;
 
@@ -104,7 +107,7 @@ public abstract class RssShuffleManagerBase implements 
RssShuffleManagerInterfac
   protected String clientType;
 
   protected SparkConf sparkConf;
-  protected ShuffleManagerClient shuffleManagerClient;
+  protected Supplier<ShuffleManagerClient> managerClientSupplier;
   protected boolean rssStageRetryEnabled;
   protected boolean rssStageRetryForWriteFailureEnabled;
   protected boolean rssStageRetryForFetchFailureEnabled;
@@ -588,7 +591,8 @@ public abstract class RssShuffleManagerBase implements 
RssShuffleManagerInterfac
     RssPartitionToShuffleServerRequest rssPartitionToShuffleServerRequest =
         new RssPartitionToShuffleServerRequest(shuffleId);
     RssReassignOnStageRetryResponse rpcPartitionToShufflerServer =
-        getOrCreateShuffleManagerClient()
+        getOrCreateShuffleManagerClientSupplier()
+            .get()
             
.getPartitionToShufflerServerWithStageRetry(rssPartitionToShuffleServerRequest);
     StageAttemptShuffleHandleInfo shuffleHandleInfo =
         StageAttemptShuffleHandleInfo.fromProto(
@@ -607,25 +611,27 @@ public abstract class RssShuffleManagerBase implements 
RssShuffleManagerInterfac
     RssPartitionToShuffleServerRequest rssPartitionToShuffleServerRequest =
         new RssPartitionToShuffleServerRequest(shuffleId);
     RssReassignOnBlockSendFailureResponse rpcPartitionToShufflerServer =
-        getOrCreateShuffleManagerClient()
+        getOrCreateShuffleManagerClientSupplier()
+            .get()
             
.getPartitionToShufflerServerWithBlockRetry(rssPartitionToShuffleServerRequest);
     MutableShuffleHandleInfo shuffleHandleInfo =
         
MutableShuffleHandleInfo.fromProto(rpcPartitionToShufflerServer.getHandle());
     return shuffleHandleInfo;
   }
 
-  // todo: automatic close client when the client is idle to avoid too much 
connections for spark
-  // driver.
-  protected ShuffleManagerClient getOrCreateShuffleManagerClient() {
-    if (shuffleManagerClient == null) {
+  protected synchronized Supplier<ShuffleManagerClient> 
getOrCreateShuffleManagerClientSupplier() {
+    if (managerClientSupplier == null) {
       RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
       String driver = rssConf.getString("driver.host", "");
       int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
-      this.shuffleManagerClient =
-          ShuffleManagerClientFactory.getInstance()
-              .createShuffleManagerClient(ClientType.GRPC, driver, port);
+      long rpcTimeout = 
rssConf.getLong(RssBaseConf.RSS_CLIENT_TYPE_GRPC_TIMEOUT_MS);
+      this.managerClientSupplier =
+          ExpiringCloseableSupplier.of(
+              () ->
+                  ShuffleManagerClientFactory.getInstance()
+                      .createShuffleManagerClient(ClientType.GRPC, driver, 
port, rpcTimeout));
     }
-    return shuffleManagerClient;
+    return managerClientSupplier;
   }
 
   @Override
@@ -808,6 +814,14 @@ public abstract class RssShuffleManagerBase implements 
RssShuffleManagerInterfac
     }
   }
 
+  @Override
+  public void stop() {
+    if (managerClientSupplier != null
+        && managerClientSupplier instanceof ExpiringCloseableSupplier) {
+      ((ExpiringCloseableSupplier<ShuffleManagerClient>) 
managerClientSupplier).close();
+    }
+  }
+
   /**
    * Creating the shuffleAssignmentInfo from the servers and partitionIds
    *
diff --git 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 1e5bb4941..27db614bf 100644
--- 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -214,16 +214,15 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
           }
         }
       }
-
       if (shuffleManagerRpcServiceEnabled) {
-        this.shuffleManagerClient = getOrCreateShuffleManagerClient();
+        getOrCreateShuffleManagerClientSupplier();
       }
       this.shuffleWriteClient =
           RssShuffleClientFactory.getInstance()
               .createShuffleWriteClient(
                   RssShuffleClientFactory.newWriteBuilder()
                       .blockIdSelfManagedEnabled(blockIdSelfManagedEnabled)
-                      .shuffleManagerClient(shuffleManagerClient)
+                      .managerClientSupplier(managerClientSupplier)
                       .clientType(clientType)
                       .retryMax(retryMax)
                       .retryIntervalMax(retryIntervalMax)
@@ -434,6 +433,7 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
           this,
           sparkConf,
           shuffleWriteClient,
+          managerClientSupplier,
           rssHandle,
           this::markFailedTask,
           context);
@@ -537,7 +537,8 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
           blockIdBitmap,
           taskIdBitmap,
           RssSparkConfig.toRssConf(sparkConf),
-          partitionToServers);
+          partitionToServers,
+          managerClientSupplier);
     } else {
       throw new RssException("Unexpected ShuffleHandle:" + 
handle.getClass().getName());
     }
@@ -573,6 +574,7 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
 
   @Override
   public void stop() {
+    super.stop();
     if (heartBeatScheduledExecutorService != null) {
       heartBeatScheduledExecutorService.shutdownNow();
     }
@@ -719,7 +721,13 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
           clientType, shuffleServerInfoSet, appId, shuffleId, partitionId);
     } catch (RssFetchFailedException e) {
       throw RssSparkShuffleUtils.reportRssFetchFailedException(
-          e, sparkConf, appId, shuffleId, stageAttemptId, 
Sets.newHashSet(partitionId));
+          managerClientSupplier,
+          e,
+          sparkConf,
+          appId,
+          shuffleId,
+          stageAttemptId,
+          Sets.newHashSet(partitionId));
     }
   }
 
diff --git 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
index 3bf5840e8..4b4ec32c5 100644
--- 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
+++ 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
@@ -19,6 +19,7 @@ package org.apache.spark.shuffle.reader;
 
 import java.util.List;
 import java.util.Map;
+import java.util.function.Supplier;
 
 import scala.Function0;
 import scala.Function2;
@@ -47,6 +48,7 @@ import org.roaringbitmap.longlong.Roaring64NavigableMap;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import org.apache.uniffle.client.api.ShuffleManagerClient;
 import org.apache.uniffle.client.api.ShuffleReadClient;
 import org.apache.uniffle.client.factory.ShuffleClientFactory;
 import org.apache.uniffle.client.util.RssClientConfig;
@@ -77,6 +79,7 @@ public class RssShuffleReader<K, C> implements 
ShuffleReader<K, C> {
   private List<ShuffleServerInfo> shuffleServerInfoList;
   private Configuration hadoopConf;
   private RssConf rssConf;
+  private Supplier<ShuffleManagerClient> managerClientSupplier;
 
   public RssShuffleReader(
       int startPartition,
@@ -90,7 +93,8 @@ public class RssShuffleReader<K, C> implements 
ShuffleReader<K, C> {
       Roaring64NavigableMap blockIdBitmap,
       Roaring64NavigableMap taskIdBitmap,
       RssConf rssConf,
-      Map<Integer, List<ShuffleServerInfo>> partitionToServers) {
+      Map<Integer, List<ShuffleServerInfo>> partitionToServers,
+      Supplier<ShuffleManagerClient> managerClientSupplier) {
     this.appId = rssShuffleHandle.getAppId();
     this.startPartition = startPartition;
     this.endPartition = endPartition;
@@ -107,6 +111,7 @@ public class RssShuffleReader<K, C> implements 
ShuffleReader<K, C> {
     this.hadoopConf = hadoopConf;
     this.shuffleServerInfoList = (List<ShuffleServerInfo>) 
(partitionToServers.get(startPartition));
     this.rssConf = rssConf;
+    this.managerClientSupplier = managerClientSupplier;
     expectedTaskIdsBitmapFilterEnable = shuffleServerInfoList.size() > 1;
   }
 
@@ -235,16 +240,13 @@ public class RssShuffleReader<K, C> implements 
ShuffleReader<K, C> {
     // stage re-compute and shuffle manager server port are both set
     if (rssConf.getBoolean(RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED)
         && rssConf.getInteger(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT, 0) > 0) 
{
-      String driver = rssConf.getString("driver.host", "");
-      int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
       resultIter =
           RssFetchFailedIterator.newBuilder()
               .appId(appId)
               .shuffleId(shuffleId)
               .partitionId(startPartition)
               .stageAttemptId(context.stageAttemptNumber())
-              .reportServerHost(driver)
-              .port(port)
+              .managerClientSupplier(managerClientSupplier)
               .build(resultIter);
     }
     return resultIter;
diff --git 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 5ac6a7e9e..4474c99c8 100644
--- 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++ 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -17,7 +17,6 @@
 
 package org.apache.spark.shuffle.writer;
 
-import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
@@ -31,6 +30,7 @@ import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
 import java.util.concurrent.TimeUnit;
 import java.util.function.Function;
+import java.util.function.Supplier;
 import java.util.stream.Collectors;
 
 import scala.Function1;
@@ -64,17 +64,13 @@ import org.slf4j.LoggerFactory;
 
 import org.apache.uniffle.client.api.ShuffleManagerClient;
 import org.apache.uniffle.client.api.ShuffleWriteClient;
-import org.apache.uniffle.client.factory.ShuffleManagerClientFactory;
 import org.apache.uniffle.client.impl.FailedBlockSendTracker;
 import org.apache.uniffle.client.request.RssReassignServersRequest;
 import org.apache.uniffle.client.request.RssReportShuffleWriteFailureRequest;
 import org.apache.uniffle.client.response.RssReassignServersResponse;
 import org.apache.uniffle.client.response.RssReportShuffleWriteFailureResponse;
-import org.apache.uniffle.common.ClientType;
 import org.apache.uniffle.common.ShuffleBlockInfo;
 import org.apache.uniffle.common.ShuffleServerInfo;
-import org.apache.uniffle.common.config.RssClientConf;
-import org.apache.uniffle.common.config.RssConf;
 import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.common.exception.RssSendFailedException;
 import org.apache.uniffle.common.exception.RssWaitFailedException;
@@ -114,6 +110,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
   private final Set<Long> blockIds = Sets.newConcurrentHashSet();
   private TaskContext taskContext;
   private SparkConf sparkConf;
+  private Supplier<ShuffleManagerClient> managerClientSupplier;
 
   public RssShuffleWriter(
       String appId,
@@ -125,6 +122,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
       RssShuffleManager shuffleManager,
       SparkConf sparkConf,
       ShuffleWriteClient shuffleWriteClient,
+      Supplier<ShuffleManagerClient> managerClientSupplier,
       RssShuffleHandle<K, V, C> rssHandle,
       SimpleShuffleHandleInfo shuffleHandleInfo,
       TaskContext context) {
@@ -137,6 +135,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         shuffleManager,
         sparkConf,
         shuffleWriteClient,
+        managerClientSupplier,
         rssHandle,
         (tid) -> true,
         shuffleHandleInfo,
@@ -153,6 +152,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
       RssShuffleManager shuffleManager,
       SparkConf sparkConf,
       ShuffleWriteClient shuffleWriteClient,
+      Supplier<ShuffleManagerClient> managerClientSupplier,
       RssShuffleHandle<K, V, C> rssHandle,
       Function<String, Boolean> taskFailureCallback,
       ShuffleHandleInfo shuffleHandleInfo,
@@ -172,6 +172,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     this.bitmapSplitNum = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_BITMAP_SPLIT_NUM);
     this.serverToPartitionToBlockIds = Maps.newHashMap();
     this.shuffleWriteClient = shuffleWriteClient;
+    this.managerClientSupplier = managerClientSupplier;
     this.shuffleServersForData = shuffleHandleInfo.getServers();
     this.partitionToServers = 
shuffleHandleInfo.getAvailablePartitionServersForWriter();
     this.isMemoryShuffleEnabled =
@@ -191,6 +192,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
       RssShuffleManager shuffleManager,
       SparkConf sparkConf,
       ShuffleWriteClient shuffleWriteClient,
+      Supplier<ShuffleManagerClient> managerClientSupplier,
       RssShuffleHandle<K, V, C> rssHandle,
       Function<String, Boolean> taskFailureCallback,
       TaskContext context) {
@@ -203,6 +205,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         shuffleManager,
         sparkConf,
         shuffleWriteClient,
+        managerClientSupplier,
         rssHandle,
         taskFailureCallback,
         shuffleManager.getShuffleHandleInfo(rssHandle),
@@ -528,14 +531,6 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     return shuffleWriteMetrics;
   }
 
-  private static ShuffleManagerClient createShuffleManagerClient(String host, 
int port)
-      throws IOException {
-    ClientType grpc = ClientType.GRPC;
-    // Host can be inferred from `spark.driver.bindAddress`, which would be 
set when SparkContext is
-    // constructed.
-    return 
ShuffleManagerClientFactory.getInstance().createShuffleManagerClient(grpc, 
host, port);
-  }
-
   private void throwFetchFailedIfNecessary(Exception e) {
     // The shuffleServer is registered only when a Block fails to be sent
     if (e instanceof RssSendFailedException) {
@@ -550,34 +545,28 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
               taskContext.stageAttemptNumber(),
               shuffleServerInfos,
               e.getMessage());
-      RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
-      String driver = rssConf.getString("driver.host", "");
-      int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
-      try (ShuffleManagerClient shuffleManagerClient = 
createShuffleManagerClient(driver, port)) {
-        RssReportShuffleWriteFailureResponse response =
-            shuffleManagerClient.reportShuffleWriteFailure(req);
-        if (response.getReSubmitWholeStage()) {
-          // The shuffle server is reassigned.
-          RssReassignServersRequest rssReassignServersRequest =
-              new RssReassignServersRequest(
-                  taskContext.stageId(),
-                  taskContext.stageAttemptNumber(),
-                  shuffleId,
-                  partitioner.numPartitions());
-          RssReassignServersResponse rssReassignServersResponse =
-              
shuffleManagerClient.reassignOnStageResubmit(rssReassignServersRequest);
-          LOG.info(
-              "Whether the reassignment is successful: {}",
-              rssReassignServersResponse.isNeedReassign());
-          // since we are going to roll out the whole stage, mapIndex 
shouldn't matter, hence -1 is
-          // provided.
-          FetchFailedException ffe =
-              RssSparkShuffleUtils.createFetchFailedException(
-                  shuffleId, -1, taskContext.stageAttemptNumber(), e);
-          throw new RssException(ffe);
-        }
-      } catch (IOException ioe) {
-        LOG.info("Error closing shuffle manager client with error:", ioe);
+      RssReportShuffleWriteFailureResponse response =
+          managerClientSupplier.get().reportShuffleWriteFailure(req);
+      if (response.getReSubmitWholeStage()) {
+        // The shuffle server is reassigned.
+        RssReassignServersRequest rssReassignServersRequest =
+            new RssReassignServersRequest(
+                taskContext.stageId(),
+                taskContext.stageAttemptNumber(),
+                shuffleId,
+                partitioner.numPartitions());
+        RssReassignServersResponse rssReassignServersResponse =
+            
managerClientSupplier.get().reassignOnStageResubmit(rssReassignServersRequest);
+        LOG.info(
+            "Whether the reassignment is successful: {}",
+            rssReassignServersResponse.isNeedReassign());
+        // since we are going to roll out the whole stage, mapIndex shouldn't 
matter, hence -1
+        // is
+        // provided.
+        FetchFailedException ffe =
+            RssSparkShuffleUtils.createFetchFailedException(
+                shuffleId, -1, taskContext.stageAttemptNumber(), e);
+        throw new RssException(ffe);
       }
     }
     throw new RssException(e);
diff --git 
a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
 
b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
index f09223b1c..78fe7dec0 100644
--- 
a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
+++ 
b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
@@ -35,9 +35,11 @@ import org.apache.spark.shuffle.RssShuffleHandle;
 import org.junit.jupiter.api.Test;
 import org.roaringbitmap.longlong.Roaring64NavigableMap;
 
+import org.apache.uniffle.client.api.ShuffleManagerClient;
 import org.apache.uniffle.common.ShuffleServerInfo;
 import org.apache.uniffle.common.config.RssClientConf;
 import org.apache.uniffle.common.config.RssConf;
+import org.apache.uniffle.common.util.ExpiringCloseableSupplier;
 import org.apache.uniffle.storage.handler.impl.HadoopShuffleWriteHandler;
 import org.apache.uniffle.storage.util.StorageType;
 
@@ -85,6 +87,7 @@ public class RssShuffleReaderTest extends 
AbstractRssReaderTest {
     rssConf.set(RssClientConf.RSS_STORAGE_TYPE, StorageType.HDFS.name());
     rssConf.set(RssClientConf.RSS_INDEX_READ_LIMIT, 1000);
     rssConf.set(RssClientConf.RSS_CLIENT_READ_BUFFER_SIZE, "1000");
+    final ShuffleManagerClient mockShuffleManagerClient = 
mock(ShuffleManagerClient.class);
     RssShuffleReader<String, String> rssShuffleReaderSpy =
         spy(
             new RssShuffleReader<>(
@@ -99,7 +102,8 @@ public class RssShuffleReaderTest extends 
AbstractRssReaderTest {
                 blockIdBitmap,
                 taskIdBitmap,
                 rssConf,
-                partitionToServers));
+                partitionToServers,
+                ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient)));
 
     validateResult(rssShuffleReaderSpy.read(), expectedData, 10);
   }
diff --git 
a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
 
b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
index e039ad9d5..779f94117 100644
--- 
a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
+++ 
b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
@@ -48,12 +48,14 @@ import org.apache.spark.shuffle.RssSparkConfig;
 import org.apache.spark.shuffle.handle.SimpleShuffleHandleInfo;
 import org.junit.jupiter.api.Test;
 
+import org.apache.uniffle.client.api.ShuffleManagerClient;
 import org.apache.uniffle.client.api.ShuffleWriteClient;
 import org.apache.uniffle.client.impl.FailedBlockSendTracker;
 import org.apache.uniffle.common.ShuffleBlockInfo;
 import org.apache.uniffle.common.ShuffleServerInfo;
 import org.apache.uniffle.common.config.RssConf;
 import org.apache.uniffle.common.rpc.StatusCode;
+import org.apache.uniffle.common.util.ExpiringCloseableSupplier;
 import org.apache.uniffle.storage.util.StorageType;
 
 import static org.junit.jupiter.api.Assertions.assertEquals;
@@ -88,6 +90,7 @@ public class RssShuffleWriterTest {
 
     Serializer kryoSerializer = new KryoSerializer(conf);
     ShuffleWriteClient mockShuffleWriteClient = mock(ShuffleWriteClient.class);
+    ShuffleManagerClient mockShuffleManagerClient = 
mock(ShuffleManagerClient.class);
     Partitioner mockPartitioner = mock(Partitioner.class);
     ShuffleDependency<String, String, String> mockDependency = 
mock(ShuffleDependency.class);
     RssShuffleHandle<String, String, String> mockHandle = 
mock(RssShuffleHandle.class);
@@ -124,6 +127,7 @@ public class RssShuffleWriterTest {
             manager,
             conf,
             mockShuffleWriteClient,
+            ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient),
             mockHandle,
             mockShuffleHandleInfo,
             contextMock);
@@ -234,6 +238,7 @@ public class RssShuffleWriterTest {
     Partitioner mockPartitioner = mock(Partitioner.class);
     ShuffleDependency<String, String, String> mockDependency = 
mock(ShuffleDependency.class);
     final ShuffleWriteClient mockShuffleWriteClient = 
mock(ShuffleWriteClient.class);
+    final ShuffleManagerClient mockShuffleManagerClient = 
mock(ShuffleManagerClient.class);
     RssShuffleHandle<String, String, String> mockHandle = 
mock(RssShuffleHandle.class);
     when(mockHandle.getDependency()).thenReturn(mockDependency);
     Serializer kryoSerializer = new KryoSerializer(conf);
@@ -299,6 +304,7 @@ public class RssShuffleWriterTest {
             manager,
             conf,
             mockShuffleWriteClient,
+            ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient),
             mockHandle,
             mockShuffleHandleInfo,
             contextMock);
@@ -348,6 +354,7 @@ public class RssShuffleWriterTest {
   @Test
   public void postBlockEventTest() throws Exception {
     final ShuffleWriteMetrics mockMetrics = mock(ShuffleWriteMetrics.class);
+    final ShuffleManagerClient mockShuffleManagerClient = 
mock(ShuffleManagerClient.class);
     ShuffleDependency<String, String, String> mockDependency = 
mock(ShuffleDependency.class);
     Partitioner mockPartitioner = mock(Partitioner.class);
     when(mockDependency.partitioner()).thenReturn(mockPartitioner);
@@ -411,6 +418,7 @@ public class RssShuffleWriterTest {
             manager,
             conf,
             mockWriteClient,
+            ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient),
             mockHandle,
             mockShuffleHandleInfo,
             contextMock);
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index bf42bf361..92e630df2 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -239,7 +239,7 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
       }
     }
     if (shuffleManagerRpcServiceEnabled) {
-      this.shuffleManagerClient = getOrCreateShuffleManagerClient();
+      getOrCreateShuffleManagerClientSupplier();
     }
     int unregisterThreadPoolSize =
         sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_THREAD_POOL_SIZE);
@@ -253,7 +253,7 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
             .createShuffleWriteClient(
                 RssShuffleClientFactory.newWriteBuilder()
                     .blockIdSelfManagedEnabled(blockIdSelfManagedEnabled)
-                    .shuffleManagerClient(shuffleManagerClient)
+                    .managerClientSupplier(managerClientSupplier)
                     .clientType(clientType)
                     .retryMax(retryMax)
                     .retryIntervalMax(retryIntervalMax)
@@ -523,6 +523,7 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
         this,
         sparkConf,
         shuffleWriteClient,
+        managerClientSupplier,
         rssHandle,
         this::markFailedTask,
         context);
@@ -696,6 +697,7 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
             blockIdBitmap, startPartition, endPartition, blockIdLayout),
         taskIdBitmap,
         readMetrics,
+        managerClientSupplier,
         RssSparkConfig.toRssConf(sparkConf),
         dataDistributionType,
         shuffleHandleInfo.getAllPartitionServersForReader());
@@ -853,6 +855,7 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
 
   @Override
   public void stop() {
+    super.stop();
     if (heartBeatScheduledExecutorService != null) {
       heartBeatScheduledExecutorService.shutdownNow();
     }
@@ -1031,7 +1034,7 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
           replicaRequirementTracking);
     } catch (RssFetchFailedException e) {
       throw RssSparkShuffleUtils.reportRssFetchFailedException(
-          e, sparkConf, appId, shuffleId, stageAttemptId, failedPartitions);
+          managerClientSupplier, e, sparkConf, appId, shuffleId, 
stageAttemptId, failedPartitions);
     }
   }
 
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
index bf47ced6b..19682bd65 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
@@ -19,6 +19,7 @@ package org.apache.spark.shuffle.reader;
 
 import java.util.List;
 import java.util.Map;
+import java.util.function.Supplier;
 
 import scala.Function0;
 import scala.Function1;
@@ -49,6 +50,7 @@ import org.roaringbitmap.longlong.Roaring64NavigableMap;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import org.apache.uniffle.client.api.ShuffleManagerClient;
 import org.apache.uniffle.client.api.ShuffleReadClient;
 import org.apache.uniffle.client.factory.ShuffleClientFactory;
 import org.apache.uniffle.client.util.RssClientConfig;
@@ -58,7 +60,6 @@ import org.apache.uniffle.common.config.RssClientConf;
 import org.apache.uniffle.common.config.RssConf;
 
 import static 
org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED;
-import static org.apache.uniffle.common.util.Constants.DRIVER_HOST;
 
 public class RssShuffleReader<K, C> implements ShuffleReader<K, C> {
   private static final Logger LOG = 
LoggerFactory.getLogger(RssShuffleReader.class);
@@ -83,6 +84,7 @@ public class RssShuffleReader<K, C> implements 
ShuffleReader<K, C> {
   private ShuffleReadMetrics readMetrics;
   private RssConf rssConf;
   private ShuffleDataDistributionType dataDistributionType;
+  private Supplier<ShuffleManagerClient> managerClientSupplier;
 
   public RssShuffleReader(
       int startPartition,
@@ -97,6 +99,7 @@ public class RssShuffleReader<K, C> implements 
ShuffleReader<K, C> {
       Map<Integer, Roaring64NavigableMap> partitionToExpectBlocks,
       Roaring64NavigableMap taskIdBitmap,
       ShuffleReadMetrics readMetrics,
+      Supplier<ShuffleManagerClient> managerClientSupplier,
       RssConf rssConf,
       ShuffleDataDistributionType dataDistributionType,
       Map<Integer, List<ShuffleServerInfo>> allPartitionToServers) {
@@ -120,6 +123,7 @@ public class RssShuffleReader<K, C> implements 
ShuffleReader<K, C> {
     this.partitionToShuffleServers = allPartitionToServers;
     this.rssConf = rssConf;
     this.dataDistributionType = dataDistributionType;
+    this.managerClientSupplier = managerClientSupplier;
   }
 
   @Override
@@ -193,16 +197,13 @@ public class RssShuffleReader<K, C> implements 
ShuffleReader<K, C> {
     // resubmit stage and shuffle manager server port are both set
     if (rssConf.getBoolean(RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED)
         && rssConf.getInteger(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT, 0) > 0) 
{
-      String driver = rssConf.getString(DRIVER_HOST, "");
-      int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
       resultIter =
           RssFetchFailedIterator.newBuilder()
               .appId(appId)
               .shuffleId(shuffleId)
               .partitionId(startPartition)
               .stageAttemptId(context.stageAttemptNumber())
-              .reportServerHost(driver)
-              .port(port)
+              .managerClientSupplier(managerClientSupplier)
               .build(resultIter);
     }
     return resultIter;
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 870141c4b..24a3b8c1c 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -17,7 +17,6 @@
 
 package org.apache.spark.shuffle.writer;
 
-import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
@@ -36,6 +35,7 @@ import java.util.concurrent.Future;
 import java.util.concurrent.LinkedBlockingQueue;
 import java.util.concurrent.TimeUnit;
 import java.util.function.Function;
+import java.util.function.Supplier;
 import java.util.stream.Collectors;
 
 import scala.Function1;
@@ -71,7 +71,6 @@ import org.slf4j.LoggerFactory;
 
 import org.apache.uniffle.client.api.ShuffleManagerClient;
 import org.apache.uniffle.client.api.ShuffleWriteClient;
-import org.apache.uniffle.client.factory.ShuffleManagerClientFactory;
 import org.apache.uniffle.client.impl.FailedBlockSendTracker;
 import org.apache.uniffle.client.impl.TrackingBlockStatus;
 import org.apache.uniffle.client.request.RssReassignOnBlockSendFailureRequest;
@@ -80,12 +79,10 @@ import 
org.apache.uniffle.client.request.RssReportShuffleWriteFailureRequest;
 import 
org.apache.uniffle.client.response.RssReassignOnBlockSendFailureResponse;
 import org.apache.uniffle.client.response.RssReassignServersResponse;
 import org.apache.uniffle.client.response.RssReportShuffleWriteFailureResponse;
-import org.apache.uniffle.common.ClientType;
 import org.apache.uniffle.common.ReceivingFailureServer;
 import org.apache.uniffle.common.ShuffleBlockInfo;
 import org.apache.uniffle.common.ShuffleServerInfo;
 import org.apache.uniffle.common.config.RssClientConf;
-import org.apache.uniffle.common.config.RssConf;
 import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.common.exception.RssSendFailedException;
 import org.apache.uniffle.common.exception.RssWaitFailedException;
@@ -143,6 +140,8 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
   private static final Set<StatusCode> STATUS_CODE_WITHOUT_BLOCK_RESEND =
       Sets.newHashSet(StatusCode.NO_REGISTER);
 
+  private final Supplier<ShuffleManagerClient> managerClientSupplier;
+
   // Only for tests
   @VisibleForTesting
   public RssShuffleWriter(
@@ -155,6 +154,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
       RssShuffleManager shuffleManager,
       SparkConf sparkConf,
       ShuffleWriteClient shuffleWriteClient,
+      Supplier<ShuffleManagerClient> managerClientSupplier,
       RssShuffleHandle<K, V, C> rssHandle,
       ShuffleHandleInfo shuffleHandleInfo,
       TaskContext context) {
@@ -167,6 +167,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         shuffleManager,
         sparkConf,
         shuffleWriteClient,
+        managerClientSupplier,
         rssHandle,
         (tid) -> true,
         shuffleHandleInfo,
@@ -184,6 +185,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
       RssShuffleManager shuffleManager,
       SparkConf sparkConf,
       ShuffleWriteClient shuffleWriteClient,
+      Supplier<ShuffleManagerClient> managerClientSupplier,
       RssShuffleHandle<K, V, C> rssHandle,
       Function<String, Boolean> taskFailureCallback,
       ShuffleHandleInfo shuffleHandleInfo,
@@ -217,6 +219,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     this.shuffleHandleInfo = shuffleHandleInfo;
     this.taskContext = context;
     this.sparkConf = sparkConf;
+    this.managerClientSupplier = managerClientSupplier;
     this.blockFailSentRetryEnabled =
         sparkConf.getBoolean(
             RssSparkConfig.SPARK_RSS_CONFIG_PREFIX
@@ -235,6 +238,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
       RssShuffleManager shuffleManager,
       SparkConf sparkConf,
       ShuffleWriteClient shuffleWriteClient,
+      Supplier<ShuffleManagerClient> managerClientSupplier,
       RssShuffleHandle<K, V, C> rssHandle,
       Function<String, Boolean> taskFailureCallback,
       TaskContext context) {
@@ -247,6 +251,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         shuffleManager,
         sparkConf,
         shuffleWriteClient,
+        managerClientSupplier,
         rssHandle,
         taskFailureCallback,
         shuffleManager.getShuffleHandleInfo(rssHandle),
@@ -618,14 +623,11 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     LOG.info(
         "Initiate reassignOnBlockSendFailure. failure partition servers: {}",
         failurePartitionToServers);
-    RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
-    String driver = rssConf.getString("driver.host", "");
-    int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
-    try (ShuffleManagerClient shuffleManagerClient = 
createShuffleManagerClient(driver, port)) {
-      String executorId = SparkEnv.get().executorId();
-      long taskAttemptId = taskContext.taskAttemptId();
-      int stageId = taskContext.stageId();
-      int stageAttemptNum = taskContext.stageAttemptNumber();
+    String executorId = SparkEnv.get().executorId();
+    long taskAttemptId = taskContext.taskAttemptId();
+    int stageId = taskContext.stageId();
+    int stageAttemptNum = taskContext.stageAttemptNumber();
+    try {
       RssReassignOnBlockSendFailureRequest request =
           new RssReassignOnBlockSendFailureRequest(
               shuffleId,
@@ -635,7 +637,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
               stageId,
               stageAttemptNum);
       RssReassignOnBlockSendFailureResponse response =
-          shuffleManagerClient.reassignOnBlockSendFailure(request);
+          managerClientSupplier.get().reassignOnBlockSendFailure(request);
       if (response.getStatusCode() != StatusCode.SUCCESS) {
         String msg =
             String.format(
@@ -835,14 +837,6 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     return bufferManager;
   }
 
-  private static ShuffleManagerClient createShuffleManagerClient(String host, 
int port)
-      throws IOException {
-    ClientType grpc = ClientType.GRPC;
-    // Host can be inferred from `spark.driver.bindAddress`, which would be 
set when SparkContext is
-    // constructed.
-    return 
ShuffleManagerClientFactory.getInstance().createShuffleManagerClient(grpc, 
host, port);
-  }
-
   private void throwFetchFailedIfNecessary(Exception e) {
     // The shuffleServer is registered only when a Block fails to be sent
     if (e instanceof RssSendFailedException) {
@@ -857,33 +851,27 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
               taskContext.stageAttemptNumber(),
               shuffleServerInfos,
               e.getMessage());
-      RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
-      String driver = rssConf.getString("driver.host", "");
-      int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
-      try (ShuffleManagerClient shuffleManagerClient = 
createShuffleManagerClient(driver, port)) {
-        RssReportShuffleWriteFailureResponse response =
-            shuffleManagerClient.reportShuffleWriteFailure(req);
-        if (response.getReSubmitWholeStage()) {
-          RssReassignServersRequest rssReassignServersRequest =
-              new RssReassignServersRequest(
-                  taskContext.stageId(),
-                  taskContext.stageAttemptNumber(),
-                  shuffleId,
-                  partitioner.numPartitions());
-          RssReassignServersResponse rssReassignServersResponse =
-              
shuffleManagerClient.reassignOnStageResubmit(rssReassignServersRequest);
-          LOG.info(
-              "Whether the reassignment is successful: {}",
-              rssReassignServersResponse.isNeedReassign());
-          // since we are going to roll out the whole stage, mapIndex 
shouldn't matter, hence -1 is
-          // provided.
-          FetchFailedException ffe =
-              RssSparkShuffleUtils.createFetchFailedException(
-                  shuffleId, -1, taskContext.stageAttemptNumber(), e);
-          throw new RssException(ffe);
-        }
-      } catch (IOException ioe) {
-        LOG.info("Error closing shuffle manager client with error:", ioe);
+      RssReportShuffleWriteFailureResponse response =
+          managerClientSupplier.get().reportShuffleWriteFailure(req);
+      if (response.getReSubmitWholeStage()) {
+        RssReassignServersRequest rssReassignServersRequest =
+            new RssReassignServersRequest(
+                taskContext.stageId(),
+                taskContext.stageAttemptNumber(),
+                shuffleId,
+                partitioner.numPartitions());
+        RssReassignServersResponse rssReassignServersResponse =
+            
managerClientSupplier.get().reassignOnStageResubmit(rssReassignServersRequest);
+        LOG.info(
+            "Whether the reassignment is successful: {}",
+            rssReassignServersResponse.isNeedReassign());
+        // since we are going to roll out the whole stage, mapIndex shouldn't 
matter, hence -1
+        // is
+        // provided.
+        FetchFailedException ffe =
+            RssSparkShuffleUtils.createFetchFailedException(
+                shuffleId, -1, taskContext.stageAttemptNumber(), e);
+        throw new RssException(ffe);
       }
     }
     throw new RssException(e);
diff --git 
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
 
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
index aaff4cb8e..bc77f7192 100644
--- 
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
+++ 
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
@@ -36,10 +36,12 @@ import org.apache.spark.shuffle.RssShuffleHandle;
 import org.junit.jupiter.api.Test;
 import org.roaringbitmap.longlong.Roaring64NavigableMap;
 
+import org.apache.uniffle.client.api.ShuffleManagerClient;
 import org.apache.uniffle.common.ShuffleDataDistributionType;
 import org.apache.uniffle.common.ShuffleServerInfo;
 import org.apache.uniffle.common.config.RssClientConf;
 import org.apache.uniffle.common.config.RssConf;
+import org.apache.uniffle.common.util.ExpiringCloseableSupplier;
 import org.apache.uniffle.storage.handler.impl.HadoopShuffleWriteHandler;
 import org.apache.uniffle.storage.util.StorageType;
 
@@ -93,6 +95,7 @@ public class RssShuffleReaderTest extends 
AbstractRssReaderTest {
     rssConf.set(RssClientConf.RSS_STORAGE_TYPE, StorageType.HDFS.name());
     rssConf.set(RssClientConf.RSS_INDEX_READ_LIMIT, 1000);
     rssConf.set(RssClientConf.RSS_CLIENT_READ_BUFFER_SIZE, "1000");
+    ShuffleManagerClient mockShuffleManagerClient = 
mock(ShuffleManagerClient.class);
     RssShuffleReader<String, String> rssShuffleReaderSpy =
         spy(
             new RssShuffleReader<>(
@@ -108,6 +111,7 @@ public class RssShuffleReaderTest extends 
AbstractRssReaderTest {
                 partitionToExpectBlocks,
                 taskIdBitmap,
                 new ShuffleReadMetrics(),
+                ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient),
                 rssConf,
                 ShuffleDataDistributionType.NORMAL,
                 partitionToServers));
@@ -131,6 +135,7 @@ public class RssShuffleReaderTest extends 
AbstractRssReaderTest {
                 partitionToExpectBlocks,
                 taskIdBitmap,
                 new ShuffleReadMetrics(),
+                ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient),
                 rssConf,
                 ShuffleDataDistributionType.NORMAL,
                 partitionToServers));
@@ -151,6 +156,7 @@ public class RssShuffleReaderTest extends 
AbstractRssReaderTest {
                 partitionToExpectBlocks,
                 Roaring64NavigableMap.bitmapOf(),
                 new ShuffleReadMetrics(),
+                ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient),
                 rssConf,
                 ShuffleDataDistributionType.NORMAL,
                 partitionToServers));
diff --git 
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
 
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
index 53a8e7143..a4317aae8 100644
--- 
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
+++ 
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
@@ -55,12 +55,14 @@ import 
org.apache.spark.shuffle.handle.SimpleShuffleHandleInfo;
 import org.awaitility.Awaitility;
 import org.junit.jupiter.api.Test;
 
+import org.apache.uniffle.client.api.ShuffleManagerClient;
 import org.apache.uniffle.client.api.ShuffleWriteClient;
 import org.apache.uniffle.client.impl.FailedBlockSendTracker;
 import org.apache.uniffle.common.RemoteStorageInfo;
 import org.apache.uniffle.common.ShuffleBlockInfo;
 import org.apache.uniffle.common.ShuffleServerInfo;
 import org.apache.uniffle.common.rpc.StatusCode;
+import org.apache.uniffle.common.util.ExpiringCloseableSupplier;
 import org.apache.uniffle.common.util.JavaUtils;
 import org.apache.uniffle.storage.util.StorageType;
 
@@ -133,6 +135,7 @@ public class RssShuffleWriterTest {
     Serializer kryoSerializer = new KryoSerializer(conf);
     Partitioner mockPartitioner = mock(Partitioner.class);
     final ShuffleWriteClient mockShuffleWriteClient = 
mock(ShuffleWriteClient.class);
+    final ShuffleManagerClient mockShuffleManagerClient = 
mock(ShuffleManagerClient.class);
     ShuffleDependency<String, String, String> mockDependency = 
mock(ShuffleDependency.class);
     RssShuffleHandle<String, String, String> mockHandle = 
mock(RssShuffleHandle.class);
     when(mockHandle.getDependency()).thenReturn(mockDependency);
@@ -179,6 +182,7 @@ public class RssShuffleWriterTest {
             manager,
             conf,
             mockShuffleWriteClient,
+            ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient),
             mockHandle,
             shuffleHandle,
             contextMock);
@@ -385,6 +389,7 @@ public class RssShuffleWriterTest {
     Serializer kryoSerializer = new KryoSerializer(conf);
     Partitioner mockPartitioner = mock(Partitioner.class);
     final ShuffleWriteClient mockShuffleWriteClient = 
mock(ShuffleWriteClient.class);
+    final ShuffleManagerClient mockShuffleManagerClient = 
mock(ShuffleManagerClient.class);
     ShuffleDependency<String, String, String> mockDependency = 
mock(ShuffleDependency.class);
     RssShuffleHandle<String, String, String> mockHandle = 
mock(RssShuffleHandle.class);
     when(mockHandle.getDependency()).thenReturn(mockDependency);
@@ -450,6 +455,7 @@ public class RssShuffleWriterTest {
             manager,
             conf,
             mockShuffleWriteClient,
+            ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient),
             mockHandle,
             shuffleHandleInfo,
             contextMock);
@@ -552,6 +558,7 @@ public class RssShuffleWriterTest {
             conf, false, null, successBlocks, taskToFailedBlockSendTracker);
 
     ShuffleWriteClient mockShuffleWriteClient = mock(ShuffleWriteClient.class);
+    ShuffleManagerClient mockShuffleManagerClient = 
mock(ShuffleManagerClient.class);
     Partitioner mockPartitioner = mock(Partitioner.class);
     RssShuffleHandle<String, String, String> mockHandle = 
mock(RssShuffleHandle.class);
     ShuffleDependency<String, String, String> mockDependency = 
mock(ShuffleDependency.class);
@@ -587,6 +594,7 @@ public class RssShuffleWriterTest {
             manager,
             conf,
             mockShuffleWriteClient,
+            ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient),
             mockHandle,
             mockShuffleHandleInfo,
             contextMock);
@@ -714,6 +722,7 @@ public class RssShuffleWriterTest {
     Serializer kryoSerializer = new KryoSerializer(conf);
     Partitioner mockPartitioner = mock(Partitioner.class);
     final ShuffleWriteClient mockShuffleWriteClient = 
mock(ShuffleWriteClient.class);
+    final ShuffleManagerClient mockShuffleManagerClient = 
mock(ShuffleManagerClient.class);
     ShuffleDependency<String, String, String> mockDependency = 
mock(ShuffleDependency.class);
     RssShuffleHandle<String, String, String> mockHandle = 
mock(RssShuffleHandle.class);
     when(mockHandle.getDependency()).thenReturn(mockDependency);
@@ -734,6 +743,7 @@ public class RssShuffleWriterTest {
             manager,
             conf,
             mockShuffleWriteClient,
+            ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient),
             mockHandle,
             mockShuffleHandleInfo,
             contextMock);
@@ -794,6 +804,7 @@ public class RssShuffleWriterTest {
     Serializer kryoSerializer = new KryoSerializer(conf);
     Partitioner mockPartitioner = mock(Partitioner.class);
     final ShuffleWriteClient mockShuffleWriteClient = 
mock(ShuffleWriteClient.class);
+    final ShuffleManagerClient mockShuffleManagerClient = 
mock(ShuffleManagerClient.class);
     ShuffleDependency<String, String, String> mockDependency = 
mock(ShuffleDependency.class);
     RssShuffleHandle<String, String, String> mockHandle = 
mock(RssShuffleHandle.class);
     when(mockHandle.getDependency()).thenReturn(mockDependency);
@@ -857,6 +868,7 @@ public class RssShuffleWriterTest {
             manager,
             conf,
             mockShuffleWriteClient,
+            ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient),
             mockHandle,
             mockShuffleHandleInfo,
             contextMock);
@@ -958,6 +970,7 @@ public class RssShuffleWriterTest {
     TaskContext contextMock = mock(TaskContext.class);
     SimpleShuffleHandleInfo mockShuffleHandleInfo = 
mock(SimpleShuffleHandleInfo.class);
     ShuffleWriteClient mockWriteClient = mock(ShuffleWriteClient.class);
+    ShuffleManagerClient mockShuffleManagerClient = 
mock(ShuffleManagerClient.class);
 
     List<ShuffleBlockInfo> shuffleBlockInfoList = createShuffleBlockList(1, 
31);
     RssShuffleWriter<String, String, String> writer =
@@ -971,6 +984,7 @@ public class RssShuffleWriterTest {
             mockShuffleManager,
             conf,
             mockWriteClient,
+            ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient),
             mockHandle,
             mockShuffleHandleInfo,
             contextMock);
diff --git 
a/common/src/main/java/org/apache/uniffle/common/util/ExpiringCloseableSupplier.java
 
b/common/src/main/java/org/apache/uniffle/common/util/ExpiringCloseableSupplier.java
new file mode 100644
index 000000000..f36f9be0c
--- /dev/null
+++ 
b/common/src/main/java/org/apache/uniffle/common/util/ExpiringCloseableSupplier.java
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.util;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.ScheduledFuture;
+import java.util.concurrent.TimeUnit;
+import java.util.function.Supplier;
+
+import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * A Supplier for T cacheable and autocloseable with delay by using 
ExpiringCloseableSupplier to
+ * obtain an object, manual closure may not be necessary.
+ */
+public class ExpiringCloseableSupplier<T extends StatefulCloseable>
+    implements Supplier<T>, Serializable {
+  private static final long serialVersionUID = 0;
+  private static final Logger LOG = 
LoggerFactory.getLogger(ExpiringCloseableSupplier.class);
+  private static final int DEFAULT_DELAY_CLOSE_INTERVAL = 60000;
+  private static final ScheduledExecutorService executor =
+      
ThreadUtils.getDaemonSingleThreadScheduledExecutor("ExpiringCloseableSupplier");
+
+  private final Supplier<T> delegate;
+  private final long delayCloseInterval;
+
+  private transient volatile ScheduledFuture<?> future;
+
+  @SuppressFBWarnings("SE_TRANSIENT_FIELD_NOT_RESTORED")
+  private transient volatile long accessTime = System.currentTimeMillis();
+
+  private transient volatile T t;
+
+  private ExpiringCloseableSupplier(Supplier<T> delegate, long 
delayCloseInterval) {
+    this.delegate = delegate;
+    this.delayCloseInterval = delayCloseInterval;
+  }
+
+  public synchronized T get() {
+    accessTime = System.currentTimeMillis();
+    if (t == null || t.isClosed()) {
+      this.t = delegate.get();
+      ensureCloseFutureScheduled();
+    }
+    return t;
+  }
+
+  public synchronized void close() {
+    try {
+      if (t != null && !t.isClosed()) {
+        t.close();
+      }
+    } catch (IOException ioe) {
+      LOG.warn("Failed to close {} the resource", t.getClass().getName(), ioe);
+    } finally {
+      this.t = null;
+      this.accessTime = System.currentTimeMillis();
+      cancelCloseFuture();
+    }
+  }
+
+  private void tryClose() {
+    if (System.currentTimeMillis() - accessTime > delayCloseInterval) {
+      close();
+    }
+  }
+
+  private void ensureCloseFutureScheduled() {
+    cancelCloseFuture();
+    this.future =
+        executor.scheduleAtFixedRate(
+            this::tryClose, delayCloseInterval, delayCloseInterval, 
TimeUnit.MILLISECONDS);
+  }
+
+  private void cancelCloseFuture() {
+    if (future != null && !future.isDone()) {
+      future.cancel(false);
+      this.future = null;
+    }
+  }
+
+  public static <T extends StatefulCloseable> ExpiringCloseableSupplier<T> of(
+      Supplier<T> delegate) {
+    return new ExpiringCloseableSupplier<>(delegate, 
DEFAULT_DELAY_CLOSE_INTERVAL);
+  }
+
+  public static <T extends StatefulCloseable> ExpiringCloseableSupplier<T> of(
+      Supplier<T> delegate, long delayCloseInterval) {
+    return new ExpiringCloseableSupplier<>(delegate, delayCloseInterval);
+  }
+}
diff --git 
a/common/src/main/java/org/apache/uniffle/common/util/StatefulCloseable.java 
b/common/src/main/java/org/apache/uniffle/common/util/StatefulCloseable.java
new file mode 100644
index 000000000..a4a2453d6
--- /dev/null
+++ b/common/src/main/java/org/apache/uniffle/common/util/StatefulCloseable.java
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.util;
+
+import java.io.Closeable;
+
+/** StatefulCloseable is an interface that utilizes the 
ExpiringCloseableSupplier delegate. */
+public interface StatefulCloseable extends Closeable {
+  boolean isClosed();
+}
diff --git 
a/common/src/test/java/org/apache/uniffle/common/util/ExpiringCloseableSupplierTest.java
 
b/common/src/test/java/org/apache/uniffle/common/util/ExpiringCloseableSupplierTest.java
new file mode 100644
index 000000000..0f791ceab
--- /dev/null
+++ 
b/common/src/test/java/org/apache/uniffle/common/util/ExpiringCloseableSupplierTest.java
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.util;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.List;
+import java.util.Random;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.function.Supplier;
+
+import com.google.common.collect.Lists;
+import com.google.common.util.concurrent.Uninterruptibles;
+import org.apache.commons.lang3.SerializationUtils;
+import org.awaitility.Awaitility;
+import org.junit.jupiter.api.Test;
+
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertNotSame;
+import static org.junit.jupiter.api.Assertions.assertSame;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+class ExpiringCloseableSupplierTest {
+
+  @Test
+  void testCacheable() {
+    Supplier<MockClient> cf = () -> new MockClient(false);
+    ExpiringCloseableSupplier<MockClient> mockClientSupplier = 
ExpiringCloseableSupplier.of(cf);
+
+    MockClient mockClient = mockClientSupplier.get();
+    MockClient mockClient2 = mockClientSupplier.get();
+    assertSame(mockClient, mockClient2);
+    mockClientSupplier.close();
+    mockClientSupplier.close();
+  }
+
+  @Test
+  void testAutoCloseable() {
+    Supplier<MockClient> cf = () -> new MockClient(true);
+    ExpiringCloseableSupplier<MockClient> mockClientSupplier = 
ExpiringCloseableSupplier.of(cf, 10);
+    MockClient mockClient1 = mockClientSupplier.get();
+    assertNotNull(mockClient1);
+    Uninterruptibles.sleepUninterruptibly(30, TimeUnit.MILLISECONDS);
+    assertTrue(mockClient1.isClosed());
+    MockClient mockClient2 = mockClientSupplier.get();
+    assertNotSame(mockClient1, mockClient2);
+    mockClientSupplier.close();
+  }
+
+  @Test
+  void testRenew() {
+    Supplier<MockClient> cf = () -> new MockClient(true);
+    ExpiringCloseableSupplier<MockClient> mockClientSupplier = 
ExpiringCloseableSupplier.of(cf);
+    MockClient mockClient = mockClientSupplier.get();
+    mockClientSupplier.close();
+    MockClient mockClient2 = mockClientSupplier.get();
+    assertNotSame(mockClient, mockClient2);
+  }
+
+  @Test
+  void testReClose() {
+    Supplier<MockClient> cf = () -> new MockClient(true);
+    ExpiringCloseableSupplier<MockClient> mockClientSupplier = 
ExpiringCloseableSupplier.of(cf);
+    mockClientSupplier.get();
+    mockClientSupplier.close();
+    mockClientSupplier.close();
+  }
+
+  @Test
+  void testDelegateExtendClose() throws IOException {
+    Supplier<MockClient> cf = () -> new MockClient(false);
+    ExpiringCloseableSupplier<MockClient> mockClientSupplier = 
ExpiringCloseableSupplier.of(cf);
+    MockClient mockClient = mockClientSupplier.get();
+    mockClient.close();
+    assertTrue(mockClient.isClosed());
+
+    MockClient mockClient1 = mockClientSupplier.get();
+    assertNotSame(mockClient, mockClient1);
+    MockClient mockClient2 = mockClientSupplier.get();
+    assertSame(mockClient1, mockClient2);
+    mockClientSupplier.close();
+  }
+
+  @Test
+  public void testSerialization() {
+    Supplier<MockClient> cf = (Supplier<MockClient> & Serializable) () -> new 
MockClient(true);
+    ExpiringCloseableSupplier<MockClient> mockClientSupplier = 
ExpiringCloseableSupplier.of(cf, 10);
+    MockClient mockClient = mockClientSupplier.get();
+
+    ExpiringCloseableSupplier<MockClient> mockClientSupplier2 =
+        SerializationUtils.roundtrip(mockClientSupplier);
+    MockClient mockClient2 = mockClientSupplier2.get();
+    assertFalse(mockClient2.isClosed());
+    assertNotSame(mockClient, mockClient2);
+    Uninterruptibles.sleepUninterruptibly(30, TimeUnit.MILLISECONDS);
+    assertTrue(mockClient.isClosed());
+    assertTrue(mockClient2.isClosed());
+  }
+
+  @Test
+  public void testMultipleSupplierShouldNotInterfere() {
+    Supplier<MockClient> cf = () -> new MockClient(true);
+    ExpiringCloseableSupplier<MockClient> mockClientSupplier = 
ExpiringCloseableSupplier.of(cf, 10);
+    ExpiringCloseableSupplier<MockClient> mockClientSupplier2 =
+        ExpiringCloseableSupplier.of(cf, 10);
+    MockClient mockClient = mockClientSupplier.get();
+    MockClient mockClient2 = mockClientSupplier2.get();
+    Uninterruptibles.sleepUninterruptibly(30, TimeUnit.MILLISECONDS);
+    assertTrue(mockClient.isClosed());
+    assertTrue(mockClient2.isClosed());
+    mockClientSupplier.close();
+    mockClientSupplier.close();
+    mockClientSupplier2.close();
+    mockClientSupplier2.close();
+  }
+
+  @Test
+  public void stressingTestManySuppliers() {
+    int num = 100000; // this should be sufficient for most production use 
cases
+    Supplier<MockClient> cf = () -> new MockClient(true);
+    List<MockClient> clients = Lists.newArrayList();
+    Random random = new Random(42);
+    for (int i = 0; i < num; i++) {
+      int delayCloseInterval = random.nextInt(1000) + 1;
+      ExpiringCloseableSupplier<MockClient> mockClientSupplier =
+          ExpiringCloseableSupplier.of(cf, delayCloseInterval);
+      MockClient mockClient = mockClientSupplier.get();
+      clients.add(mockClient);
+    }
+    Awaitility.waitAtMost(5, TimeUnit.SECONDS)
+        .until(() -> clients.stream().allMatch(MockClient::isClosed));
+  }
+
+  private static class MockClient implements StatefulCloseable, Serializable {
+    boolean withException;
+    AtomicBoolean closed = new AtomicBoolean(false);
+
+    MockClient(boolean withException) {
+      this.withException = withException;
+    }
+
+    @Override
+    public void close() throws IOException {
+      closed.set(true);
+      if (withException) {
+        throw new IOException("test exception!");
+      }
+    }
+
+    @Override
+    public boolean isClosed() {
+      return closed.get();
+    }
+  }
+}
diff --git 
a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/ShuffleServerManagerTestBase.java
 
b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/ShuffleServerManagerTestBase.java
index 831fa0f2f..abe3a9dfa 100644
--- 
a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/ShuffleServerManagerTestBase.java
+++ 
b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/ShuffleServerManagerTestBase.java
@@ -23,6 +23,7 @@ import org.junit.jupiter.api.BeforeEach;
 import org.apache.uniffle.client.factory.ShuffleManagerClientFactory;
 import org.apache.uniffle.client.impl.grpc.ShuffleManagerGrpcClient;
 import org.apache.uniffle.common.ClientType;
+import org.apache.uniffle.common.config.RssBaseConf;
 import org.apache.uniffle.common.config.RssConf;
 import org.apache.uniffle.common.rpc.GrpcServer;
 import org.apache.uniffle.shuffle.manager.DummyRssShuffleManager;
@@ -36,12 +37,17 @@ public class ShuffleServerManagerTestBase {
   protected ShuffleManagerGrpcClient client;
   protected static final String LOCALHOST = "localhost";
   protected GrpcServer shuffleManagerServer;
+  protected RssConf rssConf;
 
   protected RssShuffleManagerInterface getShuffleManager() {
     return new DummyRssShuffleManager();
   }
 
-  protected RssConf getConf() {
+  protected ShuffleServerManagerTestBase() {
+    this.rssConf = getRssConf();
+  }
+
+  private RssConf getRssConf() {
     RssConf conf = new RssConf();
     // use a random port
     conf.set(RPC_SERVER_PORT, 0);
@@ -49,7 +55,7 @@ public class ShuffleServerManagerTestBase {
   }
 
   protected GrpcServer createShuffleManagerServer() {
-    return new ShuffleManagerServerFactory(getShuffleManager(), 
getConf()).getServer();
+    return new ShuffleManagerServerFactory(getShuffleManager(), 
rssConf).getServer();
   }
 
   @BeforeEach
@@ -57,7 +63,8 @@ public class ShuffleServerManagerTestBase {
     shuffleManagerServer = createShuffleManagerServer();
     shuffleManagerServer.start();
     int port = shuffleManagerServer.getPort();
-    client = factory.createShuffleManagerClient(ClientType.GRPC, LOCALHOST, 
port);
+    long rpcTimeout = 
rssConf.getLong(RssBaseConf.RSS_CLIENT_TYPE_GRPC_TIMEOUT_MS);
+    client = factory.createShuffleManagerClient(ClientType.GRPC, LOCALHOST, 
port, rpcTimeout);
   }
 
   @AfterEach
diff --git 
a/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java
 
b/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java
index c5b412a9e..6616fe7b1 100644
--- 
a/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java
+++ 
b/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java
@@ -17,8 +17,6 @@
 
 package org.apache.uniffle.client.api;
 
-import java.io.Closeable;
-
 import 
org.apache.uniffle.client.request.RssGetShuffleResultForMultiPartRequest;
 import org.apache.uniffle.client.request.RssGetShuffleResultRequest;
 import org.apache.uniffle.client.request.RssPartitionToShuffleServerRequest;
@@ -34,8 +32,9 @@ import 
org.apache.uniffle.client.response.RssReassignServersResponse;
 import org.apache.uniffle.client.response.RssReportShuffleFetchFailureResponse;
 import org.apache.uniffle.client.response.RssReportShuffleResultResponse;
 import org.apache.uniffle.client.response.RssReportShuffleWriteFailureResponse;
+import org.apache.uniffle.common.util.StatefulCloseable;
 
-public interface ShuffleManagerClient extends Closeable {
+public interface ShuffleManagerClient extends StatefulCloseable {
   RssReportShuffleFetchFailureResponse reportShuffleFetchFailure(
       RssReportShuffleFetchFailureRequest request);
 
diff --git 
a/internal-client/src/main/java/org/apache/uniffle/client/factory/ShuffleManagerClientFactory.java
 
b/internal-client/src/main/java/org/apache/uniffle/client/factory/ShuffleManagerClientFactory.java
index c55acdc22..66b4a2a9e 100644
--- 
a/internal-client/src/main/java/org/apache/uniffle/client/factory/ShuffleManagerClientFactory.java
+++ 
b/internal-client/src/main/java/org/apache/uniffle/client/factory/ShuffleManagerClientFactory.java
@@ -33,9 +33,9 @@ public class ShuffleManagerClientFactory {
   private ShuffleManagerClientFactory() {}
 
   public ShuffleManagerGrpcClient createShuffleManagerClient(
-      ClientType clientType, String host, int port) {
+      ClientType clientType, String host, int port, long rpcTimeout) {
     if (ClientType.GRPC.equals(clientType)) {
-      return new ShuffleManagerGrpcClient(host, port);
+      return new ShuffleManagerGrpcClient(host, port, rpcTimeout);
     } else {
       throw new UnsupportedOperationException("Unsupported client type " + 
clientType);
     }
diff --git 
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java
 
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java
index 6dd9f4a1e..8cad876c2 100644
--- 
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java
+++ 
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java
@@ -38,7 +38,6 @@ import 
org.apache.uniffle.client.response.RssReassignServersResponse;
 import org.apache.uniffle.client.response.RssReportShuffleFetchFailureResponse;
 import org.apache.uniffle.client.response.RssReportShuffleResultResponse;
 import org.apache.uniffle.client.response.RssReportShuffleWriteFailureResponse;
-import org.apache.uniffle.common.config.RssBaseConf;
 import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.proto.RssProtos;
 import org.apache.uniffle.proto.RssProtos.ReportShuffleFetchFailureRequest;
@@ -48,22 +47,22 @@ import org.apache.uniffle.proto.ShuffleManagerGrpc;
 public class ShuffleManagerGrpcClient extends GrpcClient implements 
ShuffleManagerClient {
 
   private static final Logger LOG = 
LoggerFactory.getLogger(ShuffleManagerGrpcClient.class);
-  private static RssBaseConf rssConf = new RssBaseConf();
-  private long rpcTimeout = 
rssConf.getLong(RssBaseConf.RSS_CLIENT_TYPE_GRPC_TIMEOUT_MS);
+  private final long rpcTimeout;
   private ShuffleManagerGrpc.ShuffleManagerBlockingStub blockingStub;
 
-  public ShuffleManagerGrpcClient(String host, int port) {
-    this(host, port, 3);
+  public ShuffleManagerGrpcClient(String host, int port, long rpcTimeout) {
+    this(host, port, rpcTimeout, 3);
   }
 
-  public ShuffleManagerGrpcClient(String host, int port, int maxRetryAttempts) 
{
-    this(host, port, maxRetryAttempts, true);
+  public ShuffleManagerGrpcClient(String host, int port, long rpcTimeout, int 
maxRetryAttempts) {
+    this(host, port, rpcTimeout, maxRetryAttempts, true);
   }
 
   public ShuffleManagerGrpcClient(
-      String host, int port, int maxRetryAttempts, boolean usePlaintext) {
+      String host, int port, long rpcTimeout, int maxRetryAttempts, boolean 
usePlaintext) {
     super(host, port, maxRetryAttempts, usePlaintext);
     blockingStub = ShuffleManagerGrpc.newBlockingStub(channel);
+    this.rpcTimeout = rpcTimeout;
   }
 
   public ShuffleManagerGrpc.ShuffleManagerBlockingStub getBlockingStub() {
@@ -165,4 +164,9 @@ public class ShuffleManagerGrpcClient extends GrpcClient 
implements ShuffleManag
         getBlockingStub().reportShuffleResult(request.toProto());
     return RssReportShuffleResultResponse.fromProto(response);
   }
+
+  @Override
+  public boolean isClosed() {
+    return channel.isShutdown() || channel.isTerminated();
+  }
 }
diff --git 
a/internal-client/src/test/java/org/apache/uniffle/client/factory/ShuffleManagerClientFactoryTest.java
 
b/internal-client/src/test/java/org/apache/uniffle/client/factory/ShuffleManagerClientFactoryTest.java
index 5fed54ff0..c40c06c32 100644
--- 
a/internal-client/src/test/java/org/apache/uniffle/client/factory/ShuffleManagerClientFactoryTest.java
+++ 
b/internal-client/src/test/java/org/apache/uniffle/client/factory/ShuffleManagerClientFactoryTest.java
@@ -32,10 +32,11 @@ class ShuffleManagerClientFactoryTest {
     ShuffleManagerClientFactory factory = 
ShuffleManagerClientFactory.getInstance();
     assertNotNull(factory);
     // only grpc type is supported currently
-    ShuffleManagerClient c = 
factory.createShuffleManagerClient(ClientType.GRPC, "localhost", 1234);
+    ShuffleManagerClient c =
+        factory.createShuffleManagerClient(ClientType.GRPC, "localhost", 1234, 
60000);
     assertNotNull(c);
     assertThrows(
         UnsupportedOperationException.class,
-        () -> factory.createShuffleManagerClient(ClientType.GRPC_NETTY, 
"localhost", 1234));
+        () -> factory.createShuffleManagerClient(ClientType.GRPC_NETTY, 
"localhost", 1234, 60000));
   }
 }

Reply via email to