This is an automated email from the ASF dual-hosted git repository.

roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new 9a227da  [Improvement] Introduce config to customize assignment server 
numbers in client side (#100)
9a227da is described below

commit 9a227da1d4f6de2f2267225ae1d547414be5e234
Author: Junfan Zhang <[email protected]>
AuthorDate: Mon Aug 1 20:12:17 2022 +0800

    [Improvement] Introduce config to customize assignment server numbers in 
client side (#100)
    
    ### What changes were proposed in this pull request?
    [Improvement] Introduce config to customize assignment server numbers in 
client side.
    
    **Changelog**
    1. Introduce the config of 
`<client_type>.rss.client.assignment.shuffle.nodes.max`
    
    
    ### Why are the changes needed?
    Now the assignment number specified by coordinator's conf of 
rss.coordinator.shuffle.nodes.max. But i think it's not suitable for all spark 
jobs.
    
    We should introduce new config to let client specify the assignment server 
number. rss.coordinator.shuffle.nodes.max should be as a max limitation of 
clients' number.
    
    ### Does this PR introduce _any_ user-facing change?
    YES.
    
    ### How was this patch tested?
    UT.
---
 .../org/apache/hadoop/mapreduce/RssMRConfig.java   |   7 +-
 .../hadoop/mapreduce/v2/app/RssMRAppMaster.java    |  17 ++-
 .../hadoop/mapred/SortWriteBufferManagerTest.java  |   2 +-
 .../hadoop/mapreduce/task/reduce/FetcherTest.java  |   2 +-
 .../org/apache/spark/shuffle/RssSparkConfig.java   |   4 +
 .../apache/spark/shuffle/RssShuffleManager.java    |   4 +-
 .../apache/spark/shuffle/RssShuffleManager.java    |   5 +-
 .../uniffle/client/api/ShuffleWriteClient.java     |   2 +-
 .../client/impl/ShuffleWriteClientImpl.java        |   4 +-
 .../uniffle/client/util/RssClientConfig.java       |   4 +
 .../uniffle/coordinator/AssignmentStrategy.java    |   2 +-
 .../coordinator/BasicAssignmentStrategy.java       |   8 +-
 .../uniffle/coordinator/CoordinatorConf.java       |   2 +-
 .../coordinator/CoordinatorGrpcService.java        |  10 +-
 .../PartitionBalanceAssignmentStrategy.java        |  12 +-
 .../coordinator/BasicAssignmentStrategyTest.java   | 104 ++++++++++++++++-
 .../PartitionBalanceAssignmentStrategyTest.java    | 124 ++++++++++++++++++---
 docs/client_guide.md                               |   1 +
 .../test/AssignmentServerNodesNumberTest.java      | 106 ++++++++++++++++++
 .../uniffle/test/AssignmentWithTagsTest.java       |  10 +-
 .../client/impl/grpc/CoordinatorGrpcClient.java    |  12 +-
 .../request/RssGetShuffleAssignmentsRequest.java   |  14 +++
 proto/src/main/proto/Rss.proto                     |   1 +
 23 files changed, 409 insertions(+), 48 deletions(-)

diff --git 
a/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRConfig.java 
b/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRConfig.java
index ed1f90f..ef47e21 100644
--- a/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRConfig.java
+++ b/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRConfig.java
@@ -144,7 +144,12 @@ public class RssMRConfig {
   public static final int RSS_ACCESS_TIMEOUT_MS_DEFAULT_VALUE = 
RssClientConfig.RSS_ACCESS_TIMEOUT_MS_DEFAULT_VALUE;
 
   public static final String RSS_CLIENT_ASSIGNMENT_TAGS =
-          MR_RSS_CONFIG_PREFIX + RssClientConfig.RSS_CLIENT_ASSIGNMENT_TAGS;
+      MR_RSS_CONFIG_PREFIX + RssClientConfig.RSS_CLIENT_ASSIGNMENT_TAGS;
+
+  public static final String RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER =
+      RssClientConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER;
+  public static final int 
RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER_DEFAULT_VALUE =
+      
RssClientConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER_DEFAULT_VALUE;
 
   public static final String RSS_CONF_FILE = "rss_conf.xml";
 
diff --git 
a/client-mr/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java
 
b/client-mr/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java
index e163eec..c65f2a2 100644
--- 
a/client-mr/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java
+++ 
b/client-mr/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java
@@ -128,10 +128,23 @@ public class RssMRAppMaster extends MRAppMaster {
       }
       assignmentTags.add(Constants.SHUFFLE_SERVER_VERSION);
 
+      int requiredAssignmentShuffleServersNum = conf.getInt(
+          RssMRConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER,
+          RssMRConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER_DEFAULT_VALUE
+      );
+
       ApplicationAttemptId applicationAttemptId = 
RssMRUtils.getApplicationAttemptId();
       String appId = applicationAttemptId.toString();
-      ShuffleAssignmentsInfo response = client.getShuffleAssignments(
-          appId, 0, numReduceTasks, 1, Sets.newHashSet(assignmentTags));
+
+      ShuffleAssignmentsInfo response =
+          client.getShuffleAssignments(
+              appId,
+              0,
+              numReduceTasks,
+              1,
+              Sets.newHashSet(assignmentTags),
+              requiredAssignmentShuffleServersNum
+          );
 
       Map<ShuffleServerInfo, List<PartitionRange>> serverToPartitionRanges = 
response.getServerToPartitionRanges();
       final ScheduledExecutorService scheduledExecutorService = 
Executors.newSingleThreadScheduledExecutor(
diff --git 
a/client-mr/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
 
b/client-mr/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
index 398b52a..7613883 100644
--- 
a/client-mr/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
+++ 
b/client-mr/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
@@ -314,7 +314,7 @@ public class SortWriteBufferManagerTest {
     }
 
     @Override
-    public ShuffleAssignmentsInfo getShuffleAssignments(String appId, int 
shuffleId, int partitionNum, int partitionNumPerRange, Set<String> 
requiredTags) {
+    public ShuffleAssignmentsInfo getShuffleAssignments(String appId, int 
shuffleId, int partitionNum, int partitionNumPerRange, Set<String> 
requiredTags, int assignmentShuffleServerNumber) {
       return null;
     }
 
diff --git 
a/client-mr/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
 
b/client-mr/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
index ee2539b..cc622f0 100644
--- 
a/client-mr/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
+++ 
b/client-mr/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
@@ -401,7 +401,7 @@ public class FetcherTest {
     }
 
     @Override
-    public ShuffleAssignmentsInfo getShuffleAssignments(String appId, int 
shuffleId, int partitionNum, int partitionNumPerRange, Set<String> 
requiredTags) {
+    public ShuffleAssignmentsInfo getShuffleAssignments(String appId, int 
shuffleId, int partitionNum, int partitionNumPerRange, Set<String> 
requiredTags, int assignmentShuffleServerNumber) {
       return null;
     }
 
diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
index 875c9a5..6b549b1 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
@@ -203,6 +203,10 @@ public class RssSparkConfig {
               + "whether this conf is set or not"))
       .createWithDefault("");
 
+  public static final ConfigEntry<Integer> 
RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER = createIntegerBuilder(
+      new ConfigBuilder(SPARK_RSS_CONFIG_PREFIX + 
RssClientConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER))
+      
.createWithDefault(RssClientConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER_DEFAULT_VALUE);
+
   public static final ConfigEntry<String> RSS_COORDINATOR_QUORUM = 
createStringBuilder(
       new ConfigBuilder(SPARK_RSS_CONFIG_PREFIX + 
RssClientConfig.RSS_COORDINATOR_QUORUM)
           .doc("Coordinator quorum"))
diff --git 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index c313747..ec84308 100644
--- 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -218,9 +218,11 @@ public class RssShuffleManager implements ShuffleManager {
     // get all register info according to coordinator's response
     Set<String> assignmentTags = 
RssSparkShuffleUtils.getAssignmentTags(sparkConf);
 
+    int requiredShuffleServerNumber = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER);
+
     ShuffleAssignmentsInfo response = shuffleWriteClient.getShuffleAssignments(
         appId, shuffleId, dependency.partitioner().numPartitions(),
-        partitionNumPerRange, assignmentTags);
+        partitionNumPerRange, assignmentTags, requiredShuffleServerNumber);
     Map<Integer, List<ShuffleServerInfo>> partitionToServers = 
response.getPartitionToServers();
 
     startHeartbeat();
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 32239b3..80fac99 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -256,12 +256,15 @@ public class RssShuffleManager implements ShuffleManager {
 
     Set<String> assignmentTags = 
RssSparkShuffleUtils.getAssignmentTags(sparkConf);
 
+    int requiredShuffleServerNumber = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER);
+
     ShuffleAssignmentsInfo response = shuffleWriteClient.getShuffleAssignments(
         id.get(),
         shuffleId,
         dependency.partitioner().numPartitions(),
         1,
-        assignmentTags);
+        assignmentTags,
+        requiredShuffleServerNumber);
     Map<Integer, List<ShuffleServerInfo>> partitionToServers = 
response.getPartitionToServers();
 
     startHeartbeat();
diff --git 
a/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java 
b/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java
index 2cf3685..d5981c4 100644
--- a/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java
+++ b/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java
@@ -60,7 +60,7 @@ public interface ShuffleWriteClient {
       int bitmapNum);
 
   ShuffleAssignmentsInfo getShuffleAssignments(String appId, int shuffleId, 
int partitionNum,
-      int partitionNumPerRange, Set<String> requiredTags);
+      int partitionNumPerRange, Set<String> requiredTags, int 
assignmentShuffleServerNumber);
 
   Roaring64NavigableMap getShuffleResult(String clientType, 
Set<ShuffleServerInfo> shuffleServerInfoSet,
       String appId, int shuffleId, int partitionId);
diff --git 
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
 
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
index 43d8b3b..c6c13e0 100644
--- 
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
+++ 
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
@@ -375,9 +375,9 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
 
   @Override
   public ShuffleAssignmentsInfo getShuffleAssignments(String appId, int 
shuffleId, int partitionNum,
-      int partitionNumPerRange, Set<String> requiredTags) {
+      int partitionNumPerRange, Set<String> requiredTags, int 
assignmentShuffleServerNumber) {
     RssGetShuffleAssignmentsRequest request = new 
RssGetShuffleAssignmentsRequest(
-        appId, shuffleId, partitionNum, partitionNumPerRange, replica, 
requiredTags);
+        appId, shuffleId, partitionNum, partitionNumPerRange, replica, 
requiredTags, assignmentShuffleServerNumber);
 
     RssGetShuffleAssignmentsResponse response = new 
RssGetShuffleAssignmentsResponse(ResponseStatusCode.INTERNAL_ERROR);
     for (CoordinatorClient coordinatorClient : coordinatorClients) {
diff --git 
a/client/src/main/java/org/apache/uniffle/client/util/RssClientConfig.java 
b/client/src/main/java/org/apache/uniffle/client/util/RssClientConfig.java
index 0b42d49..eb6006a 100644
--- a/client/src/main/java/org/apache/uniffle/client/util/RssClientConfig.java
+++ b/client/src/main/java/org/apache/uniffle/client/util/RssClientConfig.java
@@ -65,4 +65,8 @@ public class RssClientConfig {
   public static final int RSS_ACCESS_TIMEOUT_MS_DEFAULT_VALUE = 10000;
   public static final String RSS_DYNAMIC_CLIENT_CONF_ENABLED = 
"rss.dynamicClientConf.enabled";
   public static final boolean RSS_DYNAMIC_CLIENT_CONF_ENABLED_DEFAULT_VALUE = 
true;
+
+  public static final String RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER =
+      "rss.client.assignment.shuffle.nodes.max";
+  public static final int 
RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER_DEFAULT_VALUE = -1;
 }
diff --git 
a/coordinator/src/main/java/org/apache/uniffle/coordinator/AssignmentStrategy.java
 
b/coordinator/src/main/java/org/apache/uniffle/coordinator/AssignmentStrategy.java
index 86ddb18..36d1908 100644
--- 
a/coordinator/src/main/java/org/apache/uniffle/coordinator/AssignmentStrategy.java
+++ 
b/coordinator/src/main/java/org/apache/uniffle/coordinator/AssignmentStrategy.java
@@ -22,6 +22,6 @@ import java.util.Set;
 public interface AssignmentStrategy {
 
   PartitionRangeAssignment assign(int totalPartitionNum, int 
partitionNumPerRange,
-      int replica, Set<String> requiredTags);
+      int replica, Set<String> requiredTags, int requiredShuffleServerNumber);
 
 }
diff --git 
a/coordinator/src/main/java/org/apache/uniffle/coordinator/BasicAssignmentStrategy.java
 
b/coordinator/src/main/java/org/apache/uniffle/coordinator/BasicAssignmentStrategy.java
index 8d4eb54..54ca2c2 100644
--- 
a/coordinator/src/main/java/org/apache/uniffle/coordinator/BasicAssignmentStrategy.java
+++ 
b/coordinator/src/main/java/org/apache/uniffle/coordinator/BasicAssignmentStrategy.java
@@ -41,10 +41,14 @@ public class BasicAssignmentStrategy implements 
AssignmentStrategy {
 
   @Override
   public PartitionRangeAssignment assign(int totalPartitionNum, int 
partitionNumPerRange,
-      int replica, Set<String> requiredTags) {
+      int replica, Set<String> requiredTags, int requiredShuffleServerNumber) {
     List<PartitionRange> ranges = 
CoordinatorUtils.generateRanges(totalPartitionNum, partitionNumPerRange);
     int shuffleNodesMax = clusterManager.getShuffleNodesMax();
-    List<ServerNode> servers = getRequiredServers(requiredTags, 
shuffleNodesMax);
+    int expectedShuffleNodesNum = shuffleNodesMax;
+    if (requiredShuffleServerNumber < shuffleNodesMax && 
requiredShuffleServerNumber > 0) {
+      expectedShuffleNodesNum = requiredShuffleServerNumber;
+    }
+    List<ServerNode> servers = getRequiredServers(requiredTags, 
expectedShuffleNodesNum);
     if (servers.isEmpty() || servers.size() < replica) {
       return new PartitionRangeAssignment(null);
     }
diff --git 
a/coordinator/src/main/java/org/apache/uniffle/coordinator/CoordinatorConf.java 
b/coordinator/src/main/java/org/apache/uniffle/coordinator/CoordinatorConf.java
index bdad2d6..ebb50ce 100644
--- 
a/coordinator/src/main/java/org/apache/uniffle/coordinator/CoordinatorConf.java
+++ 
b/coordinator/src/main/java/org/apache/uniffle/coordinator/CoordinatorConf.java
@@ -61,7 +61,7 @@ public class CoordinatorConf extends RssBaseConf {
       .key("rss.coordinator.shuffle.nodes.max")
       .intType()
       .defaultValue(9)
-      .withDescription("The max number of shuffle server when do the 
assignment");
+      .withDescription("The max limitation number of shuffle server when do 
the assignment");
   public static final ConfigOption<List<String>> COORDINATOR_ACCESS_CHECKERS = 
ConfigOptions
       .key("rss.coordinator.access.checkers")
       .stringType()
diff --git 
a/coordinator/src/main/java/org/apache/uniffle/coordinator/CoordinatorGrpcService.java
 
b/coordinator/src/main/java/org/apache/uniffle/coordinator/CoordinatorGrpcService.java
index ce14458..d2c3fc2 100644
--- 
a/coordinator/src/main/java/org/apache/uniffle/coordinator/CoordinatorGrpcService.java
+++ 
b/coordinator/src/main/java/org/apache/uniffle/coordinator/CoordinatorGrpcService.java
@@ -109,15 +109,21 @@ public class CoordinatorGrpcService extends 
CoordinatorServerGrpc.CoordinatorSer
     final int partitionNumPerRange = request.getPartitionNumPerRange();
     final int replica = request.getDataReplica();
     final Set<String> requiredTags = 
Sets.newHashSet(request.getRequireTagsList());
+    final int requiredShuffleServerNumber = 
request.getAssignmentShuffleServerNumber();
 
     LOG.info("Request of getShuffleAssignments for appId[" + appId
         + "], shuffleId[" + shuffleId + "], partitionNum[" + partitionNum
-        + "], partitionNumPerRange[" + partitionNumPerRange + "], replica[" + 
replica + "]");
+        + "], partitionNumPerRange[" + partitionNumPerRange + "], replica[" + 
replica
+        + "], requiredTags[" + requiredTags
+        + "], requiredShuffleServerNumber[" + requiredShuffleServerNumber + "]"
+    );
 
     GetShuffleAssignmentsResponse response;
     try {
       final PartitionRangeAssignment pra =
-          coordinatorServer.getAssignmentStrategy().assign(partitionNum, 
partitionNumPerRange, replica, requiredTags);
+          coordinatorServer
+              .getAssignmentStrategy()
+              .assign(partitionNum, partitionNumPerRange, replica, 
requiredTags, requiredShuffleServerNumber);
       response =
           CoordinatorUtils.toGetShuffleAssignmentsResponse(pra);
       logAssignmentResult(appId, shuffleId, pra);
diff --git 
a/coordinator/src/main/java/org/apache/uniffle/coordinator/PartitionBalanceAssignmentStrategy.java
 
b/coordinator/src/main/java/org/apache/uniffle/coordinator/PartitionBalanceAssignmentStrategy.java
index ba92477..d074b8c 100644
--- 
a/coordinator/src/main/java/org/apache/uniffle/coordinator/PartitionBalanceAssignmentStrategy.java
+++ 
b/coordinator/src/main/java/org/apache/uniffle/coordinator/PartitionBalanceAssignmentStrategy.java
@@ -66,7 +66,8 @@ public class PartitionBalanceAssignmentStrategy implements 
AssignmentStrategy {
       int totalPartitionNum,
       int partitionNumPerRange,
       int replica,
-      Set<String> requiredTags) {
+      Set<String> requiredTags,
+      int requiredShuffleServerNumber) {
 
     if (partitionNumPerRange != 1) {
       throw new RuntimeException("PartitionNumPerRange must be one");
@@ -107,8 +108,13 @@ public class PartitionBalanceAssignmentStrategy implements 
AssignmentStrategy {
         throw new RuntimeException("There isn't enough shuffle servers");
       }
 
-      int expectNum = clusterManager.getShuffleNodesMax();
-      if (nodes.size() < clusterManager.getShuffleNodesMax()) {
+      final int assignmentMaxNum = clusterManager.getShuffleNodesMax();
+      int expectNum = assignmentMaxNum;
+      if (requiredShuffleServerNumber < assignmentMaxNum && 
requiredShuffleServerNumber > 0) {
+        expectNum = requiredShuffleServerNumber;
+      }
+
+      if (nodes.size() < expectNum) {
         LOG.warn("Can't get expected servers [" + expectNum + "] and found 
only [" + nodes.size() + "]");
         expectNum = nodes.size();
       }
diff --git 
a/coordinator/src/test/java/org/apache/uniffle/coordinator/BasicAssignmentStrategyTest.java
 
b/coordinator/src/test/java/org/apache/uniffle/coordinator/BasicAssignmentStrategyTest.java
index 6f80eb3..a1f79cf 100644
--- 
a/coordinator/src/test/java/org/apache/uniffle/coordinator/BasicAssignmentStrategyTest.java
+++ 
b/coordinator/src/test/java/org/apache/uniffle/coordinator/BasicAssignmentStrategyTest.java
@@ -22,6 +22,8 @@ import static org.junit.jupiter.api.Assertions.assertFalse;
 import static org.junit.jupiter.api.Assertions.assertNull;
 import static org.junit.jupiter.api.Assertions.assertTrue;
 import com.google.common.collect.Sets;
+import java.util.Collection;
+import java.util.stream.Collectors;
 import org.apache.uniffle.common.PartitionRange;
 
 import java.io.IOException;
@@ -64,7 +66,7 @@ public class BasicAssignmentStrategyTest {
           20 - i, 0, tags, true));
     }
 
-    PartitionRangeAssignment pra = strategy.assign(100, 10, 2, tags);
+    PartitionRangeAssignment pra = strategy.assign(100, 10, 2, tags, -1);
     SortedMap<PartitionRange, List<ServerNode>> assignments = 
pra.getAssignments();
     assertEquals(10, assignments.size());
 
@@ -90,14 +92,14 @@ public class BasicAssignmentStrategyTest {
       clusterManager.add(new ServerNode(String.valueOf(i), "", 0, 0, 0,
           0, 0, tags, true));
     }
-    PartitionRangeAssignment pra = strategy.assign(100, 10, 2, tags);
+    PartitionRangeAssignment pra = strategy.assign(100, 10, 2, tags, -1);
     SortedMap<PartitionRange, List<ServerNode>> assignments = 
pra.getAssignments();
     Set<ServerNode> serverNodes1 = Sets.newHashSet();
     for (Map.Entry<PartitionRange, List<ServerNode>> assignment : 
assignments.entrySet()) {
       serverNodes1.addAll(assignment.getValue());
     }
 
-    pra = strategy.assign(100, 10, 2, tags);
+    pra = strategy.assign(100, 10, 2, tags, -1);
     assignments = pra.getAssignments();
     Set<ServerNode> serverNodes2 = Sets.newHashSet();
     for (Map.Entry<PartitionRange, List<ServerNode>> assignment : 
assignments.entrySet()) {
@@ -118,13 +120,13 @@ public class BasicAssignmentStrategyTest {
         0, 0, tags, true);
 
     clusterManager.add(sn1);
-    PartitionRangeAssignment pra = strategy.assign(100, 10, 2, tags);
+    PartitionRangeAssignment pra = strategy.assign(100, 10, 2, tags, -1);
     // nodeNum < replica
     assertNull(pra.getAssignments());
 
     // nodeNum = replica
     clusterManager.add(sn2);
-    pra = strategy.assign(100, 10, 2, tags);
+    pra = strategy.assign(100, 10, 2, tags, -1);
     SortedMap<PartitionRange, List<ServerNode>> assignments = 
pra.getAssignments();
     Set<ServerNode> serverNodes = Sets.newHashSet();
     for (Map.Entry<PartitionRange, List<ServerNode>> assignment : 
assignments.entrySet()) {
@@ -136,7 +138,7 @@ public class BasicAssignmentStrategyTest {
 
     // nodeNum > replica & nodeNum < shuffleNodesMax
     clusterManager.add(sn3);
-    pra = strategy.assign(100, 10, 2, tags);
+    pra = strategy.assign(100, 10, 2, tags, -1);
     assignments = pra.getAssignments();
     serverNodes = Sets.newHashSet();
     for (Map.Entry<PartitionRange, List<ServerNode>> assignment : 
assignments.entrySet()) {
@@ -147,4 +149,94 @@ public class BasicAssignmentStrategyTest {
     assertTrue(serverNodes.contains(sn2));
     assertTrue(serverNodes.contains(sn3));
   }
+
+  @Test
+  public void testAssignmentShuffleNodesNum() {
+    Set<String> serverTags = Sets.newHashSet("tag-1");
+
+    for (int i = 0; i < 20; ++i) {
+      clusterManager.add(new ServerNode("t1-" + i, "", 0, 0, 0,
+          20 - i, 0, serverTags, true));
+    }
+
+    /**
+     * case1: user specify the illegal shuffle node num(<0)
+     * it will use the default shuffle nodes num when having enough servers.
+     */
+    PartitionRangeAssignment pra = strategy.assign(100, 10, 1, serverTags, -1);
+    assertEquals(
+        shuffleNodesMax,
+        pra.getAssignments()
+            .values()
+            .stream()
+            .flatMap(Collection::stream)
+            .collect(Collectors.toSet())
+            .size()
+    );
+
+    /**
+     * case2: user specify the illegal shuffle node num(==0)
+     * it will use the default shuffle nodes num when having enough servers.
+     */
+    pra = strategy.assign(100, 10, 1, serverTags, 0);
+    assertEquals(
+        shuffleNodesMax,
+        pra.getAssignments()
+            .values()
+            .stream()
+            .flatMap(Collection::stream)
+            .collect(Collectors.toSet())
+            .size()
+    );
+
+    /**
+     * case3: user specify the illegal shuffle node num(>default max 
limitation)
+     * it will use the default shuffle nodes num when having enough servers
+     */
+    pra = strategy.assign(100, 10, 1, serverTags, shuffleNodesMax + 10);
+    assertEquals(
+        shuffleNodesMax,
+        pra.getAssignments()
+            .values()
+            .stream()
+            .flatMap(Collection::stream)
+            .collect(Collectors.toSet())
+            .size()
+    );
+
+    /**
+     * case4: user specify the legal shuffle node num,
+     * it will use the customized shuffle nodes num when having enough servers
+     */
+    pra = strategy.assign(100, 10, 1, serverTags, shuffleNodesMax - 1);
+    assertEquals(
+        shuffleNodesMax - 1,
+        pra.getAssignments()
+            .values()
+            .stream()
+            .flatMap(Collection::stream)
+            .collect(Collectors.toSet())
+            .size()
+    );
+
+    /**
+     * case5: user specify the legal shuffle node num, but cluster dont have 
enough servers,
+     * it will return the remaining servers.
+     */
+    serverTags = Sets.newHashSet("tag-2");
+    for (int i = 0; i < shuffleNodesMax - 1; ++i) {
+      clusterManager.add(new ServerNode("t2-" + i, "", 0, 0, 0,
+          20 - i, 0, serverTags, true));
+    }
+    pra = strategy.assign(100, 10, 1, serverTags, shuffleNodesMax);
+    assertEquals(
+        shuffleNodesMax - 1,
+        pra.getAssignments()
+            .values()
+            .stream()
+            .flatMap(Collection::stream)
+            .collect(Collectors.toSet())
+            .size()
+    );
+  }
 }
diff --git 
a/coordinator/src/test/java/org/apache/uniffle/coordinator/PartitionBalanceAssignmentStrategyTest.java
 
b/coordinator/src/test/java/org/apache/uniffle/coordinator/PartitionBalanceAssignmentStrategyTest.java
index ae3b6e3..47cc26f 100644
--- 
a/coordinator/src/test/java/org/apache/uniffle/coordinator/PartitionBalanceAssignmentStrategyTest.java
+++ 
b/coordinator/src/test/java/org/apache/uniffle/coordinator/PartitionBalanceAssignmentStrategyTest.java
@@ -18,6 +18,7 @@
 package org.apache.uniffle.coordinator;
 
 import java.io.IOException;
+import java.util.Collection;
 import java.util.Comparator;
 import java.util.List;
 import java.util.Set;
@@ -27,6 +28,7 @@ import com.google.common.collect.Lists;
 import com.google.common.collect.Sets;
 import com.google.common.util.concurrent.Uninterruptibles;
 
+import java.util.stream.Collectors;
 import org.apache.hadoop.conf.Configuration;
 import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.BeforeEach;
@@ -60,32 +62,32 @@ public class PartitionBalanceAssignmentStrategyTest {
     updateServerResource(list);
     boolean isThrown = false;
     try {
-      strategy.assign(100, 2, 1, tags);
+      strategy.assign(100, 2, 1, tags, -1);
     } catch (Exception e) {
       isThrown = true;
     }
     assertTrue(isThrown);
     try {
-      strategy.assign(0, 1, 1, tags);
+      strategy.assign(0, 1, 1, tags, -1);
     } catch (Exception e) {
       fail();
     }
     isThrown = false;
     try {
-      strategy.assign(10, 1, 1, Sets.newHashSet("fake"));
+      strategy.assign(10, 1, 1, Sets.newHashSet("fake"), 1);
     } catch (Exception e) {
       isThrown = true;
     }
     assertTrue(isThrown);
-    strategy.assign(100, 1, 1, tags);
+    strategy.assign(100, 1, 1, tags, -1);
     List<Long> expect = Lists.newArrayList(20L, 20L, 20L, 20L, 20L, 0L, 0L, 
0L, 0L,
         0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L);
     valid(expect);
-    strategy.assign(75, 1, 1, tags);
+    strategy.assign(75, 1, 1, tags, -1);
     expect = Lists.newArrayList(20L, 20L, 20L, 20L, 20L, 15L, 15L, 15L, 15L,
         15L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L);
     valid(expect);
-    strategy.assign(100, 1, 1, tags);
+    strategy.assign(100, 1, 1, tags, -1);
     expect = Lists.newArrayList(20L, 20L, 20L, 20L, 20L, 15L, 15L, 15L, 15L,
         15L, 20L, 20L, 20L, 20L, 20L, 0L, 0L, 0L, 0L, 0L);
     valid(expect);
@@ -94,16 +96,16 @@ public class PartitionBalanceAssignmentStrategyTest {
     list = Lists.newArrayList(7L, 18L, 7L, 3L, 19L, 15L, 11L, 10L, 16L, 11L,
         14L, 17L, 15L, 17L, 8L, 1L, 3L, 3L, 6L, 12L);
     updateServerResource(list);
-    strategy.assign(100, 1, 1, tags);
+    strategy.assign(100, 1, 1, tags, -1);
     expect = Lists.newArrayList(0L, 20L, 0L, 0L, 20L, 0L, 0L, 0L, 20L, 0L,
         0L, 20L, 0L, 20L, 0L, 0L, 0L, 0L, 0L, 0L);
     valid(expect);
-    strategy.assign(50, 1, 1, tags);
+    strategy.assign(50, 1, 1, tags, -1);
     expect = Lists.newArrayList(0L, 20L, 0L, 0L, 20L, 10L, 10L, 0L, 20L, 0L,
         10L, 20L, 10L, 20L, 0L, 0L, 0L, 0L, 0L, 10L);
     valid(expect);
 
-    strategy.assign(75, 1, 1, tags);
+    strategy.assign(75, 1, 1, tags, -1);
     expect = Lists.newArrayList(0L, 20L, 0L, 0L, 20L, 25L, 10L, 15L, 20L, 15L,
         25L, 20L, 25L, 20L, 0L, 0L, 0L, 0L, 0L, 10L);
     valid(expect);
@@ -112,15 +114,15 @@ public class PartitionBalanceAssignmentStrategyTest {
     list = Lists.newArrayList(7L, 18L, 7L, 3L, 19L, 15L, 11L, 10L, 16L, 11L,
         14L, 17L, 15L, 17L, 8L, 1L, 3L, 3L, 6L, 12L);
     updateServerResource(list);
-    strategy.assign(50, 1, 2, tags);
+    strategy.assign(50, 1, 2, tags, -1);
     expect = Lists.newArrayList(0L, 20L, 0L, 0L, 20L, 0L, 0L, 0L, 20L, 0L,
         0L, 20L, 0L, 20L, 0L, 0L, 0L, 0L, 0L, 0L);
     valid(expect);
-    strategy.assign(75, 1, 2, tags);
+    strategy.assign(75, 1, 2, tags, -1);
     expect = Lists.newArrayList(0L, 20L, 0L, 0L, 50L, 30L, 0L, 0L, 20L, 0L,
         30L, 20L, 30L, 20L, 0L, 0L, 0L, 0L, 0L, 30L);
     valid(expect);
-    strategy.assign(33, 1, 2, tags);
+    strategy.assign(33, 1, 2, tags, -1);
     expect = Lists.newArrayList(0L, 33L, 0L, 0L, 50L, 30L, 14L, 13L, 20L, 13L,
         30L, 20L, 30L, 20L, 13L, 0L, 0L, 0L, 0L, 30L);
     valid(expect);
@@ -136,19 +138,19 @@ public class PartitionBalanceAssignmentStrategyTest {
 
     Uninterruptibles.sleepUninterruptibly(10, TimeUnit.MILLISECONDS);
     updateServerResource(list);
-    strategy.assign(33, 1, 1, tags);
+    strategy.assign(33, 1, 1, tags, -1);
     expect = Lists.newArrayList(0L, 7L, 0L, 7L, 0L, 7L, 0L, 6L, 0L, 6L, 0L, 0L,
         0L, 0L, 0L, 0L, 0L, 0L, 0L, 0L);
     valid(expect);
-    strategy.assign(41, 1, 2, tags);
+    strategy.assign(41, 1, 2, tags, -1);
     expect = Lists.newArrayList(0L, 7L, 0L, 7L, 0L, 7L, 0L, 6L, 0L, 6L, 0L, 
17L,
         0L, 17L, 0L, 16L, 0L, 16L, 0L, 16L);
     valid(expect);
-    strategy.assign(23, 1, 1, tags);
+    strategy.assign(23, 1, 1, tags, -1);
     expect = Lists.newArrayList(5L, 7L, 5L, 7L, 5L, 7L, 4L, 6L, 4L, 6L, 0L, 
17L,
         0L, 17L, 0L, 16L, 0L, 16L, 0L, 16L);
     valid(expect);
-    strategy.assign(11, 1, 3, tags);
+    strategy.assign(11, 1, 3, tags, -1);
     expect = Lists.newArrayList(5L, 7L, 5L, 7L, 5L, 7L, 4L, 13L, 4L, 13L, 7L, 
17L,
         6L, 17L, 6L, 16L, 0L, 16L, 0L, 16L);
     valid(expect);
@@ -191,4 +193,94 @@ public class PartitionBalanceAssignmentStrategyTest {
       clusterManager.add(node);
     }
   }
+
+  @Test
+  public void testAssignmentShuffleNodesNum() {
+    Set<String> serverTags = Sets.newHashSet("tag-1");
+
+    for (int i = 0; i < 20; ++i) {
+      clusterManager.add(new ServerNode("t1-" + i, "", 0, 0, 0,
+          20 - i, 0, serverTags, true));
+    }
+
+    /**
+     * case1: user specify the illegal shuffle node num(<0)
+     * it will use the default shuffle nodes num when having enough servers.
+     */
+    PartitionRangeAssignment pra = strategy.assign(100, 1, 1, serverTags, -1);
+    assertEquals(
+        shuffleNodesMax,
+        pra.getAssignments()
+            .values()
+            .stream()
+            .flatMap(Collection::stream)
+            .collect(Collectors.toSet())
+            .size()
+    );
+
+    /**
+     * case2: user specify the illegal shuffle node num(==0)
+     * it will use the default shuffle nodes num when having enough servers.
+     */
+    pra = strategy.assign(100, 1, 1, serverTags, 0);
+    assertEquals(
+        shuffleNodesMax,
+        pra.getAssignments()
+            .values()
+            .stream()
+            .flatMap(Collection::stream)
+            .collect(Collectors.toSet())
+            .size()
+    );
+
+    /**
+     * case3: user specify the illegal shuffle node num(>default max 
limitation)
+     * it will use the default shuffle nodes num when having enough servers
+     */
+    pra = strategy.assign(100, 1, 1, serverTags, shuffleNodesMax + 10);
+    assertEquals(
+        shuffleNodesMax,
+        pra.getAssignments()
+            .values()
+            .stream()
+            .flatMap(Collection::stream)
+            .collect(Collectors.toSet())
+            .size()
+    );
+
+    /**
+     * case4: user specify the legal shuffle node num,
+     * it will use the customized shuffle nodes num when having enough servers
+     */
+    pra = strategy.assign(100, 1, 1, serverTags, shuffleNodesMax - 1);
+    assertEquals(
+        shuffleNodesMax - 1,
+        pra.getAssignments()
+            .values()
+            .stream()
+            .flatMap(Collection::stream)
+            .collect(Collectors.toSet())
+            .size()
+    );
+
+    /**
+     * case5: user specify the legal shuffle node num, but cluster dont have 
enough servers,
+     * it will return the remaining servers.
+     */
+    serverTags = Sets.newHashSet("tag-2");
+    for (int i = 0; i < shuffleNodesMax - 1; ++i) {
+      clusterManager.add(new ServerNode("t2-" + i, "", 0, 0, 0,
+          20 - i, 0, serverTags, true));
+    }
+    pra = strategy.assign(100, 1, 1, serverTags, shuffleNodesMax);
+    assertEquals(
+        shuffleNodesMax - 1,
+        pra.getAssignments()
+            .values()
+            .stream()
+            .flatMap(Collection::stream)
+            .collect(Collectors.toSet())
+            .size()
+    );
+  }
 }
diff --git a/docs/client_guide.md b/docs/client_guide.md
index eee239f..b97474e 100644
--- a/docs/client_guide.md
+++ b/docs/client_guide.md
@@ -88,6 +88,7 @@ These configurations are shared by all types of clients.
 |<client_type>.rss.client.send.threadPool.size|5|The thread size for send 
shuffle data to shuffle server|
 |<client_type>.rss.client.assignment.tags|-|The comma-separated list of tags 
for deciding assignment shuffle servers. Notice that the SHUFFLE_SERVER_VERSION 
will always as the assignment tag whether this conf is set or not|
 |<client_type>.rss.client.data.commit.pool.size|The number of assigned shuffle 
servers|The thread size for sending commit to shuffle servers|
+|<client_type>.rss.client.assignment.shuffle.nodes.max|-1|The number of 
required assignment shuffle servers. If it is less than 0 or equals to 0 or 
greater than the coordinator's config of "rss.coordinator.shuffle.nodes.max", 
it will use the size of "rss.coordinator.shuffle.nodes.max" default|
 Notice:
 
 1. `<client_type>` should be `spark` or `mapreduce`
diff --git 
a/integration-test/common/src/test/java/org/apache/uniffle/test/AssignmentServerNodesNumberTest.java
 
b/integration-test/common/src/test/java/org/apache/uniffle/test/AssignmentServerNodesNumberTest.java
new file mode 100644
index 0000000..57bf341
--- /dev/null
+++ 
b/integration-test/common/src/test/java/org/apache/uniffle/test/AssignmentServerNodesNumberTest.java
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.test;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+import com.google.common.collect.Sets;
+import com.google.common.io.Files;
+import java.io.File;
+import java.util.ArrayList;
+import java.util.HashSet;
+import org.apache.uniffle.client.impl.ShuffleWriteClientImpl;
+import org.apache.uniffle.client.util.ClientType;
+import org.apache.uniffle.common.ShuffleAssignmentsInfo;
+import org.apache.uniffle.common.config.RssBaseConf;
+import org.apache.uniffle.coordinator.CoordinatorConf;
+import org.apache.uniffle.server.ShuffleServerConf;
+import org.apache.uniffle.storage.util.StorageType;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class AssignmentServerNodesNumberTest extends CoordinatorTestBase {
+  private static final Logger LOG = 
LoggerFactory.getLogger(AssignmentServerNodesNumberTest.class);
+  private static final int SHUFFLE_NODES_MAX = 10;
+  private static final int SERVER_NUM = 10;
+  private static final HashSet<String> TAGS = Sets.newHashSet("t1");
+
+  @BeforeAll
+  public static void setupServers() throws Exception {
+    CoordinatorConf coordinatorConf = getCoordinatorConf();
+    coordinatorConf.setLong(CoordinatorConf.COORDINATOR_APP_EXPIRED, 2000);
+    coordinatorConf.setInteger(CoordinatorConf.COORDINATOR_SHUFFLE_NODES_MAX, 
SHUFFLE_NODES_MAX);
+    createCoordinatorServer(coordinatorConf);
+
+    for (int i = 0; i < SERVER_NUM; i++){
+      ShuffleServerConf shuffleServerConf = getShuffleServerConf();
+      File tmpDir = Files.createTempDir();
+      File dataDir1 = new File(tmpDir, "data1");
+      String basePath = dataDir1.getAbsolutePath();
+      shuffleServerConf.set(ShuffleServerConf.RSS_STORAGE_TYPE, 
StorageType.MEMORY_LOCALFILE_HDFS.name());
+      shuffleServerConf.set(ShuffleServerConf.RSS_STORAGE_BASE_PATH, basePath);
+      shuffleServerConf.set(RssBaseConf.RPC_METRICS_ENABLED, true);
+      
shuffleServerConf.set(ShuffleServerConf.SERVER_APP_EXPIRED_WITHOUT_HEARTBEAT, 
2000L);
+      shuffleServerConf.set(ShuffleServerConf.SERVER_PRE_ALLOCATION_EXPIRED, 
5000L);
+      shuffleServerConf.setInteger(RssBaseConf.RPC_SERVER_PORT, 18001 + i);
+      shuffleServerConf.setInteger(RssBaseConf.JETTY_HTTP_PORT, 19010 + i);
+      shuffleServerConf.set(ShuffleServerConf.TAGS, new ArrayList<>(TAGS));
+      createShuffleServer(shuffleServerConf);
+    }
+    startServers();
+
+    Thread.sleep(1000 * 5);
+  }
+
+  @Test
+  public void testAssignmentServerNodesNumber() throws Exception {
+    ShuffleWriteClientImpl shuffleWriteClient = new 
ShuffleWriteClientImpl(ClientType.GRPC.name(), 3, 1000, 1,
+        1, 1, 1, true, 1, 1);
+    shuffleWriteClient.registerCoordinators(COORDINATOR_QUORUM);
+
+    /**
+     * case1: user specify the illegal shuffle node num(<0)
+     * it will use the default shuffle nodes num when having enough servers.
+     */
+    ShuffleAssignmentsInfo info = 
shuffleWriteClient.getShuffleAssignments("app1", 0, 10, 1, TAGS, -1);
+    assertEquals(SHUFFLE_NODES_MAX, 
info.getServerToPartitionRanges().keySet().size());
+
+    /**
+     * case2: user specify the illegal shuffle node num(==0)
+     * it will use the default shuffle nodes num when having enough servers.
+     */
+    info = shuffleWriteClient.getShuffleAssignments("app1", 0, 10, 1, TAGS, 0);
+    assertEquals(SHUFFLE_NODES_MAX, 
info.getServerToPartitionRanges().keySet().size());
+
+    /**
+     * case3: user specify the illegal shuffle node num(>default max 
limitation)
+     * it will use the default shuffle nodes num when having enough servers
+     */
+    info = shuffleWriteClient.getShuffleAssignments("app1", 0, 10, 1, TAGS, 
SERVER_NUM + 10);
+    assertEquals(SHUFFLE_NODES_MAX, 
info.getServerToPartitionRanges().keySet().size());
+
+    /**
+     * case4: user specify the legal shuffle node num,
+     * it will use the customized shuffle nodes num when having enough servers
+     */
+    info = shuffleWriteClient.getShuffleAssignments("app1", 0, 10, 1, TAGS, 
SERVER_NUM - 1);
+    assertEquals(SHUFFLE_NODES_MAX - 1, 
info.getServerToPartitionRanges().keySet().size());
+  }
+}
diff --git 
a/integration-test/common/src/test/java/org/apache/uniffle/test/AssignmentWithTagsTest.java
 
b/integration-test/common/src/test/java/org/apache/uniffle/test/AssignmentWithTagsTest.java
index 9ab84d4..4ca0c20 100644
--- 
a/integration-test/common/src/test/java/org/apache/uniffle/test/AssignmentWithTagsTest.java
+++ 
b/integration-test/common/src/test/java/org/apache/uniffle/test/AssignmentWithTagsTest.java
@@ -153,7 +153,7 @@ public class AssignmentWithTagsTest extends 
CoordinatorTestBase {
         // Case1 : only set the single default shuffle version tag
         ShuffleAssignmentsInfo assignmentsInfo =
                 shuffleWriteClient.getShuffleAssignments("app-1",
-                        1, 1, 1, 
Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION));
+                        1, 1, 1, 
Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION), 1);
 
         List<Integer> assignedServerPorts = assignmentsInfo
                 .getPartitionToServers()
@@ -168,7 +168,7 @@ public class AssignmentWithTagsTest extends 
CoordinatorTestBase {
         // Case2: Set the single non-exist shuffle server tag
         try {
             assignmentsInfo = shuffleWriteClient.getShuffleAssignments("app-2",
-                    1, 1, 1, Sets.newHashSet("non-exist"));
+                    1, 1, 1, Sets.newHashSet("non-exist"), 1);
             fail();
         } catch (Exception e) {
             assertTrue(e.getMessage().startsWith("Error happened when 
getShuffleAssignments with"));
@@ -176,7 +176,7 @@ public class AssignmentWithTagsTest extends 
CoordinatorTestBase {
 
         // Case3: Set the single fixed tag
         assignmentsInfo = shuffleWriteClient.getShuffleAssignments("app-3",
-                1, 1, 1, Sets.newHashSet("fixed"));
+                1, 1, 1, Sets.newHashSet("fixed"), 1);
         assignedServerPorts = assignmentsInfo
                 .getPartitionToServers()
                 .values()
@@ -189,7 +189,7 @@ public class AssignmentWithTagsTest extends 
CoordinatorTestBase {
 
         // case4: Set the multiple tags if exists
         assignmentsInfo = shuffleWriteClient.getShuffleAssignments("app-4",
-                1, 1, 1, Sets.newHashSet("fixed", 
Constants.SHUFFLE_SERVER_VERSION));
+                1, 1, 1, Sets.newHashSet("fixed", 
Constants.SHUFFLE_SERVER_VERSION), 1);
         assignedServerPorts = assignmentsInfo
                 .getPartitionToServers()
                 .values()
@@ -203,7 +203,7 @@ public class AssignmentWithTagsTest extends 
CoordinatorTestBase {
         // case5: Set the multiple tags if non-exist
         try {
             assignmentsInfo = shuffleWriteClient.getShuffleAssignments("app-5",
-                    1, 1, 1, Sets.newHashSet("fixed", "elastic", 
Constants.SHUFFLE_SERVER_VERSION));
+                    1, 1, 1, Sets.newHashSet("fixed", "elastic", 
Constants.SHUFFLE_SERVER_VERSION), 1);
             fail();
         } catch (Exception e) {
             assertTrue(e.getMessage().startsWith("Error happened when 
getShuffleAssignments with"));
diff --git 
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/CoordinatorGrpcClient.java
 
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/CoordinatorGrpcClient.java
index dc1fa47..53e0922 100644
--- 
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/CoordinatorGrpcClient.java
+++ 
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/CoordinatorGrpcClient.java
@@ -153,7 +153,13 @@ public class CoordinatorGrpcClient extends GrpcClient 
implements CoordinatorClie
   }
 
   public RssProtos.GetShuffleAssignmentsResponse doGetShuffleAssignments(
-      String appId, int shuffleId, int numMaps, int partitionNumPerRange, int 
dataReplica, Set<String> requiredTags) {
+      String appId,
+      int shuffleId,
+      int numMaps,
+      int partitionNumPerRange,
+      int dataReplica,
+      Set<String> requiredTags,
+      int assignmentShuffleServerNumber) {
 
     RssProtos.GetShuffleServerRequest getServerRequest = 
RssProtos.GetShuffleServerRequest.newBuilder()
         .setApplicationId(appId)
@@ -162,6 +168,7 @@ public class CoordinatorGrpcClient extends GrpcClient 
implements CoordinatorClie
         .setPartitionNumPerRange(partitionNumPerRange)
         .setDataReplica(dataReplica)
         .addAllRequireTags(requiredTags)
+        .setAssignmentShuffleServerNumber(assignmentShuffleServerNumber)
         .build();
 
     return blockingStub.getShuffleAssignments(getServerRequest);
@@ -221,7 +228,8 @@ public class CoordinatorGrpcClient extends GrpcClient 
implements CoordinatorClie
         request.getPartitionNum(),
         request.getPartitionNumPerRange(),
         request.getDataReplica(),
-        request.getRequiredTags());
+        request.getRequiredTags(),
+        request.getAssignmentShuffleServerNumber());
 
     RssGetShuffleAssignmentsResponse response;
     StatusCode statusCode = rpcResponse.getStatus();
diff --git 
a/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleAssignmentsRequest.java
 
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleAssignmentsRequest.java
index acf0e3d..d0971cb 100644
--- 
a/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleAssignmentsRequest.java
+++ 
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleAssignmentsRequest.java
@@ -19,6 +19,8 @@ package org.apache.uniffle.client.request;
 
 import java.util.Set;
 
+import com.google.common.annotations.VisibleForTesting;
+
 public class RssGetShuffleAssignmentsRequest {
 
   private String appId;
@@ -27,15 +29,23 @@ public class RssGetShuffleAssignmentsRequest {
   private int partitionNumPerRange;
   private int dataReplica;
   private Set<String> requiredTags;
+  private int assignmentShuffleServerNumber;
 
+  @VisibleForTesting
   public RssGetShuffleAssignmentsRequest(String appId, int shuffleId, int 
partitionNum,
       int partitionNumPerRange, int dataReplica, Set<String> requiredTags) {
+    this(appId, shuffleId, partitionNum, partitionNumPerRange, dataReplica, 
requiredTags, -1);
+  }
+
+  public RssGetShuffleAssignmentsRequest(String appId, int shuffleId, int 
partitionNum,
+      int partitionNumPerRange, int dataReplica, Set<String> requiredTags, int 
assignmentShuffleServerNumber) {
     this.appId = appId;
     this.shuffleId = shuffleId;
     this.partitionNum = partitionNum;
     this.partitionNumPerRange = partitionNumPerRange;
     this.dataReplica = dataReplica;
     this.requiredTags = requiredTags;
+    this.assignmentShuffleServerNumber = assignmentShuffleServerNumber;
   }
 
   public String getAppId() {
@@ -61,4 +71,8 @@ public class RssGetShuffleAssignmentsRequest {
   public Set<String> getRequiredTags() {
     return requiredTags;
   }
+
+  public int getAssignmentShuffleServerNumber() {
+    return assignmentShuffleServerNumber;
+  }
 }
diff --git a/proto/src/main/proto/Rss.proto b/proto/src/main/proto/Rss.proto
index 647430d..d4d979c 100644
--- a/proto/src/main/proto/Rss.proto
+++ b/proto/src/main/proto/Rss.proto
@@ -295,6 +295,7 @@ message GetShuffleServerRequest {
   int32 partitionNumPerRange = 7;
   int32 dataReplica = 8;
   repeated string requireTags = 9;
+  int32 assignmentShuffleServerNumber = 10;
 }
 
 message PartitionRangeAssignment {

Reply via email to