This is an automated email from the ASF dual-hosted git repository.
zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new fe0ff7e60 [#2586] fix(spark): Support writer switching servers on
partition split with LOAD_BALANCE mode without reassign (#2587)
fe0ff7e60 is described below
commit fe0ff7e60efab90a556554f4afb46f0caea6c2ce
Author: Junfan Zhang <[email protected]>
AuthorDate: Fri Aug 22 10:08:12 2025 +0800
[#2586] fix(spark): Support writer switching servers on partition split
with LOAD_BALANCE mode without reassign (#2587)
### What changes were proposed in this pull request?
This PR is to support writer switching servers on partition split with
LOAD_BALANCE mode without reassign.
### Why are the changes needed?
fix #2586
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Unit tests
---
.../shuffle/handle/MutableShuffleHandleInfo.java | 13 ++++----
.../spark/shuffle/handle/ShuffleHandleInfo.java | 8 ++++-
.../shuffle/handle/SimpleShuffleHandleInfo.java | 8 ++++-
.../handle/StageAttemptShuffleHandleInfo.java | 5 +--
.../shuffle/writer/TaskAttemptAssignment.java | 39 ++++++++++++++++++++--
.../spark/shuffle/writer/RssShuffleWriter.java | 10 ++++--
.../apache/uniffle/common/ShuffleServerInfo.java | 25 ++++++++++++--
.../uniffle/common/ShuffleServerInfoTest.java | 18 ++++++++++
8 files changed, 108 insertions(+), 18 deletions(-)
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfo.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfo.java
index 0e22c017c..35aaa50be 100644
---
a/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfo.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/MutableShuffleHandleInfo.java
@@ -202,21 +202,21 @@ public class MutableShuffleHandleInfo extends
ShuffleHandleInfoBase {
}
@Override
- public Map<Integer, List<ShuffleServerInfo>>
getAvailablePartitionServersForWriter() {
+ public Map<Integer, List<ShuffleServerInfo>>
getAvailablePartitionServersForWriter(
+ Map<Integer, List<ShuffleServerInfo>> exclusiveServers) {
+ Map<Integer, List<ShuffleServerInfo>> requestExclusiveServers =
+ exclusiveServers == null ? Collections.emptyMap() : exclusiveServers;
Map<Integer, List<ShuffleServerInfo>> assignment = new HashMap<>();
for (Map.Entry<Integer, Map<Integer, List<ShuffleServerInfo>>> entry :
partitionReplicaAssignedServers.entrySet()) {
int partitionId = entry.getKey();
+ List<ShuffleServerInfo> partitionExclusiveServers =
+ requestExclusiveServers.getOrDefault(partitionId,
Collections.emptyList());
Map<Integer, List<ShuffleServerInfo>> replicaServers = entry.getValue();
PartitionSplitInfo splitInfo = this.getPartitionSplitInfo(partitionId);
for (Map.Entry<Integer, List<ShuffleServerInfo>> replicaServerEntry :
replicaServers.entrySet()) {
- // For normal partition reassignment, the latest replacement shuffle
server is always used.
- Optional<ShuffleServerInfo> candidateOptional =
- replicaServerEntry.getValue().stream()
- .filter(x ->
!excludedServerToReplacements.containsKey(x.getId()))
- .findFirst();
// Get the unexcluded server for each replica writing
ShuffleServerInfo candidate =
replicaServerEntry.getValue().get(replicaServerEntry.getValue().size() - 1);
@@ -230,6 +230,7 @@ public class MutableShuffleHandleInfo extends
ShuffleHandleInfoBase {
List<ShuffleServerInfo> servers =
replicaServerEntry.getValue().stream()
.filter(x ->
!excludedServerToReplacements.containsKey(x.getId()))
+ .filter(x -> !partitionExclusiveServers.contains(x))
.collect(Collectors.toList());
// 2. exclude the first partition split triggered node.
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/ShuffleHandleInfo.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/ShuffleHandleInfo.java
index e51ff8029..b794acb68 100644
---
a/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/ShuffleHandleInfo.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/ShuffleHandleInfo.java
@@ -17,6 +17,7 @@
package org.apache.spark.shuffle.handle;
+import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
@@ -36,7 +37,12 @@ public interface ShuffleHandleInfo {
* shuffleServers. Implementations might return dynamic, up-to-date
information here. Returns
* partitionId -> [replica1, replica2, ...]
*/
- Map<Integer, List<ShuffleServerInfo>>
getAvailablePartitionServersForWriter();
+ default Map<Integer, List<ShuffleServerInfo>>
getAvailablePartitionServersForWriter() {
+ return getAvailablePartitionServersForWriter(Collections.emptyMap());
+ }
+
+ Map<Integer, List<ShuffleServerInfo>> getAvailablePartitionServersForWriter(
+ Map<Integer, List<ShuffleServerInfo>> exclusivePartitionServers);
/**
* Get all servers ever assigned to writers group by partitionId for reader
to get the data
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/SimpleShuffleHandleInfo.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/SimpleShuffleHandleInfo.java
index 20b5c13ec..13a83c7d6 100644
---
a/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/SimpleShuffleHandleInfo.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/SimpleShuffleHandleInfo.java
@@ -28,6 +28,7 @@ import
org.apache.spark.shuffle.handle.split.PartitionSplitInfo;
import org.apache.uniffle.client.PartitionDataReplicaRequirementTracking;
import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.exception.RssException;
/**
* Class for holding, 1. partition ID -> shuffle servers mapping. 2. remote
storage info
@@ -54,7 +55,12 @@ public class SimpleShuffleHandleInfo extends
ShuffleHandleInfoBase implements Se
}
@Override
- public Map<Integer, List<ShuffleServerInfo>>
getAvailablePartitionServersForWriter() {
+ public Map<Integer, List<ShuffleServerInfo>>
getAvailablePartitionServersForWriter(
+ Map<Integer, List<ShuffleServerInfo>> exclusiveServers) {
+ if (exclusiveServers != null && !exclusiveServers.isEmpty()) {
+ throw new RssException(
+ "Exclusive servers are not supported when getting available
partition servers for shuffle writer");
+ }
return partitionToServers;
}
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/StageAttemptShuffleHandleInfo.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/StageAttemptShuffleHandleInfo.java
index 4ca7b8153..90fd51820 100644
---
a/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/StageAttemptShuffleHandleInfo.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/handle/StageAttemptShuffleHandleInfo.java
@@ -63,8 +63,9 @@ public class StageAttemptShuffleHandleInfo extends
ShuffleHandleInfoBase {
}
@Override
- public Map<Integer, List<ShuffleServerInfo>>
getAvailablePartitionServersForWriter() {
- return current.getAvailablePartitionServersForWriter();
+ public Map<Integer, List<ShuffleServerInfo>>
getAvailablePartitionServersForWriter(
+ Map<Integer, List<ShuffleServerInfo>> exclusiveServers) {
+ return current.getAvailablePartitionServersForWriter(exclusiveServers);
}
@Override
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/TaskAttemptAssignment.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/TaskAttemptAssignment.java
index 18eee16e2..7e46528c4 100644
---
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/TaskAttemptAssignment.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/TaskAttemptAssignment.java
@@ -17,9 +17,13 @@
package org.apache.spark.shuffle.writer;
+import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.ConcurrentSkipListSet;
+import java.util.stream.Collectors;
import org.apache.spark.shuffle.handle.ShuffleHandleInfo;
import org.apache.spark.shuffle.handle.split.PartitionSplitInfo;
@@ -29,6 +33,7 @@ import org.slf4j.LoggerFactory;
import org.apache.uniffle.common.PartitionSplitMode;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.common.util.JavaUtils;
/** This class is to get the partition assignment for ShuffleWriter. */
public class TaskAttemptAssignment {
@@ -38,7 +43,12 @@ public class TaskAttemptAssignment {
private ShuffleHandleInfo handle;
private final long taskAttemptId;
+ // key: partitionId, values: exclusive servers.
+ // this is for the partition split mechanism with load balance mode
+ private final Map<Integer, Set<ShuffleServerInfo>>
exclusiveServersForPartition;
+
public TaskAttemptAssignment(long taskAttemptId, ShuffleHandleInfo
shuffleHandleInfo) {
+ this.exclusiveServersForPartition = JavaUtils.newConcurrentMap();
this.update(shuffleHandleInfo);
this.handle = shuffleHandleInfo;
this.taskAttemptId = taskAttemptId;
@@ -58,16 +68,39 @@ public class TaskAttemptAssignment {
if (handle == null) {
throw new RssException("Errors on updating shuffle handle by the empty
handleInfo.");
}
- this.assignment = handle.getAvailablePartitionServersForWriter();
+ this.assignment =
+ handle.getAvailablePartitionServersForWriter(
+ this.exclusiveServersForPartition.entrySet().stream()
+ .collect(Collectors.toMap(Map.Entry::getKey, x -> new
ArrayList<>(x.getValue()))));
this.handle = handle;
}
- public boolean isSkipPartitionSplit(int partitionId) {
- // for those load balance partition split, once split, skip the following
split.
+ private boolean hasBeenLoadBalanced(int partitionId) {
PartitionSplitInfo splitInfo =
this.handle.getPartitionSplitInfo(partitionId);
return splitInfo.isSplit() && splitInfo.getMode() ==
PartitionSplitMode.LOAD_BALANCE;
}
+ /**
+ * If partition has been load balanced and marked as split, it could update
assignment by the next
+ * servers. Otherwise, it will directly return false that will trigger
reassignment.
+ *
+ * @param partitionId
+ * @param exclusiveServers
+ * @return
+ */
+ public boolean updatePartitionSplitAssignment(
+ int partitionId, List<ShuffleServerInfo> exclusiveServers) {
+ if (hasBeenLoadBalanced(partitionId)) {
+ Set<ShuffleServerInfo> servers =
+ this.exclusiveServersForPartition.computeIfAbsent(
+ partitionId, k -> new ConcurrentSkipListSet<>());
+ servers.addAll(exclusiveServers);
+ update(this.handle);
+ return true;
+ }
+ return false;
+ }
+
/**
* @param partitionId
* @return all assigned shuffle servers for one partition id
diff --git
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 382643c52..b51ab00d5 100644
---
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -704,9 +704,13 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
for (Map.Entry<Integer, List<ReceivingFailureServer>> entry :
failurePartitionToServers.entrySet()) {
int partitionId = entry.getKey();
- boolean isSkip = taskAttemptAssignment.isSkipPartitionSplit(partitionId);
- if (!isSkip) {
- partitionToServersReassignList.put(partitionId, entry.getValue());
+ List<ReceivingFailureServer> failureServers = entry.getValue();
+ if (!taskAttemptAssignment.updatePartitionSplitAssignment(
+ partitionId,
+ failureServers.stream()
+ .map(x -> ShuffleServerInfo.from(x.getServerId()))
+ .collect(Collectors.toList()))) {
+ partitionToServersReassignList.put(partitionId, failureServers);
}
}
diff --git
a/common/src/main/java/org/apache/uniffle/common/ShuffleServerInfo.java
b/common/src/main/java/org/apache/uniffle/common/ShuffleServerInfo.java
index 300a7369c..b414d48c5 100644
--- a/common/src/main/java/org/apache/uniffle/common/ShuffleServerInfo.java
+++ b/common/src/main/java/org/apache/uniffle/common/ShuffleServerInfo.java
@@ -22,11 +22,16 @@ import java.util.List;
import java.util.stream.Collectors;
import com.google.common.annotations.VisibleForTesting;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
import org.apache.uniffle.proto.RssProtos;
public class ShuffleServerInfo implements Serializable {
+ private static final Logger LOGGER =
LoggerFactory.getLogger(ShuffleServerInfo.class);
+ private static final String DELIMITER = "-";
private static final long serialVersionUID = 0L;
+
private String id;
private String host;
@@ -37,14 +42,14 @@ public class ShuffleServerInfo implements Serializable {
@VisibleForTesting
public ShuffleServerInfo(String host, int port) {
- this.id = host + "-" + port;
+ this.id = host + DELIMITER + port;
this.host = host;
this.grpcPort = port;
}
@VisibleForTesting
public ShuffleServerInfo(String host, int grpcPort, int nettyPort) {
- this.id = host + "-" + grpcPort + "-" + nettyPort;
+ this.id = host + DELIMITER + grpcPort + DELIMITER + nettyPort;
this.host = host;
this.grpcPort = grpcPort;
this.nettyPort = nettyPort;
@@ -149,4 +154,20 @@ public class ShuffleServerInfo implements Serializable {
.map(server -> convertToShuffleServerId(server))
.collect(Collectors.toList());
}
+
+ public static ShuffleServerInfo from(String serverId) {
+ if (serverId == null || serverId.isEmpty()) {
+ LOGGER.warn("Server id is null or empty");
+ return null;
+ }
+ String[] parts = serverId.split(DELIMITER);
+ if (parts.length == 2) {
+ return new ShuffleServerInfo(parts[0], Integer.parseInt(parts[1]));
+ } else if (parts.length == 3) {
+ return new ShuffleServerInfo(
+ parts[0], Integer.parseInt(parts[1]), Integer.parseInt(parts[2]));
+ }
+ LOGGER.warn("Server id is invalid. {}", serverId);
+ return null;
+ }
}
diff --git
a/common/src/test/java/org/apache/uniffle/common/ShuffleServerInfoTest.java
b/common/src/test/java/org/apache/uniffle/common/ShuffleServerInfoTest.java
index feba4a4f5..4883cc4e9 100644
--- a/common/src/test/java/org/apache/uniffle/common/ShuffleServerInfoTest.java
+++ b/common/src/test/java/org/apache/uniffle/common/ShuffleServerInfoTest.java
@@ -21,6 +21,7 @@ import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
+import static org.junit.jupiter.api.Assertions.assertNull;
public class ShuffleServerInfoTest {
@@ -60,4 +61,21 @@ public class ShuffleServerInfoTest {
+ "]}",
newInfo.toString());
}
+
+ @Test
+ public void testFromServerId() {
+ // case1
+ ShuffleServerInfo info = new ShuffleServerInfo("localhost", 1234);
+ String serverId = info.getId();
+ assertEquals(info, ShuffleServerInfo.from(serverId));
+
+ // case2
+ info = new ShuffleServerInfo("localhost", 1234, 5678);
+ serverId = info.getId();
+ assertEquals(info, ShuffleServerInfo.from(serverId));
+
+ // case3: illegal server id
+ String illegalServerId = "illegal";
+ assertNull(ShuffleServerInfo.from(illegalServerId));
+ }
}