This is an automated email from the ASF dual-hosted git repository.

roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new 2cb22ff5 [#719] feat(netty): Optimize allocation strategy (#739)
2cb22ff5 is described below

commit 2cb22ff5869fadf59f6013357de0827dc12f215d
Author: jokercurry <[email protected]>
AuthorDate: Mon Mar 20 22:11:31 2023 +0800

    [#719] feat(netty): Optimize allocation strategy (#739)
    
    ### What changes were proposed in this pull request?
    Users can choose to use netty's transmission method or grpc's through 
client configuration.
    
    ### Why are the changes needed?
    Fix: #719
    
    ### Does this PR introduce _any_ user-facing change?
    No. However, if users want to use `netty` as a data transfer method, they 
need to enable `spark.rss.client.type=GRPC_ NETTY` or 
`mapreduce.rss.client.type=GRPC_ NETTY`
    
    ### How was this patch tested?
    New uts.
---
 .../hadoop/mapreduce/v2/app/RssMRAppMaster.java    |  3 +
 .../apache/spark/shuffle/RssShuffleManager.java    |  3 +
 .../apache/spark/shuffle/RssShuffleManager.java    |  2 +
 .../apache/uniffle/client/util/ClientUtils.java    | 11 +++
 .../org/apache/uniffle/client/ClientUtilsTest.java | 13 ++++
 .../coordinator/SimpleClusterManagerTest.java      | 78 ++++++++++++++++++++--
 .../apache/uniffle/test/CoordinatorGrpcTest.java   | 27 +++++++-
 .../apache/uniffle/test/MRIntegrationTestBase.java |  3 +
 .../org/apache/uniffle/server/ShuffleServer.java   | 10 +++
 9 files changed, 141 insertions(+), 9 deletions(-)

diff --git 
a/client-mr/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java
 
b/client-mr/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java
index 15981aac..044eae0b 100644
--- 
a/client-mr/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java
+++ 
b/client-mr/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java
@@ -130,6 +130,9 @@ public class RssMRAppMaster extends MRAppMaster {
         assignmentTags.addAll(Arrays.asList(rawTags.split(",")));
       }
       assignmentTags.add(Constants.SHUFFLE_SERVER_VERSION);
+      String clientType = conf.get(RssMRConfig.RSS_CLIENT_TYPE);
+      ClientUtils.validateClientType(clientType);
+      assignmentTags.add(clientType);
 
       final ScheduledExecutorService scheduledExecutorService = 
Executors.newSingleThreadScheduledExecutor(
           new ThreadFactory() {
diff --git 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 2ecb6f8f..61e201cb 100644
--- 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -263,12 +263,15 @@ public class RssShuffleManager implements ShuffleManager {
 
     // get all register info according to coordinator's response
     Set<String> assignmentTags = 
RssSparkShuffleUtils.getAssignmentTags(sparkConf);
+    ClientUtils.validateClientType(clientType);
+    assignmentTags.add(clientType);
 
     int requiredShuffleServerNumber = 
RssSparkShuffleUtils.getRequiredShuffleServerNumber(sparkConf);
 
     // retryInterval must bigger than `rss.server.heartbeat.timeout`, or maybe 
it will return the same result
     long retryInterval = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL);
     int retryTimes = 
sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES);
+
     Map<Integer, List<ShuffleServerInfo>> partitionToServers;
     try {
       partitionToServers = RetryUtils.retry(() -> {
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index e70026ac..4a574bc5 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -339,6 +339,8 @@ public class RssShuffleManager implements ShuffleManager {
         id.get(), defaultRemoteStorage, dynamicConfEnabled, storageType, 
shuffleWriteClient);
 
     Set<String> assignmentTags = 
RssSparkShuffleUtils.getAssignmentTags(sparkConf);
+    ClientUtils.validateClientType(clientType);
+    assignmentTags.add(clientType);
 
     int requiredShuffleServerNumber = 
RssSparkShuffleUtils.getRequiredShuffleServerNumber(sparkConf);
 
diff --git 
a/client/src/main/java/org/apache/uniffle/client/util/ClientUtils.java 
b/client/src/main/java/org/apache/uniffle/client/util/ClientUtils.java
index 0bdf7cf0..d9b51883 100644
--- a/client/src/main/java/org/apache/uniffle/client/util/ClientUtils.java
+++ b/client/src/main/java/org/apache/uniffle/client/util/ClientUtils.java
@@ -18,12 +18,16 @@
 package org.apache.uniffle.client.util;
 
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.List;
+import java.util.Set;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.Future;
 import java.util.concurrent.TimeUnit;
+import java.util.stream.Collectors;
 
 import org.apache.uniffle.client.api.ShuffleWriteClient;
+import org.apache.uniffle.common.ClientType;
 import org.apache.uniffle.common.RemoteStorageInfo;
 import org.apache.uniffle.common.util.Constants;
 import org.apache.uniffle.storage.util.StorageType;
@@ -122,4 +126,11 @@ public class ClientUtils {
               + "because of the poor performance of these two types.");
     }
   }
+
+  public static void validateClientType(String clientType) {
+    Set<String> types = 
Arrays.stream(ClientType.values()).map(Enum::name).collect(Collectors.toSet());
+    if (!types.contains(clientType)) {
+      throw new IllegalArgumentException(String.format("The value of %s should 
be one of %s", clientType, types));
+    }
+  }
 }
diff --git 
a/client/src/test/java/org/apache/uniffle/client/ClientUtilsTest.java 
b/client/src/test/java/org/apache/uniffle/client/ClientUtilsTest.java
index 77f9cba5..577162a4 100644
--- a/client/src/test/java/org/apache/uniffle/client/ClientUtilsTest.java
+++ b/client/src/test/java/org/apache/uniffle/client/ClientUtilsTest.java
@@ -134,4 +134,17 @@ public class ClientUtilsTest {
     List<CompletableFuture<Boolean>> futures3 = getFutures(false);
     Awaitility.await().timeout(4, TimeUnit.SECONDS).until(() -> 
waitUntilDoneOrFail(futures3, true));
   }
+
+  @Test
+  public void testValidateClientType() {
+    String clientType = "GRPC_NETTY";
+    ClientUtils.validateClientType(clientType);
+    clientType = "test";
+    try {
+      ClientUtils.validateClientType(clientType);
+      fail();
+    } catch (Exception e) {
+      // Ignore
+    }
+  }
 }
diff --git 
a/coordinator/src/test/java/org/apache/uniffle/coordinator/SimpleClusterManagerTest.java
 
b/coordinator/src/test/java/org/apache/uniffle/coordinator/SimpleClusterManagerTest.java
index c8dccf08..123dca4d 100644
--- 
a/coordinator/src/test/java/org/apache/uniffle/coordinator/SimpleClusterManagerTest.java
+++ 
b/coordinator/src/test/java/org/apache/uniffle/coordinator/SimpleClusterManagerTest.java
@@ -27,6 +27,7 @@ import java.util.Set;
 import java.util.concurrent.TimeUnit;
 import java.util.stream.Collectors;
 
+import com.google.common.collect.Maps;
 import com.google.common.collect.Sets;
 import org.apache.hadoop.conf.Configuration;
 import org.junit.jupiter.api.AfterEach;
@@ -36,6 +37,8 @@ import org.junit.jupiter.api.Test;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import org.apache.uniffle.common.ClientType;
+import org.apache.uniffle.common.ServerStatus;
 import org.apache.uniffle.coordinator.metric.CoordinatorMetrics;
 
 import static org.awaitility.Awaitility.await;
@@ -48,6 +51,8 @@ public class SimpleClusterManagerTest {
   private static final Logger LOG = 
LoggerFactory.getLogger(SimpleClusterManagerTest.class);
 
   private final Set<String> testTags = Sets.newHashSet("test");
+  private final Set<String> nettyTags = Sets.newHashSet("test", 
ClientType.GRPC_NETTY.name());
+  private final Set<String> grpcTags = Sets.newHashSet("test", 
ClientType.GRPC.name());
 
   @BeforeEach
   public void setUp() {
@@ -79,15 +84,15 @@ public class SimpleClusterManagerTest {
     try (SimpleClusterManager clusterManager = new SimpleClusterManager(ssc, 
new Configuration())) {
 
       ServerNode sn1 = new ServerNode("sn1", "ip", 0, 100L, 50L, 20,
-              10, testTags, true);
+              10, grpcTags, true);
       ServerNode sn2 = new ServerNode("sn2", "ip", 0, 100L, 50L, 21,
-              10, testTags, true);
+              10, grpcTags, true);
       ServerNode sn3 = new ServerNode("sn3", "ip", 0, 100L, 50L, 20,
-              11, testTags, true);
+              11, grpcTags, true);
       clusterManager.add(sn1);
       clusterManager.add(sn2);
       clusterManager.add(sn3);
-      List<ServerNode> serverNodes = clusterManager.getServerList(testTags);
+      List<ServerNode> serverNodes = clusterManager.getServerList(grpcTags);
       assertEquals(3, serverNodes.size());
       Set<String> expectedIds = Sets.newHashSet("sn1", "sn2", "sn3");
       assertEquals(expectedIds, 
serverNodes.stream().map(ServerNode::getId).collect(Collectors.toSet()));
@@ -98,7 +103,7 @@ public class SimpleClusterManagerTest {
       sn2 = new ServerNode("sn2", "ip", 0, 100L, 50L, 21,
               10, Sets.newHashSet("test", "new_tag"), true);
       ServerNode sn4 = new ServerNode("sn4", "ip", 0, 100L, 51L, 20,
-              10, testTags, true);
+              10, grpcTags, true);
       clusterManager.add(sn1);
       clusterManager.add(sn2);
       clusterManager.add(sn4);
@@ -109,7 +114,7 @@ public class SimpleClusterManagerTest {
       assertTrue(serverNodes.contains(sn4));
 
       Map<String, Set<ServerNode>> tagToNodes = clusterManager.getTagToNodes();
-      assertEquals(2, tagToNodes.size());
+      assertEquals(3, tagToNodes.size());
 
       Set<ServerNode> newTagNodes = tagToNodes.get("new_tag");
       assertEquals(2, newTagNodes.size());
@@ -124,6 +129,67 @@ public class SimpleClusterManagerTest {
     }
   }
 
+  @Test
+  public void getServerListForNettyTest() throws Exception {
+    CoordinatorConf ssc = new CoordinatorConf();
+    ssc.setLong(CoordinatorConf.COORDINATOR_HEARTBEAT_TIMEOUT, 30 * 1000L);
+    try (SimpleClusterManager clusterManager = new SimpleClusterManager(ssc, 
new Configuration())) {
+
+      ServerNode sn1 = new ServerNode("sn1", "ip", 0, 100L, 50L, 20,
+          10, nettyTags, true, ServerStatus.ACTIVE, Maps.newConcurrentMap(), 
1);
+      ServerNode sn2 = new ServerNode("sn2", "ip", 0, 100L, 50L, 21,
+          10, nettyTags, true, ServerStatus.ACTIVE, Maps.newConcurrentMap(), 
1);
+      ServerNode sn3 = new ServerNode("sn3", "ip", 0, 100L, 50L, 20,
+          11, nettyTags, true, ServerStatus.ACTIVE, Maps.newConcurrentMap(), 
1);
+      ServerNode sn4 = new ServerNode("sn4", "ip", 0, 100L, 50L, 20,
+          11, grpcTags, true);
+      clusterManager.add(sn1);
+      clusterManager.add(sn2);
+      clusterManager.add(sn3);
+      clusterManager.add(sn4);
+
+      List<ServerNode> serverNodes2 = clusterManager.getServerList(nettyTags);
+      assertEquals(3, serverNodes2.size());
+
+      List<ServerNode> serverNodes3 = clusterManager.getServerList(grpcTags);
+      assertEquals(1, serverNodes3.size());
+
+      List<ServerNode> serverNodes4 = clusterManager.getServerList(testTags);
+      assertEquals(4, serverNodes4.size());
+
+      Map<String, Set<ServerNode>> tagToNodes = clusterManager.getTagToNodes();
+      assertEquals(3, tagToNodes.size());
+
+      // tag changes
+      sn1 = new ServerNode("sn1", "ip", 0, 100L, 50L, 20,
+          10, Sets.newHashSet("new_tag"), true, ServerStatus.ACTIVE, 
Maps.newConcurrentMap(), 1);
+      sn2 = new ServerNode("sn2", "ip", 0, 100L, 50L, 21,
+          10, Sets.newHashSet("test", "new_tag"),
+          true, ServerStatus.ACTIVE, Maps.newConcurrentMap(), 1);
+      sn4 = new ServerNode("sn4", "ip", 0, 100L, 51L, 20,
+          10, grpcTags, true);
+      clusterManager.add(sn1);
+      clusterManager.add(sn2);
+      clusterManager.add(sn4);
+      Set<ServerNode> testTagNodesForNetty = 
tagToNodes.get(ClientType.GRPC_NETTY.name());
+      assertEquals(1, testTagNodesForNetty.size());
+
+      List<ServerNode> serverNodes = clusterManager.getServerList(grpcTags);
+      assertEquals(1, serverNodes.size());
+      assertTrue(serverNodes.contains(sn4));
+
+      Set<ServerNode> newTagNodes = tagToNodes.get("new_tag");
+      assertEquals(2, newTagNodes.size());
+      assertTrue(newTagNodes.contains(sn1));
+      assertTrue(newTagNodes.contains(sn2));
+      Set<ServerNode> testTagNodes = tagToNodes.get("test");
+      assertEquals(3, testTagNodes.size());
+      assertTrue(testTagNodes.contains(sn2));
+      assertTrue(testTagNodes.contains(sn3));
+      assertTrue(testTagNodes.contains(sn4));
+    }
+  }
+
   @Test
   public void 
testGetCorrectServerNodesWhenOneNodeRemovedAndUnhealthyNodeFound() throws 
Exception {
     CoordinatorConf ssc = new CoordinatorConf();
diff --git 
a/integration-test/common/src/test/java/org/apache/uniffle/test/CoordinatorGrpcTest.java
 
b/integration-test/common/src/test/java/org/apache/uniffle/test/CoordinatorGrpcTest.java
index 9b63df7c..800184b2 100644
--- 
a/integration-test/common/src/test/java/org/apache/uniffle/test/CoordinatorGrpcTest.java
+++ 
b/integration-test/common/src/test/java/org/apache/uniffle/test/CoordinatorGrpcTest.java
@@ -17,10 +17,12 @@
 
 package org.apache.uniffle.test;
 
+import java.lang.reflect.Field;
 import java.util.Arrays;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
 
 import com.google.common.collect.Lists;
 import com.google.common.collect.Sets;
@@ -31,10 +33,12 @@ import 
org.apache.uniffle.client.request.RssApplicationInfoRequest;
 import org.apache.uniffle.client.request.RssGetShuffleAssignmentsRequest;
 import org.apache.uniffle.client.response.RssApplicationInfoResponse;
 import org.apache.uniffle.client.response.RssGetShuffleAssignmentsResponse;
+import org.apache.uniffle.common.ClientType;
 import org.apache.uniffle.common.PartitionRange;
 import org.apache.uniffle.common.ShuffleRegisterInfo;
 import org.apache.uniffle.common.ShuffleServerInfo;
 import org.apache.uniffle.common.config.RssBaseConf;
+import org.apache.uniffle.common.config.RssConf;
 import org.apache.uniffle.common.rpc.StatusCode;
 import org.apache.uniffle.common.storage.StorageInfo;
 import org.apache.uniffle.common.storage.StorageMedia;
@@ -121,11 +125,28 @@ public class CoordinatorGrpcTest extends 
CoordinatorTestBase {
 
   @Test
   public void getShuffleAssignmentsTest() throws Exception {
-    String appId = "getShuffleAssignmentsTest";
+    final String appId = "getShuffleAssignmentsTest";
     CoordinatorTestUtils.waitForRegister(coordinatorClient,2);
+    // When the shuffleServerHeartbeat Test is completed before the current 
test,
+    // the server's tags will be [ss_v4, GRPC_NETTY] and [ss_v4, GRPC], 
respectively.
+    // We need to remove the first machine's tag from GRPC_NETTY to GRPC
+    shuffleServers.get(0).stopServer();
+    RssConf shuffleServerConf = shuffleServers.get(0).getShuffleServerConf();
+    Class<RssConf> clazz = RssConf.class;
+    Field field = clazz.getDeclaredField("settings");
+    field.setAccessible(true);
+    ((ConcurrentHashMap) 
field.get(shuffleServerConf)).remove(ShuffleServerConf.NETTY_SERVER_PORT.key());
+    String storageTypeJsonSource = String.format("{\"%s\": \"ssd\"}", baseDir);
+    withEnvironmentVariables("RSS_ENV_KEY", storageTypeJsonSource).execute(() 
-> {
+      ShuffleServer ss = new ShuffleServer((ShuffleServerConf) 
shuffleServerConf);
+      ss.start();
+      shuffleServers.set(0, ss);
+    });
+    Thread.sleep(5000);
+    // add tag when ClientType is `GRPC`
     RssGetShuffleAssignmentsRequest request = new 
RssGetShuffleAssignmentsRequest(
         appId, 1, 10, 4, 1,
-        Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION));
+        Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION, 
ClientType.GRPC.name()));
     RssGetShuffleAssignmentsResponse response = 
coordinatorClient.getShuffleAssignments(request);
     Set<Integer> expectedStart = Sets.newHashSet(0, 4, 8);
 
@@ -157,7 +178,7 @@ public class CoordinatorGrpcTest extends 
CoordinatorTestBase {
 
     request = new RssGetShuffleAssignmentsRequest(
         appId, 1, 10, 4, 2,
-        Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION));
+        Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION, 
ClientType.GRPC.name()));
     response = coordinatorClient.getShuffleAssignments(request);
     serverToPartitionRanges = response.getServerToPartitionRanges();
     assertEquals(2, serverToPartitionRanges.size());
diff --git 
a/integration-test/mr/src/test/java/org/apache/uniffle/test/MRIntegrationTestBase.java
 
b/integration-test/mr/src/test/java/org/apache/uniffle/test/MRIntegrationTestBase.java
index 1ceab307..ed22f953 100644
--- 
a/integration-test/mr/src/test/java/org/apache/uniffle/test/MRIntegrationTestBase.java
+++ 
b/integration-test/mr/src/test/java/org/apache/uniffle/test/MRIntegrationTestBase.java
@@ -44,6 +44,8 @@ import org.apache.hadoop.yarn.conf.YarnConfiguration;
 import org.junit.jupiter.api.AfterAll;
 import org.junit.jupiter.api.BeforeAll;
 
+import org.apache.uniffle.common.ClientType;
+
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertNotNull;
 
@@ -165,6 +167,7 @@ public class MRIntegrationTestBase extends 
IntegrationTestBase {
     jobConf.set(MRJobConfig.MAPREDUCE_APPLICATION_CLASSPATH,
         "$PWD/rss.jar/" + localFile.getName() + "," + 
MRJobConfig.DEFAULT_MAPREDUCE_APPLICATION_CLASSPATH);
     jobConf.set(RssMRConfig.RSS_COORDINATOR_QUORUM, COORDINATOR_QUORUM);
+    jobConf.set(RssMRConfig.RSS_CLIENT_TYPE, ClientType.GRPC.name());
     updateRssConfiguration(jobConf);
     runMRApp(jobConf, getTestTool(), getTestArgs());
 
diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleServer.java 
b/server/src/main/java/org/apache/uniffle/server/ShuffleServer.java
index 3a47a81d..263884c1 100644
--- a/server/src/main/java/org/apache/uniffle/server/ShuffleServer.java
+++ b/server/src/main/java/org/apache/uniffle/server/ShuffleServer.java
@@ -35,6 +35,7 @@ import org.slf4j.LoggerFactory;
 import picocli.CommandLine;
 
 import org.apache.uniffle.common.Arguments;
+import org.apache.uniffle.common.ClientType;
 import org.apache.uniffle.common.ServerStatus;
 import org.apache.uniffle.common.exception.InvalidRequestException;
 import org.apache.uniffle.common.metrics.GRPCMetrics;
@@ -248,9 +249,18 @@ public class ShuffleServer {
     if (CollectionUtils.isNotEmpty(configuredTags)) {
       tags.addAll(configuredTags);
     }
+    tagServer();
     LOG.info("Server tags: {}", tags);
   }
 
+  private void tagServer() {
+    if (nettyServerEnabled) {
+      tags.add(ClientType.GRPC_NETTY.name());
+    } else {
+      tags.add(ClientType.GRPC.name());
+    }
+  }
+
   private void registerMetrics() throws Exception {
     LOG.info("Register metrics");
     CollectorRegistry shuffleServerCollectorRegistry = new 
CollectorRegistry(true);

Reply via email to