This is an automated email from the ASF dual-hosted git repository.

zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new fe0ff7e60 [#2586] fix(spark): Support writer switching servers on 
partition split with LOAD_BALANCE mode without reassign (#2587)
fe0ff7e60 is described below

commit fe0ff7e60efab90a556554f4afb46f0caea6c2ce
Author: Junfan Zhang <[email protected]>
AuthorDate: Fri Aug 22 10:08:12 2025 +0800

    [#2586] fix(spark): Support writer switching servers on partition split 
with LOAD_BALANCE mode without reassign (#2587)
    
    ### What changes were proposed in this pull request?
    
    This PR is to support writer switching servers on partition split with 
LOAD_BALANCE mode without reassign.
    
    ### Why are the changes needed?
    
    fix #2586
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Unit tests
---
 .../shuffle/handle/MutableShuffleHandleInfo.java   | 13 ++++----
 .../spark/shuffle/handle/ShuffleHandleInfo.java    |  8 ++++-
 .../shuffle/handle/SimpleShuffleHandleInfo.java    |  8 ++++-
 .../handle/StageAttemptShuffleHandleInfo.java      |  5 +--
 .../shuffle/writer/TaskAttemptAssignment.java      | 39 ++++++++++++++++++++--
 .../spark/shuffle/writer/RssShuffleWriter.java     | 10 ++++--
 .../apache/uniffle/common/ShuffleServerInfo.java   | 25 ++++++++++++--
 .../uniffle/common/ShuffleServerInfoTest.java      | 18 ++++++++++
 8 files changed, 108 insertions(+), 18 deletions(-)

diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfo.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfo.java
index 0e22c017c..35aaa50be 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfo.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfo.java
@@ -202,21 +202,21 @@ public class MutableShuffleHandleInfo extends 
ShuffleHandleInfoBase {
   }
 
   @Override
-  public Map<Integer, List<ShuffleServerInfo>> 
getAvailablePartitionServersForWriter() {
+  public Map<Integer, List<ShuffleServerInfo>> 
getAvailablePartitionServersForWriter(
+      Map<Integer, List<ShuffleServerInfo>> exclusiveServers) {
+    Map<Integer, List<ShuffleServerInfo>> requestExclusiveServers =
+        exclusiveServers == null ? Collections.emptyMap() : exclusiveServers;
     Map<Integer, List<ShuffleServerInfo>> assignment = new HashMap<>();
     for (Map.Entry<Integer, Map<Integer, List<ShuffleServerInfo>>> entry :
         partitionReplicaAssignedServers.entrySet()) {
       int partitionId = entry.getKey();
+      List<ShuffleServerInfo> partitionExclusiveServers =
+          requestExclusiveServers.getOrDefault(partitionId, 
Collections.emptyList());
       Map<Integer, List<ShuffleServerInfo>> replicaServers = entry.getValue();
       PartitionSplitInfo splitInfo = this.getPartitionSplitInfo(partitionId);
       for (Map.Entry<Integer, List<ShuffleServerInfo>> replicaServerEntry :
           replicaServers.entrySet()) {
 
-        // For normal partition reassignment, the latest replacement shuffle 
server is always used.
-        Optional<ShuffleServerInfo> candidateOptional =
-            replicaServerEntry.getValue().stream()
-                .filter(x -> 
!excludedServerToReplacements.containsKey(x.getId()))
-                .findFirst();
         // Get the unexcluded server for each replica writing
         ShuffleServerInfo candidate =
             
replicaServerEntry.getValue().get(replicaServerEntry.getValue().size() - 1);
@@ -230,6 +230,7 @@ public class MutableShuffleHandleInfo extends 
ShuffleHandleInfoBase {
           List<ShuffleServerInfo> servers =
               replicaServerEntry.getValue().stream()
                   .filter(x -> 
!excludedServerToReplacements.containsKey(x.getId()))
+                  .filter(x -> !partitionExclusiveServers.contains(x))
                   .collect(Collectors.toList());
 
           // 2. exclude the first partition split triggered node.
diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/ShuffleHandleInfo.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/ShuffleHandleInfo.java
index e51ff8029..b794acb68 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/ShuffleHandleInfo.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/ShuffleHandleInfo.java
@@ -17,6 +17,7 @@
 
 package org.apache.spark.shuffle.handle;
 
+import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
@@ -36,7 +37,12 @@ public interface ShuffleHandleInfo {
    * shuffleServers. Implementations might return dynamic, up-to-date 
information here. Returns
    * partitionId -> [replica1, replica2, ...]
    */
-  Map<Integer, List<ShuffleServerInfo>> 
getAvailablePartitionServersForWriter();
+  default Map<Integer, List<ShuffleServerInfo>> 
getAvailablePartitionServersForWriter() {
+    return getAvailablePartitionServersForWriter(Collections.emptyMap());
+  }
+
+  Map<Integer, List<ShuffleServerInfo>> getAvailablePartitionServersForWriter(
+      Map<Integer, List<ShuffleServerInfo>> exclusivePartitionServers);
 
   /**
    * Get all servers ever assigned to writers group by partitionId for reader 
to get the data
diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/SimpleShuffleHandleInfo.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/SimpleShuffleHandleInfo.java
index 20b5c13ec..13a83c7d6 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/SimpleShuffleHandleInfo.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/SimpleShuffleHandleInfo.java
@@ -28,6 +28,7 @@ import 
org.apache.spark.shuffle.handle.split.PartitionSplitInfo;
 import org.apache.uniffle.client.PartitionDataReplicaRequirementTracking;
 import org.apache.uniffle.common.RemoteStorageInfo;
 import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.exception.RssException;
 
 /**
  * Class for holding, 1. partition ID -> shuffle servers mapping. 2. remote 
storage info
@@ -54,7 +55,12 @@ public class SimpleShuffleHandleInfo extends 
ShuffleHandleInfoBase implements Se
   }
 
   @Override
-  public Map<Integer, List<ShuffleServerInfo>> 
getAvailablePartitionServersForWriter() {
+  public Map<Integer, List<ShuffleServerInfo>> 
getAvailablePartitionServersForWriter(
+      Map<Integer, List<ShuffleServerInfo>> exclusiveServers) {
+    if (exclusiveServers != null && !exclusiveServers.isEmpty()) {
+      throw new RssException(
+          "Exclusive servers are not supported when getting available 
partition servers for shuffle writer");
+    }
     return partitionToServers;
   }
 
diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/StageAttemptShuffleHandleInfo.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/StageAttemptShuffleHandleInfo.java
index 4ca7b8153..90fd51820 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/StageAttemptShuffleHandleInfo.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/StageAttemptShuffleHandleInfo.java
@@ -63,8 +63,9 @@ public class StageAttemptShuffleHandleInfo extends 
ShuffleHandleInfoBase {
   }
 
   @Override
-  public Map<Integer, List<ShuffleServerInfo>> 
getAvailablePartitionServersForWriter() {
-    return current.getAvailablePartitionServersForWriter();
+  public Map<Integer, List<ShuffleServerInfo>> 
getAvailablePartitionServersForWriter(
+      Map<Integer, List<ShuffleServerInfo>> exclusiveServers) {
+    return current.getAvailablePartitionServersForWriter(exclusiveServers);
   }
 
   @Override
diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/TaskAttemptAssignment.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/TaskAttemptAssignment.java
index 18eee16e2..7e46528c4 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/TaskAttemptAssignment.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/TaskAttemptAssignment.java
@@ -17,9 +17,13 @@
 
 package org.apache.spark.shuffle.writer;
 
+import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.ConcurrentSkipListSet;
+import java.util.stream.Collectors;
 
 import org.apache.spark.shuffle.handle.ShuffleHandleInfo;
 import org.apache.spark.shuffle.handle.split.PartitionSplitInfo;
@@ -29,6 +33,7 @@ import org.slf4j.LoggerFactory;
 import org.apache.uniffle.common.PartitionSplitMode;
 import org.apache.uniffle.common.ShuffleServerInfo;
 import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.util.JavaUtils;
 
 /** This class is to get the partition assignment for ShuffleWriter. */
 public class TaskAttemptAssignment {
@@ -38,7 +43,12 @@ public class TaskAttemptAssignment {
   private ShuffleHandleInfo handle;
   private final long taskAttemptId;
 
+  // key: partitionId, values: exclusive servers.
+  // this is for the partition split mechanism with load balance mode
+  private final Map<Integer, Set<ShuffleServerInfo>> 
exclusiveServersForPartition;
+
   public TaskAttemptAssignment(long taskAttemptId, ShuffleHandleInfo 
shuffleHandleInfo) {
+    this.exclusiveServersForPartition = JavaUtils.newConcurrentMap();
     this.update(shuffleHandleInfo);
     this.handle = shuffleHandleInfo;
     this.taskAttemptId = taskAttemptId;
@@ -58,16 +68,39 @@ public class TaskAttemptAssignment {
     if (handle == null) {
       throw new RssException("Errors on updating shuffle handle by the empty 
handleInfo.");
     }
-    this.assignment = handle.getAvailablePartitionServersForWriter();
+    this.assignment =
+        handle.getAvailablePartitionServersForWriter(
+            this.exclusiveServersForPartition.entrySet().stream()
+                .collect(Collectors.toMap(Map.Entry::getKey, x -> new 
ArrayList<>(x.getValue()))));
     this.handle = handle;
   }
 
-  public boolean isSkipPartitionSplit(int partitionId) {
-    // for those load balance partition split, once split, skip the following 
split.
+  private boolean hasBeenLoadBalanced(int partitionId) {
     PartitionSplitInfo splitInfo = 
this.handle.getPartitionSplitInfo(partitionId);
     return splitInfo.isSplit() && splitInfo.getMode() == 
PartitionSplitMode.LOAD_BALANCE;
   }
 
+  /**
+   * If partition has been load balanced and marked as split, it could update 
assignment by the next
+   * servers. Otherwise, it will directly return false that will trigger 
reassignment.
+   *
+   * @param partitionId
+   * @param exclusiveServers
+   * @return
+   */
+  public boolean updatePartitionSplitAssignment(
+      int partitionId, List<ShuffleServerInfo> exclusiveServers) {
+    if (hasBeenLoadBalanced(partitionId)) {
+      Set<ShuffleServerInfo> servers =
+          this.exclusiveServersForPartition.computeIfAbsent(
+              partitionId, k -> new ConcurrentSkipListSet<>());
+      servers.addAll(exclusiveServers);
+      update(this.handle);
+      return true;
+    }
+    return false;
+  }
+
   /**
    * @param partitionId
    * @return all assigned shuffle servers for one partition id
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 382643c52..b51ab00d5 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -704,9 +704,13 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     for (Map.Entry<Integer, List<ReceivingFailureServer>> entry :
         failurePartitionToServers.entrySet()) {
       int partitionId = entry.getKey();
-      boolean isSkip = taskAttemptAssignment.isSkipPartitionSplit(partitionId);
-      if (!isSkip) {
-        partitionToServersReassignList.put(partitionId, entry.getValue());
+      List<ReceivingFailureServer> failureServers = entry.getValue();
+      if (!taskAttemptAssignment.updatePartitionSplitAssignment(
+          partitionId,
+          failureServers.stream()
+              .map(x -> ShuffleServerInfo.from(x.getServerId()))
+              .collect(Collectors.toList()))) {
+        partitionToServersReassignList.put(partitionId, failureServers);
       }
     }
 
diff --git 
a/common/src/main/java/org/apache/uniffle/common/ShuffleServerInfo.java 
b/common/src/main/java/org/apache/uniffle/common/ShuffleServerInfo.java
index 300a7369c..b414d48c5 100644
--- a/common/src/main/java/org/apache/uniffle/common/ShuffleServerInfo.java
+++ b/common/src/main/java/org/apache/uniffle/common/ShuffleServerInfo.java
@@ -22,11 +22,16 @@ import java.util.List;
 import java.util.stream.Collectors;
 
 import com.google.common.annotations.VisibleForTesting;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 import org.apache.uniffle.proto.RssProtos;
 
 public class ShuffleServerInfo implements Serializable {
+  private static final Logger LOGGER = 
LoggerFactory.getLogger(ShuffleServerInfo.class);
+  private static final String DELIMITER = "-";
   private static final long serialVersionUID = 0L;
+
   private String id;
 
   private String host;
@@ -37,14 +42,14 @@ public class ShuffleServerInfo implements Serializable {
 
   @VisibleForTesting
   public ShuffleServerInfo(String host, int port) {
-    this.id = host + "-" + port;
+    this.id = host + DELIMITER + port;
     this.host = host;
     this.grpcPort = port;
   }
 
   @VisibleForTesting
   public ShuffleServerInfo(String host, int grpcPort, int nettyPort) {
-    this.id = host + "-" + grpcPort + "-" + nettyPort;
+    this.id = host + DELIMITER + grpcPort + DELIMITER + nettyPort;
     this.host = host;
     this.grpcPort = grpcPort;
     this.nettyPort = nettyPort;
@@ -149,4 +154,20 @@ public class ShuffleServerInfo implements Serializable {
         .map(server -> convertToShuffleServerId(server))
         .collect(Collectors.toList());
   }
+
+  public static ShuffleServerInfo from(String serverId) {
+    if (serverId == null || serverId.isEmpty()) {
+      LOGGER.warn("Server id is null or empty");
+      return null;
+    }
+    String[] parts = serverId.split(DELIMITER);
+    if (parts.length == 2) {
+      return new ShuffleServerInfo(parts[0], Integer.parseInt(parts[1]));
+    } else if (parts.length == 3) {
+      return new ShuffleServerInfo(
+          parts[0], Integer.parseInt(parts[1]), Integer.parseInt(parts[2]));
+    }
+    LOGGER.warn("Server id is invalid. {}", serverId);
+    return null;
+  }
 }
diff --git 
a/common/src/test/java/org/apache/uniffle/common/ShuffleServerInfoTest.java 
b/common/src/test/java/org/apache/uniffle/common/ShuffleServerInfoTest.java
index feba4a4f5..4883cc4e9 100644
--- a/common/src/test/java/org/apache/uniffle/common/ShuffleServerInfoTest.java
+++ b/common/src/test/java/org/apache/uniffle/common/ShuffleServerInfoTest.java
@@ -21,6 +21,7 @@ import org.junit.jupiter.api.Test;
 
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertNotEquals;
+import static org.junit.jupiter.api.Assertions.assertNull;
 
 public class ShuffleServerInfoTest {
 
@@ -60,4 +61,21 @@ public class ShuffleServerInfoTest {
             + "]}",
         newInfo.toString());
   }
+
+  @Test
+  public void testFromServerId() {
+    // case1
+    ShuffleServerInfo info = new ShuffleServerInfo("localhost", 1234);
+    String serverId = info.getId();
+    assertEquals(info, ShuffleServerInfo.from(serverId));
+
+    // case2
+    info = new ShuffleServerInfo("localhost", 1234, 5678);
+    serverId = info.getId();
+    assertEquals(info, ShuffleServerInfo.from(serverId));
+
+    // case3: illegal server id
+    String illegalServerId = "illegal";
+    assertNull(ShuffleServerInfo.from(illegalServerId));
+  }
 }

Reply via email to