This is an automated email from the ASF dual-hosted git repository.

zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new 96bf76cbc [#2591] feat(client): Introduce the mechanism to report 
localfile read plan (#2603)
96bf76cbc is described below

commit 96bf76cbc9d9bbe51c3e8a502ba37b32ae2ce8d6
Author: Junfan Zhang <[email protected]>
AuthorDate: Tue Sep 9 19:36:18 2025 +0800

    [#2591] feat(client): Introduce the mechanism to report localfile read plan 
(#2603)
    
    ### What changes were proposed in this pull request?
    
    This PR introduces a mechanism to report localfile read plan, and the 
changes only are scoped in the client side. More changes should be added in the 
shuffle server in the future.
    
    ### Why are the changes needed?
    
    For normal partitions, the reading mode is sequential, which makes 
read-ahead optimization feasible. This has already been verified in the Riffle 
project (see [issue #483](https://github.com/zuston/riffle/issues/483)).
    
    For huge partitions, however, the reading mode becomes skippable due to the 
AQE skew join optimization rule. In such cases, it is difficult to predict the 
next read position and length.
    
    Based on this analysis, we propose introducing a fixed read plan that is 
propagated from the client to the server, allowing the server to recognize the 
next read offset and thereby benefit from read-ahead optimization.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. And this feature will be disabled by default
    
    ### How was this patch tested?
    
    Existing tests.
---
 .../spark/shuffle/reader/RssShuffleReader.java     |   6 +-
 .../client/factory/ShuffleClientFactory.java       |  10 ++
 .../uniffle/client/impl/ShuffleReadClientImpl.java |   7 ++
 .../org/apache/uniffle/common/ReadSegment.java     |  51 +++++++++
 .../uniffle/common/config/RssClientConf.java       |  13 +++
 .../protocol/GetLocalShuffleDataV3Request.java     | 118 +++++++++++++++++++++
 .../uniffle/common/netty/protocol/Message.java     |   1 +
 .../client/impl/grpc/ShuffleServerGrpcClient.java  |  10 ++
 .../impl/grpc/ShuffleServerGrpcNettyClient.java    |  35 ++++--
 .../client/request/RssGetShuffleDataRequest.java   |  33 +++++-
 proto/src/main/proto/Rss.proto                     |   6 ++
 .../storage/factory/ShuffleHandlerFactory.java     |   5 +-
 .../handler/impl/DataSkippableReadHandler.java     |  30 +++++-
 .../handler/impl/HadoopShuffleReadHandler.java     |   6 +-
 .../handler/impl/LocalFileClientReadHandler.java   |  26 ++++-
 .../request/CreateShuffleReadHandlerRequest.java   |  29 +++++
 .../handler/impl/DataSkippableReadHandlerTest.java |  52 +++++++++
 17 files changed, 418 insertions(+), 20 deletions(-)

diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
index 09a2b58c2..f03aa9955 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
@@ -316,7 +316,11 @@ public class RssShuffleReader<K, C> implements 
ShuffleReader<K, C> {
                 
.expectedTaskIdsBitmapFilterEnable(expectedTaskIdsBitmapFilterEnable)
                 .retryMax(retryMax)
                 .retryIntervalMax(retryIntervalMax)
-                .rssConf(rssConf);
+                .rssConf(rssConf)
+                .taskAttemptId(
+                    Optional.ofNullable(TaskContext.get())
+                        .map(taskContext -> taskContext.taskAttemptId())
+                        .orElse(0L));
         if (codec.isPresent() && 
rssConf.get(RSS_READ_OVERLAPPING_DECOMPRESSION_ENABLED)) {
           builder
               .overlappingDecompressionEnabled(true)
diff --git 
a/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java
 
b/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java
index d4a49fa97..33f7df0ce 100644
--- 
a/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java
+++ 
b/client/src/main/java/org/apache/uniffle/client/factory/ShuffleClientFactory.java
@@ -227,6 +227,12 @@ public class ShuffleClientFactory {
     private int retryMax;
     private long retryIntervalMax;
     private ShuffleServerReadCostTracker readCostTracker;
+    private long taskAttemptId;
+
+    public ReadClientBuilder taskAttemptId(long taskAttemptId) {
+      this.taskAttemptId = taskAttemptId;
+      return this;
+    }
 
     private boolean overlappingDecompressionEnabled;
     private int overlappingDecompressionThreadNum;
@@ -463,6 +469,10 @@ public class ShuffleClientFactory {
       return codec;
     }
 
+    public long getTaskAttemptId() {
+      return taskAttemptId;
+    }
+
     public ShuffleReadClientImpl build() {
       return new ShuffleReadClientImpl(this);
     }
diff --git 
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
 
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
index e6ffc09d6..8586d3493 100644
--- 
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
+++ 
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
@@ -57,6 +57,9 @@ import 
org.apache.uniffle.storage.handler.api.ClientReadHandler;
 import org.apache.uniffle.storage.handler.impl.ShuffleServerReadCostTracker;
 import org.apache.uniffle.storage.request.CreateShuffleReadHandlerRequest;
 
+import static 
org.apache.uniffle.common.config.RssClientConf.READ_CLIENT_NEXT_SEGMENTS_REPORT_COUNT;
+import static 
org.apache.uniffle.common.config.RssClientConf.READ_CLIENT_NEXT_SEGMENTS_REPORT_ENABLED;
+
 public class ShuffleReadClientImpl implements ShuffleReadClient {
 
   private static final Logger LOG = 
LoggerFactory.getLogger(ShuffleReadClientImpl.class);
@@ -171,6 +174,10 @@ public class ShuffleReadClientImpl implements 
ShuffleReadClient {
     this.readCostTracker = builder.getReadCostTracker();
 
     CreateShuffleReadHandlerRequest request = new 
CreateShuffleReadHandlerRequest();
+    request.setNextReadSegmentsReportCount(
+        builder.getRssConf().get(READ_CLIENT_NEXT_SEGMENTS_REPORT_COUNT));
+    request.setNextReadSegmentsReportEnabled(
+        builder.getRssConf().get(READ_CLIENT_NEXT_SEGMENTS_REPORT_ENABLED));
     request.setStorageType(builder.getStorageType());
     request.setAppId(builder.getAppId());
     request.setShuffleId(shuffleId);
diff --git a/common/src/main/java/org/apache/uniffle/common/ReadSegment.java 
b/common/src/main/java/org/apache/uniffle/common/ReadSegment.java
new file mode 100644
index 000000000..552e4d552
--- /dev/null
+++ b/common/src/main/java/org/apache/uniffle/common/ReadSegment.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common;
+
+import java.util.ArrayList;
+import java.util.List;
+
+public class ReadSegment {
+  private final long offset;
+  private final long length;
+
+  public ReadSegment(long offset, long length) {
+    this.offset = offset;
+    this.length = length;
+  }
+
+  public long getOffset() {
+    return offset;
+  }
+
+  public long getLength() {
+    return length;
+  }
+
+  public static ReadSegment from(ShuffleDataSegment segment) {
+    return new ReadSegment(segment.getOffset(), segment.getLength());
+  }
+
+  public static List<ReadSegment> from(List<ShuffleDataSegment> segments) {
+    List<ReadSegment> readSegments = new ArrayList<ReadSegment>();
+    for (ShuffleDataSegment segment : segments) {
+      readSegments.add(ReadSegment.from(segment));
+    }
+    return readSegments;
+  }
+}
diff --git 
a/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java 
b/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java
index 0c3a465f2..a7f67e1bc 100644
--- a/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java
+++ b/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java
@@ -368,4 +368,17 @@ public class RssClientConf {
           .intType()
           .defaultValue(120)
           .withDescription("Read prefetch timeout seconds");
+
+  public static final ConfigOption<Boolean> 
READ_CLIENT_NEXT_SEGMENTS_REPORT_ENABLED =
+      ConfigOptions.key("rss.client.read.nextReadSegmentsReportEnabled")
+          .booleanType()
+          .defaultValue(false)
+          .withDescription(
+              "Whether the next read segment report is enabled for 
shuffle-server read ahead");
+
+  public static final ConfigOption<Integer> 
READ_CLIENT_NEXT_SEGMENTS_REPORT_COUNT =
+      ConfigOptions.key("rss.client.read.nextReadSegmentsReportCount")
+          .intType()
+          .defaultValue(4)
+          .withDescription("Next read segment count for shuffle-server read 
ahead");
 }
diff --git 
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleDataV3Request.java
 
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleDataV3Request.java
new file mode 100644
index 000000000..993b80c51
--- /dev/null
+++ 
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleDataV3Request.java
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.netty.protocol;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import io.netty.buffer.ByteBuf;
+
+import org.apache.uniffle.common.ReadSegment;
+import org.apache.uniffle.common.util.ByteBufUtils;
+
+public class GetLocalShuffleDataV3Request extends GetLocalShuffleDataV2Request 
{
+  private final List<ReadSegment> nextReadSegments;
+  private final long taskAttemptId;
+
+  public GetLocalShuffleDataV3Request(
+      long requestId,
+      String appId,
+      int shuffleId,
+      int partitionId,
+      int partitionNumPerRange,
+      int partitionNum,
+      long offset,
+      int length,
+      int storageId,
+      List<ReadSegment> nextReadSegments,
+      long timestamp,
+      long taskAttemptId) {
+    super(
+        requestId,
+        appId,
+        shuffleId,
+        partitionId,
+        partitionNumPerRange,
+        partitionNum,
+        offset,
+        length,
+        storageId,
+        timestamp);
+    this.nextReadSegments = nextReadSegments;
+    this.taskAttemptId = taskAttemptId;
+  }
+
+  @Override
+  public Type type() {
+    return Type.GET_LOCAL_SHUFFLE_DATA_V3_REQUEST;
+  }
+
+  @Override
+  public int encodedLength() {
+    return super.encodedLength() + Long.BYTES * 2 * nextReadSegments.size();
+  }
+
+  @Override
+  public void encode(ByteBuf buf) {
+    super.encode(buf);
+    buf.writeInt(nextReadSegments.size());
+    for (ReadSegment segment : nextReadSegments) {
+      buf.writeLong(segment.getOffset());
+      buf.writeLong(segment.getLength());
+    }
+    buf.writeLong(taskAttemptId);
+  }
+
+  public static GetLocalShuffleDataV3Request decode(ByteBuf byteBuf) {
+    long requestId = byteBuf.readLong();
+    String appId = ByteBufUtils.readLengthAndString(byteBuf);
+    int shuffleId = byteBuf.readInt();
+    int partitionId = byteBuf.readInt();
+    int partitionNumPerRange = byteBuf.readInt();
+    int partitionNum = byteBuf.readInt();
+    long offset = byteBuf.readLong();
+    int length = byteBuf.readInt();
+    long timestamp = byteBuf.readLong();
+    int storageId = byteBuf.readInt();
+
+    int readSegmentCount = byteBuf.readInt();
+    List<ReadSegment> readSegments = new ArrayList<>(readSegmentCount);
+    for (int i = 0; i < readSegmentCount; i++) {
+      readSegments.add(new ReadSegment(byteBuf.readLong(), 
byteBuf.readLong()));
+    }
+    long taskAttemptId = byteBuf.readLong();
+    return new GetLocalShuffleDataV3Request(
+        requestId,
+        appId,
+        shuffleId,
+        partitionId,
+        partitionNumPerRange,
+        partitionNum,
+        offset,
+        length,
+        storageId,
+        readSegments,
+        timestamp,
+        taskAttemptId);
+  }
+
+  @Override
+  public String getOperationType() {
+    return "getLocalShuffleDataV3";
+  }
+}
diff --git 
a/common/src/main/java/org/apache/uniffle/common/netty/protocol/Message.java 
b/common/src/main/java/org/apache/uniffle/common/netty/protocol/Message.java
index 2ad0b0d77..5fcb47bd0 100644
--- a/common/src/main/java/org/apache/uniffle/common/netty/protocol/Message.java
+++ b/common/src/main/java/org/apache/uniffle/common/netty/protocol/Message.java
@@ -67,6 +67,7 @@ public abstract class Message implements Encodable {
     GET_SORTED_SHUFFLE_DATA_RESPONSE(22),
     GET_LOCAL_SHUFFLE_INDEX_V2_RESPONSE(23),
     GET_LOCAL_SHUFFLE_DATA_V2_REQUEST(24),
+    GET_LOCAL_SHUFFLE_DATA_V3_REQUEST(25),
     ;
 
     private final byte id;
diff --git 
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
 
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
index 1d9d7b0d0..0a5259284 100644
--- 
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
+++ 
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
@@ -26,6 +26,7 @@ import java.util.Random;
 import java.util.Set;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicReference;
+import java.util.stream.Collectors;
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.Lists;
@@ -958,6 +959,15 @@ public class ShuffleServerGrpcClient extends GrpcClient 
implements ShuffleServer
             .setLength(request.getLength())
             .setTimestamp(start)
             .setStorageId(request.getStorageId())
+            .addAllNextReadSegments(
+                request.getNextReadSegments().stream()
+                    .map(
+                        x ->
+                            RssProtos.ReadSegment.newBuilder()
+                                .setLength(x.getLength())
+                                .setOffset(x.getOffset())
+                                .build())
+                    .collect(Collectors.toList()))
             .build();
     String requestInfo =
         "appId["
diff --git 
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java
 
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java
index 1f014c8b2..080fed297 100644
--- 
a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java
+++ 
b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java
@@ -43,6 +43,7 @@ import 
org.apache.uniffle.client.response.RssGetShuffleIndexResponse;
 import org.apache.uniffle.client.response.RssGetSortedShuffleDataResponse;
 import org.apache.uniffle.client.response.RssSendShuffleDataResponse;
 import org.apache.uniffle.common.ClientType;
+import org.apache.uniffle.common.ReadSegment;
 import org.apache.uniffle.common.ShuffleBlockInfo;
 import org.apache.uniffle.common.ShuffleServerInfo;
 import org.apache.uniffle.common.config.RssClientConf;
@@ -57,6 +58,7 @@ import 
org.apache.uniffle.common.netty.client.TransportContext;
 import org.apache.uniffle.common.netty.protocol.GetLocalShuffleDataRequest;
 import org.apache.uniffle.common.netty.protocol.GetLocalShuffleDataResponse;
 import org.apache.uniffle.common.netty.protocol.GetLocalShuffleDataV2Request;
+import org.apache.uniffle.common.netty.protocol.GetLocalShuffleDataV3Request;
 import org.apache.uniffle.common.netty.protocol.GetLocalShuffleIndexRequest;
 import org.apache.uniffle.common.netty.protocol.GetLocalShuffleIndexResponse;
 import org.apache.uniffle.common.netty.protocol.GetMemoryShuffleDataRequest;
@@ -387,9 +389,11 @@ public class ShuffleServerGrpcNettyClient extends 
ShuffleServerGrpcClient {
   public RssGetShuffleDataResponse getShuffleData(RssGetShuffleDataRequest 
request) {
     TransportClient transportClient = getTransportClient();
     // Construct old version or v2 get shuffle data request to compatible with 
old server
-    GetLocalShuffleDataRequest getLocalShuffleIndexRequest =
-        request.storageIdSpecified()
-            ? new GetLocalShuffleDataV2Request(
+    GetLocalShuffleDataRequest getLocalShuffleDataRequest = null;
+    if (request.storageIdSpecified()) {
+      if (request.isNextReadSegmentsReportEnabled()) {
+        getLocalShuffleDataRequest =
+            new GetLocalShuffleDataV3Request(
                 requestId(),
                 request.getAppId(),
                 request.getShuffleId(),
@@ -399,8 +403,12 @@ public class ShuffleServerGrpcNettyClient extends 
ShuffleServerGrpcClient {
                 request.getOffset(),
                 request.getLength(),
                 request.getStorageId(),
-                System.currentTimeMillis())
-            : new GetLocalShuffleDataRequest(
+                ReadSegment.from(request.getNextReadSegments()),
+                System.currentTimeMillis(),
+                request.getTaskAttemptId());
+      } else {
+        getLocalShuffleDataRequest =
+            new GetLocalShuffleDataV2Request(
                 requestId(),
                 request.getAppId(),
                 request.getShuffleId(),
@@ -409,7 +417,22 @@ public class ShuffleServerGrpcNettyClient extends 
ShuffleServerGrpcClient {
                 request.getPartitionNum(),
                 request.getOffset(),
                 request.getLength(),
+                request.getStorageId(),
                 System.currentTimeMillis());
+      }
+    } else {
+      getLocalShuffleDataRequest =
+          new GetLocalShuffleDataRequest(
+              requestId(),
+              request.getAppId(),
+              request.getShuffleId(),
+              request.getPartitionId(),
+              request.getPartitionNumPerRange(),
+              request.getPartitionNum(),
+              request.getOffset(),
+              request.getLength(),
+              System.currentTimeMillis());
+    }
     String requestInfo =
         "appId["
             + request.getAppId()
@@ -423,7 +446,7 @@ public class ShuffleServerGrpcNettyClient extends 
ShuffleServerGrpcClient {
     RpcResponse rpcResponse;
     GetLocalShuffleDataResponse getLocalShuffleDataResponse;
     while (true) {
-      rpcResponse = transportClient.sendRpcSync(getLocalShuffleIndexRequest, 
rpcTimeout);
+      rpcResponse = transportClient.sendRpcSync(getLocalShuffleDataRequest, 
rpcTimeout);
       getLocalShuffleDataResponse = (GetLocalShuffleDataResponse) rpcResponse;
       if (rpcResponse.getStatusCode() != StatusCode.NO_BUFFER) {
         break;
diff --git 
a/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleDataRequest.java
 
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleDataRequest.java
index c245e48b7..d401e12fc 100644
--- 
a/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleDataRequest.java
+++ 
b/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleDataRequest.java
@@ -17,8 +17,13 @@
 
 package org.apache.uniffle.client.request;
 
+import java.util.Collections;
+import java.util.List;
+
 import com.google.common.annotations.VisibleForTesting;
 
+import org.apache.uniffle.common.ShuffleDataSegment;
+
 public class RssGetShuffleDataRequest extends RetryableRequest {
 
   private final String appId;
@@ -29,6 +34,9 @@ public class RssGetShuffleDataRequest extends 
RetryableRequest {
   private final long offset;
   private final int length;
   private final int storageId;
+  private final long taskAttemptId;
+  private final List<ShuffleDataSegment> nextReadSegments;
+  private final boolean nextReadSegmentsReportEnabled;
 
   public RssGetShuffleDataRequest(
       String appId,
@@ -40,7 +48,10 @@ public class RssGetShuffleDataRequest extends 
RetryableRequest {
       int length,
       int storageId,
       int retryMax,
-      long retryIntervalMax) {
+      long retryIntervalMax,
+      long taskAttemptId,
+      List<ShuffleDataSegment> nextReadSegments,
+      boolean nextReadSegmentsReportEnabled) {
     this.appId = appId;
     this.shuffleId = shuffleId;
     this.partitionId = partitionId;
@@ -51,6 +62,9 @@ public class RssGetShuffleDataRequest extends 
RetryableRequest {
     this.storageId = storageId;
     this.retryMax = retryMax;
     this.retryIntervalMax = retryIntervalMax;
+    this.nextReadSegments = nextReadSegments;
+    this.nextReadSegmentsReportEnabled = nextReadSegmentsReportEnabled;
+    this.taskAttemptId = taskAttemptId;
   }
 
   @VisibleForTesting
@@ -72,7 +86,10 @@ public class RssGetShuffleDataRequest extends 
RetryableRequest {
         length,
         -1,
         1,
-        0);
+        0,
+        0,
+        Collections.emptyList(),
+        false);
   }
 
   public String getAppId() {
@@ -111,6 +128,18 @@ public class RssGetShuffleDataRequest extends 
RetryableRequest {
     return storageId != -1;
   }
 
+  public List<ShuffleDataSegment> getNextReadSegments() {
+    return nextReadSegments;
+  }
+
+  public boolean isNextReadSegmentsReportEnabled() {
+    return nextReadSegmentsReportEnabled;
+  }
+
+  public long getTaskAttemptId() {
+    return taskAttemptId;
+  }
+
   @Override
   public String operationType() {
     return "GetShuffleData";
diff --git a/proto/src/main/proto/Rss.proto b/proto/src/main/proto/Rss.proto
index 2967d98c0..c3db61e39 100644
--- a/proto/src/main/proto/Rss.proto
+++ b/proto/src/main/proto/Rss.proto
@@ -87,6 +87,12 @@ message GetLocalShuffleDataRequest {
   int32 length = 7;
   int64 timestamp = 8;
   int32 storageId = 9;
+  repeated ReadSegment nextReadSegments = 10;
+}
+
+message ReadSegment {
+  int64 offset = 1;
+  int64 length = 2;
 }
 
 message GetLocalShuffleDataResponse {
diff --git 
a/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
 
b/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
index 356cfc1dd..e01ba1d1f 100644
--- 
a/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
+++ 
b/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
@@ -165,7 +165,10 @@ public class ShuffleHandlerFactory {
         request.getRetryMax(),
         request.getRetryIntervalMax(),
         request.getPrefetchOption(),
-        request.getReadCostTracker());
+        request.getReadCostTracker(),
+        request.isNextReadSegmentsReportEnabled(),
+        request.getNextReadSegmentsReportCount(),
+        request.getTaskAttemptId());
   }
 
   private ClientReadHandler getHadoopClientReadHandler(
diff --git 
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/DataSkippableReadHandler.java
 
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/DataSkippableReadHandler.java
index 58288d53c..1ecb32f23 100644
--- 
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/DataSkippableReadHandler.java
+++ 
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/DataSkippableReadHandler.java
@@ -22,6 +22,7 @@ import java.util.List;
 import java.util.Optional;
 import java.util.Set;
 
+import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.Lists;
 import org.roaringbitmap.longlong.Roaring64NavigableMap;
 import org.slf4j.Logger;
@@ -35,6 +36,7 @@ import 
org.apache.uniffle.common.segment.SegmentSplitterFactory;
 
 public abstract class DataSkippableReadHandler extends 
PrefetchableClientReadHandler {
   private static final Logger LOG = 
LoggerFactory.getLogger(DataSkippableReadHandler.class);
+  private static final int DEFAULT_NEXT_READ_BATCH_NUMBER = 4;
 
   protected List<ShuffleDataSegment> shuffleDataSegments = 
Lists.newArrayList();
   protected int segmentIndex = 0;
@@ -45,6 +47,8 @@ public abstract class DataSkippableReadHandler extends 
PrefetchableClientReadHan
   protected ShuffleDataDistributionType distributionType;
   protected Roaring64NavigableMap expectTaskIds;
 
+  private int nextReadSegmentReportCount;
+
   public DataSkippableReadHandler(
       String appId,
       int shuffleId,
@@ -54,7 +58,8 @@ public abstract class DataSkippableReadHandler extends 
PrefetchableClientReadHan
       Set<Long> processBlockIds,
       ShuffleDataDistributionType distributionType,
       Roaring64NavigableMap expectTaskIds,
-      Optional<PrefetchOption> prefetchOption) {
+      Optional<PrefetchOption> prefetchOption,
+      int nextReadSegmentReportCount) {
     super(prefetchOption);
     this.appId = appId;
     this.shuffleId = shuffleId;
@@ -64,11 +69,13 @@ public abstract class DataSkippableReadHandler extends 
PrefetchableClientReadHan
     this.processBlockIds = processBlockIds;
     this.distributionType = distributionType;
     this.expectTaskIds = expectTaskIds;
+    this.nextReadSegmentReportCount = nextReadSegmentReportCount;
   }
 
   protected abstract ShuffleIndexResult readShuffleIndex();
 
-  protected abstract ShuffleDataResult readShuffleData(ShuffleDataSegment 
segment);
+  protected abstract ShuffleDataResult readShuffleData(
+      ShuffleDataSegment segment, List<ShuffleDataSegment> nextReadSegments);
 
   @Override
   public ShuffleDataResult doReadShuffleData() {
@@ -100,7 +107,11 @@ public abstract class DataSkippableReadHandler extends 
PrefetchableClientReadHan
         // skip processed blockIds
         blocksOfSegment.removeAll(processBlockIds);
         if (!blocksOfSegment.isEmpty()) {
-          result = readShuffleData(segment);
+          result =
+              readShuffleData(
+                  segment,
+                  getNextSegments(
+                      shuffleDataSegments, segmentIndex + 1, 
nextReadSegmentReportCount));
           segmentIndex++;
           break;
         }
@@ -109,4 +120,17 @@ public abstract class DataSkippableReadHandler extends 
PrefetchableClientReadHan
     }
     return result;
   }
+
+  @VisibleForTesting
+  protected static List<ShuffleDataSegment> getNextSegments(
+      List<ShuffleDataSegment> shuffleDataSegments, int startIndex, int 
number) {
+    List<ShuffleDataSegment> nextSegments = Lists.newArrayList();
+    for (int i = startIndex; i < shuffleDataSegments.size(); i++) {
+      if (nextSegments.size() >= number) {
+        break;
+      }
+      nextSegments.add(shuffleDataSegments.get(i));
+    }
+    return nextSegments;
+  }
 }
diff --git 
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/HadoopShuffleReadHandler.java
 
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/HadoopShuffleReadHandler.java
index f2153d2be..09372c32f 100644
--- 
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/HadoopShuffleReadHandler.java
+++ 
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/HadoopShuffleReadHandler.java
@@ -71,7 +71,8 @@ public class HadoopShuffleReadHandler extends 
DataSkippableReadHandler {
         processBlockIds,
         distributionType,
         expectTaskIds,
-        prefetchOption);
+        prefetchOption,
+        4);
     this.filePrefix = filePrefix;
     this.indexReader =
         
createHadoopReader(ShuffleStorageUtils.generateIndexFileName(filePrefix), conf);
@@ -135,7 +136,8 @@ public class HadoopShuffleReadHandler extends 
DataSkippableReadHandler {
     return new ShuffleIndexResult();
   }
 
-  protected ShuffleDataResult readShuffleData(ShuffleDataSegment 
shuffleDataSegment) {
+  protected ShuffleDataResult readShuffleData(
+      ShuffleDataSegment shuffleDataSegment, List<ShuffleDataSegment> 
nextReadSegments) {
     // Here we make an assumption that the rest of the file is corrupted, if 
an unexpected data is
     // read.
     int expectedLength = shuffleDataSegment.getLength();
diff --git 
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileClientReadHandler.java
 
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileClientReadHandler.java
index cfdf5fdb7..9405cad62 100644
--- 
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileClientReadHandler.java
+++ 
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileClientReadHandler.java
@@ -17,6 +17,7 @@
 
 package org.apache.uniffle.storage.handler.impl;
 
+import java.util.List;
 import java.util.Optional;
 import java.util.Set;
 
@@ -46,6 +47,8 @@ public class LocalFileClientReadHandler extends 
DataSkippableReadHandler {
   private int retryMax;
   private long retryIntervalMax;
   private ShuffleServerReadCostTracker readCostTracker;
+  private boolean nextReadSegmentsReportEnabled;
+  private long taskAttemptId;
 
   public LocalFileClientReadHandler(
       String appId,
@@ -63,7 +66,10 @@ public class LocalFileClientReadHandler extends 
DataSkippableReadHandler {
       int retryMax,
       long retryIntervalMax,
       Optional<PrefetchOption> prefetchOption,
-      ShuffleServerReadCostTracker readCostTracker) {
+      ShuffleServerReadCostTracker readCostTracker,
+      boolean nextReadSegmentsReportEnabled,
+      int nextReadSegmentCount,
+      long taskAttemptId) {
     super(
         appId,
         shuffleId,
@@ -73,13 +79,16 @@ public class LocalFileClientReadHandler extends 
DataSkippableReadHandler {
         processBlockIds,
         distributionType,
         expectTaskIds,
-        prefetchOption);
+        prefetchOption,
+        nextReadSegmentCount);
     this.shuffleServerClient = shuffleServerClient;
     this.partitionNumPerRange = partitionNumPerRange;
     this.partitionNum = partitionNum;
     this.retryMax = retryMax;
     this.retryIntervalMax = retryIntervalMax;
     this.readCostTracker = readCostTracker;
+    this.nextReadSegmentsReportEnabled = nextReadSegmentsReportEnabled;
+    this.taskAttemptId = taskAttemptId;
   }
 
   @VisibleForTesting
@@ -110,7 +119,10 @@ public class LocalFileClientReadHandler extends 
DataSkippableReadHandler {
         1,
         0,
         Optional.empty(),
-        new ShuffleServerReadCostTracker());
+        new ShuffleServerReadCostTracker(),
+        false,
+        4,
+        0);
   }
 
   @Override
@@ -144,7 +156,8 @@ public class LocalFileClientReadHandler extends 
DataSkippableReadHandler {
   }
 
   @Override
-  public ShuffleDataResult readShuffleData(ShuffleDataSegment 
shuffleDataSegment) {
+  public ShuffleDataResult readShuffleData(
+      ShuffleDataSegment shuffleDataSegment, List<ShuffleDataSegment> 
nextReadSegments) {
     ShuffleDataResult result = null;
     int expectedLength = shuffleDataSegment.getLength();
     if (expectedLength <= 0) {
@@ -171,7 +184,10 @@ public class LocalFileClientReadHandler extends 
DataSkippableReadHandler {
             expectedLength,
             shuffleDataSegment.getStorageId(),
             retryMax,
-            retryIntervalMax);
+            retryIntervalMax,
+            taskAttemptId,
+            nextReadSegments,
+            nextReadSegmentsReportEnabled);
     try {
       long start = System.currentTimeMillis();
       RssGetShuffleDataResponse response = 
shuffleServerClient.getShuffleData(request);
diff --git 
a/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java
 
b/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java
index a0a6b6c6e..e8b93bb8b 100644
--- 
a/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java
+++ 
b/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java
@@ -59,6 +59,11 @@ public class CreateShuffleReadHandlerRequest {
   private RssConf clientConf;
   private ShuffleServerReadCostTracker readCostTracker;
 
+  private boolean nextReadSegmentsReportEnabled;
+  private int nextReadSegmentsReportCount;
+
+  private long taskAttemptId;
+
   private IdHelper idHelper;
 
   private ClientType clientType;
@@ -69,6 +74,30 @@ public class CreateShuffleReadHandlerRequest {
     return rssBaseConf;
   }
 
+  public int getNextReadSegmentsReportCount() {
+    return nextReadSegmentsReportCount;
+  }
+
+  public void setNextReadSegmentsReportCount(int nextReadSegmentsReportCount) {
+    this.nextReadSegmentsReportCount = nextReadSegmentsReportCount;
+  }
+
+  public long getTaskAttemptId() {
+    return taskAttemptId;
+  }
+
+  public void setTaskAttemptId(long taskAttemptId) {
+    this.taskAttemptId = taskAttemptId;
+  }
+
+  public boolean isNextReadSegmentsReportEnabled() {
+    return nextReadSegmentsReportEnabled;
+  }
+
+  public void setNextReadSegmentsReportEnabled(boolean 
nextReadSegmentsReportEnabled) {
+    this.nextReadSegmentsReportEnabled = nextReadSegmentsReportEnabled;
+  }
+
   public void setRssBaseConf(RssBaseConf rssBaseConf) {
     this.rssBaseConf = rssBaseConf;
   }
diff --git 
a/storage/src/test/java/org/apache/uniffle/storage/handler/impl/DataSkippableReadHandlerTest.java
 
b/storage/src/test/java/org/apache/uniffle/storage/handler/impl/DataSkippableReadHandlerTest.java
new file mode 100644
index 000000000..d8ac8cc2b
--- /dev/null
+++ 
b/storage/src/test/java/org/apache/uniffle/storage/handler/impl/DataSkippableReadHandlerTest.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.storage.handler.impl;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.junit.jupiter.api.Test;
+
+import org.apache.uniffle.common.ShuffleDataSegment;
+
+import static 
org.apache.uniffle.storage.handler.impl.DataSkippableReadHandler.getNextSegments;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+public class DataSkippableReadHandlerTest {
+
+  private List<ShuffleDataSegment> createSegments(int count, int length) {
+    List<ShuffleDataSegment> segments = new ArrayList<>();
+    int offset = 0;
+    for (int i = 0; i < count; i++) {
+      segments.add(new ShuffleDataSegment(offset, length, -1, null));
+      offset += length;
+    }
+    return segments;
+  }
+
+  @Test
+  public void testGetNextSegments() {
+    List<ShuffleDataSegment> segments = createSegments(10, 10);
+    List<ShuffleDataSegment> nexts = getNextSegments(segments, 0, 2);
+    assertEquals(2, nexts.size());
+
+    segments = createSegments(10, 10);
+    nexts = getNextSegments(segments, 10, 2);
+    assertEquals(0, nexts.size());
+  }
+}


Reply via email to