This is an automated email from the ASF dual-hosted git repository.

zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new 1b48c120 [#706] feat(spark): support spill to avoid memory deadlock 
(#714)
1b48c120 is described below

commit 1b48c12026ad7be42368fb6eee2cd9a1dff2bbb5
Author: Junfan Zhang <[email protected]>
AuthorDate: Tue Mar 28 19:02:47 2023 +0800

    [#706] feat(spark): support spill to avoid memory deadlock (#714)
    
    ### What changes were proposed in this pull request?
    1.  Introduce the `DataPusher` to replace the `eventLoop`, this could be as 
general part for spark2 and spark3.
    2. Implement the `spill` method in `WriterBufferManager` to avoid memory 
deadlock.
    
    ### Why are the changes needed?
    In current codebase, if having several `WriterBufferManagers`, when each 
other is acquiring memory, the deadlock will happen. To solve this, we should 
implement spill function to break this deadlock condition.
    
    Fix: #706
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    1. Existing UTs
    2. Newly added UTs
---
 .../org/apache/spark/shuffle/RssSparkConfig.java   |  20 ++-
 .../apache/spark/shuffle/writer/AddBlockEvent.java |  21 ++++
 .../apache/spark/shuffle/writer/DataPusher.java    | 137 +++++++++++++++++++++
 .../spark/shuffle/writer/WriteBufferManager.java   | 115 ++++++++++++++++-
 .../spark/shuffle/writer/DataPusherTest.java       | 119 ++++++++++++++++++
 .../shuffle/writer/WriteBufferManagerTest.java     |  79 ++++++++++++
 .../apache/spark/shuffle/RssShuffleManager.java    | 128 +++++++------------
 .../spark/shuffle/writer/RssShuffleWriter.java     |  25 +---
 .../spark/shuffle/writer/RssShuffleWriterTest.java | 100 +++++++++------
 .../apache/spark/shuffle/RssShuffleManager.java    | 125 ++++++-------------
 .../spark/shuffle/writer/RssShuffleWriter.java     |  25 +---
 .../java/org/apache/spark/shuffle/TestUtils.java   |   7 +-
 .../spark/shuffle/writer/RssShuffleWriterTest.java | 125 ++++++++++++-------
 .../apache/uniffle/common/util/ThreadUtils.java    |  23 +++-
 .../uniffle/common/util/ThreadUtilsTest.java       |  47 +++++++
 15 files changed, 782 insertions(+), 314 deletions(-)

diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
index c8e3478c..6aa90e1a 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
@@ -28,11 +28,26 @@ import scala.Tuple2;
 import scala.runtime.AbstractFunction1;
 
 import org.apache.uniffle.client.util.RssClientConfig;
+import org.apache.uniffle.common.config.ConfigOption;
+import org.apache.uniffle.common.config.ConfigOptions;
 import org.apache.uniffle.common.config.ConfigUtils;
 import org.apache.uniffle.common.config.RssConf;
 
 public class RssSparkConfig {
 
+  public static final ConfigOption<Long> RSS_CLIENT_SEND_SIZE_LIMITATION = 
ConfigOptions
+      .key("rss.client.send.size.limit")
+      .longType()
+      .defaultValue(1024 * 1024 * 16L)
+      .withDescription("The max data size sent to shuffle server");
+
+  public static final ConfigOption<Integer> RSS_MEMORY_SPILL_TIMEOUT = 
ConfigOptions
+      .key("rss.client.memory.spill.timeout.sec")
+      .intType()
+      .defaultValue(1)
+      .withDescription("The timeout of spilling data to remote shuffle server, 
"
+          + "which will be triggered by Spark TaskMemoryManager. Unit is sec, 
default value is 1");
+
   public static final String SPARK_RSS_CONFIG_PREFIX = "spark.";
 
   public static final ConfigEntry<Integer> RSS_PARTITION_NUM_PER_RANGE = 
createIntegerBuilder(
@@ -115,11 +130,6 @@ public class RssSparkConfig {
       new ConfigBuilder("spark.rss.client.heartBeat.threadNum"))
       .createWithDefault(4);
 
-  public static final ConfigEntry<String> RSS_CLIENT_SEND_SIZE_LIMIT = 
createStringBuilder(
-      new ConfigBuilder("spark.rss.client.send.size.limit")
-          .doc("The max data size sent to shuffle server"))
-      .createWithDefault("16m");
-
   public static final ConfigEntry<Integer> 
RSS_CLIENT_UNREGISTER_THREAD_POOL_SIZE = createIntegerBuilder(
       new ConfigBuilder("spark.rss.client.unregister.thread.pool.size"))
       .createWithDefault(10);
diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java
index a8889754..7dab0725 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java
@@ -17,6 +17,7 @@
 
 package org.apache.spark.shuffle.writer;
 
+import java.util.ArrayList;
 import java.util.List;
 
 import org.apache.uniffle.common.ShuffleBlockInfo;
@@ -25,10 +26,26 @@ public class AddBlockEvent {
 
   private String taskId;
   private List<ShuffleBlockInfo> shuffleDataInfoList;
+  private List<Runnable> processedCallbackChain;
 
   public AddBlockEvent(String taskId, List<ShuffleBlockInfo> 
shuffleDataInfoList) {
     this.taskId = taskId;
     this.shuffleDataInfoList = shuffleDataInfoList;
+    this.processedCallbackChain = new ArrayList<>();
+  }
+
+  public AddBlockEvent(String taskId, List<ShuffleBlockInfo> 
shuffleBlockInfoList, Runnable callback) {
+    this.taskId = taskId;
+    this.shuffleDataInfoList = shuffleBlockInfoList;
+    this.processedCallbackChain = new ArrayList<>();
+    addCallback(callback);
+  }
+
+  /**
+   * @param callback, should not throw any exception and execute fast.
+   */
+  public void addCallback(Runnable callback) {
+    processedCallbackChain.add(callback);
   }
 
   public String getTaskId() {
@@ -39,6 +56,10 @@ public class AddBlockEvent {
     return shuffleDataInfoList;
   }
 
+  public List<Runnable> getProcessedCallbackChain() {
+    return processedCallbackChain;
+  }
+
   @Override
   public String toString() {
     return "AddBlockEvent: TaskId[" + taskId + "], " + shuffleDataInfoList;
diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
new file mode 100644
index 00000000..ca03a784
--- /dev/null
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.writer;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.ThreadPoolExecutor;
+import java.util.concurrent.TimeUnit;
+
+import com.google.common.collect.Queues;
+import com.google.common.collect.Sets;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.client.api.ShuffleWriteClient;
+import org.apache.uniffle.client.response.SendShuffleDataResult;
+import org.apache.uniffle.common.ShuffleBlockInfo;
+import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.util.ThreadUtils;
+
+/**
+ * A {@link DataPusher} that is responsible for sending data to remote
+ * shuffle servers asynchronously.
+ */
+public class DataPusher implements Closeable {
+  private static final Logger LOGGER = 
LoggerFactory.getLogger(DataPusher.class);
+
+  private final ExecutorService executorService;
+
+  private final ShuffleWriteClient shuffleWriteClient;
+  // Must be thread safe
+  private final Map<String, Set<Long>> taskToSuccessBlockIds;
+  // Must be thread safe
+  private final Map<String, Set<Long>> taskToFailedBlockIds;
+  private String rssAppId;
+  // Must be thread safe
+  private final Set<String> failedTaskIds;
+
+  public DataPusher(ShuffleWriteClient shuffleWriteClient,
+      Map<String, Set<Long>> taskToSuccessBlockIds,
+      Map<String, Set<Long>> taskToFailedBlockIds,
+      Set<String> failedTaskIds,
+      int threadPoolSize,
+      int threadKeepAliveTime) {
+    this.shuffleWriteClient = shuffleWriteClient;
+    this.taskToSuccessBlockIds = taskToSuccessBlockIds;
+    this.taskToFailedBlockIds = taskToFailedBlockIds;
+    this.failedTaskIds = failedTaskIds;
+    this.executorService = new ThreadPoolExecutor(
+        threadPoolSize,
+        threadPoolSize * 2,
+        threadKeepAliveTime,
+        TimeUnit.SECONDS,
+        Queues.newLinkedBlockingQueue(Integer.MAX_VALUE),
+        ThreadUtils.getThreadFactory(this.getClass().getName())
+    );
+  }
+
+  public CompletableFuture<Long> send(AddBlockEvent event) {
+    if (rssAppId == null) {
+      throw new RssException("RssAppId should be set.");
+    }
+    return CompletableFuture.supplyAsync(() -> {
+      String taskId = event.getTaskId();
+      List<ShuffleBlockInfo> shuffleBlockInfoList = 
event.getShuffleDataInfoList();
+      try {
+        SendShuffleDataResult result = shuffleWriteClient.sendShuffleData(
+            rssAppId,
+            shuffleBlockInfoList,
+            () -> !isValidTask(taskId)
+        );
+        putBlockId(taskToSuccessBlockIds, taskId, result.getSuccessBlockIds());
+        putBlockId(taskToFailedBlockIds, taskId, result.getFailedBlockIds());
+      } finally {
+        List<Runnable> callbackChain = 
Optional.of(event.getProcessedCallbackChain()).orElse(Collections.EMPTY_LIST);
+        for (Runnable runnable : callbackChain) {
+          runnable.run();
+        }
+      }
+      return shuffleBlockInfoList.stream()
+          .map(x -> x.getFreeMemory())
+          .reduce((a, b) -> a + b)
+          .get();
+    }, executorService);
+  }
+
+  private synchronized void putBlockId(
+      Map<String, Set<Long>> taskToBlockIds,
+      String taskAttemptId,
+      Set<Long> blockIds) {
+    if (blockIds == null || blockIds.isEmpty()) {
+      return;
+    }
+    taskToBlockIds.computeIfAbsent(taskAttemptId, x -> 
Sets.newConcurrentHashSet()).addAll(blockIds);
+  }
+
+  public boolean isValidTask(String taskId) {
+    return !failedTaskIds.contains(taskId);
+  }
+
+  public void setRssAppId(String rssAppId) {
+    this.rssAppId = rssAppId;
+  }
+
+  @Override
+  public void close() throws IOException {
+    if (executorService != null) {
+      try {
+        ThreadUtils.shutdownThreadPool(executorService, 5);
+      } catch (InterruptedException interruptedException) {
+        LOGGER.error("Errors on shutdown thread pool of [{}].", 
this.getClass().getSimpleName());
+      }
+    }
+  }
+}
diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
index 9f33be38..580fce69 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
@@ -17,10 +17,16 @@
 
 package org.apache.spark.shuffle.writer;
 
+import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
 import java.util.Map.Entry;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
 import java.util.concurrent.atomic.AtomicLong;
+import java.util.function.Function;
+import java.util.stream.Collectors;
 
 import com.clearspring.analytics.util.Lists;
 import com.google.common.annotations.VisibleForTesting;
@@ -61,6 +67,7 @@ public class WriteBufferManager extends MemoryConsumer {
   private Map<Integer, Integer> partitionToSeqNo = Maps.newHashMap();
   private long askExecutorMemory;
   private int shuffleId;
+  private String taskId;
   private long taskAttemptId;
   private SerializerInstance instance;
   private ShuffleWriteMetrics shuffleWriteMetrics;
@@ -81,6 +88,9 @@ public class WriteBufferManager extends MemoryConsumer {
   private long requireMemoryInterval;
   private int requireMemoryRetryMax;
   private Codec codec;
+  private Function<AddBlockEvent, CompletableFuture<Long>> spillFunc;
+  private long sendSizeLimit;
+  private int memorySpillTimeoutSec;
 
   public WriteBufferManager(
       int shuffleId,
@@ -91,12 +101,38 @@ public class WriteBufferManager extends MemoryConsumer {
       TaskMemoryManager taskMemoryManager,
       ShuffleWriteMetrics shuffleWriteMetrics,
       RssConf rssConf) {
+    this(
+        shuffleId,
+        null,
+        taskAttemptId,
+        bufferManagerOptions,
+        serializer,
+        partitionToServers,
+        taskMemoryManager,
+        shuffleWriteMetrics,
+        rssConf,
+        null
+    );
+  }
+
+  public WriteBufferManager(
+      int shuffleId,
+      String taskId,
+      long taskAttemptId,
+      BufferManagerOptions bufferManagerOptions,
+      Serializer serializer,
+      Map<Integer, List<ShuffleServerInfo>> partitionToServers,
+      TaskMemoryManager taskMemoryManager,
+      ShuffleWriteMetrics shuffleWriteMetrics,
+      RssConf rssConf,
+      Function<AddBlockEvent, CompletableFuture<Long>> spillFunc) {
     super(taskMemoryManager, taskMemoryManager.pageSizeBytes(), 
MemoryMode.ON_HEAP);
     this.bufferSize = bufferManagerOptions.getBufferSize();
     this.spillSize = bufferManagerOptions.getBufferSpillThreshold();
     this.instance = serializer.newInstance();
     this.buffers = Maps.newHashMap();
     this.shuffleId = shuffleId;
+    this.taskId = taskId;
     this.taskAttemptId = taskAttemptId;
     this.partitionToServers = partitionToServers;
     this.shuffleWriteMetrics = shuffleWriteMetrics;
@@ -111,6 +147,9 @@ public class WriteBufferManager extends MemoryConsumer {
             .substring(RssSparkConfig.SPARK_RSS_CONFIG_PREFIX.length()),
         RssSparkConfig.SPARK_SHUFFLE_COMPRESS_DEFAULT);
     this.codec = compress ? Codec.newInstance(rssConf) : null;
+    this.spillFunc = spillFunc;
+    this.sendSizeLimit = 
rssConf.get(RssSparkConfig.RSS_CLIENT_SEND_SIZE_LIMITATION);
+    this.memorySpillTimeoutSec = 
rssConf.get(RssSparkConfig.RSS_MEMORY_SPILL_TIMEOUT);
   }
 
   public List<ShuffleBlockInfo> addRecord(int partitionId, Object key, Object 
value) {
@@ -165,7 +204,7 @@ public class WriteBufferManager extends MemoryConsumer {
   }
 
   // transform all [partition, records] to [partition, ShuffleBlockInfo] and 
clear cache
-  public List<ShuffleBlockInfo> clear() {
+  public synchronized List<ShuffleBlockInfo> clear() {
     List<ShuffleBlockInfo> result = Lists.newArrayList();
     long dataSize = 0;
     long memoryUsed = 0;
@@ -247,10 +286,64 @@ public class WriteBufferManager extends MemoryConsumer {
     }
   }
 
+  public List<AddBlockEvent> buildBlockEvents(List<ShuffleBlockInfo> 
shuffleBlockInfoList) {
+    long totalSize = 0;
+    long memoryUsed = 0;
+    List<AddBlockEvent> events = new ArrayList<>();
+    List<ShuffleBlockInfo> shuffleBlockInfosPerEvent = Lists.newArrayList();
+    for (ShuffleBlockInfo sbi : shuffleBlockInfoList) {
+      totalSize += sbi.getSize();
+      memoryUsed += sbi.getFreeMemory();
+      shuffleBlockInfosPerEvent.add(sbi);
+      // split shuffle data according to the size
+      if (totalSize > sendSizeLimit) {
+        LOG.info("Build event with " + shuffleBlockInfosPerEvent.size()
+            + " blocks and " + totalSize + " bytes");
+        // Use final temporary variables for closures
+        final long _memoryUsed = memoryUsed;
+        events.add(
+            new AddBlockEvent(taskId, shuffleBlockInfosPerEvent, () -> 
freeAllocatedMemory(_memoryUsed))
+        );
+        shuffleBlockInfosPerEvent = Lists.newArrayList();
+        totalSize = 0;
+        memoryUsed = 0;
+      }
+    }
+    if (!shuffleBlockInfosPerEvent.isEmpty()) {
+      LOG.info("Build event with " + shuffleBlockInfosPerEvent.size()
+          + " blocks and " + totalSize + " bytes");
+      // Use final temporary variables for closures
+      final long _memoryUsed = memoryUsed;
+      events.add(
+          new AddBlockEvent(taskId, shuffleBlockInfosPerEvent, () -> 
freeAllocatedMemory(_memoryUsed))
+      );
+    }
+    return events;
+  }
+
   @Override
   public long spill(long size, MemoryConsumer trigger) {
-    // there is no spill for such situation
-    return 0;
+    List<AddBlockEvent> events = buildBlockEvents(clear());
+    List<CompletableFuture<Long>> futures = events.stream().map(x -> 
spillFunc.apply(x)).collect(Collectors.toList());
+    CompletableFuture<Void> allOfFutures =
+        CompletableFuture.allOf(futures.toArray(new 
CompletableFuture[futures.size()]));
+    try {
+      allOfFutures.get(memorySpillTimeoutSec, TimeUnit.SECONDS);
+    } catch (TimeoutException timeoutException) {
+      // A best effort strategy to wait.
+      // If timeout exception occurs, the underlying tasks won't be cancelled.
+    } finally {
+      long releasedSize = futures.stream().filter(x -> x.isDone()).mapToLong(x 
-> {
+        try {
+          return x.get();
+        } catch (Exception e) {
+          return 0;
+        }
+      }).sum();
+      LOG.info("[taskId: {}] Spill triggered by memory consumer of {}, 
released memory size: {}",
+          taskId, trigger.getClass().getSimpleName(), releasedSize);
+      return releasedSize;
+    }
   }
 
   @VisibleForTesting
@@ -307,4 +400,20 @@ public class WriteBufferManager extends MemoryConsumer {
         + estimateTime + "], requireMemoryTime[" + requireMemoryTime
         + "], uncompressedDataLen[" + uncompressedDataLen + "]";
   }
+
+  @VisibleForTesting
+  public void setTaskId(String taskId) {
+    this.taskId = taskId;
+  }
+
+  @VisibleForTesting
+  public void setSpillFunc(
+      Function<AddBlockEvent, CompletableFuture<Long>> spillFunc) {
+    this.spillFunc = spillFunc;
+  }
+
+  @VisibleForTesting
+  public void setSendSizeLimit(long sendSizeLimit) {
+    this.sendSizeLimit = sendSizeLimit;
+  }
 }
diff --git 
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java
 
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java
new file mode 100644
index 00000000..20711dc0
--- /dev/null
+++ 
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.writer;
+
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
+import java.util.function.Supplier;
+
+import com.google.common.collect.Maps;
+import com.google.common.collect.Sets;
+import org.junit.jupiter.api.Test;
+
+import org.apache.uniffle.client.impl.ShuffleWriteClientImpl;
+import org.apache.uniffle.client.response.SendShuffleDataResult;
+import org.apache.uniffle.common.ShuffleBlockInfo;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class DataPusherTest {
+
+  static class FakedShuffleWriteClient extends ShuffleWriteClientImpl {
+    private SendShuffleDataResult fakedShuffleDataResult;
+
+    FakedShuffleWriteClient() {
+      this(
+          "GRPC",
+          1,
+          1,
+          10,
+          1,
+          1,
+          1,
+          false,
+          1,
+          1,
+          1,
+          1
+      );
+    }
+
+    private FakedShuffleWriteClient(String clientType, int retryMax, long 
retryIntervalMax, int heartBeatThreadNum,
+        int replica, int replicaWrite, int replicaRead, boolean 
replicaSkipEnabled, int dataTransferPoolSize,
+        int dataCommitPoolSize, int unregisterThreadPoolSize, int 
unregisterRequestTimeSec) {
+      super(clientType, retryMax, retryIntervalMax, heartBeatThreadNum, 
replica, replicaWrite, replicaRead,
+          replicaSkipEnabled, dataTransferPoolSize, dataCommitPoolSize, 
unregisterThreadPoolSize,
+          unregisterRequestTimeSec);
+    }
+
+    @Override
+    public SendShuffleDataResult sendShuffleData(String appId, 
List<ShuffleBlockInfo> shuffleBlockInfoList,
+        Supplier<Boolean> needCancelRequest) {
+      return fakedShuffleDataResult;
+    }
+
+    public void setFakedShuffleDataResult(SendShuffleDataResult 
fakedShuffleDataResult) {
+      this.fakedShuffleDataResult = fakedShuffleDataResult;
+    }
+  }
+
+  @Test
+  public void testSendData() throws ExecutionException, InterruptedException {
+    FakedShuffleWriteClient shuffleWriteClient = new FakedShuffleWriteClient();
+
+    Map<String, Set<Long>> taskToSuccessBlockIds = Maps.newConcurrentMap();
+    Map<String, Set<Long>> taskToFailedBlockIds = Maps.newConcurrentMap();
+    Set<String> failedTaskIds = new HashSet<>();
+
+    DataPusher dataPusher = new DataPusher(
+        shuffleWriteClient,
+        taskToSuccessBlockIds,
+        taskToFailedBlockIds,
+        failedTaskIds,
+        1,
+        2
+    );
+    dataPusher.setRssAppId("testSendData_appId");
+
+    // sync send
+    AddBlockEvent event = new AddBlockEvent("taskId", Arrays.asList(
+        new ShuffleBlockInfo(
+            1, 1, 1, 1, 1, new byte[1], null, 1, 100, 1
+        ))
+    );
+    shuffleWriteClient.setFakedShuffleDataResult(
+        new SendShuffleDataResult(
+            Sets.newHashSet(1L, 2L),
+            Sets.newHashSet(3L, 4L)
+        )
+    );
+    CompletableFuture<Long> future = dataPusher.send(event);
+    long memoryFree = future.get();
+    assertEquals(100, memoryFree);
+    assertTrue(taskToSuccessBlockIds.get("taskId").contains(1L));
+    assertTrue(taskToSuccessBlockIds.get("taskId").contains(2L));
+    assertTrue(taskToFailedBlockIds.get("taskId").contains(3L));
+    assertTrue(taskToFailedBlockIds.get("taskId").contains(4L));
+  }
+}
diff --git 
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
 
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
index 4f57e265..87580541 100644
--- 
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
+++ 
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
@@ -17,7 +17,11 @@
 
 package org.apache.spark.shuffle.writer;
 
+import java.util.Arrays;
 import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.TimeUnit;
+import java.util.function.Function;
 
 import com.google.common.collect.Maps;
 import org.apache.commons.lang.reflect.FieldUtils;
@@ -27,6 +31,7 @@ import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.serializer.KryoSerializer;
 import org.apache.spark.serializer.Serializer;
 import org.apache.spark.shuffle.RssSparkConfig;
+import org.awaitility.Awaitility;
 import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 
@@ -195,4 +200,78 @@ public class WriteBufferManagerTest {
     sbi = wbm.createShuffleBlock(1, mockWriterBuffer);
     assertEquals(35184374185984L, sbi.getBlockId());
   }
+
+  @Test
+  public void buildBlockEventsTest() {
+    SparkConf conf = getConf();
+    conf.set("spark.rss.client.send.size.limit", "30");
+
+    TaskMemoryManager mockTaskMemoryManager = mock(TaskMemoryManager.class);
+
+    BufferManagerOptions bufferOptions = new BufferManagerOptions(conf);
+    WriteBufferManager wbm = new WriteBufferManager(
+        0, 0, bufferOptions, new KryoSerializer(conf),
+        Maps.newHashMap(), mockTaskMemoryManager, new ShuffleWriteMetrics(), 
RssSparkConfig.toRssConf(conf));
+
+    // every block: length=4, memoryUsed=12
+    ShuffleBlockInfo info1 = new ShuffleBlockInfo(1, 1, 1, 4, 1, new byte[1], 
null, 1, 12, 1);
+    ShuffleBlockInfo info2 = new ShuffleBlockInfo(1, 1, 1, 4, 1, new byte[1], 
null, 1, 12, 1);
+    ShuffleBlockInfo info3 = new ShuffleBlockInfo(1, 1, 1, 4, 1, new byte[1], 
null, 1, 12, 1);
+    List<AddBlockEvent> events = wbm.buildBlockEvents(Arrays.asList(info1, 
info2, info3));
+    assertEquals(3, events.size());
+  }
+
+  @Test
+  public void spillTest() {
+    SparkConf conf = getConf();
+    conf.set("spark.rss.client.send.size.limit", "1000");
+    TaskMemoryManager mockTaskMemoryManager = mock(TaskMemoryManager.class);
+    BufferManagerOptions bufferOptions = new BufferManagerOptions(conf);
+
+    Function<AddBlockEvent, CompletableFuture<Long>> spillFunc = event -> {
+      event.getProcessedCallbackChain().stream().forEach(x -> x.run());
+      return CompletableFuture.completedFuture(
+          event.getShuffleDataInfoList().stream().mapToLong(x -> 
x.getFreeMemory()).sum()
+      );
+    };
+
+    WriteBufferManager wbm = new WriteBufferManager(
+        0, "taskId_spillTest", 0, bufferOptions, new KryoSerializer(conf),
+        Maps.newHashMap(), mockTaskMemoryManager, new ShuffleWriteMetrics(),
+        RssSparkConfig.toRssConf(conf), spillFunc);
+    WriteBufferManager spyManager = spy(wbm);
+    doReturn(512L).when(spyManager).acquireMemory(anyLong());
+
+    String testKey = "Key";
+    String testValue = "Value";
+    spyManager.addRecord(0, testKey, testValue);
+    spyManager.addRecord(1, testKey, testValue);
+
+    // case1. all events are flushed within normal time.
+    long releasedSize = spyManager.spill(1000, mock(WriteBufferManager.class));
+    assertEquals(64, releasedSize);
+
+    // case2. partial events are not flushed within normal time.
+    // when calling spill func, 2 events should be spilled. But
+    // only event will be finished in the expected time.
+    spyManager.setSendSizeLimit(30);
+    spyManager.addRecord(0, testKey, testValue);
+    spyManager.addRecord(1, testKey, testValue);
+    spyManager.setSpillFunc(event -> CompletableFuture.supplyAsync(() -> {
+      int partitionId = event.getShuffleDataInfoList().get(0).getPartitionId();
+      if (partitionId == 1) {
+        try {
+          Thread.sleep(2000);
+        } catch (InterruptedException interruptedException) {
+          // ignore.
+        }
+      }
+      event.getProcessedCallbackChain().stream().forEach(x -> x.run());
+      return event.getShuffleDataInfoList().stream().mapToLong(x -> 
x.getFreeMemory()).sum();
+    }));
+    releasedSize = spyManager.spill(1000, mock(WriteBufferManager.class));
+    assertEquals(32, releasedSize);
+    assertEquals(32, spyManager.getUsedBytes());
+    Awaitility.await().timeout(3, TimeUnit.SECONDS).until(() -> 
spyManager.getUsedBytes() == 0);
+  }
 }
diff --git 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 0b301154..77b9ef0d 100644
--- 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -17,17 +17,17 @@
 
 package org.apache.spark.shuffle;
 
+import java.io.IOException;
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ScheduledExecutorService;
-import java.util.concurrent.ThreadPoolExecutor;
 import java.util.concurrent.TimeUnit;
 import java.util.function.Function;
 
 import com.google.common.annotations.VisibleForTesting;
-import com.google.common.collect.Queues;
 import com.google.common.collect.Sets;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.spark.ShuffleDependency;
@@ -39,11 +39,11 @@ import org.apache.spark.executor.ShuffleWriteMetrics;
 import org.apache.spark.shuffle.reader.RssShuffleReader;
 import org.apache.spark.shuffle.writer.AddBlockEvent;
 import org.apache.spark.shuffle.writer.BufferManagerOptions;
+import org.apache.spark.shuffle.writer.DataPusher;
 import org.apache.spark.shuffle.writer.RssShuffleWriter;
 import org.apache.spark.shuffle.writer.WriteBufferManager;
 import org.apache.spark.storage.BlockId;
 import org.apache.spark.storage.BlockManagerId;
-import org.apache.spark.util.EventLoop;
 import org.roaringbitmap.longlong.Roaring64NavigableMap;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -54,12 +54,10 @@ import scala.collection.Seq;
 
 import org.apache.uniffle.client.api.ShuffleWriteClient;
 import org.apache.uniffle.client.factory.ShuffleClientFactory;
-import org.apache.uniffle.client.response.SendShuffleDataResult;
 import org.apache.uniffle.client.util.ClientUtils;
 import org.apache.uniffle.common.PartitionRange;
 import org.apache.uniffle.common.RemoteStorageInfo;
 import org.apache.uniffle.common.ShuffleAssignmentsInfo;
-import org.apache.uniffle.common.ShuffleBlockInfo;
 import org.apache.uniffle.common.ShuffleDataDistributionType;
 import org.apache.uniffle.common.ShuffleServerInfo;
 import org.apache.uniffle.common.exception.RssException;
@@ -80,7 +78,6 @@ public class RssShuffleManager implements ShuffleManager {
   private ShuffleWriteClient shuffleWriteClient;
   private Map<String, Set<Long>> taskToSuccessBlockIds = 
JavaUtils.newConcurrentMap();
   private Map<String, Set<Long>> taskToFailedBlockIds = 
JavaUtils.newConcurrentMap();
-  private Map<String, WriteBufferManager> taskToBufferManager = 
JavaUtils.newConcurrentMap();
   private final int dataReplica;
   private final int dataReplicaWrite;
   private final int dataReplicaRead;
@@ -92,58 +89,7 @@ public class RssShuffleManager implements ShuffleManager {
   private boolean dynamicConfEnabled = false;
   private final String user;
   private final String uuid;
-  private ThreadPoolExecutor threadPoolExecutor;
-  private EventLoop<AddBlockEvent> eventLoop = new 
EventLoop<AddBlockEvent>("ShuffleDataQueue") {
-
-    @Override
-    public void onReceive(AddBlockEvent event) {
-      threadPoolExecutor.execute(() -> sendShuffleData(event.getTaskId(), 
event.getShuffleDataInfoList()));
-    }
-
-    private void sendShuffleData(String taskId, List<ShuffleBlockInfo> 
shuffleDataInfoList) {
-      try {
-        SendShuffleDataResult result = shuffleWriteClient.sendShuffleData(
-            appId,
-            shuffleDataInfoList,
-            () -> !isValidTask(taskId)
-        );
-        putBlockId(taskToSuccessBlockIds, taskId, result.getSuccessBlockIds());
-        putBlockId(taskToFailedBlockIds, taskId, result.getFailedBlockIds());
-      } finally {
-        // data is already send, release the memory to executor
-        long releaseSize = 0;
-        for (ShuffleBlockInfo sbi : shuffleDataInfoList) {
-          releaseSize += sbi.getFreeMemory();
-        }
-        WriteBufferManager bufferManager = taskToBufferManager.get(taskId);
-        if (bufferManager != null) {
-          bufferManager.freeAllocatedMemory(releaseSize);
-        }
-        LOG.debug("Finish send data and release " + releaseSize + " bytes");
-      }
-    }
-
-    private synchronized void putBlockId(
-        Map<String, Set<Long>> taskToBlockIds,
-        String taskAttemptId,
-        Set<Long> blockIds) {
-      if (blockIds == null) {
-        return;
-      }
-      if (taskToBlockIds.get(taskAttemptId) == null) {
-        taskToBlockIds.put(taskAttemptId, Sets.newConcurrentHashSet());
-      }
-      taskToBlockIds.get(taskAttemptId).addAll(blockIds);
-    }
-
-    @Override
-    public void onError(Throwable throwable) {
-    }
-
-    @Override
-    public void onStart() {
-    }
-  };
+  private DataPusher dataPusher;
 
   public RssShuffleManager(SparkConf sparkConf, boolean isDriver) {
     if (sparkConf.getBoolean("spark.sql.adaptive.enabled", false)) {
@@ -193,19 +139,22 @@ public class RssShuffleManager implements ShuffleManager {
     sparkConf.set("spark.shuffle.reduceLocality.enabled", "false");
     LOG.info("Disable shuffle data locality in RssShuffleManager.");
     if (!sparkConf.getBoolean(RssSparkConfig.RSS_TEST_FLAG.key(), false)) {
-      // for non-driver executor, start a thread for sending shuffle data to 
shuffle server
-      LOG.info("RSS data send thread is starting");
-      eventLoop.start();
-      int poolSize = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_SIZE);
-      int keepAliveTime = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_KEEPALIVE);
-      threadPoolExecutor = new ThreadPoolExecutor(poolSize, poolSize * 2, 
keepAliveTime, TimeUnit.SECONDS,
-          Queues.newLinkedBlockingQueue(Integer.MAX_VALUE),
-          ThreadUtils.getThreadFactory("SendData"));
-
       if (isDriver) {
         heartBeatScheduledExecutorService =
             
ThreadUtils.getDaemonSingleThreadScheduledExecutor("rss-heartbeat");
       }
+      // for non-driver executor, start a thread for sending shuffle data to 
shuffle server
+      LOG.info("RSS data pusher is starting...");
+      int poolSize = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_SIZE);
+      int keepAliveTime = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_KEEPALIVE);
+      this.dataPusher = new DataPusher(
+          shuffleWriteClient,
+          taskToSuccessBlockIds,
+          taskToFailedBlockIds,
+          failedTaskIds,
+          poolSize,
+          keepAliveTime
+      );
     }
   }
 
@@ -236,6 +185,7 @@ public class RssShuffleManager implements ShuffleManager {
     // will be called many times depend on how many shuffle stage
     if ("".equals(appId)) {
       appId = SparkEnv.get().conf().getAppId() + "_" + uuid;
+      dataPusher.setRssAppId(appId);
       LOG.info("Generate application id used in rss: " + appId);
     }
 
@@ -344,6 +294,13 @@ public class RssShuffleManager implements ShuffleManager {
     shuffleWriteClient.registerCoordinators(coordinators);
   }
 
+  public CompletableFuture<Long> sendData(AddBlockEvent event) {
+    if (dataPusher != null && event != null) {
+      return dataPusher.send(event);
+    }
+    return new CompletableFuture<>();
+  }
+
   // This method is called in Spark executor,
   // getting information from Spark driver via the ShuffleHandle.
   @Override
@@ -352,6 +309,7 @@ public class RssShuffleManager implements ShuffleManager {
     if (handle instanceof RssShuffleHandle) {
       RssShuffleHandle<K, V, ?> rssHandle = (RssShuffleHandle<K, V, ?>) handle;
       appId = rssHandle.getAppId();
+      dataPusher.setRssAppId(appId);
 
       int shuffleId = rssHandle.getShuffleId();
       String taskId = "" + context.taskAttemptId() + "_" + 
context.attemptNumber();
@@ -359,15 +317,16 @@ public class RssShuffleManager implements ShuffleManager {
       ShuffleWriteMetrics writeMetrics = 
context.taskMetrics().shuffleWriteMetrics();
       WriteBufferManager bufferManager = new WriteBufferManager(
           shuffleId,
+          taskId,
           context.taskAttemptId(),
           bufferOptions,
           rssHandle.getDependency().serializer(),
           rssHandle.getPartitionToServers(),
           context.taskMemoryManager(),
           writeMetrics,
-          RssSparkConfig.toRssConf(sparkConf)
+          RssSparkConfig.toRssConf(sparkConf),
+          this::sendData
       );
-      taskToBufferManager.put(taskId, bufferManager);
 
       return new RssShuffleWriter<>(rssHandle.getAppId(), shuffleId, taskId, 
context.taskAttemptId(), bufferManager,
           writeMetrics, this, sparkConf, shuffleWriteClient, rssHandle,
@@ -448,7 +407,13 @@ public class RssShuffleManager implements ShuffleManager {
     if (heartBeatScheduledExecutorService != null) {
       heartBeatScheduledExecutorService.shutdownNow();
     }
-    threadPoolExecutor.shutdownNow();
+    if (dataPusher != null) {
+      try {
+        dataPusher.close();
+      } catch (IOException e) {
+        LOG.warn("Errors on closing data pusher", e);
+      }
+    }
     shuffleWriteClient.close();
   }
 
@@ -457,15 +422,6 @@ public class RssShuffleManager implements ShuffleManager {
     throw new RssException("RssShuffleManager.shuffleBlockResolver is not 
implemented");
   }
 
-  public EventLoop<AddBlockEvent> getEventLoop() {
-    return eventLoop;
-  }
-
-  @VisibleForTesting
-  public void setEventLoop(EventLoop<AddBlockEvent> eventLoop) {
-    this.eventLoop = eventLoop;
-  }
-
   // when speculation enable, duplicate data will be sent and reported to 
shuffle server,
   // get the actual tasks and filter the duplicate data caused by speculation 
task
   private Roaring64NavigableMap getExpectedTasks(int shuffleId, int 
startPartition, int endPartition) {
@@ -520,15 +476,9 @@ public class RssShuffleManager implements ShuffleManager {
     taskToSuccessBlockIds.get(taskId).addAll(blockIds);
   }
 
-  @VisibleForTesting
-  public Map<String, WriteBufferManager> getTaskToBufferManager() {
-    return taskToBufferManager;
-  }
-
   public void clearTaskMeta(String taskId) {
     taskToSuccessBlockIds.remove(taskId);
     taskToFailedBlockIds.remove(taskId);
-    taskToBufferManager.remove(taskId);
   }
 
   @VisibleForTesting
@@ -550,4 +500,12 @@ public class RssShuffleManager implements ShuffleManager {
   public boolean isValidTask(String taskId) {
     return !failedTaskIds.contains(taskId);
   }
+
+  public DataPusher getDataPusher() {
+    return dataPusher;
+  }
+
+  public void setDataPusher(DataPusher dataPusher) {
+    this.dataPusher = dataPusher;
+  }
 }
diff --git 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 81445f44..b0b7653f 100644
--- 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++ 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -81,7 +81,6 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
   private RssShuffleManager shuffleManager;
   private long sendCheckTimeout;
   private long sendCheckInterval;
-  private long sendSizeLimit;
   private boolean isMemoryShuffleEnabled;
   private final Function<String, Boolean> taskFailureCallback;
 
@@ -136,8 +135,6 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     this.shouldPartition = partitioner.numPartitions() > 1;
     this.sendCheckTimeout = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_CHECK_TIMEOUT_MS);
     this.sendCheckInterval = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_CHECK_INTERVAL_MS);
-    this.sendSizeLimit = 
sparkConf.getSizeAsBytes(RssSparkConfig.RSS_CLIENT_SEND_SIZE_LIMIT.key(),
-        RssSparkConfig.RSS_CLIENT_SEND_SIZE_LIMIT.defaultValue().get());
     this.bitmapSplitNum = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_BITMAP_SPLIT_NUM);
     this.partitionToBlockIds = Maps.newHashMap();
     this.shuffleWriteClient = shuffleWriteClient;
@@ -233,26 +230,8 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
 
   // don't send huge block to shuffle server, or there will be OOM if shuffle 
sever receives data more than expected
   protected void postBlockEvent(List<ShuffleBlockInfo> shuffleBlockInfoList) {
-    long totalSize = 0;
-    List<ShuffleBlockInfo> shuffleBlockInfosPerEvent = Lists.newArrayList();
-    for (ShuffleBlockInfo sbi : shuffleBlockInfoList) {
-      totalSize += sbi.getSize();
-      shuffleBlockInfosPerEvent.add(sbi);
-      // split shuffle data according to the size
-      if (totalSize > sendSizeLimit) {
-        LOG.debug("Post event to queue with " + 
shuffleBlockInfosPerEvent.size()
-            + " blocks and " + totalSize + " bytes");
-        shuffleManager.getEventLoop().post(
-            new AddBlockEvent(taskId, shuffleBlockInfosPerEvent));
-        shuffleBlockInfosPerEvent = Lists.newArrayList();
-        totalSize = 0;
-      }
-    }
-    if (!shuffleBlockInfosPerEvent.isEmpty()) {
-      LOG.debug("Post event to queue with " + shuffleBlockInfosPerEvent.size()
-          + " blocks and " + totalSize + " bytes");
-      shuffleManager.getEventLoop().post(
-          new AddBlockEvent(taskId, shuffleBlockInfosPerEvent));
+    for (AddBlockEvent event : 
bufferManager.buildBlockEvents(shuffleBlockInfoList)) {
+      shuffleManager.sendData(event);
     }
   }
 
diff --git 
a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
 
b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
index 81d1bd23..61ff6325 100644
--- 
a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
+++ 
b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
@@ -21,6 +21,8 @@ import java.util.Arrays;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.concurrent.CompletableFuture;
+import java.util.function.Function;
 import java.util.stream.Collectors;
 
 import com.google.common.collect.Lists;
@@ -38,7 +40,6 @@ import org.apache.spark.serializer.Serializer;
 import org.apache.spark.shuffle.RssShuffleHandle;
 import org.apache.spark.shuffle.RssShuffleManager;
 import org.apache.spark.shuffle.RssSparkConfig;
-import org.apache.spark.util.EventLoop;
 import org.junit.jupiter.api.Test;
 import scala.Product2;
 import scala.Tuple2;
@@ -51,7 +52,6 @@ import org.apache.uniffle.common.config.RssConf;
 import org.apache.uniffle.storage.util.StorageType;
 
 import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertNull;
 import static org.junit.jupiter.api.Assertions.assertThrows;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 import static org.mockito.ArgumentMatchers.anyLong;
@@ -126,11 +126,36 @@ public class RssShuffleWriterTest {
     manager.clearTaskMeta(taskId);
     assertTrue(manager.getSuccessBlockIds(taskId).isEmpty());
     assertTrue(manager.getFailedBlockIds(taskId).isEmpty());
-    assertNull(manager.getTaskToBufferManager().get(taskId));
 
     sc.stop();
   }
 
+  static class FakedDataPusher extends DataPusher {
+    private final Function<AddBlockEvent, CompletableFuture<Long>> sendFunc;
+
+    FakedDataPusher(Function<AddBlockEvent, CompletableFuture<Long>> sendFunc) 
{
+      this(null, null, null, null, 1, 1, sendFunc);
+    }
+
+    private FakedDataPusher(
+        ShuffleWriteClient shuffleWriteClient,
+        Map<String, Set<Long>> taskToSuccessBlockIds,
+        Map<String, Set<Long>> taskToFailedBlockIds,
+        Set<String> failedTaskIds,
+        int threadPoolSize,
+        int threadKeepAliveTime,
+        Function<AddBlockEvent, CompletableFuture<Long>> sendFunc) {
+      super(shuffleWriteClient, taskToSuccessBlockIds, taskToFailedBlockIds, 
failedTaskIds, threadPoolSize,
+          threadKeepAliveTime);
+      this.sendFunc = sendFunc;
+    }
+
+    @Override
+    public CompletableFuture<Long> send(AddBlockEvent event) {
+      return sendFunc.apply(event);
+    }
+  }
+
   @Test
   public void writeTest() throws Exception {
     SparkConf conf = new SparkConf();
@@ -149,21 +174,15 @@ public class RssShuffleWriterTest {
     RssShuffleManager manager = new RssShuffleManager(conf, false);
     List<ShuffleBlockInfo> shuffleBlockInfos = Lists.newArrayList();
 
-    manager.setEventLoop(new EventLoop<AddBlockEvent>("test") {
-      @Override
-      public void onReceive(AddBlockEvent event) {
-        assertEquals("taskId", event.getTaskId());
-        shuffleBlockInfos.addAll(event.getShuffleDataInfoList());
-        Set<Long> blockIds = event.getShuffleDataInfoList().parallelStream()
-            .map(sdi -> sdi.getBlockId()).collect(Collectors.toSet());
-        manager.addSuccessBlockIds(event.getTaskId(), blockIds);
-      }
-
-      @Override
-      public void onError(Throwable e) {
-      }
+    DataPusher dataPusher = new FakedDataPusher(event -> {
+      assertEquals("taskId", event.getTaskId());
+      shuffleBlockInfos.addAll(event.getShuffleDataInfoList());
+      Set<Long> blockIds = event.getShuffleDataInfoList().parallelStream()
+          .map(sdi -> sdi.getBlockId()).collect(Collectors.toSet());
+      manager.addSuccessBlockIds(event.getTaskId(), blockIds);
+      return CompletableFuture.completedFuture(0L);
     });
-    manager.getEventLoop().start();
+    manager.setDataPusher(dataPusher);
 
     Partitioner mockPartitioner = mock(Partitioner.class);
     ShuffleDependency<String, String, String> mockDependency = 
mock(ShuffleDependency.class);
@@ -200,8 +219,8 @@ public class RssShuffleWriterTest {
     ShuffleWriteMetrics shuffleWriteMetrics = new ShuffleWriteMetrics();
     BufferManagerOptions bufferOptions = new BufferManagerOptions(conf);
     WriteBufferManager bufferManager = new WriteBufferManager(
-        0, 0, bufferOptions, kryoSerializer,
-        partitionToServers, mockTaskMemoryManager, shuffleWriteMetrics, new 
RssConf());
+        0, "taskId", 0, bufferOptions, kryoSerializer,
+        partitionToServers, mockTaskMemoryManager, shuffleWriteMetrics, new 
RssConf(), null);
     WriteBufferManager bufferManagerSpy = spy(bufferManager);
     doReturn(1000000L).when(bufferManagerSpy).acquireMemory(anyLong());
 
@@ -253,37 +272,46 @@ public class RssShuffleWriterTest {
 
   @Test
   public void postBlockEventTest() throws Exception {
-    final WriteBufferManager mockBufferManager = 
mock(WriteBufferManager.class);
     final ShuffleWriteMetrics mockMetrics = mock(ShuffleWriteMetrics.class);
     ShuffleDependency<String, String, String> mockDependency = 
mock(ShuffleDependency.class);
     Partitioner mockPartitioner = mock(Partitioner.class);
-    final RssShuffleManager mockShuffleManager = mock(RssShuffleManager.class);
     when(mockDependency.partitioner()).thenReturn(mockPartitioner);
     when(mockPartitioner.numPartitions()).thenReturn(2);
     List<AddBlockEvent> events = Lists.newArrayList();
 
-    EventLoop<AddBlockEvent> eventLoop = new EventLoop<AddBlockEvent>("test") {
-      @Override
-      public void onReceive(AddBlockEvent event) {
-        events.add(event);
-      }
+    SparkConf conf = new SparkConf();
+    conf.setAppName("postBlockEventTest").setMaster("local[2]")
+        .set(RssSparkConfig.RSS_TEST_FLAG.key(), "true")
+        .set(RssSparkConfig.RSS_TEST_MODE_ENABLE.key(), "true")
+        .set(RssSparkConfig.RSS_WRITER_BUFFER_SIZE.key(), "32")
+        .set(RssSparkConfig.RSS_WRITER_SERIALIZER_BUFFER_SIZE.key(), "32")
+        .set(RssSparkConfig.RSS_WRITER_BUFFER_SEGMENT_SIZE.key(), "64")
+        .set(RssSparkConfig.RSS_WRITER_BUFFER_SPILL_SIZE.key(), "128")
+        .set(RssSparkConfig.RSS_CLIENT_SEND_CHECK_INTERVAL_MS.key(), "1000")
+        .set(RssSparkConfig.RSS_STORAGE_TYPE.key(), 
StorageType.LOCALFILE.name())
+        .set(RssSparkConfig.RSS_COORDINATOR_QUORUM.key(), 
"127.0.0.1:12345,127.0.0.1:12346")
+        .set(RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + 
RssSparkConfig.RSS_CLIENT_SEND_SIZE_LIMITATION.key(), "64")
+        .set(RssSparkConfig.RSS_STORAGE_TYPE.key(), 
StorageType.LOCALFILE.name());
 
-      @Override
-      public void onError(Throwable e) {
-      }
-    };
-    eventLoop.start();
+    TaskMemoryManager mockTaskMemoryManager = mock(TaskMemoryManager.class);
+    BufferManagerOptions bufferOptions = new BufferManagerOptions(conf);
+    WriteBufferManager bufferManager = new WriteBufferManager(
+        0, 0, bufferOptions, new KryoSerializer(conf),
+        Maps.newHashMap(), mockTaskMemoryManager, new ShuffleWriteMetrics(), 
RssSparkConfig.toRssConf(conf));
+
+    RssShuffleManager manager = new RssShuffleManager(conf, false);
+    DataPusher dataPusher = new FakedDataPusher(event -> {
+      events.add(event);
+      return CompletableFuture.completedFuture(0L);
+    });
+    manager.setDataPusher(dataPusher);
 
-    when(mockShuffleManager.getEventLoop()).thenReturn(eventLoop);
     RssShuffleHandle<String, String, String> mockHandle = 
mock(RssShuffleHandle.class);
     when(mockHandle.getDependency()).thenReturn(mockDependency);
     ShuffleWriteClient mockWriteClient = mock(ShuffleWriteClient.class);
-    SparkConf conf = new SparkConf();
-    conf.set(RssSparkConfig.RSS_CLIENT_SEND_SIZE_LIMIT.key(), "64")
-        .set(RssSparkConfig.RSS_STORAGE_TYPE.key(), 
StorageType.LOCALFILE.name());
 
     RssShuffleWriter<String, String, String> writer = new 
RssShuffleWriter<>("appId", 0, "taskId", 1L,
-        mockBufferManager, mockMetrics, mockShuffleManager, conf, 
mockWriteClient, mockHandle);
+        bufferManager, mockMetrics, manager, conf, mockWriteClient, 
mockHandle);
     List<ShuffleBlockInfo> shuffleBlockInfoList = createShuffleBlockList(1, 
31);
     writer.postBlockEvent(shuffleBlockInfoList);
     Thread.sleep(500);
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 04b57136..e875132d 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -17,21 +17,20 @@
 
 package org.apache.spark.shuffle;
 
+import java.io.IOException;
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
 import java.util.Set;
+import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ScheduledExecutorService;
-import java.util.concurrent.ThreadPoolExecutor;
 import java.util.concurrent.TimeUnit;
-import java.util.concurrent.atomic.AtomicLong;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.Function;
 import java.util.stream.Collectors;
 
 import com.google.common.annotations.VisibleForTesting;
-import com.google.common.collect.Queues;
 import com.google.common.collect.Sets;
 import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
 import org.apache.hadoop.conf.Configuration;
@@ -46,12 +45,12 @@ import org.apache.spark.executor.ShuffleWriteMetrics;
 import org.apache.spark.shuffle.reader.RssShuffleReader;
 import org.apache.spark.shuffle.writer.AddBlockEvent;
 import org.apache.spark.shuffle.writer.BufferManagerOptions;
+import org.apache.spark.shuffle.writer.DataPusher;
 import org.apache.spark.shuffle.writer.RssShuffleWriter;
 import org.apache.spark.shuffle.writer.WriteBufferManager;
 import org.apache.spark.sql.internal.SQLConf;
 import org.apache.spark.storage.BlockId;
 import org.apache.spark.storage.BlockManagerId;
-import org.apache.spark.util.EventLoop;
 import org.roaringbitmap.longlong.Roaring64NavigableMap;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -62,12 +61,10 @@ import scala.collection.Seq;
 
 import org.apache.uniffle.client.api.ShuffleWriteClient;
 import org.apache.uniffle.client.factory.ShuffleClientFactory;
-import org.apache.uniffle.client.response.SendShuffleDataResult;
 import org.apache.uniffle.client.util.ClientUtils;
 import org.apache.uniffle.common.PartitionRange;
 import org.apache.uniffle.common.RemoteStorageInfo;
 import org.apache.uniffle.common.ShuffleAssignmentsInfo;
-import org.apache.uniffle.common.ShuffleBlockInfo;
 import org.apache.uniffle.common.ShuffleDataDistributionType;
 import org.apache.uniffle.common.ShuffleServerInfo;
 import org.apache.uniffle.common.config.RssClientConf;
@@ -84,7 +81,6 @@ public class RssShuffleManager implements ShuffleManager {
   private final String clientType;
   private final long heartbeatInterval;
   private final long heartbeatTimeout;
-  private final ThreadPoolExecutor threadPoolExecutor;
   private AtomicReference<String> id = new AtomicReference<>();
   private SparkConf sparkConf;
   private final int dataReplica;
@@ -96,7 +92,6 @@ public class RssShuffleManager implements ShuffleManager {
   private ShuffleWriteClient shuffleWriteClient;
   private final Map<String, Set<Long>> taskToSuccessBlockIds;
   private final Map<String, Set<Long>> taskToFailedBlockIds;
-  private Map<String, WriteBufferManager> taskToBufferManager = 
JavaUtils.newConcurrentMap();
   private ScheduledExecutorService heartBeatScheduledExecutorService;
   private boolean heartbeatStarted = false;
   private boolean dynamicConfEnabled = false;
@@ -104,55 +99,7 @@ public class RssShuffleManager implements ShuffleManager {
   private String user;
   private String uuid;
   private Set<String> failedTaskIds = Sets.newConcurrentHashSet();
-  private final EventLoop<AddBlockEvent> eventLoop;
-  private final EventLoop<AddBlockEvent> defaultEventLoop = new 
EventLoop<AddBlockEvent>("ShuffleDataQueue") {
-
-    @Override
-    public void onReceive(AddBlockEvent event) {
-      threadPoolExecutor.execute(() -> sendShuffleData(event.getTaskId(), 
event.getShuffleDataInfoList()));
-    }
-
-    @Override
-    public void onError(Throwable throwable) {
-      LOG.info("Shuffle event loop error...", throwable);
-    }
-
-    @Override
-    public void onStart() {
-      LOG.info("Shuffle event loop start...");
-    }
-
-    private void sendShuffleData(String taskId, List<ShuffleBlockInfo> 
shuffleDataInfoList) {
-      try {
-        SendShuffleDataResult result = shuffleWriteClient.sendShuffleData(
-            id.get(),
-            shuffleDataInfoList,
-            () -> !isValidTask(taskId)
-        );
-        putBlockId(taskToSuccessBlockIds, taskId, result.getSuccessBlockIds());
-        putBlockId(taskToFailedBlockIds, taskId, result.getFailedBlockIds());
-      } finally {
-        final AtomicLong releaseSize = new AtomicLong(0);
-        shuffleDataInfoList.forEach((sbi) -> 
releaseSize.addAndGet(sbi.getFreeMemory()));
-        WriteBufferManager bufferManager = taskToBufferManager.get(taskId);
-        if (bufferManager != null) {
-          bufferManager.freeAllocatedMemory(releaseSize.get());
-        }
-        LOG.debug("Spark 3.0 finish send data and release " + releaseSize + " 
bytes");
-      }
-    }
-
-    private synchronized void putBlockId(
-        Map<String, Set<Long>> taskToBlockIds,
-        String taskAttemptId,
-        Set<Long> blockIds) {
-      if (blockIds == null || blockIds.isEmpty()) {
-        return;
-      }
-      taskToBlockIds.putIfAbsent(taskAttemptId, Sets.newConcurrentHashSet());
-      taskToBlockIds.get(taskAttemptId).addAll(blockIds);
-    }
-  };
+  private DataPusher dataPusher;
 
   public RssShuffleManager(SparkConf conf, boolean isDriver) {
     this.sparkConf = conf;
@@ -211,19 +158,28 @@ public class RssShuffleManager implements ShuffleManager {
     LOG.info("Disable shuffle data locality in RssShuffleManager.");
     taskToSuccessBlockIds = JavaUtils.newConcurrentMap();
     taskToFailedBlockIds = JavaUtils.newConcurrentMap();
-    // for non-driver executor, start a thread for sending shuffle data to 
shuffle server
-    LOG.info("RSS data send thread is starting");
-    eventLoop = defaultEventLoop;
-    eventLoop.start();
-    int poolSize = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_SIZE);
-    int keepAliveTime = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_KEEPALIVE);
-    threadPoolExecutor = new ThreadPoolExecutor(poolSize, poolSize * 2, 
keepAliveTime, TimeUnit.SECONDS,
-        Queues.newLinkedBlockingQueue(Integer.MAX_VALUE),
-        ThreadUtils.getThreadFactory("SendData"));
     if (isDriver) {
       heartBeatScheduledExecutorService =
           ThreadUtils.getDaemonSingleThreadScheduledExecutor("rss-heartbeat");
     }
+    LOG.info("Rss data pusher is starting...");
+    int poolSize = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_SIZE);
+    int keepAliveTime = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_THREAD_POOL_KEEPALIVE);
+    this.dataPusher = new DataPusher(
+        shuffleWriteClient,
+        taskToSuccessBlockIds,
+        taskToFailedBlockIds,
+        failedTaskIds,
+        poolSize,
+        keepAliveTime
+    );
+  }
+
+  public CompletableFuture<Long> sendData(AddBlockEvent event) {
+    if (dataPusher != null && event != null) {
+      return dataPusher.send(event);
+    }
+    return new CompletableFuture<>();
   }
 
   @VisibleForTesting
@@ -242,7 +198,7 @@ public class RssShuffleManager implements ShuffleManager {
   RssShuffleManager(
       SparkConf conf,
       boolean isDriver,
-      EventLoop<AddBlockEvent> loop,
+      DataPusher dataPusher,
       Map<String, Set<Long>> taskToSuccessBlockIds,
       Map<String, Set<Long>> taskToFailedBlockIds) {
     this.sparkConf = conf;
@@ -283,14 +239,8 @@ public class RssShuffleManager implements ShuffleManager {
         );
     this.taskToSuccessBlockIds = taskToSuccessBlockIds;
     this.taskToFailedBlockIds = taskToFailedBlockIds;
-    if (loop != null) {
-      eventLoop = loop;
-    } else {
-      eventLoop = defaultEventLoop;
-    }
-    eventLoop.start();
-    threadPoolExecutor = null;
-    heartBeatScheduledExecutorService = null;
+    this.heartBeatScheduledExecutorService = null;
+    this.dataPusher = dataPusher;
   }
 
   // This method is called in Spark driver side,
@@ -315,6 +265,7 @@ public class RssShuffleManager implements ShuffleManager {
 
     if (id.get() == null) {
       id.compareAndSet(null, SparkEnv.get().conf().getAppId() + "_" + uuid);
+      dataPusher.setRssAppId(id.get());
     }
     LOG.info("Generate application id used in rss: " + id.get());
 
@@ -390,6 +341,7 @@ public class RssShuffleManager implements ShuffleManager {
     // todo: this implement is tricky, we should refactor it
     if (id.get() == null) {
       id.compareAndSet(null, rssHandle.getAppId());
+      dataPusher.setRssAppId(id.get());
     }
     int shuffleId = rssHandle.getShuffleId();
     String taskId = "" + context.taskAttemptId() + "_" + 
context.attemptNumber();
@@ -401,10 +353,9 @@ public class RssShuffleManager implements ShuffleManager {
       writeMetrics = context.taskMetrics().shuffleWriteMetrics();
     }
     WriteBufferManager bufferManager = new WriteBufferManager(
-        shuffleId, context.taskAttemptId(), bufferOptions, 
rssHandle.getDependency().serializer(),
+        shuffleId, taskId, context.taskAttemptId(), bufferOptions, 
rssHandle.getDependency().serializer(),
         rssHandle.getPartitionToServers(), context.taskMemoryManager(),
-        writeMetrics, RssSparkConfig.toRssConf(sparkConf));
-    taskToBufferManager.put(taskId, bufferManager);
+        writeMetrics, RssSparkConfig.toRssConf(sparkConf), this::sendData);
     LOG.info("RssHandle appId {} shuffleId {} ", rssHandle.getAppId(), 
rssHandle.getShuffleId());
     return new RssShuffleWriter<>(rssHandle.getAppId(), shuffleId, taskId, 
context.taskAttemptId(), bufferManager,
         writeMetrics, this, sparkConf, shuffleWriteClient, rssHandle,
@@ -660,21 +611,21 @@ public class RssShuffleManager implements ShuffleManager {
     if (heartBeatScheduledExecutorService != null) {
       heartBeatScheduledExecutorService.shutdownNow();
     }
-    if (threadPoolExecutor != null) {
-      threadPoolExecutor.shutdownNow();
-    }
     if (shuffleWriteClient != null) {
       shuffleWriteClient.close();
     }
-    if (eventLoop != null) {
-      eventLoop.stop();
+    if (dataPusher != null) {
+      try {
+        dataPusher.close();
+      } catch (IOException e) {
+        LOG.warn("Errors on closing data pusher", e);
+      }
     }
   }
 
   public void clearTaskMeta(String taskId) {
     taskToSuccessBlockIds.remove(taskId);
     taskToFailedBlockIds.remove(taskId);
-    taskToBufferManager.remove(taskId);
   }
 
   @VisibleForTesting
@@ -736,12 +687,6 @@ public class RssShuffleManager implements ShuffleManager {
     }
   }
 
-  public void postEvent(AddBlockEvent addBlockEvent) {
-    if (eventLoop != null) {
-      eventLoop.post(addBlockEvent);
-    }
-  }
-
   public Set<Long> getFailedBlockIds(String taskId) {
     Set<Long> result = taskToFailedBlockIds.get(taskId);
     if (result == null) {
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 4784c390..fc1e32f5 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -75,7 +75,6 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
   private final boolean shouldPartition;
   private final long sendCheckTimeout;
   private final long sendCheckInterval;
-  private final long sendSizeLimit;
   private final int bitmapSplitNum;
   private final Map<Integer, Set<Long>> partitionToBlockIds;
   private final ShuffleWriteClient shuffleWriteClient;
@@ -137,8 +136,6 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     this.shouldPartition = partitioner.numPartitions() > 1;
     this.sendCheckTimeout = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_CHECK_TIMEOUT_MS);
     this.sendCheckInterval = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_SEND_CHECK_INTERVAL_MS);
-    this.sendSizeLimit = 
sparkConf.getSizeAsBytes(RssSparkConfig.RSS_CLIENT_SEND_SIZE_LIMIT.key(),
-        RssSparkConfig.RSS_CLIENT_SEND_SIZE_LIMIT.defaultValue().get());
     this.bitmapSplitNum = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_BITMAP_SPLIT_NUM);
     this.partitionToBlockIds = Maps.newHashMap();
     this.shuffleWriteClient = shuffleWriteClient;
@@ -231,26 +228,8 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
   }
 
   protected void postBlockEvent(List<ShuffleBlockInfo> shuffleBlockInfoList) {
-    long totalSize = 0;
-    List<ShuffleBlockInfo> shuffleBlockInfosPerEvent = Lists.newArrayList();
-    for (ShuffleBlockInfo sbi : shuffleBlockInfoList) {
-      totalSize += sbi.getSize();
-      shuffleBlockInfosPerEvent.add(sbi);
-      // split shuffle data according to the size
-      if (totalSize > sendSizeLimit) {
-        LOG.debug("Post event to queue with " + 
shuffleBlockInfosPerEvent.size()
-            + " blocks and " + totalSize + " bytes");
-        shuffleManager.postEvent(
-            new AddBlockEvent(taskId, shuffleBlockInfosPerEvent));
-        shuffleBlockInfosPerEvent = Lists.newArrayList();
-        totalSize = 0;
-      }
-    }
-    if (!shuffleBlockInfosPerEvent.isEmpty()) {
-      LOG.debug("Post event to queue with " + shuffleBlockInfosPerEvent.size()
-          + " blocks and " + totalSize + " bytes");
-      shuffleManager.postEvent(
-          new AddBlockEvent(taskId, shuffleBlockInfosPerEvent));
+    for (AddBlockEvent event : 
bufferManager.buildBlockEvents(shuffleBlockInfoList)) {
+      shuffleManager.sendData(event);
     }
   }
 
diff --git 
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/TestUtils.java 
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/TestUtils.java
index 033966a7..1c7f988a 100644
--- a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/TestUtils.java
+++ b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/TestUtils.java
@@ -22,8 +22,7 @@ import java.util.Set;
 
 import org.apache.commons.lang3.SystemUtils;
 import org.apache.spark.SparkConf;
-import org.apache.spark.shuffle.writer.AddBlockEvent;
-import org.apache.spark.util.EventLoop;
+import org.apache.spark.shuffle.writer.DataPusher;
 
 public class TestUtils {
 
@@ -33,10 +32,10 @@ public class TestUtils {
   public static RssShuffleManager createShuffleManager(
       SparkConf conf,
       Boolean isDriver,
-      EventLoop<AddBlockEvent> loop,
+      DataPusher dataPusher,
       Map<String, Set<Long>> successBlockIds,
       Map<String, Set<Long>> failBlockIds) {
-    return new RssShuffleManager(conf, isDriver, loop, successBlockIds, 
failBlockIds);
+    return new RssShuffleManager(conf, isDriver, dataPusher, successBlockIds, 
failBlockIds);
   }
 
   public static boolean isMacOnAppleSilicon() {
diff --git 
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
 
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
index f446cf3a..25c76497 100644
--- 
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
+++ 
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
@@ -17,11 +17,13 @@
 
 package org.apache.spark.shuffle.writer;
 
-
+import java.time.Duration;
 import java.util.Arrays;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.concurrent.CompletableFuture;
+import java.util.function.Function;
 import java.util.stream.Collectors;
 
 import com.google.common.collect.Lists;
@@ -40,7 +42,7 @@ import org.apache.spark.shuffle.RssShuffleHandle;
 import org.apache.spark.shuffle.RssShuffleManager;
 import org.apache.spark.shuffle.RssSparkConfig;
 import org.apache.spark.shuffle.TestUtils;
-import org.apache.spark.util.EventLoop;
+import org.awaitility.Awaitility;
 import org.junit.jupiter.api.Test;
 import scala.Product2;
 import scala.Tuple2;
@@ -49,7 +51,6 @@ import scala.collection.mutable.MutableList;
 import org.apache.uniffle.client.api.ShuffleWriteClient;
 import org.apache.uniffle.common.ShuffleBlockInfo;
 import org.apache.uniffle.common.ShuffleServerInfo;
-import org.apache.uniffle.common.config.RssConf;
 import org.apache.uniffle.common.util.JavaUtils;
 import org.apache.uniffle.storage.util.StorageType;
 
@@ -102,7 +103,7 @@ public class RssShuffleWriterTest {
     BufferManagerOptions bufferOptions = new BufferManagerOptions(conf);
     WriteBufferManager bufferManager = new WriteBufferManager(
         0, 0, bufferOptions, kryoSerializer,
-        Maps.newHashMap(), mockTaskMemoryManager, new ShuffleWriteMetrics(), 
new RssConf());
+        Maps.newHashMap(), mockTaskMemoryManager, new ShuffleWriteMetrics(), 
RssSparkConfig.toRssConf(conf));
     WriteBufferManager bufferManagerSpy = spy(bufferManager);
 
     RssShuffleWriter<String, String, String> rssShuffleWriter = new 
RssShuffleWriter<>("appId", 0, "taskId", 1L,
@@ -135,6 +136,32 @@ public class RssShuffleWriterTest {
     sc.stop();
   }
 
+  static class FakedDataPusher extends DataPusher {
+    private final Function<AddBlockEvent, CompletableFuture<Long>> sendFunc;
+
+    FakedDataPusher(Function<AddBlockEvent, CompletableFuture<Long>> sendFunc) 
{
+      this(null, null, null, null, 1, 1, sendFunc);
+    }
+
+    private FakedDataPusher(
+        ShuffleWriteClient shuffleWriteClient,
+        Map<String, Set<Long>> taskToSuccessBlockIds,
+        Map<String, Set<Long>> taskToFailedBlockIds,
+        Set<String> failedTaskIds,
+        int threadPoolSize,
+        int threadKeepAliveTime,
+        Function<AddBlockEvent, CompletableFuture<Long>> sendFunc) {
+      super(shuffleWriteClient, taskToSuccessBlockIds, taskToFailedBlockIds, 
failedTaskIds, threadPoolSize,
+          threadKeepAliveTime);
+      this.sendFunc = sendFunc;
+    }
+
+    @Override
+    public CompletableFuture<Long> send(AddBlockEvent event) {
+      return sendFunc.apply(event);
+    }
+  }
+
   @Test
   public void writeTest() throws Exception {
     SparkConf conf = new SparkConf();
@@ -151,27 +178,24 @@ public class RssShuffleWriterTest {
     // init SparkContext
     List<ShuffleBlockInfo> shuffleBlockInfos = Lists.newArrayList();
     final SparkContext sc = SparkContext.getOrCreate(conf);
-    Map<String, Set<Long>> successBlockIds = JavaUtils.newConcurrentMap();
-    EventLoop<AddBlockEvent> testLoop = new EventLoop<AddBlockEvent>("test") {
-      @Override
-      public void onReceive(AddBlockEvent event) {
-        assertEquals("taskId", event.getTaskId());
-        shuffleBlockInfos.addAll(event.getShuffleDataInfoList());
-        Set<Long> blockIds = event.getShuffleDataInfoList().parallelStream()
-            .map(sdi -> sdi.getBlockId()).collect(Collectors.toSet());
-        successBlockIds.putIfAbsent(event.getTaskId(), 
Sets.newConcurrentHashSet());
-        successBlockIds.get(event.getTaskId()).addAll(blockIds);
-      }
-
-      @Override
-      public void onError(Throwable e) {
-      }
-    };
+    Map<String, Set<Long>> successBlockIds = Maps.newConcurrentMap();
+
+    FakedDataPusher dataPusher = new FakedDataPusher(
+        event -> {
+          assertEquals("taskId", event.getTaskId());
+          shuffleBlockInfos.addAll(event.getShuffleDataInfoList());
+          Set<Long> blockIds = event.getShuffleDataInfoList().parallelStream()
+              .map(sdi -> sdi.getBlockId()).collect(Collectors.toSet());
+          successBlockIds.putIfAbsent(event.getTaskId(), 
Sets.newConcurrentHashSet());
+          successBlockIds.get(event.getTaskId()).addAll(blockIds);
+          return new CompletableFuture<>();
+        }
+    );
 
     final RssShuffleManager manager = TestUtils.createShuffleManager(
         conf,
         false,
-        testLoop,
+        dataPusher,
         successBlockIds,
         JavaUtils.newConcurrentMap());
     Serializer kryoSerializer = new KryoSerializer(conf);
@@ -210,7 +234,11 @@ public class RssShuffleWriterTest {
     ShuffleWriteMetrics shuffleWriteMetrics = new ShuffleWriteMetrics();
     WriteBufferManager bufferManager = new WriteBufferManager(
         0, 0, bufferOptions, kryoSerializer,
-        partitionToServers, mockTaskMemoryManager, shuffleWriteMetrics, new 
RssConf());
+        partitionToServers, mockTaskMemoryManager, shuffleWriteMetrics,
+        RssSparkConfig.toRssConf(conf)
+    );
+    bufferManager.setTaskId("taskId");
+
     WriteBufferManager bufferManagerSpy = spy(bufferManager);
     RssShuffleWriter<String, String, String> rssShuffleWriter = new 
RssShuffleWriter<>("appId", 0, "taskId", 1L,
         bufferManagerSpy, shuffleWriteMetrics, manager, conf, 
mockShuffleWriteClient, mockHandle);
@@ -265,7 +293,16 @@ public class RssShuffleWriterTest {
 
   @Test
   public void postBlockEventTest() throws Exception {
-    WriteBufferManager mockBufferManager = mock(WriteBufferManager.class);
+    SparkConf conf = new SparkConf();
+    conf.set(RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + 
RssSparkConfig.RSS_CLIENT_SEND_SIZE_LIMITATION.key(), "64")
+        .set(RssSparkConfig.RSS_STORAGE_TYPE.key(), 
StorageType.MEMORY_LOCALFILE.name());
+
+    BufferManagerOptions bufferOptions = new BufferManagerOptions(conf);
+    WriteBufferManager bufferManager = new WriteBufferManager(
+        0, 0, bufferOptions, new KryoSerializer(conf),
+        Maps.newHashMap(), mock(TaskMemoryManager.class), new 
ShuffleWriteMetrics(), RssSparkConfig.toRssConf(conf));
+    WriteBufferManager bufferManagerSpy = spy(bufferManager);
+
     ShuffleDependency<String, String, String> mockDependency = 
mock(ShuffleDependency.class);
     ShuffleWriteMetrics mockMetrics = mock(ShuffleWriteMetrics.class);
     Partitioner mockPartitioner = mock(Partitioner.class);
@@ -274,35 +311,39 @@ public class RssShuffleWriterTest {
     when(mockPartitioner.numPartitions()).thenReturn(2);
     List<AddBlockEvent> events = Lists.newArrayList();
 
-    EventLoop<AddBlockEvent> eventLoop = new EventLoop<AddBlockEvent>("test") {
-      @Override
-      public void onReceive(AddBlockEvent event) {
-        events.add(event);
-      }
+    FakedDataPusher dataPusher = new FakedDataPusher(
+        event -> {
+          events.add(event);
+          return new CompletableFuture<>();
+        }
+    );
 
-      @Override
-      public void onError(Throwable e) {
-      }
-    };
     RssShuffleManager mockShuffleManager = spy(TestUtils.createShuffleManager(
         sparkConf,
         false,
-        eventLoop,
-        JavaUtils.newConcurrentMap(),
-        JavaUtils.newConcurrentMap()));
+        dataPusher,
+        Maps.newConcurrentMap(),
+        Maps.newConcurrentMap()));
 
     RssShuffleHandle<String, String, String> mockHandle = 
mock(RssShuffleHandle.class);
     when(mockHandle.getDependency()).thenReturn(mockDependency);
     ShuffleWriteClient mockWriteClient = mock(ShuffleWriteClient.class);
-    SparkConf conf = new SparkConf();
-    conf.set(RssSparkConfig.RSS_CLIENT_SEND_SIZE_LIMIT.key(), "64")
-        .set(RssSparkConfig.RSS_STORAGE_TYPE.key(), 
StorageType.MEMORY_LOCALFILE.name());
+
     List<ShuffleBlockInfo> shuffleBlockInfoList = createShuffleBlockList(1, 
31);
-    RssShuffleWriter<String, String, String> writer = new 
RssShuffleWriter<>("appId", 0, "taskId", 1L,
-        mockBufferManager, mockMetrics, mockShuffleManager, conf, 
mockWriteClient, mockHandle);
+    RssShuffleWriter<String, String, String> writer = new RssShuffleWriter<>(
+        "appId",
+        0,
+        "taskId",
+        1L,
+        bufferManagerSpy,
+        mockMetrics,
+        mockShuffleManager,
+        conf,
+        mockWriteClient,
+        mockHandle
+    );
     writer.postBlockEvent(shuffleBlockInfoList);
-    Thread.sleep(500);
-    assertEquals(1, events.size());
+    Awaitility.await().timeout(Duration.ofSeconds(1)).until(() -> 
events.size() == 1);
     assertEquals(1, events.get(0).getShuffleDataInfoList().size());
     events.clear();
 
diff --git 
a/common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java 
b/common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java
index b0ed870e..eb9173ba 100644
--- a/common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java
+++ b/common/src/main/java/org/apache/uniffle/common/util/ThreadUtils.java
@@ -22,15 +22,19 @@ import java.util.concurrent.Executors;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.ScheduledThreadPoolExecutor;
 import java.util.concurrent.ThreadFactory;
+import java.util.concurrent.TimeUnit;
 
 import com.google.common.util.concurrent.ThreadFactoryBuilder;
 import io.netty.util.concurrent.DefaultThreadFactory;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
-/**
- * Provide a general method to create a thread factory to make the code more 
standardized
- */
 public class ThreadUtils {
+  private static final Logger LOGGER = 
LoggerFactory.getLogger(ThreadUtils.class);
 
+  /**
+   * Provide a general method to create a thread factory to make the code more 
standardized
+   */
   public static ThreadFactory getThreadFactory(String factoryName) {
     return new 
ThreadFactoryBuilder().setDaemon(true).setNameFormat(factoryName + 
"-%d").build();
   }
@@ -74,4 +78,17 @@ public class ThreadUtils {
   public static ExecutorService getDaemonCachedThreadPool(String factoryName) {
     return Executors.newCachedThreadPool(getThreadFactory(factoryName));
   }
+
+  public static void shutdownThreadPool(ExecutorService threadPool, int 
waitSec) throws InterruptedException {
+    if (threadPool == null) {
+      return;
+    }
+    threadPool.shutdown();
+    if (!threadPool.awaitTermination(waitSec, TimeUnit.SECONDS)) {
+      threadPool.shutdownNow();
+      if (!threadPool.awaitTermination(waitSec, TimeUnit.SECONDS)) {
+        LOGGER.warn("Thread pool don't stop gracefully.");
+      }
+    }
+  }
 }
diff --git 
a/common/src/test/java/org/apache/uniffle/common/util/ThreadUtilsTest.java 
b/common/src/test/java/org/apache/uniffle/common/util/ThreadUtilsTest.java
new file mode 100644
index 00000000..7fafa442
--- /dev/null
+++ b/common/src/test/java/org/apache/uniffle/common/util/ThreadUtilsTest.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.util;
+
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+import org.junit.jupiter.api.Test;
+
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class ThreadUtilsTest {
+
+  @Test
+  public void shutdownThreadPoolTest() throws InterruptedException {
+    ExecutorService executorService = Executors.newFixedThreadPool(2);
+    AtomicBoolean finished = new AtomicBoolean(false);
+    executorService.submit(() -> {
+      try {
+        Thread.sleep(100000);
+      } catch (InterruptedException interruptedException) {
+        // ignore
+      } finally {
+        finished.set(true);
+      }
+    });
+    ThreadUtils.shutdownThreadPool(executorService, 1);
+    assertTrue(finished.get());
+    assertTrue(executorService.isShutdown());
+  }
+}

Reply via email to