This is an automated email from the ASF dual-hosted git repository.

ethanfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new ab0c5af44 [CELEBORN-2124][CIP-14] Support PushData network interfaces 
in cppClient
ab0c5af44 is described below

commit ab0c5af446f9b61a649f01d4e3f859ef958efe38
Author: HolyLow <[email protected]>
AuthorDate: Tue Sep 16 21:38:08 2025 +0800

    [CELEBORN-2124][CIP-14] Support PushData network interfaces in cppClient
    
    ### What changes were proposed in this pull request?
    Support PushData network interfaces in cppClient.
    
    ### Why are the changes needed?
    PushData network interfaces are used when writing data to Worker node via 
cppClient.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Compilation and UTs.
    
    Closes #3461 from 
HolyLow/issue/celeborn-2124-support-pushdata-network-interface.
    
    Authored-by: HolyLow <[email protected]>
    Signed-off-by: Ethan Feng <[email protected]>
---
 cpp/celeborn/network/Message.h                     |  12 ++
 cpp/celeborn/network/MessageDispatcher.cpp         |  31 ++-
 cpp/celeborn/network/MessageDispatcher.h           |   3 +
 cpp/celeborn/network/TransportClient.cpp           |  37 ++++
 cpp/celeborn/network/TransportClient.h             |  16 ++
 .../network/tests/MessageDispatcherTest.cpp        |  86 +++++++++
 cpp/celeborn/network/tests/TransportClientTest.cpp | 212 ++++++++++++++++++++-
 7 files changed, 387 insertions(+), 10 deletions(-)

diff --git a/cpp/celeborn/network/Message.h b/cpp/celeborn/network/Message.h
index 8aeca60c9..e0944294c 100644
--- a/cpp/celeborn/network/Message.h
+++ b/cpp/celeborn/network/Message.h
@@ -252,6 +252,18 @@ class PushData : public Message {
     return requestId_;
   }
 
+  uint8_t mode() const {
+    return mode_;
+  }
+
+  std::string shuffleKey() const {
+    return shuffleKey_;
+  }
+
+  std::string partitionUniqueId() const {
+    return partitionUniqueId_;
+  }
+
  private:
   int internalEncodedLength() const override;
 
diff --git a/cpp/celeborn/network/MessageDispatcher.cpp 
b/cpp/celeborn/network/MessageDispatcher.cpp
index 30161483d..8dd08ec38 100644
--- a/cpp/celeborn/network/MessageDispatcher.cpp
+++ b/cpp/celeborn/network/MessageDispatcher.cpp
@@ -129,14 +129,30 @@ void MessageDispatcher::read(Context*, 
std::unique_ptr<Message> toRecvMsg) {
 folly::Future<std::unique_ptr<Message>> MessageDispatcher::operator()(
     std::unique_ptr<Message> toSendMsg) {
   CELEBORN_CHECK(!closed_);
-  CELEBORN_CHECK(toSendMsg->type() == Message::RPC_REQUEST);
-  RpcRequest* request = reinterpret_cast<RpcRequest*>(toSendMsg.get());
+  auto currTime = std::chrono::system_clock::now();
+  long requestId;
+  switch (toSendMsg->type()) {
+    case Message::RPC_REQUEST: {
+      RpcRequest* request = reinterpret_cast<RpcRequest*>(toSendMsg.get());
+      requestId = request->requestId();
+      break;
+    }
+    case Message::PUSH_DATA: {
+      PushData* pushData = reinterpret_cast<PushData*>(toSendMsg.get());
+      requestId = pushData->requestId();
+      break;
+    }
+    default: {
+      CELEBORN_FAIL("unsupported type");
+    }
+  }
+
   auto f = requestIdRegistry_.withLock(
       [&](auto& registry) -> folly::Future<std::unique_ptr<Message>> {
-        auto& holder = registry[request->requestId()];
-        holder.requestTime = std::chrono::system_clock::now();
+        auto& holder = registry[requestId];
+        holder.requestTime = currTime;
         auto& p = holder.msgPromise;
-        p.setInterruptHandler([requestId = request->requestId(),
+        p.setInterruptHandler([requestId,
                                this](const folly::exception_wrapper&) {
           this->requestIdRegistry_.lock()->erase(requestId);
           LOG(WARNING) << "rpc request interrupted, requestId: " << requestId;
@@ -150,6 +166,11 @@ folly::Future<std::unique_ptr<Message>> 
MessageDispatcher::operator()(
   return f;
 }
 
+folly::Future<std::unique_ptr<Message>> MessageDispatcher::sendPushDataRequest(
+    std::unique_ptr<Message> toSendMsg) {
+  return (*this)(std::move(toSendMsg));
+}
+
 folly::Future<std::unique_ptr<Message>>
 MessageDispatcher::sendFetchChunkRequest(
     const protocol::StreamChunkSlice& streamChunkSlice,
diff --git a/cpp/celeborn/network/MessageDispatcher.h 
b/cpp/celeborn/network/MessageDispatcher.h
index 676ef081b..7ce6a7afe 100644
--- a/cpp/celeborn/network/MessageDispatcher.h
+++ b/cpp/celeborn/network/MessageDispatcher.h
@@ -60,6 +60,9 @@ class MessageDispatcher : public wangle::ClientDispatcherBase<
     return operator()(std::move(toSendMsg));
   }
 
+  virtual folly::Future<std::unique_ptr<Message>> sendPushDataRequest(
+      std::unique_ptr<Message> toSendMsg);
+
   virtual folly::Future<std::unique_ptr<Message>> sendFetchChunkRequest(
       const protocol::StreamChunkSlice& streamChunkSlice,
       std::unique_ptr<Message> toSendMsg);
diff --git a/cpp/celeborn/network/TransportClient.cpp 
b/cpp/celeborn/network/TransportClient.cpp
index 796eb5d50..c3e4d81fd 100644
--- a/cpp/celeborn/network/TransportClient.cpp
+++ b/cpp/celeborn/network/TransportClient.cpp
@@ -79,6 +79,43 @@ void TransportClient::sendRpcRequestWithoutResponse(const 
RpcRequest& request) {
   }
 }
 
+void TransportClient::pushDataAsync(
+    const PushData& pushData,
+    Timeout timeout,
+    std::shared_ptr<RpcResponseCallback> callback) {
+  try {
+    auto requestMsg = std::make_unique<PushData>(pushData);
+    auto future = dispatcher_->sendPushDataRequest(std::move(requestMsg));
+    std::move(future)
+        .within(timeout)
+        .thenValue(
+            [_callback = callback](std::unique_ptr<Message> responseMsg) {
+              if (responseMsg->type() == Message::RPC_RESPONSE) {
+                auto rpcResponse =
+                    reinterpret_cast<RpcResponse*>(responseMsg.get());
+                _callback->onSuccess(rpcResponse->body());
+              } else {
+                _callback->onFailure(std::make_unique<std::runtime_error>(
+                    "pushData return value type is not rpcResponse"));
+              }
+            })
+        .thenError([_callback = callback](const folly::exception_wrapper& e) {
+          _callback->onFailure(
+              std::make_unique<std::runtime_error>(e.what().toStdString()));
+        });
+
+  } catch (std::exception& e) {
+    auto errorMsg = fmt::format(
+        "PushData failed. shuffleKey: {}, partitionUniqueId: {}, mode: {}, 
error message: {}",
+        pushData.shuffleKey(),
+        pushData.partitionUniqueId(),
+        pushData.mode(),
+        e.what());
+    LOG(ERROR) << errorMsg;
+    callback->onFailure(std::make_unique<std::runtime_error>(errorMsg));
+  }
+}
+
 void TransportClient::fetchChunkAsync(
     const protocol::StreamChunkSlice& streamChunkSlice,
     const RpcRequest& request,
diff --git a/cpp/celeborn/network/TransportClient.h 
b/cpp/celeborn/network/TransportClient.h
index 5bfa4afdc..e3ece7d22 100644
--- a/cpp/celeborn/network/TransportClient.h
+++ b/cpp/celeborn/network/TransportClient.h
@@ -54,6 +54,17 @@ using FetchChunkSuccessCallback = std::function<void(
 using FetchChunkFailureCallback = std::function<
     void(protocol::StreamChunkSlice, std::unique_ptr<std::exception>)>;
 
+class RpcResponseCallback {
+ public:
+  RpcResponseCallback() = default;
+
+  virtual ~RpcResponseCallback() = default;
+
+  virtual void onSuccess(std::unique_ptr<memory::ReadOnlyByteBuffer>) = 0;
+
+  virtual void onFailure(std::unique_ptr<std::exception> exception) = 0;
+};
+
 /**
  * TransportClient sends the messages to the network layer, and handles
  * the message callback, timeout, error handling, etc.
@@ -76,6 +87,11 @@ class TransportClient {
   // Ignore the response, return immediately.
   virtual void sendRpcRequestWithoutResponse(const RpcRequest& request);
 
+  virtual void pushDataAsync(
+      const PushData& pushData,
+      Timeout timeout,
+      std::shared_ptr<RpcResponseCallback> callback);
+
   virtual void fetchChunkAsync(
       const protocol::StreamChunkSlice& streamChunkSlice,
       const RpcRequest& request,
diff --git a/cpp/celeborn/network/tests/MessageDispatcherTest.cpp 
b/cpp/celeborn/network/tests/MessageDispatcherTest.cpp
index 4d5dbc541..6127548d6 100644
--- a/cpp/celeborn/network/tests/MessageDispatcherTest.cpp
+++ b/cpp/celeborn/network/tests/MessageDispatcherTest.cpp
@@ -127,6 +127,92 @@ TEST(MessageDispatcherTest, 
sendRpcRequestAndReceiveFailure) {
   EXPECT_TRUE(future.hasException());
 }
 
+TEST(MessageDispatcherTest, sendPushDataAndReceiveSuccess) {
+  std::unique_ptr<Message> sentMsg;
+  MockHandler mockHandler(sentMsg);
+  auto mockPipeline = createMockedPipeline(std::move(mockHandler));
+  auto dispatcher = std::make_unique<MessageDispatcher>();
+  dispatcher->setPipeline(mockPipeline.get());
+
+  const long requestId = 1001;
+  const uint8_t mode = 2;
+  const std::string shuffleKey = "test-shuffle-key";
+  const std::string partitionUniqueId = "test-partition-id";
+  const std::string requestBody = "test-request-body";
+  auto pushData = std::make_unique<PushData>(
+      requestId,
+      mode,
+      shuffleKey,
+      partitionUniqueId,
+      toReadOnlyByteBuffer(requestBody));
+  auto future = dispatcher->sendPushDataRequest(std::move(pushData));
+
+  EXPECT_FALSE(future.isReady());
+  EXPECT_EQ(sentMsg->type(), Message::PUSH_DATA);
+  auto sentPushData = dynamic_cast<PushData*>(sentMsg.get());
+  EXPECT_EQ(sentPushData->requestId(), requestId);
+  EXPECT_EQ(sentPushData->mode(), mode);
+  EXPECT_EQ(sentPushData->shuffleKey(), shuffleKey);
+  EXPECT_EQ(sentPushData->partitionUniqueId(), partitionUniqueId);
+  EXPECT_EQ(sentPushData->body()->remainingSize(), requestBody.size());
+  EXPECT_EQ(
+      sentPushData->body()->readToString(requestBody.size()), requestBody);
+
+  const std::string responseBody = "test-response-body";
+  auto rpcResponse = std::make_unique<RpcResponse>(
+      requestId, toReadOnlyByteBuffer(responseBody));
+  dispatcher->read(nullptr, std::move(rpcResponse));
+
+  EXPECT_TRUE(future.isReady());
+  auto receivedMsg = std::move(future).get();
+  EXPECT_EQ(receivedMsg->type(), Message::RPC_RESPONSE);
+  auto receivedRpcResponse = dynamic_cast<RpcResponse*>(receivedMsg.get());
+  EXPECT_EQ(receivedRpcResponse->body()->remainingSize(), responseBody.size());
+  EXPECT_EQ(
+      receivedRpcResponse->body()->readToString(responseBody.size()),
+      responseBody);
+}
+
+TEST(MessageDispatcherTest, sendPushDataAndReceiveFailure) {
+  std::unique_ptr<Message> sentMsg;
+  MockHandler mockHandler(sentMsg);
+  auto mockPipeline = createMockedPipeline(std::move(mockHandler));
+  auto dispatcher = std::make_unique<MessageDispatcher>();
+  dispatcher->setPipeline(mockPipeline.get());
+
+  const long requestId = 1001;
+  const uint8_t mode = 2;
+  const std::string shuffleKey = "test-shuffle-key";
+  const std::string partitionUniqueId = "test-partition-id";
+  const std::string requestBody = "test-request-body";
+  auto pushData = std::make_unique<PushData>(
+      requestId,
+      mode,
+      shuffleKey,
+      partitionUniqueId,
+      toReadOnlyByteBuffer(requestBody));
+  auto future = dispatcher->sendPushDataRequest(std::move(pushData));
+
+  EXPECT_FALSE(future.isReady());
+  EXPECT_EQ(sentMsg->type(), Message::PUSH_DATA);
+  auto sentPushData = dynamic_cast<PushData*>(sentMsg.get());
+  EXPECT_EQ(sentPushData->requestId(), requestId);
+  EXPECT_EQ(sentPushData->mode(), mode);
+  EXPECT_EQ(sentPushData->shuffleKey(), shuffleKey);
+  EXPECT_EQ(sentPushData->partitionUniqueId(), partitionUniqueId);
+  EXPECT_EQ(sentPushData->body()->remainingSize(), requestBody.size());
+  EXPECT_EQ(
+      sentPushData->body()->readToString(requestBody.size()), requestBody);
+
+  const std::string errorMsg = "test-error-msg";
+  auto copiedErrorMsg = errorMsg;
+  auto rpcFailure =
+      std::make_unique<RpcFailure>(requestId, std::move(copiedErrorMsg));
+  dispatcher->read(nullptr, std::move(rpcFailure));
+
+  EXPECT_TRUE(future.hasException());
+}
+
 TEST(MessageDispatcherTest, sendFetchChunkRequestAndReceiveSuccess) {
   std::unique_ptr<Message> sentMsg;
   MockHandler mockHandler(sentMsg);
diff --git a/cpp/celeborn/network/tests/TransportClientTest.cpp 
b/cpp/celeborn/network/tests/TransportClientTest.cpp
index 82740ad67..0135d61e5 100644
--- a/cpp/celeborn/network/tests/TransportClientTest.cpp
+++ b/cpp/celeborn/network/tests/TransportClientTest.cpp
@@ -15,6 +15,7 @@
  * limitations under the License.
  */
 
+#include <folly/init/Init.h>
 #include <gtest/gtest.h>
 
 #include "celeborn/network/TransportClient.h"
@@ -38,6 +39,13 @@ class MockDispatcher : public MessageDispatcher {
     sentMsg_ = std::move(toSendMsg);
   }
 
+  folly::Future<std::unique_ptr<Message>> sendPushDataRequest(
+      std::unique_ptr<Message> toSendMsg) override {
+    sentMsg_ = std::move(toSendMsg);
+    msgPromise_ = MsgPromise();
+    return msgPromise_.getFuture();
+  }
+
   folly::Future<std::unique_ptr<Message>> sendFetchChunkRequest(
       const protocol::StreamChunkSlice& streamChunkSlice,
       std::unique_ptr<Message> toSendMsg) override {
@@ -66,9 +74,61 @@ std::unique_ptr<memory::ReadOnlyByteBuffer> 
toReadOnlyByteBuffer(
   buffer->writeFromString(content);
   return memory::ByteBuffer::toReadOnly(std::move(buffer));
 }
+
+class MockRpcResponseCallback : public RpcResponseCallback {
+ public:
+  MockRpcResponseCallback() = default;
+
+  ~MockRpcResponseCallback() override = default;
+
+  void onSuccess(std::unique_ptr<memory::ReadOnlyByteBuffer> data) override {
+    onSuccessBuffer_ = std::move(data);
+  }
+
+  void onFailure(std::unique_ptr<std::exception> exception) override {
+    onFailureException_ = std::move(exception);
+  }
+
+  std::unique_ptr<memory::ReadOnlyByteBuffer> getOnSuccessBuffer() {
+    return std::move(onSuccessBuffer_);
+  }
+
+  std::unique_ptr<std::exception> getOnFailureException() {
+    return std::move(onFailureException_);
+  }
+
+ private:
+  std::unique_ptr<memory::ReadOnlyByteBuffer> onSuccessBuffer_;
+  std::unique_ptr<std::exception> onFailureException_;
+};
 } // namespace
 
-TEST(TransportClientTest, sendRpcRequestSync) {
+class TransportClientTest : public testing::Test {
+ protected:
+  TransportClientTest() {
+    if (!follyInit_) {
+      std::lock_guard<std::mutex> lock(initMutex_);
+      if (!follyInit_) {
+        int argc = 0;
+        char* arg = "test-arg";
+        char** argv = &arg;
+        follyInit_ = std::make_unique<folly::Init>(&argc, &argv, false);
+      }
+    }
+  }
+
+  ~TransportClientTest() override = default;
+
+ private:
+  // Must be only inited once per process.
+  static std::unique_ptr<folly::Init> follyInit_;
+  static std::mutex initMutex_;
+};
+
+std::unique_ptr<folly::Init> TransportClientTest::follyInit_ = {};
+std::mutex TransportClientTest::initMutex_ = {};
+
+TEST_F(TransportClientTest, sendRpcRequestSync) {
   auto mockDispatcher = std::make_unique<MockDispatcher>();
   auto rawMockDispatcher = mockDispatcher.get();
   const auto timeoutInterval = MS(10000);
@@ -109,7 +169,7 @@ TEST(TransportClientTest, sendRpcRequestSync) {
       responseBody);
 }
 
-TEST(TransportClientTest, sendRpcRequestSyncTimeout) {
+TEST_F(TransportClientTest, sendRpcRequestSyncTimeout) {
   auto mockDispatcher = std::make_unique<MockDispatcher>();
   auto rawMockDispatcher = mockDispatcher.get();
   const auto timeoutInterval = MS(200);
@@ -149,7 +209,7 @@ TEST(TransportClientTest, sendRpcRequestSyncTimeout) {
   EXPECT_TRUE(timeoutHappened);
 }
 
-TEST(TransportClientTest, sendRpcRequestWithoutResponse) {
+TEST_F(TransportClientTest, sendRpcRequestWithoutResponse) {
   auto mockDispatcher = std::make_unique<MockDispatcher>();
   auto rawMockDispatcher = mockDispatcher.get();
   const auto timeoutInterval = MS(10000);
@@ -169,7 +229,149 @@ TEST(TransportClientTest, sendRpcRequestWithoutResponse) {
       sentRpcRequest->body()->readToString(requestBody.size()), requestBody);
 }
 
-TEST(TransportClientTest, fetchChunkAsyncSuccess) {
+TEST_F(TransportClientTest, pushDataAsyncSuccess) {
+  // Construct mock utils.
+  auto mockDispatcher = std::make_unique<MockDispatcher>();
+  auto rawMockDispatcher = mockDispatcher.get();
+  TransportClient client(nullptr, std::move(mockDispatcher), MS(10000));
+  auto mockRpcResponseCallback = std::make_shared<MockRpcResponseCallback>();
+
+  // Construct toSend PushData.
+  const long requestId = 1001;
+  const uint8_t mode = 2;
+  const std::string shuffleKey = "test-shuffle-key";
+  const std::string partitionUniqueId = "test-partition-id";
+  const std::string requestBody = "test-request-body";
+  auto pushData = std::make_unique<PushData>(
+      requestId,
+      mode,
+      shuffleKey,
+      partitionUniqueId,
+      toReadOnlyByteBuffer(requestBody));
+
+  // Send pushData via client, check that the sentPushData is identical to
+  // original pushData.
+  client.pushDataAsync(*pushData, MS(10000), mockRpcResponseCallback);
+  auto sentMsg = rawMockDispatcher->getSentMsg();
+  EXPECT_EQ(sentMsg->type(), Message::PUSH_DATA);
+  auto sentPushData = dynamic_cast<PushData*>(sentMsg.get());
+  EXPECT_EQ(sentPushData->requestId(), requestId);
+  EXPECT_EQ(sentPushData->mode(), mode);
+  EXPECT_EQ(sentPushData->shuffleKey(), shuffleKey);
+  EXPECT_EQ(sentPushData->partitionUniqueId(), partitionUniqueId);
+  EXPECT_EQ(sentPushData->body()->remainingSize(), requestBody.size());
+  EXPECT_EQ(
+      sentPushData->body()->readToString(requestBody.size()), requestBody);
+  EXPECT_FALSE(mockRpcResponseCallback->getOnSuccessBuffer());
+
+  // Construct response and make dispatcher receive it.
+  const std::string responseBody = "test-response-body";
+  auto rpcResponse = std::make_unique<RpcResponse>(
+      requestId, toReadOnlyByteBuffer(responseBody));
+  rawMockDispatcher->receiveMsg(std::move(rpcResponse));
+
+  // Check the received message.
+  auto onSuccessBuffer = mockRpcResponseCallback->getOnSuccessBuffer();
+  auto onFailureException = mockRpcResponseCallback->getOnFailureException();
+  EXPECT_TRUE(onSuccessBuffer);
+  EXPECT_FALSE(onFailureException);
+  EXPECT_EQ(onSuccessBuffer->remainingSize(), responseBody.size());
+  EXPECT_EQ(onSuccessBuffer->readToString(responseBody.size()), responseBody);
+}
+
+TEST_F(TransportClientTest, pushDataAsyncFailure) {
+  // Construct mock utils.
+  auto mockDispatcher = std::make_unique<MockDispatcher>();
+  auto rawMockDispatcher = mockDispatcher.get();
+  TransportClient client(nullptr, std::move(mockDispatcher), MS(10000));
+  auto mockRpcResponseCallback = std::make_shared<MockRpcResponseCallback>();
+
+  // Construct toSend PushData.
+  const long requestId = 1001;
+  const uint8_t mode = 2;
+  const std::string shuffleKey = "test-shuffle-key";
+  const std::string partitionUniqueId = "test-partition-id";
+  const std::string requestBody = "test-request-body";
+  auto pushData = std::make_unique<PushData>(
+      requestId,
+      mode,
+      shuffleKey,
+      partitionUniqueId,
+      toReadOnlyByteBuffer(requestBody));
+
+  // Send pushData via client, check that the sentPushData is identical to
+  // original pushData.
+  client.pushDataAsync(*pushData, MS(10000), mockRpcResponseCallback);
+  auto sentMsg = rawMockDispatcher->getSentMsg();
+  EXPECT_EQ(sentMsg->type(), Message::PUSH_DATA);
+  auto sentPushData = dynamic_cast<PushData*>(sentMsg.get());
+  EXPECT_EQ(sentPushData->requestId(), requestId);
+  EXPECT_EQ(sentPushData->mode(), mode);
+  EXPECT_EQ(sentPushData->shuffleKey(), shuffleKey);
+  EXPECT_EQ(sentPushData->partitionUniqueId(), partitionUniqueId);
+  EXPECT_EQ(sentPushData->body()->remainingSize(), requestBody.size());
+  EXPECT_EQ(
+      sentPushData->body()->readToString(requestBody.size()), requestBody);
+  EXPECT_FALSE(mockRpcResponseCallback->getOnSuccessBuffer());
+
+  // Construct failure and make dispatcher receive it.
+  auto rpcFailure = std::make_unique<RpcFailure>(requestId, 
"failure-msg-body");
+  rawMockDispatcher->receiveMsg(std::move(rpcFailure));
+
+  // Check the received message.
+  auto onSuccessBuffer = mockRpcResponseCallback->getOnSuccessBuffer();
+  auto onFailureException = mockRpcResponseCallback->getOnFailureException();
+  EXPECT_FALSE(onSuccessBuffer);
+  EXPECT_TRUE(onFailureException);
+}
+
+TEST_F(TransportClientTest, pushDataAsyncTimeout) {
+  // Construct mock utils.
+  auto mockDispatcher = std::make_unique<MockDispatcher>();
+  auto rawMockDispatcher = mockDispatcher.get();
+  TransportClient client(nullptr, std::move(mockDispatcher), MS(10000));
+  auto mockRpcResponseCallback = std::make_shared<MockRpcResponseCallback>();
+
+  // Construct toSend PushData.
+  const long requestId = 1001;
+  const uint8_t mode = 2;
+  const std::string shuffleKey = "test-shuffle-key";
+  const std::string partitionUniqueId = "test-partition-id";
+  const std::string requestBody = "test-request-body";
+  auto pushData = std::make_unique<PushData>(
+      requestId,
+      mode,
+      shuffleKey,
+      partitionUniqueId,
+      toReadOnlyByteBuffer(requestBody));
+
+  // Send pushData via client, check that the sentPushData is identical to
+  // original pushData.
+  auto timeoutInterval = MS(100);
+  client.pushDataAsync(*pushData, timeoutInterval, mockRpcResponseCallback);
+  auto sentMsg = rawMockDispatcher->getSentMsg();
+  EXPECT_EQ(sentMsg->type(), Message::PUSH_DATA);
+  auto sentPushData = dynamic_cast<PushData*>(sentMsg.get());
+  EXPECT_EQ(sentPushData->requestId(), requestId);
+  EXPECT_EQ(sentPushData->mode(), mode);
+  EXPECT_EQ(sentPushData->shuffleKey(), shuffleKey);
+  EXPECT_EQ(sentPushData->partitionUniqueId(), partitionUniqueId);
+  EXPECT_EQ(sentPushData->body()->remainingSize(), requestBody.size());
+  EXPECT_EQ(
+      sentPushData->body()->readToString(requestBody.size()), requestBody);
+  EXPECT_FALSE(mockRpcResponseCallback->getOnSuccessBuffer());
+
+  // Wait for timeout.
+  std::this_thread::sleep_for(timeoutInterval * 3);
+
+  // Check the received message.
+  auto onSuccessBuffer = mockRpcResponseCallback->getOnSuccessBuffer();
+  auto onFailureException = mockRpcResponseCallback->getOnFailureException();
+  EXPECT_FALSE(onSuccessBuffer);
+  EXPECT_TRUE(onFailureException);
+}
+
+TEST_F(TransportClientTest, fetchChunkAsyncSuccess) {
   auto mockDispatcher = std::make_unique<MockDispatcher>();
   auto rawMockDispatcher = mockDispatcher.get();
   TransportClient client(nullptr, std::move(mockDispatcher), MS(10000));
@@ -218,7 +420,7 @@ TEST(TransportClientTest, fetchChunkAsyncSuccess) {
   EXPECT_EQ(onSuccessBuffer->readToString(responseBody.size()), responseBody);
 }
 
-TEST(TransportClientTest, fetchChunkAsyncFailure) {
+TEST_F(TransportClientTest, fetchChunkAsyncFailure) {
   auto mockDispatcher = std::make_unique<MockDispatcher>();
   auto rawMockDispatcher = mockDispatcher.get();
   TransportClient client(nullptr, std::move(mockDispatcher), MS(10000));

Reply via email to