This is an automated email from the ASF dual-hosted git repository.
zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new 96bf76cbc [#2591] feat(client): Introduce the mechanism to report
localfile read plan (#2603)
96bf76cbc is described below
commit 96bf76cbc9d9bbe51c3e8a502ba37b32ae2ce8d6
Author: Junfan Zhang <[email protected]>
AuthorDate: Tue Sep 9 19:36:18 2025 +0800
[#2591] feat(client): Introduce the mechanism to report localfile read plan
(#2603)
### What changes were proposed in this pull request?
This PR introduces a mechanism to report localfile read plan, and the
changes only are scoped in the client side. More changes should be added in the
shuffle server in the future.
### Why are the changes needed?
For normal partitions, the reading mode is sequential, which makes
read-ahead optimization feasible. This has already been verified in the Riffle
project (see [issue #483](https://github.com/zuston/riffle/issues/483)).
For huge partitions, however, the reading mode becomes skippable due to the
AQE skew join optimization rule. In such cases, it is difficult to predict the
next read position and length.
Based on this analysis, we propose introducing a fixed read plan that is
propagated from the client to the server, allowing the server to recognize the
next read offset and thereby benefit from read-ahead optimization.
### Does this PR introduce _any_ user-facing change?
Yes. And this feature will be disabled by default
### How was this patch tested?
Existing tests.
---
.../spark/shuffle/reader/RssShuffleReader.java | 6 +-
.../client/factory/ShuffleClientFactory.java | 10 ++
.../uniffle/client/impl/ShuffleReadClientImpl.java | 7 ++
.../org/apache/uniffle/common/ReadSegment.java | 51 +++++++++
.../uniffle/common/config/RssClientConf.java | 13 +++
.../protocol/GetLocalShuffleDataV3Request.java | 118 +++++++++++++++++++++
.../uniffle/common/netty/protocol/Message.java | 1 +
.../client/impl/grpc/ShuffleServerGrpcClient.java | 10 ++
.../impl/grpc/ShuffleServerGrpcNettyClient.java | 35 ++++--
.../client/request/RssGetShuffleDataRequest.java | 33 +++++-
proto/src/main/proto/Rss.proto | 6 ++
.../storage/factory/ShuffleHandlerFactory.java | 5 +-
.../handler/impl/DataSkippableReadHandler.java | 30 +++++-
.../handler/impl/HadoopShuffleReadHandler.java | 6 +-
.../handler/impl/LocalFileClientReadHandler.java | 26 ++++-
.../request/CreateShuffleReadHandlerRequest.java | 29 +++++
.../handler/impl/DataSkippableReadHandlerTest.java | 52 +++++++++
17 files changed, 418 insertions(+), 20 deletions(-)
diff --git
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
index 09a2b58c2..f03aa9955 100644
---
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
+++
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
@@ -316,7 +316,11 @@ public class RssShuffleReader<K, C> implements
ShuffleReader<K, C> {
.expectedTaskIdsBitmapFilterEnable(expectedTaskIdsBitmapFilterEnable)
.retryMax(retryMax)
.retryIntervalMax(retryIntervalMax)
- .rssConf(rssConf);
+ .rssConf(rssConf)
+ .taskAttemptId(
+ Optional.ofNullable(TaskContext.get())
+ .map(taskContext -> taskContext.taskAttemptId())
+ .orElse(0L));
if (codec.isPresent() &&
rssConf.get(RSS_READ_OVERLAPPING_DECOMPRESSION_ENABLED)) {
builder
.overlappingDecompressionEnabled(true)
diff --git
a/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java
b/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java
index d4a49fa97..33f7df0ce 100644
---
a/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java
+++
b/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java
@@ -227,6 +227,12 @@ public class ShuffleClientFactory {
private int retryMax;
private long retryIntervalMax;
private ShuffleServerReadCostTracker readCostTracker;
+ private long taskAttemptId;
+
+ public ReadClientBuilder taskAttemptId(long taskAttemptId) {
+ this.taskAttemptId = taskAttemptId;
+ return this;
+ }
private boolean overlappingDecompressionEnabled;
private int overlappingDecompressionThreadNum;
@@ -463,6 +469,10 @@ public class ShuffleClientFactory {
return codec;
}
+ public long getTaskAttemptId() {
+ return taskAttemptId;
+ }
+
public ShuffleReadClientImpl build() {
return new ShuffleReadClientImpl(this);
}
diff --git
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
index e6ffc09d6..8586d3493 100644
---
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
+++
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
@@ -57,6 +57,9 @@ import
org.apache.uniffle.storage.handler.api.ClientReadHandler;
import org.apache.uniffle.storage.handler.impl.ShuffleServerReadCostTracker;
import org.apache.uniffle.storage.request.CreateShuffleReadHandlerRequest;
+import static
org.apache.uniffle.common.config.RssClientConf.READ_CLIENT_NEXT_SEGMENTS_REPORT_COUNT;
+import static
org.apache.uniffle.common.config.RssClientConf.READ_CLIENT_NEXT_SEGMENTS_REPORT_ENABLED;
+
public class ShuffleReadClientImpl implements ShuffleReadClient {
private static final Logger LOG =
LoggerFactory.getLogger(ShuffleReadClientImpl.class);
@@ -171,6 +174,10 @@ public class ShuffleReadClientImpl implements
ShuffleReadClient {
this.readCostTracker = builder.getReadCostTracker();
CreateShuffleReadHandlerRequest request = new
CreateShuffleReadHandlerRequest();
+ request.setNextReadSegmentsReportCount(
+ builder.getRssConf().get(READ_CLIENT_NEXT_SEGMENTS_REPORT_COUNT));
+ request.setNextReadSegmentsReportEnabled(
+ builder.getRssConf().get(READ_CLIENT_NEXT_SEGMENTS_REPORT_ENABLED));
request.setStorageType(builder.getStorageType());
request.setAppId(builder.getAppId());
request.setShuffleId(shuffleId);
diff --git a/common/src/main/java/org/apache/uniffle/common/ReadSegment.java
b/common/src/main/java/org/apache/uniffle/common/ReadSegment.java
new file mode 100644
index 000000000..552e4d552
--- /dev/null
+++ b/common/src/main/java/org/apache/uniffle/common/ReadSegment.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common;
+
+import java.util.ArrayList;
+import java.util.List;
+
+public class ReadSegment {
+ private final long offset;
+ private final long length;
+
+ public ReadSegment(long offset, long length) {
+ this.offset = offset;
+ this.length = length;
+ }
+
+ public long getOffset() {
+ return offset;
+ }
+
+ public long getLength() {
+ return length;
+ }
+
+ public static ReadSegment from(ShuffleDataSegment segment) {
+ return new ReadSegment(segment.getOffset(), segment.getLength());
+ }
+
+ public static List<ReadSegment> from(List<ShuffleDataSegment> segments) {
+ List<ReadSegment> readSegments = new ArrayList<ReadSegment>();
+ for (ShuffleDataSegment segment : segments) {
+ readSegments.add(ReadSegment.from(segment));
+ }
+ return readSegments;
+ }
+}
diff --git
a/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java
b/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java
index 0c3a465f2..a7f67e1bc 100644
--- a/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java
+++ b/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java
@@ -368,4 +368,17 @@ public class RssClientConf {
.intType()
.defaultValue(120)
.withDescription("Read prefetch timeout seconds");
+
+ public static final ConfigOption<Boolean>
READ_CLIENT_NEXT_SEGMENTS_REPORT_ENABLED =
+ ConfigOptions.key("rss.client.read.nextReadSegmentsReportEnabled")
+ .booleanType()
+ .defaultValue(false)
+ .withDescription(
+ "Whether the next read segment report is enabled for
shuffle-server read ahead");
+
+ public static final ConfigOption<Integer>
READ_CLIENT_NEXT_SEGMENTS_REPORT_COUNT =
+ ConfigOptions.key("rss.client.read.nextReadSegmentsReportCount")
+ .intType()
+ .defaultValue(4)
+ .withDescription("Next read segment count for shuffle-server read
ahead");
}
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleDataV3Request.java
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleDataV3Request.java
new file mode 100644
index 000000000..993b80c51
--- /dev/null
+++
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleDataV3Request.java
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.netty.protocol;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import io.netty.buffer.ByteBuf;
+
+import org.apache.uniffle.common.ReadSegment;
+import org.apache.uniffle.common.util.ByteBufUtils;
+
+public class GetLocalShuffleDataV3Request extends GetLocalShuffleDataV2Request
{
+ private final List<ReadSegment> nextReadSegments;
+ private final long taskAttemptId;
+
+ public GetLocalShuffleDataV3Request(
+ long requestId,
+ String appId,
+ int shuffleId,
+ int partitionId,
+ int partitionNumPerRange,
+ int partitionNum,
+ long offset,
+ int length,
+ int storageId,
+ List<ReadSegment> nextReadSegments,
+ long timestamp,
+ long taskAttemptId) {
+ super(
+ requestId,
+ appId,
+ shuffleId,
+ partitionId,
+ partitionNumPerRange,
+ partitionNum,
+ offset,
+ length,
+ storageId,
+ timestamp);
+ this.nextReadSegments = nextReadSegments;
+ this.taskAttemptId = taskAttemptId;
+ }
+
+ @Override
+ public Type type() {
+ return Type.GET_LOCAL_SHUFFLE_DATA_V3_REQUEST;
+ }
+
+ @Override
+ public int encodedLength() {
+ return super.encodedLength() + Long.BYTES * 2 * nextReadSegments.size();
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ super.encode(buf);
+ buf.writeInt(nextReadSegments.size());
+ for (ReadSegment segment : nextReadSegments) {
+ buf.writeLong(segment.getOffset());
+ buf.writeLong(segment.getLength());
+ }
+ buf.writeLong(taskAttemptId);
+ }
+
+ public static GetLocalShuffleDataV3Request decode(ByteBuf byteBuf) {
+ long requestId = byteBuf.readLong();
+ String appId = ByteBufUtils.readLengthAndString(byteBuf);
+ int shuffleId = byteBuf.readInt();
+ int partitionId = byteBuf.readInt();
+ int partitionNumPerRange = byteBuf.readInt();
+ int partitionNum = byteBuf.readInt();
+ long offset = byteBuf.readLong();
+ int length = byteBuf.readInt();
+ long timestamp = byteBuf.readLong();
+ int storageId = byteBuf.readInt();
+
+ int readSegmentCount = byteBuf.readInt();
+ List<ReadSegment> readSegments = new ArrayList<>(readSegmentCount);
+ for (int i = 0; i < readSegmentCount; i++) {
+ readSegments.add(new ReadSegment(byteBuf.readLong(),
byteBuf.readLong()));
+ }
+ long taskAttemptId = byteBuf.readLong();
+ return new GetLocalShuffleDataV3Request(
+ requestId,
+ appId,
+ shuffleId,
+ partitionId,
+ partitionNumPerRange,
+ partitionNum,
+ offset,
+ length,
+ storageId,
+ readSegments,
+ timestamp,
+ taskAttemptId);
+ }
+
+ @Override
+ public String getOperationType() {
+ return "getLocalShuffleDataV3";
+ }
+}
diff --git
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/Message.java
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/Message.java
index 2ad0b0d77..5fcb47bd0 100644
--- a/common/src/main/java/org/apache/uniffle/common/netty/protocol/Message.java
+++ b/common/src/main/java/org/apache/uniffle/common/netty/protocol/Message.java
@@ -67,6 +67,7 @@ public abstract class Message implements Encodable {
GET_SORTED_SHUFFLE_DATA_RESPONSE(22),
GET_LOCAL_SHUFFLE_INDEX_V2_RESPONSE(23),
GET_LOCAL_SHUFFLE_DATA_V2_REQUEST(24),
+ GET_LOCAL_SHUFFLE_DATA_V3_REQUEST(25),
;
private final byte id;
diff --git
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
index 1d9d7b0d0..0a5259284 100644
---
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
+++
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
@@ -26,6 +26,7 @@ import java.util.Random;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
+import java.util.stream.Collectors;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Lists;
@@ -958,6 +959,15 @@ public class ShuffleServerGrpcClient extends GrpcClient
implements ShuffleServer
.setLength(request.getLength())
.setTimestamp(start)
.setStorageId(request.getStorageId())
+ .addAllNextReadSegments(
+ request.getNextReadSegments().stream()
+ .map(
+ x ->
+ RssProtos.ReadSegment.newBuilder()
+ .setLength(x.getLength())
+ .setOffset(x.getOffset())
+ .build())
+ .collect(Collectors.toList()))
.build();
String requestInfo =
"appId["
diff --git
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java
index 1f014c8b2..080fed297 100644
---
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java
+++
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java
@@ -43,6 +43,7 @@ import
org.apache.uniffle.client.response.RssGetShuffleIndexResponse;
import org.apache.uniffle.client.response.RssGetSortedShuffleDataResponse;
import org.apache.uniffle.client.response.RssSendShuffleDataResponse;
import org.apache.uniffle.common.ClientType;
+import org.apache.uniffle.common.ReadSegment;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.RssClientConf;
@@ -57,6 +58,7 @@ import
org.apache.uniffle.common.netty.client.TransportContext;
import org.apache.uniffle.common.netty.protocol.GetLocalShuffleDataRequest;
import org.apache.uniffle.common.netty.protocol.GetLocalShuffleDataResponse;
import org.apache.uniffle.common.netty.protocol.GetLocalShuffleDataV2Request;
+import org.apache.uniffle.common.netty.protocol.GetLocalShuffleDataV3Request;
import org.apache.uniffle.common.netty.protocol.GetLocalShuffleIndexRequest;
import org.apache.uniffle.common.netty.protocol.GetLocalShuffleIndexResponse;
import org.apache.uniffle.common.netty.protocol.GetMemoryShuffleDataRequest;
@@ -387,9 +389,11 @@ public class ShuffleServerGrpcNettyClient extends
ShuffleServerGrpcClient {
public RssGetShuffleDataResponse getShuffleData(RssGetShuffleDataRequest
request) {
TransportClient transportClient = getTransportClient();
// Construct old version or v2 get shuffle data request to compatible with
old server
- GetLocalShuffleDataRequest getLocalShuffleIndexRequest =
- request.storageIdSpecified()
- ? new GetLocalShuffleDataV2Request(
+ GetLocalShuffleDataRequest getLocalShuffleDataRequest = null;
+ if (request.storageIdSpecified()) {
+ if (request.isNextReadSegmentsReportEnabled()) {
+ getLocalShuffleDataRequest =
+ new GetLocalShuffleDataV3Request(
requestId(),
request.getAppId(),
request.getShuffleId(),
@@ -399,8 +403,12 @@ public class ShuffleServerGrpcNettyClient extends
ShuffleServerGrpcClient {
request.getOffset(),
request.getLength(),
request.getStorageId(),
- System.currentTimeMillis())
- : new GetLocalShuffleDataRequest(
+ ReadSegment.from(request.getNextReadSegments()),
+ System.currentTimeMillis(),
+ request.getTaskAttemptId());
+ } else {
+ getLocalShuffleDataRequest =
+ new GetLocalShuffleDataV2Request(
requestId(),
request.getAppId(),
request.getShuffleId(),
@@ -409,7 +417,22 @@ public class ShuffleServerGrpcNettyClient extends
ShuffleServerGrpcClient {
request.getPartitionNum(),
request.getOffset(),
request.getLength(),
+ request.getStorageId(),
System.currentTimeMillis());
+ }
+ } else {
+ getLocalShuffleDataRequest =
+ new GetLocalShuffleDataRequest(
+ requestId(),
+ request.getAppId(),
+ request.getShuffleId(),
+ request.getPartitionId(),
+ request.getPartitionNumPerRange(),
+ request.getPartitionNum(),
+ request.getOffset(),
+ request.getLength(),
+ System.currentTimeMillis());
+ }
String requestInfo =
"appId["
+ request.getAppId()
@@ -423,7 +446,7 @@ public class ShuffleServerGrpcNettyClient extends
ShuffleServerGrpcClient {
RpcResponse rpcResponse;
GetLocalShuffleDataResponse getLocalShuffleDataResponse;
while (true) {
- rpcResponse = transportClient.sendRpcSync(getLocalShuffleIndexRequest,
rpcTimeout);
+ rpcResponse = transportClient.sendRpcSync(getLocalShuffleDataRequest,
rpcTimeout);
getLocalShuffleDataResponse = (GetLocalShuffleDataResponse) rpcResponse;
if (rpcResponse.getStatusCode() != StatusCode.NO_BUFFER) {
break;
diff --git
a/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleDataRequest.java
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleDataRequest.java
index c245e48b7..d401e12fc 100644
---
a/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleDataRequest.java
+++
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleDataRequest.java
@@ -17,8 +17,13 @@
package org.apache.uniffle.client.request;
+import java.util.Collections;
+import java.util.List;
+
import com.google.common.annotations.VisibleForTesting;
+import org.apache.uniffle.common.ShuffleDataSegment;
+
public class RssGetShuffleDataRequest extends RetryableRequest {
private final String appId;
@@ -29,6 +34,9 @@ public class RssGetShuffleDataRequest extends
RetryableRequest {
private final long offset;
private final int length;
private final int storageId;
+ private final long taskAttemptId;
+ private final List<ShuffleDataSegment> nextReadSegments;
+ private final boolean nextReadSegmentsReportEnabled;
public RssGetShuffleDataRequest(
String appId,
@@ -40,7 +48,10 @@ public class RssGetShuffleDataRequest extends
RetryableRequest {
int length,
int storageId,
int retryMax,
- long retryIntervalMax) {
+ long retryIntervalMax,
+ long taskAttemptId,
+ List<ShuffleDataSegment> nextReadSegments,
+ boolean nextReadSegmentsReportEnabled) {
this.appId = appId;
this.shuffleId = shuffleId;
this.partitionId = partitionId;
@@ -51,6 +62,9 @@ public class RssGetShuffleDataRequest extends
RetryableRequest {
this.storageId = storageId;
this.retryMax = retryMax;
this.retryIntervalMax = retryIntervalMax;
+ this.nextReadSegments = nextReadSegments;
+ this.nextReadSegmentsReportEnabled = nextReadSegmentsReportEnabled;
+ this.taskAttemptId = taskAttemptId;
}
@VisibleForTesting
@@ -72,7 +86,10 @@ public class RssGetShuffleDataRequest extends
RetryableRequest {
length,
-1,
1,
- 0);
+ 0,
+ 0,
+ Collections.emptyList(),
+ false);
}
public String getAppId() {
@@ -111,6 +128,18 @@ public class RssGetShuffleDataRequest extends
RetryableRequest {
return storageId != -1;
}
+ public List<ShuffleDataSegment> getNextReadSegments() {
+ return nextReadSegments;
+ }
+
+ public boolean isNextReadSegmentsReportEnabled() {
+ return nextReadSegmentsReportEnabled;
+ }
+
+ public long getTaskAttemptId() {
+ return taskAttemptId;
+ }
+
@Override
public String operationType() {
return "GetShuffleData";
diff --git a/proto/src/main/proto/Rss.proto b/proto/src/main/proto/Rss.proto
index 2967d98c0..c3db61e39 100644
--- a/proto/src/main/proto/Rss.proto
+++ b/proto/src/main/proto/Rss.proto
@@ -87,6 +87,12 @@ message GetLocalShuffleDataRequest {
int32 length = 7;
int64 timestamp = 8;
int32 storageId = 9;
+ repeated ReadSegment nextReadSegments = 10;
+}
+
+message ReadSegment {
+ int64 offset = 1;
+ int64 length = 2;
}
message GetLocalShuffleDataResponse {
diff --git
a/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
b/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
index 356cfc1dd..e01ba1d1f 100644
---
a/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
+++
b/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
@@ -165,7 +165,10 @@ public class ShuffleHandlerFactory {
request.getRetryMax(),
request.getRetryIntervalMax(),
request.getPrefetchOption(),
- request.getReadCostTracker());
+ request.getReadCostTracker(),
+ request.isNextReadSegmentsReportEnabled(),
+ request.getNextReadSegmentsReportCount(),
+ request.getTaskAttemptId());
}
private ClientReadHandler getHadoopClientReadHandler(
diff --git
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/DataSkippableReadHandler.java
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/DataSkippableReadHandler.java
index 58288d53c..1ecb32f23 100644
---
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/DataSkippableReadHandler.java
+++
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/DataSkippableReadHandler.java
@@ -22,6 +22,7 @@ import java.util.List;
import java.util.Optional;
import java.util.Set;
+import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Lists;
import org.roaringbitmap.longlong.Roaring64NavigableMap;
import org.slf4j.Logger;
@@ -35,6 +36,7 @@ import
org.apache.uniffle.common.segment.SegmentSplitterFactory;
public abstract class DataSkippableReadHandler extends
PrefetchableClientReadHandler {
private static final Logger LOG =
LoggerFactory.getLogger(DataSkippableReadHandler.class);
+ private static final int DEFAULT_NEXT_READ_BATCH_NUMBER = 4;
protected List<ShuffleDataSegment> shuffleDataSegments =
Lists.newArrayList();
protected int segmentIndex = 0;
@@ -45,6 +47,8 @@ public abstract class DataSkippableReadHandler extends
PrefetchableClientReadHan
protected ShuffleDataDistributionType distributionType;
protected Roaring64NavigableMap expectTaskIds;
+ private int nextReadSegmentReportCount;
+
public DataSkippableReadHandler(
String appId,
int shuffleId,
@@ -54,7 +58,8 @@ public abstract class DataSkippableReadHandler extends
PrefetchableClientReadHan
Set<Long> processBlockIds,
ShuffleDataDistributionType distributionType,
Roaring64NavigableMap expectTaskIds,
- Optional<PrefetchOption> prefetchOption) {
+ Optional<PrefetchOption> prefetchOption,
+ int nextReadSegmentReportCount) {
super(prefetchOption);
this.appId = appId;
this.shuffleId = shuffleId;
@@ -64,11 +69,13 @@ public abstract class DataSkippableReadHandler extends
PrefetchableClientReadHan
this.processBlockIds = processBlockIds;
this.distributionType = distributionType;
this.expectTaskIds = expectTaskIds;
+ this.nextReadSegmentReportCount = nextReadSegmentReportCount;
}
protected abstract ShuffleIndexResult readShuffleIndex();
- protected abstract ShuffleDataResult readShuffleData(ShuffleDataSegment
segment);
+ protected abstract ShuffleDataResult readShuffleData(
+ ShuffleDataSegment segment, List<ShuffleDataSegment> nextReadSegments);
@Override
public ShuffleDataResult doReadShuffleData() {
@@ -100,7 +107,11 @@ public abstract class DataSkippableReadHandler extends
PrefetchableClientReadHan
// skip processed blockIds
blocksOfSegment.removeAll(processBlockIds);
if (!blocksOfSegment.isEmpty()) {
- result = readShuffleData(segment);
+ result =
+ readShuffleData(
+ segment,
+ getNextSegments(
+ shuffleDataSegments, segmentIndex + 1,
nextReadSegmentReportCount));
segmentIndex++;
break;
}
@@ -109,4 +120,17 @@ public abstract class DataSkippableReadHandler extends
PrefetchableClientReadHan
}
return result;
}
+
+ @VisibleForTesting
+ protected static List<ShuffleDataSegment> getNextSegments(
+ List<ShuffleDataSegment> shuffleDataSegments, int startIndex, int
number) {
+ List<ShuffleDataSegment> nextSegments = Lists.newArrayList();
+ for (int i = startIndex; i < shuffleDataSegments.size(); i++) {
+ if (nextSegments.size() >= number) {
+ break;
+ }
+ nextSegments.add(shuffleDataSegments.get(i));
+ }
+ return nextSegments;
+ }
}
diff --git
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/HadoopShuffleReadHandler.java
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/HadoopShuffleReadHandler.java
index f2153d2be..09372c32f 100644
---
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/HadoopShuffleReadHandler.java
+++
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/HadoopShuffleReadHandler.java
@@ -71,7 +71,8 @@ public class HadoopShuffleReadHandler extends
DataSkippableReadHandler {
processBlockIds,
distributionType,
expectTaskIds,
- prefetchOption);
+ prefetchOption,
+ 4);
this.filePrefix = filePrefix;
this.indexReader =
createHadoopReader(ShuffleStorageUtils.generateIndexFileName(filePrefix), conf);
@@ -135,7 +136,8 @@ public class HadoopShuffleReadHandler extends
DataSkippableReadHandler {
return new ShuffleIndexResult();
}
- protected ShuffleDataResult readShuffleData(ShuffleDataSegment
shuffleDataSegment) {
+ protected ShuffleDataResult readShuffleData(
+ ShuffleDataSegment shuffleDataSegment, List<ShuffleDataSegment>
nextReadSegments) {
// Here we make an assumption that the rest of the file is corrupted, if
an unexpected data is
// read.
int expectedLength = shuffleDataSegment.getLength();
diff --git
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileClientReadHandler.java
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileClientReadHandler.java
index cfdf5fdb7..9405cad62 100644
---
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileClientReadHandler.java
+++
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileClientReadHandler.java
@@ -17,6 +17,7 @@
package org.apache.uniffle.storage.handler.impl;
+import java.util.List;
import java.util.Optional;
import java.util.Set;
@@ -46,6 +47,8 @@ public class LocalFileClientReadHandler extends
DataSkippableReadHandler {
private int retryMax;
private long retryIntervalMax;
private ShuffleServerReadCostTracker readCostTracker;
+ private boolean nextReadSegmentsReportEnabled;
+ private long taskAttemptId;
public LocalFileClientReadHandler(
String appId,
@@ -63,7 +66,10 @@ public class LocalFileClientReadHandler extends
DataSkippableReadHandler {
int retryMax,
long retryIntervalMax,
Optional<PrefetchOption> prefetchOption,
- ShuffleServerReadCostTracker readCostTracker) {
+ ShuffleServerReadCostTracker readCostTracker,
+ boolean nextReadSegmentsReportEnabled,
+ int nextReadSegmentCount,
+ long taskAttemptId) {
super(
appId,
shuffleId,
@@ -73,13 +79,16 @@ public class LocalFileClientReadHandler extends
DataSkippableReadHandler {
processBlockIds,
distributionType,
expectTaskIds,
- prefetchOption);
+ prefetchOption,
+ nextReadSegmentCount);
this.shuffleServerClient = shuffleServerClient;
this.partitionNumPerRange = partitionNumPerRange;
this.partitionNum = partitionNum;
this.retryMax = retryMax;
this.retryIntervalMax = retryIntervalMax;
this.readCostTracker = readCostTracker;
+ this.nextReadSegmentsReportEnabled = nextReadSegmentsReportEnabled;
+ this.taskAttemptId = taskAttemptId;
}
@VisibleForTesting
@@ -110,7 +119,10 @@ public class LocalFileClientReadHandler extends
DataSkippableReadHandler {
1,
0,
Optional.empty(),
- new ShuffleServerReadCostTracker());
+ new ShuffleServerReadCostTracker(),
+ false,
+ 4,
+ 0);
}
@Override
@@ -144,7 +156,8 @@ public class LocalFileClientReadHandler extends
DataSkippableReadHandler {
}
@Override
- public ShuffleDataResult readShuffleData(ShuffleDataSegment
shuffleDataSegment) {
+ public ShuffleDataResult readShuffleData(
+ ShuffleDataSegment shuffleDataSegment, List<ShuffleDataSegment>
nextReadSegments) {
ShuffleDataResult result = null;
int expectedLength = shuffleDataSegment.getLength();
if (expectedLength <= 0) {
@@ -171,7 +184,10 @@ public class LocalFileClientReadHandler extends
DataSkippableReadHandler {
expectedLength,
shuffleDataSegment.getStorageId(),
retryMax,
- retryIntervalMax);
+ retryIntervalMax,
+ taskAttemptId,
+ nextReadSegments,
+ nextReadSegmentsReportEnabled);
try {
long start = System.currentTimeMillis();
RssGetShuffleDataResponse response =
shuffleServerClient.getShuffleData(request);
diff --git
a/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java
b/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java
index a0a6b6c6e..e8b93bb8b 100644
---
a/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java
+++
b/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java
@@ -59,6 +59,11 @@ public class CreateShuffleReadHandlerRequest {
private RssConf clientConf;
private ShuffleServerReadCostTracker readCostTracker;
+ private boolean nextReadSegmentsReportEnabled;
+ private int nextReadSegmentsReportCount;
+
+ private long taskAttemptId;
+
private IdHelper idHelper;
private ClientType clientType;
@@ -69,6 +74,30 @@ public class CreateShuffleReadHandlerRequest {
return rssBaseConf;
}
+ public int getNextReadSegmentsReportCount() {
+ return nextReadSegmentsReportCount;
+ }
+
+ public void setNextReadSegmentsReportCount(int nextReadSegmentsReportCount) {
+ this.nextReadSegmentsReportCount = nextReadSegmentsReportCount;
+ }
+
+ public long getTaskAttemptId() {
+ return taskAttemptId;
+ }
+
+ public void setTaskAttemptId(long taskAttemptId) {
+ this.taskAttemptId = taskAttemptId;
+ }
+
+ public boolean isNextReadSegmentsReportEnabled() {
+ return nextReadSegmentsReportEnabled;
+ }
+
+ public void setNextReadSegmentsReportEnabled(boolean
nextReadSegmentsReportEnabled) {
+ this.nextReadSegmentsReportEnabled = nextReadSegmentsReportEnabled;
+ }
+
public void setRssBaseConf(RssBaseConf rssBaseConf) {
this.rssBaseConf = rssBaseConf;
}
diff --git
a/storage/src/test/java/org/apache/uniffle/storage/handler/impl/DataSkippableReadHandlerTest.java
b/storage/src/test/java/org/apache/uniffle/storage/handler/impl/DataSkippableReadHandlerTest.java
new file mode 100644
index 000000000..d8ac8cc2b
--- /dev/null
+++
b/storage/src/test/java/org/apache/uniffle/storage/handler/impl/DataSkippableReadHandlerTest.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.storage.handler.impl;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.junit.jupiter.api.Test;
+
+import org.apache.uniffle.common.ShuffleDataSegment;
+
+import static
org.apache.uniffle.storage.handler.impl.DataSkippableReadHandler.getNextSegments;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+public class DataSkippableReadHandlerTest {
+
+ private List<ShuffleDataSegment> createSegments(int count, int length) {
+ List<ShuffleDataSegment> segments = new ArrayList<>();
+ int offset = 0;
+ for (int i = 0; i < count; i++) {
+ segments.add(new ShuffleDataSegment(offset, length, -1, null));
+ offset += length;
+ }
+ return segments;
+ }
+
+ @Test
+ public void testGetNextSegments() {
+ List<ShuffleDataSegment> segments = createSegments(10, 10);
+ List<ShuffleDataSegment> nexts = getNextSegments(segments, 0, 2);
+ assertEquals(2, nexts.size());
+
+ segments = createSegments(10, 10);
+ nexts = getNextSegments(segments, 10, 2);
+ assertEquals(0, nexts.size());
+ }
+}