This is an automated email from the ASF dual-hosted git repository. zhouky pushed a commit to branch branch-0.2 in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git
commit 07248402a0575867bfcc7830b8681137a632a3e2 Author: Keyong Zhou <[email protected]> AuthorDate: Tue Dec 20 20:40:42 2022 +0800 [CELEBORN-119] Add timeout for pushdata (#1097) --- .../apache/celeborn/client/ShuffleClientImpl.java | 59 +++++++-------- .../apache/celeborn/client/write/PushState.java | 83 +++++++++++++++++----- .../common/network/client/TransportClient.java | 2 +- .../common/protocol/message/StatusCode.java | 6 +- .../org/apache/celeborn/common/CelebornConf.scala | 15 +++- docs/configuration/client.md | 2 +- docs/configuration/worker.md | 1 + .../celeborn/tests/spark/PushdataTimeoutTest.scala | 80 +++++++++++++++++++++ .../service/deploy/worker/PushDataHandler.scala | 20 ++++++ 9 files changed, 212 insertions(+), 56 deletions(-) diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index a5aa2020..2f740d3c 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -172,7 +172,7 @@ public class ShuffleClientImpl extends ShuffleClient { } else if (mapperEnded(shuffleId, mapId, attemptId)) { logger.debug( "Retrying push data, but the mapper(map {} attempt {}) has ended.", mapId, attemptId); - pushState.inFlightBatches.remove(batchId); + pushState.removeBatch(batchId); } else { PartitionLocation newLoc = reducePartitionMap.get(shuffleId).get(partitionId); logger.info("Revive success, new location for reduce {} is {}.", partitionId, newLoc); @@ -185,7 +185,7 @@ public class ShuffleClientImpl extends ShuffleClient { PushData newPushData = new PushData(MASTER_MODE, shuffleKey, newLoc.getUniqueId(), newBuffer); ChannelFuture future = client.pushData(newPushData, callback); - pushState.addFuture(batchId, future); + pushState.pushStarted(batchId, future, callback); } catch (Exception ex) { logger.warn( "Exception raised while pushing data for shuffle {} map {} attempt {}" + " batch {}.", @@ -250,7 +250,7 @@ public class ShuffleClientImpl extends ShuffleClient { pushState, true); } - pushState.inFlightBatches.remove(oldGroupedBatchId); + pushState.removeBatch(oldGroupedBatchId); } private String genAddressPair(PartitionLocation loc) { @@ -354,18 +354,20 @@ public class ShuffleClientImpl extends ShuffleClient { throw pushState.exception.get(); } - ConcurrentHashMap<Integer, PartitionLocation> inFlightBatches = pushState.inFlightBatches; long timeoutMs = conf.pushLimitInFlightTimeoutMs(); long delta = conf.pushLimitInFlightSleepDeltaMs(); long times = timeoutMs / delta; try { while (times > 0) { - if (inFlightBatches.size() <= limit) { + if (pushState.inflightBatchCount() <= limit) { break; } if (pushState.exception.get() != null) { throw pushState.exception.get(); } + + pushState.failExpiredBatch(); + Thread.sleep(delta); times--; } @@ -378,10 +380,9 @@ public class ShuffleClientImpl extends ShuffleClient { "After waiting for {} ms, there are still {} batches in flight for map {}, " + "which exceeds the limit {}.", timeoutMs, - inFlightBatches.size(), + pushState.inflightBatchCount(), mapKey, limit); - logger.error("Map: {} in flight batches: {}", mapKey, inFlightBatches); throw new IOException("wait timeout for task " + mapKey, pushState.exception.get()); } if (pushState.exception.get() != null) { @@ -501,7 +502,7 @@ public class ShuffleClientImpl extends ShuffleClient { attemptId); PushState pushState = pushStates.get(mapKey); if (pushState != null) { - pushState.cancelFutures(); + pushState.cleanup(); } return 0; } @@ -540,7 +541,7 @@ public class ShuffleClientImpl extends ShuffleClient { attemptId); PushState pushState = pushStates.get(mapKey); if (pushState != null) { - pushState.cancelFutures(); + pushState.cleanup(); } return 0; } @@ -587,7 +588,7 @@ public class ShuffleClientImpl extends ShuffleClient { limitMaxInFlight(mapKey, pushState, maxInFlight); // add inFlight requests - pushState.inFlightBatches.put(nextBatchId, loc); + pushState.addBatch(nextBatchId); // build PushData request NettyManagedBuffer buffer = new NettyManagedBuffer(Unpooled.wrappedBuffer(body)); @@ -598,7 +599,7 @@ public class ShuffleClientImpl extends ShuffleClient { new RpcResponseCallback() { @Override public void onSuccess(ByteBuffer response) { - pushState.inFlightBatches.remove(nextBatchId); + pushState.removeBatch(nextBatchId); // TODO Need to adjust maxReqsInFlight if server response is congested, see // CELEBORN-62 if (response.remaining() > 0 && response.get() == StatusCode.STAGE_ENDED.getValue()) { @@ -606,7 +607,6 @@ public class ShuffleClientImpl extends ShuffleClient { .computeIfAbsent(shuffleId, (id) -> ConcurrentHashMap.newKeySet()) .add(mapKey); } - pushState.removeFuture(nextBatchId); logger.debug( "Push data to {}:{} success for map {} attempt {} batch {}.", loc.getHost(), @@ -620,7 +620,6 @@ public class ShuffleClientImpl extends ShuffleClient { public void onFailure(Throwable e) { pushState.exception.compareAndSet( null, new IOException("Revived PushData failed!", e)); - pushState.removeFuture(nextBatchId); logger.error( "Push data to {}:{} failed for map {} attempt {} batch {}.", loc.getHost(), @@ -703,7 +702,7 @@ public class ShuffleClientImpl extends ShuffleClient { pushState, getPushDataFailCause(e.getMessage()))); } else { - pushState.inFlightBatches.remove(nextBatchId); + pushState.removeBatch(nextBatchId); logger.info( "Mapper shuffleId:{} mapId:{} attempt:{} already ended, remove batchId:{}.", shuffleId, @@ -719,7 +718,7 @@ public class ShuffleClientImpl extends ShuffleClient { TransportClient client = dataClientFactory.createClient(loc.getHost(), loc.getPushPort(), partitionId); ChannelFuture future = client.pushData(pushData, wrappedCallback); - pushState.addFuture(nextBatchId, future); + pushState.pushStarted(nextBatchId, future, wrappedCallback); } catch (Exception e) { logger.warn("PushData failed", e); wrappedCallback.onFailure( @@ -875,7 +874,7 @@ public class ShuffleClientImpl extends ShuffleClient { final int port = Integer.parseInt(splits[1]); int groupedBatchId = pushState.batchId.addAndGet(1); - pushState.inFlightBatches.put(groupedBatchId, batches.get(0).loc); + pushState.addBatch(groupedBatchId); final int numBatches = batches.size(); final String[] partitionUniqueIds = new String[numBatches]; @@ -905,7 +904,7 @@ public class ShuffleClientImpl extends ShuffleClient { mapId, attemptId, groupedBatchId); - pushState.inFlightBatches.remove(groupedBatchId); + pushState.removeBatch(groupedBatchId); // TODO Need to adjust maxReqsInFlight if server response is congested, see CELEBORN-62 if (response.remaining() > 0 && response.get() == StatusCode.STAGE_ENDED.getValue()) { mapperEndMap @@ -1017,7 +1016,8 @@ public class ShuffleClientImpl extends ShuffleClient { // do push merged data try { TransportClient client = dataClientFactory.createClient(host, port); - client.pushMergedData(mergedData, wrappedCallback); + ChannelFuture future = client.pushMergedData(mergedData, wrappedCallback); + pushState.pushStarted(groupedBatchId, future, wrappedCallback); } catch (Exception e) { logger.warn("PushMergedData failed", e); wrappedCallback.onFailure(new Exception(getPushDataFailCause(e.getMessage()).toString(), e)); @@ -1052,7 +1052,7 @@ public class ShuffleClientImpl extends ShuffleClient { PushState pushState = pushStates.remove(mapKey); if (pushState != null) { pushState.exception.compareAndSet(null, new IOException("Cleaned Up")); - pushState.cancelFutures(); + pushState.cleanup(); } } @@ -1207,6 +1207,8 @@ public class ShuffleClientImpl extends ShuffleClient { } else if (StatusCode.PUSH_DATA_FAIL_MASTER.getMessage().equals(message) || connectFail(message)) { cause = StatusCode.PUSH_DATA_FAIL_MASTER; + } else if (StatusCode.PUSH_DATA_TIMEOUT.getMessage().equals(message)) { + cause = StatusCode.PUSH_DATA_TIMEOUT; } else { cause = StatusCode.PUSH_DATA_FAIL_NON_CRITICAL_CAUSE; } @@ -1233,7 +1235,6 @@ public class ShuffleClientImpl extends ShuffleClient { shuffleId, mapId, attemptId, - location, () -> { String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId); logger.info( @@ -1252,7 +1253,7 @@ public class ShuffleClientImpl extends ShuffleClient { attemptId, numPartitions, bufferSize); - client.sendRpcSync(handShake.toByteBuffer(), conf.pushDataRpcTimeoutMs()); + client.sendRpcSync(handShake.toByteBuffer(), conf.pushDataTimeoutMs()); return null; }); } @@ -1271,7 +1272,6 @@ public class ShuffleClientImpl extends ShuffleClient { shuffleId, mapId, attemptId, - location, () -> { String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId); logger.info( @@ -1291,7 +1291,7 @@ public class ShuffleClientImpl extends ShuffleClient { currentRegionIdx, isBroadcast); ByteBuffer regionStartResponse = - client.sendRpcSync(regionStart.toByteBuffer(), conf.pushDataRpcTimeoutMs()); + client.sendRpcSync(regionStart.toByteBuffer(), conf.pushDataTimeoutMs()); if (regionStartResponse.hasRemaining() && regionStartResponse.get() == StatusCode.HARD_SPLIT.getValue()) { // if split then revive @@ -1340,7 +1340,6 @@ public class ShuffleClientImpl extends ShuffleClient { shuffleId, mapId, attemptId, - location, () -> { final String shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId); logger.info( @@ -1353,17 +1352,13 @@ public class ShuffleClientImpl extends ShuffleClient { dataClientFactory.createClient(location.getHost(), location.getPushPort()); RegionFinish regionFinish = new RegionFinish(MASTER_MODE, shuffleKey, location.getUniqueId(), attemptId); - client.sendRpcSync(regionFinish.toByteBuffer(), conf.pushDataRpcTimeoutMs()); + client.sendRpcSync(regionFinish.toByteBuffer(), conf.pushDataTimeoutMs()); return null; }); } private <R> R sendMessageInternal( - int shuffleId, - int mapId, - int attemptId, - PartitionLocation location, - ThrowingExceptionSupplier<R, Exception> supplier) + int shuffleId, int mapId, int attemptId, ThrowingExceptionSupplier<R, Exception> supplier) throws IOException { PushState pushState = null; int batchId = 0; @@ -1385,11 +1380,11 @@ public class ShuffleClientImpl extends ShuffleClient { // add inFlight requests batchId = pushState.batchId.incrementAndGet(); - pushState.inFlightBatches.put(batchId, location); + pushState.addBatch(batchId); return retrySendMessage(supplier); } finally { if (pushState != null) { - pushState.inFlightBatches.remove(batchId); + pushState.removeBatch(batchId); } } } diff --git a/client/src/main/java/org/apache/celeborn/client/write/PushState.java b/client/src/main/java/org/apache/celeborn/client/write/PushState.java index 67a719aa..a8f96240 100644 --- a/client/src/main/java/org/apache/celeborn/client/write/PushState.java +++ b/client/src/main/java/org/apache/celeborn/client/write/PushState.java @@ -18,8 +18,6 @@ package org.apache.celeborn.client.write; import java.io.IOException; -import java.util.HashSet; -import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; @@ -29,41 +27,90 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.celeborn.common.CelebornConf; +import org.apache.celeborn.common.network.client.RpcResponseCallback; import org.apache.celeborn.common.protocol.PartitionLocation; +import org.apache.celeborn.common.protocol.message.StatusCode; public class PushState { + class BatchInfo { + ChannelFuture channelFuture; + long pushTime; + RpcResponseCallback callback; + } + private static final Logger logger = LoggerFactory.getLogger(PushState.class); private int pushBufferMaxSize; + private long pushTimeout; public final AtomicInteger batchId = new AtomicInteger(); - public final ConcurrentHashMap<Integer, PartitionLocation> inFlightBatches = + private final ConcurrentHashMap<Integer, BatchInfo> inflightBatchInfos = new ConcurrentHashMap<>(); - public final ConcurrentHashMap<Integer, ChannelFuture> futures = new ConcurrentHashMap<>(); public AtomicReference<IOException> exception = new AtomicReference<>(); public PushState(CelebornConf conf) { pushBufferMaxSize = conf.pushBufferMaxSize(); + pushTimeout = conf.pushDataTimeoutMs(); + } + + public void addBatch(int batchId) { + inflightBatchInfos.computeIfAbsent(batchId, id -> new BatchInfo()); + } + + public void removeBatch(int batchId) { + BatchInfo info = inflightBatchInfos.remove(batchId); + if (info != null && info.channelFuture != null) { + info.channelFuture.cancel(true); + } } - public void addFuture(int batchId, ChannelFuture future) { - futures.put(batchId, future); + public int inflightBatchCount() { + return inflightBatchInfos.size(); } - public void removeFuture(int batchId) { - futures.remove(batchId); + public synchronized void failExpiredBatch() { + long currentTime = System.currentTimeMillis(); + inflightBatchInfos + .values() + .forEach( + info -> { + if (currentTime - info.pushTime > pushTimeout) { + if (info.callback != null) { + info.channelFuture.cancel(true); + info.callback.onFailure( + new IOException(StatusCode.PUSH_DATA_TIMEOUT.getMessage())); + info.channelFuture = null; + info.callback = null; + } + } + }); + } + + public void pushStarted(int batchId, ChannelFuture future, RpcResponseCallback callback) { + BatchInfo info = inflightBatchInfos.get(batchId); + // In rare cases info could be null. For example, a speculative task has one thread pushing, + // and other thread retry-pushing. At time 1 thread 1 find StageEnded, then it cleans up + // PushState, at the same time thread 2 pushes data and calles pushStarted, + // at this time info will be null + if (info != null) { + info.pushTime = System.currentTimeMillis(); + info.channelFuture = future; + info.callback = callback; + } } - public synchronized void cancelFutures() { - if (!futures.isEmpty()) { - Set<Integer> keys = new HashSet<>(futures.keySet()); - logger.debug("Cancel all {} futures.", keys.size()); - for (Integer batchId : keys) { - ChannelFuture future = futures.remove(batchId); - if (future != null) { - future.cancel(true); - } - } + public void cleanup() { + if (!inflightBatchInfos.isEmpty()) { + logger.debug("Cancel all {} futures.", inflightBatchInfos.size()); + inflightBatchInfos + .values() + .forEach( + entry -> { + if (entry.channelFuture != null) { + entry.channelFuture.cancel(true); + } + }); + inflightBatchInfos.clear(); } } diff --git a/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java b/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java index 4ffc579d..88e8c42a 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java +++ b/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java @@ -288,7 +288,7 @@ public class TransportClient implements Closeable { } else { String errorMsg = String.format( - "Failed to send RPC %s to %s: %s, channel will be closed", + "Failed to send request %s to %s: %s, channel will be closed", requestId, NettyUtils.getRemoteAddress(channel), future.cause()); logger.warn(errorMsg); channel.close(); diff --git a/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java b/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java index 3b3f067d..0fb46765 100644 --- a/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java +++ b/common/src/main/java/org/apache/celeborn/common/protocol/message/StatusCode.java @@ -66,7 +66,9 @@ public enum StatusCode { REGION_START_FAIL_SLAVE(34), REGION_START_FAIL_MASTER(35), REGION_FINISH_FAIL_SLAVE(36), - REGION_FINISH_FAIL_MASTER(37); + REGION_FINISH_FAIL_MASTER(37), + + PUSH_DATA_TIMEOUT(38); private final byte value; @@ -103,6 +105,8 @@ public enum StatusCode { msg = "RegionFinishFailMaster"; } else if (value == REGION_FINISH_FAIL_SLAVE.getValue()) { msg = "RegionFinishFailSlave"; + } else if (value == PUSH_DATA_TIMEOUT.getValue()) { + msg = "PushDataTimeout"; } return msg; diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala index cb344baa..8fc0f576 100644 --- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala @@ -556,6 +556,7 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se // ////////////////////////////////////////////////////// def testFetchFailure: Boolean = get(TEST_FETCH_FAILURE) def testRetryCommitFiles: Boolean = get(TEST_RETRY_COMMIT_FILE) + def testPushDataTimeout: Boolean = get(TEST_PUSHDATA_TIMEOUT) def masterHost: String = get(MASTER_HOST) @@ -679,7 +680,7 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se def rpcCacheSize: Int = get(RPC_CACHE_SIZE) def rpcCacheConcurrencyLevel: Int = get(RPC_CACHE_CONCURRENCY_LEVEL) def rpcCacheExpireTime: Long = get(RPC_CACHE_EXPIRE_TIME) - def pushDataRpcTimeoutMs = get(PUSH_DATA_RPC_TIMEOUT) + def pushDataTimeoutMs = get(PUSH_DATA_TIMEOUT) def registerShuffleRpcAskTimeout: RpcTimeout = new RpcTimeout( @@ -2189,8 +2190,8 @@ object CelebornConf extends Logging { .timeConf(TimeUnit.MILLISECONDS) .createWithDefaultString("5s") - val PUSH_DATA_RPC_TIMEOUT: ConfigEntry[Long] = - buildConf("celeborn.push.data.rpc.timeout") + val PUSH_DATA_TIMEOUT: ConfigEntry[Long] = + buildConf("celeborn.push.data.timeout") .withAlternative("rss.push.data.rpc.timeout") .categories("client") .version("0.2.0") @@ -2198,6 +2199,14 @@ object CelebornConf extends Logging { .timeConf(TimeUnit.MILLISECONDS) .createWithDefaultString("120s") + val TEST_PUSHDATA_TIMEOUT: ConfigEntry[Boolean] = + buildConf("celeborn.test.pushdataTimeout") + .categories("worker") + .version("0.2.0") + .doc("Wheter to test pushdata timeout") + .booleanConf + .createWithDefault(false) + val REGISTER_SHUFFLE_RPC_ASK_TIMEOUT: OptionalConfigEntry[Long] = buildConf("celeborn.rpc.registerShuffle.askTimeout") .categories("client") diff --git a/docs/configuration/client.md b/docs/configuration/client.md index 8ec3aafc..b58fac84 100644 --- a/docs/configuration/client.md +++ b/docs/configuration/client.md @@ -27,7 +27,7 @@ license: | | celeborn.master.endpoints | <localhost>:9097 | Endpoints of master nodes for celeborn client to connect, allowed pattern is: `<host1>:<port1>[,<host2>:<port2>]*`, e.g. `clb1:9097,clb2:9098,clb3:9099`. If the port is omitted, 9097 will be used. | 0.2.0 | | celeborn.push.buffer.initial.size | 8k | | 0.2.0 | | celeborn.push.buffer.max.size | 64k | Max size of reducer partition buffer memory for shuffle hash writer. The pushed data will be buffered in memory before sending to Celeborn worker. For performance consideration keep this buffer size higher than 32K. Example: If reducer amount is 2000, buffer size is 64K, then each task will consume up to `64KiB * 2000 = 125MiB` heap memory. | 0.2.0 | -| celeborn.push.data.rpc.timeout | 120s | Timeout for a task to push data rpc message. | 0.2.0 | +| celeborn.push.data.timeout | 120s | Timeout for a task to push data rpc message. | 0.2.0 | | celeborn.push.limit.inFlight.sleepInterval | 50ms | Sleep interval when check netty in-flight requests to be done. | 0.2.0 | | celeborn.push.limit.inFlight.timeout | 240s | Timeout for netty in-flight requests to be done. | 0.2.0 | | celeborn.push.maxReqsInFlight | 32 | Amount of Netty in-flight requests. The maximum memory is `celeborn.push.maxReqsInFlight` * `celeborn.push.buffer.max.size` * compression ratio(1 in worst case), default: 64Kib * 32 = 2Mib | 0.2.0 | diff --git a/docs/configuration/worker.md b/docs/configuration/worker.md index bd7d7d91..a5a63fc2 100644 --- a/docs/configuration/worker.md +++ b/docs/configuration/worker.md @@ -29,6 +29,7 @@ license: | | celeborn.shuffle.chuck.size | 8m | Max chunk size of reducer's merged shuffle data. For example, if a reducer's shuffle data is 128M and the data will need 16 fetch chunk requests to fetch. | 0.2.0 | | celeborn.shuffle.minPartitionSizeToEstimate | 8mb | Ignore partition size smaller than this configuration of partition size for estimation. | 0.2.0 | | celeborn.storage.hdfs.dir | <undefined> | HDFS dir configuration for Celeborn to access HDFS. | 0.2.0 | +| celeborn.test.pushdataTimeout | false | Wheter to test pushdata timeout | 0.2.0 | | celeborn.worker.closeIdleConnections | false | Whether worker will close idle connections. | 0.2.0 | | celeborn.worker.commit.threads | 32 | Thread number of worker to commit shuffle data files asynchronously. | 0.2.0 | | celeborn.worker.directMemoryRatioForMemoryShuffleStorage | 0.1 | Max ratio of direct memory to store shuffle data | 0.2.0 | diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/PushdataTimeoutTest.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/PushdataTimeoutTest.scala new file mode 100644 index 00000000..ed876524 --- /dev/null +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/PushdataTimeoutTest.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.celeborn.tests.spark + +import org.apache.spark.SparkConf +import org.apache.spark.sql.SparkSession +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} +import org.scalatest.funsuite.AnyFunSuite + +import org.apache.celeborn.client.ShuffleClient + +class PushdataTimeoutTest extends AnyFunSuite + with SparkTestBase + with BeforeAndAfterAll + with BeforeAndAfterEach { + + override def beforeAll(): Unit = { + logInfo("test initialized , setup rss mini cluster") + val workerConf = Map( + "celeborn.test.pushdataTimeout" -> s"true") + tuple = setupRssMiniClusterSpark(masterConfs = null, workerConfs = workerConf) + } + + override def afterAll(): Unit = { + logInfo("all test complete , stop rss mini cluster") + clearMiniCluster(tuple) + } + + override def beforeEach(): Unit = { + ShuffleClient.reset() + } + + override def afterEach(): Unit = { + System.gc() + } + + test("celeborn spark integration test - pushdata timeout") { + val sparkConf = new SparkConf().setAppName("rss-demo").setMaster("local[4]") + .set("spark.celeborn.push.data.timeout", "10s") + val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate() + val combineResult = combine(sparkSession) + val groupbyResult = groupBy(sparkSession) + val repartitionResult = repartition(sparkSession) + val sqlResult = runsql(sparkSession) + + Thread.sleep(3000L) + sparkSession.stop() + + val rssSparkSession = SparkSession.builder() + .config(updateSparkConf(sparkConf, false)).getOrCreate() + val rssCombineResult = combine(rssSparkSession) + val rssGroupbyResult = groupBy(rssSparkSession) + val rssRepartitionResult = repartition(rssSparkSession) + val rssSqlResult = runsql(rssSparkSession) + + assert(combineResult.equals(rssCombineResult)) + assert(groupbyResult.equals(rssGroupbyResult)) + assert(repartitionResult.equals(rssRepartitionResult)) + assert(combineResult.equals(rssCombineResult)) + assert(sqlResult.equals(rssSqlResult)) + + rssSparkSession.stop() + + } +} diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala index aba8809b..a35b40f1 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala @@ -24,6 +24,7 @@ import java.util.concurrent.atomic.{AtomicBoolean, AtomicIntegerArray} import com.google.common.base.Throwables import io.netty.buffer.ByteBuf +import org.apache.celeborn.common.CelebornConf import org.apache.celeborn.common.exception.AlreadyClosedException import org.apache.celeborn.common.internal.Logging import org.apache.celeborn.common.meta.{PartitionLocationInfo, WorkerInfo} @@ -55,6 +56,8 @@ class PushDataHandler extends BaseMessageHandler with Logging { var partitionSplitMinimumSize: Long = _ var shutdown: AtomicBoolean = _ var storageManager: StorageManager = _ + var conf: CelebornConf = _ + @volatile var pushDataTimeoutTested = false def init(worker: Worker): Unit = { workerSource = worker.workerSource @@ -71,6 +74,7 @@ class PushDataHandler extends BaseMessageHandler with Logging { partitionSplitMinimumSize = worker.conf.partitionSplitMinimumSize storageManager = worker.storageManager shutdown = worker.shutdown + conf = worker.conf logInfo(s"diskReserveSize $diskReserveSize") } @@ -115,6 +119,12 @@ class PushDataHandler extends BaseMessageHandler with Logging { val body = pushData.body.asInstanceOf[NettyManagedBuffer].getBuf val isMaster = mode == PartitionLocation.Mode.MASTER + // For test + if (conf.testPushDataTimeout && !pushDataTimeoutTested) { + pushDataTimeoutTested = true + return + } + val key = s"${pushData.requestId}" if (isMaster) { workerSource.startTimer(WorkerSource.MasterPushDataTime, key) @@ -303,6 +313,12 @@ class PushDataHandler extends BaseMessageHandler with Logging { workerSource.startTimer(WorkerSource.SlavePushDataTime, key) } + // For test + if (conf.testPushDataTimeout && !PushDataHandler.pushDataTimeoutTested) { + PushDataHandler.pushDataTimeoutTested = true + return + } + val wrappedCallback = new RpcResponseCallback() { override def onSuccess(response: ByteBuffer): Unit = { if (isMaster) { @@ -803,3 +819,7 @@ class PushDataHandler extends BaseMessageHandler with Logging { } } } + +object PushDataHandler { + @volatile var pushDataTimeoutTested = false +}
