This is an automated email from the ASF dual-hosted git repository.
maobaolong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new 43879ca13 [MINOR] improvement(spark-client): Refactor
RssShuffleManager for spark v2 and v3 to reduce redundant code (#2330)
43879ca13 is described below
commit 43879ca13aa261e6a775178c7015b8cf758053ab
Author: maobaolong <[email protected]>
AuthorDate: Thu Jan 9 20:46:04 2025 +0800
[MINOR] improvement(spark-client): Refactor RssShuffleManager for spark v2
and v3 to reduce redundant code (#2330)
### What changes were proposed in this pull request?
Refactor and abstract the same code into base class.
### Why are the changes needed?
Reduce the redundant code and simplify the development in rss spark-client
scope.
The Spark version unrelated code should be placed into
RssShuffleManagerBase class, for this RssShuffleManager Class, it should only
maintains the spark api related codes.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
- Existing UTs.
- Our test cluster, tested both spark v2 and v3 version.
---
.../shuffle/manager/RssShuffleManagerBase.java | 452 ++++++++++++++++++-
.../apache/spark/shuffle/RssShuffleManager.java | 409 ++----------------
.../apache/spark/shuffle/RssShuffleManager.java | 480 ++-------------------
3 files changed, 521 insertions(+), 820 deletions(-)
diff --git
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
index c1fc5b68e..d869c64fe 100644
---
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
+++
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
@@ -17,6 +17,7 @@
package org.apache.uniffle.shuffle.manager;
+import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
@@ -28,6 +29,9 @@ import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
@@ -40,6 +44,7 @@ import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import org.apache.commons.collections4.CollectionUtils;
+import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.spark.MapOutputTracker;
@@ -58,6 +63,8 @@ import
org.apache.spark.shuffle.handle.MutableShuffleHandleInfo;
import org.apache.spark.shuffle.handle.ShuffleHandleInfo;
import org.apache.spark.shuffle.handle.SimpleShuffleHandleInfo;
import org.apache.spark.shuffle.handle.StageAttemptShuffleHandleInfo;
+import org.apache.spark.shuffle.writer.AddBlockEvent;
+import org.apache.spark.shuffle.writer.DataPusher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -65,6 +72,7 @@ import org.apache.uniffle.client.api.ShuffleManagerClient;
import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.factory.CoordinatorClientFactory;
import org.apache.uniffle.client.factory.ShuffleManagerClientFactory;
+import org.apache.uniffle.client.impl.FailedBlockSendTracker;
import org.apache.uniffle.client.impl.grpc.CoordinatorGrpcRetryableClient;
import org.apache.uniffle.client.request.RssFetchClientConfRequest;
import org.apache.uniffle.client.request.RssPartitionToShuffleServerRequest;
@@ -83,17 +91,38 @@ import org.apache.uniffle.common.config.ConfigOption;
import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.rpc.GrpcServer;
import org.apache.uniffle.common.rpc.StatusCode;
+import org.apache.uniffle.common.util.BlockIdLayout;
import org.apache.uniffle.common.util.Constants;
import org.apache.uniffle.common.util.ExpiringCloseableSupplier;
+import org.apache.uniffle.common.util.JavaUtils;
import org.apache.uniffle.common.util.RetryUtils;
+import org.apache.uniffle.common.util.RssUtils;
+import org.apache.uniffle.common.util.ThreadUtils;
import org.apache.uniffle.shuffle.BlockIdManager;
+import static
org.apache.spark.shuffle.RssSparkConfig.RSS_BLOCK_ID_SELF_MANAGEMENT_ENABLED;
+import static
org.apache.spark.shuffle.RssSparkConfig.RSS_PARTITION_REASSIGN_MAX_REASSIGNMENT_SERVER_NUM;
+import static
org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED;
+import static
org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_WRITE_FAILURE_ENABLED;
+import static org.apache.uniffle.common.config.RssBaseConf.RPC_SERVER_PORT;
import static
org.apache.uniffle.common.config.RssClientConf.HADOOP_CONFIG_KEY_PREFIX;
+import static
org.apache.uniffle.common.config.RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE;
import static
org.apache.uniffle.common.config.RssClientConf.RSS_CLIENT_REMOTE_STORAGE_USE_LOCAL_CONF_ENABLED;
public abstract class RssShuffleManagerBase implements
RssShuffleManagerInterface, ShuffleManager {
private static final Logger LOG =
LoggerFactory.getLogger(RssShuffleManagerBase.class);
+ protected final int dataTransferPoolSize;
+ protected final int dataCommitPoolSize;
+ protected final int dataReplica;
+ protected final int dataReplicaWrite;
+ protected final int dataReplicaRead;
+ protected final boolean dataReplicaSkipEnabled;
+ protected final Map<String, Set<Long>> taskToSuccessBlockIds;
+ protected final Map<String, FailedBlockSendTracker>
taskToFailedBlockSendTracker;
+ private Set<String> failedTaskIds = Sets.newConcurrentHashSet();
+
private AtomicBoolean isInitialized = new AtomicBoolean(false);
private Method unregisterAllMapOutputMethod;
private Method registerShuffleMethod;
@@ -107,7 +136,8 @@ public abstract class RssShuffleManagerBase implements
RssShuffleManagerInterfac
protected int maxConcurrencyPerPartitionToWrite;
protected String clientType;
- protected SparkConf sparkConf;
+ protected final SparkConf sparkConf;
+ protected final RssConf rssConf;
protected Map<Integer, Integer> shuffleIdToPartitionNum;
protected Map<Integer, Integer> shuffleIdToNumMapTasks;
protected Supplier<ShuffleManagerClient> managerClientSupplier;
@@ -126,9 +156,227 @@ public abstract class RssShuffleManagerBase implements
RssShuffleManagerInterfac
protected boolean partitionReassignEnabled;
protected boolean shuffleManagerRpcServiceEnabled;
- public RssShuffleManagerBase() {
+ protected boolean heartbeatStarted = false;
+ protected final long heartbeatInterval;
+ protected final long heartbeatTimeout;
+ protected String user;
+ protected String uuid;
+ protected ScheduledExecutorService heartBeatScheduledExecutorService;
+ protected final int maxFailures;
+ protected final boolean speculation;
+ protected final BlockIdLayout blockIdLayout;
+ private ShuffleManagerGrpcService service;
+ protected GrpcServer shuffleManagerServer;
+ protected DataPusher dataPusher;
+
+ public RssShuffleManagerBase(SparkConf conf, boolean isDriver) {
LOG.info(
"Uniffle {} version: {}", this.getClass().getName(),
Constants.VERSION_AND_REVISION_SHORT);
+ this.sparkConf = conf;
+ checkSupported(sparkConf);
+ boolean supportsRelocation =
+ Optional.ofNullable(SparkEnv.get())
+ .map(env ->
env.serializer().supportsRelocationOfSerializedObjects())
+ .orElse(true);
+ if (!supportsRelocation) {
+ LOG.warn(
+ "RSSShuffleManager requires a serializer which supports relocations
of serialized object. Please set "
+ + "spark.serializer to
org.apache.spark.serializer.KryoSerializer instead");
+ }
+ this.user = sparkConf.get("spark.rss.quota.user", "user");
+ this.uuid = sparkConf.get("spark.rss.quota.uuid",
Long.toString(System.currentTimeMillis()));
+ this.dynamicConfEnabled =
sparkConf.get(RssSparkConfig.RSS_DYNAMIC_CLIENT_CONF_ENABLED);
+
+ // fetch client conf and apply them if necessary
+ if (isDriver && this.dynamicConfEnabled) {
+ fetchAndApplyDynamicConf(sparkConf);
+ }
+ RssSparkShuffleUtils.validateRssClientConf(sparkConf);
+
+ // convert spark conf to rss conf after fetching dynamic client conf
+ this.rssConf = RssSparkConfig.toRssConf(sparkConf);
+ RssUtils.setExtraJavaProperties(rssConf);
+
+ // set & check replica config
+ this.dataReplica = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA);
+ this.dataReplicaWrite =
sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_WRITE);
+ this.dataReplicaRead = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_READ);
+ this.dataReplicaSkipEnabled =
sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_SKIP_ENABLED);
+ LOG.info(
+ "Check quorum config ["
+ + dataReplica
+ + ":"
+ + dataReplicaWrite
+ + ":"
+ + dataReplicaRead
+ + ":"
+ + dataReplicaSkipEnabled
+ + "]");
+ RssUtils.checkQuorumSetting(dataReplica, dataReplicaWrite,
dataReplicaRead);
+
+ this.maxConcurrencyPerPartitionToWrite =
rssConf.get(MAX_CONCURRENCY_PER_PARTITION_TO_WRITE);
+
+ this.clientType = sparkConf.get(RssSparkConfig.RSS_CLIENT_TYPE);
+
+ // configure block id layout
+ this.maxFailures = sparkConf.getInt("spark.task.maxFailures", 4);
+ this.speculation = sparkConf.getBoolean("spark.speculation", false);
+ // configureBlockIdLayout requires maxFailures and speculation to be
initialized
+ configureBlockIdLayout(sparkConf, rssConf);
+ this.blockIdLayout = BlockIdLayout.from(rssConf);
+
+ this.dataTransferPoolSize =
sparkConf.get(RssSparkConfig.RSS_DATA_TRANSFER_POOL_SIZE);
+ this.dataCommitPoolSize =
sparkConf.get(RssSparkConfig.RSS_DATA_COMMIT_POOL_SIZE);
+
+ // External shuffle service is not supported when using remote shuffle
service
+ sparkConf.set("spark.shuffle.service.enabled", "false");
+ sparkConf.set("spark.dynamicAllocation.shuffleTracking.enabled", "false");
+ sparkConf.set(RssSparkConfig.RSS_ENABLED.key(), "true");
+ LOG.info("Disable external shuffle service in RssShuffleManager.");
+ sparkConf.set("spark.sql.adaptive.localShuffleReader.enabled", "false");
+ LOG.info("Disable local shuffle reader in RssShuffleManager.");
+ // If we store shuffle data in distributed filesystem or in a disaggregated
+ // shuffle cluster, we don't need shuffle data locality
+ sparkConf.set("spark.shuffle.reduceLocality.enabled", "false");
+ LOG.info("Disable shuffle data locality in RssShuffleManager.");
+
+ taskToSuccessBlockIds = JavaUtils.newConcurrentMap();
+ taskToFailedBlockSendTracker = JavaUtils.newConcurrentMap();
+ this.shuffleIdToPartitionNum = JavaUtils.newConcurrentMap();
+ this.shuffleIdToNumMapTasks = JavaUtils.newConcurrentMap();
+
+ // stage retry for write/fetch failure
+ rssStageRetryForFetchFailureEnabled =
+ rssConf.get(RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED);
+ rssStageRetryForWriteFailureEnabled =
+ rssConf.get(RSS_RESUBMIT_STAGE_WITH_WRITE_FAILURE_ENABLED);
+ if (rssStageRetryForFetchFailureEnabled ||
rssStageRetryForWriteFailureEnabled) {
+ rssStageRetryEnabled = true;
+ List<String> logTips = new ArrayList<>();
+ if (rssStageRetryForWriteFailureEnabled) {
+ logTips.add("write");
+ }
+ if (rssStageRetryForWriteFailureEnabled) {
+ logTips.add("fetch");
+ }
+ LOG.info(
+ "Activate the stage retry mechanism that will resubmit stage on {}
failure",
+ StringUtils.join(logTips, "/"));
+ }
+
+ this.partitionReassignEnabled =
rssConf.get(RssClientConf.RSS_CLIENT_REASSIGN_ENABLED);
+ // The feature of partition reassign is exclusive with multiple replicas
and stage retry.
+ if (partitionReassignEnabled) {
+ if (dataReplica > 1) {
+ throw new RssException(
+ "The feature of task partition reassign is incompatible with
multiple replicas mechanism.");
+ }
+ }
+ this.blockIdSelfManagedEnabled =
rssConf.getBoolean(RSS_BLOCK_ID_SELF_MANAGEMENT_ENABLED);
+ this.shuffleManagerRpcServiceEnabled =
+ partitionReassignEnabled || rssStageRetryEnabled ||
blockIdSelfManagedEnabled;
+
+ if (isDriver) {
+ heartBeatScheduledExecutorService =
+ ThreadUtils.getDaemonSingleThreadScheduledExecutor("rss-heartbeat");
+ if (shuffleManagerRpcServiceEnabled) {
+ LOG.info("stage resubmit is supported and enabled");
+ // start shuffle manager server
+ rssConf.set(RPC_SERVER_PORT, 0);
+ ShuffleManagerServerFactory factory = new
ShuffleManagerServerFactory(this, rssConf);
+ service = factory.getService();
+ shuffleManagerServer = factory.getServer(service);
+ try {
+ shuffleManagerServer.start();
+ // pass this as a spark.rss.shuffle.manager.grpc.port config, so it
can be propagated to
+ // executor properly.
+ sparkConf.set(
+ RssSparkConfig.RSS_SHUFFLE_MANAGER_GRPC_PORT,
shuffleManagerServer.getPort());
+ } catch (Exception e) {
+ LOG.error("Failed to start shuffle manager server", e);
+ throw new RssException(e);
+ }
+ }
+ }
+ if (shuffleManagerRpcServiceEnabled) {
+ getOrCreateShuffleManagerClientSupplier();
+ }
+
+ // Start heartbeat thread.
+ this.heartbeatInterval =
sparkConf.get(RssSparkConfig.RSS_HEARTBEAT_INTERVAL);
+ this.heartbeatTimeout =
+ sparkConf.getLong(RssSparkConfig.RSS_HEARTBEAT_TIMEOUT.key(),
heartbeatInterval / 2);
+ heartBeatScheduledExecutorService =
+ ThreadUtils.getDaemonSingleThreadScheduledExecutor("rss-heartbeat");
+
+ this.shuffleWriteClient = createShuffleWriteClient();
+ registerCoordinator();
+
+ LOG.info("Rss data pusher is starting...");
+ int poolSize =
sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_SIZE);
+ int keepAliveTime =
sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_KEEPALIVE);
+ this.dataPusher =
+ new DataPusher(
+ shuffleWriteClient,
+ taskToSuccessBlockIds,
+ taskToFailedBlockSendTracker,
+ failedTaskIds,
+ poolSize,
+ keepAliveTime);
+ this.partitionReassignMaxServerNum =
+ rssConf.get(RSS_PARTITION_REASSIGN_MAX_REASSIGNMENT_SERVER_NUM);
+ this.shuffleHandleInfoManager = new ShuffleHandleInfoManager();
+ this.rssStageResubmitManager = new RssStageResubmitManager();
+ }
+
+ @VisibleForTesting
+ protected RssShuffleManagerBase(
+ SparkConf conf,
+ boolean isDriver,
+ DataPusher dataPusher,
+ Map<String, Set<Long>> taskToSuccessBlockIds,
+ Map<String, FailedBlockSendTracker> taskToFailedBlockSendTracker) {
+ this.sparkConf = conf;
+ this.clientType = sparkConf.get(RssSparkConfig.RSS_CLIENT_TYPE);
+ this.rssConf = RssSparkConfig.toRssConf(sparkConf);
+ this.dataDistributionType =
rssConf.get(RssClientConf.DATA_DISTRIBUTION_TYPE);
+ this.blockIdLayout = BlockIdLayout.from(rssConf);
+ this.maxConcurrencyPerPartitionToWrite =
rssConf.get(MAX_CONCURRENCY_PER_PARTITION_TO_WRITE);
+ this.maxFailures = sparkConf.getInt("spark.task.maxFailures", 4);
+ this.speculation = sparkConf.getBoolean("spark.speculation", false);
+ // configureBlockIdLayout requires maxFailures and speculation to be
initialized
+ configureBlockIdLayout(sparkConf, rssConf);
+ this.heartbeatInterval =
sparkConf.get(RssSparkConfig.RSS_HEARTBEAT_INTERVAL);
+ this.heartbeatTimeout =
+ sparkConf.getLong(RssSparkConfig.RSS_HEARTBEAT_TIMEOUT.key(),
heartbeatInterval / 2);
+ this.dataReplica = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA);
+ this.dataReplicaWrite =
sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_WRITE);
+ this.dataReplicaRead = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_READ);
+ this.dataReplicaSkipEnabled =
sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_SKIP_ENABLED);
+ LOG.info(
+ "Check quorum config ["
+ + dataReplica
+ + ":"
+ + dataReplicaWrite
+ + ":"
+ + dataReplicaRead
+ + ":"
+ + dataReplicaSkipEnabled
+ + "]");
+ RssUtils.checkQuorumSetting(dataReplica, dataReplicaWrite,
dataReplicaRead);
+
+ this.dataTransferPoolSize =
sparkConf.get(RssSparkConfig.RSS_DATA_TRANSFER_POOL_SIZE);
+ this.dataCommitPoolSize =
sparkConf.get(RssSparkConfig.RSS_DATA_COMMIT_POOL_SIZE);
+ createShuffleWriteClient();
+
+ this.taskToSuccessBlockIds = taskToSuccessBlockIds;
+ this.heartBeatScheduledExecutorService = null;
+ this.taskToFailedBlockSendTracker = taskToFailedBlockSendTracker;
+ this.dataPusher = dataPusher;
+ this.partitionReassignMaxServerNum =
+ rssConf.get(RSS_PARTITION_REASSIGN_MAX_REASSIGNMENT_SERVER_NUM);
+ this.shuffleHandleInfoManager = new ShuffleHandleInfoManager();
+ this.rssStageResubmitManager = new RssStageResubmitManager();
}
public BlockIdManager getBlockIdManager() {
@@ -145,14 +393,35 @@ public abstract class RssShuffleManagerBase implements
RssShuffleManagerInterfac
@Override
public boolean unregisterShuffle(int shuffleId) {
- if (blockIdManager != null) {
- blockIdManager.remove(shuffleId);
+ try {
+ if (blockIdManager != null) {
+ blockIdManager.remove(shuffleId);
+ }
+ if (SparkEnv.get().executorId().equals("driver")) {
+ shuffleWriteClient.unregisterShuffle(getAppId(), shuffleId);
+ shuffleIdToPartitionNum.remove(shuffleId);
+ shuffleIdToNumMapTasks.remove(shuffleId);
+ if (service != null) {
+ service.unregisterShuffle(shuffleId);
+ }
+ }
+ } catch (Exception e) {
+ LOG.warn("Errors on unregistering from remote shuffle-servers", e);
}
return true;
}
- /** See static overload of this method. */
- public abstract void configureBlockIdLayout(SparkConf sparkConf, RssConf
rssConf);
+ /**
+ * Derives block id layout config from maximum number of allowed partitions.
Computes the number
+ * of required bits for partition id and task attempt id and reserves
remaining bits for sequence
+ * number.
+ *
+ * @param sparkConf Spark config providing max partitions
+ * @param rssConf Rss config to amend
+ */
+ public void configureBlockIdLayout(SparkConf sparkConf, RssConf rssConf) {
+ configureBlockIdLayout(sparkConf, rssConf, maxFailures, speculation);
+ }
/**
* Derives block id layout config from maximum number of allowed partitions.
This value can be set
@@ -344,7 +613,10 @@ public abstract class RssShuffleManagerBase implements
RssShuffleManagerInterfac
}
/** See static overload of this method. */
- public abstract long getTaskAttemptIdForBlockId(int mapIndex, int attemptNo);
+ public long getTaskAttemptIdForBlockId(int mapIndex, int attemptNo) {
+ return getTaskAttemptIdForBlockId(
+ mapIndex, attemptNo, maxFailures, speculation,
blockIdLayout.taskAttemptIdBits);
+ }
/**
* Provides a task attempt id to be used in the block id, that is unique for
a shuffle stage.
@@ -809,7 +1081,8 @@ public abstract class RssShuffleManagerBase implements
RssShuffleManagerInterfac
LOG.info(
"Register the new partition->servers assignment on reassign. {}",
newServerToPartitions);
- registerShuffleServers(id.get(), shuffleId, newServerToPartitions,
getRemoteStorageInfo());
+ registerShuffleServers(
+ getAppId(), shuffleId, newServerToPartitions,
getRemoteStorageInfo());
}
LOG.info(
@@ -852,8 +1125,76 @@ public abstract class RssShuffleManagerBase implements
RssShuffleManagerInterfac
&& managerClientSupplier instanceof ExpiringCloseableSupplier) {
((ExpiringCloseableSupplier<ShuffleManagerClient>)
managerClientSupplier).close();
}
+ if (heartBeatScheduledExecutorService != null) {
+ heartBeatScheduledExecutorService.shutdownNow();
+ }
+ if (shuffleWriteClient != null) {
+ // Unregister shuffle before closing shuffle write client.
+ shuffleWriteClient.unregisterShuffle(getAppId());
+ shuffleWriteClient.close();
+ }
+ if (dataPusher != null) {
+ try {
+ dataPusher.close();
+ } catch (IOException e) {
+ LOG.warn("Errors on closing data pusher", e);
+ }
+ }
+
+ if (shuffleManagerServer != null) {
+ try {
+ shuffleManagerServer.stop();
+ } catch (InterruptedException e) {
+ // ignore
+ LOG.info("shuffle manager server is interrupted during stop");
+ }
+ }
+ }
+
+ /** @return the unique spark id for rss shuffle */
+ @Override
+ public String getAppId() {
+ return id.get();
+ }
+
+ @Override
+ public int getPartitionNum(int shuffleId) {
+ return shuffleIdToPartitionNum.getOrDefault(shuffleId, 0);
}
+ /**
+ * @param shuffleId the shuffle id to query
+ * @return the num of map tasks for current shuffle with shuffle id.
+ */
+ @Override
+ public int getNumMaps(int shuffleId) {
+ return shuffleIdToNumMapTasks.getOrDefault(shuffleId, 0);
+ }
+
+ @VisibleForTesting
+ public void addSuccessBlockIds(String taskId, Set<Long> blockIds) {
+ if (taskToSuccessBlockIds.get(taskId) == null) {
+ taskToSuccessBlockIds.put(taskId, Sets.newHashSet());
+ }
+ taskToSuccessBlockIds.get(taskId).addAll(blockIds);
+ }
+
+ @VisibleForTesting
+ public void addFailedBlockSendTracker(
+ String taskId, FailedBlockSendTracker failedBlockSendTracker) {
+ taskToFailedBlockSendTracker.putIfAbsent(taskId, failedBlockSendTracker);
+ }
+
+ /** Create the shuffleWriteClient. */
+ protected abstract ShuffleWriteClient createShuffleWriteClient();
+
+ /**
+ * Check whether the configuration is supported.
+ *
+ * @param sparkConf the sparkConf
+ */
+ protected void checkSupported(SparkConf sparkConf) {}
+
/**
* Creating the shuffleAssignmentInfo from the servers and partitionIds
*
@@ -931,7 +1272,7 @@ public abstract class RssShuffleManagerBase implements
RssShuffleManagerInterfac
try {
ShuffleAssignmentsInfo response =
shuffleWriteClient.getShuffleAssignments(
- id.get(),
+ getAppId(),
shuffleId,
partitionNum,
partitionNumPerRange,
@@ -949,7 +1290,7 @@ public abstract class RssShuffleManagerBase implements
RssShuffleManagerInterfac
response = reassignmentHandler.apply(response);
}
registerShuffleServers(
- id.get(), shuffleId, response.getServerToPartitionRanges(),
getRemoteStorageInfo());
+ getAppId(), shuffleId, response.getServerToPartitionRanges(),
getRemoteStorageInfo());
return response.getPartitionToServers();
} catch (Throwable throwable) {
throw new RssException("registerShuffle failed!", throwable);
@@ -1131,4 +1472,95 @@ public abstract class RssShuffleManagerBase implements
RssShuffleManagerInterfac
public ShuffleWriteClient getShuffleWriteClient() {
return shuffleWriteClient;
}
+
+ protected synchronized void startHeartbeat() {
+ shuffleWriteClient.registerApplicationInfo(getAppId(), heartbeatTimeout,
user);
+ if (!sparkConf.getBoolean(RssSparkConfig.RSS_TEST_FLAG.key(), false) &&
!heartbeatStarted) {
+ heartBeatScheduledExecutorService.scheduleAtFixedRate(
+ () -> {
+ try {
+ String appId = getAppId();
+ shuffleWriteClient.sendAppHeartbeat(appId, heartbeatTimeout);
+ LOG.info("Finish send heartbeat to coordinator and servers");
+ } catch (Exception e) {
+ LOG.warn("Fail to send heartbeat to coordinator and servers", e);
+ }
+ },
+ heartbeatInterval / 2,
+ heartbeatInterval,
+ TimeUnit.MILLISECONDS);
+ heartbeatStarted = true;
+ }
+ }
+
+ public void clearTaskMeta(String taskId) {
+ taskToSuccessBlockIds.remove(taskId);
+ taskToFailedBlockSendTracker.remove(taskId);
+ }
+
+ @VisibleForTesting
+ protected void registerCoordinator() {
+ String coordinators =
sparkConf.get(RssSparkConfig.RSS_COORDINATOR_QUORUM.key());
+ LOG.info("Start Registering coordinators {}", coordinators);
+ shuffleWriteClient.registerCoordinators(
+ coordinators,
+ this.sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_INTERVAL_MAX),
+ this.sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_MAX));
+ }
+
+ public Set<Long> getFailedBlockIds(String taskId) {
+ FailedBlockSendTracker blockIdsFailedSendTracker =
getBlockIdsFailedSendTracker(taskId);
+ if (blockIdsFailedSendTracker == null) {
+ return Collections.emptySet();
+ }
+ return blockIdsFailedSendTracker.getFailedBlockIds();
+ }
+
+ public Set<Long> getSuccessBlockIds(String taskId) {
+ Set<Long> result = taskToSuccessBlockIds.get(taskId);
+ if (result == null) {
+ result = Collections.emptySet();
+ }
+ return result;
+ }
+
+ public FailedBlockSendTracker getBlockIdsFailedSendTracker(String taskId) {
+ return taskToFailedBlockSendTracker.get(taskId);
+ }
+
+ public boolean markFailedTask(String taskId) {
+ LOG.info("Mark the task: {} failed.", taskId);
+ failedTaskIds.add(taskId);
+ return true;
+ }
+
+ public boolean isValidTask(String taskId) {
+ return !failedTaskIds.contains(taskId);
+ }
+
+ @VisibleForTesting
+ public void setDataPusher(DataPusher dataPusher) {
+ this.dataPusher = dataPusher;
+ }
+
+ public DataPusher getDataPusher() {
+ return dataPusher;
+ }
+
+ @VisibleForTesting
+ public Map<String, Set<Long>> getTaskToSuccessBlockIds() {
+ return taskToSuccessBlockIds;
+ }
+
+ @VisibleForTesting
+ public Map<String, FailedBlockSendTracker> getTaskToFailedBlockSendTracker()
{
+ return taskToFailedBlockSendTracker;
+ }
+
+ public CompletableFuture<Long> sendData(AddBlockEvent event) {
+ if (dataPusher != null && event != null) {
+ return dataPusher.send(event);
+ }
+ return new CompletableFuture<>();
+ }
}
diff --git
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index d62827243..8e6b8dfca 100644
---
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -17,15 +17,10 @@
package org.apache.spark.shuffle;
-import java.io.IOException;
-import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
-import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.ScheduledExecutorService;
-import java.util.concurrent.TimeUnit;
import scala.Option;
import scala.Tuple2;
@@ -34,7 +29,6 @@ import scala.collection.Seq;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Sets;
-import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
@@ -47,8 +41,6 @@ import org.apache.spark.shuffle.handle.ShuffleHandleInfo;
import org.apache.spark.shuffle.handle.SimpleShuffleHandleInfo;
import org.apache.spark.shuffle.handle.StageAttemptShuffleHandleInfo;
import org.apache.spark.shuffle.reader.RssShuffleReader;
-import org.apache.spark.shuffle.writer.AddBlockEvent;
-import org.apache.spark.shuffle.writer.DataPusher;
import org.apache.spark.shuffle.writer.RssShuffleWriter;
import org.apache.spark.storage.BlockId;
import org.apache.spark.storage.BlockManagerId;
@@ -56,205 +48,20 @@ import org.roaringbitmap.longlong.Roaring64NavigableMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import org.apache.uniffle.client.impl.FailedBlockSendTracker;
+import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.util.ClientUtils;
import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
-import org.apache.uniffle.common.config.RssClientConf;
-import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.exception.RssFetchFailedException;
-import org.apache.uniffle.common.rpc.GrpcServer;
-import org.apache.uniffle.common.util.BlockIdLayout;
-import org.apache.uniffle.common.util.JavaUtils;
-import org.apache.uniffle.common.util.RssUtils;
-import org.apache.uniffle.common.util.ThreadUtils;
import org.apache.uniffle.shuffle.RssShuffleClientFactory;
import org.apache.uniffle.shuffle.manager.RssShuffleManagerBase;
-import org.apache.uniffle.shuffle.manager.ShuffleManagerGrpcService;
-import org.apache.uniffle.shuffle.manager.ShuffleManagerServerFactory;
-
-import static
org.apache.spark.shuffle.RssSparkConfig.RSS_BLOCK_ID_SELF_MANAGEMENT_ENABLED;
-import static
org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED;
-import static
org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_WRITE_FAILURE_ENABLED;
-import static org.apache.uniffle.common.config.RssBaseConf.RPC_SERVER_PORT;
-import static
org.apache.uniffle.common.config.RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE;
public class RssShuffleManager extends RssShuffleManagerBase {
-
private static final Logger LOG =
LoggerFactory.getLogger(RssShuffleManager.class);
- private final long heartbeatInterval;
- private final long heartbeatTimeout;
- private ScheduledExecutorService heartBeatScheduledExecutorService;
- private Map<String, Set<Long>> taskToSuccessBlockIds =
JavaUtils.newConcurrentMap();
- private Map<String, FailedBlockSendTracker> taskToFailedBlockSendTracker =
- JavaUtils.newConcurrentMap();
- private final int dataReplica;
- private final int dataReplicaWrite;
- private final int dataReplicaRead;
- private final boolean dataReplicaSkipEnabled;
- private final int dataTransferPoolSize;
- private final int dataCommitPoolSize;
- private Set<String> failedTaskIds = Sets.newConcurrentHashSet();
- private boolean heartbeatStarted = false;
- private final int maxFailures;
- private final boolean speculation;
- private final BlockIdLayout blockIdLayout;
- private final String user;
- private final String uuid;
- private DataPusher dataPusher;
- private GrpcServer shuffleManagerServer;
- private ShuffleManagerGrpcService service;
public RssShuffleManager(SparkConf sparkConf, boolean isDriver) {
- if (sparkConf.getBoolean("spark.sql.adaptive.enabled", false)) {
- throw new IllegalArgumentException(
- "Spark2 doesn't support AQE, spark.sql.adaptive.enabled should be
false.");
- }
- this.sparkConf = sparkConf;
- this.user = sparkConf.get("spark.rss.quota.user", "user");
- this.uuid = sparkConf.get("spark.rss.quota.uuid",
Long.toString(System.currentTimeMillis()));
- this.dynamicConfEnabled =
sparkConf.get(RssSparkConfig.RSS_DYNAMIC_CLIENT_CONF_ENABLED);
-
- // fetch client conf and apply them if necessary
- if (isDriver && this.dynamicConfEnabled) {
- fetchAndApplyDynamicConf(sparkConf);
- }
- RssSparkShuffleUtils.validateRssClientConf(sparkConf);
-
- // configure block id layout
- this.maxFailures = sparkConf.getInt("spark.task.maxFailures", 4);
- this.speculation = sparkConf.getBoolean("spark.speculation", false);
- RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
- RssUtils.setExtraJavaProperties(rssConf);
- // configureBlockIdLayout requires maxFailures and speculation to be
initialized
- configureBlockIdLayout(sparkConf, rssConf);
- this.blockIdLayout = BlockIdLayout.from(rssConf);
-
- // set & check replica config
- this.dataReplica = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA);
- this.dataReplicaWrite =
sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_WRITE);
- this.dataReplicaRead = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_READ);
- this.dataTransferPoolSize =
sparkConf.get(RssSparkConfig.RSS_DATA_TRANSFER_POOL_SIZE);
- this.dataReplicaSkipEnabled =
sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_SKIP_ENABLED);
- this.maxConcurrencyPerPartitionToWrite =
-
RssSparkConfig.toRssConf(sparkConf).get(MAX_CONCURRENCY_PER_PARTITION_TO_WRITE);
- LOG.info(
- "Check quorum config [{}:{}:{}:{}]",
- dataReplica,
- dataReplicaWrite,
- dataReplicaRead,
- dataReplicaSkipEnabled);
- RssUtils.checkQuorumSetting(dataReplica, dataReplicaWrite,
dataReplicaRead);
-
- this.clientType = sparkConf.get(RssSparkConfig.RSS_CLIENT_TYPE);
- this.heartbeatInterval =
sparkConf.get(RssSparkConfig.RSS_HEARTBEAT_INTERVAL);
- this.heartbeatTimeout =
- sparkConf.getLong(RssSparkConfig.RSS_HEARTBEAT_TIMEOUT.key(),
heartbeatInterval / 2);
- int retryMax = sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_MAX);
- long retryIntervalMax =
sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_INTERVAL_MAX);
- int heartBeatThreadNum =
sparkConf.get(RssSparkConfig.RSS_CLIENT_HEARTBEAT_THREAD_NUM);
- this.dataCommitPoolSize =
sparkConf.get(RssSparkConfig.RSS_DATA_COMMIT_POOL_SIZE);
- int unregisterThreadPoolSize =
- sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_THREAD_POOL_SIZE);
- int unregisterTimeoutSec =
sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_TIMEOUT_SEC);
- int unregisterRequestTimeoutSec =
-
sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_REQUEST_TIMEOUT_SEC);
- // External shuffle service is not supported when using remote shuffle
service
- sparkConf.set("spark.shuffle.service.enabled", "false");
- LOG.info("Disable external shuffle service in RssShuffleManager.");
- // If we store shuffle data in distributed filesystem or in a disaggregated
- // shuffle cluster, we don't need shuffle data locality
- sparkConf.set("spark.shuffle.reduceLocality.enabled", "false");
- LOG.info("Disable shuffle data locality in RssShuffleManager.");
-
- this.shuffleIdToPartitionNum = JavaUtils.newConcurrentMap();
- this.shuffleIdToNumMapTasks = JavaUtils.newConcurrentMap();
- // stage retry for write/fetch failure
- rssStageRetryForFetchFailureEnabled =
- rssConf.get(RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED);
- rssStageRetryForWriteFailureEnabled =
- rssConf.get(RSS_RESUBMIT_STAGE_WITH_WRITE_FAILURE_ENABLED);
- if (rssStageRetryForFetchFailureEnabled ||
rssStageRetryForWriteFailureEnabled) {
- rssStageRetryEnabled = true;
- List<String> logTips = new ArrayList<>();
- if (rssStageRetryForWriteFailureEnabled) {
- logTips.add("write");
- }
- if (rssStageRetryForWriteFailureEnabled) {
- logTips.add("fetch");
- }
- LOG.info(
- "Activate the stage retry mechanism that will resubmit stage on {}
failure",
- StringUtils.join(logTips, "/"));
- }
- this.partitionReassignEnabled =
rssConf.getBoolean(RssClientConf.RSS_CLIENT_REASSIGN_ENABLED);
- this.blockIdSelfManagedEnabled =
rssConf.getBoolean(RSS_BLOCK_ID_SELF_MANAGEMENT_ENABLED);
- this.shuffleManagerRpcServiceEnabled =
- partitionReassignEnabled || rssStageRetryEnabled ||
blockIdSelfManagedEnabled;
- if (!sparkConf.getBoolean(RssSparkConfig.RSS_TEST_FLAG.key(), false)) {
- if (isDriver) {
- heartBeatScheduledExecutorService =
-
ThreadUtils.getDaemonSingleThreadScheduledExecutor("rss-heartbeat");
- if (shuffleManagerRpcServiceEnabled) {
- LOG.info("stage resubmit is supported and enabled");
- // start shuffle manager server
- rssConf.set(RPC_SERVER_PORT, 0);
- ShuffleManagerServerFactory factory = new
ShuffleManagerServerFactory(this, rssConf);
- service = factory.getService();
- shuffleManagerServer = factory.getServer(service);
- try {
- shuffleManagerServer.start();
- // pass this as a spark.rss.shuffle.manager.grpc.port config, so
it can be propagated to
- // executor properly.
- sparkConf.set(
- RssSparkConfig.RSS_SHUFFLE_MANAGER_GRPC_PORT,
shuffleManagerServer.getPort());
- } catch (Exception e) {
- LOG.error("Failed to start shuffle manager server", e);
- throw new RssException(e);
- }
- }
- }
- if (shuffleManagerRpcServiceEnabled) {
- getOrCreateShuffleManagerClientSupplier();
- }
- this.shuffleWriteClient =
- RssShuffleClientFactory.getInstance()
- .createShuffleWriteClient(
- RssShuffleClientFactory.newWriteBuilder()
- .blockIdSelfManagedEnabled(blockIdSelfManagedEnabled)
- .managerClientSupplier(managerClientSupplier)
- .clientType(clientType)
- .retryMax(retryMax)
- .retryIntervalMax(retryIntervalMax)
- .heartBeatThreadNum(heartBeatThreadNum)
- .replica(dataReplica)
- .replicaWrite(dataReplicaWrite)
- .replicaRead(dataReplicaRead)
- .replicaSkipEnabled(dataReplicaSkipEnabled)
- .dataTransferPoolSize(dataTransferPoolSize)
- .dataCommitPoolSize(dataCommitPoolSize)
- .unregisterThreadPoolSize(unregisterThreadPoolSize)
- .unregisterTimeSec(unregisterTimeoutSec)
- .unregisterRequestTimeSec(unregisterRequestTimeoutSec)
- .rssConf(rssConf));
- registerCoordinator();
-
- // for non-driver executor, start a thread for sending shuffle data to
shuffle server
- LOG.info("RSS data pusher is starting...");
- int poolSize =
sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_SIZE);
- int keepAliveTime =
sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_KEEPALIVE);
- this.dataPusher =
- new DataPusher(
- shuffleWriteClient,
- taskToSuccessBlockIds,
- taskToFailedBlockSendTracker,
- failedTaskIds,
- poolSize,
- keepAliveTime);
- }
- this.shuffleHandleInfoManager = new ShuffleHandleInfoManager();
- this.rssStageResubmitManager = new RssStageResubmitManager();
+ super(sparkConf, isDriver);
}
// This method is called in Spark driver side,
@@ -380,42 +187,6 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
return new RssShuffleHandle(shuffleId, appId, numMaps, dependency,
hdlInfoBd);
}
- private void startHeartbeat() {
- shuffleWriteClient.registerApplicationInfo(appId, heartbeatTimeout, user);
- if (!sparkConf.getBoolean(RssSparkConfig.RSS_TEST_FLAG.key(), false) &&
!heartbeatStarted) {
- heartBeatScheduledExecutorService.scheduleAtFixedRate(
- () -> {
- try {
- shuffleWriteClient.sendAppHeartbeat(appId, heartbeatTimeout);
- LOG.info("Finish send heartbeat to coordinator and servers");
- } catch (Exception e) {
- LOG.warn("Fail to send heartbeat to coordinator and servers", e);
- }
- },
- heartbeatInterval / 2,
- heartbeatInterval,
- TimeUnit.MILLISECONDS);
- heartbeatStarted = true;
- }
- }
-
- @VisibleForTesting
- protected void registerCoordinator() {
- String coordinators =
sparkConf.get(RssSparkConfig.RSS_COORDINATOR_QUORUM.key());
- LOG.info("Registering coordinators {}", coordinators);
- shuffleWriteClient.registerCoordinators(
- coordinators,
- this.sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_INTERVAL_MAX),
- this.sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_MAX));
- }
-
- public CompletableFuture<Long> sendData(AddBlockEvent event) {
- if (dataPusher != null && event != null) {
- return dataPusher.send(event);
- }
- return new CompletableFuture<>();
- }
-
// This method is called in Spark executor,
// getting information from Spark driver via the ShuffleHandle.
@Override
@@ -447,24 +218,6 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
}
}
- /**
- * Derives block id layout config from maximum number of allowed partitions.
Computes the number
- * of required bits for partition id and task attempt id and reserves
remaining bits for sequence
- * number.
- *
- * @param sparkConf Spark config providing max partitions
- * @param rssConf Rss config to amend
- */
- public void configureBlockIdLayout(SparkConf sparkConf, RssConf rssConf) {
- configureBlockIdLayout(sparkConf, rssConf, maxFailures, speculation);
- }
-
- @Override
- public long getTaskAttemptIdForBlockId(int mapIndex, int attemptNo) {
- return getTaskAttemptIdForBlockId(
- mapIndex, attemptNo, maxFailures, speculation,
blockIdLayout.taskAttemptIdBits);
- }
-
// This method is called in Spark executor,
// getting information from Spark driver via the ShuffleHandle.
@Override
@@ -563,44 +316,6 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
return null;
}
- @Override
- public boolean unregisterShuffle(int shuffleId) {
- try {
- super.unregisterShuffle(shuffleId);
- if (SparkEnv.get().executorId().equals("driver")) {
- shuffleWriteClient.unregisterShuffle(appId, shuffleId);
- shuffleIdToNumMapTasks.remove(shuffleId);
- shuffleIdToPartitionNum.remove(shuffleId);
- if (service != null) {
- service.unregisterShuffle(shuffleId);
- }
- }
- } catch (Exception e) {
- LOG.warn("Errors on unregistering from remote shuffle-servers", e);
- }
- return true;
- }
-
- @Override
- public void stop() {
- super.stop();
- if (heartBeatScheduledExecutorService != null) {
- heartBeatScheduledExecutorService.shutdownNow();
- }
- if (dataPusher != null) {
- try {
- dataPusher.close();
- } catch (IOException e) {
- LOG.warn("Errors on closing data pusher", e);
- }
- }
- if (shuffleWriteClient != null) {
- // Unregister shuffle before closing shuffle write client.
- shuffleWriteClient.unregisterShuffle(appId);
- shuffleWriteClient.close();
- }
- }
-
@Override
public ShuffleBlockResolver shuffleBlockResolver() {
throw new RssException("RssShuffleManager.shuffleBlockResolver is not
implemented");
@@ -631,93 +346,17 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
return taskIdBitmap;
}
- public Set<Long> getFailedBlockIds(String taskId) {
- FailedBlockSendTracker blockIdsFailedSendTracker =
getBlockIdsFailedSendTracker(taskId);
- if (blockIdsFailedSendTracker == null) {
- return Collections.emptySet();
- }
- return blockIdsFailedSendTracker.getFailedBlockIds();
- }
-
- public Set<Long> getSuccessBlockIds(String taskId) {
- Set<Long> result = taskToSuccessBlockIds.get(taskId);
- if (result == null) {
- result = Collections.emptySet();
- }
- return result;
- }
-
- @VisibleForTesting
- public void addSuccessBlockIds(String taskId, Set<Long> blockIds) {
- if (taskToSuccessBlockIds.get(taskId) == null) {
- taskToSuccessBlockIds.put(taskId, Sets.newHashSet());
- }
- taskToSuccessBlockIds.get(taskId).addAll(blockIds);
- }
-
- @VisibleForTesting
- public void addFailedBlockSendTracker(
- String taskId, FailedBlockSendTracker failedBlockSendTracker) {
- taskToFailedBlockSendTracker.putIfAbsent(taskId, failedBlockSendTracker);
- }
-
- public void clearTaskMeta(String taskId) {
- taskToSuccessBlockIds.remove(taskId);
- taskToFailedBlockSendTracker.remove(taskId);
- }
-
- @VisibleForTesting
- public SparkConf getSparkConf() {
- return sparkConf;
- }
-
@VisibleForTesting
public void setAppId(String appId) {
this.appId = appId;
}
- public boolean markFailedTask(String taskId) {
- LOG.info("Mark the task: {} failed.", taskId);
- failedTaskIds.add(taskId);
- return true;
- }
-
- public boolean isValidTask(String taskId) {
- return !failedTaskIds.contains(taskId);
- }
-
- public DataPusher getDataPusher() {
- return dataPusher;
- }
-
- public void setDataPusher(DataPusher dataPusher) {
- this.dataPusher = dataPusher;
- }
-
/** @return the unique spark id for rss shuffle */
@Override
public String getAppId() {
return appId;
}
- /**
- * @param shuffleId the shuffleId to query
- * @return the num of partitions(a.k.a reduce tasks) for shuffle with
shuffle id.
- */
- @Override
- public int getPartitionNum(int shuffleId) {
- return shuffleIdToPartitionNum.getOrDefault(shuffleId, 0);
- }
-
- /**
- * @param shuffleId the shuffle id to query
- * @return the num of map tasks for current shuffle with shuffle id.
- */
- @Override
- public int getNumMaps(int shuffleId) {
- return shuffleIdToNumMapTasks.getOrDefault(shuffleId, 0);
- }
-
private Roaring64NavigableMap getShuffleResult(
String clientType,
Set<ShuffleServerInfo> shuffleServerInfoSet,
@@ -740,10 +379,6 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
}
}
- public FailedBlockSendTracker getBlockIdsFailedSendTracker(String taskId) {
- return taskToFailedBlockSendTracker.get(taskId);
- }
-
private ShuffleServerInfo assignShuffleServer(int shuffleId, String
faultyShuffleServerId) {
Set<String> faultyServerIds = Sets.newHashSet(faultyShuffleServerId);
faultyServerIds.addAll(rssStageResubmitManager.getServerIdBlackList());
@@ -754,4 +389,44 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
}
return null;
}
+
+ @Override
+ protected ShuffleWriteClient createShuffleWriteClient() {
+ int unregisterThreadPoolSize =
+ sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_THREAD_POOL_SIZE);
+ int unregisterTimeoutSec =
sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_TIMEOUT_SEC);
+ int unregisterRequestTimeoutSec =
+
sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_REQUEST_TIMEOUT_SEC);
+ long retryIntervalMax =
sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_INTERVAL_MAX);
+ int heartBeatThreadNum =
sparkConf.get(RssSparkConfig.RSS_CLIENT_HEARTBEAT_THREAD_NUM);
+
+ final int retryMax = sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_MAX);
+ return RssShuffleClientFactory.getInstance()
+ .createShuffleWriteClient(
+ RssShuffleClientFactory.newWriteBuilder()
+ .blockIdSelfManagedEnabled(blockIdSelfManagedEnabled)
+ .managerClientSupplier(managerClientSupplier)
+ .clientType(clientType)
+ .retryMax(retryMax)
+ .retryIntervalMax(retryIntervalMax)
+ .heartBeatThreadNum(heartBeatThreadNum)
+ .replica(dataReplica)
+ .replicaWrite(dataReplicaWrite)
+ .replicaRead(dataReplicaRead)
+ .replicaSkipEnabled(dataReplicaSkipEnabled)
+ .dataTransferPoolSize(dataTransferPoolSize)
+ .dataCommitPoolSize(dataCommitPoolSize)
+ .unregisterThreadPoolSize(unregisterThreadPoolSize)
+ .unregisterTimeSec(unregisterTimeoutSec)
+ .unregisterRequestTimeSec(unregisterRequestTimeoutSec)
+ .rssConf(rssConf));
+ }
+
+ @Override
+ protected void checkSupported(SparkConf sparkConf) {
+ if (sparkConf.getBoolean("spark.sql.adaptive.enabled", false)) {
+ throw new IllegalArgumentException(
+ "Spark2 doesn't support AQE, spark.sql.adaptive.enabled should be
false.");
+ }
+ }
}
diff --git
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 37692a05a..5e2a94102 100644
---
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -17,16 +17,10 @@
package org.apache.spark.shuffle;
-import java.io.IOException;
-import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
-import java.util.Optional;
import java.util.Set;
-import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.ScheduledExecutorService;
-import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
@@ -38,7 +32,6 @@ import scala.collection.Seq;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Sets;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
-import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.spark.MapOutputTracker;
import org.apache.spark.ShuffleDependency;
@@ -53,7 +46,6 @@ import org.apache.spark.shuffle.handle.ShuffleHandleInfo;
import org.apache.spark.shuffle.handle.SimpleShuffleHandleInfo;
import org.apache.spark.shuffle.handle.StageAttemptShuffleHandleInfo;
import org.apache.spark.shuffle.reader.RssShuffleReader;
-import org.apache.spark.shuffle.writer.AddBlockEvent;
import org.apache.spark.shuffle.writer.DataPusher;
import org.apache.spark.shuffle.writer.RssShuffleWriter;
import org.apache.spark.sql.internal.SQLConf;
@@ -64,6 +56,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.uniffle.client.PartitionDataReplicaRequirementTracking;
+import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.impl.FailedBlockSendTracker;
import org.apache.uniffle.client.util.ClientUtils;
import org.apache.uniffle.common.RemoteStorageInfo;
@@ -73,235 +66,16 @@ import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.exception.RssFetchFailedException;
-import org.apache.uniffle.common.rpc.GrpcServer;
-import org.apache.uniffle.common.util.BlockIdLayout;
-import org.apache.uniffle.common.util.JavaUtils;
import org.apache.uniffle.common.util.RssUtils;
-import org.apache.uniffle.common.util.ThreadUtils;
import org.apache.uniffle.shuffle.RssShuffleClientFactory;
import org.apache.uniffle.shuffle.manager.RssShuffleManagerBase;
-import org.apache.uniffle.shuffle.manager.ShuffleManagerGrpcService;
-import org.apache.uniffle.shuffle.manager.ShuffleManagerServerFactory;
-
-import static
org.apache.spark.shuffle.RssSparkConfig.RSS_BLOCK_ID_SELF_MANAGEMENT_ENABLED;
-import static
org.apache.spark.shuffle.RssSparkConfig.RSS_PARTITION_REASSIGN_MAX_REASSIGNMENT_SERVER_NUM;
-import static
org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED;
-import static
org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_WRITE_FAILURE_ENABLED;
-import static org.apache.uniffle.common.config.RssBaseConf.RPC_SERVER_PORT;
-import static
org.apache.uniffle.common.config.RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE;
public class RssShuffleManager extends RssShuffleManagerBase {
private static final Logger LOG =
LoggerFactory.getLogger(RssShuffleManager.class);
- private final long heartbeatInterval;
- private final long heartbeatTimeout;
- private final int dataReplica;
- private final int dataReplicaWrite;
- private final int dataReplicaRead;
- private final boolean dataReplicaSkipEnabled;
- private final int dataTransferPoolSize;
- private final int dataCommitPoolSize;
- private final Map<String, Set<Long>> taskToSuccessBlockIds;
- private final Map<String, FailedBlockSendTracker>
taskToFailedBlockSendTracker;
- private ScheduledExecutorService heartBeatScheduledExecutorService;
- private boolean heartbeatStarted = false;
- private final BlockIdLayout blockIdLayout;
- private final int maxFailures;
- private final boolean speculation;
- private String user;
- private String uuid;
- private Set<String> failedTaskIds = Sets.newConcurrentHashSet();
- private DataPusher dataPusher;
- private ShuffleManagerGrpcService service;
- private GrpcServer shuffleManagerServer;
public RssShuffleManager(SparkConf conf, boolean isDriver) {
- this.sparkConf = conf;
- boolean supportsRelocation =
- Optional.ofNullable(SparkEnv.get())
- .map(env ->
env.serializer().supportsRelocationOfSerializedObjects())
- .orElse(true);
- if (!supportsRelocation) {
- LOG.warn(
- "RSSShuffleManager requires a serializer which supports relocations
of serialized object. Please set "
- + "spark.serializer to
org.apache.spark.serializer.KryoSerializer instead");
- }
- this.user = sparkConf.get("spark.rss.quota.user", "user");
- this.uuid = sparkConf.get("spark.rss.quota.uuid",
Long.toString(System.currentTimeMillis()));
- this.dynamicConfEnabled =
sparkConf.get(RssSparkConfig.RSS_DYNAMIC_CLIENT_CONF_ENABLED);
-
- // fetch client conf and apply them if necessary
- if (isDriver && this.dynamicConfEnabled) {
- fetchAndApplyDynamicConf(sparkConf);
- }
- RssSparkShuffleUtils.validateRssClientConf(sparkConf);
-
- // set & check replica config
- this.dataReplica = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA);
- this.dataReplicaWrite =
sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_WRITE);
- this.dataReplicaRead = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_READ);
- this.dataReplicaSkipEnabled =
sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_SKIP_ENABLED);
- LOG.info(
- "Check quorum config ["
- + dataReplica
- + ":"
- + dataReplicaWrite
- + ":"
- + dataReplicaRead
- + ":"
- + dataReplicaSkipEnabled
- + "]");
- RssUtils.checkQuorumSetting(dataReplica, dataReplicaWrite,
dataReplicaRead);
-
- this.heartbeatInterval =
sparkConf.get(RssSparkConfig.RSS_HEARTBEAT_INTERVAL);
- this.heartbeatTimeout =
- sparkConf.getLong(RssSparkConfig.RSS_HEARTBEAT_TIMEOUT.key(),
heartbeatInterval / 2);
- final int retryMax = sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_MAX);
- this.clientType = sparkConf.get(RssSparkConfig.RSS_CLIENT_TYPE);
+ super(conf, isDriver);
this.dataDistributionType = getDataDistributionType(sparkConf);
- RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
- RssUtils.setExtraJavaProperties(rssConf);
- this.maxConcurrencyPerPartitionToWrite =
rssConf.get(MAX_CONCURRENCY_PER_PARTITION_TO_WRITE);
- this.maxFailures = sparkConf.getInt("spark.task.maxFailures", 4);
- this.speculation = sparkConf.getBoolean("spark.speculation", false);
- // configureBlockIdLayout requires maxFailures and speculation to be
initialized
- configureBlockIdLayout(sparkConf, rssConf);
- this.blockIdLayout = BlockIdLayout.from(rssConf);
- this.dataTransferPoolSize =
sparkConf.get(RssSparkConfig.RSS_DATA_TRANSFER_POOL_SIZE);
- this.dataCommitPoolSize =
sparkConf.get(RssSparkConfig.RSS_DATA_COMMIT_POOL_SIZE);
- // External shuffle service is not supported when using remote shuffle
service
- sparkConf.set("spark.shuffle.service.enabled", "false");
- sparkConf.set("spark.dynamicAllocation.shuffleTracking.enabled", "false");
- sparkConf.set(RssSparkConfig.RSS_ENABLED.key(), "true");
- LOG.info("Disable external shuffle service in RssShuffleManager.");
- sparkConf.set("spark.sql.adaptive.localShuffleReader.enabled", "false");
- LOG.info("Disable local shuffle reader in RssShuffleManager.");
- // If we store shuffle data in distributed filesystem or in a disaggregated
- // shuffle cluster, we don't need shuffle data locality
- sparkConf.set("spark.shuffle.reduceLocality.enabled", "false");
- LOG.info("Disable shuffle data locality in RssShuffleManager.");
- taskToSuccessBlockIds = JavaUtils.newConcurrentMap();
- taskToFailedBlockSendTracker = JavaUtils.newConcurrentMap();
- this.shuffleIdToPartitionNum = JavaUtils.newConcurrentMap();
- this.shuffleIdToNumMapTasks = JavaUtils.newConcurrentMap();
- this.partitionReassignEnabled =
rssConf.get(RssClientConf.RSS_CLIENT_REASSIGN_ENABLED);
-
- // stage retry for write/fetch failure
- rssStageRetryForFetchFailureEnabled =
- rssConf.get(RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED);
- rssStageRetryForWriteFailureEnabled =
- rssConf.get(RSS_RESUBMIT_STAGE_WITH_WRITE_FAILURE_ENABLED);
- if (rssStageRetryForFetchFailureEnabled ||
rssStageRetryForWriteFailureEnabled) {
- rssStageRetryEnabled = true;
- List<String> logTips = new ArrayList<>();
- if (rssStageRetryForWriteFailureEnabled) {
- logTips.add("write");
- }
- if (rssStageRetryForWriteFailureEnabled) {
- logTips.add("fetch");
- }
- LOG.info(
- "Activate the stage retry mechanism that will resubmit stage on {}
failure",
- StringUtils.join(logTips, "/"));
- }
-
- // The feature of partition reassign is exclusive with multiple replicas
and stage retry.
- if (partitionReassignEnabled) {
- if (dataReplica > 1) {
- throw new RssException(
- "The feature of task partition reassign is incompatible with
multiple replicas mechanism.");
- }
- }
-
- this.blockIdSelfManagedEnabled =
rssConf.getBoolean(RSS_BLOCK_ID_SELF_MANAGEMENT_ENABLED);
- this.shuffleManagerRpcServiceEnabled =
- partitionReassignEnabled || rssStageRetryEnabled ||
blockIdSelfManagedEnabled;
- if (isDriver) {
- heartBeatScheduledExecutorService =
- ThreadUtils.getDaemonSingleThreadScheduledExecutor("rss-heartbeat");
- if (shuffleManagerRpcServiceEnabled) {
- LOG.info("stage resubmit is supported and enabled");
- // start shuffle manager server
- rssConf.set(RPC_SERVER_PORT, 0);
- ShuffleManagerServerFactory factory = new
ShuffleManagerServerFactory(this, rssConf);
- service = factory.getService();
- shuffleManagerServer = factory.getServer(service);
- try {
- shuffleManagerServer.start();
- // pass this as a spark.rss.shuffle.manager.grpc.port config, so it
can be propagated to
- // executor properly.
- sparkConf.set(
- RssSparkConfig.RSS_SHUFFLE_MANAGER_GRPC_PORT,
shuffleManagerServer.getPort());
- } catch (Exception e) {
- LOG.error("Failed to start shuffle manager server", e);
- throw new RssException(e);
- }
- }
- }
- if (shuffleManagerRpcServiceEnabled) {
- getOrCreateShuffleManagerClientSupplier();
- }
- int unregisterThreadPoolSize =
- sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_THREAD_POOL_SIZE);
- int unregisterTimeoutSec =
sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_TIMEOUT_SEC);
- int unregisterRequestTimeoutSec =
-
sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_REQUEST_TIMEOUT_SEC);
- long retryIntervalMax =
sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_INTERVAL_MAX);
- int heartBeatThreadNum =
sparkConf.get(RssSparkConfig.RSS_CLIENT_HEARTBEAT_THREAD_NUM);
- shuffleWriteClient =
- RssShuffleClientFactory.getInstance()
- .createShuffleWriteClient(
- RssShuffleClientFactory.newWriteBuilder()
- .blockIdSelfManagedEnabled(blockIdSelfManagedEnabled)
- .managerClientSupplier(managerClientSupplier)
- .clientType(clientType)
- .retryMax(retryMax)
- .retryIntervalMax(retryIntervalMax)
- .heartBeatThreadNum(heartBeatThreadNum)
- .replica(dataReplica)
- .replicaWrite(dataReplicaWrite)
- .replicaRead(dataReplicaRead)
- .replicaSkipEnabled(dataReplicaSkipEnabled)
- .dataTransferPoolSize(dataTransferPoolSize)
- .dataCommitPoolSize(dataCommitPoolSize)
- .unregisterThreadPoolSize(unregisterThreadPoolSize)
- .unregisterTimeSec(unregisterTimeoutSec)
- .unregisterRequestTimeSec(unregisterRequestTimeoutSec)
- .rssConf(rssConf));
- registerCoordinator();
-
- LOG.info("Rss data pusher is starting...");
- int poolSize =
sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_SIZE);
- int keepAliveTime =
sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_KEEPALIVE);
- this.dataPusher =
- new DataPusher(
- shuffleWriteClient,
- taskToSuccessBlockIds,
- taskToFailedBlockSendTracker,
- failedTaskIds,
- poolSize,
- keepAliveTime);
- this.partitionReassignMaxServerNum =
- rssConf.get(RSS_PARTITION_REASSIGN_MAX_REASSIGNMENT_SERVER_NUM);
- this.shuffleHandleInfoManager = new ShuffleHandleInfoManager();
- this.rssStageResubmitManager = new RssStageResubmitManager();
- }
-
- public CompletableFuture<Long> sendData(AddBlockEvent event) {
- if (dataPusher != null && event != null) {
- return dataPusher.send(event);
- }
- return new CompletableFuture<>();
- }
-
- @VisibleForTesting
- protected static ShuffleDataDistributionType
getDataDistributionType(SparkConf sparkConf) {
- RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
- if ((boolean) sparkConf.get(SQLConf.ADAPTIVE_EXECUTION_ENABLED())
- && !rssConf.containsKey(RssClientConf.DATA_DISTRIBUTION_TYPE.key())) {
- return ShuffleDataDistributionType.LOCAL_ORDER;
- }
-
- return rssConf.get(RssClientConf.DATA_DISTRIBUTION_TYPE);
}
// For testing only
@@ -312,72 +86,7 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
DataPusher dataPusher,
Map<String, Set<Long>> taskToSuccessBlockIds,
Map<String, FailedBlockSendTracker> taskToFailedBlockSendTracker) {
- this.sparkConf = conf;
- this.clientType = sparkConf.get(RssSparkConfig.RSS_CLIENT_TYPE);
- RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
- this.dataDistributionType =
rssConf.get(RssClientConf.DATA_DISTRIBUTION_TYPE);
- this.blockIdLayout = BlockIdLayout.from(rssConf);
- this.maxConcurrencyPerPartitionToWrite =
rssConf.get(MAX_CONCURRENCY_PER_PARTITION_TO_WRITE);
- this.maxFailures = sparkConf.getInt("spark.task.maxFailures", 4);
- this.speculation = sparkConf.getBoolean("spark.speculation", false);
- // configureBlockIdLayout requires maxFailures and speculation to be
initialized
- configureBlockIdLayout(sparkConf, rssConf);
- this.heartbeatInterval =
sparkConf.get(RssSparkConfig.RSS_HEARTBEAT_INTERVAL);
- this.heartbeatTimeout =
- sparkConf.getLong(RssSparkConfig.RSS_HEARTBEAT_TIMEOUT.key(),
heartbeatInterval / 2);
- this.dataReplica = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA);
- this.dataReplicaWrite =
sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_WRITE);
- this.dataReplicaRead = sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_READ);
- this.dataReplicaSkipEnabled =
sparkConf.get(RssSparkConfig.RSS_DATA_REPLICA_SKIP_ENABLED);
- LOG.info(
- "Check quorum config ["
- + dataReplica
- + ":"
- + dataReplicaWrite
- + ":"
- + dataReplicaRead
- + ":"
- + dataReplicaSkipEnabled
- + "]");
- RssUtils.checkQuorumSetting(dataReplica, dataReplicaWrite,
dataReplicaRead);
-
- int retryMax = sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_MAX);
- long retryIntervalMax =
sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_INTERVAL_MAX);
- int heartBeatThreadNum =
sparkConf.get(RssSparkConfig.RSS_CLIENT_HEARTBEAT_THREAD_NUM);
- this.dataTransferPoolSize =
sparkConf.get(RssSparkConfig.RSS_DATA_TRANSFER_POOL_SIZE);
- this.dataCommitPoolSize =
sparkConf.get(RssSparkConfig.RSS_DATA_COMMIT_POOL_SIZE);
- int unregisterThreadPoolSize =
- sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_THREAD_POOL_SIZE);
- int unregisterTimeoutSec =
sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_TIMEOUT_SEC);
- int unregisterRequestTimeoutSec =
-
sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_REQUEST_TIMEOUT_SEC);
- shuffleWriteClient =
- RssShuffleClientFactory.getInstance()
- .createShuffleWriteClient(
- RssShuffleClientFactory.getInstance()
- .newWriteBuilder()
- .clientType(clientType)
- .retryMax(retryMax)
- .retryIntervalMax(retryIntervalMax)
- .heartBeatThreadNum(heartBeatThreadNum)
- .replica(dataReplica)
- .replicaWrite(dataReplicaWrite)
- .replicaRead(dataReplicaRead)
- .replicaSkipEnabled(dataReplicaSkipEnabled)
- .dataTransferPoolSize(dataTransferPoolSize)
- .dataCommitPoolSize(dataCommitPoolSize)
- .unregisterThreadPoolSize(unregisterThreadPoolSize)
- .unregisterTimeSec(unregisterTimeoutSec)
- .unregisterRequestTimeSec(unregisterRequestTimeoutSec)
- .rssConf(rssConf));
- this.taskToSuccessBlockIds = taskToSuccessBlockIds;
- this.heartBeatScheduledExecutorService = null;
- this.taskToFailedBlockSendTracker = taskToFailedBlockSendTracker;
- this.dataPusher = dataPusher;
- this.partitionReassignMaxServerNum =
- rssConf.get(RSS_PARTITION_REASSIGN_MAX_REASSIGNMENT_SERVER_NUM);
- this.shuffleHandleInfoManager = new ShuffleHandleInfoManager();
- this.rssStageResubmitManager = new RssStageResubmitManager();
+ super(conf, isDriver, dataPusher, taskToSuccessBlockIds,
taskToFailedBlockSendTracker);
}
// This method is called in Spark driver side,
@@ -527,17 +236,6 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
context);
}
- @Override
- public void configureBlockIdLayout(SparkConf sparkConf, RssConf rssConf) {
- configureBlockIdLayout(sparkConf, rssConf, maxFailures, speculation);
- }
-
- @Override
- public long getTaskAttemptIdForBlockId(int mapIndex, int attemptNo) {
- return getTaskAttemptIdForBlockId(
- mapIndex, attemptNo, maxFailures, speculation,
blockIdLayout.taskAttemptIdBits);
- }
-
public void setPusherAppId(RssShuffleHandle rssShuffleHandle) {
// todo: this implement is tricky, we should refactor it
if (id.get() == null) {
@@ -845,127 +543,52 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
return taskIdBitmap;
}
- @Override
- public boolean unregisterShuffle(int shuffleId) {
- try {
- super.unregisterShuffle(shuffleId);
- if (SparkEnv.get().executorId().equals("driver")) {
- shuffleWriteClient.unregisterShuffle(id.get(), shuffleId);
- shuffleIdToPartitionNum.remove(shuffleId);
- shuffleIdToNumMapTasks.remove(shuffleId);
- if (service != null) {
- service.unregisterShuffle(shuffleId);
- }
- }
- } catch (Exception e) {
- LOG.warn("Errors on unregistering from remote shuffle-servers", e);
- }
- return true;
- }
-
@Override
public ShuffleBlockResolver shuffleBlockResolver() {
throw new RssException("RssShuffleManager.shuffleBlockResolver is not
implemented");
}
@Override
- public void stop() {
- super.stop();
- if (heartBeatScheduledExecutorService != null) {
- heartBeatScheduledExecutorService.shutdownNow();
- }
- if (shuffleWriteClient != null) {
- // Unregister shuffle before closing shuffle write client.
- shuffleWriteClient.unregisterShuffle(getAppId());
- shuffleWriteClient.close();
- }
- if (dataPusher != null) {
- try {
- dataPusher.close();
- } catch (IOException e) {
- LOG.warn("Errors on closing data pusher", e);
- }
- }
-
- if (shuffleManagerServer != null) {
- try {
- shuffleManagerServer.stop();
- } catch (InterruptedException e) {
- // ignore
- LOG.info("shuffle manager server is interrupted during stop");
- }
- }
- }
+ protected ShuffleWriteClient createShuffleWriteClient() {
+ int unregisterThreadPoolSize =
+ sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_THREAD_POOL_SIZE);
+ int unregisterTimeoutSec =
sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_TIMEOUT_SEC);
+ int unregisterRequestTimeoutSec =
+
sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_REQUEST_TIMEOUT_SEC);
+ long retryIntervalMax =
sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_INTERVAL_MAX);
+ int heartBeatThreadNum =
sparkConf.get(RssSparkConfig.RSS_CLIENT_HEARTBEAT_THREAD_NUM);
- public void clearTaskMeta(String taskId) {
- taskToSuccessBlockIds.remove(taskId);
- taskToFailedBlockSendTracker.remove(taskId);
+ final int retryMax = sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_MAX);
+ return RssShuffleClientFactory.getInstance()
+ .createShuffleWriteClient(
+ RssShuffleClientFactory.newWriteBuilder()
+ .blockIdSelfManagedEnabled(blockIdSelfManagedEnabled)
+ .managerClientSupplier(managerClientSupplier)
+ .clientType(clientType)
+ .retryMax(retryMax)
+ .retryIntervalMax(retryIntervalMax)
+ .heartBeatThreadNum(heartBeatThreadNum)
+ .replica(dataReplica)
+ .replicaWrite(dataReplicaWrite)
+ .replicaRead(dataReplicaRead)
+ .replicaSkipEnabled(dataReplicaSkipEnabled)
+ .dataTransferPoolSize(dataTransferPoolSize)
+ .dataCommitPoolSize(dataCommitPoolSize)
+ .unregisterThreadPoolSize(unregisterThreadPoolSize)
+ .unregisterTimeSec(unregisterTimeoutSec)
+ .unregisterRequestTimeSec(unregisterRequestTimeoutSec)
+ .rssConf(rssConf));
}
@VisibleForTesting
- protected void registerCoordinator() {
- String coordinators =
sparkConf.get(RssSparkConfig.RSS_COORDINATOR_QUORUM.key());
- LOG.info("Start Registering coordinators {}", coordinators);
- shuffleWriteClient.registerCoordinators(
- coordinators,
- this.sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_INTERVAL_MAX),
- this.sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_MAX));
- }
-
- private synchronized void startHeartbeat() {
- shuffleWriteClient.registerApplicationInfo(id.get(), heartbeatTimeout,
user);
- if (!heartbeatStarted) {
- heartBeatScheduledExecutorService.scheduleAtFixedRate(
- () -> {
- try {
- String appId = id.get();
- shuffleWriteClient.sendAppHeartbeat(appId, heartbeatTimeout);
- LOG.info("Finish send heartbeat to coordinator and servers");
- } catch (Exception e) {
- LOG.warn("Fail to send heartbeat to coordinator and servers", e);
- }
- },
- heartbeatInterval / 2,
- heartbeatInterval,
- TimeUnit.MILLISECONDS);
- heartbeatStarted = true;
- }
- }
-
- public Set<Long> getFailedBlockIds(String taskId) {
- FailedBlockSendTracker blockIdsFailedSendTracker =
getBlockIdsFailedSendTracker(taskId);
- if (blockIdsFailedSendTracker == null) {
- return Collections.emptySet();
- }
- return blockIdsFailedSendTracker.getFailedBlockIds();
- }
-
- public Set<Long> getSuccessBlockIds(String taskId) {
- Set<Long> result = taskToSuccessBlockIds.get(taskId);
- if (result == null) {
- result = Collections.emptySet();
+ protected static ShuffleDataDistributionType
getDataDistributionType(SparkConf sparkConf) {
+ RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
+ if ((boolean) sparkConf.get(SQLConf.ADAPTIVE_EXECUTION_ENABLED())
+ && !rssConf.containsKey(RssClientConf.DATA_DISTRIBUTION_TYPE.key())) {
+ return ShuffleDataDistributionType.LOCAL_ORDER;
}
- return result;
- }
- /** @return the unique spark id for rss shuffle */
- @Override
- public String getAppId() {
- return id.get();
- }
-
- @Override
- public int getPartitionNum(int shuffleId) {
- return shuffleIdToPartitionNum.getOrDefault(shuffleId, 0);
- }
-
- /**
- * @param shuffleId the shuffle id to query
- * @return the num of map tasks for current shuffle with shuffle id.
- */
- @Override
- public int getNumMaps(int shuffleId) {
- return shuffleIdToNumMapTasks.getOrDefault(shuffleId, 0);
+ return rssConf.get(RssClientConf.DATA_DISTRIBUTION_TYPE);
}
static class ReadMetrics extends ShuffleReadMetrics {
@@ -1019,16 +642,6 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
this.id = new AtomicReference<>(appId);
}
- public boolean markFailedTask(String taskId) {
- LOG.info("Mark the task: {} failed.", taskId);
- failedTaskIds.add(taskId);
- return true;
- }
-
- public boolean isValidTask(String taskId) {
- return !failedTaskIds.contains(taskId);
- }
-
private Roaring64NavigableMap getShuffleResultForMultiPart(
String clientType,
Map<ShuffleServerInfo, Set<Integer>> serverToPartitions,
@@ -1050,23 +663,4 @@ public class RssShuffleManager extends
RssShuffleManagerBase {
managerClientSupplier, e, sparkConf, appId, shuffleId,
stageAttemptId, failedPartitions);
}
}
-
- public FailedBlockSendTracker getBlockIdsFailedSendTracker(String taskId) {
- return taskToFailedBlockSendTracker.get(taskId);
- }
-
- @VisibleForTesting
- public void setDataPusher(DataPusher dataPusher) {
- this.dataPusher = dataPusher;
- }
-
- @VisibleForTesting
- public Map<String, Set<Long>> getTaskToSuccessBlockIds() {
- return taskToSuccessBlockIds;
- }
-
- @VisibleForTesting
- public Map<String, FailedBlockSendTracker> getTaskToFailedBlockSendTracker()
{
- return taskToFailedBlockSendTracker;
- }
}