This is an automated email from the ASF dual-hosted git repository.

zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new 4f2bee909 [#2090] fix(client): Ensure thread-safe to store the sending 
status (#2091)
4f2bee909 is described below

commit 4f2bee909d1bf0e537b620a0921f9e8c7db20337
Author: xianjingfeng <xianjingfeng...@gmail.com>
AuthorDate: Tue Sep 3 14:48:08 2024 +0800

    [#2090] fix(client): Ensure thread-safe to store the sending status (#2091)
    
    ### What changes were proposed in this pull request?
    
    Replace LinkedList with CopyOnWriteArrayList to store the sending status
    
    ### Why are the changes needed?
    
    Fix: #2090
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    CI
---
 .../spark/shuffle/writer/RssShuffleWriter.java     | 59 ++++++++---------
 .../client/impl/FailedBlockSendTracker.java        | 32 ++++++---
 .../client/impl/FailedBlockSendTrackerTest.java    | 75 ++++++++++++++++++++++
 3 files changed, 127 insertions(+), 39 deletions(-)

diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 24a3b8c1c..b3a5ccf09 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -25,7 +25,6 @@ import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
-import java.util.Optional;
 import java.util.Set;
 import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.CompletableFuture;
@@ -567,47 +566,49 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     // to check whether the blocks resent exceed the max resend count.
     for (Long blockId : failedBlockIds) {
       List<TrackingBlockStatus> failedBlockStatus = 
failedTracker.getFailedBlockStatus(blockId);
-      int retryIndex =
-          failedBlockStatus.stream()
-              .map(x -> x.getShuffleBlockInfo().getRetryCnt())
-              .max(Comparator.comparing(Integer::valueOf))
-              .get();
-      if (retryIndex >= blockFailSentRetryMaxTimes) {
-        LOG.error(
-            "Partial blocks for taskId: [{}] retry exceeding the max retry 
times: [{}]. Fast fail! faulty server list: {}",
-            taskId,
-            blockFailSentRetryMaxTimes,
+      synchronized (failedBlockStatus) {
+        int retryIndex =
             failedBlockStatus.stream()
-                .map(x -> x.getShuffleServerInfo())
-                .collect(Collectors.toSet()));
-        isFastFail = true;
-        break;
-      }
-
-      for (TrackingBlockStatus status : failedBlockStatus) {
-        StatusCode code = status.getStatusCode();
-        if (STATUS_CODE_WITHOUT_BLOCK_RESEND.contains(code)) {
+                .map(x -> x.getShuffleBlockInfo().getRetryCnt())
+                .max(Comparator.comparing(Integer::valueOf))
+                .get();
+        if (retryIndex >= blockFailSentRetryMaxTimes) {
           LOG.error(
-              "Partial blocks for taskId: [{}] failed on the illegal status 
code: [{}] without resend on server: {}",
+              "Partial blocks for taskId: [{}] retry exceeding the max retry 
times: [{}]. Fast fail! faulty server list: {}",
               taskId,
-              code,
-              status.getShuffleServerInfo());
+              blockFailSentRetryMaxTimes,
+              failedBlockStatus.stream()
+                  .map(x -> x.getShuffleServerInfo())
+                  .collect(Collectors.toSet()));
           isFastFail = true;
           break;
         }
-      }
 
-      // todo: if setting multi replica and another replica is succeed to 
send, no need to resend
-      resendCandidates.addAll(failedBlockStatus);
+        for (TrackingBlockStatus status : failedBlockStatus) {
+          StatusCode code = status.getStatusCode();
+          if (STATUS_CODE_WITHOUT_BLOCK_RESEND.contains(code)) {
+            LOG.error(
+                "Partial blocks for taskId: [{}] failed on the illegal status 
code: [{}] without resend on server: {}",
+                taskId,
+                code,
+                status.getShuffleServerInfo());
+            isFastFail = true;
+            break;
+          }
+        }
+
+        // todo: if setting multi replica and another replica is succeed to 
send, no need to resend
+        resendCandidates.addAll(failedBlockStatus);
+      }
     }
 
     if (isFastFail) {
       // release data and allocated memory
       for (Long blockId : failedBlockIds) {
         List<TrackingBlockStatus> failedBlockStatus = 
failedTracker.getFailedBlockStatus(blockId);
-        Optional<TrackingBlockStatus> blockStatus = 
failedBlockStatus.stream().findFirst();
-        if (blockStatus.isPresent()) {
-          
blockStatus.get().getShuffleBlockInfo().executeCompletionCallback(true);
+        if (CollectionUtils.isNotEmpty(failedBlockStatus)) {
+          TrackingBlockStatus blockStatus = failedBlockStatus.get(0);
+          blockStatus.getShuffleBlockInfo().executeCompletionCallback(true);
         }
       }
 
diff --git 
a/client/src/main/java/org/apache/uniffle/client/impl/FailedBlockSendTracker.java
 
b/client/src/main/java/org/apache/uniffle/client/impl/FailedBlockSendTracker.java
index 93e20dd02..c0ff6d5bd 100644
--- 
a/client/src/main/java/org/apache/uniffle/client/impl/FailedBlockSendTracker.java
+++ 
b/client/src/main/java/org/apache/uniffle/client/impl/FailedBlockSendTracker.java
@@ -17,14 +17,14 @@
 
 package org.apache.uniffle.client.impl;
 
-import java.util.Collection;
+import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
-import java.util.stream.Collectors;
 
 import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
+import com.google.common.collect.Sets;
 
 import org.apache.uniffle.common.ShuffleBlockInfo;
 import org.apache.uniffle.common.ShuffleServerInfo;
@@ -49,7 +49,8 @@ public class FailedBlockSendTracker {
       ShuffleServerInfo shuffleServerInfo,
       StatusCode statusCode) {
     trackingBlockStatusMap
-        .computeIfAbsent(shuffleBlockInfo.getBlockId(), s -> 
Lists.newLinkedList())
+        .computeIfAbsent(
+            shuffleBlockInfo.getBlockId(), s -> 
Collections.synchronizedList(Lists.newArrayList()))
         .add(new TrackingBlockStatus(shuffleBlockInfo, shuffleServerInfo, 
statusCode));
   }
 
@@ -62,9 +63,14 @@ public class FailedBlockSendTracker {
   }
 
   public void clearAndReleaseBlockResources() {
-    trackingBlockStatusMap.values().stream()
-        .flatMap(x -> x.stream())
-        .forEach(x -> x.getShuffleBlockInfo().executeCompletionCallback(true));
+    trackingBlockStatusMap
+        .values()
+        .forEach(
+            l -> {
+              synchronized (l) {
+                l.forEach(x -> 
x.getShuffleBlockInfo().executeCompletionCallback(false));
+              }
+            });
     trackingBlockStatusMap.clear();
   }
 
@@ -77,9 +83,15 @@ public class FailedBlockSendTracker {
   }
 
   public Set<ShuffleServerInfo> getFaultyShuffleServers() {
-    return trackingBlockStatusMap.values().stream()
-        .flatMap(Collection::stream)
-        .map(s -> s.getShuffleServerInfo())
-        .collect(Collectors.toSet());
+    Set<ShuffleServerInfo> shuffleServerInfos = Sets.newHashSet();
+    trackingBlockStatusMap.values().stream()
+        .forEach(
+            l -> {
+              synchronized (l) {
+                l.stream()
+                    .forEach((status) -> 
shuffleServerInfos.add(status.getShuffleServerInfo()));
+              }
+            });
+    return shuffleServerInfos;
   }
 }
diff --git 
a/client/src/test/java/org/apache/uniffle/client/impl/FailedBlockSendTrackerTest.java
 
b/client/src/test/java/org/apache/uniffle/client/impl/FailedBlockSendTrackerTest.java
new file mode 100644
index 000000000..ec5b0ab16
--- /dev/null
+++ 
b/client/src/test/java/org/apache/uniffle/client/impl/FailedBlockSendTrackerTest.java
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.client.impl;
+
+import java.util.List;
+import java.util.concurrent.TimeUnit;
+import java.util.stream.Collectors;
+
+import com.google.common.collect.Lists;
+import org.apache.commons.collections4.CollectionUtils;
+import org.junit.jupiter.api.Test;
+
+import org.apache.uniffle.common.ShuffleBlockInfo;
+import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.rpc.StatusCode;
+
+import static org.awaitility.Awaitility.await;
+
+public class FailedBlockSendTrackerTest {
+  @Test
+  public void test() throws Exception {
+    FailedBlockSendTracker tracker = new FailedBlockSendTracker();
+    ShuffleServerInfo shuffleServerInfo1 = new ShuffleServerInfo("host1", 
19999);
+    ShuffleServerInfo shuffleServerInfo2 = new ShuffleServerInfo("host2", 
19999);
+    ShuffleServerInfo shuffleServerInfo3 = new ShuffleServerInfo("host3", 
19999);
+    List<ShuffleServerInfo> shuffleServerInfos1 =
+        Lists.newArrayList(shuffleServerInfo1, shuffleServerInfo2);
+    ShuffleBlockInfo shuffleBlockInfo1 =
+        new ShuffleBlockInfo(0, 0, 1L, 0, 0L, new byte[] {}, 
shuffleServerInfos1, 0, 0L, 0L);
+    List<ShuffleServerInfo> shuffleServerInfos2 =
+        Lists.newArrayList(shuffleServerInfo3, shuffleServerInfo2);
+    ShuffleBlockInfo shuffleBlockInfo2 =
+        new ShuffleBlockInfo(0, 0, 2L, 0, 0L, new byte[] {}, 
shuffleServerInfos2, 0, 0L, 0L);
+    new Thread(
+            () -> {
+              tracker.add(shuffleBlockInfo1, shuffleServerInfo1, 
StatusCode.INTERNAL_ERROR);
+              tracker.add(shuffleBlockInfo1, shuffleServerInfo2, 
StatusCode.INTERNAL_ERROR);
+              tracker.add(shuffleBlockInfo2, shuffleServerInfo3, 
StatusCode.INTERNAL_ERROR);
+              tracker.add(shuffleBlockInfo2, shuffleServerInfo2, 
StatusCode.INTERNAL_ERROR);
+            })
+        .start();
+    List<String> expected =
+        Lists.newArrayList(
+            shuffleServerInfo1.getId(), shuffleServerInfo2.getId(), 
shuffleServerInfo3.getId());
+    await()
+        .atMost(5, TimeUnit.SECONDS)
+        .until(
+            () -> {
+              if (tracker.getFailedBlockIds().size() != 2) {
+                return false;
+              }
+              List<String> actual =
+                  tracker.getFaultyShuffleServers().stream()
+                      .map(ShuffleServerInfo::getId)
+                      .sorted()
+                      .collect(Collectors.toList());
+              return CollectionUtils.isEqualCollection(expected, actual);
+            });
+  }
+}

Reply via email to