This is an automated email from the ASF dual-hosted git repository.

roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new 87f9b6f73 [MINOR]improvement(client/server): (RemoteMerger) Refactor 
to use mergeContext collect the arguments related (#2195)
87f9b6f73 is described below

commit 87f9b6f73b2499e241761fb36ca9f3b7edf3375d
Author: maobaolong <[email protected]>
AuthorDate: Tue Oct 22 19:12:19 2024 +0800

    [MINOR]improvement(client/server): (RemoteMerger) Refactor to use 
mergeContext collect the arguments related (#2195)
    
    ### What changes were proposed in this pull request?
    
    Refactor to use mergeContext collect the arguments related to remote merger
    
    ### Why are the changes needed?
    
    - Make code clean and friendly to other developer who do not attention to 
`Remote Merger`.
    - Without api change while extends the `mergeContext`.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    No need, just refactor.
---
 .../hadoop/mapreduce/v2/app/RssMRAppMaster.java    | 24 +++++-----
 .../hadoop/mapred/SortWriteBufferManagerTest.java  |  7 +--
 .../hadoop/mapreduce/task/reduce/FetcherTest.java  |  7 +--
 .../shuffle/manager/RssShuffleManagerBase.java     |  4 --
 .../tez/dag/app/TezRemoteShuffleManager.java       | 25 ++++++++---
 .../common/sort/buffer/WriteBufferManagerTest.java |  7 +--
 .../uniffle/client/api/ShuffleWriteClient.java     | 11 +----
 .../client/impl/ShuffleWriteClientImpl.java        | 13 ++----
 .../record/reader/MockedShuffleWriteClient.java    |  7 +--
 .../test/RemoteMergeShuffleWithRssClientTest.java  | 49 +++++++++++---------
 ...ShuffleWithRssClientTestWhenShuffleFlushed.java | 49 +++++++++++---------
 .../client/impl/grpc/ShuffleServerGrpcClient.java  | 26 +++--------
 .../client/request/RssRegisterShuffleRequest.java  | 52 +++-------------------
 proto/src/main/proto/Rss.proto                     | 14 +++---
 .../uniffle/server/ShuffleServerGrpcService.java   | 11 +----
 .../uniffle/server/merge/ShuffleMergeManager.java  | 23 ++++------
 .../server/merge/ShuffleMergeManagerTest.java      | 11 ++++-
 17 files changed, 145 insertions(+), 195 deletions(-)

diff --git 
a/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java
 
b/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java
index 4403120d5..06973d079 100644
--- 
a/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java
+++ 
b/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java
@@ -80,6 +80,7 @@ import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.common.util.Constants;
 import org.apache.uniffle.common.util.RetryUtils;
 import org.apache.uniffle.hadoop.shim.HadoopShimImpl;
+import org.apache.uniffle.proto.RssProtos.MergeContext;
 import org.apache.uniffle.storage.util.StorageType;
 
 import static 
org.apache.hadoop.mapreduce.RssMRConfig.RSS_REMOTE_MERGE_CLASS_LOADER;
@@ -285,17 +286,20 @@ public class RssMRAppMaster extends MRAppMaster {
                                   RssMRConfig.toRssConf(conf)
                                       
.get(MAX_CONCURRENCY_PER_PARTITION_TO_WRITE),
                                   0,
-                                  remoteMergeEnable ? 
conf.getMapOutputKeyClass().getName() : null,
                                   remoteMergeEnable
-                                      ? conf.getMapOutputValueClass().getName()
-                                      : null,
-                                  remoteMergeEnable
-                                      ? 
conf.getOutputKeyComparator().getClass().getName()
-                                      : null,
-                                  conf.getInt(
-                                      RssMRConfig.RSS_MERGED_BLOCK_SZIE,
-                                      
RssMRConfig.RSS_MERGED_BLOCK_SZIE_DEFAULT),
-                                  conf.get(RSS_REMOTE_MERGE_CLASS_LOADER)));
+                                      ? MergeContext.newBuilder()
+                                          
.setKeyClass(conf.getMapOutputKeyClass().getName())
+                                          
.setValueClass(conf.getMapOutputValueClass().getName())
+                                          .setComparatorClass(
+                                              
conf.getOutputKeyComparator().getClass().getName())
+                                          .setMergedBlockSize(
+                                              conf.getInt(
+                                                  
RssMRConfig.RSS_MERGED_BLOCK_SZIE,
+                                                  
RssMRConfig.RSS_MERGED_BLOCK_SZIE_DEFAULT))
+                                          .setMergeClassLoader(
+                                              
conf.get(RSS_REMOTE_MERGE_CLASS_LOADER, ""))
+                                          .build()
+                                      : null));
                   LOG.info(
                       "Finish register shuffle with "
                           + (System.currentTimeMillis() - start)
diff --git 
a/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
 
b/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
index 813ea1218..0385cb58e 100644
--- 
a/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
+++ 
b/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
@@ -62,6 +62,7 @@ import org.apache.uniffle.common.serializer.SerializerFactory;
 import org.apache.uniffle.common.serializer.SerializerInstance;
 import org.apache.uniffle.common.serializer.SerializerUtils;
 import org.apache.uniffle.common.util.JavaUtils;
+import org.apache.uniffle.proto.RssProtos;
 
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertFalse;
@@ -722,11 +723,7 @@ public class SortWriteBufferManagerTest {
         ShuffleDataDistributionType distributionType,
         int maxConcurrencyPerPartitionToWrite,
         int stageAttemptNumber,
-        String keyClassName,
-        String valueClassName,
-        String comparatorClassName,
-        int mergedBlockSize,
-        String mergeClassLoader) {}
+        RssProtos.MergeContext mergeContext) {}
 
     @Override
     public boolean sendCommit(
diff --git 
a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
 
b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
index a66f16f56..d2aaebe04 100644
--- 
a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
+++ 
b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
@@ -77,6 +77,7 @@ import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.common.rpc.StatusCode;
 import org.apache.uniffle.common.util.JavaUtils;
 import org.apache.uniffle.hadoop.shim.HadoopShimImpl;
+import org.apache.uniffle.proto.RssProtos;
 
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.mockito.Mockito.mock;
@@ -507,11 +508,7 @@ public class FetcherTest {
         ShuffleDataDistributionType distributionType,
         int maxConcurrencyPerPartitionToWrite,
         int stageAttemptNumber,
-        String keyClassName,
-        String valueClassName,
-        String comparatorClassName,
-        int mergedBlockSize,
-        String mergeClassLoader) {}
+        RssProtos.MergeContext mergeContext) {}
 
     @Override
     public boolean sendCommit(
diff --git 
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
index d82d3a509..47f9e271d 100644
--- 
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
+++ 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
@@ -1028,10 +1028,6 @@ public abstract class RssShuffleManagerBase implements 
RssShuffleManagerInterfac
                   ShuffleDataDistributionType.NORMAL,
                   maxConcurrencyPerPartitionToWrite,
                   stageAttemptNumber,
-                  null,
-                  null,
-                  null,
-                  -1,
                   null);
             });
     LOG.info(
diff --git 
a/client-tez/src/main/java/org/apache/tez/dag/app/TezRemoteShuffleManager.java 
b/client-tez/src/main/java/org/apache/tez/dag/app/TezRemoteShuffleManager.java
index 85a138c13..f44ad0c5e 100644
--- 
a/client-tez/src/main/java/org/apache/tez/dag/app/TezRemoteShuffleManager.java
+++ 
b/client-tez/src/main/java/org/apache/tez/dag/app/TezRemoteShuffleManager.java
@@ -66,6 +66,7 @@ import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.common.util.Constants;
 import org.apache.uniffle.common.util.JavaUtils;
 import org.apache.uniffle.common.util.RetryUtils;
+import org.apache.uniffle.proto.RssProtos;
 
 import static 
org.apache.uniffle.common.config.RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE;
 
@@ -305,13 +306,23 @@ public class TezRemoteShuffleManager implements 
ServicePluginLifecycle {
                                           RssTezConfig.toRssConf(conf)
                                               
.get(MAX_CONCURRENCY_PER_PARTITION_TO_WRITE),
                                           0,
-                                          keyClassName,
-                                          valueClassName,
-                                          comparatorClassName,
-                                          conf.getInt(
-                                              
RssTezConfig.RSS_MERGED_BLOCK_SZIE,
-                                              
RssTezConfig.RSS_MERGED_BLOCK_SZIE_DEFAULT),
-                                          
conf.get(RssTezConfig.RSS_REMOTE_MERGE_CLASS_LOADER)));
+                                          StringUtils.isBlank(keyClassName)
+                                              ? null
+                                              : 
RssProtos.MergeContext.newBuilder()
+                                                  .setKeyClass(keyClassName)
+                                                  
.setValueClass(valueClassName)
+                                                  
.setComparatorClass(comparatorClassName)
+                                                  .setMergedBlockSize(
+                                                      conf.getInt(
+                                                          
RssTezConfig.RSS_MERGED_BLOCK_SZIE,
+                                                          RssTezConfig
+                                                              
.RSS_MERGED_BLOCK_SZIE_DEFAULT))
+                                                  .setMergeClassLoader(
+                                                      conf.get(
+                                                          RssTezConfig
+                                                              
.RSS_REMOTE_MERGE_CLASS_LOADER,
+                                                          ""))
+                                                  .build()));
                           LOG.info(
                               "Finish register shuffle with "
                                   + (System.currentTimeMillis() - start)
diff --git 
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
 
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
index 724cec6c0..ce0458219 100644
--- 
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
+++ 
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
@@ -78,6 +78,7 @@ import org.apache.uniffle.common.serializer.SerializerFactory;
 import org.apache.uniffle.common.serializer.SerializerInstance;
 import org.apache.uniffle.common.serializer.SerializerUtils;
 import org.apache.uniffle.common.util.JavaUtils;
+import org.apache.uniffle.proto.RssProtos;
 import org.apache.uniffle.storage.util.StorageType;
 
 import static org.junit.jupiter.api.Assertions.assertEquals;
@@ -719,11 +720,7 @@ public class WriteBufferManagerTest {
         ShuffleDataDistributionType dataDistributionType,
         int maxConcurrencyPerPartitionToWrite,
         int stageAttemptNumber,
-        String keyClassName,
-        String valueClassName,
-        String comparatorClassName,
-        int mergedBlockSize,
-        String mergeClassLoader) {}
+        RssProtos.MergeContext mergeContext) {}
 
     @Override
     public boolean sendCommit(
diff --git 
a/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java 
b/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java
index 121271e36..d21c7e67b 100644
--- a/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java
+++ b/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java
@@ -33,6 +33,7 @@ import org.apache.uniffle.common.ShuffleAssignmentsInfo;
 import org.apache.uniffle.common.ShuffleBlockInfo;
 import org.apache.uniffle.common.ShuffleDataDistributionType;
 import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.proto.RssProtos.MergeContext;
 
 public interface ShuffleWriteClient {
 
@@ -72,10 +73,6 @@ public interface ShuffleWriteClient {
         dataDistributionType,
         maxConcurrencyPerPartitionToWrite,
         0,
-        null,
-        null,
-        null,
-        -1,
         null);
   }
 
@@ -88,11 +85,7 @@ public interface ShuffleWriteClient {
       ShuffleDataDistributionType dataDistributionType,
       int maxConcurrencyPerPartitionToWrite,
       int stageAttemptNumber,
-      String keyClassName,
-      String valueClassName,
-      String comparatorClassName,
-      int mergedBlockSize,
-      String mergeClassLoader);
+      MergeContext mergeContext);
 
   boolean sendCommit(
       Set<ShuffleServerInfo> shuffleServerInfoSet, String appId, int 
shuffleId, int numMaps);
diff --git 
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
 
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
index cba7ccc06..c81d3c725 100644
--- 
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
+++ 
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
@@ -95,6 +95,7 @@ import org.apache.uniffle.common.rpc.StatusCode;
 import org.apache.uniffle.common.util.BlockIdLayout;
 import org.apache.uniffle.common.util.JavaUtils;
 import org.apache.uniffle.common.util.ThreadUtils;
+import org.apache.uniffle.proto.RssProtos.MergeContext;
 
 public class ShuffleWriteClientImpl implements ShuffleWriteClient {
 
@@ -564,11 +565,7 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
       ShuffleDataDistributionType dataDistributionType,
       int maxConcurrencyPerPartitionToWrite,
       int stageAttemptNumber,
-      String keyClassName,
-      String valueClassName,
-      String comparatorClassName,
-      int mergedBlockSize,
-      String mergeClassLoader) {
+      MergeContext mergeContext) {
     String user = null;
     try {
       user = UserGroupInformation.getCurrentUser().getShortUserName();
@@ -586,11 +583,7 @@ public class ShuffleWriteClientImpl implements 
ShuffleWriteClient {
             dataDistributionType,
             maxConcurrencyPerPartitionToWrite,
             stageAttemptNumber,
-            keyClassName,
-            valueClassName,
-            comparatorClassName,
-            mergedBlockSize,
-            mergeClassLoader);
+            mergeContext);
     RssRegisterShuffleResponse response =
         getShuffleServerClient(shuffleServerInfo).registerShuffle(request);
 
diff --git 
a/client/src/test/java/org/apache/uniffle/client/record/reader/MockedShuffleWriteClient.java
 
b/client/src/test/java/org/apache/uniffle/client/record/reader/MockedShuffleWriteClient.java
index 7d4cbf980..6798a792c 100644
--- 
a/client/src/test/java/org/apache/uniffle/client/record/reader/MockedShuffleWriteClient.java
+++ 
b/client/src/test/java/org/apache/uniffle/client/record/reader/MockedShuffleWriteClient.java
@@ -34,6 +34,7 @@ import org.apache.uniffle.common.ShuffleAssignmentsInfo;
 import org.apache.uniffle.common.ShuffleBlockInfo;
 import org.apache.uniffle.common.ShuffleDataDistributionType;
 import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.proto.RssProtos;
 
 public class MockedShuffleWriteClient implements ShuffleWriteClient {
 
@@ -63,11 +64,7 @@ public class MockedShuffleWriteClient implements 
ShuffleWriteClient {
       ShuffleDataDistributionType dataDistributionType,
       int maxConcurrencyPerPartitionToWrite,
       int stageAttemptNumber,
-      String keyClassName,
-      String valueClassName,
-      String comparatorClassName,
-      int mergedBlockSize,
-      String mergeClassLoader) {}
+      RssProtos.MergeContext mergeContext) {}
 
   @Override
   public boolean sendCommit(
diff --git 
a/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTest.java
 
b/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTest.java
index 17ec0c212..f34117586 100644
--- 
a/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTest.java
+++ 
b/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTest.java
@@ -64,6 +64,7 @@ import org.apache.uniffle.common.serializer.SerializerUtils;
 import org.apache.uniffle.common.util.BlockIdLayout;
 import org.apache.uniffle.common.util.ChecksumUtils;
 import org.apache.uniffle.coordinator.CoordinatorConf;
+import org.apache.uniffle.proto.RssProtos;
 import org.apache.uniffle.server.ShuffleServerConf;
 import org.apache.uniffle.server.buffer.ShuffleBufferType;
 import org.apache.uniffle.storage.util.StorageType;
@@ -174,11 +175,13 @@ public class RemoteMergeShuffleWithRssClientTest extends 
ShuffleReadWriteBase {
         ShuffleDataDistributionType.NORMAL,
         0,
         -1,
-        keyClass.getName(),
-        valueClass.getName(),
-        comparator.getClass().getName(),
-        -1,
-        null);
+        RssProtos.MergeContext.newBuilder()
+            .setKeyClass(keyClass.getName())
+            .setValueClass(valueClass.getName())
+            .setComparatorClass(comparator.getClass().getName())
+            .setMergedBlockSize(-1)
+            .setMergeClassLoader("")
+            .build());
 
     // 3 report shuffle result
     // task 0 attempt 0 generate three blocks
@@ -337,11 +340,13 @@ public class RemoteMergeShuffleWithRssClientTest extends 
ShuffleReadWriteBase {
         ShuffleDataDistributionType.NORMAL,
         0,
         -1,
-        keyClass.getName(),
-        valueClass.getName(),
-        comparator.getClass().getName(),
-        -1,
-        null);
+        RssProtos.MergeContext.newBuilder()
+            .setKeyClass(keyClass.getName())
+            .setValueClass(valueClass.getName())
+            .setComparatorClass(comparator.getClass().getName())
+            .setMergedBlockSize(-1)
+            .setMergeClassLoader("")
+            .build());
 
     // 3 report shuffle result
     // task 0 attempt 0 generate three blocks
@@ -508,11 +513,13 @@ public class RemoteMergeShuffleWithRssClientTest extends 
ShuffleReadWriteBase {
         ShuffleDataDistributionType.NORMAL,
         0,
         -1,
-        keyClass.getName(),
-        valueClass.getName(),
-        comparator.getClass().getName(),
-        -1,
-        null);
+        RssProtos.MergeContext.newBuilder()
+            .setKeyClass(keyClass.getName())
+            .setValueClass(valueClass.getName())
+            .setComparatorClass(comparator.getClass().getName())
+            .setMergedBlockSize(-1)
+            .setMergeClassLoader("")
+            .build());
 
     // 3 report shuffle result
     // this shuffle have three partition, which is hash by key index mode 3
@@ -714,11 +721,13 @@ public class RemoteMergeShuffleWithRssClientTest extends 
ShuffleReadWriteBase {
         ShuffleDataDistributionType.NORMAL,
         0,
         -1,
-        keyClass.getName(),
-        valueClass.getName(),
-        comparator.getClass().getName(),
-        -1,
-        null);
+        RssProtos.MergeContext.newBuilder()
+            .setKeyClass(keyClass.getName())
+            .setValueClass(valueClass.getName())
+            .setComparatorClass(comparator.getClass().getName())
+            .setMergedBlockSize(-1)
+            .setMergeClassLoader("")
+            .build());
 
     // 3 report shuffle result
     // this shuffle have three partition, which is hash by key index mode 3
diff --git 
a/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed.java
 
b/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed.java
index c0af5e1bf..d12b286c2 100644
--- 
a/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed.java
+++ 
b/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed.java
@@ -64,6 +64,7 @@ import org.apache.uniffle.common.serializer.SerializerUtils;
 import org.apache.uniffle.common.util.BlockIdLayout;
 import org.apache.uniffle.common.util.ChecksumUtils;
 import org.apache.uniffle.coordinator.CoordinatorConf;
+import org.apache.uniffle.proto.RssProtos;
 import org.apache.uniffle.server.ShuffleServerConf;
 import org.apache.uniffle.server.buffer.ShuffleBufferType;
 import org.apache.uniffle.storage.util.StorageType;
@@ -179,11 +180,13 @@ public class 
RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed extends Shuff
         ShuffleDataDistributionType.NORMAL,
         -1,
         0,
-        keyClass.getName(),
-        valueClass.getName(),
-        comparator.getClass().getName(),
-        -1,
-        null);
+        RssProtos.MergeContext.newBuilder()
+            .setKeyClass(keyClass.getName())
+            .setValueClass(valueClass.getName())
+            .setComparatorClass(comparator.getClass().getName())
+            .setMergedBlockSize(-1)
+            .setMergeClassLoader("")
+            .build());
 
     // 3 report shuffle result
     // task 0 attempt 0 generate three blocks
@@ -342,11 +345,13 @@ public class 
RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed extends Shuff
         ShuffleDataDistributionType.NORMAL,
         -1,
         0,
-        keyClass.getName(),
-        valueClass.getName(),
-        comparator.getClass().getName(),
-        -1,
-        null);
+        RssProtos.MergeContext.newBuilder()
+            .setKeyClass(keyClass.getName())
+            .setValueClass(valueClass.getName())
+            .setComparatorClass(comparator.getClass().getName())
+            .setMergedBlockSize(-1)
+            .setMergeClassLoader("")
+            .build());
     Roaring64NavigableMap blockIdBitmap = Roaring64NavigableMap.bitmapOf();
 
     // 3 report shuffle result
@@ -514,11 +519,13 @@ public class 
RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed extends Shuff
         ShuffleDataDistributionType.NORMAL,
         -1,
         0,
-        keyClass.getName(),
-        valueClass.getName(),
-        comparator.getClass().getName(),
-        -1,
-        null);
+        RssProtos.MergeContext.newBuilder()
+            .setKeyClass(keyClass.getName())
+            .setValueClass(valueClass.getName())
+            .setComparatorClass(comparator.getClass().getName())
+            .setMergedBlockSize(-1)
+            .setMergeClassLoader("")
+            .build());
 
     // 3 report shuffle result
     // this shuffle have three partition, which is hash by key index mode 3
@@ -721,11 +728,13 @@ public class 
RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed extends Shuff
         ShuffleDataDistributionType.NORMAL,
         -1,
         0,
-        keyClass.getName(),
-        valueClass.getName(),
-        comparator.getClass().getName(),
-        -1,
-        null);
+        RssProtos.MergeContext.newBuilder()
+            .setKeyClass(keyClass.getName())
+            .setValueClass(valueClass.getName())
+            .setComparatorClass(comparator.getClass().getName())
+            .setMergedBlockSize(-1)
+            .setMergeClassLoader("")
+            .build());
 
     // 3 report shuffle result
     // this shuffle have three partition, which is hash by key index mode 3
diff --git 
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
 
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
index 63081041d..20b6bf98b 100644
--- 
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
+++ 
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
@@ -31,7 +31,6 @@ import com.google.common.collect.Lists;
 import com.google.protobuf.ByteString;
 import com.google.protobuf.UnsafeByteOperations;
 import io.netty.buffer.Unpooled;
-import org.apache.commons.lang3.StringUtils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -95,6 +94,7 @@ import 
org.apache.uniffle.proto.RssProtos.GetShuffleResultForMultiPartRequest;
 import org.apache.uniffle.proto.RssProtos.GetShuffleResultForMultiPartResponse;
 import org.apache.uniffle.proto.RssProtos.GetShuffleResultRequest;
 import org.apache.uniffle.proto.RssProtos.GetShuffleResultResponse;
+import org.apache.uniffle.proto.RssProtos.MergeContext;
 import org.apache.uniffle.proto.RssProtos.PartitionToBlockIds;
 import org.apache.uniffle.proto.RssProtos.RemoteStorage;
 import org.apache.uniffle.proto.RssProtos.RemoteStorageConfItem;
@@ -198,11 +198,7 @@ public class ShuffleServerGrpcClient extends GrpcClient 
implements ShuffleServer
       ShuffleDataDistributionType dataDistributionType,
       int maxConcurrencyPerPartitionToWrite,
       int stageAttemptNumber,
-      String keyClassName,
-      String valueClassName,
-      String comparatorClassName,
-      int mergedBlockSize,
-      String mergeClassLoader) {
+      MergeContext mergeContext) {
     ShuffleRegisterRequest.Builder reqBuilder = 
ShuffleRegisterRequest.newBuilder();
     reqBuilder
         .setAppId(appId)
@@ -212,16 +208,8 @@ public class ShuffleServerGrpcClient extends GrpcClient 
implements ShuffleServer
         
.setMaxConcurrencyPerPartitionToWrite(maxConcurrencyPerPartitionToWrite)
         .addAllPartitionRanges(toShufflePartitionRanges(partitionRanges))
         .setStageAttemptNumber(stageAttemptNumber);
-    if (StringUtils.isNotBlank(keyClassName)) {
-      reqBuilder.setKeyClass(keyClassName);
-      reqBuilder.setValueClass(valueClassName);
-      if (StringUtils.isNotBlank(comparatorClassName)) {
-        reqBuilder.setComparatorClass(comparatorClassName);
-      }
-      reqBuilder.setMergedBlockSize(mergedBlockSize);
-      if (StringUtils.isNotBlank(mergeClassLoader)) {
-        reqBuilder.setMergeClassLoader(mergeClassLoader);
-      }
+    if (mergeContext != null) {
+      reqBuilder.setMergeContext(mergeContext);
     }
     RemoteStorage.Builder rsBuilder = RemoteStorage.newBuilder();
     rsBuilder.setPath(remoteStorageInfo.getPath());
@@ -496,11 +484,7 @@ public class ShuffleServerGrpcClient extends GrpcClient 
implements ShuffleServer
             request.getDataDistributionType(),
             request.getMaxConcurrencyPerPartitionToWrite(),
             request.getStageAttemptNumber(),
-            request.getKeyClassName(),
-            request.getValueClassName(),
-            request.getComparatorClassName(),
-            request.getMergedBlockSize(),
-            request.getMergeClassLoader());
+            request.getMergeContext());
 
     RssRegisterShuffleResponse response;
     RssProtos.StatusCode statusCode = rpcResponse.getStatus();
diff --git 
a/internal-client/src/main/java/org/apache/uniffle/client/request/RssRegisterShuffleRequest.java
 
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssRegisterShuffleRequest.java
index 1db40a0d1..92ed1e15e 100644
--- 
a/internal-client/src/main/java/org/apache/uniffle/client/request/RssRegisterShuffleRequest.java
+++ 
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssRegisterShuffleRequest.java
@@ -25,6 +25,7 @@ import org.apache.uniffle.common.PartitionRange;
 import org.apache.uniffle.common.RemoteStorageInfo;
 import org.apache.uniffle.common.ShuffleDataDistributionType;
 import org.apache.uniffle.common.config.RssClientConf;
+import org.apache.uniffle.proto.RssProtos.MergeContext;
 
 public class RssRegisterShuffleRequest {
 
@@ -36,11 +37,8 @@ public class RssRegisterShuffleRequest {
   private ShuffleDataDistributionType dataDistributionType;
   private int maxConcurrencyPerPartitionToWrite;
   private int stageAttemptNumber;
-  private String keyClassName;
-  private String valueClassName;
-  private String comparatorClassName;
-  private int mergedBlockSize;
-  private String mergeClassLoader;
+
+  private final MergeContext mergeContext;
 
   public RssRegisterShuffleRequest(
       String appId,
@@ -59,10 +57,6 @@ public class RssRegisterShuffleRequest {
         dataDistributionType,
         maxConcurrencyPerPartitionToWrite,
         0,
-        null,
-        null,
-        null,
-        -1,
         null);
   }
 
@@ -75,11 +69,7 @@ public class RssRegisterShuffleRequest {
       ShuffleDataDistributionType dataDistributionType,
       int maxConcurrencyPerPartitionToWrite,
       int stageAttemptNumber,
-      String keyClassName,
-      String valueClassName,
-      String comparatorClassName,
-      int mergedBlockSize,
-      String mergeClassLoader) {
+      MergeContext mergeContext) {
     this.appId = appId;
     this.shuffleId = shuffleId;
     this.partitionRanges = partitionRanges;
@@ -88,11 +78,7 @@ public class RssRegisterShuffleRequest {
     this.dataDistributionType = dataDistributionType;
     this.maxConcurrencyPerPartitionToWrite = maxConcurrencyPerPartitionToWrite;
     this.stageAttemptNumber = stageAttemptNumber;
-    this.keyClassName = keyClassName;
-    this.valueClassName = valueClassName;
-    this.comparatorClassName = comparatorClassName;
-    this.mergedBlockSize = mergedBlockSize;
-    this.mergeClassLoader = mergeClassLoader;
+    this.mergeContext = mergeContext;
   }
 
   public RssRegisterShuffleRequest(
@@ -111,10 +97,6 @@ public class RssRegisterShuffleRequest {
         dataDistributionType,
         RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE.defaultValue(),
         0,
-        null,
-        null,
-        null,
-        -1,
         null);
   }
 
@@ -129,10 +111,6 @@ public class RssRegisterShuffleRequest {
         ShuffleDataDistributionType.NORMAL,
         RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE.defaultValue(),
         0,
-        null,
-        null,
-        null,
-        -1,
         null);
   }
 
@@ -168,23 +146,7 @@ public class RssRegisterShuffleRequest {
     return stageAttemptNumber;
   }
 
-  public String getKeyClassName() {
-    return keyClassName;
-  }
-
-  public String getValueClassName() {
-    return valueClassName;
-  }
-
-  public String getComparatorClassName() {
-    return comparatorClassName;
-  }
-
-  public int getMergedBlockSize() {
-    return mergedBlockSize;
-  }
-
-  public String getMergeClassLoader() {
-    return mergeClassLoader;
+  public MergeContext getMergeContext() {
+    return mergeContext;
   }
 }
diff --git a/proto/src/main/proto/Rss.proto b/proto/src/main/proto/Rss.proto
index 7e4b19696..d92ec40c7 100644
--- a/proto/src/main/proto/Rss.proto
+++ b/proto/src/main/proto/Rss.proto
@@ -179,6 +179,14 @@ message ShufflePartitionRange {
   int32 end = 2;
 }
 
+message MergeContext {
+  string keyClass = 1;
+  string valueClass = 2;
+  string comparatorClass = 3;
+  int32 mergedBlockSize = 4;
+  string mergeClassLoader = 5;
+}
+
 message ShuffleRegisterRequest {
   string appId = 1;
   int32 shuffleId = 2;
@@ -188,11 +196,7 @@ message ShuffleRegisterRequest {
   DataDistribution shuffleDataDistribution = 6;
   int32 maxConcurrencyPerPartitionToWrite = 7;
   int32 stageAttemptNumber = 8;
-  string keyClass = 9;
-  string valueClass = 10;
-  string comparatorClass = 11;
-  int32 mergedBlockSize = 12;
-  string mergeClassLoader = 13;
+  MergeContext mergeContext = 9;
 }
 
 enum DataDistribution {
diff --git 
a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java 
b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java
index ee780b872..994a25c89 100644
--- 
a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java
+++ 
b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java
@@ -325,7 +325,7 @@ public class ShuffleServerGrpcService extends 
ShuffleServerImplBase {
                   maxConcurrencyPerPartitionToWrite);
       if (StatusCode.SUCCESS == result
           && shuffleServer.isRemoteMergeEnable()
-          && StringUtils.isNotBlank(req.getKeyClass())) {
+          && req.hasMergeContext()) {
         // The merged block is in a different domain from the original block,
         // so you need to register a new app for holding the merged block.
         result =
@@ -343,14 +343,7 @@ public class ShuffleServerGrpcService extends 
ShuffleServerImplBase {
           result =
               shuffleServer
                   .getShuffleMergeManager()
-                  .registerShuffle(
-                      appId,
-                      shuffleId,
-                      req.getKeyClass(),
-                      req.getValueClass(),
-                      req.getComparatorClass(),
-                      req.getMergedBlockSize(),
-                      req.getMergeClassLoader());
+                  .registerShuffle(appId, shuffleId, req.getMergeContext());
         }
       }
       auditContext.withStatusCode(result);
diff --git 
a/server/src/main/java/org/apache/uniffle/server/merge/ShuffleMergeManager.java 
b/server/src/main/java/org/apache/uniffle/server/merge/ShuffleMergeManager.java
index f50bfd271..f8a6c1bfd 100644
--- 
a/server/src/main/java/org/apache/uniffle/server/merge/ShuffleMergeManager.java
+++ 
b/server/src/main/java/org/apache/uniffle/server/merge/ShuffleMergeManager.java
@@ -43,6 +43,7 @@ import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.common.merger.Segment;
 import org.apache.uniffle.common.rpc.StatusCode;
 import org.apache.uniffle.common.util.JavaUtils;
+import org.apache.uniffle.proto.RssProtos.MergeContext;
 import org.apache.uniffle.server.ShuffleServer;
 import org.apache.uniffle.server.ShuffleServerConf;
 
@@ -145,22 +146,16 @@ public class ShuffleMergeManager {
     return cachedClassLoader.getOrDefault(label, cachedClassLoader.get(""));
   }
 
-  public StatusCode registerShuffle(
-      String appId,
-      int shuffleId,
-      String keyClassName,
-      String valueClassName,
-      String comparatorClassName,
-      int mergedBlockSize,
-      String classLoaderLabel) {
+  public StatusCode registerShuffle(String appId, int shuffleId, MergeContext 
mergeContext) {
     try {
-      ClassLoader classLoader = getClassLoader(classLoaderLabel);
-      Class kClass = ClassUtils.getClass(classLoader, keyClassName);
-      Class vClass = ClassUtils.getClass(classLoader, valueClassName);
+      ClassLoader classLoader = 
getClassLoader(mergeContext.getMergeClassLoader());
+      Class kClass = ClassUtils.getClass(classLoader, 
mergeContext.getKeyClass());
+      Class vClass = ClassUtils.getClass(classLoader, 
mergeContext.getValueClass());
       Comparator comparator;
-      if (StringUtils.isNotBlank(comparatorClassName)) {
+      if (StringUtils.isNotBlank(mergeContext.getComparatorClass())) {
         Constructor constructor =
-            ClassUtils.getClass(classLoader, 
comparatorClassName).getDeclaredConstructor();
+            ClassUtils.getClass(classLoader, mergeContext.getComparatorClass())
+                .getDeclaredConstructor();
         constructor.setAccessible(true);
         comparator = (Comparator) constructor.newInstance();
       } else {
@@ -180,7 +175,7 @@ public class ShuffleMergeManager {
                   kClass,
                   vClass,
                   comparator,
-                  mergedBlockSize,
+                  mergeContext.getMergedBlockSize(),
                   classLoader));
     } catch (ClassNotFoundException
         | InstantiationException
diff --git 
a/server/src/test/java/org/apache/uniffle/server/merge/ShuffleMergeManagerTest.java
 
b/server/src/test/java/org/apache/uniffle/server/merge/ShuffleMergeManagerTest.java
index 449881dc3..4ea82750c 100644
--- 
a/server/src/test/java/org/apache/uniffle/server/merge/ShuffleMergeManagerTest.java
+++ 
b/server/src/test/java/org/apache/uniffle/server/merge/ShuffleMergeManagerTest.java
@@ -45,6 +45,7 @@ import 
org.apache.uniffle.common.serializer.PartialInputStream;
 import org.apache.uniffle.common.serializer.SerializerUtils;
 import org.apache.uniffle.common.serializer.writable.WritableSerializer;
 import org.apache.uniffle.common.util.BlockIdLayout;
+import org.apache.uniffle.proto.RssProtos;
 import org.apache.uniffle.server.ShuffleServer;
 import org.apache.uniffle.server.ShuffleServerConf;
 import org.apache.uniffle.server.ShuffleServerMetrics;
@@ -131,7 +132,15 @@ public class ShuffleMergeManagerTest {
         new RemoteStorageInfo(""),
         USER);
     mergeManager.registerShuffle(
-        APP_ID, SHUFFLE_ID, keyClassName, valueClassName, comparatorClassName, 
-1, "");
+        APP_ID,
+        SHUFFLE_ID,
+        RssProtos.MergeContext.newBuilder()
+            .setKeyClass(keyClassName)
+            .setValueClass(valueClassName)
+            .setComparatorClass(comparatorClassName)
+            .setMergedBlockSize(-1)
+            .setMergeClassLoader("")
+            .build());
 
     // 4 report blocks
     // 4.1 send shuffle data


Reply via email to