This is an automated email from the ASF dual-hosted git repository. xianjin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push: new 457c86536 [#1608] feat: Introduce ExpiringClosableSupplier and refactor ShuffleManagerClient creation (#1838) 457c86536 is described below commit 457c865362e1dc573004b30c505287c253a6dba0 Author: xumanbu <jam...@vipshop.com> AuthorDate: Fri Jul 26 21:24:28 2024 +0800 [#1608] feat: Introduce ExpiringClosableSupplier and refactor ShuffleManagerClient creation (#1838) ### What changes were proposed in this pull request? 1. Introduce StatefulCloseable and ExpiringClosableSupplier 2. refactor ShuffleManagerClient to leverage ExpiringClosableSupplier ### Why are the changes needed? For better code quality ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing UTs and new UTs. --- .../apache/spark/shuffle/RssSparkShuffleUtils.java | 48 +++--- .../shuffle/reader/RssFetchFailedIterator.java | 63 +++----- .../BlockIdSelfManagedShuffleWriteClient.java | 13 +- .../uniffle/shuffle/RssShuffleClientFactory.java | 12 +- .../shuffle/manager/RssShuffleManagerBase.java | 36 +++-- .../apache/spark/shuffle/RssShuffleManager.java | 18 ++- .../spark/shuffle/reader/RssShuffleReader.java | 12 +- .../spark/shuffle/writer/RssShuffleWriter.java | 71 ++++----- .../spark/shuffle/reader/RssShuffleReaderTest.java | 6 +- .../spark/shuffle/writer/RssShuffleWriterTest.java | 8 + .../apache/spark/shuffle/RssShuffleManager.java | 9 +- .../spark/shuffle/reader/RssShuffleReader.java | 11 +- .../spark/shuffle/writer/RssShuffleWriter.java | 84 +++++----- .../spark/shuffle/reader/RssShuffleReaderTest.java | 6 + .../spark/shuffle/writer/RssShuffleWriterTest.java | 14 ++ .../common/util/ExpiringCloseableSupplier.java | 110 +++++++++++++ .../uniffle/common/util/StatefulCloseable.java | 25 +++ .../common/util/ExpiringCloseableSupplierTest.java | 172 +++++++++++++++++++++ .../uniffle/test/ShuffleServerManagerTestBase.java | 13 +- .../uniffle/client/api/ShuffleManagerClient.java | 5 +- .../factory/ShuffleManagerClientFactory.java | 4 +- .../client/impl/grpc/ShuffleManagerGrpcClient.java | 20 ++- .../factory/ShuffleManagerClientFactoryTest.java | 5 +- 23 files changed, 545 insertions(+), 220 deletions(-) diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java index b3763df32..feee2a331 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java @@ -17,7 +17,6 @@ package org.apache.spark.shuffle; -import java.io.IOException; import java.lang.reflect.Constructor; import java.lang.reflect.InvocationTargetException; import java.util.Arrays; @@ -25,6 +24,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.function.Supplier; import scala.Option; import scala.reflect.ClassTag; @@ -43,21 +43,18 @@ import org.slf4j.LoggerFactory; import org.apache.uniffle.client.api.CoordinatorClient; import org.apache.uniffle.client.api.ShuffleManagerClient; import org.apache.uniffle.client.factory.CoordinatorClientFactory; -import org.apache.uniffle.client.factory.ShuffleManagerClientFactory; import org.apache.uniffle.client.request.RssReportShuffleFetchFailureRequest; import org.apache.uniffle.client.response.RssReportShuffleFetchFailureResponse; import org.apache.uniffle.client.util.ClientUtils; import org.apache.uniffle.common.ClientType; import org.apache.uniffle.common.RemoteStorageInfo; import org.apache.uniffle.common.ShuffleServerInfo; -import org.apache.uniffle.common.config.RssClientConf; import org.apache.uniffle.common.config.RssConf; import org.apache.uniffle.common.exception.RssException; import org.apache.uniffle.common.exception.RssFetchFailedException; import org.apache.uniffle.common.util.Constants; import static org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED; -import static org.apache.uniffle.common.util.Constants.DRIVER_HOST; public class RssSparkShuffleUtils { @@ -346,6 +343,7 @@ public class RssSparkShuffleUtils { } public static RssException reportRssFetchFailedException( + Supplier<ShuffleManagerClient> managerClientSupplier, RssFetchFailedException rssFetchFailedException, SparkConf sparkConf, String appId, @@ -355,32 +353,24 @@ public class RssSparkShuffleUtils { RssConf rssConf = RssSparkConfig.toRssConf(sparkConf); if (rssConf.getBoolean(RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED) && RssSparkShuffleUtils.isStageResubmitSupported()) { - String driver = rssConf.getString(DRIVER_HOST, ""); - int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT); - try (ShuffleManagerClient client = - ShuffleManagerClientFactory.getInstance() - .createShuffleManagerClient(ClientType.GRPC, driver, port)) { - // todo: Create a new rpc interface to report failures in batch. - for (int partitionId : failedPartitions) { - RssReportShuffleFetchFailureRequest req = - new RssReportShuffleFetchFailureRequest( - appId, - shuffleId, - stageAttemptId, - partitionId, - rssFetchFailedException.getMessage()); - RssReportShuffleFetchFailureResponse response = client.reportShuffleFetchFailure(req); - if (response.getReSubmitWholeStage()) { - // since we are going to roll out the whole stage, mapIndex shouldn't matter, hence -1 - // is provided. - FetchFailedException ffe = - RssSparkShuffleUtils.createFetchFailedException( - shuffleId, -1, partitionId, rssFetchFailedException); - return new RssException(ffe); - } + for (int partitionId : failedPartitions) { + RssReportShuffleFetchFailureRequest req = + new RssReportShuffleFetchFailureRequest( + appId, + shuffleId, + stageAttemptId, + partitionId, + rssFetchFailedException.getMessage()); + RssReportShuffleFetchFailureResponse response = + managerClientSupplier.get().reportShuffleFetchFailure(req); + if (response.getReSubmitWholeStage()) { + // since we are going to roll out the whole stage, mapIndex shouldn't matter, hence -1 + // is provided. + FetchFailedException ffe = + RssSparkShuffleUtils.createFetchFailedException( + shuffleId, -1, partitionId, rssFetchFailedException); + return new RssException(ffe); } - } catch (IOException ioe) { - LOG.info("Error closing shuffle manager client with error:", ioe); } } return rssFetchFailedException; diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssFetchFailedIterator.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssFetchFailedIterator.java index c394f510b..1bc61dc74 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssFetchFailedIterator.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssFetchFailedIterator.java @@ -17,8 +17,8 @@ package org.apache.spark.shuffle.reader; -import java.io.IOException; import java.util.Objects; +import java.util.function.Supplier; import scala.Product2; import scala.collection.AbstractIterator; @@ -30,10 +30,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.uniffle.client.api.ShuffleManagerClient; -import org.apache.uniffle.client.factory.ShuffleManagerClientFactory; import org.apache.uniffle.client.request.RssReportShuffleFetchFailureRequest; import org.apache.uniffle.client.response.RssReportShuffleFetchFailureResponse; -import org.apache.uniffle.common.ClientType; import org.apache.uniffle.common.exception.RssException; import org.apache.uniffle.common.exception.RssFetchFailedException; @@ -52,8 +50,7 @@ public class RssFetchFailedIterator<K, C> extends AbstractIterator<Product2<K, C private int shuffleId; private int partitionId; private int stageAttemptId; - private String reportServerHost; - private int reportServerPort; + private Supplier<ShuffleManagerClient> managerClientSupplier; private Builder() {} @@ -77,19 +74,13 @@ public class RssFetchFailedIterator<K, C> extends AbstractIterator<Product2<K, C return this; } - Builder reportServerHost(String host) { - this.reportServerHost = host; - return this; - } - - Builder port(int port) { - this.reportServerPort = port; + Builder managerClientSupplier(Supplier<ShuffleManagerClient> managerClientSupplier) { + this.managerClientSupplier = managerClientSupplier; return this; } <K, C> RssFetchFailedIterator<K, C> build(Iterator<Product2<K, C>> iter) { Objects.requireNonNull(this.appId); - Objects.requireNonNull(this.reportServerHost); return new RssFetchFailedIterator<>(this, iter); } } @@ -98,37 +89,23 @@ public class RssFetchFailedIterator<K, C> extends AbstractIterator<Product2<K, C return new Builder(); } - private static ShuffleManagerClient createShuffleManagerClient(String host, int port) - throws IOException { - ClientType grpc = ClientType.GRPC; - // host is passed from spark.driver.bindAddress, which would be set when SparkContext is - // constructed. - return ShuffleManagerClientFactory.getInstance().createShuffleManagerClient(grpc, host, port); - } - private RssException generateFetchFailedIfNecessary(RssFetchFailedException e) { - String driver = builder.reportServerHost; - int port = builder.reportServerPort; - // todo: reuse this manager client if this is a bottleneck. - try (ShuffleManagerClient client = createShuffleManagerClient(driver, port)) { - RssReportShuffleFetchFailureRequest req = - new RssReportShuffleFetchFailureRequest( - builder.appId, - builder.shuffleId, - builder.stageAttemptId, - builder.partitionId, - e.getMessage()); - RssReportShuffleFetchFailureResponse response = client.reportShuffleFetchFailure(req); - if (response.getReSubmitWholeStage()) { - // since we are going to roll out the whole stage, mapIndex shouldn't matter, hence -1 is - // provided. - FetchFailedException ffe = - RssSparkShuffleUtils.createFetchFailedException( - builder.shuffleId, -1, builder.partitionId, e); - return new RssException(ffe); - } - } catch (IOException ioe) { - LOG.info("Error closing shuffle manager client with error:", ioe); + ShuffleManagerClient client = builder.managerClientSupplier.get(); + RssReportShuffleFetchFailureRequest req = + new RssReportShuffleFetchFailureRequest( + builder.appId, + builder.shuffleId, + builder.stageAttemptId, + builder.partitionId, + e.getMessage()); + RssReportShuffleFetchFailureResponse response = client.reportShuffleFetchFailure(req); + if (response.getReSubmitWholeStage()) { + // since we are going to roll out the whole stage, mapIndex shouldn't matter, hence -1 is + // provided. + FetchFailedException ffe = + RssSparkShuffleUtils.createFetchFailedException( + builder.shuffleId, -1, builder.partitionId, e); + return new RssException(ffe); } return e; } diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/BlockIdSelfManagedShuffleWriteClient.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/BlockIdSelfManagedShuffleWriteClient.java index 1429bacbf..93aa3f0fc 100644 --- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/BlockIdSelfManagedShuffleWriteClient.java +++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/BlockIdSelfManagedShuffleWriteClient.java @@ -22,6 +22,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.function.Supplier; import java.util.stream.Collectors; import org.roaringbitmap.longlong.Roaring64NavigableMap; @@ -41,16 +42,16 @@ import org.apache.uniffle.common.util.BlockIdLayout; * driver side. */ public class BlockIdSelfManagedShuffleWriteClient extends ShuffleWriteClientImpl { - private ShuffleManagerClient shuffleManagerClient; + private Supplier<ShuffleManagerClient> managerClientSupplier; public BlockIdSelfManagedShuffleWriteClient( RssShuffleClientFactory.ExtendWriteClientBuilder builder) { super(builder); - if (builder.getShuffleManagerClient() == null) { + if (builder.getManagerClientSupplier() == null) { throw new RssException("Illegal empty shuffleManagerClient. This should not happen"); } - this.shuffleManagerClient = builder.getShuffleManagerClient(); + this.managerClientSupplier = builder.getManagerClientSupplier(); } @Override @@ -73,7 +74,7 @@ public class BlockIdSelfManagedShuffleWriteClient extends ShuffleWriteClientImpl RssReportShuffleResultRequest request = new RssReportShuffleResultRequest( appId, shuffleId, taskAttemptId, partitionToBlockIds, bitmapNum); - shuffleManagerClient.reportShuffleResult(request); + managerClientSupplier.get().reportShuffleResult(request); } @Override @@ -85,7 +86,7 @@ public class BlockIdSelfManagedShuffleWriteClient extends ShuffleWriteClientImpl int partitionId) { RssGetShuffleResultRequest request = new RssGetShuffleResultRequest(appId, shuffleId, partitionId, BlockIdLayout.DEFAULT); - return shuffleManagerClient.getShuffleResult(request).getBlockIdBitmap(); + return managerClientSupplier.get().getShuffleResult(request).getBlockIdBitmap(); } @Override @@ -101,6 +102,6 @@ public class BlockIdSelfManagedShuffleWriteClient extends ShuffleWriteClientImpl RssGetShuffleResultForMultiPartRequest request = new RssGetShuffleResultForMultiPartRequest( appId, shuffleId, partitionIds, BlockIdLayout.DEFAULT); - return shuffleManagerClient.getShuffleResultForMultiPart(request).getBlockIdBitmap(); + return managerClientSupplier.get().getShuffleResultForMultiPart(request).getBlockIdBitmap(); } } diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/RssShuffleClientFactory.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/RssShuffleClientFactory.java index c19d91324..bad10ab72 100644 --- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/RssShuffleClientFactory.java +++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/RssShuffleClientFactory.java @@ -17,6 +17,8 @@ package org.apache.uniffle.shuffle; +import java.util.function.Supplier; + import org.apache.uniffle.client.api.ShuffleManagerClient; import org.apache.uniffle.client.api.ShuffleWriteClient; import org.apache.uniffle.client.factory.ShuffleClientFactory; @@ -41,18 +43,18 @@ public class RssShuffleClientFactory extends ShuffleClientFactory { public static class ExtendWriteClientBuilder<T extends ExtendWriteClientBuilder<T>> extends WriteClientBuilder<T> { private boolean blockIdSelfManagedEnabled; - private ShuffleManagerClient shuffleManagerClient; + private Supplier<ShuffleManagerClient> managerClientSupplier; public boolean isBlockIdSelfManagedEnabled() { return blockIdSelfManagedEnabled; } - public ShuffleManagerClient getShuffleManagerClient() { - return shuffleManagerClient; + public Supplier<ShuffleManagerClient> getManagerClientSupplier() { + return managerClientSupplier; } - public T shuffleManagerClient(ShuffleManagerClient client) { - this.shuffleManagerClient = client; + public T managerClientSupplier(Supplier<ShuffleManagerClient> managerClientSupplier) { + this.managerClientSupplier = managerClientSupplier; return self(); } diff --git a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java index 6a281db2e..d314b9bb6 100644 --- a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java +++ b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java @@ -31,6 +31,7 @@ import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; +import java.util.function.Supplier; import java.util.stream.Collectors; import com.google.common.annotations.VisibleForTesting; @@ -78,10 +79,12 @@ import org.apache.uniffle.common.ShuffleAssignmentsInfo; import org.apache.uniffle.common.ShuffleDataDistributionType; import org.apache.uniffle.common.ShuffleServerInfo; import org.apache.uniffle.common.config.ConfigOption; +import org.apache.uniffle.common.config.RssBaseConf; import org.apache.uniffle.common.config.RssClientConf; import org.apache.uniffle.common.config.RssConf; import org.apache.uniffle.common.exception.RssException; import org.apache.uniffle.common.rpc.StatusCode; +import org.apache.uniffle.common.util.ExpiringCloseableSupplier; import org.apache.uniffle.common.util.RetryUtils; import org.apache.uniffle.shuffle.BlockIdManager; @@ -104,7 +107,7 @@ public abstract class RssShuffleManagerBase implements RssShuffleManagerInterfac protected String clientType; protected SparkConf sparkConf; - protected ShuffleManagerClient shuffleManagerClient; + protected Supplier<ShuffleManagerClient> managerClientSupplier; protected boolean rssStageRetryEnabled; protected boolean rssStageRetryForWriteFailureEnabled; protected boolean rssStageRetryForFetchFailureEnabled; @@ -588,7 +591,8 @@ public abstract class RssShuffleManagerBase implements RssShuffleManagerInterfac RssPartitionToShuffleServerRequest rssPartitionToShuffleServerRequest = new RssPartitionToShuffleServerRequest(shuffleId); RssReassignOnStageRetryResponse rpcPartitionToShufflerServer = - getOrCreateShuffleManagerClient() + getOrCreateShuffleManagerClientSupplier() + .get() .getPartitionToShufflerServerWithStageRetry(rssPartitionToShuffleServerRequest); StageAttemptShuffleHandleInfo shuffleHandleInfo = StageAttemptShuffleHandleInfo.fromProto( @@ -607,25 +611,27 @@ public abstract class RssShuffleManagerBase implements RssShuffleManagerInterfac RssPartitionToShuffleServerRequest rssPartitionToShuffleServerRequest = new RssPartitionToShuffleServerRequest(shuffleId); RssReassignOnBlockSendFailureResponse rpcPartitionToShufflerServer = - getOrCreateShuffleManagerClient() + getOrCreateShuffleManagerClientSupplier() + .get() .getPartitionToShufflerServerWithBlockRetry(rssPartitionToShuffleServerRequest); MutableShuffleHandleInfo shuffleHandleInfo = MutableShuffleHandleInfo.fromProto(rpcPartitionToShufflerServer.getHandle()); return shuffleHandleInfo; } - // todo: automatic close client when the client is idle to avoid too much connections for spark - // driver. - protected ShuffleManagerClient getOrCreateShuffleManagerClient() { - if (shuffleManagerClient == null) { + protected synchronized Supplier<ShuffleManagerClient> getOrCreateShuffleManagerClientSupplier() { + if (managerClientSupplier == null) { RssConf rssConf = RssSparkConfig.toRssConf(sparkConf); String driver = rssConf.getString("driver.host", ""); int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT); - this.shuffleManagerClient = - ShuffleManagerClientFactory.getInstance() - .createShuffleManagerClient(ClientType.GRPC, driver, port); + long rpcTimeout = rssConf.getLong(RssBaseConf.RSS_CLIENT_TYPE_GRPC_TIMEOUT_MS); + this.managerClientSupplier = + ExpiringCloseableSupplier.of( + () -> + ShuffleManagerClientFactory.getInstance() + .createShuffleManagerClient(ClientType.GRPC, driver, port, rpcTimeout)); } - return shuffleManagerClient; + return managerClientSupplier; } @Override @@ -808,6 +814,14 @@ public abstract class RssShuffleManagerBase implements RssShuffleManagerInterfac } } + @Override + public void stop() { + if (managerClientSupplier != null + && managerClientSupplier instanceof ExpiringCloseableSupplier) { + ((ExpiringCloseableSupplier<ShuffleManagerClient>) managerClientSupplier).close(); + } + } + /** * Creating the shuffleAssignmentInfo from the servers and partitionIds * diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index 1e5bb4941..27db614bf 100644 --- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -214,16 +214,15 @@ public class RssShuffleManager extends RssShuffleManagerBase { } } } - if (shuffleManagerRpcServiceEnabled) { - this.shuffleManagerClient = getOrCreateShuffleManagerClient(); + getOrCreateShuffleManagerClientSupplier(); } this.shuffleWriteClient = RssShuffleClientFactory.getInstance() .createShuffleWriteClient( RssShuffleClientFactory.newWriteBuilder() .blockIdSelfManagedEnabled(blockIdSelfManagedEnabled) - .shuffleManagerClient(shuffleManagerClient) + .managerClientSupplier(managerClientSupplier) .clientType(clientType) .retryMax(retryMax) .retryIntervalMax(retryIntervalMax) @@ -434,6 +433,7 @@ public class RssShuffleManager extends RssShuffleManagerBase { this, sparkConf, shuffleWriteClient, + managerClientSupplier, rssHandle, this::markFailedTask, context); @@ -537,7 +537,8 @@ public class RssShuffleManager extends RssShuffleManagerBase { blockIdBitmap, taskIdBitmap, RssSparkConfig.toRssConf(sparkConf), - partitionToServers); + partitionToServers, + managerClientSupplier); } else { throw new RssException("Unexpected ShuffleHandle:" + handle.getClass().getName()); } @@ -573,6 +574,7 @@ public class RssShuffleManager extends RssShuffleManagerBase { @Override public void stop() { + super.stop(); if (heartBeatScheduledExecutorService != null) { heartBeatScheduledExecutorService.shutdownNow(); } @@ -719,7 +721,13 @@ public class RssShuffleManager extends RssShuffleManagerBase { clientType, shuffleServerInfoSet, appId, shuffleId, partitionId); } catch (RssFetchFailedException e) { throw RssSparkShuffleUtils.reportRssFetchFailedException( - e, sparkConf, appId, shuffleId, stageAttemptId, Sets.newHashSet(partitionId)); + managerClientSupplier, + e, + sparkConf, + appId, + shuffleId, + stageAttemptId, + Sets.newHashSet(partitionId)); } } diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java index 3bf5840e8..4b4ec32c5 100644 --- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java +++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java @@ -19,6 +19,7 @@ package org.apache.spark.shuffle.reader; import java.util.List; import java.util.Map; +import java.util.function.Supplier; import scala.Function0; import scala.Function2; @@ -47,6 +48,7 @@ import org.roaringbitmap.longlong.Roaring64NavigableMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.uniffle.client.api.ShuffleManagerClient; import org.apache.uniffle.client.api.ShuffleReadClient; import org.apache.uniffle.client.factory.ShuffleClientFactory; import org.apache.uniffle.client.util.RssClientConfig; @@ -77,6 +79,7 @@ public class RssShuffleReader<K, C> implements ShuffleReader<K, C> { private List<ShuffleServerInfo> shuffleServerInfoList; private Configuration hadoopConf; private RssConf rssConf; + private Supplier<ShuffleManagerClient> managerClientSupplier; public RssShuffleReader( int startPartition, @@ -90,7 +93,8 @@ public class RssShuffleReader<K, C> implements ShuffleReader<K, C> { Roaring64NavigableMap blockIdBitmap, Roaring64NavigableMap taskIdBitmap, RssConf rssConf, - Map<Integer, List<ShuffleServerInfo>> partitionToServers) { + Map<Integer, List<ShuffleServerInfo>> partitionToServers, + Supplier<ShuffleManagerClient> managerClientSupplier) { this.appId = rssShuffleHandle.getAppId(); this.startPartition = startPartition; this.endPartition = endPartition; @@ -107,6 +111,7 @@ public class RssShuffleReader<K, C> implements ShuffleReader<K, C> { this.hadoopConf = hadoopConf; this.shuffleServerInfoList = (List<ShuffleServerInfo>) (partitionToServers.get(startPartition)); this.rssConf = rssConf; + this.managerClientSupplier = managerClientSupplier; expectedTaskIdsBitmapFilterEnable = shuffleServerInfoList.size() > 1; } @@ -235,16 +240,13 @@ public class RssShuffleReader<K, C> implements ShuffleReader<K, C> { // stage re-compute and shuffle manager server port are both set if (rssConf.getBoolean(RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED) && rssConf.getInteger(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT, 0) > 0) { - String driver = rssConf.getString("driver.host", ""); - int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT); resultIter = RssFetchFailedIterator.newBuilder() .appId(appId) .shuffleId(shuffleId) .partitionId(startPartition) .stageAttemptId(context.stageAttemptNumber()) - .reportServerHost(driver) - .port(port) + .managerClientSupplier(managerClientSupplier) .build(resultIter); } return resultIter; diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java index 5ac6a7e9e..4474c99c8 100644 --- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java +++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java @@ -17,7 +17,6 @@ package org.apache.spark.shuffle.writer; -import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -31,6 +30,7 @@ import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.function.Function; +import java.util.function.Supplier; import java.util.stream.Collectors; import scala.Function1; @@ -64,17 +64,13 @@ import org.slf4j.LoggerFactory; import org.apache.uniffle.client.api.ShuffleManagerClient; import org.apache.uniffle.client.api.ShuffleWriteClient; -import org.apache.uniffle.client.factory.ShuffleManagerClientFactory; import org.apache.uniffle.client.impl.FailedBlockSendTracker; import org.apache.uniffle.client.request.RssReassignServersRequest; import org.apache.uniffle.client.request.RssReportShuffleWriteFailureRequest; import org.apache.uniffle.client.response.RssReassignServersResponse; import org.apache.uniffle.client.response.RssReportShuffleWriteFailureResponse; -import org.apache.uniffle.common.ClientType; import org.apache.uniffle.common.ShuffleBlockInfo; import org.apache.uniffle.common.ShuffleServerInfo; -import org.apache.uniffle.common.config.RssClientConf; -import org.apache.uniffle.common.config.RssConf; import org.apache.uniffle.common.exception.RssException; import org.apache.uniffle.common.exception.RssSendFailedException; import org.apache.uniffle.common.exception.RssWaitFailedException; @@ -114,6 +110,7 @@ public class RssShuffleWriter<K, V, C> extends ShuffleWriter<K, V> { private final Set<Long> blockIds = Sets.newConcurrentHashSet(); private TaskContext taskContext; private SparkConf sparkConf; + private Supplier<ShuffleManagerClient> managerClientSupplier; public RssShuffleWriter( String appId, @@ -125,6 +122,7 @@ public class RssShuffleWriter<K, V, C> extends ShuffleWriter<K, V> { RssShuffleManager shuffleManager, SparkConf sparkConf, ShuffleWriteClient shuffleWriteClient, + Supplier<ShuffleManagerClient> managerClientSupplier, RssShuffleHandle<K, V, C> rssHandle, SimpleShuffleHandleInfo shuffleHandleInfo, TaskContext context) { @@ -137,6 +135,7 @@ public class RssShuffleWriter<K, V, C> extends ShuffleWriter<K, V> { shuffleManager, sparkConf, shuffleWriteClient, + managerClientSupplier, rssHandle, (tid) -> true, shuffleHandleInfo, @@ -153,6 +152,7 @@ public class RssShuffleWriter<K, V, C> extends ShuffleWriter<K, V> { RssShuffleManager shuffleManager, SparkConf sparkConf, ShuffleWriteClient shuffleWriteClient, + Supplier<ShuffleManagerClient> managerClientSupplier, RssShuffleHandle<K, V, C> rssHandle, Function<String, Boolean> taskFailureCallback, ShuffleHandleInfo shuffleHandleInfo, @@ -172,6 +172,7 @@ public class RssShuffleWriter<K, V, C> extends ShuffleWriter<K, V> { this.bitmapSplitNum = sparkConf.get(RssSparkConfig.RSS_CLIENT_BITMAP_SPLIT_NUM); this.serverToPartitionToBlockIds = Maps.newHashMap(); this.shuffleWriteClient = shuffleWriteClient; + this.managerClientSupplier = managerClientSupplier; this.shuffleServersForData = shuffleHandleInfo.getServers(); this.partitionToServers = shuffleHandleInfo.getAvailablePartitionServersForWriter(); this.isMemoryShuffleEnabled = @@ -191,6 +192,7 @@ public class RssShuffleWriter<K, V, C> extends ShuffleWriter<K, V> { RssShuffleManager shuffleManager, SparkConf sparkConf, ShuffleWriteClient shuffleWriteClient, + Supplier<ShuffleManagerClient> managerClientSupplier, RssShuffleHandle<K, V, C> rssHandle, Function<String, Boolean> taskFailureCallback, TaskContext context) { @@ -203,6 +205,7 @@ public class RssShuffleWriter<K, V, C> extends ShuffleWriter<K, V> { shuffleManager, sparkConf, shuffleWriteClient, + managerClientSupplier, rssHandle, taskFailureCallback, shuffleManager.getShuffleHandleInfo(rssHandle), @@ -528,14 +531,6 @@ public class RssShuffleWriter<K, V, C> extends ShuffleWriter<K, V> { return shuffleWriteMetrics; } - private static ShuffleManagerClient createShuffleManagerClient(String host, int port) - throws IOException { - ClientType grpc = ClientType.GRPC; - // Host can be inferred from `spark.driver.bindAddress`, which would be set when SparkContext is - // constructed. - return ShuffleManagerClientFactory.getInstance().createShuffleManagerClient(grpc, host, port); - } - private void throwFetchFailedIfNecessary(Exception e) { // The shuffleServer is registered only when a Block fails to be sent if (e instanceof RssSendFailedException) { @@ -550,34 +545,28 @@ public class RssShuffleWriter<K, V, C> extends ShuffleWriter<K, V> { taskContext.stageAttemptNumber(), shuffleServerInfos, e.getMessage()); - RssConf rssConf = RssSparkConfig.toRssConf(sparkConf); - String driver = rssConf.getString("driver.host", ""); - int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT); - try (ShuffleManagerClient shuffleManagerClient = createShuffleManagerClient(driver, port)) { - RssReportShuffleWriteFailureResponse response = - shuffleManagerClient.reportShuffleWriteFailure(req); - if (response.getReSubmitWholeStage()) { - // The shuffle server is reassigned. - RssReassignServersRequest rssReassignServersRequest = - new RssReassignServersRequest( - taskContext.stageId(), - taskContext.stageAttemptNumber(), - shuffleId, - partitioner.numPartitions()); - RssReassignServersResponse rssReassignServersResponse = - shuffleManagerClient.reassignOnStageResubmit(rssReassignServersRequest); - LOG.info( - "Whether the reassignment is successful: {}", - rssReassignServersResponse.isNeedReassign()); - // since we are going to roll out the whole stage, mapIndex shouldn't matter, hence -1 is - // provided. - FetchFailedException ffe = - RssSparkShuffleUtils.createFetchFailedException( - shuffleId, -1, taskContext.stageAttemptNumber(), e); - throw new RssException(ffe); - } - } catch (IOException ioe) { - LOG.info("Error closing shuffle manager client with error:", ioe); + RssReportShuffleWriteFailureResponse response = + managerClientSupplier.get().reportShuffleWriteFailure(req); + if (response.getReSubmitWholeStage()) { + // The shuffle server is reassigned. + RssReassignServersRequest rssReassignServersRequest = + new RssReassignServersRequest( + taskContext.stageId(), + taskContext.stageAttemptNumber(), + shuffleId, + partitioner.numPartitions()); + RssReassignServersResponse rssReassignServersResponse = + managerClientSupplier.get().reassignOnStageResubmit(rssReassignServersRequest); + LOG.info( + "Whether the reassignment is successful: {}", + rssReassignServersResponse.isNeedReassign()); + // since we are going to roll out the whole stage, mapIndex shouldn't matter, hence -1 + // is + // provided. + FetchFailedException ffe = + RssSparkShuffleUtils.createFetchFailedException( + shuffleId, -1, taskContext.stageAttemptNumber(), e); + throw new RssException(ffe); } } throw new RssException(e); diff --git a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java index f09223b1c..78fe7dec0 100644 --- a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java +++ b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java @@ -35,9 +35,11 @@ import org.apache.spark.shuffle.RssShuffleHandle; import org.junit.jupiter.api.Test; import org.roaringbitmap.longlong.Roaring64NavigableMap; +import org.apache.uniffle.client.api.ShuffleManagerClient; import org.apache.uniffle.common.ShuffleServerInfo; import org.apache.uniffle.common.config.RssClientConf; import org.apache.uniffle.common.config.RssConf; +import org.apache.uniffle.common.util.ExpiringCloseableSupplier; import org.apache.uniffle.storage.handler.impl.HadoopShuffleWriteHandler; import org.apache.uniffle.storage.util.StorageType; @@ -85,6 +87,7 @@ public class RssShuffleReaderTest extends AbstractRssReaderTest { rssConf.set(RssClientConf.RSS_STORAGE_TYPE, StorageType.HDFS.name()); rssConf.set(RssClientConf.RSS_INDEX_READ_LIMIT, 1000); rssConf.set(RssClientConf.RSS_CLIENT_READ_BUFFER_SIZE, "1000"); + final ShuffleManagerClient mockShuffleManagerClient = mock(ShuffleManagerClient.class); RssShuffleReader<String, String> rssShuffleReaderSpy = spy( new RssShuffleReader<>( @@ -99,7 +102,8 @@ public class RssShuffleReaderTest extends AbstractRssReaderTest { blockIdBitmap, taskIdBitmap, rssConf, - partitionToServers)); + partitionToServers, + ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient))); validateResult(rssShuffleReaderSpy.read(), expectedData, 10); } diff --git a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java index e039ad9d5..779f94117 100644 --- a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java +++ b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java @@ -48,12 +48,14 @@ import org.apache.spark.shuffle.RssSparkConfig; import org.apache.spark.shuffle.handle.SimpleShuffleHandleInfo; import org.junit.jupiter.api.Test; +import org.apache.uniffle.client.api.ShuffleManagerClient; import org.apache.uniffle.client.api.ShuffleWriteClient; import org.apache.uniffle.client.impl.FailedBlockSendTracker; import org.apache.uniffle.common.ShuffleBlockInfo; import org.apache.uniffle.common.ShuffleServerInfo; import org.apache.uniffle.common.config.RssConf; import org.apache.uniffle.common.rpc.StatusCode; +import org.apache.uniffle.common.util.ExpiringCloseableSupplier; import org.apache.uniffle.storage.util.StorageType; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -88,6 +90,7 @@ public class RssShuffleWriterTest { Serializer kryoSerializer = new KryoSerializer(conf); ShuffleWriteClient mockShuffleWriteClient = mock(ShuffleWriteClient.class); + ShuffleManagerClient mockShuffleManagerClient = mock(ShuffleManagerClient.class); Partitioner mockPartitioner = mock(Partitioner.class); ShuffleDependency<String, String, String> mockDependency = mock(ShuffleDependency.class); RssShuffleHandle<String, String, String> mockHandle = mock(RssShuffleHandle.class); @@ -124,6 +127,7 @@ public class RssShuffleWriterTest { manager, conf, mockShuffleWriteClient, + ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient), mockHandle, mockShuffleHandleInfo, contextMock); @@ -234,6 +238,7 @@ public class RssShuffleWriterTest { Partitioner mockPartitioner = mock(Partitioner.class); ShuffleDependency<String, String, String> mockDependency = mock(ShuffleDependency.class); final ShuffleWriteClient mockShuffleWriteClient = mock(ShuffleWriteClient.class); + final ShuffleManagerClient mockShuffleManagerClient = mock(ShuffleManagerClient.class); RssShuffleHandle<String, String, String> mockHandle = mock(RssShuffleHandle.class); when(mockHandle.getDependency()).thenReturn(mockDependency); Serializer kryoSerializer = new KryoSerializer(conf); @@ -299,6 +304,7 @@ public class RssShuffleWriterTest { manager, conf, mockShuffleWriteClient, + ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient), mockHandle, mockShuffleHandleInfo, contextMock); @@ -348,6 +354,7 @@ public class RssShuffleWriterTest { @Test public void postBlockEventTest() throws Exception { final ShuffleWriteMetrics mockMetrics = mock(ShuffleWriteMetrics.class); + final ShuffleManagerClient mockShuffleManagerClient = mock(ShuffleManagerClient.class); ShuffleDependency<String, String, String> mockDependency = mock(ShuffleDependency.class); Partitioner mockPartitioner = mock(Partitioner.class); when(mockDependency.partitioner()).thenReturn(mockPartitioner); @@ -411,6 +418,7 @@ public class RssShuffleWriterTest { manager, conf, mockWriteClient, + ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient), mockHandle, mockShuffleHandleInfo, contextMock); diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java index bf42bf361..92e630df2 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java @@ -239,7 +239,7 @@ public class RssShuffleManager extends RssShuffleManagerBase { } } if (shuffleManagerRpcServiceEnabled) { - this.shuffleManagerClient = getOrCreateShuffleManagerClient(); + getOrCreateShuffleManagerClientSupplier(); } int unregisterThreadPoolSize = sparkConf.get(RssSparkConfig.RSS_CLIENT_UNREGISTER_THREAD_POOL_SIZE); @@ -253,7 +253,7 @@ public class RssShuffleManager extends RssShuffleManagerBase { .createShuffleWriteClient( RssShuffleClientFactory.newWriteBuilder() .blockIdSelfManagedEnabled(blockIdSelfManagedEnabled) - .shuffleManagerClient(shuffleManagerClient) + .managerClientSupplier(managerClientSupplier) .clientType(clientType) .retryMax(retryMax) .retryIntervalMax(retryIntervalMax) @@ -523,6 +523,7 @@ public class RssShuffleManager extends RssShuffleManagerBase { this, sparkConf, shuffleWriteClient, + managerClientSupplier, rssHandle, this::markFailedTask, context); @@ -696,6 +697,7 @@ public class RssShuffleManager extends RssShuffleManagerBase { blockIdBitmap, startPartition, endPartition, blockIdLayout), taskIdBitmap, readMetrics, + managerClientSupplier, RssSparkConfig.toRssConf(sparkConf), dataDistributionType, shuffleHandleInfo.getAllPartitionServersForReader()); @@ -853,6 +855,7 @@ public class RssShuffleManager extends RssShuffleManagerBase { @Override public void stop() { + super.stop(); if (heartBeatScheduledExecutorService != null) { heartBeatScheduledExecutorService.shutdownNow(); } @@ -1031,7 +1034,7 @@ public class RssShuffleManager extends RssShuffleManagerBase { replicaRequirementTracking); } catch (RssFetchFailedException e) { throw RssSparkShuffleUtils.reportRssFetchFailedException( - e, sparkConf, appId, shuffleId, stageAttemptId, failedPartitions); + managerClientSupplier, e, sparkConf, appId, shuffleId, stageAttemptId, failedPartitions); } } diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java index bf47ced6b..19682bd65 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java @@ -19,6 +19,7 @@ package org.apache.spark.shuffle.reader; import java.util.List; import java.util.Map; +import java.util.function.Supplier; import scala.Function0; import scala.Function1; @@ -49,6 +50,7 @@ import org.roaringbitmap.longlong.Roaring64NavigableMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.uniffle.client.api.ShuffleManagerClient; import org.apache.uniffle.client.api.ShuffleReadClient; import org.apache.uniffle.client.factory.ShuffleClientFactory; import org.apache.uniffle.client.util.RssClientConfig; @@ -58,7 +60,6 @@ import org.apache.uniffle.common.config.RssClientConf; import org.apache.uniffle.common.config.RssConf; import static org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED; -import static org.apache.uniffle.common.util.Constants.DRIVER_HOST; public class RssShuffleReader<K, C> implements ShuffleReader<K, C> { private static final Logger LOG = LoggerFactory.getLogger(RssShuffleReader.class); @@ -83,6 +84,7 @@ public class RssShuffleReader<K, C> implements ShuffleReader<K, C> { private ShuffleReadMetrics readMetrics; private RssConf rssConf; private ShuffleDataDistributionType dataDistributionType; + private Supplier<ShuffleManagerClient> managerClientSupplier; public RssShuffleReader( int startPartition, @@ -97,6 +99,7 @@ public class RssShuffleReader<K, C> implements ShuffleReader<K, C> { Map<Integer, Roaring64NavigableMap> partitionToExpectBlocks, Roaring64NavigableMap taskIdBitmap, ShuffleReadMetrics readMetrics, + Supplier<ShuffleManagerClient> managerClientSupplier, RssConf rssConf, ShuffleDataDistributionType dataDistributionType, Map<Integer, List<ShuffleServerInfo>> allPartitionToServers) { @@ -120,6 +123,7 @@ public class RssShuffleReader<K, C> implements ShuffleReader<K, C> { this.partitionToShuffleServers = allPartitionToServers; this.rssConf = rssConf; this.dataDistributionType = dataDistributionType; + this.managerClientSupplier = managerClientSupplier; } @Override @@ -193,16 +197,13 @@ public class RssShuffleReader<K, C> implements ShuffleReader<K, C> { // resubmit stage and shuffle manager server port are both set if (rssConf.getBoolean(RSS_RESUBMIT_STAGE_WITH_FETCH_FAILURE_ENABLED) && rssConf.getInteger(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT, 0) > 0) { - String driver = rssConf.getString(DRIVER_HOST, ""); - int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT); resultIter = RssFetchFailedIterator.newBuilder() .appId(appId) .shuffleId(shuffleId) .partitionId(startPartition) .stageAttemptId(context.stageAttemptNumber()) - .reportServerHost(driver) - .port(port) + .managerClientSupplier(managerClientSupplier) .build(resultIter); } return resultIter; diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java index 870141c4b..24a3b8c1c 100644 --- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java +++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java @@ -17,7 +17,6 @@ package org.apache.spark.shuffle.writer; -import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -36,6 +35,7 @@ import java.util.concurrent.Future; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.function.Function; +import java.util.function.Supplier; import java.util.stream.Collectors; import scala.Function1; @@ -71,7 +71,6 @@ import org.slf4j.LoggerFactory; import org.apache.uniffle.client.api.ShuffleManagerClient; import org.apache.uniffle.client.api.ShuffleWriteClient; -import org.apache.uniffle.client.factory.ShuffleManagerClientFactory; import org.apache.uniffle.client.impl.FailedBlockSendTracker; import org.apache.uniffle.client.impl.TrackingBlockStatus; import org.apache.uniffle.client.request.RssReassignOnBlockSendFailureRequest; @@ -80,12 +79,10 @@ import org.apache.uniffle.client.request.RssReportShuffleWriteFailureRequest; import org.apache.uniffle.client.response.RssReassignOnBlockSendFailureResponse; import org.apache.uniffle.client.response.RssReassignServersResponse; import org.apache.uniffle.client.response.RssReportShuffleWriteFailureResponse; -import org.apache.uniffle.common.ClientType; import org.apache.uniffle.common.ReceivingFailureServer; import org.apache.uniffle.common.ShuffleBlockInfo; import org.apache.uniffle.common.ShuffleServerInfo; import org.apache.uniffle.common.config.RssClientConf; -import org.apache.uniffle.common.config.RssConf; import org.apache.uniffle.common.exception.RssException; import org.apache.uniffle.common.exception.RssSendFailedException; import org.apache.uniffle.common.exception.RssWaitFailedException; @@ -143,6 +140,8 @@ public class RssShuffleWriter<K, V, C> extends ShuffleWriter<K, V> { private static final Set<StatusCode> STATUS_CODE_WITHOUT_BLOCK_RESEND = Sets.newHashSet(StatusCode.NO_REGISTER); + private final Supplier<ShuffleManagerClient> managerClientSupplier; + // Only for tests @VisibleForTesting public RssShuffleWriter( @@ -155,6 +154,7 @@ public class RssShuffleWriter<K, V, C> extends ShuffleWriter<K, V> { RssShuffleManager shuffleManager, SparkConf sparkConf, ShuffleWriteClient shuffleWriteClient, + Supplier<ShuffleManagerClient> managerClientSupplier, RssShuffleHandle<K, V, C> rssHandle, ShuffleHandleInfo shuffleHandleInfo, TaskContext context) { @@ -167,6 +167,7 @@ public class RssShuffleWriter<K, V, C> extends ShuffleWriter<K, V> { shuffleManager, sparkConf, shuffleWriteClient, + managerClientSupplier, rssHandle, (tid) -> true, shuffleHandleInfo, @@ -184,6 +185,7 @@ public class RssShuffleWriter<K, V, C> extends ShuffleWriter<K, V> { RssShuffleManager shuffleManager, SparkConf sparkConf, ShuffleWriteClient shuffleWriteClient, + Supplier<ShuffleManagerClient> managerClientSupplier, RssShuffleHandle<K, V, C> rssHandle, Function<String, Boolean> taskFailureCallback, ShuffleHandleInfo shuffleHandleInfo, @@ -217,6 +219,7 @@ public class RssShuffleWriter<K, V, C> extends ShuffleWriter<K, V> { this.shuffleHandleInfo = shuffleHandleInfo; this.taskContext = context; this.sparkConf = sparkConf; + this.managerClientSupplier = managerClientSupplier; this.blockFailSentRetryEnabled = sparkConf.getBoolean( RssSparkConfig.SPARK_RSS_CONFIG_PREFIX @@ -235,6 +238,7 @@ public class RssShuffleWriter<K, V, C> extends ShuffleWriter<K, V> { RssShuffleManager shuffleManager, SparkConf sparkConf, ShuffleWriteClient shuffleWriteClient, + Supplier<ShuffleManagerClient> managerClientSupplier, RssShuffleHandle<K, V, C> rssHandle, Function<String, Boolean> taskFailureCallback, TaskContext context) { @@ -247,6 +251,7 @@ public class RssShuffleWriter<K, V, C> extends ShuffleWriter<K, V> { shuffleManager, sparkConf, shuffleWriteClient, + managerClientSupplier, rssHandle, taskFailureCallback, shuffleManager.getShuffleHandleInfo(rssHandle), @@ -618,14 +623,11 @@ public class RssShuffleWriter<K, V, C> extends ShuffleWriter<K, V> { LOG.info( "Initiate reassignOnBlockSendFailure. failure partition servers: {}", failurePartitionToServers); - RssConf rssConf = RssSparkConfig.toRssConf(sparkConf); - String driver = rssConf.getString("driver.host", ""); - int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT); - try (ShuffleManagerClient shuffleManagerClient = createShuffleManagerClient(driver, port)) { - String executorId = SparkEnv.get().executorId(); - long taskAttemptId = taskContext.taskAttemptId(); - int stageId = taskContext.stageId(); - int stageAttemptNum = taskContext.stageAttemptNumber(); + String executorId = SparkEnv.get().executorId(); + long taskAttemptId = taskContext.taskAttemptId(); + int stageId = taskContext.stageId(); + int stageAttemptNum = taskContext.stageAttemptNumber(); + try { RssReassignOnBlockSendFailureRequest request = new RssReassignOnBlockSendFailureRequest( shuffleId, @@ -635,7 +637,7 @@ public class RssShuffleWriter<K, V, C> extends ShuffleWriter<K, V> { stageId, stageAttemptNum); RssReassignOnBlockSendFailureResponse response = - shuffleManagerClient.reassignOnBlockSendFailure(request); + managerClientSupplier.get().reassignOnBlockSendFailure(request); if (response.getStatusCode() != StatusCode.SUCCESS) { String msg = String.format( @@ -835,14 +837,6 @@ public class RssShuffleWriter<K, V, C> extends ShuffleWriter<K, V> { return bufferManager; } - private static ShuffleManagerClient createShuffleManagerClient(String host, int port) - throws IOException { - ClientType grpc = ClientType.GRPC; - // Host can be inferred from `spark.driver.bindAddress`, which would be set when SparkContext is - // constructed. - return ShuffleManagerClientFactory.getInstance().createShuffleManagerClient(grpc, host, port); - } - private void throwFetchFailedIfNecessary(Exception e) { // The shuffleServer is registered only when a Block fails to be sent if (e instanceof RssSendFailedException) { @@ -857,33 +851,27 @@ public class RssShuffleWriter<K, V, C> extends ShuffleWriter<K, V> { taskContext.stageAttemptNumber(), shuffleServerInfos, e.getMessage()); - RssConf rssConf = RssSparkConfig.toRssConf(sparkConf); - String driver = rssConf.getString("driver.host", ""); - int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT); - try (ShuffleManagerClient shuffleManagerClient = createShuffleManagerClient(driver, port)) { - RssReportShuffleWriteFailureResponse response = - shuffleManagerClient.reportShuffleWriteFailure(req); - if (response.getReSubmitWholeStage()) { - RssReassignServersRequest rssReassignServersRequest = - new RssReassignServersRequest( - taskContext.stageId(), - taskContext.stageAttemptNumber(), - shuffleId, - partitioner.numPartitions()); - RssReassignServersResponse rssReassignServersResponse = - shuffleManagerClient.reassignOnStageResubmit(rssReassignServersRequest); - LOG.info( - "Whether the reassignment is successful: {}", - rssReassignServersResponse.isNeedReassign()); - // since we are going to roll out the whole stage, mapIndex shouldn't matter, hence -1 is - // provided. - FetchFailedException ffe = - RssSparkShuffleUtils.createFetchFailedException( - shuffleId, -1, taskContext.stageAttemptNumber(), e); - throw new RssException(ffe); - } - } catch (IOException ioe) { - LOG.info("Error closing shuffle manager client with error:", ioe); + RssReportShuffleWriteFailureResponse response = + managerClientSupplier.get().reportShuffleWriteFailure(req); + if (response.getReSubmitWholeStage()) { + RssReassignServersRequest rssReassignServersRequest = + new RssReassignServersRequest( + taskContext.stageId(), + taskContext.stageAttemptNumber(), + shuffleId, + partitioner.numPartitions()); + RssReassignServersResponse rssReassignServersResponse = + managerClientSupplier.get().reassignOnStageResubmit(rssReassignServersRequest); + LOG.info( + "Whether the reassignment is successful: {}", + rssReassignServersResponse.isNeedReassign()); + // since we are going to roll out the whole stage, mapIndex shouldn't matter, hence -1 + // is + // provided. + FetchFailedException ffe = + RssSparkShuffleUtils.createFetchFailedException( + shuffleId, -1, taskContext.stageAttemptNumber(), e); + throw new RssException(ffe); } } throw new RssException(e); diff --git a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java index aaff4cb8e..bc77f7192 100644 --- a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java +++ b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java @@ -36,10 +36,12 @@ import org.apache.spark.shuffle.RssShuffleHandle; import org.junit.jupiter.api.Test; import org.roaringbitmap.longlong.Roaring64NavigableMap; +import org.apache.uniffle.client.api.ShuffleManagerClient; import org.apache.uniffle.common.ShuffleDataDistributionType; import org.apache.uniffle.common.ShuffleServerInfo; import org.apache.uniffle.common.config.RssClientConf; import org.apache.uniffle.common.config.RssConf; +import org.apache.uniffle.common.util.ExpiringCloseableSupplier; import org.apache.uniffle.storage.handler.impl.HadoopShuffleWriteHandler; import org.apache.uniffle.storage.util.StorageType; @@ -93,6 +95,7 @@ public class RssShuffleReaderTest extends AbstractRssReaderTest { rssConf.set(RssClientConf.RSS_STORAGE_TYPE, StorageType.HDFS.name()); rssConf.set(RssClientConf.RSS_INDEX_READ_LIMIT, 1000); rssConf.set(RssClientConf.RSS_CLIENT_READ_BUFFER_SIZE, "1000"); + ShuffleManagerClient mockShuffleManagerClient = mock(ShuffleManagerClient.class); RssShuffleReader<String, String> rssShuffleReaderSpy = spy( new RssShuffleReader<>( @@ -108,6 +111,7 @@ public class RssShuffleReaderTest extends AbstractRssReaderTest { partitionToExpectBlocks, taskIdBitmap, new ShuffleReadMetrics(), + ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient), rssConf, ShuffleDataDistributionType.NORMAL, partitionToServers)); @@ -131,6 +135,7 @@ public class RssShuffleReaderTest extends AbstractRssReaderTest { partitionToExpectBlocks, taskIdBitmap, new ShuffleReadMetrics(), + ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient), rssConf, ShuffleDataDistributionType.NORMAL, partitionToServers)); @@ -151,6 +156,7 @@ public class RssShuffleReaderTest extends AbstractRssReaderTest { partitionToExpectBlocks, Roaring64NavigableMap.bitmapOf(), new ShuffleReadMetrics(), + ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient), rssConf, ShuffleDataDistributionType.NORMAL, partitionToServers)); diff --git a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java index 53a8e7143..a4317aae8 100644 --- a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java +++ b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java @@ -55,12 +55,14 @@ import org.apache.spark.shuffle.handle.SimpleShuffleHandleInfo; import org.awaitility.Awaitility; import org.junit.jupiter.api.Test; +import org.apache.uniffle.client.api.ShuffleManagerClient; import org.apache.uniffle.client.api.ShuffleWriteClient; import org.apache.uniffle.client.impl.FailedBlockSendTracker; import org.apache.uniffle.common.RemoteStorageInfo; import org.apache.uniffle.common.ShuffleBlockInfo; import org.apache.uniffle.common.ShuffleServerInfo; import org.apache.uniffle.common.rpc.StatusCode; +import org.apache.uniffle.common.util.ExpiringCloseableSupplier; import org.apache.uniffle.common.util.JavaUtils; import org.apache.uniffle.storage.util.StorageType; @@ -133,6 +135,7 @@ public class RssShuffleWriterTest { Serializer kryoSerializer = new KryoSerializer(conf); Partitioner mockPartitioner = mock(Partitioner.class); final ShuffleWriteClient mockShuffleWriteClient = mock(ShuffleWriteClient.class); + final ShuffleManagerClient mockShuffleManagerClient = mock(ShuffleManagerClient.class); ShuffleDependency<String, String, String> mockDependency = mock(ShuffleDependency.class); RssShuffleHandle<String, String, String> mockHandle = mock(RssShuffleHandle.class); when(mockHandle.getDependency()).thenReturn(mockDependency); @@ -179,6 +182,7 @@ public class RssShuffleWriterTest { manager, conf, mockShuffleWriteClient, + ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient), mockHandle, shuffleHandle, contextMock); @@ -385,6 +389,7 @@ public class RssShuffleWriterTest { Serializer kryoSerializer = new KryoSerializer(conf); Partitioner mockPartitioner = mock(Partitioner.class); final ShuffleWriteClient mockShuffleWriteClient = mock(ShuffleWriteClient.class); + final ShuffleManagerClient mockShuffleManagerClient = mock(ShuffleManagerClient.class); ShuffleDependency<String, String, String> mockDependency = mock(ShuffleDependency.class); RssShuffleHandle<String, String, String> mockHandle = mock(RssShuffleHandle.class); when(mockHandle.getDependency()).thenReturn(mockDependency); @@ -450,6 +455,7 @@ public class RssShuffleWriterTest { manager, conf, mockShuffleWriteClient, + ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient), mockHandle, shuffleHandleInfo, contextMock); @@ -552,6 +558,7 @@ public class RssShuffleWriterTest { conf, false, null, successBlocks, taskToFailedBlockSendTracker); ShuffleWriteClient mockShuffleWriteClient = mock(ShuffleWriteClient.class); + ShuffleManagerClient mockShuffleManagerClient = mock(ShuffleManagerClient.class); Partitioner mockPartitioner = mock(Partitioner.class); RssShuffleHandle<String, String, String> mockHandle = mock(RssShuffleHandle.class); ShuffleDependency<String, String, String> mockDependency = mock(ShuffleDependency.class); @@ -587,6 +594,7 @@ public class RssShuffleWriterTest { manager, conf, mockShuffleWriteClient, + ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient), mockHandle, mockShuffleHandleInfo, contextMock); @@ -714,6 +722,7 @@ public class RssShuffleWriterTest { Serializer kryoSerializer = new KryoSerializer(conf); Partitioner mockPartitioner = mock(Partitioner.class); final ShuffleWriteClient mockShuffleWriteClient = mock(ShuffleWriteClient.class); + final ShuffleManagerClient mockShuffleManagerClient = mock(ShuffleManagerClient.class); ShuffleDependency<String, String, String> mockDependency = mock(ShuffleDependency.class); RssShuffleHandle<String, String, String> mockHandle = mock(RssShuffleHandle.class); when(mockHandle.getDependency()).thenReturn(mockDependency); @@ -734,6 +743,7 @@ public class RssShuffleWriterTest { manager, conf, mockShuffleWriteClient, + ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient), mockHandle, mockShuffleHandleInfo, contextMock); @@ -794,6 +804,7 @@ public class RssShuffleWriterTest { Serializer kryoSerializer = new KryoSerializer(conf); Partitioner mockPartitioner = mock(Partitioner.class); final ShuffleWriteClient mockShuffleWriteClient = mock(ShuffleWriteClient.class); + final ShuffleManagerClient mockShuffleManagerClient = mock(ShuffleManagerClient.class); ShuffleDependency<String, String, String> mockDependency = mock(ShuffleDependency.class); RssShuffleHandle<String, String, String> mockHandle = mock(RssShuffleHandle.class); when(mockHandle.getDependency()).thenReturn(mockDependency); @@ -857,6 +868,7 @@ public class RssShuffleWriterTest { manager, conf, mockShuffleWriteClient, + ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient), mockHandle, mockShuffleHandleInfo, contextMock); @@ -958,6 +970,7 @@ public class RssShuffleWriterTest { TaskContext contextMock = mock(TaskContext.class); SimpleShuffleHandleInfo mockShuffleHandleInfo = mock(SimpleShuffleHandleInfo.class); ShuffleWriteClient mockWriteClient = mock(ShuffleWriteClient.class); + ShuffleManagerClient mockShuffleManagerClient = mock(ShuffleManagerClient.class); List<ShuffleBlockInfo> shuffleBlockInfoList = createShuffleBlockList(1, 31); RssShuffleWriter<String, String, String> writer = @@ -971,6 +984,7 @@ public class RssShuffleWriterTest { mockShuffleManager, conf, mockWriteClient, + ExpiringCloseableSupplier.of(() -> mockShuffleManagerClient), mockHandle, mockShuffleHandleInfo, contextMock); diff --git a/common/src/main/java/org/apache/uniffle/common/util/ExpiringCloseableSupplier.java b/common/src/main/java/org/apache/uniffle/common/util/ExpiringCloseableSupplier.java new file mode 100644 index 000000000..f36f9be0c --- /dev/null +++ b/common/src/main/java/org/apache/uniffle/common/util/ExpiringCloseableSupplier.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.common.util; + +import java.io.IOException; +import java.io.Serializable; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; + +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A Supplier for T cacheable and autocloseable with delay by using ExpiringCloseableSupplier to + * obtain an object, manual closure may not be necessary. + */ +public class ExpiringCloseableSupplier<T extends StatefulCloseable> + implements Supplier<T>, Serializable { + private static final long serialVersionUID = 0; + private static final Logger LOG = LoggerFactory.getLogger(ExpiringCloseableSupplier.class); + private static final int DEFAULT_DELAY_CLOSE_INTERVAL = 60000; + private static final ScheduledExecutorService executor = + ThreadUtils.getDaemonSingleThreadScheduledExecutor("ExpiringCloseableSupplier"); + + private final Supplier<T> delegate; + private final long delayCloseInterval; + + private transient volatile ScheduledFuture<?> future; + + @SuppressFBWarnings("SE_TRANSIENT_FIELD_NOT_RESTORED") + private transient volatile long accessTime = System.currentTimeMillis(); + + private transient volatile T t; + + private ExpiringCloseableSupplier(Supplier<T> delegate, long delayCloseInterval) { + this.delegate = delegate; + this.delayCloseInterval = delayCloseInterval; + } + + public synchronized T get() { + accessTime = System.currentTimeMillis(); + if (t == null || t.isClosed()) { + this.t = delegate.get(); + ensureCloseFutureScheduled(); + } + return t; + } + + public synchronized void close() { + try { + if (t != null && !t.isClosed()) { + t.close(); + } + } catch (IOException ioe) { + LOG.warn("Failed to close {} the resource", t.getClass().getName(), ioe); + } finally { + this.t = null; + this.accessTime = System.currentTimeMillis(); + cancelCloseFuture(); + } + } + + private void tryClose() { + if (System.currentTimeMillis() - accessTime > delayCloseInterval) { + close(); + } + } + + private void ensureCloseFutureScheduled() { + cancelCloseFuture(); + this.future = + executor.scheduleAtFixedRate( + this::tryClose, delayCloseInterval, delayCloseInterval, TimeUnit.MILLISECONDS); + } + + private void cancelCloseFuture() { + if (future != null && !future.isDone()) { + future.cancel(false); + this.future = null; + } + } + + public static <T extends StatefulCloseable> ExpiringCloseableSupplier<T> of( + Supplier<T> delegate) { + return new ExpiringCloseableSupplier<>(delegate, DEFAULT_DELAY_CLOSE_INTERVAL); + } + + public static <T extends StatefulCloseable> ExpiringCloseableSupplier<T> of( + Supplier<T> delegate, long delayCloseInterval) { + return new ExpiringCloseableSupplier<>(delegate, delayCloseInterval); + } +} diff --git a/common/src/main/java/org/apache/uniffle/common/util/StatefulCloseable.java b/common/src/main/java/org/apache/uniffle/common/util/StatefulCloseable.java new file mode 100644 index 000000000..a4a2453d6 --- /dev/null +++ b/common/src/main/java/org/apache/uniffle/common/util/StatefulCloseable.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.common.util; + +import java.io.Closeable; + +/** StatefulCloseable is an interface that utilizes the ExpiringCloseableSupplier delegate. */ +public interface StatefulCloseable extends Closeable { + boolean isClosed(); +} diff --git a/common/src/test/java/org/apache/uniffle/common/util/ExpiringCloseableSupplierTest.java b/common/src/test/java/org/apache/uniffle/common/util/ExpiringCloseableSupplierTest.java new file mode 100644 index 000000000..0f791ceab --- /dev/null +++ b/common/src/test/java/org/apache/uniffle/common/util/ExpiringCloseableSupplierTest.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.common.util; + +import java.io.IOException; +import java.io.Serializable; +import java.util.List; +import java.util.Random; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Supplier; + +import com.google.common.collect.Lists; +import com.google.common.util.concurrent.Uninterruptibles; +import org.apache.commons.lang3.SerializationUtils; +import org.awaitility.Awaitility; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class ExpiringCloseableSupplierTest { + + @Test + void testCacheable() { + Supplier<MockClient> cf = () -> new MockClient(false); + ExpiringCloseableSupplier<MockClient> mockClientSupplier = ExpiringCloseableSupplier.of(cf); + + MockClient mockClient = mockClientSupplier.get(); + MockClient mockClient2 = mockClientSupplier.get(); + assertSame(mockClient, mockClient2); + mockClientSupplier.close(); + mockClientSupplier.close(); + } + + @Test + void testAutoCloseable() { + Supplier<MockClient> cf = () -> new MockClient(true); + ExpiringCloseableSupplier<MockClient> mockClientSupplier = ExpiringCloseableSupplier.of(cf, 10); + MockClient mockClient1 = mockClientSupplier.get(); + assertNotNull(mockClient1); + Uninterruptibles.sleepUninterruptibly(30, TimeUnit.MILLISECONDS); + assertTrue(mockClient1.isClosed()); + MockClient mockClient2 = mockClientSupplier.get(); + assertNotSame(mockClient1, mockClient2); + mockClientSupplier.close(); + } + + @Test + void testRenew() { + Supplier<MockClient> cf = () -> new MockClient(true); + ExpiringCloseableSupplier<MockClient> mockClientSupplier = ExpiringCloseableSupplier.of(cf); + MockClient mockClient = mockClientSupplier.get(); + mockClientSupplier.close(); + MockClient mockClient2 = mockClientSupplier.get(); + assertNotSame(mockClient, mockClient2); + } + + @Test + void testReClose() { + Supplier<MockClient> cf = () -> new MockClient(true); + ExpiringCloseableSupplier<MockClient> mockClientSupplier = ExpiringCloseableSupplier.of(cf); + mockClientSupplier.get(); + mockClientSupplier.close(); + mockClientSupplier.close(); + } + + @Test + void testDelegateExtendClose() throws IOException { + Supplier<MockClient> cf = () -> new MockClient(false); + ExpiringCloseableSupplier<MockClient> mockClientSupplier = ExpiringCloseableSupplier.of(cf); + MockClient mockClient = mockClientSupplier.get(); + mockClient.close(); + assertTrue(mockClient.isClosed()); + + MockClient mockClient1 = mockClientSupplier.get(); + assertNotSame(mockClient, mockClient1); + MockClient mockClient2 = mockClientSupplier.get(); + assertSame(mockClient1, mockClient2); + mockClientSupplier.close(); + } + + @Test + public void testSerialization() { + Supplier<MockClient> cf = (Supplier<MockClient> & Serializable) () -> new MockClient(true); + ExpiringCloseableSupplier<MockClient> mockClientSupplier = ExpiringCloseableSupplier.of(cf, 10); + MockClient mockClient = mockClientSupplier.get(); + + ExpiringCloseableSupplier<MockClient> mockClientSupplier2 = + SerializationUtils.roundtrip(mockClientSupplier); + MockClient mockClient2 = mockClientSupplier2.get(); + assertFalse(mockClient2.isClosed()); + assertNotSame(mockClient, mockClient2); + Uninterruptibles.sleepUninterruptibly(30, TimeUnit.MILLISECONDS); + assertTrue(mockClient.isClosed()); + assertTrue(mockClient2.isClosed()); + } + + @Test + public void testMultipleSupplierShouldNotInterfere() { + Supplier<MockClient> cf = () -> new MockClient(true); + ExpiringCloseableSupplier<MockClient> mockClientSupplier = ExpiringCloseableSupplier.of(cf, 10); + ExpiringCloseableSupplier<MockClient> mockClientSupplier2 = + ExpiringCloseableSupplier.of(cf, 10); + MockClient mockClient = mockClientSupplier.get(); + MockClient mockClient2 = mockClientSupplier2.get(); + Uninterruptibles.sleepUninterruptibly(30, TimeUnit.MILLISECONDS); + assertTrue(mockClient.isClosed()); + assertTrue(mockClient2.isClosed()); + mockClientSupplier.close(); + mockClientSupplier.close(); + mockClientSupplier2.close(); + mockClientSupplier2.close(); + } + + @Test + public void stressingTestManySuppliers() { + int num = 100000; // this should be sufficient for most production use cases + Supplier<MockClient> cf = () -> new MockClient(true); + List<MockClient> clients = Lists.newArrayList(); + Random random = new Random(42); + for (int i = 0; i < num; i++) { + int delayCloseInterval = random.nextInt(1000) + 1; + ExpiringCloseableSupplier<MockClient> mockClientSupplier = + ExpiringCloseableSupplier.of(cf, delayCloseInterval); + MockClient mockClient = mockClientSupplier.get(); + clients.add(mockClient); + } + Awaitility.waitAtMost(5, TimeUnit.SECONDS) + .until(() -> clients.stream().allMatch(MockClient::isClosed)); + } + + private static class MockClient implements StatefulCloseable, Serializable { + boolean withException; + AtomicBoolean closed = new AtomicBoolean(false); + + MockClient(boolean withException) { + this.withException = withException; + } + + @Override + public void close() throws IOException { + closed.set(true); + if (withException) { + throw new IOException("test exception!"); + } + } + + @Override + public boolean isClosed() { + return closed.get(); + } + } +} diff --git a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/ShuffleServerManagerTestBase.java b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/ShuffleServerManagerTestBase.java index 831fa0f2f..abe3a9dfa 100644 --- a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/ShuffleServerManagerTestBase.java +++ b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/ShuffleServerManagerTestBase.java @@ -23,6 +23,7 @@ import org.junit.jupiter.api.BeforeEach; import org.apache.uniffle.client.factory.ShuffleManagerClientFactory; import org.apache.uniffle.client.impl.grpc.ShuffleManagerGrpcClient; import org.apache.uniffle.common.ClientType; +import org.apache.uniffle.common.config.RssBaseConf; import org.apache.uniffle.common.config.RssConf; import org.apache.uniffle.common.rpc.GrpcServer; import org.apache.uniffle.shuffle.manager.DummyRssShuffleManager; @@ -36,12 +37,17 @@ public class ShuffleServerManagerTestBase { protected ShuffleManagerGrpcClient client; protected static final String LOCALHOST = "localhost"; protected GrpcServer shuffleManagerServer; + protected RssConf rssConf; protected RssShuffleManagerInterface getShuffleManager() { return new DummyRssShuffleManager(); } - protected RssConf getConf() { + protected ShuffleServerManagerTestBase() { + this.rssConf = getRssConf(); + } + + private RssConf getRssConf() { RssConf conf = new RssConf(); // use a random port conf.set(RPC_SERVER_PORT, 0); @@ -49,7 +55,7 @@ public class ShuffleServerManagerTestBase { } protected GrpcServer createShuffleManagerServer() { - return new ShuffleManagerServerFactory(getShuffleManager(), getConf()).getServer(); + return new ShuffleManagerServerFactory(getShuffleManager(), rssConf).getServer(); } @BeforeEach @@ -57,7 +63,8 @@ public class ShuffleServerManagerTestBase { shuffleManagerServer = createShuffleManagerServer(); shuffleManagerServer.start(); int port = shuffleManagerServer.getPort(); - client = factory.createShuffleManagerClient(ClientType.GRPC, LOCALHOST, port); + long rpcTimeout = rssConf.getLong(RssBaseConf.RSS_CLIENT_TYPE_GRPC_TIMEOUT_MS); + client = factory.createShuffleManagerClient(ClientType.GRPC, LOCALHOST, port, rpcTimeout); } @AfterEach diff --git a/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java b/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java index c5b412a9e..6616fe7b1 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleManagerClient.java @@ -17,8 +17,6 @@ package org.apache.uniffle.client.api; -import java.io.Closeable; - import org.apache.uniffle.client.request.RssGetShuffleResultForMultiPartRequest; import org.apache.uniffle.client.request.RssGetShuffleResultRequest; import org.apache.uniffle.client.request.RssPartitionToShuffleServerRequest; @@ -34,8 +32,9 @@ import org.apache.uniffle.client.response.RssReassignServersResponse; import org.apache.uniffle.client.response.RssReportShuffleFetchFailureResponse; import org.apache.uniffle.client.response.RssReportShuffleResultResponse; import org.apache.uniffle.client.response.RssReportShuffleWriteFailureResponse; +import org.apache.uniffle.common.util.StatefulCloseable; -public interface ShuffleManagerClient extends Closeable { +public interface ShuffleManagerClient extends StatefulCloseable { RssReportShuffleFetchFailureResponse reportShuffleFetchFailure( RssReportShuffleFetchFailureRequest request); diff --git a/internal-client/src/main/java/org/apache/uniffle/client/factory/ShuffleManagerClientFactory.java b/internal-client/src/main/java/org/apache/uniffle/client/factory/ShuffleManagerClientFactory.java index c55acdc22..66b4a2a9e 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/factory/ShuffleManagerClientFactory.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/factory/ShuffleManagerClientFactory.java @@ -33,9 +33,9 @@ public class ShuffleManagerClientFactory { private ShuffleManagerClientFactory() {} public ShuffleManagerGrpcClient createShuffleManagerClient( - ClientType clientType, String host, int port) { + ClientType clientType, String host, int port, long rpcTimeout) { if (ClientType.GRPC.equals(clientType)) { - return new ShuffleManagerGrpcClient(host, port); + return new ShuffleManagerGrpcClient(host, port, rpcTimeout); } else { throw new UnsupportedOperationException("Unsupported client type " + clientType); } diff --git a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java index 6dd9f4a1e..8cad876c2 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleManagerGrpcClient.java @@ -38,7 +38,6 @@ import org.apache.uniffle.client.response.RssReassignServersResponse; import org.apache.uniffle.client.response.RssReportShuffleFetchFailureResponse; import org.apache.uniffle.client.response.RssReportShuffleResultResponse; import org.apache.uniffle.client.response.RssReportShuffleWriteFailureResponse; -import org.apache.uniffle.common.config.RssBaseConf; import org.apache.uniffle.common.exception.RssException; import org.apache.uniffle.proto.RssProtos; import org.apache.uniffle.proto.RssProtos.ReportShuffleFetchFailureRequest; @@ -48,22 +47,22 @@ import org.apache.uniffle.proto.ShuffleManagerGrpc; public class ShuffleManagerGrpcClient extends GrpcClient implements ShuffleManagerClient { private static final Logger LOG = LoggerFactory.getLogger(ShuffleManagerGrpcClient.class); - private static RssBaseConf rssConf = new RssBaseConf(); - private long rpcTimeout = rssConf.getLong(RssBaseConf.RSS_CLIENT_TYPE_GRPC_TIMEOUT_MS); + private final long rpcTimeout; private ShuffleManagerGrpc.ShuffleManagerBlockingStub blockingStub; - public ShuffleManagerGrpcClient(String host, int port) { - this(host, port, 3); + public ShuffleManagerGrpcClient(String host, int port, long rpcTimeout) { + this(host, port, rpcTimeout, 3); } - public ShuffleManagerGrpcClient(String host, int port, int maxRetryAttempts) { - this(host, port, maxRetryAttempts, true); + public ShuffleManagerGrpcClient(String host, int port, long rpcTimeout, int maxRetryAttempts) { + this(host, port, rpcTimeout, maxRetryAttempts, true); } public ShuffleManagerGrpcClient( - String host, int port, int maxRetryAttempts, boolean usePlaintext) { + String host, int port, long rpcTimeout, int maxRetryAttempts, boolean usePlaintext) { super(host, port, maxRetryAttempts, usePlaintext); blockingStub = ShuffleManagerGrpc.newBlockingStub(channel); + this.rpcTimeout = rpcTimeout; } public ShuffleManagerGrpc.ShuffleManagerBlockingStub getBlockingStub() { @@ -165,4 +164,9 @@ public class ShuffleManagerGrpcClient extends GrpcClient implements ShuffleManag getBlockingStub().reportShuffleResult(request.toProto()); return RssReportShuffleResultResponse.fromProto(response); } + + @Override + public boolean isClosed() { + return channel.isShutdown() || channel.isTerminated(); + } } diff --git a/internal-client/src/test/java/org/apache/uniffle/client/factory/ShuffleManagerClientFactoryTest.java b/internal-client/src/test/java/org/apache/uniffle/client/factory/ShuffleManagerClientFactoryTest.java index 5fed54ff0..c40c06c32 100644 --- a/internal-client/src/test/java/org/apache/uniffle/client/factory/ShuffleManagerClientFactoryTest.java +++ b/internal-client/src/test/java/org/apache/uniffle/client/factory/ShuffleManagerClientFactoryTest.java @@ -32,10 +32,11 @@ class ShuffleManagerClientFactoryTest { ShuffleManagerClientFactory factory = ShuffleManagerClientFactory.getInstance(); assertNotNull(factory); // only grpc type is supported currently - ShuffleManagerClient c = factory.createShuffleManagerClient(ClientType.GRPC, "localhost", 1234); + ShuffleManagerClient c = + factory.createShuffleManagerClient(ClientType.GRPC, "localhost", 1234, 60000); assertNotNull(c); assertThrows( UnsupportedOperationException.class, - () -> factory.createShuffleManagerClient(ClientType.GRPC_NETTY, "localhost", 1234)); + () -> factory.createShuffleManagerClient(ClientType.GRPC_NETTY, "localhost", 1234, 60000)); } }