This is an automated email from the ASF dual-hosted git repository.

zhouky pushed a commit to branch branch-0.3
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/branch-0.3 by this push:
     new abea3122e [CELEBORN-745] Match TransportMessage type use number 
instead of enum
abea3122e is described below

commit abea3122eae00b33811600c8e6a976b23fda8879
Author: Angerszhuuuu <[email protected]>
AuthorDate: Sat Jul 1 18:50:02 2023 +0800

    [CELEBORN-745] Match TransportMessage type use number instead of enum
    
    ### What changes were proposed in this pull request?
    Match TransportMessage type use number not enum to support change 
MessageType name,after this pr, then we can change the MessageType name.
    
    ### Why are the changes needed?
    
    ### Does this PR introduce _any_ user-facing change?
    
    ### How was this patch tested?
    
    Closes #1658 from AngersZhuuuu/CELEBORN-745.
    
    Lead-authored-by: Angerszhuuuu <[email protected]>
    Co-authored-by: Shuang <[email protected]>
    Signed-off-by: zky.zhoukeyong <[email protected]>
    (cherry picked from commit 7880c52fff9586453cd923945e57cc87c1769103)
    Signed-off-by: zky.zhoukeyong <[email protected]>
---
 .../common/network/protocol/TransportMessage.java  |  8 +-
 .../common/protocol/message/ControlMessages.scala  | 96 ++++++++++++----------
 2 files changed, 58 insertions(+), 46 deletions(-)

diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
index b3b0f0f7b..a139fc08e 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
@@ -23,11 +23,13 @@ import org.apache.celeborn.common.protocol.MessageType;
 
 public class TransportMessage implements Serializable {
   private static final long serialVersionUID = -3259000920699629773L;
-  private final MessageType type;
+  @Deprecated private final MessageType type;
+  private final int messageTypeValue;
   private final byte[] payload;
 
   public TransportMessage(MessageType type, byte[] payload) {
     this.type = type;
+    this.messageTypeValue = type.getNumber();
     this.payload = payload;
   }
 
@@ -35,6 +37,10 @@ public class TransportMessage implements Serializable {
     return type;
   }
 
+  public int getMessageTypeValue() {
+    return messageTypeValue;
+  }
+
   public byte[] getPayload() {
     return payload;
   }
diff --git 
a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
 
b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
index 571ff5381..2c31d2914 100644
--- 
a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
+++ 
b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala
@@ -781,16 +781,22 @@ object ControlMessages extends Logging {
 
   // TODO change return type to GeneratedMessageV3
   def fromTransportMessage(message: TransportMessage): Any = {
-    message.getType match {
-      case UNKNOWN_MESSAGE | UNRECOGNIZED =>
+    // This can be removed when Transport Message removes type field support 
later.
+    val messageTypeValue = message.getMessageTypeValue match {
+      case UNKNOWN_MESSAGE_VALUE => message.getType.getNumber
+      case _ => message.getMessageTypeValue
+    }
+
+    messageTypeValue match {
+      case UNKNOWN_MESSAGE_VALUE =>
         val msg = s"received unknown message $message"
         logError(msg)
         throw new UnsupportedOperationException(msg)
 
-      case REGISTER_WORKER =>
+      case REGISTER_WORKER_VALUE =>
         PbRegisterWorker.parseFrom(message.getPayload)
 
-      case HEARTBEAT_FROM_WORKER =>
+      case HEARTBEAT_FROM_WORKER_VALUE =>
         val pbHeartbeatFromWorker = 
PbHeartbeatFromWorker.parseFrom(message.getPayload)
         val estimatedAppDiskUsage = new util.HashMap[String, java.lang.Long]()
         val userResourceConsumption = 
PbSerDeUtils.fromPbUserResourceConsumption(
@@ -815,7 +821,7 @@ object ControlMessages extends Logging {
           estimatedAppDiskUsage,
           pbHeartbeatFromWorker.getRequestId)
 
-      case HEARTBEAT_RESPONSE =>
+      case HEARTBEAT_RESPONSE_VALUE =>
         val pbHeartbeatFromWorkerResponse =
           PbHeartbeatFromWorkerResponse.parseFrom(message.getPayload)
         val expiredShuffleKeys = new util.HashSet[String]()
@@ -824,16 +830,16 @@ object ControlMessages extends Logging {
         }
         HeartbeatFromWorkerResponse(expiredShuffleKeys, 
pbHeartbeatFromWorkerResponse.getRegistered)
 
-      case REGISTER_SHUFFLE =>
+      case REGISTER_SHUFFLE_VALUE =>
         PbRegisterShuffle.parseFrom(message.getPayload)
 
-      case REGISTER_MAP_PARTITION_TASK =>
+      case REGISTER_MAP_PARTITION_TASK_VALUE =>
         PbRegisterMapPartitionTask.parseFrom(message.getPayload)
 
-      case REGISTER_SHUFFLE_RESPONSE =>
+      case REGISTER_SHUFFLE_RESPONSE_VALUE =>
         PbRegisterShuffleResponse.parseFrom(message.getPayload)
 
-      case REQUEST_SLOTS =>
+      case REQUEST_SLOTS_VALUE =>
         val pbRequestSlots = PbRequestSlots.parseFrom(message.getPayload)
         val userIdentifier = 
PbSerDeUtils.fromPbUserIdentifier(pbRequestSlots.getUserIdentifier)
         RequestSlots(
@@ -846,7 +852,7 @@ object ControlMessages extends Logging {
           userIdentifier,
           pbRequestSlots.getRequestId)
 
-      case RELEASE_SLOTS =>
+      case RELEASE_SLOTS_VALUE =>
         val pbReleaseSlots = PbReleaseSlots.parseFrom(message.getPayload)
         val slotsList = pbReleaseSlots.getSlotsList.asScala.map(pbSlot =>
           new util.HashMap[String, Integer](pbSlot.getSlotMap)).toList.asJava
@@ -857,24 +863,24 @@ object ControlMessages extends Logging {
           new util.ArrayList[util.Map[String, Integer]](slotsList),
           pbReleaseSlots.getRequestId)
 
-      case RELEASE_SLOTS_RESPONSE =>
+      case RELEASE_SLOTS_RESPONSE_VALUE =>
         val pbReleaseSlotsResponse = 
PbReleaseSlotsResponse.parseFrom(message.getPayload)
         
ReleaseSlotsResponse(Utils.toStatusCode(pbReleaseSlotsResponse.getStatus))
 
-      case REQUEST_SLOTS_RESPONSE =>
+      case REQUEST_SLOTS_RESPONSE_VALUE =>
         val pbRequestSlotsResponse = 
PbRequestSlotsResponse.parseFrom(message.getPayload)
         RequestSlotsResponse(
           Utils.toStatusCode(pbRequestSlotsResponse.getStatus),
           PbSerDeUtils.fromPbWorkerResource(
             pbRequestSlotsResponse.getWorkerResourceMap))
 
-      case REVIVE =>
+      case REVIVE_VALUE =>
         PbRevive.parseFrom(message.getPayload)
 
-      case CHANGE_LOCATION_RESPONSE =>
+      case CHANGE_LOCATION_RESPONSE_VALUE =>
         PbChangeLocationResponse.parseFrom(message.getPayload)
 
-      case MAPPER_END =>
+      case MAPPER_END_VALUE =>
         val pbMapperEnd = PbMapperEnd.parseFrom(message.getPayload)
         MapperEnd(
           pbMapperEnd.getShuffleId,
@@ -883,16 +889,16 @@ object ControlMessages extends Logging {
           pbMapperEnd.getNumMappers,
           pbMapperEnd.getPartitionId)
 
-      case MAPPER_END_RESPONSE =>
+      case MAPPER_END_RESPONSE_VALUE =>
         val pbMapperEndResponse = 
PbMapperEndResponse.parseFrom(message.getPayload)
         MapperEndResponse(Utils.toStatusCode(pbMapperEndResponse.getStatus))
 
-      case GET_REDUCER_FILE_GROUP =>
+      case GET_REDUCER_FILE_GROUP_VALUE =>
         val pbGetReducerFileGroup = 
PbGetReducerFileGroup.parseFrom(message.getPayload)
         GetReducerFileGroup(
           pbGetReducerFileGroup.getShuffleId)
 
-      case GET_REDUCER_FILE_GROUP_RESPONSE =>
+      case GET_REDUCER_FILE_GROUP_RESPONSE_VALUE =>
         val pbGetReducerFileGroupResponse = PbGetReducerFileGroupResponse
           .parseFrom(message.getPayload)
         val fileGroup = 
pbGetReducerFileGroupResponse.getFileGroupsMap.asScala.map {
@@ -911,21 +917,21 @@ object ControlMessages extends Logging {
           attempts,
           partitionIds)
 
-      case UNREGISTER_SHUFFLE =>
+      case UNREGISTER_SHUFFLE_VALUE =>
         PbUnregisterShuffle.parseFrom(message.getPayload)
 
-      case UNREGISTER_SHUFFLE_RESPONSE =>
+      case UNREGISTER_SHUFFLE_RESPONSE_VALUE =>
         PbUnregisterShuffleResponse.parseFrom(message.getPayload)
 
-      case APPLICATION_LOST =>
+      case APPLICATION_LOST_VALUE =>
         val pbApplicationLost = PbApplicationLost.parseFrom(message.getPayload)
         ApplicationLost(pbApplicationLost.getAppId, 
pbApplicationLost.getRequestId)
 
-      case APPLICATION_LOST_RESPONSE =>
+      case APPLICATION_LOST_RESPONSE_VALUE =>
         val pbApplicationLostResponse = 
PbApplicationLostResponse.parseFrom(message.getPayload)
         
ApplicationLostResponse(Utils.toStatusCode(pbApplicationLostResponse.getStatus))
 
-      case HEARTBEAT_FROM_APPLICATION =>
+      case HEARTBEAT_FROM_APPLICATION_VALUE =>
         val pbHeartbeatFromApplication = 
PbHeartbeatFromApplication.parseFrom(message.getPayload)
         HeartbeatFromApplication(
           pbHeartbeatFromApplication.getAppId,
@@ -937,7 +943,7 @@ object ControlMessages extends Logging {
           pbHeartbeatFromApplication.getRequestId,
           pbHeartbeatFromApplication.getShouldResponse)
 
-      case HEARTBEAT_FROM_APPLICATION_RESPONSE =>
+      case HEARTBEAT_FROM_APPLICATION_RESPONSE_VALUE =>
         val pbHeartbeatFromApplicationResponse =
           PbHeartbeatFromApplicationResponse.parseFrom(message.getPayload)
         HeartbeatFromApplicationResponse(
@@ -949,13 +955,13 @@ object ControlMessages extends Logging {
           pbHeartbeatFromApplicationResponse.getShuttingWorkersList.asScala
             .map(PbSerDeUtils.fromPbWorkerInfo).toList.asJava)
 
-      case GET_BLACKLIST =>
+      case GET_BLACKLIST_VALUE =>
         val pbGetBlacklist = PbGetBlacklist.parseFrom(message.getPayload)
         GetBlacklist(
           new 
util.ArrayList[WorkerInfo](pbGetBlacklist.getLocalExcludedWorkersList.asScala
             .map(PbSerDeUtils.fromPbWorkerInfo).toList.asJava))
 
-      case GET_BLACKLIST_RESPONSE =>
+      case GET_BLACKLIST_RESPONSE_VALUE =>
         val pbGetBlacklistResponse = 
PbGetBlacklistResponse.parseFrom(message.getPayload)
         GetBlacklistResponse(
           Utils.toStatusCode(pbGetBlacklistResponse.getStatus),
@@ -964,28 +970,28 @@ object ControlMessages extends Logging {
           pbGetBlacklistResponse.getUnknownWorkersList.asScala
             .map(PbSerDeUtils.fromPbWorkerInfo).toList.asJava)
 
-      case CHECK_QUOTA =>
+      case CHECK_QUOTA_VALUE =>
         val pbCheckAvailable = PbCheckQuota.parseFrom(message.getPayload)
         
CheckQuota(PbSerDeUtils.fromPbUserIdentifier(pbCheckAvailable.getUserIdentifier))
 
-      case CHECK_QUOTA_RESPONSE =>
+      case CHECK_QUOTA_RESPONSE_VALUE =>
         val pbCheckAvailableResponse = PbCheckQuotaResponse
           .parseFrom(message.getPayload)
         CheckQuotaResponse(
           pbCheckAvailableResponse.getAvailable,
           pbCheckAvailableResponse.getReason)
 
-      case REPORT_WORKER_FAILURE =>
+      case REPORT_WORKER_FAILURE_VALUE =>
         val pbReportWorkerUnavailable = 
PbReportWorkerUnavailable.parseFrom(message.getPayload)
         ReportWorkerUnavailable(
           new 
util.ArrayList[WorkerInfo](pbReportWorkerUnavailable.getUnavailableList
             .asScala.map(PbSerDeUtils.fromPbWorkerInfo).toList.asJava),
           pbReportWorkerUnavailable.getRequestId)
 
-      case REGISTER_WORKER_RESPONSE =>
+      case REGISTER_WORKER_RESPONSE_VALUE =>
         PbRegisterWorkerResponse.parseFrom(message.getPayload)
 
-      case RESERVE_SLOTS =>
+      case RESERVE_SLOTS_VALUE =>
         val pbReserveSlots = PbReserveSlots.parseFrom(message.getPayload)
         val userIdentifier = 
PbSerDeUtils.fromPbUserIdentifier(pbReserveSlots.getUserIdentifier)
         ReserveSlots(
@@ -1002,13 +1008,13 @@ object ControlMessages extends Logging {
           userIdentifier,
           pbReserveSlots.getPushDataTimeout)
 
-      case RESERVE_SLOTS_RESPONSE =>
+      case RESERVE_SLOTS_RESPONSE_VALUE =>
         val pbReserveSlotsResponse = 
PbReserveSlotsResponse.parseFrom(message.getPayload)
         ReserveSlotsResponse(
           Utils.toStatusCode(pbReserveSlotsResponse.getStatus),
           pbReserveSlotsResponse.getReason)
 
-      case COMMIT_FILES =>
+      case COMMIT_FILES_VALUE =>
         val pbCommitFiles = PbCommitFiles.parseFrom(message.getPayload)
         CommitFiles(
           pbCommitFiles.getApplicationId,
@@ -1018,7 +1024,7 @@ object ControlMessages extends Logging {
           pbCommitFiles.getMapAttemptsList.asScala.map(_.toInt).toArray,
           pbCommitFiles.getEpoch)
 
-      case COMMIT_FILES_RESPONSE =>
+      case COMMIT_FILES_RESPONSE_VALUE =>
         val pbCommitFilesResponse = 
PbCommitFilesResponse.parseFrom(message.getPayload)
         val committedPrimaryStorageInfos = new util.HashMap[String, 
StorageInfo]()
         val committedReplicaStorageInfos = new util.HashMap[String, 
StorageInfo]()
@@ -1042,46 +1048,46 @@ object ControlMessages extends Logging {
           pbCommitFilesResponse.getTotalWritten,
           pbCommitFilesResponse.getFileCount)
 
-      case DESTROY =>
+      case DESTROY_VALUE =>
         val pbDestroy = PbDestroyWorkerSlots.parseFrom(message.getPayload)
         DestroyWorkerSlots(
           pbDestroy.getShuffleKey,
           pbDestroy.getPrimaryLocationsList,
           pbDestroy.getReplicaLocationList)
 
-      case DESTROY_RESPONSE =>
+      case DESTROY_RESPONSE_VALUE =>
         val pbDestroyResponse = 
PbDestroyWorkerSlotsResponse.parseFrom(message.getPayload)
         DestroyWorkerSlotsResponse(
           Utils.toStatusCode(pbDestroyResponse.getStatus),
           pbDestroyResponse.getFailedPrimariesList,
           pbDestroyResponse.getFailedReplicasList)
 
-      case REMOVE_EXPIRED_SHUFFLE =>
+      case REMOVE_EXPIRED_SHUFFLE_VALUE =>
         RemoveExpiredShuffle
 
-      case ONE_WAY_MESSAGE_RESPONSE =>
+      case ONE_WAY_MESSAGE_RESPONSE_VALUE =>
         OneWayMessageResponse
 
-      case CHECK_FOR_WORKER_TIMEOUT =>
+      case CHECK_FOR_WORKER_TIMEOUT_VALUE =>
         pbCheckForWorkerTimeout
 
-      case CHECK_FOR_APPLICATION_TIMEOUT =>
+      case CHECK_FOR_APPLICATION_TIMEOUT_VALUE =>
         CheckForApplicationTimeOut
 
-      case WORKER_LOST =>
+      case WORKER_LOST_VALUE =>
         PbWorkerLost.parseFrom(message.getPayload)
 
-      case WORKER_LOST_RESPONSE =>
+      case WORKER_LOST_RESPONSE_VALUE =>
         PbWorkerLostResponse.parseFrom(message.getPayload)
 
-      case STAGE_END =>
+      case STAGE_END_VALUE =>
         val pbStageEnd = PbStageEnd.parseFrom(message.getPayload)
         StageEnd(pbStageEnd.getShuffleId)
 
-      case PARTITION_SPLIT =>
+      case PARTITION_SPLIT_VALUE =>
         PbPartitionSplit.parseFrom(message.getPayload)
 
-      case STAGE_END_RESPONSE =>
+      case STAGE_END_RESPONSE_VALUE =>
         val pbStageEndResponse = 
PbStageEndResponse.parseFrom(message.getPayload)
         StageEndResponse(Utils.toStatusCode(pbStageEndResponse.getStatus))
     }

Reply via email to