This is an automated email from the ASF dual-hosted git repository.

xianjingfeng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new 722d3079a [#1887] improvement: reject all requests from unregistered 
apps in shuffle server (#1923)
722d3079a is described below

commit 722d3079aad4a8b0a0b2f3eb9b956658889e0aa2
Author: xianjingfeng <xianjingfeng...@gmail.com>
AuthorDate: Tue Jul 23 15:28:02 2024 +0800

    [#1887] improvement: reject all requests from unregistered apps in shuffle 
server (#1923)
    
    ### What changes were proposed in this pull request?
    Reject all requests from unregistered apps in shuffle server
    
    ### Why are the changes needed?
    For better performance.
    Fix: #1887
    
    ### Does this PR introduce any user-facing change?
    No.
    
    ### How was this patch tested?
    Existing UT
---
 .../apache/uniffle/test/ShuffleServerGrpcTest.java |   4 +-
 .../uniffle/server/ShuffleServerGrpcService.java   | 164 ++++++++++++++++++---
 .../apache/uniffle/server/ShuffleTaskManager.java  |   8 +-
 .../server/netty/ShuffleServerNettyHandler.java    |  46 +++++-
 4 files changed, 192 insertions(+), 30 deletions(-)

diff --git 
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java
 
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java
index 8a6e0cecf..df3e29971 100644
--- 
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java
+++ 
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerGrpcTest.java
@@ -259,7 +259,7 @@ public class ShuffleServerGrpcTest extends 
IntegrationTestBase {
       grpcShuffleServerClient.reportShuffleResult(request);
       fail("Exception should be thrown");
     } catch (Exception e) {
-      assertTrue(e.getMessage().contains("error happened when report shuffle 
result"));
+      assertTrue(e.getMessage().contains("NO_REGISTER"));
     }
 
     RssGetShuffleResultRequest req =
@@ -268,7 +268,7 @@ public class ShuffleServerGrpcTest extends 
IntegrationTestBase {
       grpcShuffleServerClient.getShuffleResult(req);
       fail("Exception should be thrown");
     } catch (Exception e) {
-      assertTrue(e.getMessage().contains("Can't get shuffle result"));
+      assertTrue(e.getMessage().contains("NO_REGISTER"));
     }
 
     RssRegisterShuffleRequest rrsr =
diff --git 
a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java 
b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java
index aea43d24e..06e2d8b92 100644
--- 
a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java
+++ 
b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java
@@ -105,19 +105,28 @@ public class ShuffleServerGrpcService extends 
ShuffleServerImplBase {
       RssProtos.ShuffleUnregisterByAppIdRequest request,
       StreamObserver<RssProtos.ShuffleUnregisterByAppIdResponse> 
responseStreamObserver) {
     String appId = request.getAppId();
-
-    StatusCode result = StatusCode.SUCCESS;
+    StatusCode status = verifyRequest(appId);
+    if (status != StatusCode.SUCCESS) {
+      RssProtos.ShuffleUnregisterByAppIdResponse reply =
+          RssProtos.ShuffleUnregisterByAppIdResponse.newBuilder()
+              .setStatus(status.toProto())
+              .setRetMsg(status.toString())
+              .build();
+      responseStreamObserver.onNext(reply);
+      responseStreamObserver.onCompleted();
+      return;
+    }
     String responseMessage = "OK";
     try {
       shuffleServer.getShuffleTaskManager().removeShuffleDataAsync(appId);
 
     } catch (Exception e) {
-      result = StatusCode.INTERNAL_ERROR;
+      status = StatusCode.INTERNAL_ERROR;
     }
 
     RssProtos.ShuffleUnregisterByAppIdResponse reply =
         RssProtos.ShuffleUnregisterByAppIdResponse.newBuilder()
-            .setStatus(result.toProto())
+            .setStatus(status.toProto())
             .setRetMsg(responseMessage)
             .build();
     responseStreamObserver.onNext(reply);
@@ -129,19 +138,29 @@ public class ShuffleServerGrpcService extends 
ShuffleServerImplBase {
       RssProtos.ShuffleUnregisterRequest request,
       StreamObserver<RssProtos.ShuffleUnregisterResponse> 
responseStreamObserver) {
     String appId = request.getAppId();
+    StatusCode status = verifyRequest(appId);
+    if (status != StatusCode.SUCCESS) {
+      RssProtos.ShuffleUnregisterResponse reply =
+          RssProtos.ShuffleUnregisterResponse.newBuilder()
+              .setStatus(status.toProto())
+              .setRetMsg(status.toString())
+              .build();
+      responseStreamObserver.onNext(reply);
+      responseStreamObserver.onCompleted();
+      return;
+    }
     int shuffleId = request.getShuffleId();
 
-    StatusCode result = StatusCode.SUCCESS;
     String responseMessage = "OK";
     try {
       shuffleServer.getShuffleTaskManager().removeShuffleDataAsync(appId, 
shuffleId);
     } catch (Exception e) {
-      result = StatusCode.INTERNAL_ERROR;
+      status = StatusCode.INTERNAL_ERROR;
     }
 
     RssProtos.ShuffleUnregisterResponse reply =
         RssProtos.ShuffleUnregisterResponse.newBuilder()
-            .setStatus(result.toProto())
+            .setStatus(status.toProto())
             .setRetMsg(responseMessage)
             .build();
     responseStreamObserver.onNext(reply);
@@ -430,12 +449,20 @@ public class ShuffleServerGrpcService extends 
ShuffleServerImplBase {
   @Override
   public void commitShuffleTask(
       ShuffleCommitRequest req, StreamObserver<ShuffleCommitResponse> 
responseObserver) {
-
-    ShuffleCommitResponse reply;
     String appId = req.getAppId();
+    StatusCode status = verifyRequest(appId);
+    if (status != StatusCode.SUCCESS) {
+      ShuffleCommitResponse response =
+          ShuffleCommitResponse.newBuilder()
+              .setStatus(status.toProto())
+              .setRetMsg(status.toString())
+              .build();
+      responseObserver.onNext(response);
+      responseObserver.onCompleted();
+      return;
+    }
     int shuffleId = req.getShuffleId();
 
-    StatusCode status = StatusCode.SUCCESS;
     String msg = "OK";
     int commitCount = 0;
 
@@ -460,7 +487,7 @@ public class ShuffleServerGrpcService extends 
ShuffleServerImplBase {
       LOG.error(msg, e);
     }
 
-    reply =
+    ShuffleCommitResponse reply =
         ShuffleCommitResponse.newBuilder()
             .setCommitCount(commitCount)
             .setStatus(status.toProto())
@@ -474,8 +501,18 @@ public class ShuffleServerGrpcService extends 
ShuffleServerImplBase {
   public void finishShuffle(
       FinishShuffleRequest req, StreamObserver<FinishShuffleResponse> 
responseObserver) {
     String appId = req.getAppId();
+    StatusCode status = verifyRequest(appId);
+    if (status != StatusCode.SUCCESS) {
+      FinishShuffleResponse response =
+          FinishShuffleResponse.newBuilder()
+              .setStatus(status.toProto())
+              .setRetMsg(status.toString())
+              .build();
+      responseObserver.onNext(response);
+      responseObserver.onCompleted();
+      return;
+    }
     int shuffleId = req.getShuffleId();
-    StatusCode status;
     String msg = "OK";
     String errorMsg =
         "Fail to finish shuffle for appId["
@@ -506,8 +543,18 @@ public class ShuffleServerGrpcService extends 
ShuffleServerImplBase {
   public void requireBuffer(
       RequireBufferRequest request, StreamObserver<RequireBufferResponse> 
responseObserver) {
     String appId = request.getAppId();
+    StatusCode status = verifyRequest(appId);
+    if (status != StatusCode.SUCCESS) {
+      RequireBufferResponse response =
+          RequireBufferResponse.newBuilder()
+              .setStatus(status.toProto())
+              .setRetMsg(status.toString())
+              .build();
+      responseObserver.onNext(response);
+      responseObserver.onCompleted();
+      return;
+    }
     long requireBufferId = -1;
-    StatusCode status = StatusCode.SUCCESS;
     try {
       if (StringUtils.isEmpty(appId)) {
         // To be compatible with older client version
@@ -548,6 +595,17 @@ public class ShuffleServerGrpcService extends 
ShuffleServerImplBase {
   public void appHeartbeat(
       AppHeartBeatRequest request, StreamObserver<AppHeartBeatResponse> 
responseObserver) {
     String appId = request.getAppId();
+    StatusCode status = verifyRequest(appId);
+    if (status != StatusCode.SUCCESS) {
+      AppHeartBeatResponse response =
+          AppHeartBeatResponse.newBuilder()
+              .setStatus(status.toProto())
+              .setRetMsg(status.toString())
+              .build();
+      responseObserver.onNext(response);
+      responseObserver.onCompleted();
+      return;
+    }
     shuffleServer.getShuffleTaskManager().refreshAppId(appId);
     AppHeartBeatResponse response =
         AppHeartBeatResponse.newBuilder()
@@ -572,12 +630,22 @@ public class ShuffleServerGrpcService extends 
ShuffleServerImplBase {
       ReportShuffleResultRequest request,
       StreamObserver<ReportShuffleResultResponse> responseObserver) {
     String appId = request.getAppId();
+    StatusCode status = verifyRequest(appId);
+    if (status != StatusCode.SUCCESS) {
+      ReportShuffleResultResponse response =
+          ReportShuffleResultResponse.newBuilder()
+              .setStatus(status.toProto())
+              .setRetMsg(status.toString())
+              .build();
+      responseObserver.onNext(response);
+      responseObserver.onCompleted();
+      return;
+    }
     int shuffleId = request.getShuffleId();
     long taskAttemptId = request.getTaskAttemptId();
     int bitmapNum = request.getBitmapNum();
     Map<Integer, long[]> partitionToBlockIds =
         toPartitionBlocksMap(request.getPartitionToBlockIdsList());
-    StatusCode status = StatusCode.SUCCESS;
     String msg = "OK";
     ReportShuffleResultResponse reply;
     String requestInfo =
@@ -617,6 +685,17 @@ public class ShuffleServerGrpcService extends 
ShuffleServerImplBase {
   public void getShuffleResult(
       GetShuffleResultRequest request, 
StreamObserver<GetShuffleResultResponse> responseObserver) {
     String appId = request.getAppId();
+    StatusCode status = verifyRequest(appId);
+    if (status != StatusCode.SUCCESS) {
+      GetShuffleResultResponse response =
+          GetShuffleResultResponse.newBuilder()
+              .setStatus(status.toProto())
+              .setRetMsg(status.toString())
+              .build();
+      responseObserver.onNext(response);
+      responseObserver.onCompleted();
+      return;
+    }
     int shuffleId = request.getShuffleId();
     int partitionId = request.getPartitionId();
     BlockIdLayout blockIdLayout =
@@ -624,7 +703,6 @@ public class ShuffleServerGrpcService extends 
ShuffleServerImplBase {
             request.getBlockIdLayout().getSequenceNoBits(),
             request.getBlockIdLayout().getPartitionIdBits(),
             request.getBlockIdLayout().getTaskAttemptIdBits());
-    StatusCode status = StatusCode.SUCCESS;
     String msg = "OK";
     GetShuffleResultResponse reply;
     byte[] serializedBlockIds = null;
@@ -665,6 +743,17 @@ public class ShuffleServerGrpcService extends 
ShuffleServerImplBase {
       GetShuffleResultForMultiPartRequest request,
       StreamObserver<GetShuffleResultForMultiPartResponse> responseObserver) {
     String appId = request.getAppId();
+    StatusCode status = verifyRequest(appId);
+    if (status != StatusCode.SUCCESS) {
+      GetShuffleResultForMultiPartResponse response =
+          GetShuffleResultForMultiPartResponse.newBuilder()
+              .setStatus(status.toProto())
+              .setRetMsg(status.toString())
+              .build();
+      responseObserver.onNext(response);
+      responseObserver.onCompleted();
+      return;
+    }
     int shuffleId = request.getShuffleId();
     List<Integer> partitionsList = request.getPartitionsList();
     BlockIdLayout blockIdLayout =
@@ -673,7 +762,6 @@ public class ShuffleServerGrpcService extends 
ShuffleServerImplBase {
             request.getBlockIdLayout().getPartitionIdBits(),
             request.getBlockIdLayout().getTaskAttemptIdBits());
 
-    StatusCode status = StatusCode.SUCCESS;
     String msg = "OK";
     GetShuffleResultForMultiPartResponse reply;
     byte[] serializedBlockIds = null;
@@ -715,6 +803,17 @@ public class ShuffleServerGrpcService extends 
ShuffleServerImplBase {
       GetLocalShuffleDataRequest request,
       StreamObserver<GetLocalShuffleDataResponse> responseObserver) {
     String appId = request.getAppId();
+    StatusCode status = verifyRequest(appId);
+    if (status != StatusCode.SUCCESS) {
+      GetLocalShuffleDataResponse response =
+          GetLocalShuffleDataResponse.newBuilder()
+              .setStatus(status.toProto())
+              .setRetMsg(status.toString())
+              .build();
+      responseObserver.onNext(response);
+      responseObserver.onCompleted();
+      return;
+    }
     int shuffleId = request.getShuffleId();
     int partitionId = request.getPartitionId();
     int partitionNumPerRange = request.getPartitionNumPerRange();
@@ -732,7 +831,6 @@ public class ShuffleServerGrpcService extends 
ShuffleServerImplBase {
     }
     String storageType =
         
shuffleServer.getShuffleServerConf().get(RssBaseConf.RSS_STORAGE_TYPE).name();
-    StatusCode status = StatusCode.SUCCESS;
     String msg = "OK";
     GetLocalShuffleDataResponse reply = null;
     ShuffleDataResult sdr = null;
@@ -831,11 +929,21 @@ public class ShuffleServerGrpcService extends 
ShuffleServerImplBase {
       GetLocalShuffleIndexRequest request,
       StreamObserver<GetLocalShuffleIndexResponse> responseObserver) {
     String appId = request.getAppId();
+    StatusCode status = verifyRequest(appId);
+    if (status != StatusCode.SUCCESS) {
+      GetLocalShuffleIndexResponse reply =
+          GetLocalShuffleIndexResponse.newBuilder()
+              .setStatus(status.toProto())
+              .setRetMsg(status.toString())
+              .build();
+      responseObserver.onNext(reply);
+      responseObserver.onCompleted();
+      return;
+    }
     int shuffleId = request.getShuffleId();
     int partitionId = request.getPartitionId();
     int partitionNumPerRange = request.getPartitionNumPerRange();
     int partitionNum = request.getPartitionNum();
-    StatusCode status = StatusCode.SUCCESS;
     String msg = "OK";
     GetLocalShuffleIndexResponse reply;
     String requestInfo =
@@ -928,6 +1036,17 @@ public class ShuffleServerGrpcService extends 
ShuffleServerImplBase {
       GetMemoryShuffleDataRequest request,
       StreamObserver<GetMemoryShuffleDataResponse> responseObserver) {
     String appId = request.getAppId();
+    StatusCode status = verifyRequest(appId);
+    if (status != StatusCode.SUCCESS) {
+      GetMemoryShuffleDataResponse reply =
+          GetMemoryShuffleDataResponse.newBuilder()
+              .setStatus(status.toProto())
+              .setRetMsg(status.toString())
+              .build();
+      responseObserver.onNext(reply);
+      responseObserver.onCompleted();
+      return;
+    }
     int shuffleId = request.getShuffleId();
     int partitionId = request.getPartitionId();
     long blockId = request.getLastBlockId();
@@ -943,7 +1062,6 @@ public class ShuffleServerGrpcService extends 
ShuffleServerImplBase {
                 ShuffleServerGrpcMetrics.GET_MEMORY_SHUFFLE_DATA_METHOD, 
transportTime);
       }
     }
-    StatusCode status = StatusCode.SUCCESS;
     String msg = "OK";
     GetMemoryShuffleDataResponse reply;
     String requestInfo =
@@ -1108,4 +1226,12 @@ public class ShuffleServerGrpcService extends 
ShuffleServerImplBase {
     }
     return shuffleDataBlockSegments;
   }
+
+  private StatusCode verifyRequest(String appId) {
+    if (StringUtils.isNotBlank(appId)
+        && shuffleServer.getShuffleTaskManager().isAppExpired(appId)) {
+      return StatusCode.NO_REGISTER;
+    }
+    return StatusCode.SUCCESS;
+  }
 }
diff --git 
a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java 
b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java
index 8fe597d03..b258c8a11 100644
--- a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java
+++ b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java
@@ -725,12 +725,12 @@ public class ShuffleTaskManager {
     }
   }
 
-  private boolean isAppExpired(String appId) {
-    if (shuffleTaskInfos.get(appId) == null) {
+  public boolean isAppExpired(String appId) {
+    ShuffleTaskInfo shuffleTaskInfo = shuffleTaskInfos.get(appId);
+    if (shuffleTaskInfo == null) {
       return true;
     }
-    return System.currentTimeMillis() - 
shuffleTaskInfos.get(appId).getCurrentTimes()
-        > appExpiredWithoutHB;
+    return System.currentTimeMillis() - shuffleTaskInfo.getCurrentTimes() > 
appExpiredWithoutHB;
   }
 
   /**
diff --git 
a/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java
 
b/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java
index 27a3f1dc4..cca6a3935 100644
--- 
a/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java
+++ 
b/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java
@@ -28,6 +28,7 @@ import io.netty.channel.ChannelFuture;
 import io.netty.channel.ChannelFutureListener;
 import org.apache.commons.collections.MapUtils;
 import org.apache.commons.collections4.CollectionUtils;
+import org.apache.commons.lang3.StringUtils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -69,7 +70,6 @@ import org.apache.uniffle.storage.util.ShuffleStorageUtils;
 public class ShuffleServerNettyHandler implements BaseMessageHandler {
 
   private static final Logger LOG = 
LoggerFactory.getLogger(ShuffleServerNettyHandler.class);
-  private static final int RPC_TIMEOUT = 60000;
   private final ShuffleServer shuffleServer;
 
   public ShuffleServerNettyHandler(ShuffleServer shuffleServer) {
@@ -335,6 +335,18 @@ public class ShuffleServerNettyHandler implements 
BaseMessageHandler {
   public void handleGetMemoryShuffleDataRequest(
       TransportClient client, GetMemoryShuffleDataRequest req) {
     String appId = req.getAppId();
+    StatusCode status = verifyRequest(appId);
+    if (status != StatusCode.SUCCESS) {
+      GetMemoryShuffleDataResponse response =
+          new GetMemoryShuffleDataResponse(
+              req.getRequestId(),
+              status,
+              status.toString(),
+              Lists.newArrayList(),
+              Unpooled.EMPTY_BUFFER);
+      client.getChannel().writeAndFlush(response);
+      return;
+    }
     int shuffleId = req.getShuffleId();
     int partitionId = req.getPartitionId();
     long blockId = req.getLastBlockId();
@@ -349,7 +361,6 @@ public class ShuffleServerNettyHandler implements 
BaseMessageHandler {
             .recordTransportTime(GetMemoryShuffleDataRequest.class.getName(), 
transportTime);
       }
     }
-    StatusCode status = StatusCode.SUCCESS;
     String msg = "OK";
     GetMemoryShuffleDataResponse response;
     String requestInfo =
@@ -417,11 +428,18 @@ public class ShuffleServerNettyHandler implements 
BaseMessageHandler {
   public void handleGetLocalShuffleIndexRequest(
       TransportClient client, GetLocalShuffleIndexRequest req) {
     String appId = req.getAppId();
+    StatusCode status = verifyRequest(appId);
+    if (status != StatusCode.SUCCESS) {
+      GetLocalShuffleIndexResponse response =
+          new GetLocalShuffleIndexResponse(
+              req.getRequestId(), status, status.toString(), 
Unpooled.EMPTY_BUFFER, 0L);
+      client.getChannel().writeAndFlush(response);
+      return;
+    }
     int shuffleId = req.getShuffleId();
     int partitionId = req.getPartitionId();
     int partitionNumPerRange = req.getPartitionNumPerRange();
     int partitionNum = req.getPartitionNum();
-    StatusCode status = StatusCode.SUCCESS;
     String msg = "OK";
     GetLocalShuffleIndexResponse response;
     String requestInfo =
@@ -501,7 +519,19 @@ public class ShuffleServerNettyHandler implements 
BaseMessageHandler {
   }
 
   public void handleGetLocalShuffleData(TransportClient client, 
GetLocalShuffleDataRequest req) {
+    GetLocalShuffleDataResponse response;
     String appId = req.getAppId();
+    StatusCode status = verifyRequest(appId);
+    if (status != StatusCode.SUCCESS) {
+      response =
+          new GetLocalShuffleDataResponse(
+              req.getRequestId(),
+              status,
+              status.toString(),
+              new NettyManagedBuffer(Unpooled.EMPTY_BUFFER));
+      client.getChannel().writeAndFlush(response);
+      return;
+    }
     int shuffleId = req.getShuffleId();
     int partitionId = req.getPartitionId();
     int partitionNumPerRange = req.getPartitionNumPerRange();
@@ -519,9 +549,7 @@ public class ShuffleServerNettyHandler implements 
BaseMessageHandler {
     }
     String storageType =
         
shuffleServer.getShuffleServerConf().get(RssBaseConf.RSS_STORAGE_TYPE).name();
-    StatusCode status = StatusCode.SUCCESS;
     String msg = "OK";
-    GetLocalShuffleDataResponse response;
     String requestInfo =
         "appId["
             + appId
@@ -625,6 +653,14 @@ public class ShuffleServerNettyHandler implements 
BaseMessageHandler {
     return ret;
   }
 
+  private StatusCode verifyRequest(String appId) {
+    if (StringUtils.isNotBlank(appId)
+        && shuffleServer.getShuffleTaskManager().isAppExpired(appId)) {
+      return StatusCode.NO_REGISTER;
+    }
+    return StatusCode.SUCCESS;
+  }
+
   class ReleaseMemoryAndRecordReadTimeListener implements 
ChannelFutureListener {
     private final long readStartedTime;
     private final long readBufferSize;

Reply via email to