This is an automated email from the ASF dual-hosted git repository.

roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new e127253c [Improvement] Task fast fail once blocks fail to send (#332)
e127253c is described below

commit e127253c44e4a89f93d22a6999274f442cfb0c1c
Author: Junfan Zhang <[email protected]>
AuthorDate: Fri Nov 25 00:05:14 2022 +0800

    [Improvement] Task fast fail once blocks fail to send (#332)
    
    ### What changes were proposed in this pull request?
    [Improvement] Task fast fail once blocks fail to send
    
    1. In single replica mechanism, single one batch data sent failed should 
make task fast fail.
    2. When some remaining block events in dataTransferPool wait to be sent, we 
should abandon it.
    3. More precisely, we need to interrupt send requests belonged to failed 
tasks. (Using the `GrpcFuture` to cancel it)
    
    ### Why are the changes needed?
    1. When shuffle-sever is down, in current codebase, the shuffle-write 
client will block and retry too much times. Actually, it should fast fail once 
partial blocks fail to send.
    2. When using the custom retry policy in rpc layer like #308 , this PR will 
solve the potential problem of waiting too long wait time when specifying the  
1min  retry time.
    
    After this patch, fail time is limited in 2min. Before, it will be for 
10min+
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    1. UTs
    2. Online real tests
---
 .../hadoop/mapred/SortWriteBufferManager.java      |  2 +-
 .../hadoop/mapred/SortWriteBufferManagerTest.java  |  4 +-
 .../hadoop/mapreduce/task/reduce/FetcherTest.java  |  4 +-
 .../apache/spark/shuffle/RssShuffleManager.java    | 20 ++++++-
 .../spark/shuffle/writer/RssShuffleWriter.java     | 39 ++++++++++++
 .../apache/spark/shuffle/RssShuffleManager.java    | 21 ++++++-
 .../spark/shuffle/writer/RssShuffleWriter.java     | 64 ++++++++++++++++----
 .../uniffle/client/api/ShuffleWriteClient.java     |  4 +-
 .../client/impl/ShuffleWriteClientImpl.java        | 69 +++++++++++++++++-----
 .../apache/uniffle/client/util/ClientUtils.java    | 44 ++++++++++++++
 .../org/apache/uniffle/client/ClientUtilsTest.java | 60 +++++++++++++++++++
 .../client/impl/ShuffleWriteClientImplTest.java    | 30 +++++++++-
 .../java/org/apache/uniffle/test/QuorumTest.java   | 18 +++++-
 .../apache/uniffle/test/ShuffleServerGrpcTest.java |  2 +-
 .../uniffle/test/ShuffleWithRssClientTest.java     |  4 +-
 .../client/impl/grpc/ShuffleServerGrpcClient.java  |  4 +-
 16 files changed, 348 insertions(+), 41 deletions(-)

diff --git 
a/client-mr/src/main/java/org/apache/hadoop/mapred/SortWriteBufferManager.java 
b/client-mr/src/main/java/org/apache/hadoop/mapred/SortWriteBufferManager.java
index 36ade47e..36d5ab96 100644
--- 
a/client-mr/src/main/java/org/apache/hadoop/mapred/SortWriteBufferManager.java
+++ 
b/client-mr/src/main/java/org/apache/hadoop/mapred/SortWriteBufferManager.java
@@ -240,7 +240,7 @@ public class SortWriteBufferManager<K, V> {
           for (ShuffleBlockInfo block : shuffleBlocks) {
             size += block.getFreeMemory();
           }
-          SendShuffleDataResult result = 
shuffleWriteClient.sendShuffleData(appId, shuffleBlocks);
+          SendShuffleDataResult result = 
shuffleWriteClient.sendShuffleData(appId, shuffleBlocks, () -> false);
           successBlockIds.addAll(result.getSuccessBlockIds());
           failedBlockIds.addAll(result.getFailedBlockIds());
         } catch (Throwable t) {
diff --git 
a/client-mr/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
 
b/client-mr/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
index 30267515..1e07335b 100644
--- 
a/client-mr/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
+++ 
b/client-mr/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
@@ -21,6 +21,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Random;
 import java.util.Set;
+import java.util.function.Supplier;
 
 import com.google.common.collect.Maps;
 import com.google.common.collect.Sets;
@@ -266,7 +267,8 @@ public class SortWriteBufferManagerTest {
     }
 
     @Override
-    public SendShuffleDataResult sendShuffleData(String appId, 
List<ShuffleBlockInfo> shuffleBlockInfoList) {
+    public SendShuffleDataResult sendShuffleData(String appId, 
List<ShuffleBlockInfo> shuffleBlockInfoList,
+        Supplier<Boolean> needCancelRequest) {
       if (mode == 0) {
         throw new RssException("send data failed");
       } else if (mode == 1) {
diff --git 
a/client-mr/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
 
b/client-mr/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
index 04e816c2..7c7ce990 100644
--- 
a/client-mr/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
+++ 
b/client-mr/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
@@ -25,6 +25,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.TreeMap;
+import java.util.function.Supplier;
 
 import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
@@ -352,7 +353,8 @@ public class FetcherTest {
     public List<byte[]> data = new LinkedList<>();
 
     @Override
-    public SendShuffleDataResult sendShuffleData(String appId, 
List<ShuffleBlockInfo> shuffleBlockInfoList) {
+    public SendShuffleDataResult sendShuffleData(String appId, 
List<ShuffleBlockInfo> shuffleBlockInfoList,
+        Supplier<Boolean> needCancelRequest) {
       if (mode == 0) {
         throw new RssException("send data failed");
       } else if (mode == 1) {
diff --git 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index e71ac82b..eea58970 100644
--- 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -25,6 +25,7 @@ import java.util.concurrent.Executors;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.ThreadPoolExecutor;
 import java.util.concurrent.TimeUnit;
+import java.util.function.Function;
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.Maps;
@@ -86,6 +87,7 @@ public class RssShuffleManager implements ShuffleManager {
   private final boolean dataReplicaSkipEnabled;
   private final int dataTransferPoolSize;
   private final int dataCommitPoolSize;
+  private Set<String> failedTaskIds = Sets.newConcurrentHashSet();
   private boolean heartbeatStarted = false;
   private boolean dynamicConfEnabled = false;
   private final String user;
@@ -100,7 +102,11 @@ public class RssShuffleManager implements ShuffleManager {
 
     private void sendShuffleData(String taskId, List<ShuffleBlockInfo> 
shuffleDataInfoList) {
       try {
-        SendShuffleDataResult result = 
shuffleWriteClient.sendShuffleData(appId, shuffleDataInfoList);
+        SendShuffleDataResult result = shuffleWriteClient.sendShuffleData(
+            appId,
+            shuffleDataInfoList,
+            () -> !isValidTask(taskId)
+        );
         putBlockId(taskToSuccessBlockIds, taskId, result.getSuccessBlockIds());
         putBlockId(taskToFailedBlockIds, taskId, result.getFailedBlockIds());
       } finally {
@@ -353,7 +359,8 @@ public class RssShuffleManager implements ShuffleManager {
       taskToBufferManager.put(taskId, bufferManager);
 
       return new RssShuffleWriter(rssHandle.getAppId(), shuffleId, taskId, 
context.taskAttemptId(), bufferManager,
-          writeMetrics, this, sparkConf, shuffleWriteClient, rssHandle);
+          writeMetrics, this, sparkConf, shuffleWriteClient, rssHandle,
+          (Function<String, Boolean>) tid -> markFailedTask(tid));
     } else {
       throw new RuntimeException("Unexpected ShuffleHandle:" + 
handle.getClass().getName());
     }
@@ -523,4 +530,13 @@ public class RssShuffleManager implements ShuffleManager {
     this.appId = appId;
   }
 
+  public boolean markFailedTask(String taskId) {
+    LOG.info("Mark the task: {} failed.", taskId);
+    failedTaskIds.add(taskId);
+    return true;
+  }
+
+  public boolean isValidTask(String taskId) {
+    return !failedTaskIds.contains(taskId);
+  }
 }
diff --git 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 8313ebb2..68a447d3 100644
--- 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++ 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -25,6 +25,7 @@ import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
 import java.util.concurrent.TimeUnit;
+import java.util.function.Function;
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.Lists;
@@ -82,6 +83,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
   private long sendCheckInterval;
   private long sendSizeLimit;
   private boolean isMemoryShuffleEnabled;
+  private final Function<String, Boolean> taskFailureCallback;
 
   public RssShuffleWriter(
       String appId,
@@ -94,6 +96,33 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
       SparkConf sparkConf,
       ShuffleWriteClient shuffleWriteClient,
       RssShuffleHandle rssHandle) {
+    this(
+        appId,
+        shuffleId,
+        taskId,
+        taskAttemptId,
+        bufferManager,
+        shuffleWriteMetrics,
+        shuffleManager,
+        sparkConf,
+        shuffleWriteClient,
+        rssHandle,
+        (tid) -> true
+    );
+  }
+
+  public RssShuffleWriter(
+      String appId,
+      int shuffleId,
+      String taskId,
+      long taskAttemptId,
+      WriteBufferManager bufferManager,
+      ShuffleWriteMetrics shuffleWriteMetrics,
+      RssShuffleManager shuffleManager,
+      SparkConf sparkConf,
+      ShuffleWriteClient shuffleWriteClient,
+      RssShuffleHandle rssHandle,
+      Function<String, Boolean> taskFailureCallback) {
     this.appId = appId;
     this.bufferManager = bufferManager;
     this.shuffleId = shuffleId;
@@ -116,6 +145,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     this.partitionToServers = rssHandle.getPartitionToServers();
     this.isMemoryShuffleEnabled = isMemoryShuffleEnabled(
         sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key()));
+    this.taskFailureCallback = taskFailureCallback;
   }
 
   private boolean isMemoryShuffleEnabled(String storageType) {
@@ -133,6 +163,15 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
 
   @Override
   public void write(Iterator<Product2<K, V>> records) {
+    try {
+      writeImpl(records);
+    } catch (Exception e) {
+      taskFailureCallback.apply(taskId);
+      throw e;
+    }
+  }
+
+  private void writeImpl(Iterator<Product2<K,V>> records) {
     List<ShuffleBlockInfo> shuffleBlockInfos = null;
     Set<Long> blockIds = Sets.newConcurrentHashSet();
     while (records.hasNext()) {
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 0ceb53fa..3a6033ad 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -27,6 +27,7 @@ import java.util.concurrent.ThreadPoolExecutor;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicLong;
 import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.Function;
 import java.util.stream.Collectors;
 
 import com.google.common.annotations.VisibleForTesting;
@@ -98,6 +99,7 @@ public class RssShuffleManager implements ShuffleManager {
   private final ShuffleDataDistributionType dataDistributionType;
   private String user;
   private String uuid;
+  private Set<String> failedTaskIds = Sets.newConcurrentHashSet();
   private final EventLoop eventLoop;
   private final EventLoop defaultEventLoop = new 
EventLoop<AddBlockEvent>("ShuffleDataQueue") {
 
@@ -118,7 +120,11 @@ public class RssShuffleManager implements ShuffleManager {
 
     private void sendShuffleData(String taskId, List<ShuffleBlockInfo> 
shuffleDataInfoList) {
       try {
-        SendShuffleDataResult result = 
shuffleWriteClient.sendShuffleData(id.get(), shuffleDataInfoList);
+        SendShuffleDataResult result = shuffleWriteClient.sendShuffleData(
+            id.get(),
+            shuffleDataInfoList,
+            () -> !isValidTask(taskId)
+        );
         putBlockId(taskToSuccessBlockIds, taskId, result.getSuccessBlockIds());
         putBlockId(taskToFailedBlockIds, taskId, result.getFailedBlockIds());
       } finally {
@@ -369,7 +375,8 @@ public class RssShuffleManager implements ShuffleManager {
     taskToBufferManager.put(taskId, bufferManager);
     LOG.info("RssHandle appId {} shuffleId {} ", rssHandle.getAppId(), 
rssHandle.getShuffleId());
     return new RssShuffleWriter(rssHandle.getAppId(), shuffleId, taskId, 
context.taskAttemptId(), bufferManager,
-        writeMetrics, this, sparkConf, shuffleWriteClient, rssHandle);
+        writeMetrics, this, sparkConf, shuffleWriteClient, rssHandle,
+        (Function<String, Boolean>) tid -> markFailedTask(tid));
   }
 
   @Override
@@ -772,4 +779,14 @@ public class RssShuffleManager implements ShuffleManager {
   public String getId() {
     return id.get();
   }
+
+  public boolean markFailedTask(String taskId) {
+    LOG.info("Mark the task: {} failed.", taskId);
+    failedTaskIds.add(taskId);
+    return true;
+  }
+
+  public boolean isValidTask(String taskId) {
+    return !failedTaskIds.contains(taskId);
+  }
 }
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 0a6cc324..5ea1a54f 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -26,6 +26,7 @@ import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
 import java.util.concurrent.TimeUnit;
+import java.util.function.Function;
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.Lists;
@@ -82,6 +83,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
   private final Set shuffleServersForData;
   private final long[] partitionLengths;
   private boolean isMemoryShuffleEnabled;
+  private final Function<String, Boolean> taskFailureCallback;
 
   public RssShuffleWriter(
       String appId,
@@ -94,6 +96,33 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
       SparkConf sparkConf,
       ShuffleWriteClient shuffleWriteClient,
       RssShuffleHandle rssHandle) {
+    this(
+        appId,
+        shuffleId,
+        taskId,
+        taskAttemptId,
+        bufferManager,
+        shuffleWriteMetrics,
+        shuffleManager,
+        sparkConf,
+        shuffleWriteClient,
+        rssHandle,
+        (tid) -> true
+    );
+  }
+
+  public RssShuffleWriter(
+      String appId,
+      int shuffleId,
+      String taskId,
+      long taskAttemptId,
+      WriteBufferManager bufferManager,
+      ShuffleWriteMetrics shuffleWriteMetrics,
+      RssShuffleManager shuffleManager,
+      SparkConf sparkConf,
+      ShuffleWriteClient shuffleWriteClient,
+      RssShuffleHandle rssHandle,
+      Function<String, Boolean> taskFailureCallback) {
     LOG.warn("RssShuffle start write taskAttemptId data" + taskAttemptId);
     this.shuffleManager = shuffleManager;
     this.appId = appId;
@@ -119,6 +148,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     partitionToServers = rssHandle.getPartitionToServers();
     this.isMemoryShuffleEnabled = isMemoryShuffleEnabled(
         sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key()));
+    this.taskFailureCallback = taskFailureCallback;
   }
 
   private boolean isMemoryShuffleEnabled(String storageType) {
@@ -127,9 +157,21 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
 
   @Override
   public void write(Iterator<Product2<K, V>> records) throws IOException {
+    try {
+      writeImpl(records);
+    } catch (Exception e) {
+      taskFailureCallback.apply(taskId);
+      throw e;
+    }
+  }
+
+  private void writeImpl(Iterator<Product2<K,V>> records) {
     List<ShuffleBlockInfo> shuffleBlockInfos = null;
     Set<Long> blockIds = Sets.newConcurrentHashSet();
     while (records.hasNext()) {
+      // Task should fast fail when sending data failed
+      checkIfBlocksFailed();
+
       Product2<K, V> record = records.next();
       K key = record._1();
       int partition = getPartition(key);
@@ -214,17 +256,8 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
   protected void checkBlockSendResult(Set<Long> blockIds) throws 
RuntimeException {
     long start = System.currentTimeMillis();
     while (true) {
+      checkIfBlocksFailed();
       Set<Long> successBlockIds = shuffleManager.getSuccessBlockIds(taskId);
-      Set<Long> failedBlockIds = shuffleManager.getFailedBlockIds(taskId);
-
-      if (!failedBlockIds.isEmpty()) {
-        String errorMsg = "Send failed: Task[" + taskId + "]"
-            + " failed because " + failedBlockIds.size()
-            + " blocks can't be sent to shuffle server.";
-        LOG.error(errorMsg);
-        throw new RssException(errorMsg);
-      }
-
       blockIds.removeAll(successBlockIds);
       if (blockIds.isEmpty()) {
         break;
@@ -240,6 +273,17 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     }
   }
 
+  private void checkIfBlocksFailed() {
+    Set<Long> failedBlockIds = shuffleManager.getFailedBlockIds(taskId);
+    if (!failedBlockIds.isEmpty()) {
+      String errorMsg = "Send failed: Task[" + taskId + "]"
+          + " failed because " + failedBlockIds.size()
+          + " blocks can't be sent to shuffle server.";
+      LOG.error(errorMsg);
+      throw new RssException(errorMsg);
+    }
+  }
+
   @VisibleForTesting
   protected void sendCommit() {
     ExecutorService executor = Executors.newSingleThreadExecutor();
diff --git 
a/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java 
b/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java
index f2443bdb..61dbd1b5 100644
--- a/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java
+++ b/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java
@@ -20,6 +20,7 @@ package org.apache.uniffle.client.api;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.function.Supplier;
 
 import org.roaringbitmap.longlong.Roaring64NavigableMap;
 
@@ -33,7 +34,8 @@ import org.apache.uniffle.common.ShuffleServerInfo;
 
 public interface ShuffleWriteClient {
 
-  SendShuffleDataResult sendShuffleData(String appId, List<ShuffleBlockInfo> 
shuffleBlockInfoList);
+  SendShuffleDataResult sendShuffleData(String appId, List<ShuffleBlockInfo> 
shuffleBlockInfoList,
+      Supplier<Boolean> needCancelRequest);
 
   void sendAppHeartbeat(String appId, long timeoutMs);
 
diff --git 
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
 
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
index 80096a9c..def845fd 100644
--- 
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
+++ 
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
@@ -17,18 +17,20 @@
 
 package org.apache.uniffle.client.impl;
 
+import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.Callable;
+import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.ForkJoinPool;
 import java.util.concurrent.Future;
 import java.util.concurrent.TimeUnit;
-import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.Supplier;
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.Lists;
@@ -72,6 +74,7 @@ import 
org.apache.uniffle.client.response.RssSendCommitResponse;
 import org.apache.uniffle.client.response.RssSendShuffleDataResponse;
 import org.apache.uniffle.client.response.RssUnregisterShuffleResponse;
 import org.apache.uniffle.client.response.SendShuffleDataResult;
+import org.apache.uniffle.client.util.ClientUtils;
 import org.apache.uniffle.common.PartitionRange;
 import org.apache.uniffle.common.RemoteStorageInfo;
 import org.apache.uniffle.common.ShuffleAssignmentsInfo;
@@ -98,7 +101,7 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
   private int replicaRead;
   private boolean replicaSkipEnabled;
   private int dataCommitPoolSize = -1;
-  private final ForkJoinPool dataTransferPool;
+  private final ExecutorService dataTransferPool;
   private final int unregisterThreadPoolSize;
   private final int unregisterRequestTimeSec;
 
@@ -125,7 +128,7 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
     this.replicaWrite = replicaWrite;
     this.replicaRead = replicaRead;
     this.replicaSkipEnabled = replicaSkipEnabled;
-    this.dataTransferPool = new ForkJoinPool(dataTranferPoolSize);
+    this.dataTransferPool = Executors.newFixedThreadPool(dataTranferPoolSize);
     this.dataCommitPoolSize = dataCommitPoolSize;
     this.unregisterThreadPoolSize = unregisterThreadPoolSize;
     this.unregisterRequestTimeSec = unregisterRequestTimeSec;
@@ -135,11 +138,22 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
       String appId,
       Map<ShuffleServerInfo, Map<Integer, Map<Integer, 
List<ShuffleBlockInfo>>>> serverToBlocks,
       Map<ShuffleServerInfo, List<Long>> serverToBlockIds,
-      Map<Long, AtomicInteger> blockIdsTracker) {
+      Map<Long, AtomicInteger> blockIdsTracker, boolean allowFastFail,
+      Supplier<Boolean> needCancelRequest) {
+
+    if (serverToBlockIds == null) {
+      return true;
+    }
+
     // If one or more servers is failed, the sending is not totally successful.
-    AtomicBoolean isAllServersSuccess = new AtomicBoolean(true);
-    if (serverToBlocks != null) {
-      dataTransferPool.submit(() -> 
serverToBlocks.entrySet().parallelStream().forEach(entry -> {
+    List<CompletableFuture<Boolean>> futures = new ArrayList<>();
+    for (Map.Entry<ShuffleServerInfo, Map<Integer, Map<Integer, 
List<ShuffleBlockInfo>>>> entry :
+        serverToBlocks.entrySet()) {
+      CompletableFuture<Boolean> future = CompletableFuture.supplyAsync(() -> {
+        if (needCancelRequest.get()) {
+          LOG.info("The upstream task has been failed. Abort this data send.");
+          return true;
+        }
         ShuffleServerInfo ssi = entry.getKey();
         try {
           Map<Integer, Map<Integer, List<ShuffleBlockInfo>>> shuffleIdToBlocks 
= entry.getValue();
@@ -157,16 +171,24 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
             serverToBlockIds.get(ssi).forEach(block -> 
blockIdsTracker.get(block).incrementAndGet());
             LOG.info("{} successfully.", logMsg);
           } else {
-            isAllServersSuccess.set(false);
             LOG.warn("{}, it failed wth statusCode[{}]", logMsg, 
response.getStatusCode());
+            return false;
           }
         } catch (Exception e) {
-          isAllServersSuccess.set(false);
           LOG.warn("Send: " + serverToBlockIds.get(ssi).size() + " blocks to 
[" + ssi.getId() + "] failed.", e);
+          return false;
         }
-      })).join();
+        return true;
+      }, dataTransferPool);
+      futures.add(future);
     }
-    return isAllServersSuccess.get();
+
+    boolean result = ClientUtils.waitUntilDoneOrFail(futures, allowFastFail);
+    if (!result) {
+      LOG.error("Some shuffle data can't be sent to shuffle-server, is fast 
fail: {}, cancelled task size: {}",
+          allowFastFail, futures.size());
+    }
+    return result;
   }
 
   private void genServerToBlocks(ShuffleBlockInfo sbi, List<ShuffleServerInfo> 
serverList,
@@ -196,8 +218,12 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
     }
   }
 
+  /**
+   * The batch of sending belongs to the same task
+   */
   @Override
-  public SendShuffleDataResult sendShuffleData(String appId, 
List<ShuffleBlockInfo> shuffleBlockInfoList) {
+  public SendShuffleDataResult sendShuffleData(String appId, 
List<ShuffleBlockInfo> shuffleBlockInfoList,
+      Supplier<Boolean> needCancelRequest) {
 
     // shuffleServer -> shuffleId -> partitionId -> blocks
     Map<ShuffleServerInfo, Map<Integer,
@@ -247,15 +273,28 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
 
     // sent the primary round of blocks.
     boolean isAllSuccess = sendShuffleDataAsync(
-        appId, primaryServerToBlocks, primaryServerToBlockIds, 
blockIdsTracker);
+        appId,
+        primaryServerToBlocks,
+        primaryServerToBlockIds,
+        blockIdsTracker,
+        secondaryServerToBlocks.isEmpty(),
+        needCancelRequest
+    );
 
     // The secondary round of blocks is sent only when the primary group 
issues failed sending.
     // This should be infrequent.
     // Even though the secondary round may send blocks more than replicaWrite 
replicas,
     // we do not apply complicated skipping logic, because server crash is 
rare in production environment.
-    if (!isAllSuccess && !secondaryServerToBlocks.isEmpty()) {
+    if (!isAllSuccess && !secondaryServerToBlocks.isEmpty() && 
!needCancelRequest.get()) {
       LOG.info("The sending of primary round is failed partially, so start the 
secondary round");
-      sendShuffleDataAsync(appId, secondaryServerToBlocks, 
secondaryServerToBlockIds, blockIdsTracker);
+      sendShuffleDataAsync(
+          appId,
+          secondaryServerToBlocks,
+          secondaryServerToBlockIds,
+          blockIdsTracker,
+          true,
+          needCancelRequest
+      );
     }
 
     // check success and failed blocks according to the replicaWrite
diff --git 
a/client/src/main/java/org/apache/uniffle/client/util/ClientUtils.java 
b/client/src/main/java/org/apache/uniffle/client/util/ClientUtils.java
index 3c394d23..eb2ee933 100644
--- a/client/src/main/java/org/apache/uniffle/client/util/ClientUtils.java
+++ b/client/src/main/java/org/apache/uniffle/client/util/ClientUtils.java
@@ -17,6 +17,12 @@
 
 package org.apache.uniffle.client.util;
 
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+
 import org.apache.uniffle.client.api.ShuffleWriteClient;
 import org.apache.uniffle.common.RemoteStorageInfo;
 import org.apache.uniffle.common.util.Constants;
@@ -70,4 +76,42 @@ public class ClientUtils {
         || StorageType.HDFS.name().equals(storageType)
         || StorageType.LOCALFILE_HDFS.name().equals(storageType);
   }
+
+  public static boolean waitUntilDoneOrFail(List<CompletableFuture<Boolean>> 
futures, boolean allowFastFail) {
+    int expected = futures.size();
+    int failed = 0;
+
+    CompletableFuture allFutures = CompletableFuture.allOf(futures.toArray(new 
CompletableFuture[0]));
+
+    List<Future> finished = new ArrayList<>();
+    while (true) {
+      for (Future<Boolean> future : futures) {
+        if (future.isDone() && !finished.contains(future)) {
+          finished.add(future);
+          try {
+            if (!future.get()) {
+              failed++;
+            }
+          } catch (Exception e) {
+            failed++;
+          }
+        }
+      }
+
+      if (expected == finished.size()) {
+        return failed <= 0;
+      }
+
+      if (failed > 0 && allowFastFail) {
+        futures.stream().filter(x -> !x.isDone()).forEach(x -> x.cancel(true));
+        return false;
+      }
+
+      try {
+        allFutures.get(10, TimeUnit.MILLISECONDS);
+      } catch (Exception e) {
+        // ignore
+      }
+    }
+  }
 }
diff --git 
a/client/src/test/java/org/apache/uniffle/client/ClientUtilsTest.java 
b/client/src/test/java/org/apache/uniffle/client/ClientUtilsTest.java
index 03139f65..5f26ae17 100644
--- a/client/src/test/java/org/apache/uniffle/client/ClientUtilsTest.java
+++ b/client/src/test/java/org/apache/uniffle/client/ClientUtilsTest.java
@@ -17,15 +17,30 @@
 
 package org.apache.uniffle.client;
 
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+
+import org.awaitility.Awaitility;
 import org.junit.jupiter.api.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 import org.apache.uniffle.client.util.ClientUtils;
 
+import static org.apache.uniffle.client.util.ClientUtils.waitUntilDoneOrFail;
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertThrows;
 import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.junit.jupiter.api.Assertions.fail;
 
 public class ClientUtilsTest {
+  private static final Logger LOGGER = 
LoggerFactory.getLogger(ClientUtilsTest.class);
+
+  private ExecutorService executorService = Executors.newFixedThreadPool(10);
 
   @Test
   public void getBlockIdTest() {
@@ -48,4 +63,49 @@ public class ClientUtilsTest {
     final Throwable e3 = assertThrows(IllegalArgumentException.class, () -> 
ClientUtils.getBlockId(0, 0, 262144));
     assertTrue(e3.getMessage().contains("Can't support sequence[262144], the 
max value should be 262143"));
   }
+
+  private List<CompletableFuture<Boolean>> getFutures(boolean fail) {
+    List<CompletableFuture<Boolean>> futures = new ArrayList<>();
+    for (int i = 0; i < 3; i++) {
+      final int index = i;
+      CompletableFuture<Boolean> future = CompletableFuture.supplyAsync(() -> {
+        if (index == 2) {
+          try {
+            Thread.sleep(3000);
+          } catch (InterruptedException interruptedException) {
+            LOGGER.info("Capture the InterruptedException");
+            return false;
+          }
+          LOGGER.info("Finished index: " + index);
+          return true;
+        }
+        if (fail && index == 1) {
+          return false;
+        }
+        return true;
+      }, executorService);
+      futures.add(future);
+    }
+    return futures;
+  }
+
+  @Test
+  public void testWaitUntilDoneOrFail() {
+    // case1: enable fail fast
+    List<CompletableFuture<Boolean>> futures1 = getFutures(true);
+    Awaitility.await().timeout(2, TimeUnit.SECONDS).until(() -> 
!waitUntilDoneOrFail(futures1, true));
+
+    // case2: disable fail fast
+    List<CompletableFuture<Boolean>> futures2 = getFutures(true);
+    try {
+      Awaitility.await().timeout(2, TimeUnit.SECONDS).until(() -> 
!waitUntilDoneOrFail(futures2, false));
+      fail();
+    } catch (Exception e) {
+      // ignore
+    }
+
+    // case3: all succeed
+    List<CompletableFuture<Boolean>> futures3 = getFutures(false);
+    Awaitility.await().timeout(4, TimeUnit.SECONDS).until(() -> 
waitUntilDoneOrFail(futures3, true));
+  }
 }
diff --git 
a/client/src/test/java/org/apache/uniffle/client/impl/ShuffleWriteClientImplTest.java
 
b/client/src/test/java/org/apache/uniffle/client/impl/ShuffleWriteClientImplTest.java
index 414d203f..47cb1e1f 100644
--- 
a/client/src/test/java/org/apache/uniffle/client/impl/ShuffleWriteClientImplTest.java
+++ 
b/client/src/test/java/org/apache/uniffle/client/impl/ShuffleWriteClientImplTest.java
@@ -18,10 +18,13 @@
 package org.apache.uniffle.client.impl;
 
 import java.util.List;
+import java.util.concurrent.TimeUnit;
 
 import com.google.common.collect.Lists;
+import org.awaitility.Awaitility;
 import org.junit.jupiter.api.Test;
 import org.mockito.Mockito;
+import org.mockito.stubbing.Answer;
 
 import org.apache.uniffle.client.api.ShuffleServerClient;
 import org.apache.uniffle.client.response.ResponseStatusCode;
@@ -39,6 +42,31 @@ import static org.mockito.Mockito.when;
 
 public class ShuffleWriteClientImplTest {
 
+  @Test
+  public void testAbandonEventWhenTaskFailed() {
+    ShuffleWriteClientImpl shuffleWriteClient =
+        new ShuffleWriteClientImpl("GRPC", 3, 2000, 4, 1, 1, 1, true, 1, 1, 
10, 10);
+    ShuffleServerClient mockShuffleServerClient = 
mock(ShuffleServerClient.class);
+    ShuffleWriteClientImpl spyClient = Mockito.spy(shuffleWriteClient);
+    
doReturn(mockShuffleServerClient).when(spyClient).getShuffleServerClient(any());
+
+    
when(mockShuffleServerClient.sendShuffleData(any())).thenAnswer((Answer<String>)
 invocation -> {
+      Thread.sleep(50000);
+      return "ABCD1234";
+    });
+
+    List<ShuffleServerInfo> shuffleServerInfoList =
+        Lists.newArrayList(new ShuffleServerInfo("id", "host", 0));
+    List<ShuffleBlockInfo> shuffleBlockInfoList = Lists.newArrayList(new 
ShuffleBlockInfo(
+        0, 0, 10, 10, 10, new byte[]{1}, shuffleServerInfoList, 10, 100, 0));
+
+    // It should directly exit and wont do rpc request.
+    Awaitility.await().timeout(1, TimeUnit.SECONDS).until(() -> {
+      spyClient.sendShuffleData("appId", shuffleBlockInfoList, () -> true);
+      return true;
+    });
+  }
+
   @Test
   public void testSendData() {
     ShuffleWriteClientImpl shuffleWriteClient =
@@ -53,7 +81,7 @@ public class ShuffleWriteClientImplTest {
         Lists.newArrayList(new ShuffleServerInfo("id", "host", 0));
     List<ShuffleBlockInfo> shuffleBlockInfoList = Lists.newArrayList(new 
ShuffleBlockInfo(
         0, 0, 10, 10, 10, new byte[]{1}, shuffleServerInfoList, 10, 100, 0));
-    SendShuffleDataResult result = spyClient.sendShuffleData("appId", 
shuffleBlockInfoList);
+    SendShuffleDataResult result = spyClient.sendShuffleData("appId", 
shuffleBlockInfoList, () -> false);
 
     assertTrue(result.getFailedBlockIds().contains(10L));
   }
diff --git 
a/integration-test/common/src/test/java/org/apache/uniffle/test/QuorumTest.java 
b/integration-test/common/src/test/java/org/apache/uniffle/test/QuorumTest.java
index f4791df9..7379e8f6 100644
--- 
a/integration-test/common/src/test/java/org/apache/uniffle/test/QuorumTest.java
+++ 
b/integration-test/common/src/test/java/org/apache/uniffle/test/QuorumTest.java
@@ -70,7 +70,7 @@ public class QuorumTest extends ShuffleReadWriteBase {
   private static ShuffleServerInfo fakedShuffleServerInfo2;
   private static ShuffleServerInfo fakedShuffleServerInfo3;
   private static ShuffleServerInfo fakedShuffleServerInfo4;
-  private ShuffleWriteClientImpl shuffleWriteClientImpl;
+  private MockedShuffleWriteClientImpl shuffleWriteClientImpl;
 
   public static MockedShuffleServer createServer(int id) throws Exception {
     ShuffleServerConf shuffleServerConf = getShuffleServerConf();
@@ -263,10 +263,24 @@ public class QuorumTest extends ShuffleReadWriteBase {
         .disableMockedTimeout();
   }
 
+  static class MockedShuffleWriteClientImpl extends ShuffleWriteClientImpl {
+    MockedShuffleWriteClientImpl(String clientType, int retryMax, long 
retryIntervalMax, int heartBeatThreadNum,
+        int replica, int replicaWrite, int replicaRead, boolean 
replicaSkipEnabled, int dataTranferPoolSize,
+        int dataCommitPoolSize, int unregisterThreadPoolSize, int 
unregisterRequestTimeSec) {
+      super(clientType, retryMax, retryIntervalMax, heartBeatThreadNum, 
replica, replicaWrite, replicaRead,
+          replicaSkipEnabled, dataTranferPoolSize, dataCommitPoolSize, 
unregisterThreadPoolSize,
+          unregisterRequestTimeSec);
+    }
+
+    public SendShuffleDataResult sendShuffleData(String appId, 
List<ShuffleBlockInfo> shuffleBlockInfoList) {
+      return super.sendShuffleData(appId, shuffleBlockInfoList, () -> false);
+    }
+  }
+
   private void registerShuffleServer(String testAppId,
       int replica, int replicaWrite, int replicaRead, boolean replicaSkip) {
 
-    shuffleWriteClientImpl = new 
ShuffleWriteClientImpl(ClientType.GRPC.name(), 3, 1000, 1,
+    shuffleWriteClientImpl = new 
MockedShuffleWriteClientImpl(ClientType.GRPC.name(), 3, 1000, 1,
       replica, replicaWrite, replicaRead, replicaSkip, 1, 1, 10, 10);
 
     List<ShuffleServerInfo> allServers = 
Lists.newArrayList(shuffleServerInfo0, shuffleServerInfo1,
diff --git 
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java
 
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java
index 3d54f790..fdb83a9f 100644
--- 
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java
+++ 
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java
@@ -515,7 +515,7 @@ public class ShuffleServerGrpcTest extends 
IntegrationTestBase {
 
   @Disabled("flaky test")
   @Test
-  public void rpcMetricsTest() {
+  public void rpcMetricsTest() throws Exception {
     String appId = "rpcMetricsTest";
     int shuffleId = 0;
     final double oldGrpcTotal = 
shuffleServers.get(0).getGrpcMetrics().getCounterGrpcTotal().get();
diff --git 
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleWithRssClientTest.java
 
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleWithRssClientTest.java
index b9e76f70..ca62f7bf 100644
--- 
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleWithRssClientTest.java
+++ 
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleWithRssClientTest.java
@@ -122,7 +122,7 @@ public class ShuffleWithRssClientTest extends 
ShuffleReadWriteBase {
     List<ShuffleBlockInfo> blocks = createShuffleBlockList(
         0, 0, 0, 3, 25, blockIdBitmap,
         expectedData, Lists.newArrayList(shuffleServerInfo1, 
fakeShuffleServerInfo));
-    SendShuffleDataResult result = 
shuffleWriteClientImpl.sendShuffleData(testAppId, blocks);
+    SendShuffleDataResult result = 
shuffleWriteClientImpl.sendShuffleData(testAppId, blocks, () -> false);
     Roaring64NavigableMap failedBlockIdBitmap = 
Roaring64NavigableMap.bitmapOf();
     Roaring64NavigableMap succBlockIdBitmap = Roaring64NavigableMap.bitmapOf();
     for (Long blockId : result.getFailedBlockIds()) {
@@ -254,7 +254,7 @@ public class ShuffleWithRssClientTest extends 
ShuffleReadWriteBase {
     List<ShuffleBlockInfo> blocks = createShuffleBlockList(
         0, 0, 0, 3, 25, blockIdBitmap,
         expectedData, Lists.newArrayList(shuffleServerInfo1, 
shuffleServerInfo2));
-    shuffleWriteClientImpl.sendShuffleData(testAppId, blocks);
+    shuffleWriteClientImpl.sendShuffleData(testAppId, blocks, () -> false);
     // send 1st commit, finish commit won't be sent to Shuffle server and data 
won't be persisted to disk
     boolean commitResult = shuffleWriteClientImpl
         .sendCommit(Sets.newHashSet(shuffleServerInfo1, shuffleServerInfo2), 
testAppId, 0, 2);
diff --git 
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
 
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
index 4adf8229..82a590db 100644
--- 
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
+++ 
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
@@ -117,7 +117,7 @@ public class ShuffleServerGrpcClient extends GrpcClient 
implements ShuffleServer
     super(host, port, maxRetryAttempts, usePlaintext);
     blockingStub = ShuffleServerGrpc.newBlockingStub(channel);
   }
-  
+
   public ShuffleServerBlockingStub getBlockingStub() {
     return blockingStub.withDeadlineAfter(rpcTimeout, TimeUnit.MILLISECONDS);
   }
@@ -190,7 +190,7 @@ public class ShuffleServerGrpcClient extends GrpcClient 
implements ShuffleServer
           + retry + "] again");
       if (retry >= retryMax) {
         LOG.warn("ShuffleServer " + host + ":" + port + " is full and can't 
send shuffle"
-            + " data successfully after retry " + retryMax + " times, cost: 
{}(ms)",
+                + " data successfully after retry " + retryMax + " times, 
cost: {}(ms)",
             System.currentTimeMillis() - start);
         return result;
       }


Reply via email to