This is an automated email from the ASF dual-hosted git repository.

rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 9244cf2cf [CELEBORN-772] Convert StreamChunkSlice, ChunkFetchRequest, 
TransportableError to PB
9244cf2cf is described below

commit 9244cf2cf2a6f319f6a6aa78271a2f6cc48b7de3
Author: SteNicholas <[email protected]>
AuthorDate: Tue Oct 17 11:12:01 2023 +0800

    [CELEBORN-772] Convert StreamChunkSlice, ChunkFetchRequest, 
TransportableError to PB
    
    ### What changes were proposed in this pull request?
    
    `StreamChunkSlice`, `ChunkFetchRequest` and `TransportableError` should 
merge to transport messages to enhance celeborn's compatibility.
    
    ### Why are the changes needed?
    
    1. Improves celeborn's transport flexibility to change RPC.
    2. Makes Compatible with 0.2 client.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    - `FetchHandlerSuiteJ`
    - `RequestTimeoutIntegrationSuiteJ`
    - `ChunkFetchIntegrationSuiteJ`
    
    Closes #1982 from SteNicholas/CELEBORN-772.
    
    Authored-by: SteNicholas <[email protected]>
    Signed-off-by: Shuang <[email protected]>
---
 .../plugin/flink/network/ReadClientHandler.java    |  4 ++
 .../common/network/client/TransportClient.java     | 23 ++++++++-
 .../common/network/protocol/ChunkFetchFailure.java |  4 +-
 .../common/network/protocol/ChunkFetchRequest.java |  1 +
 .../common/network/protocol/ChunkFetchSuccess.java |  3 +-
 .../common/network/protocol/StreamChunkSlice.java  | 15 ++++++
 .../common/network/protocol/TransportMessage.java  |  9 ++++
 .../network/protocol/TransportableError.java       |  6 +++
 common/src/main/proto/TransportMessages.proto      | 19 +++++++
 .../service/deploy/worker/FetchHandler.scala       | 60 ++++++++++++----------
 .../service/deploy/worker/FetchHandlerSuiteJ.java  | 21 ++++++--
 .../network/RequestTimeoutIntegrationSuiteJ.java   | 18 ++++++-
 .../storage/ChunkFetchIntegrationSuiteJ.java       | 14 ++++-
 13 files changed, 159 insertions(+), 38 deletions(-)

diff --git 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/ReadClientHandler.java
 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/ReadClientHandler.java
index 5c100002c..b55e3a618 100644
--- 
a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/ReadClientHandler.java
+++ 
b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/network/ReadClientHandler.java
@@ -19,6 +19,7 @@ package org.apache.celeborn.plugin.flink.network;
 
 import static 
org.apache.celeborn.common.protocol.MessageType.BACKLOG_ANNOUNCEMENT_VALUE;
 import static 
org.apache.celeborn.common.protocol.MessageType.BUFFER_STREAM_END_VALUE;
+import static 
org.apache.celeborn.common.protocol.MessageType.TRANSPORTABLE_ERROR_VALUE;
 
 import java.io.IOException;
 import java.nio.charset.StandardCharsets;
@@ -103,6 +104,9 @@ public class ReadClientHandler extends BaseMessageHandler {
             case BUFFER_STREAM_END_VALUE:
               receive(client, 
BufferStreamEnd.fromProto(transportMessage.getParsedPayload()));
               break;
+            case TRANSPORTABLE_ERROR_VALUE:
+              receive(client, 
TransportableError.fromProto(transportMessage.getParsedPayload()));
+              break;
           }
         } catch (IOException e) {
           logger.warn("Failed to process RpcRequest message {}. ", msg, e);
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java
 
b/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java
index 151895607..7865732c3 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/client/TransportClient.java
@@ -36,8 +36,15 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import org.apache.celeborn.common.network.buffer.NioManagedBuffer;
-import org.apache.celeborn.common.network.protocol.*;
+import org.apache.celeborn.common.network.protocol.OneWayMessage;
+import org.apache.celeborn.common.network.protocol.PushData;
+import org.apache.celeborn.common.network.protocol.PushMergedData;
+import org.apache.celeborn.common.network.protocol.RpcRequest;
+import org.apache.celeborn.common.network.protocol.StreamChunkSlice;
+import org.apache.celeborn.common.network.protocol.TransportMessage;
 import org.apache.celeborn.common.network.util.NettyUtils;
+import org.apache.celeborn.common.protocol.MessageType;
+import org.apache.celeborn.common.protocol.PbChunkFetchRequest;
 import org.apache.celeborn.common.read.FetchRequestInfo;
 import org.apache.celeborn.common.write.PushRequestInfo;
 
@@ -140,7 +147,19 @@ public class TransportClient implements Closeable {
     handler.addFetchRequest(streamChunkSlice, info);
 
     ChannelFuture channelFuture =
-        channel.writeAndFlush(new 
ChunkFetchRequest(streamChunkSlice)).addListener(listener);
+        channel
+            .writeAndFlush(
+                new RpcRequest(
+                    TransportClient.requestId(),
+                    new NioManagedBuffer(
+                        new TransportMessage(
+                                MessageType.CHUNK_FETCH_REQUEST,
+                                PbChunkFetchRequest.newBuilder()
+                                    
.setStreamChunkSlice(streamChunkSlice.toProto())
+                                    .build()
+                                    .toByteArray())
+                            .toByteBuffer())))
+            .addListener(listener);
     info.setChannelFuture(channelFuture);
   }
 
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/ChunkFetchFailure.java
 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/ChunkFetchFailure.java
index 3100824d6..532d0bab0 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/ChunkFetchFailure.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/ChunkFetchFailure.java
@@ -20,7 +20,9 @@ package org.apache.celeborn.common.network.protocol;
 import com.google.common.base.Objects;
 import io.netty.buffer.ByteBuf;
 
-/** Response to {@link ChunkFetchRequest} when there is an error fetching the 
chunk. */
+import org.apache.celeborn.common.protocol.PbChunkFetchRequest;
+
+/** Response to {@link PbChunkFetchRequest} when there is an error fetching 
the chunk. */
 public final class ChunkFetchFailure extends ResponseMessage {
   public final StreamChunkSlice streamChunkSlice;
   public final String errorString;
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/ChunkFetchRequest.java
 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/ChunkFetchRequest.java
index 2081000ca..28672eac1 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/ChunkFetchRequest.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/ChunkFetchRequest.java
@@ -24,6 +24,7 @@ import io.netty.buffer.ByteBuf;
  * Request to fetch a sequence of a single chunk of a stream. This will 
correspond to a single
  * {@link ResponseMessage} (either success or failure).
  */
+@Deprecated
 public final class ChunkFetchRequest extends RequestMessage {
   public final StreamChunkSlice streamChunkSlice;
 
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/ChunkFetchSuccess.java
 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/ChunkFetchSuccess.java
index baa663ce6..7d5992003 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/ChunkFetchSuccess.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/ChunkFetchSuccess.java
@@ -22,9 +22,10 @@ import io.netty.buffer.ByteBuf;
 
 import org.apache.celeborn.common.network.buffer.ManagedBuffer;
 import org.apache.celeborn.common.network.buffer.NettyManagedBuffer;
+import org.apache.celeborn.common.protocol.PbChunkFetchRequest;
 
 /**
- * Response to {@link ChunkFetchRequest} when a chunk exists and has been 
successfully fetched.
+ * Response to {@link PbChunkFetchRequest} when a chunk exists and has been 
successfully fetched.
  *
  * <p>Note that the server-side encoding of this message does NOT include the 
buffer itself, as this
  * may be written by Netty in a more efficient manner (i.e., zero-copy write). 
Similarly, the
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/StreamChunkSlice.java
 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/StreamChunkSlice.java
index a3918e4fd..4771faf87 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/StreamChunkSlice.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/StreamChunkSlice.java
@@ -20,6 +20,8 @@ package org.apache.celeborn.common.network.protocol;
 import com.google.common.base.Objects;
 import io.netty.buffer.ByteBuf;
 
+import org.apache.celeborn.common.protocol.PbStreamChunkSlice;
+
 /** Encapsulates a request for a particular chunk of a stream. */
 public final class StreamChunkSlice implements Encodable {
   public final long streamId;
@@ -90,4 +92,17 @@ public final class StreamChunkSlice implements Encodable {
         .add("len", len)
         .toString();
   }
+
+  public PbStreamChunkSlice toProto() {
+    return PbStreamChunkSlice.newBuilder()
+        .setStreamId(streamId)
+        .setChunkIndex(chunkIndex)
+        .setOffset(offset)
+        .setLen(len)
+        .build();
+  }
+
+  public static StreamChunkSlice fromProto(PbStreamChunkSlice pb) {
+    return new StreamChunkSlice(pb.getStreamId(), pb.getChunkIndex(), 
pb.getOffset(), pb.getLen());
+  }
 }
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
index 87e59151e..8fa07a145 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportMessage.java
@@ -31,12 +31,15 @@ import 
org.apache.celeborn.common.exception.CelebornIOException;
 import org.apache.celeborn.common.protocol.MessageType;
 import org.apache.celeborn.common.protocol.PbBacklogAnnouncement;
 import org.apache.celeborn.common.protocol.PbBufferStreamEnd;
+import org.apache.celeborn.common.protocol.PbChunkFetchRequest;
 import org.apache.celeborn.common.protocol.PbOpenStream;
 import org.apache.celeborn.common.protocol.PbPushDataHandShake;
 import org.apache.celeborn.common.protocol.PbReadAddCredit;
 import org.apache.celeborn.common.protocol.PbRegionFinish;
 import org.apache.celeborn.common.protocol.PbRegionStart;
+import org.apache.celeborn.common.protocol.PbStreamChunkSlice;
 import org.apache.celeborn.common.protocol.PbStreamHandler;
+import org.apache.celeborn.common.protocol.PbTransportableError;
 
 public class TransportMessage implements Serializable {
   private static final long serialVersionUID = -3259000920699629773L;
@@ -81,6 +84,12 @@ public class TransportMessage implements Serializable {
         return (T) PbBufferStreamEnd.parseFrom(payload);
       case READ_ADD_CREDIT_VALUE:
         return (T) PbReadAddCredit.parseFrom(payload);
+      case STREAM_CHUNK_SLICE_VALUE:
+        return (T) PbStreamChunkSlice.parseFrom(payload);
+      case CHUNK_FETCH_REQUEST_VALUE:
+        return (T) PbChunkFetchRequest.parseFrom(payload);
+      case TRANSPORTABLE_ERROR_VALUE:
+        return (T) PbTransportableError.parseFrom(payload);
       default:
         logger.error("Unexpected type {}", type);
     }
diff --git 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportableError.java
 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportableError.java
index 8762067c1..262c4ee66 100644
--- 
a/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportableError.java
+++ 
b/common/src/main/java/org/apache/celeborn/common/network/protocol/TransportableError.java
@@ -21,6 +21,7 @@ import java.nio.charset.StandardCharsets;
 
 import io.netty.buffer.ByteBuf;
 
+import org.apache.celeborn.common.protocol.PbTransportableError;
 import org.apache.celeborn.common.util.ExceptionUtils;
 
 public class TransportableError extends RequestMessage {
@@ -70,4 +71,9 @@ public class TransportableError extends RequestMessage {
   public String getErrorMessage() {
     return new String(errorMessage, StandardCharsets.UTF_8);
   }
+
+  public static TransportableError fromProto(PbTransportableError pb) {
+    return new TransportableError(
+        pb.getStreamId(), pb.getMessage().getBytes(StandardCharsets.UTF_8));
+  }
 }
diff --git a/common/src/main/proto/TransportMessages.proto 
b/common/src/main/proto/TransportMessages.proto
index ffeae3058..3e2e7a54c 100644
--- a/common/src/main/proto/TransportMessages.proto
+++ b/common/src/main/proto/TransportMessages.proto
@@ -82,6 +82,9 @@ enum MessageType {
   BACKLOG_ANNOUNCEMENT = 59;
   BUFFER_STREAM_END = 60;
   READ_ADD_CREDIT = 61;
+  STREAM_CHUNK_SLICE = 62;
+  CHUNK_FETCH_REQUEST = 63;
+  TRANSPORTABLE_ERROR = 64;
 }
 
 enum StreamType {
@@ -551,3 +554,19 @@ message PbReadAddCredit {
   int64 streamId = 1;
   int32 credit = 2;
 }
+
+message PbStreamChunkSlice {
+  int64 streamId = 1;
+  int32 chunkIndex = 2;
+  int32 offset = 3;
+  int32 len = 4;
+}
+
+message PbChunkFetchRequest {
+  PbStreamChunkSlice streamChunkSlice = 1;
+}
+
+message PbTransportableError {
+  int64 streamId = 1;
+  string message = 2;
+}
diff --git 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
index 1530e9440..4b299728b 100644
--- 
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
+++ 
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/FetchHandler.scala
@@ -37,7 +37,7 @@ import 
org.apache.celeborn.common.network.client.TransportClient
 import org.apache.celeborn.common.network.protocol._
 import org.apache.celeborn.common.network.server.BaseMessageHandler
 import org.apache.celeborn.common.network.util.{NettyUtils, TransportConf}
-import org.apache.celeborn.common.protocol.{MessageType, PartitionType, 
PbBufferStreamEnd, PbOpenStream, PbReadAddCredit, PbStreamHandler, StreamType}
+import org.apache.celeborn.common.protocol.{MessageType, PartitionType, 
PbBufferStreamEnd, PbChunkFetchRequest, PbOpenStream, PbReadAddCredit, 
PbStreamHandler, StreamType}
 import org.apache.celeborn.common.util.{ExceptionUtils, Utils}
 import org.apache.celeborn.service.deploy.worker.storage.{ChunkStreamManager, 
CreditStreamManager, PartitionFilesSorter, StorageManager}
 
@@ -93,7 +93,7 @@ class FetchHandler(val conf: CelebornConf, val transportConf: 
TransportConf)
       case r: ReadAddCredit =>
         handleReadAddCredit(r.getCredit, r.getStreamId)
       case r: ChunkFetchRequest =>
-        handleChunkFetchRequest(client, r)
+        handleChunkFetchRequest(client, r.streamChunkSlice, r)
       case r: RpcRequest =>
         handleRpcRequest(client, r)
       case unknown: RequestMessage =>
@@ -125,9 +125,14 @@ class FetchHandler(val conf: CelebornConf, val 
transportConf: TransportConf)
             isLegacy = false,
             openStream.getReadLocalShuffle)
         case bufferStreamEnd: PbBufferStreamEnd =>
-          handleEndStreamFromClient(bufferStreamEnd)
+          handleEndStreamFromClient(bufferStreamEnd.getStreamId, 
bufferStreamEnd.getStreamType)
         case readAddCredit: PbReadAddCredit =>
           handleReadAddCredit(readAddCredit.getCredit, 
readAddCredit.getStreamId)
+        case chunkFetchRequest: PbChunkFetchRequest =>
+          handleChunkFetchRequest(
+            client,
+            StreamChunkSlice.fromProto(chunkFetchRequest.getStreamChunkSlice),
+            rpcRequest)
         case message: GeneratedMessageV3 =>
           logError(s"Unknown message $message")
       }
@@ -318,18 +323,18 @@ class FetchHandler(val conf: CelebornConf, val 
transportConf: TransportConf)
   }
 
   def handleEndStreamFromClient(streamId: Long): Unit = {
-    creditStreamManager.notifyStreamEndByClient(streamId)
+    handleEndStreamFromClient(streamId, StreamType.CreditStream)
   }
 
-  def handleEndStreamFromClient(req: PbBufferStreamEnd): Unit = {
-    req.getStreamType match {
+  def handleEndStreamFromClient(streamId: Long, streamType: StreamType): Unit 
= {
+    streamType match {
       case StreamType.ChunkStream =>
-        val (shuffleKey, fileName) = 
chunkStreamManager.getShuffleKeyAndFileName(req.getStreamId)
-        getRawFileInfo(shuffleKey, fileName).closeStream(req.getStreamId)
+        val (shuffleKey, fileName) = 
chunkStreamManager.getShuffleKeyAndFileName(streamId)
+        getRawFileInfo(shuffleKey, fileName).closeStream(streamId)
       case StreamType.CreditStream =>
-        creditStreamManager.notifyStreamEndByClient(req.getStreamId)
+        creditStreamManager.notifyStreamEndByClient(streamId)
       case _ =>
-        logError(s"Received a PbBufferStreamEnd message with unknown type 
${req.getStreamType}")
+        logError(s"Received a PbBufferStreamEnd message with unknown type 
$streamType")
     }
   }
 
@@ -337,9 +342,12 @@ class FetchHandler(val conf: CelebornConf, val 
transportConf: TransportConf)
     creditStreamManager.addCredit(credit, streamId)
   }
 
-  def handleChunkFetchRequest(client: TransportClient, req: 
ChunkFetchRequest): Unit = {
+  def handleChunkFetchRequest(
+      client: TransportClient,
+      streamChunkSlice: StreamChunkSlice,
+      req: RequestMessage): Unit = {
     logDebug(s"Received req from 
${NettyUtils.getRemoteAddress(client.getChannel)}" +
-      s" to fetch block ${req.streamChunkSlice}")
+      s" to fetch block $streamChunkSlice")
 
     maxChunkBeingTransferred.foreach { threshold =>
       val chunksBeingTransferred = chunkStreamManager.chunksBeingTransferred 
// take high cpu usage
@@ -348,35 +356,35 @@ class FetchHandler(val conf: CelebornConf, val 
transportConf: TransportConf)
           s"$chunksBeingTransferred exceeds 
${MAX_CHUNKS_BEING_TRANSFERRED.key} " +
           s"${Utils.bytesToString(threshold)}."
         logError(message)
-        client.getChannel.writeAndFlush(new 
ChunkFetchFailure(req.streamChunkSlice, message))
+        client.getChannel.writeAndFlush(new 
ChunkFetchFailure(streamChunkSlice, message))
         return
       }
     }
 
     workerSource.startTimer(WorkerSource.FETCH_CHUNK_TIME, req.toString)
-    val fetchTimeMetric = 
chunkStreamManager.getFetchTimeMetric(req.streamChunkSlice.streamId)
+    val fetchTimeMetric = 
chunkStreamManager.getFetchTimeMetric(streamChunkSlice.streamId)
     val fetchBeginTime = System.nanoTime()
     try {
       val buf = chunkStreamManager.getChunk(
-        req.streamChunkSlice.streamId,
-        req.streamChunkSlice.chunkIndex,
-        req.streamChunkSlice.offset,
-        req.streamChunkSlice.len)
-      chunkStreamManager.chunkBeingSent(req.streamChunkSlice.streamId)
-      client.getChannel.writeAndFlush(new 
ChunkFetchSuccess(req.streamChunkSlice, buf))
+        streamChunkSlice.streamId,
+        streamChunkSlice.chunkIndex,
+        streamChunkSlice.offset,
+        streamChunkSlice.len)
+      chunkStreamManager.chunkBeingSent(streamChunkSlice.streamId)
+      client.getChannel.writeAndFlush(new ChunkFetchSuccess(streamChunkSlice, 
buf))
         .addListener(new GenericFutureListener[Future[_ >: Void]] {
           override def operationComplete(future: Future[_ >: Void]): Unit = {
-            if (future.isSuccess()) {
+            if (future.isSuccess) {
               if (log.isDebugEnabled) {
                 logDebug(
-                  s"Sending ChunkFetchSuccess operation succeeded, chunk 
${req.streamChunkSlice}")
+                  s"Sending ChunkFetchSuccess operation succeeded, chunk 
$streamChunkSlice")
               }
             } else {
               logError(
-                s"Sending ChunkFetchSuccess operation failed, chunk 
${req.streamChunkSlice}",
+                s"Sending ChunkFetchSuccess operation failed, chunk 
$streamChunkSlice",
                 future.cause())
             }
-            chunkStreamManager.chunkSent(req.streamChunkSlice.streamId)
+            chunkStreamManager.chunkSent(streamChunkSlice.streamId)
             if (fetchTimeMetric != null) {
               fetchTimeMetric.update(System.nanoTime() - fetchBeginTime)
             }
@@ -386,11 +394,11 @@ class FetchHandler(val conf: CelebornConf, val 
transportConf: TransportConf)
     } catch {
       case e: Exception =>
         logError(
-          s"Error opening block ${req.streamChunkSlice} for request from " +
+          s"Error opening block $streamChunkSlice for request from " +
             NettyUtils.getRemoteAddress(client.getChannel),
           e)
         client.getChannel.writeAndFlush(new ChunkFetchFailure(
-          req.streamChunkSlice,
+          streamChunkSlice,
           Throwables.getStackTraceAsString(e)))
         workerSource.stopTimer(WorkerSource.FETCH_CHUNK_TIME, req.toString)
     }
diff --git 
a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/FetchHandlerSuiteJ.java
 
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/FetchHandlerSuiteJ.java
index 99a188ad3..7234c21fa 100644
--- 
a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/FetchHandlerSuiteJ.java
+++ 
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/FetchHandlerSuiteJ.java
@@ -47,19 +47,19 @@ import org.apache.celeborn.common.meta.FileInfo;
 import org.apache.celeborn.common.network.buffer.NioManagedBuffer;
 import org.apache.celeborn.common.network.client.TransportClient;
 import org.apache.celeborn.common.network.client.TransportResponseHandler;
-import org.apache.celeborn.common.network.protocol.ChunkFetchRequest;
 import org.apache.celeborn.common.network.protocol.ChunkFetchSuccess;
 import org.apache.celeborn.common.network.protocol.Message;
 import org.apache.celeborn.common.network.protocol.OpenStream;
 import org.apache.celeborn.common.network.protocol.RpcRequest;
 import org.apache.celeborn.common.network.protocol.RpcResponse;
-import org.apache.celeborn.common.network.protocol.StreamChunkSlice;
 import org.apache.celeborn.common.network.protocol.StreamHandle;
 import org.apache.celeborn.common.network.protocol.TransportMessage;
 import org.apache.celeborn.common.network.util.TransportConf;
 import org.apache.celeborn.common.protocol.MessageType;
 import org.apache.celeborn.common.protocol.PbBufferStreamEnd;
+import org.apache.celeborn.common.protocol.PbChunkFetchRequest;
 import org.apache.celeborn.common.protocol.PbOpenStream;
+import org.apache.celeborn.common.protocol.PbStreamChunkSlice;
 import org.apache.celeborn.common.protocol.PbStreamHandler;
 import org.apache.celeborn.common.protocol.StreamType;
 import org.apache.celeborn.common.protocol.TransportModuleConstants;
@@ -339,8 +339,21 @@ public class FetchHandlerSuiteJ {
     for (int chunkIndex = 0; chunkIndex < streamHandler.getNumChunks(); 
chunkIndex++) {
       fetchHandler.receive(
           client,
-          new ChunkFetchRequest(
-              new StreamChunkSlice(streamHandler.getStreamId(), chunkIndex, 0, 
Integer.MAX_VALUE)));
+          new RpcRequest(
+              TransportClient.requestId(),
+              new NioManagedBuffer(
+                  new TransportMessage(
+                          MessageType.CHUNK_FETCH_REQUEST,
+                          PbChunkFetchRequest.newBuilder()
+                              .setStreamChunkSlice(
+                                  PbStreamChunkSlice.newBuilder()
+                                      .setStreamId(streamHandler.getStreamId())
+                                      .setChunkIndex(chunkIndex)
+                                      .setOffset(0)
+                                      .setLen(Integer.MAX_VALUE))
+                              .build()
+                              .toByteArray())
+                      .toByteBuffer())));
       ChunkFetchSuccess chunkFetchSuccess = channel.readOutbound();
       chunkFetchSuccess.body().retain();
       // chunk size 8m
diff --git 
a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/network/RequestTimeoutIntegrationSuiteJ.java
 
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/network/RequestTimeoutIntegrationSuiteJ.java
index 11223a0d3..85ce86367 100644
--- 
a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/network/RequestTimeoutIntegrationSuiteJ.java
+++ 
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/network/RequestTimeoutIntegrationSuiteJ.java
@@ -39,10 +39,16 @@ import 
org.apache.celeborn.common.network.client.ChunkReceivedCallback;
 import org.apache.celeborn.common.network.client.RpcResponseCallback;
 import org.apache.celeborn.common.network.client.TransportClient;
 import org.apache.celeborn.common.network.client.TransportClientFactory;
-import org.apache.celeborn.common.network.protocol.*;
+import org.apache.celeborn.common.network.protocol.ChunkFetchSuccess;
+import org.apache.celeborn.common.network.protocol.RequestMessage;
+import org.apache.celeborn.common.network.protocol.RpcRequest;
+import org.apache.celeborn.common.network.protocol.RpcResponse;
+import org.apache.celeborn.common.network.protocol.StreamChunkSlice;
+import org.apache.celeborn.common.network.protocol.TransportMessage;
 import org.apache.celeborn.common.network.server.BaseMessageHandler;
 import org.apache.celeborn.common.network.server.TransportServer;
 import org.apache.celeborn.common.network.util.TransportConf;
+import org.apache.celeborn.common.protocol.PbChunkFetchRequest;
 import org.apache.celeborn.service.deploy.worker.storage.ChunkStreamManager;
 
 /**
@@ -196,7 +202,15 @@ public class RequestTimeoutIntegrationSuiteJ {
         new BaseMessageHandler() {
           @Override
           public void receive(TransportClient client, RequestMessage msg) {
-            StreamChunkSlice slice = ((ChunkFetchRequest) 
msg).streamChunkSlice;
+            PbChunkFetchRequest chunkFetchRequest;
+            try {
+              chunkFetchRequest =
+                  
TransportMessage.fromByteBuffer(msg.body().nioByteBuffer()).getParsedPayload();
+            } catch (IOException e) {
+              throw new RuntimeException(e);
+            }
+            StreamChunkSlice slice =
+                
StreamChunkSlice.fromProto(chunkFetchRequest.getStreamChunkSlice());
             ManagedBuffer buf =
                 manager.getChunk(slice.streamId, slice.chunkIndex, 
slice.offset, slice.len);
             client.getChannel().writeAndFlush(new ChunkFetchSuccess(slice, 
buf));
diff --git 
a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/ChunkFetchIntegrationSuiteJ.java
 
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/ChunkFetchIntegrationSuiteJ.java
index 9c3fd7c32..6a1df6747 100644
--- 
a/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/ChunkFetchIntegrationSuiteJ.java
+++ 
b/worker/src/test/java/org/apache/celeborn/service/deploy/worker/storage/ChunkFetchIntegrationSuiteJ.java
@@ -21,6 +21,7 @@ import static 
org.apache.celeborn.common.util.JavaUtils.getLocalHost;
 import static org.junit.Assert.*;
 
 import java.io.File;
+import java.io.IOException;
 import java.io.RandomAccessFile;
 import java.nio.ByteBuffer;
 import java.util.*;
@@ -41,13 +42,14 @@ import 
org.apache.celeborn.common.network.buffer.NioManagedBuffer;
 import org.apache.celeborn.common.network.client.ChunkReceivedCallback;
 import org.apache.celeborn.common.network.client.TransportClient;
 import org.apache.celeborn.common.network.client.TransportClientFactory;
-import org.apache.celeborn.common.network.protocol.ChunkFetchRequest;
 import org.apache.celeborn.common.network.protocol.ChunkFetchSuccess;
 import org.apache.celeborn.common.network.protocol.RequestMessage;
 import org.apache.celeborn.common.network.protocol.StreamChunkSlice;
+import org.apache.celeborn.common.network.protocol.TransportMessage;
 import org.apache.celeborn.common.network.server.BaseMessageHandler;
 import org.apache.celeborn.common.network.server.TransportServer;
 import org.apache.celeborn.common.network.util.TransportConf;
+import org.apache.celeborn.common.protocol.PbChunkFetchRequest;
 
 public class ChunkFetchIntegrationSuiteJ {
   static final long STREAM_ID = 1;
@@ -106,7 +108,15 @@ public class ChunkFetchIntegrationSuiteJ {
         new BaseMessageHandler() {
           @Override
           public void receive(TransportClient client, RequestMessage msg) {
-            StreamChunkSlice slice = ((ChunkFetchRequest) 
msg).streamChunkSlice;
+            PbChunkFetchRequest chunkFetchRequest;
+            try {
+              chunkFetchRequest =
+                  
TransportMessage.fromByteBuffer(msg.body().nioByteBuffer()).getParsedPayload();
+            } catch (IOException e) {
+              throw new RuntimeException(e);
+            }
+            StreamChunkSlice slice =
+                
StreamChunkSlice.fromProto(chunkFetchRequest.getStreamChunkSlice());
             ManagedBuffer buf =
                 chunkStreamManager.getChunk(
                     slice.streamId, slice.chunkIndex, slice.offset, slice.len);

Reply via email to