This is an automated email from the ASF dual-hosted git repository.
roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new 87f9b6f73 [MINOR]improvement(client/server): (RemoteMerger) Refactor
to use mergeContext collect the arguments related (#2195)
87f9b6f73 is described below
commit 87f9b6f73b2499e241761fb36ca9f3b7edf3375d
Author: maobaolong <[email protected]>
AuthorDate: Tue Oct 22 19:12:19 2024 +0800
[MINOR]improvement(client/server): (RemoteMerger) Refactor to use
mergeContext collect the arguments related (#2195)
### What changes were proposed in this pull request?
Refactor to use mergeContext collect the arguments related to remote merger
### Why are the changes needed?
- Make code clean and friendly to other developer who do not attention to
`Remote Merger`.
- Without api change while extends the `mergeContext`.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
No need, just refactor.
---
.../hadoop/mapreduce/v2/app/RssMRAppMaster.java | 24 +++++-----
.../hadoop/mapred/SortWriteBufferManagerTest.java | 7 +--
.../hadoop/mapreduce/task/reduce/FetcherTest.java | 7 +--
.../shuffle/manager/RssShuffleManagerBase.java | 4 --
.../tez/dag/app/TezRemoteShuffleManager.java | 25 ++++++++---
.../common/sort/buffer/WriteBufferManagerTest.java | 7 +--
.../uniffle/client/api/ShuffleWriteClient.java | 11 +----
.../client/impl/ShuffleWriteClientImpl.java | 13 ++----
.../record/reader/MockedShuffleWriteClient.java | 7 +--
.../test/RemoteMergeShuffleWithRssClientTest.java | 49 +++++++++++---------
...ShuffleWithRssClientTestWhenShuffleFlushed.java | 49 +++++++++++---------
.../client/impl/grpc/ShuffleServerGrpcClient.java | 26 +++--------
.../client/request/RssRegisterShuffleRequest.java | 52 +++-------------------
proto/src/main/proto/Rss.proto | 14 +++---
.../uniffle/server/ShuffleServerGrpcService.java | 11 +----
.../uniffle/server/merge/ShuffleMergeManager.java | 23 ++++------
.../server/merge/ShuffleMergeManagerTest.java | 11 ++++-
17 files changed, 145 insertions(+), 195 deletions(-)
diff --git
a/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java
b/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java
index 4403120d5..06973d079 100644
---
a/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java
+++
b/client-mr/core/src/main/java/org/apache/hadoop/mapreduce/v2/app/RssMRAppMaster.java
@@ -80,6 +80,7 @@ import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.util.Constants;
import org.apache.uniffle.common.util.RetryUtils;
import org.apache.uniffle.hadoop.shim.HadoopShimImpl;
+import org.apache.uniffle.proto.RssProtos.MergeContext;
import org.apache.uniffle.storage.util.StorageType;
import static
org.apache.hadoop.mapreduce.RssMRConfig.RSS_REMOTE_MERGE_CLASS_LOADER;
@@ -285,17 +286,20 @@ public class RssMRAppMaster extends MRAppMaster {
RssMRConfig.toRssConf(conf)
.get(MAX_CONCURRENCY_PER_PARTITION_TO_WRITE),
0,
- remoteMergeEnable ?
conf.getMapOutputKeyClass().getName() : null,
remoteMergeEnable
- ? conf.getMapOutputValueClass().getName()
- : null,
- remoteMergeEnable
- ?
conf.getOutputKeyComparator().getClass().getName()
- : null,
- conf.getInt(
- RssMRConfig.RSS_MERGED_BLOCK_SZIE,
-
RssMRConfig.RSS_MERGED_BLOCK_SZIE_DEFAULT),
- conf.get(RSS_REMOTE_MERGE_CLASS_LOADER)));
+ ? MergeContext.newBuilder()
+
.setKeyClass(conf.getMapOutputKeyClass().getName())
+
.setValueClass(conf.getMapOutputValueClass().getName())
+ .setComparatorClass(
+
conf.getOutputKeyComparator().getClass().getName())
+ .setMergedBlockSize(
+ conf.getInt(
+
RssMRConfig.RSS_MERGED_BLOCK_SZIE,
+
RssMRConfig.RSS_MERGED_BLOCK_SZIE_DEFAULT))
+ .setMergeClassLoader(
+
conf.get(RSS_REMOTE_MERGE_CLASS_LOADER, ""))
+ .build()
+ : null));
LOG.info(
"Finish register shuffle with "
+ (System.currentTimeMillis() - start)
diff --git
a/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
b/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
index 813ea1218..0385cb58e 100644
---
a/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
+++
b/client-mr/core/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
@@ -62,6 +62,7 @@ import org.apache.uniffle.common.serializer.SerializerFactory;
import org.apache.uniffle.common.serializer.SerializerInstance;
import org.apache.uniffle.common.serializer.SerializerUtils;
import org.apache.uniffle.common.util.JavaUtils;
+import org.apache.uniffle.proto.RssProtos;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
@@ -722,11 +723,7 @@ public class SortWriteBufferManagerTest {
ShuffleDataDistributionType distributionType,
int maxConcurrencyPerPartitionToWrite,
int stageAttemptNumber,
- String keyClassName,
- String valueClassName,
- String comparatorClassName,
- int mergedBlockSize,
- String mergeClassLoader) {}
+ RssProtos.MergeContext mergeContext) {}
@Override
public boolean sendCommit(
diff --git
a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
index a66f16f56..d2aaebe04 100644
---
a/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
+++
b/client-mr/core/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
@@ -77,6 +77,7 @@ import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.common.util.JavaUtils;
import org.apache.uniffle.hadoop.shim.HadoopShimImpl;
+import org.apache.uniffle.proto.RssProtos;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.mock;
@@ -507,11 +508,7 @@ public class FetcherTest {
ShuffleDataDistributionType distributionType,
int maxConcurrencyPerPartitionToWrite,
int stageAttemptNumber,
- String keyClassName,
- String valueClassName,
- String comparatorClassName,
- int mergedBlockSize,
- String mergeClassLoader) {}
+ RssProtos.MergeContext mergeContext) {}
@Override
public boolean sendCommit(
diff --git
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
index d82d3a509..47f9e271d 100644
---
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
+++
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
@@ -1028,10 +1028,6 @@ public abstract class RssShuffleManagerBase implements
RssShuffleManagerInterfac
ShuffleDataDistributionType.NORMAL,
maxConcurrencyPerPartitionToWrite,
stageAttemptNumber,
- null,
- null,
- null,
- -1,
null);
});
LOG.info(
diff --git
a/client-tez/src/main/java/org/apache/tez/dag/app/TezRemoteShuffleManager.java
b/client-tez/src/main/java/org/apache/tez/dag/app/TezRemoteShuffleManager.java
index 85a138c13..f44ad0c5e 100644
---
a/client-tez/src/main/java/org/apache/tez/dag/app/TezRemoteShuffleManager.java
+++
b/client-tez/src/main/java/org/apache/tez/dag/app/TezRemoteShuffleManager.java
@@ -66,6 +66,7 @@ import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.util.Constants;
import org.apache.uniffle.common.util.JavaUtils;
import org.apache.uniffle.common.util.RetryUtils;
+import org.apache.uniffle.proto.RssProtos;
import static
org.apache.uniffle.common.config.RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE;
@@ -305,13 +306,23 @@ public class TezRemoteShuffleManager implements
ServicePluginLifecycle {
RssTezConfig.toRssConf(conf)
.get(MAX_CONCURRENCY_PER_PARTITION_TO_WRITE),
0,
- keyClassName,
- valueClassName,
- comparatorClassName,
- conf.getInt(
-
RssTezConfig.RSS_MERGED_BLOCK_SZIE,
-
RssTezConfig.RSS_MERGED_BLOCK_SZIE_DEFAULT),
-
conf.get(RssTezConfig.RSS_REMOTE_MERGE_CLASS_LOADER)));
+ StringUtils.isBlank(keyClassName)
+ ? null
+ :
RssProtos.MergeContext.newBuilder()
+ .setKeyClass(keyClassName)
+
.setValueClass(valueClassName)
+
.setComparatorClass(comparatorClassName)
+ .setMergedBlockSize(
+ conf.getInt(
+
RssTezConfig.RSS_MERGED_BLOCK_SZIE,
+ RssTezConfig
+
.RSS_MERGED_BLOCK_SZIE_DEFAULT))
+ .setMergeClassLoader(
+ conf.get(
+ RssTezConfig
+
.RSS_REMOTE_MERGE_CLASS_LOADER,
+ ""))
+ .build()));
LOG.info(
"Finish register shuffle with "
+ (System.currentTimeMillis() - start)
diff --git
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
index 724cec6c0..ce0458219 100644
---
a/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
+++
b/client-tez/src/test/java/org/apache/tez/runtime/library/common/sort/buffer/WriteBufferManagerTest.java
@@ -78,6 +78,7 @@ import org.apache.uniffle.common.serializer.SerializerFactory;
import org.apache.uniffle.common.serializer.SerializerInstance;
import org.apache.uniffle.common.serializer.SerializerUtils;
import org.apache.uniffle.common.util.JavaUtils;
+import org.apache.uniffle.proto.RssProtos;
import org.apache.uniffle.storage.util.StorageType;
import static org.junit.jupiter.api.Assertions.assertEquals;
@@ -719,11 +720,7 @@ public class WriteBufferManagerTest {
ShuffleDataDistributionType dataDistributionType,
int maxConcurrencyPerPartitionToWrite,
int stageAttemptNumber,
- String keyClassName,
- String valueClassName,
- String comparatorClassName,
- int mergedBlockSize,
- String mergeClassLoader) {}
+ RssProtos.MergeContext mergeContext) {}
@Override
public boolean sendCommit(
diff --git
a/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java
b/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java
index 121271e36..d21c7e67b 100644
--- a/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java
+++ b/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java
@@ -33,6 +33,7 @@ import org.apache.uniffle.common.ShuffleAssignmentsInfo;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleDataDistributionType;
import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.proto.RssProtos.MergeContext;
public interface ShuffleWriteClient {
@@ -72,10 +73,6 @@ public interface ShuffleWriteClient {
dataDistributionType,
maxConcurrencyPerPartitionToWrite,
0,
- null,
- null,
- null,
- -1,
null);
}
@@ -88,11 +85,7 @@ public interface ShuffleWriteClient {
ShuffleDataDistributionType dataDistributionType,
int maxConcurrencyPerPartitionToWrite,
int stageAttemptNumber,
- String keyClassName,
- String valueClassName,
- String comparatorClassName,
- int mergedBlockSize,
- String mergeClassLoader);
+ MergeContext mergeContext);
boolean sendCommit(
Set<ShuffleServerInfo> shuffleServerInfoSet, String appId, int
shuffleId, int numMaps);
diff --git
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
index cba7ccc06..c81d3c725 100644
---
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
+++
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
@@ -95,6 +95,7 @@ import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.common.util.BlockIdLayout;
import org.apache.uniffle.common.util.JavaUtils;
import org.apache.uniffle.common.util.ThreadUtils;
+import org.apache.uniffle.proto.RssProtos.MergeContext;
public class ShuffleWriteClientImpl implements ShuffleWriteClient {
@@ -564,11 +565,7 @@ public class ShuffleWriteClientImpl implements
ShuffleWriteClient {
ShuffleDataDistributionType dataDistributionType,
int maxConcurrencyPerPartitionToWrite,
int stageAttemptNumber,
- String keyClassName,
- String valueClassName,
- String comparatorClassName,
- int mergedBlockSize,
- String mergeClassLoader) {
+ MergeContext mergeContext) {
String user = null;
try {
user = UserGroupInformation.getCurrentUser().getShortUserName();
@@ -586,11 +583,7 @@ public class ShuffleWriteClientImpl implements
ShuffleWriteClient {
dataDistributionType,
maxConcurrencyPerPartitionToWrite,
stageAttemptNumber,
- keyClassName,
- valueClassName,
- comparatorClassName,
- mergedBlockSize,
- mergeClassLoader);
+ mergeContext);
RssRegisterShuffleResponse response =
getShuffleServerClient(shuffleServerInfo).registerShuffle(request);
diff --git
a/client/src/test/java/org/apache/uniffle/client/record/reader/MockedShuffleWriteClient.java
b/client/src/test/java/org/apache/uniffle/client/record/reader/MockedShuffleWriteClient.java
index 7d4cbf980..6798a792c 100644
---
a/client/src/test/java/org/apache/uniffle/client/record/reader/MockedShuffleWriteClient.java
+++
b/client/src/test/java/org/apache/uniffle/client/record/reader/MockedShuffleWriteClient.java
@@ -34,6 +34,7 @@ import org.apache.uniffle.common.ShuffleAssignmentsInfo;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleDataDistributionType;
import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.proto.RssProtos;
public class MockedShuffleWriteClient implements ShuffleWriteClient {
@@ -63,11 +64,7 @@ public class MockedShuffleWriteClient implements
ShuffleWriteClient {
ShuffleDataDistributionType dataDistributionType,
int maxConcurrencyPerPartitionToWrite,
int stageAttemptNumber,
- String keyClassName,
- String valueClassName,
- String comparatorClassName,
- int mergedBlockSize,
- String mergeClassLoader) {}
+ RssProtos.MergeContext mergeContext) {}
@Override
public boolean sendCommit(
diff --git
a/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTest.java
b/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTest.java
index 17ec0c212..f34117586 100644
---
a/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTest.java
+++
b/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTest.java
@@ -64,6 +64,7 @@ import org.apache.uniffle.common.serializer.SerializerUtils;
import org.apache.uniffle.common.util.BlockIdLayout;
import org.apache.uniffle.common.util.ChecksumUtils;
import org.apache.uniffle.coordinator.CoordinatorConf;
+import org.apache.uniffle.proto.RssProtos;
import org.apache.uniffle.server.ShuffleServerConf;
import org.apache.uniffle.server.buffer.ShuffleBufferType;
import org.apache.uniffle.storage.util.StorageType;
@@ -174,11 +175,13 @@ public class RemoteMergeShuffleWithRssClientTest extends
ShuffleReadWriteBase {
ShuffleDataDistributionType.NORMAL,
0,
-1,
- keyClass.getName(),
- valueClass.getName(),
- comparator.getClass().getName(),
- -1,
- null);
+ RssProtos.MergeContext.newBuilder()
+ .setKeyClass(keyClass.getName())
+ .setValueClass(valueClass.getName())
+ .setComparatorClass(comparator.getClass().getName())
+ .setMergedBlockSize(-1)
+ .setMergeClassLoader("")
+ .build());
// 3 report shuffle result
// task 0 attempt 0 generate three blocks
@@ -337,11 +340,13 @@ public class RemoteMergeShuffleWithRssClientTest extends
ShuffleReadWriteBase {
ShuffleDataDistributionType.NORMAL,
0,
-1,
- keyClass.getName(),
- valueClass.getName(),
- comparator.getClass().getName(),
- -1,
- null);
+ RssProtos.MergeContext.newBuilder()
+ .setKeyClass(keyClass.getName())
+ .setValueClass(valueClass.getName())
+ .setComparatorClass(comparator.getClass().getName())
+ .setMergedBlockSize(-1)
+ .setMergeClassLoader("")
+ .build());
// 3 report shuffle result
// task 0 attempt 0 generate three blocks
@@ -508,11 +513,13 @@ public class RemoteMergeShuffleWithRssClientTest extends
ShuffleReadWriteBase {
ShuffleDataDistributionType.NORMAL,
0,
-1,
- keyClass.getName(),
- valueClass.getName(),
- comparator.getClass().getName(),
- -1,
- null);
+ RssProtos.MergeContext.newBuilder()
+ .setKeyClass(keyClass.getName())
+ .setValueClass(valueClass.getName())
+ .setComparatorClass(comparator.getClass().getName())
+ .setMergedBlockSize(-1)
+ .setMergeClassLoader("")
+ .build());
// 3 report shuffle result
// this shuffle have three partition, which is hash by key index mode 3
@@ -714,11 +721,13 @@ public class RemoteMergeShuffleWithRssClientTest extends
ShuffleReadWriteBase {
ShuffleDataDistributionType.NORMAL,
0,
-1,
- keyClass.getName(),
- valueClass.getName(),
- comparator.getClass().getName(),
- -1,
- null);
+ RssProtos.MergeContext.newBuilder()
+ .setKeyClass(keyClass.getName())
+ .setValueClass(valueClass.getName())
+ .setComparatorClass(comparator.getClass().getName())
+ .setMergedBlockSize(-1)
+ .setMergeClassLoader("")
+ .build());
// 3 report shuffle result
// this shuffle have three partition, which is hash by key index mode 3
diff --git
a/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed.java
b/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed.java
index c0af5e1bf..d12b286c2 100644
---
a/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed.java
+++
b/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed.java
@@ -64,6 +64,7 @@ import org.apache.uniffle.common.serializer.SerializerUtils;
import org.apache.uniffle.common.util.BlockIdLayout;
import org.apache.uniffle.common.util.ChecksumUtils;
import org.apache.uniffle.coordinator.CoordinatorConf;
+import org.apache.uniffle.proto.RssProtos;
import org.apache.uniffle.server.ShuffleServerConf;
import org.apache.uniffle.server.buffer.ShuffleBufferType;
import org.apache.uniffle.storage.util.StorageType;
@@ -179,11 +180,13 @@ public class
RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed extends Shuff
ShuffleDataDistributionType.NORMAL,
-1,
0,
- keyClass.getName(),
- valueClass.getName(),
- comparator.getClass().getName(),
- -1,
- null);
+ RssProtos.MergeContext.newBuilder()
+ .setKeyClass(keyClass.getName())
+ .setValueClass(valueClass.getName())
+ .setComparatorClass(comparator.getClass().getName())
+ .setMergedBlockSize(-1)
+ .setMergeClassLoader("")
+ .build());
// 3 report shuffle result
// task 0 attempt 0 generate three blocks
@@ -342,11 +345,13 @@ public class
RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed extends Shuff
ShuffleDataDistributionType.NORMAL,
-1,
0,
- keyClass.getName(),
- valueClass.getName(),
- comparator.getClass().getName(),
- -1,
- null);
+ RssProtos.MergeContext.newBuilder()
+ .setKeyClass(keyClass.getName())
+ .setValueClass(valueClass.getName())
+ .setComparatorClass(comparator.getClass().getName())
+ .setMergedBlockSize(-1)
+ .setMergeClassLoader("")
+ .build());
Roaring64NavigableMap blockIdBitmap = Roaring64NavigableMap.bitmapOf();
// 3 report shuffle result
@@ -514,11 +519,13 @@ public class
RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed extends Shuff
ShuffleDataDistributionType.NORMAL,
-1,
0,
- keyClass.getName(),
- valueClass.getName(),
- comparator.getClass().getName(),
- -1,
- null);
+ RssProtos.MergeContext.newBuilder()
+ .setKeyClass(keyClass.getName())
+ .setValueClass(valueClass.getName())
+ .setComparatorClass(comparator.getClass().getName())
+ .setMergedBlockSize(-1)
+ .setMergeClassLoader("")
+ .build());
// 3 report shuffle result
// this shuffle have three partition, which is hash by key index mode 3
@@ -721,11 +728,13 @@ public class
RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed extends Shuff
ShuffleDataDistributionType.NORMAL,
-1,
0,
- keyClass.getName(),
- valueClass.getName(),
- comparator.getClass().getName(),
- -1,
- null);
+ RssProtos.MergeContext.newBuilder()
+ .setKeyClass(keyClass.getName())
+ .setValueClass(valueClass.getName())
+ .setComparatorClass(comparator.getClass().getName())
+ .setMergedBlockSize(-1)
+ .setMergeClassLoader("")
+ .build());
// 3 report shuffle result
// this shuffle have three partition, which is hash by key index mode 3
diff --git
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
index 63081041d..20b6bf98b 100644
---
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
+++
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
@@ -31,7 +31,6 @@ import com.google.common.collect.Lists;
import com.google.protobuf.ByteString;
import com.google.protobuf.UnsafeByteOperations;
import io.netty.buffer.Unpooled;
-import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -95,6 +94,7 @@ import
org.apache.uniffle.proto.RssProtos.GetShuffleResultForMultiPartRequest;
import org.apache.uniffle.proto.RssProtos.GetShuffleResultForMultiPartResponse;
import org.apache.uniffle.proto.RssProtos.GetShuffleResultRequest;
import org.apache.uniffle.proto.RssProtos.GetShuffleResultResponse;
+import org.apache.uniffle.proto.RssProtos.MergeContext;
import org.apache.uniffle.proto.RssProtos.PartitionToBlockIds;
import org.apache.uniffle.proto.RssProtos.RemoteStorage;
import org.apache.uniffle.proto.RssProtos.RemoteStorageConfItem;
@@ -198,11 +198,7 @@ public class ShuffleServerGrpcClient extends GrpcClient
implements ShuffleServer
ShuffleDataDistributionType dataDistributionType,
int maxConcurrencyPerPartitionToWrite,
int stageAttemptNumber,
- String keyClassName,
- String valueClassName,
- String comparatorClassName,
- int mergedBlockSize,
- String mergeClassLoader) {
+ MergeContext mergeContext) {
ShuffleRegisterRequest.Builder reqBuilder =
ShuffleRegisterRequest.newBuilder();
reqBuilder
.setAppId(appId)
@@ -212,16 +208,8 @@ public class ShuffleServerGrpcClient extends GrpcClient
implements ShuffleServer
.setMaxConcurrencyPerPartitionToWrite(maxConcurrencyPerPartitionToWrite)
.addAllPartitionRanges(toShufflePartitionRanges(partitionRanges))
.setStageAttemptNumber(stageAttemptNumber);
- if (StringUtils.isNotBlank(keyClassName)) {
- reqBuilder.setKeyClass(keyClassName);
- reqBuilder.setValueClass(valueClassName);
- if (StringUtils.isNotBlank(comparatorClassName)) {
- reqBuilder.setComparatorClass(comparatorClassName);
- }
- reqBuilder.setMergedBlockSize(mergedBlockSize);
- if (StringUtils.isNotBlank(mergeClassLoader)) {
- reqBuilder.setMergeClassLoader(mergeClassLoader);
- }
+ if (mergeContext != null) {
+ reqBuilder.setMergeContext(mergeContext);
}
RemoteStorage.Builder rsBuilder = RemoteStorage.newBuilder();
rsBuilder.setPath(remoteStorageInfo.getPath());
@@ -496,11 +484,7 @@ public class ShuffleServerGrpcClient extends GrpcClient
implements ShuffleServer
request.getDataDistributionType(),
request.getMaxConcurrencyPerPartitionToWrite(),
request.getStageAttemptNumber(),
- request.getKeyClassName(),
- request.getValueClassName(),
- request.getComparatorClassName(),
- request.getMergedBlockSize(),
- request.getMergeClassLoader());
+ request.getMergeContext());
RssRegisterShuffleResponse response;
RssProtos.StatusCode statusCode = rpcResponse.getStatus();
diff --git
a/internal-client/src/main/java/org/apache/uniffle/client/request/RssRegisterShuffleRequest.java
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssRegisterShuffleRequest.java
index 1db40a0d1..92ed1e15e 100644
---
a/internal-client/src/main/java/org/apache/uniffle/client/request/RssRegisterShuffleRequest.java
+++
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssRegisterShuffleRequest.java
@@ -25,6 +25,7 @@ import org.apache.uniffle.common.PartitionRange;
import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.ShuffleDataDistributionType;
import org.apache.uniffle.common.config.RssClientConf;
+import org.apache.uniffle.proto.RssProtos.MergeContext;
public class RssRegisterShuffleRequest {
@@ -36,11 +37,8 @@ public class RssRegisterShuffleRequest {
private ShuffleDataDistributionType dataDistributionType;
private int maxConcurrencyPerPartitionToWrite;
private int stageAttemptNumber;
- private String keyClassName;
- private String valueClassName;
- private String comparatorClassName;
- private int mergedBlockSize;
- private String mergeClassLoader;
+
+ private final MergeContext mergeContext;
public RssRegisterShuffleRequest(
String appId,
@@ -59,10 +57,6 @@ public class RssRegisterShuffleRequest {
dataDistributionType,
maxConcurrencyPerPartitionToWrite,
0,
- null,
- null,
- null,
- -1,
null);
}
@@ -75,11 +69,7 @@ public class RssRegisterShuffleRequest {
ShuffleDataDistributionType dataDistributionType,
int maxConcurrencyPerPartitionToWrite,
int stageAttemptNumber,
- String keyClassName,
- String valueClassName,
- String comparatorClassName,
- int mergedBlockSize,
- String mergeClassLoader) {
+ MergeContext mergeContext) {
this.appId = appId;
this.shuffleId = shuffleId;
this.partitionRanges = partitionRanges;
@@ -88,11 +78,7 @@ public class RssRegisterShuffleRequest {
this.dataDistributionType = dataDistributionType;
this.maxConcurrencyPerPartitionToWrite = maxConcurrencyPerPartitionToWrite;
this.stageAttemptNumber = stageAttemptNumber;
- this.keyClassName = keyClassName;
- this.valueClassName = valueClassName;
- this.comparatorClassName = comparatorClassName;
- this.mergedBlockSize = mergedBlockSize;
- this.mergeClassLoader = mergeClassLoader;
+ this.mergeContext = mergeContext;
}
public RssRegisterShuffleRequest(
@@ -111,10 +97,6 @@ public class RssRegisterShuffleRequest {
dataDistributionType,
RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE.defaultValue(),
0,
- null,
- null,
- null,
- -1,
null);
}
@@ -129,10 +111,6 @@ public class RssRegisterShuffleRequest {
ShuffleDataDistributionType.NORMAL,
RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE.defaultValue(),
0,
- null,
- null,
- null,
- -1,
null);
}
@@ -168,23 +146,7 @@ public class RssRegisterShuffleRequest {
return stageAttemptNumber;
}
- public String getKeyClassName() {
- return keyClassName;
- }
-
- public String getValueClassName() {
- return valueClassName;
- }
-
- public String getComparatorClassName() {
- return comparatorClassName;
- }
-
- public int getMergedBlockSize() {
- return mergedBlockSize;
- }
-
- public String getMergeClassLoader() {
- return mergeClassLoader;
+ public MergeContext getMergeContext() {
+ return mergeContext;
}
}
diff --git a/proto/src/main/proto/Rss.proto b/proto/src/main/proto/Rss.proto
index 7e4b19696..d92ec40c7 100644
--- a/proto/src/main/proto/Rss.proto
+++ b/proto/src/main/proto/Rss.proto
@@ -179,6 +179,14 @@ message ShufflePartitionRange {
int32 end = 2;
}
+message MergeContext {
+ string keyClass = 1;
+ string valueClass = 2;
+ string comparatorClass = 3;
+ int32 mergedBlockSize = 4;
+ string mergeClassLoader = 5;
+}
+
message ShuffleRegisterRequest {
string appId = 1;
int32 shuffleId = 2;
@@ -188,11 +196,7 @@ message ShuffleRegisterRequest {
DataDistribution shuffleDataDistribution = 6;
int32 maxConcurrencyPerPartitionToWrite = 7;
int32 stageAttemptNumber = 8;
- string keyClass = 9;
- string valueClass = 10;
- string comparatorClass = 11;
- int32 mergedBlockSize = 12;
- string mergeClassLoader = 13;
+ MergeContext mergeContext = 9;
}
enum DataDistribution {
diff --git
a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java
b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java
index ee780b872..994a25c89 100644
---
a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java
+++
b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java
@@ -325,7 +325,7 @@ public class ShuffleServerGrpcService extends
ShuffleServerImplBase {
maxConcurrencyPerPartitionToWrite);
if (StatusCode.SUCCESS == result
&& shuffleServer.isRemoteMergeEnable()
- && StringUtils.isNotBlank(req.getKeyClass())) {
+ && req.hasMergeContext()) {
// The merged block is in a different domain from the original block,
// so you need to register a new app for holding the merged block.
result =
@@ -343,14 +343,7 @@ public class ShuffleServerGrpcService extends
ShuffleServerImplBase {
result =
shuffleServer
.getShuffleMergeManager()
- .registerShuffle(
- appId,
- shuffleId,
- req.getKeyClass(),
- req.getValueClass(),
- req.getComparatorClass(),
- req.getMergedBlockSize(),
- req.getMergeClassLoader());
+ .registerShuffle(appId, shuffleId, req.getMergeContext());
}
}
auditContext.withStatusCode(result);
diff --git
a/server/src/main/java/org/apache/uniffle/server/merge/ShuffleMergeManager.java
b/server/src/main/java/org/apache/uniffle/server/merge/ShuffleMergeManager.java
index f50bfd271..f8a6c1bfd 100644
---
a/server/src/main/java/org/apache/uniffle/server/merge/ShuffleMergeManager.java
+++
b/server/src/main/java/org/apache/uniffle/server/merge/ShuffleMergeManager.java
@@ -43,6 +43,7 @@ import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.merger.Segment;
import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.common.util.JavaUtils;
+import org.apache.uniffle.proto.RssProtos.MergeContext;
import org.apache.uniffle.server.ShuffleServer;
import org.apache.uniffle.server.ShuffleServerConf;
@@ -145,22 +146,16 @@ public class ShuffleMergeManager {
return cachedClassLoader.getOrDefault(label, cachedClassLoader.get(""));
}
- public StatusCode registerShuffle(
- String appId,
- int shuffleId,
- String keyClassName,
- String valueClassName,
- String comparatorClassName,
- int mergedBlockSize,
- String classLoaderLabel) {
+ public StatusCode registerShuffle(String appId, int shuffleId, MergeContext
mergeContext) {
try {
- ClassLoader classLoader = getClassLoader(classLoaderLabel);
- Class kClass = ClassUtils.getClass(classLoader, keyClassName);
- Class vClass = ClassUtils.getClass(classLoader, valueClassName);
+ ClassLoader classLoader =
getClassLoader(mergeContext.getMergeClassLoader());
+ Class kClass = ClassUtils.getClass(classLoader,
mergeContext.getKeyClass());
+ Class vClass = ClassUtils.getClass(classLoader,
mergeContext.getValueClass());
Comparator comparator;
- if (StringUtils.isNotBlank(comparatorClassName)) {
+ if (StringUtils.isNotBlank(mergeContext.getComparatorClass())) {
Constructor constructor =
- ClassUtils.getClass(classLoader,
comparatorClassName).getDeclaredConstructor();
+ ClassUtils.getClass(classLoader, mergeContext.getComparatorClass())
+ .getDeclaredConstructor();
constructor.setAccessible(true);
comparator = (Comparator) constructor.newInstance();
} else {
@@ -180,7 +175,7 @@ public class ShuffleMergeManager {
kClass,
vClass,
comparator,
- mergedBlockSize,
+ mergeContext.getMergedBlockSize(),
classLoader));
} catch (ClassNotFoundException
| InstantiationException
diff --git
a/server/src/test/java/org/apache/uniffle/server/merge/ShuffleMergeManagerTest.java
b/server/src/test/java/org/apache/uniffle/server/merge/ShuffleMergeManagerTest.java
index 449881dc3..4ea82750c 100644
---
a/server/src/test/java/org/apache/uniffle/server/merge/ShuffleMergeManagerTest.java
+++
b/server/src/test/java/org/apache/uniffle/server/merge/ShuffleMergeManagerTest.java
@@ -45,6 +45,7 @@ import
org.apache.uniffle.common.serializer.PartialInputStream;
import org.apache.uniffle.common.serializer.SerializerUtils;
import org.apache.uniffle.common.serializer.writable.WritableSerializer;
import org.apache.uniffle.common.util.BlockIdLayout;
+import org.apache.uniffle.proto.RssProtos;
import org.apache.uniffle.server.ShuffleServer;
import org.apache.uniffle.server.ShuffleServerConf;
import org.apache.uniffle.server.ShuffleServerMetrics;
@@ -131,7 +132,15 @@ public class ShuffleMergeManagerTest {
new RemoteStorageInfo(""),
USER);
mergeManager.registerShuffle(
- APP_ID, SHUFFLE_ID, keyClassName, valueClassName, comparatorClassName,
-1, "");
+ APP_ID,
+ SHUFFLE_ID,
+ RssProtos.MergeContext.newBuilder()
+ .setKeyClass(keyClassName)
+ .setValueClass(valueClassName)
+ .setComparatorClass(comparatorClassName)
+ .setMergedBlockSize(-1)
+ .setMergeClassLoader("")
+ .build());
// 4 report blocks
// 4.1 send shuffle data