This is an automated email from the ASF dual-hosted git repository.
zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new 1b48c120 [#706] feat(spark): support spill to avoid memory deadlock
(#714)
1b48c120 is described below
commit 1b48c12026ad7be42368fb6eee2cd9a1dff2bbb5
Author: Junfan Zhang <[email protected]>
AuthorDate: Tue Mar 28 19:02:47 2023 +0800
[#706] feat(spark): support spill to avoid memory deadlock (#714)
### What changes were proposed in this pull request?
1. Introduce the `DataPusher` to replace the `eventLoop`, this could be as
general part for spark2 and spark3.
2. Implement the `spill` method in `WriterBufferManager` to avoid memory
deadlock.
### Why are the changes needed?
In current codebase, if having several `WriterBufferManagers`, when each
other is acquiring memory, the deadlock will happen. To solve this, we should
implement spill function to break this deadlock condition.
Fix: #706
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
1. Existing UTs
2. Newly added UTs
---
.../org/apache/spark/shuffle/RssSparkConfig.java | 20 ++-
.../apache/spark/shuffle/writer/AddBlockEvent.java | 21 ++++
.../apache/spark/shuffle/writer/DataPusher.java | 137 +++++++++++++++++++++
.../spark/shuffle/writer/WriteBufferManager.java | 115 ++++++++++++++++-
.../spark/shuffle/writer/DataPusherTest.java | 119 ++++++++++++++++++
.../shuffle/writer/WriteBufferManagerTest.java | 79 ++++++++++++
.../apache/spark/shuffle/RssShuffleManager.java | 128 +++++++------------
.../spark/shuffle/writer/RssShuffleWriter.java | 25 +---
.../spark/shuffle/writer/RssShuffleWriterTest.java | 100 +++++++++------
.../apache/spark/shuffle/RssShuffleManager.java | 125 ++++++-------------
.../spark/shuffle/writer/RssShuffleWriter.java | 25 +---
.../java/org/apache/spark/shuffle/TestUtils.java | 7 +-
.../spark/shuffle/writer/RssShuffleWriterTest.java | 125 ++++++++++++-------
.../apache/uniffle/common/util/ThreadUtils.java | 23 +++-
.../uniffle/common/util/ThreadUtilsTest.java | 47 +++++++
15 files changed, 782 insertions(+), 314 deletions(-)
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
index c8e3478c..6aa90e1a 100644
---
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
@@ -28,11 +28,26 @@ import scala.Tuple2;
import scala.runtime.AbstractFunction1;
import org.apache.uniffle.client.util.RssClientConfig;
+import org.apache.uniffle.common.config.ConfigOption;
+import org.apache.uniffle.common.config.ConfigOptions;
import org.apache.uniffle.common.config.ConfigUtils;
import org.apache.uniffle.common.config.RssConf;
public class RssSparkConfig {
+ public static final ConfigOption<Long> RSS_CLIENT_SEND_SIZE_LIMITATION =
ConfigOptions
+ .key("rss.client.send.size.limit")
+ .longType()
+ .defaultValue(1024 * 1024 * 16L)
+ .withDescription("The max data size sent to shuffle server");
+
+ public static final ConfigOption<Integer> RSS_MEMORY_SPILL_TIMEOUT =
ConfigOptions
+ .key("rss.client.memory.spill.timeout.sec")
+ .intType()
+ .defaultValue(1)
+ .withDescription("The timeout of spilling data to remote shuffle server,
"
+ + "which will be triggered by Spark TaskMemoryManager. Unit is sec,
default value is 1");
+
public static final String SPARK_RSS_CONFIG_PREFIX = "spark.";
public static final ConfigEntry<Integer> RSS_PARTITION_NUM_PER_RANGE =
createIntegerBuilder(
@@ -115,11 +130,6 @@ public class RssSparkConfig {
new ConfigBuilder("spark.rss.client.heartBeat.threadNum"))
.createWithDefault(4);
- public static final ConfigEntry<String> RSS_CLIENT_SEND_SIZE_LIMIT =
createStringBuilder(
- new ConfigBuilder("spark.rss.client.send.size.limit")
- .doc("The max data size sent to shuffle server"))
- .createWithDefault("16m");
-
public static final ConfigEntry<Integer>
RSS_CLIENT_UNREGISTER_THREAD_POOL_SIZE = createIntegerBuilder(
new ConfigBuilder("spark.rss.client.unregister.thread.pool.size"))
.createWithDefault(10);
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java
index a8889754..7dab0725 100644
---
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java
@@ -17,6 +17,7 @@
package org.apache.spark.shuffle.writer;
+import java.util.ArrayList;
import java.util.List;
import org.apache.uniffle.common.ShuffleBlockInfo;
@@ -25,10 +26,26 @@ public class AddBlockEvent {
private String taskId;
private List<ShuffleBlockInfo> shuffleDataInfoList;
+ private List<Runnable> processedCallbackChain;
public AddBlockEvent(String taskId, List<ShuffleBlockInfo>
shuffleDataInfoList) {
this.taskId = taskId;
this.shuffleDataInfoList = shuffleDataInfoList;
+ this.processedCallbackChain = new ArrayList<>();
+ }
+
+ public AddBlockEvent(String taskId, List<ShuffleBlockInfo>
shuffleBlockInfoList, Runnable callback) {
+ this.taskId = taskId;
+ this.shuffleDataInfoList = shuffleBlockInfoList;
+ this.processedCallbackChain = new ArrayList<>();
+ addCallback(callback);
+ }
+
+ /**
+ * @param callback, should not throw any exception and execute fast.
+ */
+ public void addCallback(Runnable callback) {
+ processedCallbackChain.add(callback);
}
public String getTaskId() {
@@ -39,6 +56,10 @@ public class AddBlockEvent {
return shuffleDataInfoList;
}
+ public List<Runnable> getProcessedCallbackChain() {
+ return processedCallbackChain;
+ }
+
@Override
public String toString() {
return "AddBlockEvent: TaskId[" + taskId + "], " + shuffleDataInfoList;
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
new file mode 100644
index 00000000..ca03a784
--- /dev/null
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.writer;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.ThreadPoolExecutor;
+import java.util.concurrent.TimeUnit;
+
+import com.google.common.collect.Queues;
+import com.google.common.collect.Sets;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.client.api.ShuffleWriteClient;
+import org.apache.uniffle.client.response.SendShuffleDataResult;
+import org.apache.uniffle.common.ShuffleBlockInfo;
+import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.util.ThreadUtils;
+
+/**
+ * A {@link DataPusher} that is responsible for sending data to remote
+ * shuffle servers asynchronously.
+ */
+public class DataPusher implements Closeable {
+ private static final Logger LOGGER =
LoggerFactory.getLogger(DataPusher.class);
+
+ private final ExecutorService executorService;
+
+ private final ShuffleWriteClient shuffleWriteClient;
+ // Must be thread safe
+ private final Map<String, Set<Long>> taskToSuccessBlockIds;
+ // Must be thread safe
+ private final Map<String, Set<Long>> taskToFailedBlockIds;
+ private String rssAppId;
+ // Must be thread safe
+ private final Set<String> failedTaskIds;
+
+ public DataPusher(ShuffleWriteClient shuffleWriteClient,
+ Map<String, Set<Long>> taskToSuccessBlockIds,
+ Map<String, Set<Long>> taskToFailedBlockIds,
+ Set<String> failedTaskIds,
+ int threadPoolSize,
+ int threadKeepAliveTime) {
+ this.shuffleWriteClient = shuffleWriteClient;
+ this.taskToSuccessBlockIds = taskToSuccessBlockIds;
+ this.taskToFailedBlockIds = taskToFailedBlockIds;
+ this.failedTaskIds = failedTaskIds;
+ this.executorService = new ThreadPoolExecutor(
+ threadPoolSize,
+ threadPoolSize * 2,
+ threadKeepAliveTime,
+ TimeUnit.SECONDS,
+ Queues.newLinkedBlockingQueue(Integer.MAX_VALUE),
+ ThreadUtils.getThreadFactory(this.getClass().getName())
+ );
+ }
+
+ public CompletableFuture<Long> send(AddBlockEvent event) {
+ if (rssAppId == null) {
+ throw new RssException("RssAppId should be set.");
+ }
+ return CompletableFuture.supplyAsync(() -> {
+ String taskId = event.getTaskId();
+ List<ShuffleBlockInfo> shuffleBlockInfoList =
event.getShuffleDataInfoList();
+ try {
+ SendShuffleDataResult result = shuffleWriteClient.sendShuffleData(
+ rssAppId,
+ shuffleBlockInfoList,
+ () -> !isValidTask(taskId)
+ );
+ putBlockId(taskToSuccessBlockIds, taskId, result.getSuccessBlockIds());
+ putBlockId(taskToFailedBlockIds, taskId, result.getFailedBlockIds());
+ } finally {
+ List<Runnable> callbackChain =
Optional.of(event.getProcessedCallbackChain()).orElse(Collections.EMPTY_LIST);
+ for (Runnable runnable : callbackChain) {
+ runnable.run();
+ }
+ }
+ return shuffleBlockInfoList.stream()
+ .map(x -> x.getFreeMemory())
+ .reduce((a, b) -> a + b)
+ .get();
+ }, executorService);
+ }
+
+ private synchronized void putBlockId(
+ Map<String, Set<Long>> taskToBlockIds,
+ String taskAttemptId,
+ Set<Long> blockIds) {
+ if (blockIds == null || blockIds.isEmpty()) {
+ return;
+ }
+ taskToBlockIds.computeIfAbsent(taskAttemptId, x ->
Sets.newConcurrentHashSet()).addAll(blockIds);
+ }
+
+ public boolean isValidTask(String taskId) {
+ return !failedTaskIds.contains(taskId);
+ }
+
+ public void setRssAppId(String rssAppId) {
+ this.rssAppId = rssAppId;
+ }
+
+ @Override
+ public void close() throws IOException {
+ if (executorService != null) {
+ try {
+ ThreadUtils.shutdownThreadPool(executorService, 5);
+ } catch (InterruptedException interruptedException) {
+ LOGGER.error("Errors on shutdown thread pool of [{}].",
this.getClass().getSimpleName());
+ }
+ }
+ }
+}
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
index 9f33be38..580fce69 100644
---
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
@@ -17,10 +17,16 @@
package org.apache.spark.shuffle.writer;
+import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicLong;
+import java.util.function.Function;
+import java.util.stream.Collectors;
import com.clearspring.analytics.util.Lists;
import com.google.common.annotations.VisibleForTesting;
@@ -61,6 +67,7 @@ public class WriteBufferManager extends MemoryConsumer {
private Map<Integer, Integer> partitionToSeqNo = Maps.newHashMap();
private long askExecutorMemory;
private int shuffleId;
+ private String taskId;
private long taskAttemptId;
private SerializerInstance instance;
private ShuffleWriteMetrics shuffleWriteMetrics;
@@ -81,6 +88,9 @@ public class WriteBufferManager extends MemoryConsumer {
private long requireMemoryInterval;
private int requireMemoryRetryMax;
private Codec codec;
+ private Function<AddBlockEvent, CompletableFuture<Long>> spillFunc;
+ private long sendSizeLimit;
+ private int memorySpillTimeoutSec;
public WriteBufferManager(
int shuffleId,
@@ -91,12 +101,38 @@ public class WriteBufferManager extends MemoryConsumer {
TaskMemoryManager taskMemoryManager,
ShuffleWriteMetrics shuffleWriteMetrics,
RssConf rssConf) {
+ this(
+ shuffleId,
+ null,
+ taskAttemptId,
+ bufferManagerOptions,
+ serializer,
+ partitionToServers,
+ taskMemoryManager,
+ shuffleWriteMetrics,
+ rssConf,
+ null
+ );
+ }
+
+ public WriteBufferManager(
+ int shuffleId,
+ String taskId,
+ long taskAttemptId,
+ BufferManagerOptions bufferManagerOptions,
+ Serializer serializer,
+ Map<Integer, List<ShuffleServerInfo>> partitionToServers,
+ TaskMemoryManager taskMemoryManager,
+ ShuffleWriteMetrics shuffleWriteMetrics,
+ RssConf rssConf,
+ Function<AddBlockEvent, CompletableFuture<Long>> spillFunc) {
super(taskMemoryManager, taskMemoryManager.pageSizeBytes(),
MemoryMode.ON_HEAP);
this.bufferSize = bufferManagerOptions.getBufferSize();
this.spillSize = bufferManagerOptions.getBufferSpillThreshold();
this.instance = serializer.newInstance();
this.buffers = Maps.newHashMap();
this.shuffleId = shuffleId;
+ this.taskId = taskId;
this.taskAttemptId = taskAttemptId;
this.partitionToServers = partitionToServers;
this.shuffleWriteMetrics = shuffleWriteMetrics;
@@ -111,6 +147,9 @@ public class WriteBufferManager extends MemoryConsumer {
.substring(RssSparkConfig.SPARK_RSS_CONFIG_PREFIX.length()),
RssSparkConfig.SPARK_SHUFFLE_COMPRESS_DEFAULT);
this.codec = compress ? Codec.newInstance(rssConf) : null;
+ this.spillFunc = spillFunc;
+ this.sendSizeLimit =
rssConf.get(RssSparkConfig.RSS_CLIENT_SEND_SIZE_LIMITATION);
+ this.memorySpillTimeoutSec =
rssConf.get(RssSparkConfig.RSS_MEMORY_SPILL_TIMEOUT);
}
public List<ShuffleBlockInfo> addRecord(int partitionId, Object key, Object
value) {
@@ -165,7 +204,7 @@ public class WriteBufferManager extends MemoryConsumer {
}
// transform all [partition, records] to [partition, ShuffleBlockInfo] and
clear cache
- public List<ShuffleBlockInfo> clear() {
+ public synchronized List<ShuffleBlockInfo> clear() {
List<ShuffleBlockInfo> result = Lists.newArrayList();
long dataSize = 0;
long memoryUsed = 0;
@@ -247,10 +286,64 @@ public class WriteBufferManager extends MemoryConsumer {
}
}
+ public List<AddBlockEvent> buildBlockEvents(List<ShuffleBlockInfo>
shuffleBlockInfoList) {
+ long totalSize = 0;
+ long memoryUsed = 0;
+ List<AddBlockEvent> events = new ArrayList<>();
+ List<ShuffleBlockInfo> shuffleBlockInfosPerEvent = Lists.newArrayList();
+ for (ShuffleBlockInfo sbi : shuffleBlockInfoList) {
+ totalSize += sbi.getSize();
+ memoryUsed += sbi.getFreeMemory();
+ shuffleBlockInfosPerEvent.add(sbi);
+ // split shuffle data according to the size
+ if (totalSize > sendSizeLimit) {
+ LOG.info("Build event with " + shuffleBlockInfosPerEvent.size()
+ + " blocks and " + totalSize + " bytes");
+ // Use final temporary variables for closures
+ final long _memoryUsed = memoryUsed;
+ events.add(
+ new AddBlockEvent(taskId, shuffleBlockInfosPerEvent, () ->
freeAllocatedMemory(_memoryUsed))
+ );
+ shuffleBlockInfosPerEvent = Lists.newArrayList();
+ totalSize = 0;
+ memoryUsed = 0;
+ }
+ }
+ if (!shuffleBlockInfosPerEvent.isEmpty()) {
+ LOG.info("Build event with " + shuffleBlockInfosPerEvent.size()
+ + " blocks and " + totalSize + " bytes");
+ // Use final temporary variables for closures
+ final long _memoryUsed = memoryUsed;
+ events.add(
+ new AddBlockEvent(taskId, shuffleBlockInfosPerEvent, () ->
freeAllocatedMemory(_memoryUsed))
+ );
+ }
+ return events;
+ }
+
@Override
public long spill(long size, MemoryConsumer trigger) {
- // there is no spill for such situation
- return 0;
+ List<AddBlockEvent> events = buildBlockEvents(clear());
+ List<CompletableFuture<Long>> futures = events.stream().map(x ->
spillFunc.apply(x)).collect(Collectors.toList());
+ CompletableFuture<Void> allOfFutures =
+ CompletableFuture.allOf(futures.toArray(new
CompletableFuture[futures.size()]));
+ try {
+ allOfFutures.get(memorySpillTimeoutSec, TimeUnit.SECONDS);
+ } catch (TimeoutException timeoutException) {
+ // A best effort strategy to wait.
+ // If timeout exception occurs, the underlying tasks won't be cancelled.
+ } finally {
+ long releasedSize = futures.stream().filter(x -> x.isDone()).mapToLong(x
-> {
+ try {
+ return x.get();
+ } catch (Exception e) {
+ return 0;
+ }
+ }).sum();
+ LOG.info("[taskId: {}] Spill triggered by memory consumer of {},
released memory size: {}",
+ taskId, trigger.getClass().getSimpleName(), releasedSize);
+ return releasedSize;
+ }
}
@VisibleForTesting
@@ -307,4 +400,20 @@ public class WriteBufferManager extends MemoryConsumer {
+ estimateTime + "], requireMemoryTime[" + requireMemoryTime
+ "], uncompressedDataLen[" + uncompressedDataLen + "]";
}
+
+ @VisibleForTesting
+ public void setTaskId(String taskId) {
+ this.taskId = taskId;
+ }
+
+ @VisibleForTesting
+ public void setSpillFunc(
+ Function<AddBlockEvent, CompletableFuture<Long>> spillFunc) {
+ this.spillFunc = spillFunc;
+ }
+
+ @VisibleForTesting
+ public void setSendSizeLimit(long sendSizeLimit) {
+ this.sendSizeLimit = sendSizeLimit;
+ }
}
diff --git
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java
new file mode 100644
index 00000000..20711dc0
--- /dev/null
+++
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.writer;
+
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
+import java.util.function.Supplier;
+
+import com.google.common.collect.Maps;
+import com.google.common.collect.Sets;
+import org.junit.jupiter.api.Test;
+
+import org.apache.uniffle.client.impl.ShuffleWriteClientImpl;
+import org.apache.uniffle.client.response.SendShuffleDataResult;
+import org.apache.uniffle.common.ShuffleBlockInfo;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class DataPusherTest {
+
+ static class FakedShuffleWriteClient extends ShuffleWriteClientImpl {
+ private SendShuffleDataResult fakedShuffleDataResult;
+
+ FakedShuffleWriteClient() {
+ this(
+ "GRPC",
+ 1,
+ 1,
+ 10,
+ 1,
+ 1,
+ 1,
+ false,
+ 1,
+ 1,
+ 1,
+ 1
+ );
+ }
+
+ private FakedShuffleWriteClient(String clientType, int retryMax, long
retryIntervalMax, int heartBeatThreadNum,
+ int replica, int replicaWrite, int replicaRead, boolean
replicaSkipEnabled, int dataTransferPoolSize,
+ int dataCommitPoolSize, int unregisterThreadPoolSize, int
unregisterRequestTimeSec) {
+ super(clientType, retryMax, retryIntervalMax, heartBeatThreadNum,
replica, replicaWrite, replicaRead,
+ replicaSkipEnabled, dataTransferPoolSize, dataCommitPoolSize,
unregisterThreadPoolSize,
+ unregisterRequestTimeSec);
+ }
+
+ @Override
+ public SendShuffleDataResult sendShuffleData(String appId,
List<ShuffleBlockInfo> shuffleBlockInfoList,
+ Supplier<Boolean> needCancelRequest) {
+ return fakedShuffleDataResult;
+ }
+
+ public void setFakedShuffleDataResult(SendShuffleDataResult
fakedShuffleDataResult) {
+ this.fakedShuffleDataResult = fakedShuffleDataResult;
+ }
+ }
+
+ @Test
+ public void testSendData() throws ExecutionException, InterruptedException {
+ FakedShuffleWriteClient shuffleWriteClient = new FakedShuffleWriteClient();
+
+ Map<String, Set<Long>> taskToSuccessBlockIds = Maps.newConcurrentMap();
+ Map<String, Set<Long>> taskToFailedBlockIds = Maps.newConcurrentMap();
+ Set<String> failedTaskIds = new HashSet<>();
+
+ DataPusher dataPusher = new DataPusher(
+ shuffleWriteClient,
+ taskToSuccessBlockIds,
+ taskToFailedBlockIds,
+ failedTaskIds,
+ 1,
+ 2
+ );
+ dataPusher.setRssAppId("testSendData_appId");
+
+ // sync send
+ AddBlockEvent event = new AddBlockEvent("taskId", Arrays.asList(
+ new ShuffleBlockInfo(
+ 1, 1, 1, 1, 1, new byte[1], null, 1, 100, 1
+ ))
+ );
+ shuffleWriteClient.setFakedShuffleDataResult(
+ new SendShuffleDataResult(
+ Sets.newHashSet(1L, 2L),
+ Sets.newHashSet(3L, 4L)
+ )
+ );
+ CompletableFuture<Long> future = dataPusher.send(event);
+ long memoryFree = future.get();
+ assertEquals(100, memoryFree);
+ assertTrue(taskToSuccessBlockIds.get("taskId").contains(1L));
+ assertTrue(taskToSuccessBlockIds.get("taskId").contains(2L));
+ assertTrue(taskToFailedBlockIds.get("taskId").contains(3L));
+ assertTrue(taskToFailedBlockIds.get("taskId").contains(4L));
+ }
+}
diff --git
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
index 4f57e265..87580541 100644
---
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
+++
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
@@ -17,7 +17,11 @@
package org.apache.spark.shuffle.writer;
+import java.util.Arrays;
import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.TimeUnit;
+import java.util.function.Function;
import com.google.common.collect.Maps;
import org.apache.commons.lang.reflect.FieldUtils;
@@ -27,6 +31,7 @@ import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.serializer.KryoSerializer;
import org.apache.spark.serializer.Serializer;
import org.apache.spark.shuffle.RssSparkConfig;
+import org.awaitility.Awaitility;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
@@ -195,4 +200,78 @@ public class WriteBufferManagerTest {
sbi = wbm.createShuffleBlock(1, mockWriterBuffer);
assertEquals(35184374185984L, sbi.getBlockId());
}
+
+ @Test
+ public void buildBlockEventsTest() {
+ SparkConf conf = getConf();
+ conf.set("spark.rss.client.send.size.limit", "30");
+
+ TaskMemoryManager mockTaskMemoryManager = mock(TaskMemoryManager.class);
+
+ BufferManagerOptions bufferOptions = new BufferManagerOptions(conf);
+ WriteBufferManager wbm = new WriteBufferManager(
+ 0, 0, bufferOptions, new KryoSerializer(conf),
+ Maps.newHashMap(), mockTaskMemoryManager, new ShuffleWriteMetrics(),
RssSparkConfig.toRssConf(conf));
+
+ // every block: length=4, memoryUsed=12
+ ShuffleBlockInfo info1 = new ShuffleBlockInfo(1, 1, 1, 4, 1, new byte[1],
null, 1, 12, 1);
+ ShuffleBlockInfo info2 = new ShuffleBlockInfo(1, 1, 1, 4, 1, new byte[1],
null, 1, 12, 1);
+ ShuffleBlockInfo info3 = new ShuffleBlockInfo(1, 1, 1, 4, 1, new byte[1],
null, 1, 12, 1);
+ List<AddBlockEvent> events = wbm.buildBlockEvents(Arrays.asList(info1,
info2, info3));
+ assertEquals(3, events.size());
+ }
+
+ @Test
+ public void spillTest() {
+ SparkConf conf = getConf();
+ conf.set("spark.rss.client.send.size.limit", "1000");
+ TaskMemoryManager mockTaskMemoryManager = mock(TaskMemoryManager.class);
+ BufferManagerOptions bufferOptions = new BufferManagerOptions(conf);
+
+ Function<AddBlockEvent, CompletableFuture<Long>> spillFunc = event -> {
+ event.getProcessedCallbackChain().stream().forEach(x -> x.run());
+ return CompletableFuture.completedFuture(
+ event.getShuffleDataInfoList().stream().mapToLong(x ->
x.getFreeMemory()).sum()
+ );
+ };
+
+ WriteBufferManager wbm = new WriteBufferManager(
+ 0, "taskId_spillTest", 0, bufferOptions, new KryoSerializer(conf),
+ Maps.newHashMap(), mockTaskMemoryManager, new ShuffleWriteMetrics(),
+ RssSparkConfig.toRssConf(conf), spillFunc);
+ WriteBufferManager spyManager = spy(wbm);
+ doReturn(512L).when(spyManager).acquireMemory(anyLong());
+
+ String testKey = "Key";
+ String testValue = "Value";
+ spyManager.addRecord(0, testKey, testValue);
+ spyManager.addRecord(1, testKey, testValue);
+
+ // case1. all events are flushed within normal time.
+ long releasedSize = spyManager.spill(1000, mock(WriteBufferManager.class));
+ assertEquals(64, releasedSize);
+
+ // case2. partial events are not flushed within normal time.
+ // when calling spill func, 2 events should be spilled. But
+ // only event will be finished in the expected time.
+ spyManager.setSendSizeLimit(30);
+ spyManager.addRecord(0, testKey, testValue);
+ spyManager.addRecord(1, testKey, testValue);
+ spyManager.setSpillFunc(event -> CompletableFuture.supplyAsync(() -> {
+ int partitionId = event.getShuffleDataInfoList().get(0).getPartitionId();
+ if (partitionId == 1) {
+ try {
+ Thread.sleep(2000);
+ } catch (InterruptedException interruptedException) {
+ // ignore.
+ }
+ }
+ event.getProcessedCallbackChain().stream().forEach(x -> x.run());
+ return event.getShuffleDataInfoList().stream().mapToLong(x ->
x.getFreeMemory()).sum();
+ }));
+ releasedSize = spyManager.spill(1000, mock(WriteBufferManager.class));
+ assertEquals(32, releasedSize);
+ assertEquals(32, spyManager.getUsedBytes());
+ Awaitility.await().timeout(3, TimeUnit.SECONDS).until(() ->
spyManager.getUsedBytes() == 0);
+ }
}
diff --git
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 0b301154..77b9ef0d 100644
---
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -17,17 +17,17 @@
package org.apache.spark.shuffle;
+import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
+import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ScheduledExecutorService;
-import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import com.google.common.annotations.VisibleForTesting;
-import com.google.common.collect.Queues;
import com.google.common.collect.Sets;
import org.apache.hadoop.conf.Configuration;
import org.apache.spark.ShuffleDependency;
@@ -39,11 +39,11 @@ import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.shuffle.reader.RssShuffleReader;
import org.apache.spark.shuffle.writer.AddBlockEvent;
import org.apache.spark.shuffle.writer.BufferManagerOptions;
+import org.apache.spark.shuffle.writer.DataPusher;
import org.apache.spark.shuffle.writer.RssShuffleWriter;
import org.apache.spark.shuffle.writer.WriteBufferManager;
import org.apache.spark.storage.BlockId;
import org.apache.spark.storage.BlockManagerId;
-import org.apache.spark.util.EventLoop;
import org.roaringbitmap.longlong.Roaring64NavigableMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -54,12 +54,10 @@ import scala.collection.Seq;
import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.factory.ShuffleClientFactory;
-import org.apache.uniffle.client.response.SendShuffleDataResult;
import org.apache.uniffle.client.util.ClientUtils;
import org.apache.uniffle.common.PartitionRange;
import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.ShuffleAssignmentsInfo;
-import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleDataDistributionType;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.exception.RssException;
@@ -80,7 +78,6 @@ public class RssShuffleManager implements ShuffleManager {
private ShuffleWriteClient shuffleWriteClient;
private Map<String, Set<Long>> taskToSuccessBlockIds =
JavaUtils.newConcurrentMap();
private Map<String, Set<Long>> taskToFailedBlockIds =
JavaUtils.newConcurrentMap();
- private Map<String, WriteBufferManager> taskToBufferManager =
JavaUtils.newConcurrentMap();
private final int dataReplica;
private final int dataReplicaWrite;
private final int dataReplicaRead;
@@ -92,58 +89,7 @@ public class RssShuffleManager implements ShuffleManager {
private boolean dynamicConfEnabled = false;
private final String user;
private final String uuid;
- private ThreadPoolExecutor threadPoolExecutor;
- private EventLoop<AddBlockEvent> eventLoop = new
EventLoop<AddBlockEvent>("ShuffleDataQueue") {
-
- @Override
- public void onReceive(AddBlockEvent event) {
- threadPoolExecutor.execute(() -> sendShuffleData(event.getTaskId(),
event.getShuffleDataInfoList()));
- }
-
- private void sendShuffleData(String taskId, List<ShuffleBlockInfo>
shuffleDataInfoList) {
- try {
- SendShuffleDataResult result = shuffleWriteClient.sendShuffleData(
- appId,
- shuffleDataInfoList,
- () -> !isValidTask(taskId)
- );
- putBlockId(taskToSuccessBlockIds, taskId, result.getSuccessBlockIds());
- putBlockId(taskToFailedBlockIds, taskId, result.getFailedBlockIds());
- } finally {
- // data is already send, release the memory to executor
- long releaseSize = 0;
- for (ShuffleBlockInfo sbi : shuffleDataInfoList) {
- releaseSize += sbi.getFreeMemory();
- }
- WriteBufferManager bufferManager = taskToBufferManager.get(taskId);
- if (bufferManager != null) {
- bufferManager.freeAllocatedMemory(releaseSize);
- }
- LOG.debug("Finish send data and release " + releaseSize + " bytes");
- }
- }
-
- private synchronized void putBlockId(
- Map<String, Set<Long>> taskToBlockIds,
- String taskAttemptId,
- Set<Long> blockIds) {
- if (blockIds == null) {
- return;
- }
- if (taskToBlockIds.get(taskAttemptId) == null) {
- taskToBlockIds.put(taskAttemptId, Sets.newConcurrentHashSet());
- }
- taskToBlockIds.get(taskAttemptId).addAll(blockIds);
- }
-
- @Override
- public void onError(Throwable throwable) {
- }
-
- @Override
- public void onStart() {
- }
- };
+ private DataPusher dataPusher;
public RssShuffleManager(SparkConf sparkConf, boolean isDriver) {
if (sparkConf.getBoolean("spark.sql.adaptive.enabled", false)) {
@@ -193,19 +139,22 @@ public class RssShuffleManager implements ShuffleManager {
sparkConf.set("spark.shuffle.reduceLocality.enabled", "false");
LOG.info("Disable shuffle data locality in RssShuffleManager.");
if (!sparkConf.getBoolean(RssSparkConfig.RSS_TEST_FLAG.key(), false)) {
- // for non-driver executor, start a thread for sending shuffle data to
shuffle server
- LOG.info("RSS data send thread is starting");
- eventLoop.start();
- int poolSize =
sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_SIZE);
- int keepAliveTime =
sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_KEEPALIVE);
- threadPoolExecutor = new ThreadPoolExecutor(poolSize, poolSize * 2,
keepAliveTime, TimeUnit.SECONDS,
- Queues.newLinkedBlockingQueue(Integer.MAX_VALUE),
- ThreadUtils.getThreadFactory("SendData"));
-
if (isDriver) {
heartBeatScheduledExecutorService =
ThreadUtils.getDaemonSingleThreadScheduledExecutor("rss-heartbeat");
}
+ // for non-driver executor, start a thread for sending shuffle data to
shuffle server
+ LOG.info("RSS data pusher is starting...");
+ int poolSize =
sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_SIZE);
+ int keepAliveTime =
sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_KEEPALIVE);
+ this.dataPusher = new DataPusher(
+ shuffleWriteClient,
+ taskToSuccessBlockIds,
+ taskToFailedBlockIds,
+ failedTaskIds,
+ poolSize,
+ keepAliveTime
+ );
}
}
@@ -236,6 +185,7 @@ public class RssShuffleManager implements ShuffleManager {
// will be called many times depend on how many shuffle stage
if ("".equals(appId)) {
appId = SparkEnv.get().conf().getAppId() + "_" + uuid;
+ dataPusher.setRssAppId(appId);
LOG.info("Generate application id used in rss: " + appId);
}
@@ -344,6 +294,13 @@ public class RssShuffleManager implements ShuffleManager {
shuffleWriteClient.registerCoordinators(coordinators);
}
+ public CompletableFuture<Long> sendData(AddBlockEvent event) {
+ if (dataPusher != null && event != null) {
+ return dataPusher.send(event);
+ }
+ return new CompletableFuture<>();
+ }
+
// This method is called in Spark executor,
// getting information from Spark driver via the ShuffleHandle.
@Override
@@ -352,6 +309,7 @@ public class RssShuffleManager implements ShuffleManager {
if (handle instanceof RssShuffleHandle) {
RssShuffleHandle<K, V, ?> rssHandle = (RssShuffleHandle<K, V, ?>) handle;
appId = rssHandle.getAppId();
+ dataPusher.setRssAppId(appId);
int shuffleId = rssHandle.getShuffleId();
String taskId = "" + context.taskAttemptId() + "_" +
context.attemptNumber();
@@ -359,15 +317,16 @@ public class RssShuffleManager implements ShuffleManager {
ShuffleWriteMetrics writeMetrics =
context.taskMetrics().shuffleWriteMetrics();
WriteBufferManager bufferManager = new WriteBufferManager(
shuffleId,
+ taskId,
context.taskAttemptId(),
bufferOptions,
rssHandle.getDependency().serializer(),
rssHandle.getPartitionToServers(),
context.taskMemoryManager(),
writeMetrics,
- RssSparkConfig.toRssConf(sparkConf)
+ RssSparkConfig.toRssConf(sparkConf),
+ this::sendData
);
- taskToBufferManager.put(taskId, bufferManager);
return new RssShuffleWriter<>(rssHandle.getAppId(), shuffleId, taskId,
context.taskAttemptId(), bufferManager,
writeMetrics, this, sparkConf, shuffleWriteClient, rssHandle,
@@ -448,7 +407,13 @@ public class RssShuffleManager implements ShuffleManager {
if (heartBeatScheduledExecutorService != null) {
heartBeatScheduledExecutorService.shutdownNow();
}
- threadPoolExecutor.shutdownNow();
+ if (dataPusher != null) {
+ try {
+ dataPusher.close();
+ } catch (IOException e) {
+ LOG.warn("Errors on closing data pusher", e);
+ }
+ }
shuffleWriteClient.close();
}
@@ -457,15 +422,6 @@ public class RssShuffleManager implements ShuffleManager {
throw new RssException("RssShuffleManager.shuffleBlockResolver is not
implemented");
}
- public EventLoop<AddBlockEvent> getEventLoop() {
- return eventLoop;
- }
-
- @VisibleForTesting
- public void setEventLoop(EventLoop<AddBlockEvent> eventLoop) {
- this.eventLoop = eventLoop;
- }
-
// when speculation enable, duplicate data will be sent and reported to
shuffle server,
// get the actual tasks and filter the duplicate data caused by speculation
task
private Roaring64NavigableMap getExpectedTasks(int shuffleId, int
startPartition, int endPartition) {
@@ -520,15 +476,9 @@ public class RssShuffleManager implements ShuffleManager {
taskToSuccessBlockIds.get(taskId).addAll(blockIds);
}
- @VisibleForTesting
- public Map<String, WriteBufferManager> getTaskToBufferManager() {
- return taskToBufferManager;
- }
-
public void clearTaskMeta(String taskId) {
taskToSuccessBlockIds.remove(taskId);
taskToFailedBlockIds.remove(taskId);
- taskToBufferManager.remove(taskId);
}
@VisibleForTesting
@@ -550,4 +500,12 @@ public class RssShuffleManager implements ShuffleManager {
public boolean isValidTask(String taskId) {
return !failedTaskIds.contains(taskId);
}
+
+ public DataPusher getDataPusher() {
+ return dataPusher;
+ }
+
+ public void setDataPusher(DataPusher dataPusher) {
+ this.dataPusher = dataPusher;
+ }
}
diff --git
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 81445f44..b0b7653f 100644
---
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -81,7 +81,6 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
private RssShuffleManager shuffleManager;
private long sendCheckTimeout;
private long sendCheckInterval;
- private long sendSizeLimit;
private boolean isMemoryShuffleEnabled;
private final Function<String, Boolean> taskFailureCallback;
@@ -136,8 +135,6 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
this.shouldPartition = partitioner.numPartitions() > 1;
this.sendCheckTimeout =
sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_CHECK_TIMEOUT_MS);
this.sendCheckInterval =
sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_CHECK_INTERVAL_MS);
- this.sendSizeLimit =
sparkConf.getSizeAsBytes(RssSparkConfig.RSS_CLIENT_SEND_SIZE_LIMIT.key(),
- RssSparkConfig.RSS_CLIENT_SEND_SIZE_LIMIT.defaultValue().get());
this.bitmapSplitNum =
sparkConf.get(RssSparkConfig.RSS_CLIENT_BITMAP_SPLIT_NUM);
this.partitionToBlockIds = Maps.newHashMap();
this.shuffleWriteClient = shuffleWriteClient;
@@ -233,26 +230,8 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
// don't send huge block to shuffle server, or there will be OOM if shuffle
sever receives data more than expected
protected void postBlockEvent(List<ShuffleBlockInfo> shuffleBlockInfoList) {
- long totalSize = 0;
- List<ShuffleBlockInfo> shuffleBlockInfosPerEvent = Lists.newArrayList();
- for (ShuffleBlockInfo sbi : shuffleBlockInfoList) {
- totalSize += sbi.getSize();
- shuffleBlockInfosPerEvent.add(sbi);
- // split shuffle data according to the size
- if (totalSize > sendSizeLimit) {
- LOG.debug("Post event to queue with " +
shuffleBlockInfosPerEvent.size()
- + " blocks and " + totalSize + " bytes");
- shuffleManager.getEventLoop().post(
- new AddBlockEvent(taskId, shuffleBlockInfosPerEvent));
- shuffleBlockInfosPerEvent = Lists.newArrayList();
- totalSize = 0;
- }
- }
- if (!shuffleBlockInfosPerEvent.isEmpty()) {
- LOG.debug("Post event to queue with " + shuffleBlockInfosPerEvent.size()
- + " blocks and " + totalSize + " bytes");
- shuffleManager.getEventLoop().post(
- new AddBlockEvent(taskId, shuffleBlockInfosPerEvent));
+ for (AddBlockEvent event :
bufferManager.buildBlockEvents(shuffleBlockInfoList)) {
+ shuffleManager.sendData(event);
}
}
diff --git
a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
index 81d1bd23..61ff6325 100644
---
a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
+++
b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
@@ -21,6 +21,8 @@ import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Set;
+import java.util.concurrent.CompletableFuture;
+import java.util.function.Function;
import java.util.stream.Collectors;
import com.google.common.collect.Lists;
@@ -38,7 +40,6 @@ import org.apache.spark.serializer.Serializer;
import org.apache.spark.shuffle.RssShuffleHandle;
import org.apache.spark.shuffle.RssShuffleManager;
import org.apache.spark.shuffle.RssSparkConfig;
-import org.apache.spark.util.EventLoop;
import org.junit.jupiter.api.Test;
import scala.Product2;
import scala.Tuple2;
@@ -51,7 +52,6 @@ import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.storage.util.StorageType;
import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.anyLong;
@@ -126,11 +126,36 @@ public class RssShuffleWriterTest {
manager.clearTaskMeta(taskId);
assertTrue(manager.getSuccessBlockIds(taskId).isEmpty());
assertTrue(manager.getFailedBlockIds(taskId).isEmpty());
- assertNull(manager.getTaskToBufferManager().get(taskId));
sc.stop();
}
+ static class FakedDataPusher extends DataPusher {
+ private final Function<AddBlockEvent, CompletableFuture<Long>> sendFunc;
+
+ FakedDataPusher(Function<AddBlockEvent, CompletableFuture<Long>> sendFunc)
{
+ this(null, null, null, null, 1, 1, sendFunc);
+ }
+
+ private FakedDataPusher(
+ ShuffleWriteClient shuffleWriteClient,
+ Map<String, Set<Long>> taskToSuccessBlockIds,
+ Map<String, Set<Long>> taskToFailedBlockIds,
+ Set<String> failedTaskIds,
+ int threadPoolSize,
+ int threadKeepAliveTime,
+ Function<AddBlockEvent, CompletableFuture<Long>> sendFunc) {
+ super(shuffleWriteClient, taskToSuccessBlockIds, taskToFailedBlockIds,
failedTaskIds, threadPoolSize,
+ threadKeepAliveTime);
+ this.sendFunc = sendFunc;
+ }
+
+ @Override
+ public CompletableFuture<Long> send(AddBlockEvent event) {
+ return sendFunc.apply(event);
+ }
+ }
+
@Test
public void writeTest() throws Exception {
SparkConf conf = new SparkConf();
@@ -149,21 +174,15 @@ public class RssShuffleWriterTest {
RssShuffleManager manager = new RssShuffleManager(conf, false);
List<ShuffleBlockInfo> shuffleBlockInfos = Lists.newArrayList();
- manager.setEventLoop(new EventLoop<AddBlockEvent>("test") {
- @Override
- public void onReceive(AddBlockEvent event) {
- assertEquals("taskId", event.getTaskId());
- shuffleBlockInfos.addAll(event.getShuffleDataInfoList());
- Set<Long> blockIds = event.getShuffleDataInfoList().parallelStream()
- .map(sdi -> sdi.getBlockId()).collect(Collectors.toSet());
- manager.addSuccessBlockIds(event.getTaskId(), blockIds);
- }
-
- @Override
- public void onError(Throwable e) {
- }
+ DataPusher dataPusher = new FakedDataPusher(event -> {
+ assertEquals("taskId", event.getTaskId());
+ shuffleBlockInfos.addAll(event.getShuffleDataInfoList());
+ Set<Long> blockIds = event.getShuffleDataInfoList().parallelStream()
+ .map(sdi -> sdi.getBlockId()).collect(Collectors.toSet());
+ manager.addSuccessBlockIds(event.getTaskId(), blockIds);
+ return CompletableFuture.completedFuture(0L);
});
- manager.getEventLoop().start();
+ manager.setDataPusher(dataPusher);
Partitioner mockPartitioner = mock(Partitioner.class);
ShuffleDependency<String, String, String> mockDependency =
mock(ShuffleDependency.class);
@@ -200,8 +219,8 @@ public class RssShuffleWriterTest {
ShuffleWriteMetrics shuffleWriteMetrics = new ShuffleWriteMetrics();
BufferManagerOptions bufferOptions = new BufferManagerOptions(conf);
WriteBufferManager bufferManager = new WriteBufferManager(
- 0, 0, bufferOptions, kryoSerializer,
- partitionToServers, mockTaskMemoryManager, shuffleWriteMetrics, new
RssConf());
+ 0, "taskId", 0, bufferOptions, kryoSerializer,
+ partitionToServers, mockTaskMemoryManager, shuffleWriteMetrics, new
RssConf(), null);
WriteBufferManager bufferManagerSpy = spy(bufferManager);
doReturn(1000000L).when(bufferManagerSpy).acquireMemory(anyLong());
@@ -253,37 +272,46 @@ public class RssShuffleWriterTest {
@Test
public void postBlockEventTest() throws Exception {
- final WriteBufferManager mockBufferManager =
mock(WriteBufferManager.class);
final ShuffleWriteMetrics mockMetrics = mock(ShuffleWriteMetrics.class);
ShuffleDependency<String, String, String> mockDependency =
mock(ShuffleDependency.class);
Partitioner mockPartitioner = mock(Partitioner.class);
- final RssShuffleManager mockShuffleManager = mock(RssShuffleManager.class);
when(mockDependency.partitioner()).thenReturn(mockPartitioner);
when(mockPartitioner.numPartitions()).thenReturn(2);
List<AddBlockEvent> events = Lists.newArrayList();
- EventLoop<AddBlockEvent> eventLoop = new EventLoop<AddBlockEvent>("test") {
- @Override
- public void onReceive(AddBlockEvent event) {
- events.add(event);
- }
+ SparkConf conf = new SparkConf();
+ conf.setAppName("postBlockEventTest").setMaster("local[2]")
+ .set(RssSparkConfig.RSS_TEST_FLAG.key(), "true")
+ .set(RssSparkConfig.RSS_TEST_MODE_ENABLE.key(), "true")
+ .set(RssSparkConfig.RSS_WRITER_BUFFER_SIZE.key(), "32")
+ .set(RssSparkConfig.RSS_WRITER_SERIALIZER_BUFFER_SIZE.key(), "32")
+ .set(RssSparkConfig.RSS_WRITER_BUFFER_SEGMENT_SIZE.key(), "64")
+ .set(RssSparkConfig.RSS_WRITER_BUFFER_SPILL_SIZE.key(), "128")
+ .set(RssSparkConfig.RSS_CLIENT_SEND_CHECK_INTERVAL_MS.key(), "1000")
+ .set(RssSparkConfig.RSS_STORAGE_TYPE.key(),
StorageType.LOCALFILE.name())
+ .set(RssSparkConfig.RSS_COORDINATOR_QUORUM.key(),
"127.0.0.1:12345,127.0.0.1:12346")
+ .set(RssSparkConfig.SPARK_RSS_CONFIG_PREFIX +
RssSparkConfig.RSS_CLIENT_SEND_SIZE_LIMITATION.key(), "64")
+ .set(RssSparkConfig.RSS_STORAGE_TYPE.key(),
StorageType.LOCALFILE.name());
- @Override
- public void onError(Throwable e) {
- }
- };
- eventLoop.start();
+ TaskMemoryManager mockTaskMemoryManager = mock(TaskMemoryManager.class);
+ BufferManagerOptions bufferOptions = new BufferManagerOptions(conf);
+ WriteBufferManager bufferManager = new WriteBufferManager(
+ 0, 0, bufferOptions, new KryoSerializer(conf),
+ Maps.newHashMap(), mockTaskMemoryManager, new ShuffleWriteMetrics(),
RssSparkConfig.toRssConf(conf));
+
+ RssShuffleManager manager = new RssShuffleManager(conf, false);
+ DataPusher dataPusher = new FakedDataPusher(event -> {
+ events.add(event);
+ return CompletableFuture.completedFuture(0L);
+ });
+ manager.setDataPusher(dataPusher);
- when(mockShuffleManager.getEventLoop()).thenReturn(eventLoop);
RssShuffleHandle<String, String, String> mockHandle =
mock(RssShuffleHandle.class);
when(mockHandle.getDependency()).thenReturn(mockDependency);
ShuffleWriteClient mockWriteClient = mock(ShuffleWriteClient.class);
- SparkConf conf = new SparkConf();
- conf.set(RssSparkConfig.RSS_CLIENT_SEND_SIZE_LIMIT.key(), "64")
- .set(RssSparkConfig.RSS_STORAGE_TYPE.key(),
StorageType.LOCALFILE.name());
RssShuffleWriter<String, String, String> writer = new
RssShuffleWriter<>("appId", 0, "taskId", 1L,
- mockBufferManager, mockMetrics, mockShuffleManager, conf,
mockWriteClient, mockHandle);
+ bufferManager, mockMetrics, manager, conf, mockWriteClient,
mockHandle);
List<ShuffleBlockInfo> shuffleBlockInfoList = createShuffleBlockList(1,
31);
writer.postBlockEvent(shuffleBlockInfoList);
Thread.sleep(500);
diff --git
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 04b57136..e875132d 100644
---
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -17,21 +17,20 @@
package org.apache.spark.shuffle;
+import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
+import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ScheduledExecutorService;
-import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
-import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.stream.Collectors;
import com.google.common.annotations.VisibleForTesting;
-import com.google.common.collect.Queues;
import com.google.common.collect.Sets;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import org.apache.hadoop.conf.Configuration;
@@ -46,12 +45,12 @@ import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.shuffle.reader.RssShuffleReader;
import org.apache.spark.shuffle.writer.AddBlockEvent;
import org.apache.spark.shuffle.writer.BufferManagerOptions;
+import org.apache.spark.shuffle.writer.DataPusher;
import org.apache.spark.shuffle.writer.RssShuffleWriter;
import org.apache.spark.shuffle.writer.WriteBufferManager;
import org.apache.spark.sql.internal.SQLConf;
import org.apache.spark.storage.BlockId;
import org.apache.spark.storage.BlockManagerId;
-import org.apache.spark.util.EventLoop;
import org.roaringbitmap.longlong.Roaring64NavigableMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -62,12 +61,10 @@ import scala.collection.Seq;
import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.factory.ShuffleClientFactory;
-import org.apache.uniffle.client.response.SendShuffleDataResult;
import org.apache.uniffle.client.util.ClientUtils;
import org.apache.uniffle.common.PartitionRange;
import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.ShuffleAssignmentsInfo;
-import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleDataDistributionType;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.RssClientConf;
@@ -84,7 +81,6 @@ public class RssShuffleManager implements ShuffleManager {
private final String clientType;
private final long heartbeatInterval;
private final long heartbeatTimeout;
- private final ThreadPoolExecutor threadPoolExecutor;
private AtomicReference<String> id = new AtomicReference<>();
private SparkConf sparkConf;
private final int dataReplica;
@@ -96,7 +92,6 @@ public class RssShuffleManager implements ShuffleManager {
private ShuffleWriteClient shuffleWriteClient;
private final Map<String, Set<Long>> taskToSuccessBlockIds;
private final Map<String, Set<Long>> taskToFailedBlockIds;
- private Map<String, WriteBufferManager> taskToBufferManager =
JavaUtils.newConcurrentMap();
private ScheduledExecutorService heartBeatScheduledExecutorService;
private boolean heartbeatStarted = false;
private boolean dynamicConfEnabled = false;
@@ -104,55 +99,7 @@ public class RssShuffleManager implements ShuffleManager {
private String user;
private String uuid;
private Set<String> failedTaskIds = Sets.newConcurrentHashSet();
- private final EventLoop<AddBlockEvent> eventLoop;
- private final EventLoop<AddBlockEvent> defaultEventLoop = new
EventLoop<AddBlockEvent>("ShuffleDataQueue") {
-
- @Override
- public void onReceive(AddBlockEvent event) {
- threadPoolExecutor.execute(() -> sendShuffleData(event.getTaskId(),
event.getShuffleDataInfoList()));
- }
-
- @Override
- public void onError(Throwable throwable) {
- LOG.info("Shuffle event loop error...", throwable);
- }
-
- @Override
- public void onStart() {
- LOG.info("Shuffle event loop start...");
- }
-
- private void sendShuffleData(String taskId, List<ShuffleBlockInfo>
shuffleDataInfoList) {
- try {
- SendShuffleDataResult result = shuffleWriteClient.sendShuffleData(
- id.get(),
- shuffleDataInfoList,
- () -> !isValidTask(taskId)
- );
- putBlockId(taskToSuccessBlockIds, taskId, result.getSuccessBlockIds());
- putBlockId(taskToFailedBlockIds, taskId, result.getFailedBlockIds());
- } finally {
- final AtomicLong releaseSize = new AtomicLong(0);
- shuffleDataInfoList.forEach((sbi) ->
releaseSize.addAndGet(sbi.getFreeMemory()));
- WriteBufferManager bufferManager = taskToBufferManager.get(taskId);
- if (bufferManager != null) {
- bufferManager.freeAllocatedMemory(releaseSize.get());
- }
- LOG.debug("Spark 3.0 finish send data and release " + releaseSize + "
bytes");
- }
- }
-
- private synchronized void putBlockId(
- Map<String, Set<Long>> taskToBlockIds,
- String taskAttemptId,
- Set<Long> blockIds) {
- if (blockIds == null || blockIds.isEmpty()) {
- return;
- }
- taskToBlockIds.putIfAbsent(taskAttemptId, Sets.newConcurrentHashSet());
- taskToBlockIds.get(taskAttemptId).addAll(blockIds);
- }
- };
+ private DataPusher dataPusher;
public RssShuffleManager(SparkConf conf, boolean isDriver) {
this.sparkConf = conf;
@@ -211,19 +158,28 @@ public class RssShuffleManager implements ShuffleManager {
LOG.info("Disable shuffle data locality in RssShuffleManager.");
taskToSuccessBlockIds = JavaUtils.newConcurrentMap();
taskToFailedBlockIds = JavaUtils.newConcurrentMap();
- // for non-driver executor, start a thread for sending shuffle data to
shuffle server
- LOG.info("RSS data send thread is starting");
- eventLoop = defaultEventLoop;
- eventLoop.start();
- int poolSize =
sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_SIZE);
- int keepAliveTime =
sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_KEEPALIVE);
- threadPoolExecutor = new ThreadPoolExecutor(poolSize, poolSize * 2,
keepAliveTime, TimeUnit.SECONDS,
- Queues.newLinkedBlockingQueue(Integer.MAX_VALUE),
- ThreadUtils.getThreadFactory("SendData"));
if (isDriver) {
heartBeatScheduledExecutorService =
ThreadUtils.getDaemonSingleThreadScheduledExecutor("rss-heartbeat");
}
+ LOG.info("Rss data pusher is starting...");
+ int poolSize =
sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_SIZE);
+ int keepAliveTime =
sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_KEEPALIVE);
+ this.dataPusher = new DataPusher(
+ shuffleWriteClient,
+ taskToSuccessBlockIds,
+ taskToFailedBlockIds,
+ failedTaskIds,
+ poolSize,
+ keepAliveTime
+ );
+ }
+
+ public CompletableFuture<Long> sendData(AddBlockEvent event) {
+ if (dataPusher != null && event != null) {
+ return dataPusher.send(event);
+ }
+ return new CompletableFuture<>();
}
@VisibleForTesting
@@ -242,7 +198,7 @@ public class RssShuffleManager implements ShuffleManager {
RssShuffleManager(
SparkConf conf,
boolean isDriver,
- EventLoop<AddBlockEvent> loop,
+ DataPusher dataPusher,
Map<String, Set<Long>> taskToSuccessBlockIds,
Map<String, Set<Long>> taskToFailedBlockIds) {
this.sparkConf = conf;
@@ -283,14 +239,8 @@ public class RssShuffleManager implements ShuffleManager {
);
this.taskToSuccessBlockIds = taskToSuccessBlockIds;
this.taskToFailedBlockIds = taskToFailedBlockIds;
- if (loop != null) {
- eventLoop = loop;
- } else {
- eventLoop = defaultEventLoop;
- }
- eventLoop.start();
- threadPoolExecutor = null;
- heartBeatScheduledExecutorService = null;
+ this.heartBeatScheduledExecutorService = null;
+ this.dataPusher = dataPusher;
}
// This method is called in Spark driver side,
@@ -315,6 +265,7 @@ public class RssShuffleManager implements ShuffleManager {
if (id.get() == null) {
id.compareAndSet(null, SparkEnv.get().conf().getAppId() + "_" + uuid);
+ dataPusher.setRssAppId(id.get());
}
LOG.info("Generate application id used in rss: " + id.get());
@@ -390,6 +341,7 @@ public class RssShuffleManager implements ShuffleManager {
// todo: this implement is tricky, we should refactor it
if (id.get() == null) {
id.compareAndSet(null, rssHandle.getAppId());
+ dataPusher.setRssAppId(id.get());
}
int shuffleId = rssHandle.getShuffleId();
String taskId = "" + context.taskAttemptId() + "_" +
context.attemptNumber();
@@ -401,10 +353,9 @@ public class RssShuffleManager implements ShuffleManager {
writeMetrics = context.taskMetrics().shuffleWriteMetrics();
}
WriteBufferManager bufferManager = new WriteBufferManager(
- shuffleId, context.taskAttemptId(), bufferOptions,
rssHandle.getDependency().serializer(),
+ shuffleId, taskId, context.taskAttemptId(), bufferOptions,
rssHandle.getDependency().serializer(),
rssHandle.getPartitionToServers(), context.taskMemoryManager(),
- writeMetrics, RssSparkConfig.toRssConf(sparkConf));
- taskToBufferManager.put(taskId, bufferManager);
+ writeMetrics, RssSparkConfig.toRssConf(sparkConf), this::sendData);
LOG.info("RssHandle appId {} shuffleId {} ", rssHandle.getAppId(),
rssHandle.getShuffleId());
return new RssShuffleWriter<>(rssHandle.getAppId(), shuffleId, taskId,
context.taskAttemptId(), bufferManager,
writeMetrics, this, sparkConf, shuffleWriteClient, rssHandle,
@@ -660,21 +611,21 @@ public class RssShuffleManager implements ShuffleManager {
if (heartBeatScheduledExecutorService != null) {
heartBeatScheduledExecutorService.shutdownNow();
}
- if (threadPoolExecutor != null) {
- threadPoolExecutor.shutdownNow();
- }
if (shuffleWriteClient != null) {
shuffleWriteClient.close();
}
- if (eventLoop != null) {
- eventLoop.stop();
+ if (dataPusher != null) {
+ try {
+ dataPusher.close();
+ } catch (IOException e) {
+ LOG.warn("Errors on closing data pusher", e);
+ }
}
}
public void clearTaskMeta(String taskId) {
taskToSuccessBlockIds.remove(taskId);
taskToFailedBlockIds.remove(taskId);
- taskToBufferManager.remove(taskId);
}
@VisibleForTesting
@@ -736,12 +687,6 @@ public class RssShuffleManager implements ShuffleManager {
}
}
- public void postEvent(AddBlockEvent addBlockEvent) {
- if (eventLoop != null) {
- eventLoop.post(addBlockEvent);
- }
- }
-
public Set<Long> getFailedBlockIds(String taskId) {
Set<Long> result = taskToFailedBlockIds.get(taskId);
if (result == null) {
diff --git
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 4784c390..fc1e32f5 100644
---
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -75,7 +75,6 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
private final boolean shouldPartition;
private final long sendCheckTimeout;
private final long sendCheckInterval;
- private final long sendSizeLimit;
private final int bitmapSplitNum;
private final Map<Integer, Set<Long>> partitionToBlockIds;
private final ShuffleWriteClient shuffleWriteClient;
@@ -137,8 +136,6 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
this.shouldPartition = partitioner.numPartitions() > 1;
this.sendCheckTimeout =
sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_CHECK_TIMEOUT_MS);
this.sendCheckInterval =
sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_CHECK_INTERVAL_MS);
- this.sendSizeLimit =
sparkConf.getSizeAsBytes(RssSparkConfig.RSS_CLIENT_SEND_SIZE_LIMIT.key(),
- RssSparkConfig.RSS_CLIENT_SEND_SIZE_LIMIT.defaultValue().get());
this.bitmapSplitNum =
sparkConf.get(RssSparkConfig.RSS_CLIENT_BITMAP_SPLIT_NUM);
this.partitionToBlockIds = Maps.newHashMap();
this.shuffleWriteClient = shuffleWriteClient;
@@ -231,26 +228,8 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
}
protected void postBlockEvent(List<ShuffleBlockInfo> shuffleBlockInfoList) {
- long totalSize = 0;
- List<ShuffleBlockInfo> shuffleBlockInfosPerEvent = Lists.newArrayList();
- for (ShuffleBlockInfo sbi : shuffleBlockInfoList) {
- totalSize += sbi.getSize();
- shuffleBlockInfosPerEvent.add(sbi);
- // split shuffle data according to the size
- if (totalSize > sendSizeLimit) {
- LOG.debug("Post event to queue with " +
shuffleBlockInfosPerEvent.size()
- + " blocks and " + totalSize + " bytes");
- shuffleManager.postEvent(
- new AddBlockEvent(taskId, shuffleBlockInfosPerEvent));
- shuffleBlockInfosPerEvent = Lists.newArrayList();
- totalSize = 0;
- }
- }
- if (!shuffleBlockInfosPerEvent.isEmpty()) {
- LOG.debug("Post event to queue with " + shuffleBlockInfosPerEvent.size()
- + " blocks and " + totalSize + " bytes");
- shuffleManager.postEvent(
- new AddBlockEvent(taskId, shuffleBlockInfosPerEvent));
+ for (AddBlockEvent event :
bufferManager.buildBlockEvents(shuffleBlockInfoList)) {
+ shuffleManager.sendData(event);
}
}
diff --git
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/TestUtils.java
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/TestUtils.java
index 033966a7..1c7f988a 100644
--- a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/TestUtils.java
+++ b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/TestUtils.java
@@ -22,8 +22,7 @@ import java.util.Set;
import org.apache.commons.lang3.SystemUtils;
import org.apache.spark.SparkConf;
-import org.apache.spark.shuffle.writer.AddBlockEvent;
-import org.apache.spark.util.EventLoop;
+import org.apache.spark.shuffle.writer.DataPusher;
public class TestUtils {
@@ -33,10 +32,10 @@ public class TestUtils {
public static RssShuffleManager createShuffleManager(
SparkConf conf,
Boolean isDriver,
- EventLoop<AddBlockEvent> loop,
+ DataPusher dataPusher,
Map<String, Set<Long>> successBlockIds,
Map<String, Set<Long>> failBlockIds) {
- return new RssShuffleManager(conf, isDriver, loop, successBlockIds,
failBlockIds);
+ return new RssShuffleManager(conf, isDriver, dataPusher, successBlockIds,
failBlockIds);
}
public static boolean isMacOnAppleSilicon() {
diff --git
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
index f446cf3a..25c76497 100644
---
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
+++
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
@@ -17,11 +17,13 @@
package org.apache.spark.shuffle.writer;
-
+import java.time.Duration;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Set;
+import java.util.concurrent.CompletableFuture;
+import java.util.function.Function;
import java.util.stream.Collectors;
import com.google.common.collect.Lists;
@@ -40,7 +42,7 @@ import org.apache.spark.shuffle.RssShuffleHandle;
import org.apache.spark.shuffle.RssShuffleManager;
import org.apache.spark.shuffle.RssSparkConfig;
import org.apache.spark.shuffle.TestUtils;
-import org.apache.spark.util.EventLoop;
+import org.awaitility.Awaitility;
import org.junit.jupiter.api.Test;
import scala.Product2;
import scala.Tuple2;
@@ -49,7 +51,6 @@ import scala.collection.mutable.MutableList;
import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
-import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.util.JavaUtils;
import org.apache.uniffle.storage.util.StorageType;
@@ -102,7 +103,7 @@ public class RssShuffleWriterTest {
BufferManagerOptions bufferOptions = new BufferManagerOptions(conf);
WriteBufferManager bufferManager = new WriteBufferManager(
0, 0, bufferOptions, kryoSerializer,
- Maps.newHashMap(), mockTaskMemoryManager, new ShuffleWriteMetrics(),
new RssConf());
+ Maps.newHashMap(), mockTaskMemoryManager, new ShuffleWriteMetrics(),
RssSparkConfig.toRssConf(conf));
WriteBufferManager bufferManagerSpy = spy(bufferManager);
RssShuffleWriter<String, String, String> rssShuffleWriter = new
RssShuffleWriter<>("appId", 0, "taskId", 1L,
@@ -135,6 +136,32 @@ public class RssShuffleWriterTest {
sc.stop();
}
+ static class FakedDataPusher extends DataPusher {
+ private final Function<AddBlockEvent, CompletableFuture<Long>> sendFunc;
+
+ FakedDataPusher(Function<AddBlockEvent, CompletableFuture<Long>> sendFunc)
{
+ this(null, null, null, null, 1, 1, sendFunc);
+ }
+
+ private FakedDataPusher(
+ ShuffleWriteClient shuffleWriteClient,
+ Map<String, Set<Long>> taskToSuccessBlockIds,
+ Map<String, Set<Long>> taskToFailedBlockIds,
+ Set<String> failedTaskIds,
+ int threadPoolSize,
+ int threadKeepAliveTime,
+ Function<AddBlockEvent, CompletableFuture<Long>> sendFunc) {
+ super(shuffleWriteClient, taskToSuccessBlockIds, taskToFailedBlockIds,
failedTaskIds, threadPoolSize,
+ threadKeepAliveTime);
+ this.sendFunc = sendFunc;
+ }
+
+ @Override
+ public CompletableFuture<Long> send(AddBlockEvent event) {
+ return sendFunc.apply(event);
+ }
+ }
+
@Test
public void writeTest() throws Exception {
SparkConf conf = new SparkConf();
@@ -151,27 +178,24 @@ public class RssShuffleWriterTest {
// init SparkContext
List<ShuffleBlockInfo> shuffleBlockInfos = Lists.newArrayList();
final SparkContext sc = SparkContext.getOrCreate(conf);
- Map<String, Set<Long>> successBlockIds = JavaUtils.newConcurrentMap();
- EventLoop<AddBlockEvent> testLoop = new EventLoop<AddBlockEvent>("test") {
- @Override
- public void onReceive(AddBlockEvent event) {
- assertEquals("taskId", event.getTaskId());
- shuffleBlockInfos.addAll(event.getShuffleDataInfoList());
- Set<Long> blockIds = event.getShuffleDataInfoList().parallelStream()
- .map(sdi -> sdi.getBlockId()).collect(Collectors.toSet());
- successBlockIds.putIfAbsent(event.getTaskId(),
Sets.newConcurrentHashSet());
- successBlockIds.get(event.getTaskId()).addAll(blockIds);
- }
-
- @Override
- public void onError(Throwable e) {
- }
- };
+ Map<String, Set<Long>> successBlockIds = Maps.newConcurrentMap();
+
+ FakedDataPusher dataPusher = new FakedDataPusher(
+ event -> {
+ assertEquals("taskId", event.getTaskId());
+ shuffleBlockInfos.addAll(event.getShuffleDataInfoList());
+ Set<Long> blockIds = event.getShuffleDataInfoList().parallelStream()
+ .map(sdi -> sdi.getBlockId()).collect(Collectors.toSet());
+ successBlockIds.putIfAbsent(event.getTaskId(),
Sets.newConcurrentHashSet());
+ successBlockIds.get(event.getTaskId()).addAll(blockIds);
+ return new CompletableFuture<>();
+ }
+ );
final RssShuffleManager manager = TestUtils.createShuffleManager(
conf,
false,
- testLoop,
+ dataPusher,
successBlockIds,
JavaUtils.newConcurrentMap());
Serializer kryoSerializer = new KryoSerializer(conf);
@@ -210,7 +234,11 @@ public class RssShuffleWriterTest {
ShuffleWriteMetrics shuffleWriteMetrics = new ShuffleWriteMetrics();
WriteBufferManager bufferManager = new WriteBufferManager(
0, 0, bufferOptions, kryoSerializer,
- partitionToServers, mockTaskMemoryManager, shuffleWriteMetrics, new
RssConf());
+ partitionToServers, mockTaskMemoryManager, shuffleWriteMetrics,
+ RssSparkConfig.toRssConf(conf)
+ );
+ bufferManager.setTaskId("taskId");
+
WriteBufferManager bufferManagerSpy = spy(bufferManager);
RssShuffleWriter<String, String, String> rssShuffleWriter = new
RssShuffleWriter<>("appId", 0, "taskId", 1L,
bufferManagerSpy, shuffleWriteMetrics, manager, conf,
mockShuffleWriteClient, mockHandle);
@@ -265,7 +293,16 @@ public class RssShuffleWriterTest {
@Test
public void postBlockEventTest() throws Exception {
- WriteBufferManager mockBufferManager = mock(WriteBufferManager.class);
+ SparkConf conf = new SparkConf();
+ conf.set(RssSparkConfig.SPARK_RSS_CONFIG_PREFIX +
RssSparkConfig.RSS_CLIENT_SEND_SIZE_LIMITATION.key(), "64")
+ .set(RssSparkConfig.RSS_STORAGE_TYPE.key(),
StorageType.MEMORY_LOCALFILE.name());
+
+ BufferManagerOptions bufferOptions = new BufferManagerOptions(conf);
+ WriteBufferManager bufferManager = new WriteBufferManager(
+ 0, 0, bufferOptions, new KryoSerializer(conf),
+ Maps.newHashMap(), mock(TaskMemoryManager.class), new
ShuffleWriteMetrics(), RssSparkConfig.toRssConf(conf));
+ WriteBufferManager bufferManagerSpy = spy(bufferManager);
+
ShuffleDependency<String, String, String> mockDependency =
mock(ShuffleDependency.class);
ShuffleWriteMetrics mockMetrics = mock(ShuffleWriteMetrics.class);
Partitioner mockPartitioner = mock(Partitioner.class);
@@ -274,35 +311,39 @@ public class RssShuffleWriterTest {
when(mockPartitioner.numPartitions()).thenReturn(2);
List<AddBlockEvent> events = Lists.newArrayList();
- EventLoop<AddBlockEvent> eventLoop = new EventLoop<AddBlockEvent>("test") {
- @Override
- public void onReceive(AddBlockEvent event) {
- events.add(event);
- }
+ FakedDataPusher dataPusher = new FakedDataPusher(
+ event -> {
+ events.add(event);
+ return new CompletableFuture<>();
+ }
+ );
- @Override
- public void onError(Throwable e) {
- }
- };
RssShuffleManager mockShuffleManager = spy(TestUtils.createShuffleManager(
sparkConf,
false,
- eventLoop,
- JavaUtils.newConcurrentMap(),
- JavaUtils.newConcurrentMap()));
+ dataPusher,
+ Maps.newConcurrentMap(),
+ Maps.newConcurrentMap()));
RssShuffleHandle<String, String, String> mockHandle =
mock(RssShuffleHandle.class);
when(mockHandle.getDependency()).thenReturn(mockDependency);
ShuffleWriteClient mockWriteClient = mock(ShuffleWriteClient.class);
- SparkConf conf = new SparkConf();
- conf.set(RssSparkConfig.RSS_CLIENT_SEND_SIZE_LIMIT.key(), "64")
- .set(RssSparkConfig.RSS_STORAGE_TYPE.key(),
StorageType.MEMORY_LOCALFILE.name());
+
List<ShuffleBlockInfo> shuffleBlockInfoList = createShuffleBlockList(1,
31);
- RssShuffleWriter<String, String, String> writer = new
RssShuffleWriter<>("appId", 0, "taskId", 1L,
- mockBufferManager, mockMetrics, mockShuffleManager, conf,
mockWriteClient, mockHandle);
+ RssShuffleWriter<String, String, String> writer = new RssShuffleWriter<>(
+ "appId",
+ 0,
+ "taskId",
+ 1L,
+ bufferManagerSpy,
+ mockMetrics,
+ mockShuffleManager,
+ conf,
+ mockWriteClient,
+ mockHandle
+ );
writer.postBlockEvent(shuffleBlockInfoList);
- Thread.sleep(500);
- assertEquals(1, events.size());
+ Awaitility.await().timeout(Duration.ofSeconds(1)).until(() ->
events.size() == 1);
assertEquals(1, events.get(0).getShuffleDataInfoList().size());
events.clear();
diff --git
a/common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java
b/common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java
index b0ed870e..eb9173ba 100644
--- a/common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java
+++ b/common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java
@@ -22,15 +22,19 @@ import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.ThreadFactory;
+import java.util.concurrent.TimeUnit;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import io.netty.util.concurrent.DefaultThreadFactory;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
-/**
- * Provide a general method to create a thread factory to make the code more
standardized
- */
public class ThreadUtils {
+ private static final Logger LOGGER =
LoggerFactory.getLogger(ThreadUtils.class);
+ /**
+ * Provide a general method to create a thread factory to make the code more
standardized
+ */
public static ThreadFactory getThreadFactory(String factoryName) {
return new
ThreadFactoryBuilder().setDaemon(true).setNameFormat(factoryName +
"-%d").build();
}
@@ -74,4 +78,17 @@ public class ThreadUtils {
public static ExecutorService getDaemonCachedThreadPool(String factoryName) {
return Executors.newCachedThreadPool(getThreadFactory(factoryName));
}
+
+ public static void shutdownThreadPool(ExecutorService threadPool, int
waitSec) throws InterruptedException {
+ if (threadPool == null) {
+ return;
+ }
+ threadPool.shutdown();
+ if (!threadPool.awaitTermination(waitSec, TimeUnit.SECONDS)) {
+ threadPool.shutdownNow();
+ if (!threadPool.awaitTermination(waitSec, TimeUnit.SECONDS)) {
+ LOGGER.warn("Thread pool don't stop gracefully.");
+ }
+ }
+ }
}
diff --git
a/common/src/test/java/org/apache/uniffle/common/util/ThreadUtilsTest.java
b/common/src/test/java/org/apache/uniffle/common/util/ThreadUtilsTest.java
new file mode 100644
index 00000000..7fafa442
--- /dev/null
+++ b/common/src/test/java/org/apache/uniffle/common/util/ThreadUtilsTest.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.util;
+
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+import org.junit.jupiter.api.Test;
+
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class ThreadUtilsTest {
+
+ @Test
+ public void shutdownThreadPoolTest() throws InterruptedException {
+ ExecutorService executorService = Executors.newFixedThreadPool(2);
+ AtomicBoolean finished = new AtomicBoolean(false);
+ executorService.submit(() -> {
+ try {
+ Thread.sleep(100000);
+ } catch (InterruptedException interruptedException) {
+ // ignore
+ } finally {
+ finished.set(true);
+ }
+ });
+ ThreadUtils.shutdownThreadPool(executorService, 1);
+ assertTrue(finished.get());
+ assertTrue(executorService.isShutdown());
+ }
+}