This is an automated email from the ASF dual-hosted git repository.

zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new 5a45daaf3 [#2502] feat(spark3): Fast fail for those stale assign 
blocks for partition reassign (#2516)
5a45daaf3 is described below

commit 5a45daaf38046a196f47acf4f706a4ff938b1fbf
Author: Junfan Zhang <[email protected]>
AuthorDate: Tue Jun 24 17:29:47 2025 +0800

    [#2502] feat(spark3): Fast fail for those stale assign blocks for partition 
reassign (#2516)
    
    ### What changes were proposed in this pull request?
    
    Fast fail for those stale assign blocks for partition reassign
    
    ### Why are the changes needed?
    
    for #2502
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Add unit tests
---
 .../apache/spark/shuffle/writer/DataPusher.java    | 50 ++++++++++++++++++++--
 .../spark/shuffle/writer/WriteBufferManager.java   |  3 +-
 .../spark/shuffle/writer/DataPusherTest.java       | 35 +++++++++++++++
 .../client/impl/FailedBlockSendTracker.java        |  7 +++
 .../apache/uniffle/common/ShuffleBlockInfo.java    | 48 +++++++++++++++++++++
 .../uniffle/common/ShuffleBlockInfoTest.java       | 28 ++++++++++++
 6 files changed, 166 insertions(+), 5 deletions(-)

diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
index 2da75144e..1926f31b2 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
@@ -19,6 +19,7 @@ package org.apache.spark.shuffle.writer;
 
 import java.io.Closeable;
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
@@ -31,6 +32,7 @@ import java.util.concurrent.TimeUnit;
 
 import com.google.common.collect.Queues;
 import com.google.common.collect.Sets;
+import org.apache.commons.collections4.CollectionUtils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -39,7 +41,9 @@ import 
org.apache.uniffle.client.common.ShuffleServerPushCostTracker;
 import org.apache.uniffle.client.impl.FailedBlockSendTracker;
 import org.apache.uniffle.client.response.SendShuffleDataResult;
 import org.apache.uniffle.common.ShuffleBlockInfo;
+import org.apache.uniffle.common.ShuffleServerInfo;
 import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.rpc.StatusCode;
 import org.apache.uniffle.common.util.ThreadUtils;
 
 /**
@@ -89,13 +93,20 @@ public class DataPusher implements Closeable {
             () -> {
               String taskId = event.getTaskId();
               List<ShuffleBlockInfo> shuffleBlockInfoList = 
event.getShuffleDataInfoList();
+              // filter out the shuffle blocks with stale assignment
+              List<ShuffleBlockInfo> validBlocks =
+                  filterOutStaleAssignmentBlocks(taskId, shuffleBlockInfoList);
+              if (CollectionUtils.isEmpty(validBlocks)) {
+                return 0L;
+              }
+
               SendShuffleDataResult result = null;
               try {
                 result =
                     shuffleWriteClient.sendShuffleData(
                         rssAppId,
                         event.getStageAttemptNumber(),
-                        shuffleBlockInfoList,
+                        validBlocks,
                         () -> !isValidTask(taskId));
                 putBlockId(taskToSuccessBlockIds, taskId, 
result.getSuccessBlockIds());
                 putFailedBlockSendTracker(
@@ -109,7 +120,7 @@ public class DataPusher implements Closeable {
                 }
 
                 Set<Long> succeedBlockIds = getSucceedBlockIds(result);
-                for (ShuffleBlockInfo block : shuffleBlockInfoList) {
+                for (ShuffleBlockInfo block : validBlocks) {
                   
block.executeCompletionCallback(succeedBlockIds.contains(block.getBlockId()));
                 }
 
@@ -120,7 +131,7 @@ public class DataPusher implements Closeable {
                 }
               }
               Set<Long> succeedBlockIds = getSucceedBlockIds(result);
-              return shuffleBlockInfoList.stream()
+              return validBlocks.stream()
                   .filter(x -> succeedBlockIds.contains(x.getBlockId()))
                   .map(x -> x.getFreeMemory())
                   .reduce((a, b) -> a + b)
@@ -134,6 +145,37 @@ public class DataPusher implements Closeable {
             });
   }
 
+  /**
+   * This method is only valid for the single replica. If the block info's 
assignment is stale, it
+   * will be filtered out and make it retry. If the partition reassignment is 
disabled, this method
+   * always will not filter out any blocks.
+   *
+   * @param taskId
+   * @param blocks
+   * @return the valid shuffle blocks
+   */
+  private List<ShuffleBlockInfo> filterOutStaleAssignmentBlocks(
+      String taskId, List<ShuffleBlockInfo> blocks) {
+    FailedBlockSendTracker staleBlockTracker = new FailedBlockSendTracker();
+    List<ShuffleBlockInfo> validBlocks = new ArrayList<>();
+    for (ShuffleBlockInfo block : blocks) {
+      List<ShuffleServerInfo> servers = block.getShuffleServerInfos();
+      // skip the multi replica cases.
+      if (servers == null || servers.size() != 1) {
+        validBlocks.add(block);
+      } else {
+        if (block.isStaleAssignment()) {
+          staleBlockTracker.add(
+              block, block.getShuffleServerInfos().get(0), 
StatusCode.INTERNAL_ERROR);
+        } else {
+          validBlocks.add(block);
+        }
+      }
+    }
+    putFailedBlockSendTracker(taskToFailedBlockSendTracker, taskId, 
staleBlockTracker);
+    return validBlocks;
+  }
+
   private Set<Long> getSucceedBlockIds(SendShuffleDataResult result) {
     if (result == null || result.getSuccessBlockIds() == null) {
       return Collections.emptySet();
@@ -155,7 +197,7 @@ public class DataPusher implements Closeable {
       Map<String, FailedBlockSendTracker> taskToFailedBlockSendTracker,
       String taskAttemptId,
       FailedBlockSendTracker failedBlockSendTracker) {
-    if (failedBlockSendTracker == null) {
+    if (failedBlockSendTracker == null || failedBlockSendTracker.isEmpty()) {
       return;
     }
     taskToFailedBlockSendTracker
diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
index adfdaa8c7..5b412300a 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
@@ -448,7 +448,8 @@ public class WriteBufferManager extends MemoryConsumer {
         partitionAssignmentRetrieveFunc.apply(partitionId),
         uncompressLength,
         wb.getMemoryUsed(),
-        taskAttemptId);
+        taskAttemptId,
+        partitionAssignmentRetrieveFunc);
   }
 
   // it's run in single thread, and is not thread safe
diff --git 
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java
 
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java
index 4d0d661c3..0bbf21fdf 100644
--- 
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java
+++ 
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java
@@ -18,6 +18,7 @@
 package org.apache.spark.shuffle.writer;
 
 import java.util.Arrays;
+import java.util.Collections;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
@@ -88,6 +89,40 @@ public class DataPusherTest {
     }
   }
 
+  @Test
+  public void testFilterOutStaleAssignmentBlocks() {
+    FakedShuffleWriteClient shuffleWriteClient = new FakedShuffleWriteClient();
+
+    Map<String, Set<Long>> taskToSuccessBlockIds = Maps.newConcurrentMap();
+    Map<String, FailedBlockSendTracker> taskToFailedBlockSendTracker = 
JavaUtils.newConcurrentMap();
+    Set<String> failedTaskIds = new HashSet<>();
+
+    DataPusher dataPusher =
+        new DataPusher(
+            shuffleWriteClient,
+            taskToSuccessBlockIds,
+            taskToFailedBlockSendTracker,
+            failedTaskIds,
+            1,
+            2);
+    dataPusher.setRssAppId("testFilterOutStaleAssignmentBlocks");
+
+    String taskId = "taskId1";
+    List<ShuffleServerInfo> server1 =
+        Collections.singletonList(new ShuffleServerInfo("0", "localhost", 
1234));
+    ShuffleBlockInfo staleBlock1 =
+        new ShuffleBlockInfo(
+            1, 1, 3, 1, 1, new byte[1], server1, 1, 100, 1, integer -> 
Collections.emptyList());
+
+    // case1: will fast fail due to the stale assignment
+    AddBlockEvent event = new AddBlockEvent(taskId, 
Arrays.asList(staleBlock1));
+    CompletableFuture<Long> f1 = dataPusher.send(event);
+    assertEquals(f1.join(), 0);
+    Set<Long> failedBlockIds = 
taskToFailedBlockSendTracker.get(taskId).getFailedBlockIds();
+    assertEquals(1, failedBlockIds.size());
+    assertEquals(3, failedBlockIds.stream().findFirst().get());
+  }
+
   @Test
   public void testSendData() throws ExecutionException, InterruptedException {
     FakedShuffleWriteClient shuffleWriteClient = new FakedShuffleWriteClient();
diff --git 
a/client/src/main/java/org/apache/uniffle/client/impl/FailedBlockSendTracker.java
 
b/client/src/main/java/org/apache/uniffle/client/impl/FailedBlockSendTracker.java
index 12856faf7..22c6d21fb 100644
--- 
a/client/src/main/java/org/apache/uniffle/client/impl/FailedBlockSendTracker.java
+++ 
b/client/src/main/java/org/apache/uniffle/client/impl/FailedBlockSendTracker.java
@@ -119,4 +119,11 @@ public class FailedBlockSendTracker {
     trackingNeedSplitPartitionStatusQueue.drainTo(trackingPartitionStatusList);
     return trackingPartitionStatusList;
   }
+
+  public boolean isEmpty() {
+    if (trackingBlockStatusMap.isEmpty() && 
trackingNeedSplitPartitionStatusQueue.isEmpty()) {
+      return true;
+    }
+    return false;
+  }
 }
diff --git 
a/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java 
b/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java
index 36dec5e25..a38e9d206 100644
--- a/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java
+++ b/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java
@@ -18,6 +18,7 @@
 package org.apache.uniffle.common;
 
 import java.util.List;
+import java.util.function.Function;
 
 import io.netty.buffer.ByteBuf;
 import io.netty.buffer.Unpooled;
@@ -40,6 +41,34 @@ public class ShuffleBlockInfo {
 
   private transient BlockCompletionCallback completionCallback;
 
+  private Function<Integer, List<ShuffleServerInfo>> 
partitionAssignmentRetrieveFunc;
+
+  public ShuffleBlockInfo(
+      int shuffleId,
+      int partitionId,
+      long blockId,
+      int length,
+      long crc,
+      byte[] data,
+      List<ShuffleServerInfo> shuffleServerInfos,
+      int uncompressLength,
+      long freeMemory,
+      long taskAttemptId,
+      Function<Integer, List<ShuffleServerInfo>> 
partitionAssignmentRetrieveFunc) {
+    this(
+        shuffleId,
+        partitionId,
+        blockId,
+        length,
+        crc,
+        data,
+        shuffleServerInfos,
+        uncompressLength,
+        freeMemory,
+        taskAttemptId);
+    this.partitionAssignmentRetrieveFunc = partitionAssignmentRetrieveFunc;
+  }
+
   public ShuffleBlockInfo(
       int shuffleId,
       int partitionId,
@@ -182,4 +211,23 @@ public class ShuffleBlockInfo {
     }
     completionCallback.onBlockCompletion(this, isSuccessful);
   }
+
+  public boolean isStaleAssignment() {
+    if (partitionAssignmentRetrieveFunc == null) {
+      return false;
+    }
+    List<ShuffleServerInfo> latestAssignment = 
partitionAssignmentRetrieveFunc.apply(partitionId);
+    if (latestAssignment == null || shuffleServerInfos == null) {
+      return false;
+    }
+    if (latestAssignment.size() != shuffleServerInfos.size()) {
+      return true;
+    }
+    for (int i = 0; i < latestAssignment.size(); i++) {
+      if 
(!latestAssignment.get(i).getId().equals(shuffleServerInfos.get(i).getId())) {
+        return true;
+      }
+    }
+    return false;
+  }
 }
diff --git 
a/common/src/test/java/org/apache/uniffle/common/ShuffleBlockInfoTest.java 
b/common/src/test/java/org/apache/uniffle/common/ShuffleBlockInfoTest.java
index 71db702c8..ea3b2a941 100644
--- a/common/src/test/java/org/apache/uniffle/common/ShuffleBlockInfoTest.java
+++ b/common/src/test/java/org/apache/uniffle/common/ShuffleBlockInfoTest.java
@@ -19,13 +19,41 @@ package org.apache.uniffle.common;
 
 import java.util.Collections;
 import java.util.List;
+import java.util.function.Function;
 
 import org.junit.jupiter.api.Test;
 
 import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertTrue;
 
 public class ShuffleBlockInfoTest {
 
+  @Test
+  public void testStaleAssignment() throws Exception {
+    List<ShuffleServerInfo> servers =
+        Collections.singletonList(new ShuffleServerInfo("0", "localhost", 
1234));
+    ShuffleBlockInfo blockInfo =
+        new ShuffleBlockInfo(1, 2, 3, 4, 5, new byte[1], servers, 9, 1, 9, 
null);
+    // case1: null partition assignment function, it should always be false.
+    assertFalse(blockInfo.isStaleAssignment());
+
+    // case2: stale assignment
+    Function<Integer, List<ShuffleServerInfo>> partitionAssignmentRetrieveFunc 
=
+        integer -> Collections.singletonList(new ShuffleServerInfo("1", 
"localhost", 1234));
+    blockInfo =
+        new ShuffleBlockInfo(
+            1, 2, 3, 4, 5, new byte[1], servers, 9, 1, 9, 
partitionAssignmentRetrieveFunc);
+    assertTrue(blockInfo.isStaleAssignment());
+
+    // case3: same assignment
+    partitionAssignmentRetrieveFunc = integer -> servers;
+    blockInfo =
+        new ShuffleBlockInfo(
+            1, 2, 3, 4, 5, new byte[1], servers, 9, 1, 9, 
partitionAssignmentRetrieveFunc);
+    assertFalse(blockInfo.isStaleAssignment());
+  }
+
   @Test
   public void testToString() {
     List<ShuffleServerInfo> shuffleServerInfos =

Reply via email to