This is an automated email from the ASF dual-hosted git repository.

maobaolong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new 43879ca13 [MINOR] improvement(spark-client): Refactor 
RssShuffleManager for spark v2 and v3 to reduce redundant code (#2330)
43879ca13 is described below

commit 43879ca13aa261e6a775178c7015b8cf758053ab
Author: maobaolong <[email protected]>
AuthorDate: Thu Jan 9 20:46:04 2025 +0800

    [MINOR] improvement(spark-client): Refactor RssShuffleManager for spark v2 
and v3 to reduce redundant code (#2330)
    
    ### What changes were proposed in this pull request?
    
    Refactor and abstract the same code into base class.
    
    ### Why are the changes needed?
    
    Reduce the redundant code and simplify the development in rss spark-client 
scope.
    
    The Spark version unrelated code should be placed into 
RssShuffleManagerBase class, for this RssShuffleManager Class, it should only 
maintains the spark api related codes.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    - Existing UTs.
    - Our test cluster, tested both spark v2 and v3 version.
---
 .../shuffle/manager/RssShuffleManagerBase.java     | 452 ++++++++++++++++++-
 .../apache/spark/shuffle/RssShuffleManager.java    | 409 ++----------------
 .../apache/spark/shuffle/RssShuffleManager.java    | 480 ++-------------------
 3 files changed, 521 insertions(+), 820 deletions(-)

diff --git 
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
index c1fc5b68e..d869c64fe 100644
--- 
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
+++ 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
@@ -17,6 +17,7 @@
 
 package org.apache.uniffle.shuffle.manager;
 
+import java.io.IOException;
 import java.lang.reflect.InvocationTargetException;
 import java.lang.reflect.Method;
 import java.util.ArrayList;
@@ -28,6 +29,9 @@ import java.util.List;
 import java.util.Map;
 import java.util.Optional;
 import java.util.Set;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.Function;
@@ -40,6 +44,7 @@ import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.Maps;
 import com.google.common.collect.Sets;
 import org.apache.commons.collections4.CollectionUtils;
+import org.apache.commons.lang3.StringUtils;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.security.UserGroupInformation;
 import org.apache.spark.MapOutputTracker;
@@ -58,6 +63,8 @@ import 
org.apache.spark.shuffle.handle.MutableShuffleHandleInfo;
 import org.apache.spark.shuffle.handle.ShuffleHandleInfo;
 import org.apache.spark.shuffle.handle.SimpleShuffleHandleInfo;
 import org.apache.spark.shuffle.handle.StageAttemptShuffleHandleInfo;
+import org.apache.spark.shuffle.writer.AddBlockEvent;
+import org.apache.spark.shuffle.writer.DataPusher;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -65,6 +72,7 @@ import org.apache.uniffle.client.api.ShuffleManagerClient;
 import org.apache.uniffle.client.api.ShuffleWriteClient;
 import org.apache.uniffle.client.factory.CoordinatorClientFactory;
 import org.apache.uniffle.client.factory.ShuffleManagerClientFactory;
+import org.apache.uniffle.client.impl.FailedBlockSendTracker;
 import org.apache.uniffle.client.impl.grpc.CoordinatorGrpcRetryableClient;
 import org.apache.uniffle.client.request.RssFetchClientConfRequest;
 import org.apache.uniffle.client.request.RssPartitionToShuffleServerRequest;
@@ -83,17 +91,38 @@ import org.apache.uniffle.common.config.ConfigOption;
 import org.apache.uniffle.common.config.RssClientConf;
 import org.apache.uniffle.common.config.RssConf;
 import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.rpc.GrpcServer;
 import org.apache.uniffle.common.rpc.StatusCode;
+import org.apache.uniffle.common.util.BlockIdLayout;
 import org.apache.uniffle.common.util.Constants;
 import org.apache.uniffle.common.util.ExpiringCloseableSupplier;
+import org.apache.uniffle.common.util.JavaUtils;
 import org.apache.uniffle.common.util.RetryUtils;
+import org.apache.uniffle.common.util.RssUtils;
+import org.apache.uniffle.common.util.ThreadUtils;
 import org.apache.uniffle.shuffle.BlockIdManager;
 
+import static 
org.apache.spark.shuffle.RssSparkConfig.RSS_BLOCK_ID_SELF_MANAGEMENT_ENABLED;
+import static 
org.apache.spark.shuffle.RssSparkConfig.RSS_PARTITION_REASSIGN_MAX_REASSIGNMENT_SERVER_NUM;
+import static 
org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED;
+import static 
org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_WRITE_FAILURE_ENABLED;
+import static org.apache.uniffle.common.config.RssBaseConf.RPC_SERVER_PORT;
 import static 
org.apache.uniffle.common.config.RssClientConf.HADOOP_CONFIG_KEY_PREFIX;
+import static 
org.apache.uniffle.common.config.RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE;
 import static 
org.apache.uniffle.common.config.RssClientConf.RSS_CLIENT_REMOTE_STORAGE_USE_LOCAL_CONF_ENABLED;
 
 public abstract class RssShuffleManagerBase implements 
RssShuffleManagerInterface, ShuffleManager {
   private static final Logger LOG = 
LoggerFactory.getLogger(RssShuffleManagerBase.class);
+  protected final int dataTransferPoolSize;
+  protected final int dataCommitPoolSize;
+  protected final int dataReplica;
+  protected final int dataReplicaWrite;
+  protected final int dataReplicaRead;
+  protected final boolean dataReplicaSkipEnabled;
+  protected final Map<String, Set<Long>> taskToSuccessBlockIds;
+  protected final Map<String, FailedBlockSendTracker> 
taskToFailedBlockSendTracker;
+  private Set<String> failedTaskIds = Sets.newConcurrentHashSet();
+
   private AtomicBoolean isInitialized = new AtomicBoolean(false);
   private Method unregisterAllMapOutputMethod;
   private Method registerShuffleMethod;
@@ -107,7 +136,8 @@ public abstract class RssShuffleManagerBase implements 
RssShuffleManagerInterfac
   protected int maxConcurrencyPerPartitionToWrite;
   protected String clientType;
 
-  protected SparkConf sparkConf;
+  protected final SparkConf sparkConf;
+  protected final RssConf rssConf;
   protected Map<Integer, Integer> shuffleIdToPartitionNum;
   protected Map<Integer, Integer> shuffleIdToNumMapTasks;
   protected Supplier<ShuffleManagerClient> managerClientSupplier;
@@ -126,9 +156,227 @@ public abstract class RssShuffleManagerBase implements 
RssShuffleManagerInterfac
   protected boolean partitionReassignEnabled;
   protected boolean shuffleManagerRpcServiceEnabled;
 
-  public RssShuffleManagerBase() {
+  protected boolean heartbeatStarted = false;
+  protected final long heartbeatInterval;
+  protected final long heartbeatTimeout;
+  protected String user;
+  protected String uuid;
+  protected ScheduledExecutorService heartBeatScheduledExecutorService;
+  protected final int maxFailures;
+  protected final boolean speculation;
+  protected final BlockIdLayout blockIdLayout;
+  private ShuffleManagerGrpcService service;
+  protected GrpcServer shuffleManagerServer;
+  protected DataPusher dataPusher;
+
+  public RssShuffleManagerBase(SparkConf conf, boolean isDriver) {
     LOG.info(
         "Uniffle {} version: {}", this.getClass().getName(), 
Constants.VERSION_AND_REVISION_SHORT);
+    this.sparkConf = conf;
+    checkSupported(sparkConf);
+    boolean supportsRelocation =
+        Optional.ofNullable(SparkEnv.get())
+            .map(env -> 
env.serializer().supportsRelocationOfSerializedObjects())
+            .orElse(true);
+    if (!supportsRelocation) {
+      LOG.warn(
+          "RSSShuffleManager requires a serializer which supports relocations 
of serialized object. Please set "
+              + "spark.serializer to 
org.apache.spark.serializer.KryoSerializer instead");
+    }
+    this.user = sparkConf.get("spark.rss.quota.user", "user");
+    this.uuid = sparkConf.get("spark.rss.quota.uuid", 
Long.toString(System.currentTimeMillis()));
+    this.dynamicConfEnabled = 
sparkConf.get(RssSparkConfig.RSS_DYNAMIC_CLIENT_CONF_ENABLED);
+
+    // fetch client conf and apply them if necessary
+    if (isDriver && this.dynamicConfEnabled) {
+      fetchAndApplyDynamicConf(sparkConf);
+    }
+    RssSparkShuffleUtils.validateRssClientConf(sparkConf);
+
+    // convert spark conf to rss conf after fetching dynamic client conf
+    this.rssConf = RssSparkConfig.toRssConf(sparkConf);
+    RssUtils.setExtraJavaProperties(rssConf);
+
+    // set & check replica config
+    this.dataReplica = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA);
+    this.dataReplicaWrite = 
sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_WRITE);
+    this.dataReplicaRead = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_READ);
+    this.dataReplicaSkipEnabled = 
sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_SKIP_ENABLED);
+    LOG.info(
+        "Check quorum config ["
+            + dataReplica
+            + ":"
+            + dataReplicaWrite
+            + ":"
+            + dataReplicaRead
+            + ":"
+            + dataReplicaSkipEnabled
+            + "]");
+    RssUtils.checkQuorumSetting(dataReplica, dataReplicaWrite, 
dataReplicaRead);
+
+    this.maxConcurrencyPerPartitionToWrite = 
rssConf.get(MAX_CONCURRENCY_PER_PARTITION_TO_WRITE);
+
+    this.clientType = sparkConf.get(RssSparkConfig.RSS_CLIENT_TYPE);
+
+    // configure block id layout
+    this.maxFailures = sparkConf.getInt("spark.task.maxFailures", 4);
+    this.speculation = sparkConf.getBoolean("spark.speculation", false);
+    // configureBlockIdLayout requires maxFailures and speculation to be 
initialized
+    configureBlockIdLayout(sparkConf, rssConf);
+    this.blockIdLayout = BlockIdLayout.from(rssConf);
+
+    this.dataTransferPoolSize = 
sparkConf.get(RssSparkConfig.RSS_DATA_TRANSFER_POOL_SIZE);
+    this.dataCommitPoolSize = 
sparkConf.get(RssSparkConfig.RSS_DATA_COMMIT_POOL_SIZE);
+
+    // External shuffle service is not supported when using remote shuffle 
service
+    sparkConf.set("spark.shuffle.service.enabled", "false");
+    sparkConf.set("spark.dynamicAllocation.shuffleTracking.enabled", "false");
+    sparkConf.set(RssSparkConfig.RSS_ENABLED.key(), "true");
+    LOG.info("Disable external shuffle service in RssShuffleManager.");
+    sparkConf.set("spark.sql.adaptive.localShuffleReader.enabled", "false");
+    LOG.info("Disable local shuffle reader in RssShuffleManager.");
+    // If we store shuffle data in distributed filesystem or in a disaggregated
+    // shuffle cluster, we don't need shuffle data locality
+    sparkConf.set("spark.shuffle.reduceLocality.enabled", "false");
+    LOG.info("Disable shuffle data locality in RssShuffleManager.");
+
+    taskToSuccessBlockIds = JavaUtils.newConcurrentMap();
+    taskToFailedBlockSendTracker = JavaUtils.newConcurrentMap();
+    this.shuffleIdToPartitionNum = JavaUtils.newConcurrentMap();
+    this.shuffleIdToNumMapTasks = JavaUtils.newConcurrentMap();
+
+    // stage retry for write/fetch failure
+    rssStageRetryForFetchFailureEnabled =
+        rssConf.get(RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED);
+    rssStageRetryForWriteFailureEnabled =
+        rssConf.get(RSS_RESUBMIT_STAGE_WITH_WRITE_FAILURE_ENABLED);
+    if (rssStageRetryForFetchFailureEnabled || 
rssStageRetryForWriteFailureEnabled) {
+      rssStageRetryEnabled = true;
+      List<String> logTips = new ArrayList<>();
+      if (rssStageRetryForWriteFailureEnabled) {
+        logTips.add("write");
+      }
+      if (rssStageRetryForWriteFailureEnabled) {
+        logTips.add("fetch");
+      }
+      LOG.info(
+          "Activate the stage retry mechanism that will resubmit stage on {} 
failure",
+          StringUtils.join(logTips, "/"));
+    }
+
+    this.partitionReassignEnabled = 
rssConf.get(RssClientConf.RSS_CLIENT_REASSIGN_ENABLED);
+    // The feature of partition reassign is exclusive with multiple replicas 
and stage retry.
+    if (partitionReassignEnabled) {
+      if (dataReplica > 1) {
+        throw new RssException(
+            "The feature of task partition reassign is incompatible with 
multiple replicas mechanism.");
+      }
+    }
+    this.blockIdSelfManagedEnabled = 
rssConf.getBoolean(RSS_BLOCK_ID_SELF_MANAGEMENT_ENABLED);
+    this.shuffleManagerRpcServiceEnabled =
+        partitionReassignEnabled || rssStageRetryEnabled || 
blockIdSelfManagedEnabled;
+
+    if (isDriver) {
+      heartBeatScheduledExecutorService =
+          ThreadUtils.getDaemonSingleThreadScheduledExecutor("rss-heartbeat");
+      if (shuffleManagerRpcServiceEnabled) {
+        LOG.info("stage resubmit is supported and enabled");
+        // start shuffle manager server
+        rssConf.set(RPC_SERVER_PORT, 0);
+        ShuffleManagerServerFactory factory = new 
ShuffleManagerServerFactory(this, rssConf);
+        service = factory.getService();
+        shuffleManagerServer = factory.getServer(service);
+        try {
+          shuffleManagerServer.start();
+          // pass this as a spark.rss.shuffle.manager.grpc.port config, so it 
can be propagated to
+          // executor properly.
+          sparkConf.set(
+              RssSparkConfig.RSS_SHUFFLE_MANAGER_GRPC_PORT, 
shuffleManagerServer.getPort());
+        } catch (Exception e) {
+          LOG.error("Failed to start shuffle manager server", e);
+          throw new RssException(e);
+        }
+      }
+    }
+    if (shuffleManagerRpcServiceEnabled) {
+      getOrCreateShuffleManagerClientSupplier();
+    }
+
+    // Start heartbeat thread.
+    this.heartbeatInterval = 
sparkConf.get(RssSparkConfig.RSS_HEARTBEAT_INTERVAL);
+    this.heartbeatTimeout =
+        sparkConf.getLong(RssSparkConfig.RSS_HEARTBEAT_TIMEOUT.key(), 
heartbeatInterval / 2);
+    heartBeatScheduledExecutorService =
+        ThreadUtils.getDaemonSingleThreadScheduledExecutor("rss-heartbeat");
+
+    this.shuffleWriteClient = createShuffleWriteClient();
+    registerCoordinator();
+
+    LOG.info("Rss data pusher is starting...");
+    int poolSize = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_SIZE);
+    int keepAliveTime = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_KEEPALIVE);
+    this.dataPusher =
+        new DataPusher(
+            shuffleWriteClient,
+            taskToSuccessBlockIds,
+            taskToFailedBlockSendTracker,
+            failedTaskIds,
+            poolSize,
+            keepAliveTime);
+    this.partitionReassignMaxServerNum =
+        rssConf.get(RSS_PARTITION_REASSIGN_MAX_REASSIGNMENT_SERVER_NUM);
+    this.shuffleHandleInfoManager = new ShuffleHandleInfoManager();
+    this.rssStageResubmitManager = new RssStageResubmitManager();
+  }
+
+  @VisibleForTesting
+  protected RssShuffleManagerBase(
+      SparkConf conf,
+      boolean isDriver,
+      DataPusher dataPusher,
+      Map<String, Set<Long>> taskToSuccessBlockIds,
+      Map<String, FailedBlockSendTracker> taskToFailedBlockSendTracker) {
+    this.sparkConf = conf;
+    this.clientType = sparkConf.get(RssSparkConfig.RSS_CLIENT_TYPE);
+    this.rssConf = RssSparkConfig.toRssConf(sparkConf);
+    this.dataDistributionType = 
rssConf.get(RssClientConf.DATA_DISTRIBUTION_TYPE);
+    this.blockIdLayout = BlockIdLayout.from(rssConf);
+    this.maxConcurrencyPerPartitionToWrite = 
rssConf.get(MAX_CONCURRENCY_PER_PARTITION_TO_WRITE);
+    this.maxFailures = sparkConf.getInt("spark.task.maxFailures", 4);
+    this.speculation = sparkConf.getBoolean("spark.speculation", false);
+    // configureBlockIdLayout requires maxFailures and speculation to be 
initialized
+    configureBlockIdLayout(sparkConf, rssConf);
+    this.heartbeatInterval = 
sparkConf.get(RssSparkConfig.RSS_HEARTBEAT_INTERVAL);
+    this.heartbeatTimeout =
+        sparkConf.getLong(RssSparkConfig.RSS_HEARTBEAT_TIMEOUT.key(), 
heartbeatInterval / 2);
+    this.dataReplica = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA);
+    this.dataReplicaWrite = 
sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_WRITE);
+    this.dataReplicaRead = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_READ);
+    this.dataReplicaSkipEnabled = 
sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_SKIP_ENABLED);
+    LOG.info(
+        "Check quorum config ["
+            + dataReplica
+            + ":"
+            + dataReplicaWrite
+            + ":"
+            + dataReplicaRead
+            + ":"
+            + dataReplicaSkipEnabled
+            + "]");
+    RssUtils.checkQuorumSetting(dataReplica, dataReplicaWrite, 
dataReplicaRead);
+
+    this.dataTransferPoolSize = 
sparkConf.get(RssSparkConfig.RSS_DATA_TRANSFER_POOL_SIZE);
+    this.dataCommitPoolSize = 
sparkConf.get(RssSparkConfig.RSS_DATA_COMMIT_POOL_SIZE);
+    createShuffleWriteClient();
+
+    this.taskToSuccessBlockIds = taskToSuccessBlockIds;
+    this.heartBeatScheduledExecutorService = null;
+    this.taskToFailedBlockSendTracker = taskToFailedBlockSendTracker;
+    this.dataPusher = dataPusher;
+    this.partitionReassignMaxServerNum =
+        rssConf.get(RSS_PARTITION_REASSIGN_MAX_REASSIGNMENT_SERVER_NUM);
+    this.shuffleHandleInfoManager = new ShuffleHandleInfoManager();
+    this.rssStageResubmitManager = new RssStageResubmitManager();
   }
 
   public BlockIdManager getBlockIdManager() {
@@ -145,14 +393,35 @@ public abstract class RssShuffleManagerBase implements 
RssShuffleManagerInterfac
 
   @Override
   public boolean unregisterShuffle(int shuffleId) {
-    if (blockIdManager != null) {
-      blockIdManager.remove(shuffleId);
+    try {
+      if (blockIdManager != null) {
+        blockIdManager.remove(shuffleId);
+      }
+      if (SparkEnv.get().executorId().equals("driver")) {
+        shuffleWriteClient.unregisterShuffle(getAppId(), shuffleId);
+        shuffleIdToPartitionNum.remove(shuffleId);
+        shuffleIdToNumMapTasks.remove(shuffleId);
+        if (service != null) {
+          service.unregisterShuffle(shuffleId);
+        }
+      }
+    } catch (Exception e) {
+      LOG.warn("Errors on unregistering from remote shuffle-servers", e);
     }
     return true;
   }
 
-  /** See static overload of this method. */
-  public abstract void configureBlockIdLayout(SparkConf sparkConf, RssConf 
rssConf);
+  /**
+   * Derives block id layout config from maximum number of allowed partitions. 
Computes the number
+   * of required bits for partition id and task attempt id and reserves 
remaining bits for sequence
+   * number.
+   *
+   * @param sparkConf Spark config providing max partitions
+   * @param rssConf Rss config to amend
+   */
+  public void configureBlockIdLayout(SparkConf sparkConf, RssConf rssConf) {
+    configureBlockIdLayout(sparkConf, rssConf, maxFailures, speculation);
+  }
 
   /**
    * Derives block id layout config from maximum number of allowed partitions. 
This value can be set
@@ -344,7 +613,10 @@ public abstract class RssShuffleManagerBase implements 
RssShuffleManagerInterfac
   }
 
   /** See static overload of this method. */
-  public abstract long getTaskAttemptIdForBlockId(int mapIndex, int attemptNo);
+  public long getTaskAttemptIdForBlockId(int mapIndex, int attemptNo) {
+    return getTaskAttemptIdForBlockId(
+        mapIndex, attemptNo, maxFailures, speculation, 
blockIdLayout.taskAttemptIdBits);
+  }
 
   /**
    * Provides a task attempt id to be used in the block id, that is unique for 
a shuffle stage.
@@ -809,7 +1081,8 @@ public abstract class RssShuffleManagerBase implements 
RssShuffleManagerInterfac
         LOG.info(
             "Register the new partition->servers assignment on reassign. {}",
             newServerToPartitions);
-        registerShuffleServers(id.get(), shuffleId, newServerToPartitions, 
getRemoteStorageInfo());
+        registerShuffleServers(
+            getAppId(), shuffleId, newServerToPartitions, 
getRemoteStorageInfo());
       }
 
       LOG.info(
@@ -852,8 +1125,76 @@ public abstract class RssShuffleManagerBase implements 
RssShuffleManagerInterfac
         && managerClientSupplier instanceof ExpiringCloseableSupplier) {
       ((ExpiringCloseableSupplier<ShuffleManagerClient>) 
managerClientSupplier).close();
     }
+    if (heartBeatScheduledExecutorService != null) {
+      heartBeatScheduledExecutorService.shutdownNow();
+    }
+    if (shuffleWriteClient != null) {
+      // Unregister shuffle before closing shuffle write client.
+      shuffleWriteClient.unregisterShuffle(getAppId());
+      shuffleWriteClient.close();
+    }
+    if (dataPusher != null) {
+      try {
+        dataPusher.close();
+      } catch (IOException e) {
+        LOG.warn("Errors on closing data pusher", e);
+      }
+    }
+
+    if (shuffleManagerServer != null) {
+      try {
+        shuffleManagerServer.stop();
+      } catch (InterruptedException e) {
+        // ignore
+        LOG.info("shuffle manager server is interrupted during stop");
+      }
+    }
+  }
+
+  /** @return the unique spark id for rss shuffle */
+  @Override
+  public String getAppId() {
+    return id.get();
+  }
+
+  @Override
+  public int getPartitionNum(int shuffleId) {
+    return shuffleIdToPartitionNum.getOrDefault(shuffleId, 0);
   }
 
+  /**
+   * @param shuffleId the shuffle id to query
+   * @return the num of map tasks for current shuffle with shuffle id.
+   */
+  @Override
+  public int getNumMaps(int shuffleId) {
+    return shuffleIdToNumMapTasks.getOrDefault(shuffleId, 0);
+  }
+
+  @VisibleForTesting
+  public void addSuccessBlockIds(String taskId, Set<Long> blockIds) {
+    if (taskToSuccessBlockIds.get(taskId) == null) {
+      taskToSuccessBlockIds.put(taskId, Sets.newHashSet());
+    }
+    taskToSuccessBlockIds.get(taskId).addAll(blockIds);
+  }
+
+  @VisibleForTesting
+  public void addFailedBlockSendTracker(
+      String taskId, FailedBlockSendTracker failedBlockSendTracker) {
+    taskToFailedBlockSendTracker.putIfAbsent(taskId, failedBlockSendTracker);
+  }
+
+  /** Create the shuffleWriteClient. */
+  protected abstract ShuffleWriteClient createShuffleWriteClient();
+
+  /**
+   * Check whether the configuration is supported.
+   *
+   * @param sparkConf the sparkConf
+   */
+  protected void checkSupported(SparkConf sparkConf) {}
+
   /**
    * Creating the shuffleAssignmentInfo from the servers and partitionIds
    *
@@ -931,7 +1272,7 @@ public abstract class RssShuffleManagerBase implements 
RssShuffleManagerInterfac
     try {
       ShuffleAssignmentsInfo response =
           shuffleWriteClient.getShuffleAssignments(
-              id.get(),
+              getAppId(),
               shuffleId,
               partitionNum,
               partitionNumPerRange,
@@ -949,7 +1290,7 @@ public abstract class RssShuffleManagerBase implements 
RssShuffleManagerInterfac
         response = reassignmentHandler.apply(response);
       }
       registerShuffleServers(
-          id.get(), shuffleId, response.getServerToPartitionRanges(), 
getRemoteStorageInfo());
+          getAppId(), shuffleId, response.getServerToPartitionRanges(), 
getRemoteStorageInfo());
       return response.getPartitionToServers();
     } catch (Throwable throwable) {
       throw new RssException("registerShuffle failed!", throwable);
@@ -1131,4 +1472,95 @@ public abstract class RssShuffleManagerBase implements 
RssShuffleManagerInterfac
   public ShuffleWriteClient getShuffleWriteClient() {
     return shuffleWriteClient;
   }
+
+  protected synchronized void startHeartbeat() {
+    shuffleWriteClient.registerApplicationInfo(getAppId(), heartbeatTimeout, 
user);
+    if (!sparkConf.getBoolean(RssSparkConfig.RSS_TEST_FLAG.key(), false) && 
!heartbeatStarted) {
+      heartBeatScheduledExecutorService.scheduleAtFixedRate(
+          () -> {
+            try {
+              String appId = getAppId();
+              shuffleWriteClient.sendAppHeartbeat(appId, heartbeatTimeout);
+              LOG.info("Finish send heartbeat to coordinator and servers");
+            } catch (Exception e) {
+              LOG.warn("Fail to send heartbeat to coordinator and servers", e);
+            }
+          },
+          heartbeatInterval / 2,
+          heartbeatInterval,
+          TimeUnit.MILLISECONDS);
+      heartbeatStarted = true;
+    }
+  }
+
+  public void clearTaskMeta(String taskId) {
+    taskToSuccessBlockIds.remove(taskId);
+    taskToFailedBlockSendTracker.remove(taskId);
+  }
+
+  @VisibleForTesting
+  protected void registerCoordinator() {
+    String coordinators = 
sparkConf.get(RssSparkConfig.RSS_COORDINATOR_QUORUM.key());
+    LOG.info("Start Registering coordinators {}", coordinators);
+    shuffleWriteClient.registerCoordinators(
+        coordinators,
+        this.sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_INTERVAL_MAX),
+        this.sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_MAX));
+  }
+
+  public Set<Long> getFailedBlockIds(String taskId) {
+    FailedBlockSendTracker blockIdsFailedSendTracker = 
getBlockIdsFailedSendTracker(taskId);
+    if (blockIdsFailedSendTracker == null) {
+      return Collections.emptySet();
+    }
+    return blockIdsFailedSendTracker.getFailedBlockIds();
+  }
+
+  public Set<Long> getSuccessBlockIds(String taskId) {
+    Set<Long> result = taskToSuccessBlockIds.get(taskId);
+    if (result == null) {
+      result = Collections.emptySet();
+    }
+    return result;
+  }
+
+  public FailedBlockSendTracker getBlockIdsFailedSendTracker(String taskId) {
+    return taskToFailedBlockSendTracker.get(taskId);
+  }
+
+  public boolean markFailedTask(String taskId) {
+    LOG.info("Mark the task: {} failed.", taskId);
+    failedTaskIds.add(taskId);
+    return true;
+  }
+
+  public boolean isValidTask(String taskId) {
+    return !failedTaskIds.contains(taskId);
+  }
+
+  @VisibleForTesting
+  public void setDataPusher(DataPusher dataPusher) {
+    this.dataPusher = dataPusher;
+  }
+
+  public DataPusher getDataPusher() {
+    return dataPusher;
+  }
+
+  @VisibleForTesting
+  public Map<String, Set<Long>> getTaskToSuccessBlockIds() {
+    return taskToSuccessBlockIds;
+  }
+
+  @VisibleForTesting
+  public Map<String, FailedBlockSendTracker> getTaskToFailedBlockSendTracker() 
{
+    return taskToFailedBlockSendTracker;
+  }
+
+  public CompletableFuture<Long> sendData(AddBlockEvent event) {
+    if (dataPusher != null && event != null) {
+      return dataPusher.send(event);
+    }
+    return new CompletableFuture<>();
+  }
 }
diff --git 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index d62827243..8e6b8dfca 100644
--- 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -17,15 +17,10 @@
 
 package org.apache.spark.shuffle;
 
-import java.io.IOException;
-import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
-import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.ScheduledExecutorService;
-import java.util.concurrent.TimeUnit;
 
 import scala.Option;
 import scala.Tuple2;
@@ -34,7 +29,6 @@ import scala.collection.Seq;
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.Sets;
-import org.apache.commons.lang3.StringUtils;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.spark.ShuffleDependency;
 import org.apache.spark.SparkConf;
@@ -47,8 +41,6 @@ import org.apache.spark.shuffle.handle.ShuffleHandleInfo;
 import org.apache.spark.shuffle.handle.SimpleShuffleHandleInfo;
 import org.apache.spark.shuffle.handle.StageAttemptShuffleHandleInfo;
 import org.apache.spark.shuffle.reader.RssShuffleReader;
-import org.apache.spark.shuffle.writer.AddBlockEvent;
-import org.apache.spark.shuffle.writer.DataPusher;
 import org.apache.spark.shuffle.writer.RssShuffleWriter;
 import org.apache.spark.storage.BlockId;
 import org.apache.spark.storage.BlockManagerId;
@@ -56,205 +48,20 @@ import org.roaringbitmap.longlong.Roaring64NavigableMap;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import org.apache.uniffle.client.impl.FailedBlockSendTracker;
+import org.apache.uniffle.client.api.ShuffleWriteClient;
 import org.apache.uniffle.client.util.ClientUtils;
 import org.apache.uniffle.common.RemoteStorageInfo;
 import org.apache.uniffle.common.ShuffleServerInfo;
-import org.apache.uniffle.common.config.RssClientConf;
-import org.apache.uniffle.common.config.RssConf;
 import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.common.exception.RssFetchFailedException;
-import org.apache.uniffle.common.rpc.GrpcServer;
-import org.apache.uniffle.common.util.BlockIdLayout;
-import org.apache.uniffle.common.util.JavaUtils;
-import org.apache.uniffle.common.util.RssUtils;
-import org.apache.uniffle.common.util.ThreadUtils;
 import org.apache.uniffle.shuffle.RssShuffleClientFactory;
 import org.apache.uniffle.shuffle.manager.RssShuffleManagerBase;
-import org.apache.uniffle.shuffle.manager.ShuffleManagerGrpcService;
-import org.apache.uniffle.shuffle.manager.ShuffleManagerServerFactory;
-
-import static 
org.apache.spark.shuffle.RssSparkConfig.RSS_BLOCK_ID_SELF_MANAGEMENT_ENABLED;
-import static 
org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED;
-import static 
org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_WRITE_FAILURE_ENABLED;
-import static org.apache.uniffle.common.config.RssBaseConf.RPC_SERVER_PORT;
-import static 
org.apache.uniffle.common.config.RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE;
 
 public class RssShuffleManager extends RssShuffleManagerBase {
-
   private static final Logger LOG = 
LoggerFactory.getLogger(RssShuffleManager.class);
-  private final long heartbeatInterval;
-  private final long heartbeatTimeout;
-  private ScheduledExecutorService heartBeatScheduledExecutorService;
-  private Map<String, Set<Long>> taskToSuccessBlockIds = 
JavaUtils.newConcurrentMap();
-  private Map<String, FailedBlockSendTracker> taskToFailedBlockSendTracker =
-      JavaUtils.newConcurrentMap();
-  private final int dataReplica;
-  private final int dataReplicaWrite;
-  private final int dataReplicaRead;
-  private final boolean dataReplicaSkipEnabled;
-  private final int dataTransferPoolSize;
-  private final int dataCommitPoolSize;
-  private Set<String> failedTaskIds = Sets.newConcurrentHashSet();
-  private boolean heartbeatStarted = false;
-  private final int maxFailures;
-  private final boolean speculation;
-  private final BlockIdLayout blockIdLayout;
-  private final String user;
-  private final String uuid;
-  private DataPusher dataPusher;
-  private GrpcServer shuffleManagerServer;
-  private ShuffleManagerGrpcService service;
 
   public RssShuffleManager(SparkConf sparkConf, boolean isDriver) {
-    if (sparkConf.getBoolean("spark.sql.adaptive.enabled", false)) {
-      throw new IllegalArgumentException(
-          "Spark2 doesn't support AQE, spark.sql.adaptive.enabled should be 
false.");
-    }
-    this.sparkConf = sparkConf;
-    this.user = sparkConf.get("spark.rss.quota.user", "user");
-    this.uuid = sparkConf.get("spark.rss.quota.uuid", 
Long.toString(System.currentTimeMillis()));
-    this.dynamicConfEnabled = 
sparkConf.get(RssSparkConfig.RSS_DYNAMIC_CLIENT_CONF_ENABLED);
-
-    // fetch client conf and apply them if necessary
-    if (isDriver && this.dynamicConfEnabled) {
-      fetchAndApplyDynamicConf(sparkConf);
-    }
-    RssSparkShuffleUtils.validateRssClientConf(sparkConf);
-
-    // configure block id layout
-    this.maxFailures = sparkConf.getInt("spark.task.maxFailures", 4);
-    this.speculation = sparkConf.getBoolean("spark.speculation", false);
-    RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
-    RssUtils.setExtraJavaProperties(rssConf);
-    // configureBlockIdLayout requires maxFailures and speculation to be 
initialized
-    configureBlockIdLayout(sparkConf, rssConf);
-    this.blockIdLayout = BlockIdLayout.from(rssConf);
-
-    // set & check replica config
-    this.dataReplica = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA);
-    this.dataReplicaWrite = 
sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_WRITE);
-    this.dataReplicaRead = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_READ);
-    this.dataTransferPoolSize = 
sparkConf.get(RssSparkConfig.RSS_DATA_TRANSFER_POOL_SIZE);
-    this.dataReplicaSkipEnabled = 
sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_SKIP_ENABLED);
-    this.maxConcurrencyPerPartitionToWrite =
-        
RssSparkConfig.toRssConf(sparkConf).get(MAX_CONCURRENCY_PER_PARTITION_TO_WRITE);
-    LOG.info(
-        "Check quorum config [{}:{}:{}:{}]",
-        dataReplica,
-        dataReplicaWrite,
-        dataReplicaRead,
-        dataReplicaSkipEnabled);
-    RssUtils.checkQuorumSetting(dataReplica, dataReplicaWrite, 
dataReplicaRead);
-
-    this.clientType = sparkConf.get(RssSparkConfig.RSS_CLIENT_TYPE);
-    this.heartbeatInterval = 
sparkConf.get(RssSparkConfig.RSS_HEARTBEAT_INTERVAL);
-    this.heartbeatTimeout =
-        sparkConf.getLong(RssSparkConfig.RSS_HEARTBEAT_TIMEOUT.key(), 
heartbeatInterval / 2);
-    int retryMax = sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_MAX);
-    long retryIntervalMax = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_INTERVAL_MAX);
-    int heartBeatThreadNum = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_HEARTBEAT_THREAD_NUM);
-    this.dataCommitPoolSize = 
sparkConf.get(RssSparkConfig.RSS_DATA_COMMIT_POOL_SIZE);
-    int unregisterThreadPoolSize =
-        sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_THREAD_POOL_SIZE);
-    int unregisterTimeoutSec = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_TIMEOUT_SEC);
-    int unregisterRequestTimeoutSec =
-        
sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_REQUEST_TIMEOUT_SEC);
-    // External shuffle service is not supported when using remote shuffle 
service
-    sparkConf.set("spark.shuffle.service.enabled", "false");
-    LOG.info("Disable external shuffle service in RssShuffleManager.");
-    // If we store shuffle data in distributed filesystem or in a disaggregated
-    // shuffle cluster, we don't need shuffle data locality
-    sparkConf.set("spark.shuffle.reduceLocality.enabled", "false");
-    LOG.info("Disable shuffle data locality in RssShuffleManager.");
-
-    this.shuffleIdToPartitionNum = JavaUtils.newConcurrentMap();
-    this.shuffleIdToNumMapTasks = JavaUtils.newConcurrentMap();
-    // stage retry for write/fetch failure
-    rssStageRetryForFetchFailureEnabled =
-        rssConf.get(RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED);
-    rssStageRetryForWriteFailureEnabled =
-        rssConf.get(RSS_RESUBMIT_STAGE_WITH_WRITE_FAILURE_ENABLED);
-    if (rssStageRetryForFetchFailureEnabled || 
rssStageRetryForWriteFailureEnabled) {
-      rssStageRetryEnabled = true;
-      List<String> logTips = new ArrayList<>();
-      if (rssStageRetryForWriteFailureEnabled) {
-        logTips.add("write");
-      }
-      if (rssStageRetryForWriteFailureEnabled) {
-        logTips.add("fetch");
-      }
-      LOG.info(
-          "Activate the stage retry mechanism that will resubmit stage on {} 
failure",
-          StringUtils.join(logTips, "/"));
-    }
-    this.partitionReassignEnabled = 
rssConf.getBoolean(RssClientConf.RSS_CLIENT_REASSIGN_ENABLED);
-    this.blockIdSelfManagedEnabled = 
rssConf.getBoolean(RSS_BLOCK_ID_SELF_MANAGEMENT_ENABLED);
-    this.shuffleManagerRpcServiceEnabled =
-        partitionReassignEnabled || rssStageRetryEnabled || 
blockIdSelfManagedEnabled;
-    if (!sparkConf.getBoolean(RssSparkConfig.RSS_TEST_FLAG.key(), false)) {
-      if (isDriver) {
-        heartBeatScheduledExecutorService =
-            
ThreadUtils.getDaemonSingleThreadScheduledExecutor("rss-heartbeat");
-        if (shuffleManagerRpcServiceEnabled) {
-          LOG.info("stage resubmit is supported and enabled");
-          // start shuffle manager server
-          rssConf.set(RPC_SERVER_PORT, 0);
-          ShuffleManagerServerFactory factory = new 
ShuffleManagerServerFactory(this, rssConf);
-          service = factory.getService();
-          shuffleManagerServer = factory.getServer(service);
-          try {
-            shuffleManagerServer.start();
-            // pass this as a spark.rss.shuffle.manager.grpc.port config, so 
it can be propagated to
-            // executor properly.
-            sparkConf.set(
-                RssSparkConfig.RSS_SHUFFLE_MANAGER_GRPC_PORT, 
shuffleManagerServer.getPort());
-          } catch (Exception e) {
-            LOG.error("Failed to start shuffle manager server", e);
-            throw new RssException(e);
-          }
-        }
-      }
-      if (shuffleManagerRpcServiceEnabled) {
-        getOrCreateShuffleManagerClientSupplier();
-      }
-      this.shuffleWriteClient =
-          RssShuffleClientFactory.getInstance()
-              .createShuffleWriteClient(
-                  RssShuffleClientFactory.newWriteBuilder()
-                      .blockIdSelfManagedEnabled(blockIdSelfManagedEnabled)
-                      .managerClientSupplier(managerClientSupplier)
-                      .clientType(clientType)
-                      .retryMax(retryMax)
-                      .retryIntervalMax(retryIntervalMax)
-                      .heartBeatThreadNum(heartBeatThreadNum)
-                      .replica(dataReplica)
-                      .replicaWrite(dataReplicaWrite)
-                      .replicaRead(dataReplicaRead)
-                      .replicaSkipEnabled(dataReplicaSkipEnabled)
-                      .dataTransferPoolSize(dataTransferPoolSize)
-                      .dataCommitPoolSize(dataCommitPoolSize)
-                      .unregisterThreadPoolSize(unregisterThreadPoolSize)
-                      .unregisterTimeSec(unregisterTimeoutSec)
-                      .unregisterRequestTimeSec(unregisterRequestTimeoutSec)
-                      .rssConf(rssConf));
-      registerCoordinator();
-
-      // for non-driver executor, start a thread for sending shuffle data to 
shuffle server
-      LOG.info("RSS data pusher is starting...");
-      int poolSize = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_SIZE);
-      int keepAliveTime = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_KEEPALIVE);
-      this.dataPusher =
-          new DataPusher(
-              shuffleWriteClient,
-              taskToSuccessBlockIds,
-              taskToFailedBlockSendTracker,
-              failedTaskIds,
-              poolSize,
-              keepAliveTime);
-    }
-    this.shuffleHandleInfoManager = new ShuffleHandleInfoManager();
-    this.rssStageResubmitManager = new RssStageResubmitManager();
+    super(sparkConf, isDriver);
   }
 
   // This method is called in Spark driver side,
@@ -380,42 +187,6 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
     return new RssShuffleHandle(shuffleId, appId, numMaps, dependency, 
hdlInfoBd);
   }
 
-  private void startHeartbeat() {
-    shuffleWriteClient.registerApplicationInfo(appId, heartbeatTimeout, user);
-    if (!sparkConf.getBoolean(RssSparkConfig.RSS_TEST_FLAG.key(), false) && 
!heartbeatStarted) {
-      heartBeatScheduledExecutorService.scheduleAtFixedRate(
-          () -> {
-            try {
-              shuffleWriteClient.sendAppHeartbeat(appId, heartbeatTimeout);
-              LOG.info("Finish send heartbeat to coordinator and servers");
-            } catch (Exception e) {
-              LOG.warn("Fail to send heartbeat to coordinator and servers", e);
-            }
-          },
-          heartbeatInterval / 2,
-          heartbeatInterval,
-          TimeUnit.MILLISECONDS);
-      heartbeatStarted = true;
-    }
-  }
-
-  @VisibleForTesting
-  protected void registerCoordinator() {
-    String coordinators = 
sparkConf.get(RssSparkConfig.RSS_COORDINATOR_QUORUM.key());
-    LOG.info("Registering coordinators {}", coordinators);
-    shuffleWriteClient.registerCoordinators(
-        coordinators,
-        this.sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_INTERVAL_MAX),
-        this.sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_MAX));
-  }
-
-  public CompletableFuture<Long> sendData(AddBlockEvent event) {
-    if (dataPusher != null && event != null) {
-      return dataPusher.send(event);
-    }
-    return new CompletableFuture<>();
-  }
-
   // This method is called in Spark executor,
   // getting information from Spark driver via the ShuffleHandle.
   @Override
@@ -447,24 +218,6 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
     }
   }
 
-  /**
-   * Derives block id layout config from maximum number of allowed partitions. 
Computes the number
-   * of required bits for partition id and task attempt id and reserves 
remaining bits for sequence
-   * number.
-   *
-   * @param sparkConf Spark config providing max partitions
-   * @param rssConf Rss config to amend
-   */
-  public void configureBlockIdLayout(SparkConf sparkConf, RssConf rssConf) {
-    configureBlockIdLayout(sparkConf, rssConf, maxFailures, speculation);
-  }
-
-  @Override
-  public long getTaskAttemptIdForBlockId(int mapIndex, int attemptNo) {
-    return getTaskAttemptIdForBlockId(
-        mapIndex, attemptNo, maxFailures, speculation, 
blockIdLayout.taskAttemptIdBits);
-  }
-
   // This method is called in Spark executor,
   // getting information from Spark driver via the ShuffleHandle.
   @Override
@@ -563,44 +316,6 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
     return null;
   }
 
-  @Override
-  public boolean unregisterShuffle(int shuffleId) {
-    try {
-      super.unregisterShuffle(shuffleId);
-      if (SparkEnv.get().executorId().equals("driver")) {
-        shuffleWriteClient.unregisterShuffle(appId, shuffleId);
-        shuffleIdToNumMapTasks.remove(shuffleId);
-        shuffleIdToPartitionNum.remove(shuffleId);
-        if (service != null) {
-          service.unregisterShuffle(shuffleId);
-        }
-      }
-    } catch (Exception e) {
-      LOG.warn("Errors on unregistering from remote shuffle-servers", e);
-    }
-    return true;
-  }
-
-  @Override
-  public void stop() {
-    super.stop();
-    if (heartBeatScheduledExecutorService != null) {
-      heartBeatScheduledExecutorService.shutdownNow();
-    }
-    if (dataPusher != null) {
-      try {
-        dataPusher.close();
-      } catch (IOException e) {
-        LOG.warn("Errors on closing data pusher", e);
-      }
-    }
-    if (shuffleWriteClient != null) {
-      // Unregister shuffle before closing shuffle write client.
-      shuffleWriteClient.unregisterShuffle(appId);
-      shuffleWriteClient.close();
-    }
-  }
-
   @Override
   public ShuffleBlockResolver shuffleBlockResolver() {
     throw new RssException("RssShuffleManager.shuffleBlockResolver is not 
implemented");
@@ -631,93 +346,17 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
     return taskIdBitmap;
   }
 
-  public Set<Long> getFailedBlockIds(String taskId) {
-    FailedBlockSendTracker blockIdsFailedSendTracker = 
getBlockIdsFailedSendTracker(taskId);
-    if (blockIdsFailedSendTracker == null) {
-      return Collections.emptySet();
-    }
-    return blockIdsFailedSendTracker.getFailedBlockIds();
-  }
-
-  public Set<Long> getSuccessBlockIds(String taskId) {
-    Set<Long> result = taskToSuccessBlockIds.get(taskId);
-    if (result == null) {
-      result = Collections.emptySet();
-    }
-    return result;
-  }
-
-  @VisibleForTesting
-  public void addSuccessBlockIds(String taskId, Set<Long> blockIds) {
-    if (taskToSuccessBlockIds.get(taskId) == null) {
-      taskToSuccessBlockIds.put(taskId, Sets.newHashSet());
-    }
-    taskToSuccessBlockIds.get(taskId).addAll(blockIds);
-  }
-
-  @VisibleForTesting
-  public void addFailedBlockSendTracker(
-      String taskId, FailedBlockSendTracker failedBlockSendTracker) {
-    taskToFailedBlockSendTracker.putIfAbsent(taskId, failedBlockSendTracker);
-  }
-
-  public void clearTaskMeta(String taskId) {
-    taskToSuccessBlockIds.remove(taskId);
-    taskToFailedBlockSendTracker.remove(taskId);
-  }
-
-  @VisibleForTesting
-  public SparkConf getSparkConf() {
-    return sparkConf;
-  }
-
   @VisibleForTesting
   public void setAppId(String appId) {
     this.appId = appId;
   }
 
-  public boolean markFailedTask(String taskId) {
-    LOG.info("Mark the task: {} failed.", taskId);
-    failedTaskIds.add(taskId);
-    return true;
-  }
-
-  public boolean isValidTask(String taskId) {
-    return !failedTaskIds.contains(taskId);
-  }
-
-  public DataPusher getDataPusher() {
-    return dataPusher;
-  }
-
-  public void setDataPusher(DataPusher dataPusher) {
-    this.dataPusher = dataPusher;
-  }
-
   /** @return the unique spark id for rss shuffle */
   @Override
   public String getAppId() {
     return appId;
   }
 
-  /**
-   * @param shuffleId the shuffleId to query
-   * @return the num of partitions(a.k.a reduce tasks) for shuffle with 
shuffle id.
-   */
-  @Override
-  public int getPartitionNum(int shuffleId) {
-    return shuffleIdToPartitionNum.getOrDefault(shuffleId, 0);
-  }
-
-  /**
-   * @param shuffleId the shuffle id to query
-   * @return the num of map tasks for current shuffle with shuffle id.
-   */
-  @Override
-  public int getNumMaps(int shuffleId) {
-    return shuffleIdToNumMapTasks.getOrDefault(shuffleId, 0);
-  }
-
   private Roaring64NavigableMap getShuffleResult(
       String clientType,
       Set<ShuffleServerInfo> shuffleServerInfoSet,
@@ -740,10 +379,6 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
     }
   }
 
-  public FailedBlockSendTracker getBlockIdsFailedSendTracker(String taskId) {
-    return taskToFailedBlockSendTracker.get(taskId);
-  }
-
   private ShuffleServerInfo assignShuffleServer(int shuffleId, String 
faultyShuffleServerId) {
     Set<String> faultyServerIds = Sets.newHashSet(faultyShuffleServerId);
     faultyServerIds.addAll(rssStageResubmitManager.getServerIdBlackList());
@@ -754,4 +389,44 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
     }
     return null;
   }
+
+  @Override
+  protected ShuffleWriteClient createShuffleWriteClient() {
+    int unregisterThreadPoolSize =
+        sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_THREAD_POOL_SIZE);
+    int unregisterTimeoutSec = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_TIMEOUT_SEC);
+    int unregisterRequestTimeoutSec =
+        
sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_REQUEST_TIMEOUT_SEC);
+    long retryIntervalMax = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_INTERVAL_MAX);
+    int heartBeatThreadNum = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_HEARTBEAT_THREAD_NUM);
+
+    final int retryMax = sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_MAX);
+    return RssShuffleClientFactory.getInstance()
+        .createShuffleWriteClient(
+            RssShuffleClientFactory.newWriteBuilder()
+                .blockIdSelfManagedEnabled(blockIdSelfManagedEnabled)
+                .managerClientSupplier(managerClientSupplier)
+                .clientType(clientType)
+                .retryMax(retryMax)
+                .retryIntervalMax(retryIntervalMax)
+                .heartBeatThreadNum(heartBeatThreadNum)
+                .replica(dataReplica)
+                .replicaWrite(dataReplicaWrite)
+                .replicaRead(dataReplicaRead)
+                .replicaSkipEnabled(dataReplicaSkipEnabled)
+                .dataTransferPoolSize(dataTransferPoolSize)
+                .dataCommitPoolSize(dataCommitPoolSize)
+                .unregisterThreadPoolSize(unregisterThreadPoolSize)
+                .unregisterTimeSec(unregisterTimeoutSec)
+                .unregisterRequestTimeSec(unregisterRequestTimeoutSec)
+                .rssConf(rssConf));
+  }
+
+  @Override
+  protected void checkSupported(SparkConf sparkConf) {
+    if (sparkConf.getBoolean("spark.sql.adaptive.enabled", false)) {
+      throw new IllegalArgumentException(
+          "Spark2 doesn't support AQE, spark.sql.adaptive.enabled should be 
false.");
+    }
+  }
 }
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 37692a05a..5e2a94102 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -17,16 +17,10 @@
 
 package org.apache.spark.shuffle;
 
-import java.io.IOException;
-import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
-import java.util.Optional;
 import java.util.Set;
-import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.ScheduledExecutorService;
-import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.stream.Collectors;
 
@@ -38,7 +32,6 @@ import scala.collection.Seq;
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.Sets;
 import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
-import org.apache.commons.lang3.StringUtils;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.spark.MapOutputTracker;
 import org.apache.spark.ShuffleDependency;
@@ -53,7 +46,6 @@ import org.apache.spark.shuffle.handle.ShuffleHandleInfo;
 import org.apache.spark.shuffle.handle.SimpleShuffleHandleInfo;
 import org.apache.spark.shuffle.handle.StageAttemptShuffleHandleInfo;
 import org.apache.spark.shuffle.reader.RssShuffleReader;
-import org.apache.spark.shuffle.writer.AddBlockEvent;
 import org.apache.spark.shuffle.writer.DataPusher;
 import org.apache.spark.shuffle.writer.RssShuffleWriter;
 import org.apache.spark.sql.internal.SQLConf;
@@ -64,6 +56,7 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import org.apache.uniffle.client.PartitionDataReplicaRequirementTracking;
+import org.apache.uniffle.client.api.ShuffleWriteClient;
 import org.apache.uniffle.client.impl.FailedBlockSendTracker;
 import org.apache.uniffle.client.util.ClientUtils;
 import org.apache.uniffle.common.RemoteStorageInfo;
@@ -73,235 +66,16 @@ import org.apache.uniffle.common.config.RssClientConf;
 import org.apache.uniffle.common.config.RssConf;
 import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.common.exception.RssFetchFailedException;
-import org.apache.uniffle.common.rpc.GrpcServer;
-import org.apache.uniffle.common.util.BlockIdLayout;
-import org.apache.uniffle.common.util.JavaUtils;
 import org.apache.uniffle.common.util.RssUtils;
-import org.apache.uniffle.common.util.ThreadUtils;
 import org.apache.uniffle.shuffle.RssShuffleClientFactory;
 import org.apache.uniffle.shuffle.manager.RssShuffleManagerBase;
-import org.apache.uniffle.shuffle.manager.ShuffleManagerGrpcService;
-import org.apache.uniffle.shuffle.manager.ShuffleManagerServerFactory;
-
-import static 
org.apache.spark.shuffle.RssSparkConfig.RSS_BLOCK_ID_SELF_MANAGEMENT_ENABLED;
-import static 
org.apache.spark.shuffle.RssSparkConfig.RSS_PARTITION_REASSIGN_MAX_REASSIGNMENT_SERVER_NUM;
-import static 
org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED;
-import static 
org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_WRITE_FAILURE_ENABLED;
-import static org.apache.uniffle.common.config.RssBaseConf.RPC_SERVER_PORT;
-import static 
org.apache.uniffle.common.config.RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE;
 
 public class RssShuffleManager extends RssShuffleManagerBase {
   private static final Logger LOG = 
LoggerFactory.getLogger(RssShuffleManager.class);
-  private final long heartbeatInterval;
-  private final long heartbeatTimeout;
-  private final int dataReplica;
-  private final int dataReplicaWrite;
-  private final int dataReplicaRead;
-  private final boolean dataReplicaSkipEnabled;
-  private final int dataTransferPoolSize;
-  private final int dataCommitPoolSize;
-  private final Map<String, Set<Long>> taskToSuccessBlockIds;
-  private final Map<String, FailedBlockSendTracker> 
taskToFailedBlockSendTracker;
-  private ScheduledExecutorService heartBeatScheduledExecutorService;
-  private boolean heartbeatStarted = false;
-  private final BlockIdLayout blockIdLayout;
-  private final int maxFailures;
-  private final boolean speculation;
-  private String user;
-  private String uuid;
-  private Set<String> failedTaskIds = Sets.newConcurrentHashSet();
-  private DataPusher dataPusher;
-  private ShuffleManagerGrpcService service;
-  private GrpcServer shuffleManagerServer;
 
   public RssShuffleManager(SparkConf conf, boolean isDriver) {
-    this.sparkConf = conf;
-    boolean supportsRelocation =
-        Optional.ofNullable(SparkEnv.get())
-            .map(env -> 
env.serializer().supportsRelocationOfSerializedObjects())
-            .orElse(true);
-    if (!supportsRelocation) {
-      LOG.warn(
-          "RSSShuffleManager requires a serializer which supports relocations 
of serialized object. Please set "
-              + "spark.serializer to 
org.apache.spark.serializer.KryoSerializer instead");
-    }
-    this.user = sparkConf.get("spark.rss.quota.user", "user");
-    this.uuid = sparkConf.get("spark.rss.quota.uuid", 
Long.toString(System.currentTimeMillis()));
-    this.dynamicConfEnabled = 
sparkConf.get(RssSparkConfig.RSS_DYNAMIC_CLIENT_CONF_ENABLED);
-
-    // fetch client conf and apply them if necessary
-    if (isDriver && this.dynamicConfEnabled) {
-      fetchAndApplyDynamicConf(sparkConf);
-    }
-    RssSparkShuffleUtils.validateRssClientConf(sparkConf);
-
-    // set & check replica config
-    this.dataReplica = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA);
-    this.dataReplicaWrite = 
sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_WRITE);
-    this.dataReplicaRead = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_READ);
-    this.dataReplicaSkipEnabled = 
sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_SKIP_ENABLED);
-    LOG.info(
-        "Check quorum config ["
-            + dataReplica
-            + ":"
-            + dataReplicaWrite
-            + ":"
-            + dataReplicaRead
-            + ":"
-            + dataReplicaSkipEnabled
-            + "]");
-    RssUtils.checkQuorumSetting(dataReplica, dataReplicaWrite, 
dataReplicaRead);
-
-    this.heartbeatInterval = 
sparkConf.get(RssSparkConfig.RSS_HEARTBEAT_INTERVAL);
-    this.heartbeatTimeout =
-        sparkConf.getLong(RssSparkConfig.RSS_HEARTBEAT_TIMEOUT.key(), 
heartbeatInterval / 2);
-    final int retryMax = sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_MAX);
-    this.clientType = sparkConf.get(RssSparkConfig.RSS_CLIENT_TYPE);
+    super(conf, isDriver);
     this.dataDistributionType = getDataDistributionType(sparkConf);
-    RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
-    RssUtils.setExtraJavaProperties(rssConf);
-    this.maxConcurrencyPerPartitionToWrite = 
rssConf.get(MAX_CONCURRENCY_PER_PARTITION_TO_WRITE);
-    this.maxFailures = sparkConf.getInt("spark.task.maxFailures", 4);
-    this.speculation = sparkConf.getBoolean("spark.speculation", false);
-    // configureBlockIdLayout requires maxFailures and speculation to be 
initialized
-    configureBlockIdLayout(sparkConf, rssConf);
-    this.blockIdLayout = BlockIdLayout.from(rssConf);
-    this.dataTransferPoolSize = 
sparkConf.get(RssSparkConfig.RSS_DATA_TRANSFER_POOL_SIZE);
-    this.dataCommitPoolSize = 
sparkConf.get(RssSparkConfig.RSS_DATA_COMMIT_POOL_SIZE);
-    // External shuffle service is not supported when using remote shuffle 
service
-    sparkConf.set("spark.shuffle.service.enabled", "false");
-    sparkConf.set("spark.dynamicAllocation.shuffleTracking.enabled", "false");
-    sparkConf.set(RssSparkConfig.RSS_ENABLED.key(), "true");
-    LOG.info("Disable external shuffle service in RssShuffleManager.");
-    sparkConf.set("spark.sql.adaptive.localShuffleReader.enabled", "false");
-    LOG.info("Disable local shuffle reader in RssShuffleManager.");
-    // If we store shuffle data in distributed filesystem or in a disaggregated
-    // shuffle cluster, we don't need shuffle data locality
-    sparkConf.set("spark.shuffle.reduceLocality.enabled", "false");
-    LOG.info("Disable shuffle data locality in RssShuffleManager.");
-    taskToSuccessBlockIds = JavaUtils.newConcurrentMap();
-    taskToFailedBlockSendTracker = JavaUtils.newConcurrentMap();
-    this.shuffleIdToPartitionNum = JavaUtils.newConcurrentMap();
-    this.shuffleIdToNumMapTasks = JavaUtils.newConcurrentMap();
-    this.partitionReassignEnabled = 
rssConf.get(RssClientConf.RSS_CLIENT_REASSIGN_ENABLED);
-
-    // stage retry for write/fetch failure
-    rssStageRetryForFetchFailureEnabled =
-        rssConf.get(RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED);
-    rssStageRetryForWriteFailureEnabled =
-        rssConf.get(RSS_RESUBMIT_STAGE_WITH_WRITE_FAILURE_ENABLED);
-    if (rssStageRetryForFetchFailureEnabled || 
rssStageRetryForWriteFailureEnabled) {
-      rssStageRetryEnabled = true;
-      List<String> logTips = new ArrayList<>();
-      if (rssStageRetryForWriteFailureEnabled) {
-        logTips.add("write");
-      }
-      if (rssStageRetryForWriteFailureEnabled) {
-        logTips.add("fetch");
-      }
-      LOG.info(
-          "Activate the stage retry mechanism that will resubmit stage on {} 
failure",
-          StringUtils.join(logTips, "/"));
-    }
-
-    // The feature of partition reassign is exclusive with multiple replicas 
and stage retry.
-    if (partitionReassignEnabled) {
-      if (dataReplica > 1) {
-        throw new RssException(
-            "The feature of task partition reassign is incompatible with 
multiple replicas mechanism.");
-      }
-    }
-
-    this.blockIdSelfManagedEnabled = 
rssConf.getBoolean(RSS_BLOCK_ID_SELF_MANAGEMENT_ENABLED);
-    this.shuffleManagerRpcServiceEnabled =
-        partitionReassignEnabled || rssStageRetryEnabled || 
blockIdSelfManagedEnabled;
-    if (isDriver) {
-      heartBeatScheduledExecutorService =
-          ThreadUtils.getDaemonSingleThreadScheduledExecutor("rss-heartbeat");
-      if (shuffleManagerRpcServiceEnabled) {
-        LOG.info("stage resubmit is supported and enabled");
-        // start shuffle manager server
-        rssConf.set(RPC_SERVER_PORT, 0);
-        ShuffleManagerServerFactory factory = new 
ShuffleManagerServerFactory(this, rssConf);
-        service = factory.getService();
-        shuffleManagerServer = factory.getServer(service);
-        try {
-          shuffleManagerServer.start();
-          // pass this as a spark.rss.shuffle.manager.grpc.port config, so it 
can be propagated to
-          // executor properly.
-          sparkConf.set(
-              RssSparkConfig.RSS_SHUFFLE_MANAGER_GRPC_PORT, 
shuffleManagerServer.getPort());
-        } catch (Exception e) {
-          LOG.error("Failed to start shuffle manager server", e);
-          throw new RssException(e);
-        }
-      }
-    }
-    if (shuffleManagerRpcServiceEnabled) {
-      getOrCreateShuffleManagerClientSupplier();
-    }
-    int unregisterThreadPoolSize =
-        sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_THREAD_POOL_SIZE);
-    int unregisterTimeoutSec = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_TIMEOUT_SEC);
-    int unregisterRequestTimeoutSec =
-        
sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_REQUEST_TIMEOUT_SEC);
-    long retryIntervalMax = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_INTERVAL_MAX);
-    int heartBeatThreadNum = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_HEARTBEAT_THREAD_NUM);
-    shuffleWriteClient =
-        RssShuffleClientFactory.getInstance()
-            .createShuffleWriteClient(
-                RssShuffleClientFactory.newWriteBuilder()
-                    .blockIdSelfManagedEnabled(blockIdSelfManagedEnabled)
-                    .managerClientSupplier(managerClientSupplier)
-                    .clientType(clientType)
-                    .retryMax(retryMax)
-                    .retryIntervalMax(retryIntervalMax)
-                    .heartBeatThreadNum(heartBeatThreadNum)
-                    .replica(dataReplica)
-                    .replicaWrite(dataReplicaWrite)
-                    .replicaRead(dataReplicaRead)
-                    .replicaSkipEnabled(dataReplicaSkipEnabled)
-                    .dataTransferPoolSize(dataTransferPoolSize)
-                    .dataCommitPoolSize(dataCommitPoolSize)
-                    .unregisterThreadPoolSize(unregisterThreadPoolSize)
-                    .unregisterTimeSec(unregisterTimeoutSec)
-                    .unregisterRequestTimeSec(unregisterRequestTimeoutSec)
-                    .rssConf(rssConf));
-    registerCoordinator();
-
-    LOG.info("Rss data pusher is starting...");
-    int poolSize = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_SIZE);
-    int keepAliveTime = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_KEEPALIVE);
-    this.dataPusher =
-        new DataPusher(
-            shuffleWriteClient,
-            taskToSuccessBlockIds,
-            taskToFailedBlockSendTracker,
-            failedTaskIds,
-            poolSize,
-            keepAliveTime);
-    this.partitionReassignMaxServerNum =
-        rssConf.get(RSS_PARTITION_REASSIGN_MAX_REASSIGNMENT_SERVER_NUM);
-    this.shuffleHandleInfoManager = new ShuffleHandleInfoManager();
-    this.rssStageResubmitManager = new RssStageResubmitManager();
-  }
-
-  public CompletableFuture<Long> sendData(AddBlockEvent event) {
-    if (dataPusher != null && event != null) {
-      return dataPusher.send(event);
-    }
-    return new CompletableFuture<>();
-  }
-
-  @VisibleForTesting
-  protected static ShuffleDataDistributionType 
getDataDistributionType(SparkConf sparkConf) {
-    RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
-    if ((boolean) sparkConf.get(SQLConf.ADAPTIVE_EXECUTION_ENABLED())
-        && !rssConf.containsKey(RssClientConf.DATA_DISTRIBUTION_TYPE.key())) {
-      return ShuffleDataDistributionType.LOCAL_ORDER;
-    }
-
-    return rssConf.get(RssClientConf.DATA_DISTRIBUTION_TYPE);
   }
 
   // For testing only
@@ -312,72 +86,7 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
       DataPusher dataPusher,
       Map<String, Set<Long>> taskToSuccessBlockIds,
       Map<String, FailedBlockSendTracker> taskToFailedBlockSendTracker) {
-    this.sparkConf = conf;
-    this.clientType = sparkConf.get(RssSparkConfig.RSS_CLIENT_TYPE);
-    RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
-    this.dataDistributionType = 
rssConf.get(RssClientConf.DATA_DISTRIBUTION_TYPE);
-    this.blockIdLayout = BlockIdLayout.from(rssConf);
-    this.maxConcurrencyPerPartitionToWrite = 
rssConf.get(MAX_CONCURRENCY_PER_PARTITION_TO_WRITE);
-    this.maxFailures = sparkConf.getInt("spark.task.maxFailures", 4);
-    this.speculation = sparkConf.getBoolean("spark.speculation", false);
-    // configureBlockIdLayout requires maxFailures and speculation to be 
initialized
-    configureBlockIdLayout(sparkConf, rssConf);
-    this.heartbeatInterval = 
sparkConf.get(RssSparkConfig.RSS_HEARTBEAT_INTERVAL);
-    this.heartbeatTimeout =
-        sparkConf.getLong(RssSparkConfig.RSS_HEARTBEAT_TIMEOUT.key(), 
heartbeatInterval / 2);
-    this.dataReplica = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA);
-    this.dataReplicaWrite = 
sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_WRITE);
-    this.dataReplicaRead = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_READ);
-    this.dataReplicaSkipEnabled = 
sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_SKIP_ENABLED);
-    LOG.info(
-        "Check quorum config ["
-            + dataReplica
-            + ":"
-            + dataReplicaWrite
-            + ":"
-            + dataReplicaRead
-            + ":"
-            + dataReplicaSkipEnabled
-            + "]");
-    RssUtils.checkQuorumSetting(dataReplica, dataReplicaWrite, 
dataReplicaRead);
-
-    int retryMax = sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_MAX);
-    long retryIntervalMax = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_INTERVAL_MAX);
-    int heartBeatThreadNum = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_HEARTBEAT_THREAD_NUM);
-    this.dataTransferPoolSize = 
sparkConf.get(RssSparkConfig.RSS_DATA_TRANSFER_POOL_SIZE);
-    this.dataCommitPoolSize = 
sparkConf.get(RssSparkConfig.RSS_DATA_COMMIT_POOL_SIZE);
-    int unregisterThreadPoolSize =
-        sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_THREAD_POOL_SIZE);
-    int unregisterTimeoutSec = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_TIMEOUT_SEC);
-    int unregisterRequestTimeoutSec =
-        
sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_REQUEST_TIMEOUT_SEC);
-    shuffleWriteClient =
-        RssShuffleClientFactory.getInstance()
-            .createShuffleWriteClient(
-                RssShuffleClientFactory.getInstance()
-                    .newWriteBuilder()
-                    .clientType(clientType)
-                    .retryMax(retryMax)
-                    .retryIntervalMax(retryIntervalMax)
-                    .heartBeatThreadNum(heartBeatThreadNum)
-                    .replica(dataReplica)
-                    .replicaWrite(dataReplicaWrite)
-                    .replicaRead(dataReplicaRead)
-                    .replicaSkipEnabled(dataReplicaSkipEnabled)
-                    .dataTransferPoolSize(dataTransferPoolSize)
-                    .dataCommitPoolSize(dataCommitPoolSize)
-                    .unregisterThreadPoolSize(unregisterThreadPoolSize)
-                    .unregisterTimeSec(unregisterTimeoutSec)
-                    .unregisterRequestTimeSec(unregisterRequestTimeoutSec)
-                    .rssConf(rssConf));
-    this.taskToSuccessBlockIds = taskToSuccessBlockIds;
-    this.heartBeatScheduledExecutorService = null;
-    this.taskToFailedBlockSendTracker = taskToFailedBlockSendTracker;
-    this.dataPusher = dataPusher;
-    this.partitionReassignMaxServerNum =
-        rssConf.get(RSS_PARTITION_REASSIGN_MAX_REASSIGNMENT_SERVER_NUM);
-    this.shuffleHandleInfoManager = new ShuffleHandleInfoManager();
-    this.rssStageResubmitManager = new RssStageResubmitManager();
+    super(conf, isDriver, dataPusher, taskToSuccessBlockIds, 
taskToFailedBlockSendTracker);
   }
 
   // This method is called in Spark driver side,
@@ -527,17 +236,6 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
         context);
   }
 
-  @Override
-  public void configureBlockIdLayout(SparkConf sparkConf, RssConf rssConf) {
-    configureBlockIdLayout(sparkConf, rssConf, maxFailures, speculation);
-  }
-
-  @Override
-  public long getTaskAttemptIdForBlockId(int mapIndex, int attemptNo) {
-    return getTaskAttemptIdForBlockId(
-        mapIndex, attemptNo, maxFailures, speculation, 
blockIdLayout.taskAttemptIdBits);
-  }
-
   public void setPusherAppId(RssShuffleHandle rssShuffleHandle) {
     // todo: this implement is tricky, we should refactor it
     if (id.get() == null) {
@@ -845,127 +543,52 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
     return taskIdBitmap;
   }
 
-  @Override
-  public boolean unregisterShuffle(int shuffleId) {
-    try {
-      super.unregisterShuffle(shuffleId);
-      if (SparkEnv.get().executorId().equals("driver")) {
-        shuffleWriteClient.unregisterShuffle(id.get(), shuffleId);
-        shuffleIdToPartitionNum.remove(shuffleId);
-        shuffleIdToNumMapTasks.remove(shuffleId);
-        if (service != null) {
-          service.unregisterShuffle(shuffleId);
-        }
-      }
-    } catch (Exception e) {
-      LOG.warn("Errors on unregistering from remote shuffle-servers", e);
-    }
-    return true;
-  }
-
   @Override
   public ShuffleBlockResolver shuffleBlockResolver() {
     throw new RssException("RssShuffleManager.shuffleBlockResolver is not 
implemented");
   }
 
   @Override
-  public void stop() {
-    super.stop();
-    if (heartBeatScheduledExecutorService != null) {
-      heartBeatScheduledExecutorService.shutdownNow();
-    }
-    if (shuffleWriteClient != null) {
-      // Unregister shuffle before closing shuffle write client.
-      shuffleWriteClient.unregisterShuffle(getAppId());
-      shuffleWriteClient.close();
-    }
-    if (dataPusher != null) {
-      try {
-        dataPusher.close();
-      } catch (IOException e) {
-        LOG.warn("Errors on closing data pusher", e);
-      }
-    }
-
-    if (shuffleManagerServer != null) {
-      try {
-        shuffleManagerServer.stop();
-      } catch (InterruptedException e) {
-        // ignore
-        LOG.info("shuffle manager server is interrupted during stop");
-      }
-    }
-  }
+  protected ShuffleWriteClient createShuffleWriteClient() {
+    int unregisterThreadPoolSize =
+        sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_THREAD_POOL_SIZE);
+    int unregisterTimeoutSec = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_TIMEOUT_SEC);
+    int unregisterRequestTimeoutSec =
+        
sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_REQUEST_TIMEOUT_SEC);
+    long retryIntervalMax = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_INTERVAL_MAX);
+    int heartBeatThreadNum = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_HEARTBEAT_THREAD_NUM);
 
-  public void clearTaskMeta(String taskId) {
-    taskToSuccessBlockIds.remove(taskId);
-    taskToFailedBlockSendTracker.remove(taskId);
+    final int retryMax = sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_MAX);
+    return RssShuffleClientFactory.getInstance()
+        .createShuffleWriteClient(
+            RssShuffleClientFactory.newWriteBuilder()
+                .blockIdSelfManagedEnabled(blockIdSelfManagedEnabled)
+                .managerClientSupplier(managerClientSupplier)
+                .clientType(clientType)
+                .retryMax(retryMax)
+                .retryIntervalMax(retryIntervalMax)
+                .heartBeatThreadNum(heartBeatThreadNum)
+                .replica(dataReplica)
+                .replicaWrite(dataReplicaWrite)
+                .replicaRead(dataReplicaRead)
+                .replicaSkipEnabled(dataReplicaSkipEnabled)
+                .dataTransferPoolSize(dataTransferPoolSize)
+                .dataCommitPoolSize(dataCommitPoolSize)
+                .unregisterThreadPoolSize(unregisterThreadPoolSize)
+                .unregisterTimeSec(unregisterTimeoutSec)
+                .unregisterRequestTimeSec(unregisterRequestTimeoutSec)
+                .rssConf(rssConf));
   }
 
   @VisibleForTesting
-  protected void registerCoordinator() {
-    String coordinators = 
sparkConf.get(RssSparkConfig.RSS_COORDINATOR_QUORUM.key());
-    LOG.info("Start Registering coordinators {}", coordinators);
-    shuffleWriteClient.registerCoordinators(
-        coordinators,
-        this.sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_INTERVAL_MAX),
-        this.sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_MAX));
-  }
-
-  private synchronized void startHeartbeat() {
-    shuffleWriteClient.registerApplicationInfo(id.get(), heartbeatTimeout, 
user);
-    if (!heartbeatStarted) {
-      heartBeatScheduledExecutorService.scheduleAtFixedRate(
-          () -> {
-            try {
-              String appId = id.get();
-              shuffleWriteClient.sendAppHeartbeat(appId, heartbeatTimeout);
-              LOG.info("Finish send heartbeat to coordinator and servers");
-            } catch (Exception e) {
-              LOG.warn("Fail to send heartbeat to coordinator and servers", e);
-            }
-          },
-          heartbeatInterval / 2,
-          heartbeatInterval,
-          TimeUnit.MILLISECONDS);
-      heartbeatStarted = true;
-    }
-  }
-
-  public Set<Long> getFailedBlockIds(String taskId) {
-    FailedBlockSendTracker blockIdsFailedSendTracker = 
getBlockIdsFailedSendTracker(taskId);
-    if (blockIdsFailedSendTracker == null) {
-      return Collections.emptySet();
-    }
-    return blockIdsFailedSendTracker.getFailedBlockIds();
-  }
-
-  public Set<Long> getSuccessBlockIds(String taskId) {
-    Set<Long> result = taskToSuccessBlockIds.get(taskId);
-    if (result == null) {
-      result = Collections.emptySet();
+  protected static ShuffleDataDistributionType 
getDataDistributionType(SparkConf sparkConf) {
+    RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
+    if ((boolean) sparkConf.get(SQLConf.ADAPTIVE_EXECUTION_ENABLED())
+        && !rssConf.containsKey(RssClientConf.DATA_DISTRIBUTION_TYPE.key())) {
+      return ShuffleDataDistributionType.LOCAL_ORDER;
     }
-    return result;
-  }
 
-  /** @return the unique spark id for rss shuffle */
-  @Override
-  public String getAppId() {
-    return id.get();
-  }
-
-  @Override
-  public int getPartitionNum(int shuffleId) {
-    return shuffleIdToPartitionNum.getOrDefault(shuffleId, 0);
-  }
-
-  /**
-   * @param shuffleId the shuffle id to query
-   * @return the num of map tasks for current shuffle with shuffle id.
-   */
-  @Override
-  public int getNumMaps(int shuffleId) {
-    return shuffleIdToNumMapTasks.getOrDefault(shuffleId, 0);
+    return rssConf.get(RssClientConf.DATA_DISTRIBUTION_TYPE);
   }
 
   static class ReadMetrics extends ShuffleReadMetrics {
@@ -1019,16 +642,6 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
     this.id = new AtomicReference<>(appId);
   }
 
-  public boolean markFailedTask(String taskId) {
-    LOG.info("Mark the task: {} failed.", taskId);
-    failedTaskIds.add(taskId);
-    return true;
-  }
-
-  public boolean isValidTask(String taskId) {
-    return !failedTaskIds.contains(taskId);
-  }
-
   private Roaring64NavigableMap getShuffleResultForMultiPart(
       String clientType,
       Map<ShuffleServerInfo, Set<Integer>> serverToPartitions,
@@ -1050,23 +663,4 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
           managerClientSupplier, e, sparkConf, appId, shuffleId, 
stageAttemptId, failedPartitions);
     }
   }
-
-  public FailedBlockSendTracker getBlockIdsFailedSendTracker(String taskId) {
-    return taskToFailedBlockSendTracker.get(taskId);
-  }
-
-  @VisibleForTesting
-  public void setDataPusher(DataPusher dataPusher) {
-    this.dataPusher = dataPusher;
-  }
-
-  @VisibleForTesting
-  public Map<String, Set<Long>> getTaskToSuccessBlockIds() {
-    return taskToSuccessBlockIds;
-  }
-
-  @VisibleForTesting
-  public Map<String, FailedBlockSendTracker> getTaskToFailedBlockSendTracker() 
{
-    return taskToFailedBlockSendTracker;
-  }
 }

Reply via email to