This is an automated email from the ASF dual-hosted git repository.
ethanfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new ab0c5af44 [CELEBORN-2124][CIP-14] Support PushData network interfaces
in cppClient
ab0c5af44 is described below
commit ab0c5af446f9b61a649f01d4e3f859ef958efe38
Author: HolyLow <[email protected]>
AuthorDate: Tue Sep 16 21:38:08 2025 +0800
[CELEBORN-2124][CIP-14] Support PushData network interfaces in cppClient
### What changes were proposed in this pull request?
Support PushData network interfaces in cppClient.
### Why are the changes needed?
PushData network interfaces are used when writing data to Worker node via
cppClient.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Compilation and UTs.
Closes #3461 from
HolyLow/issue/celeborn-2124-support-pushdata-network-interface.
Authored-by: HolyLow <[email protected]>
Signed-off-by: Ethan Feng <[email protected]>
---
cpp/celeborn/network/Message.h | 12 ++
cpp/celeborn/network/MessageDispatcher.cpp | 31 ++-
cpp/celeborn/network/MessageDispatcher.h | 3 +
cpp/celeborn/network/TransportClient.cpp | 37 ++++
cpp/celeborn/network/TransportClient.h | 16 ++
.../network/tests/MessageDispatcherTest.cpp | 86 +++++++++
cpp/celeborn/network/tests/TransportClientTest.cpp | 212 ++++++++++++++++++++-
7 files changed, 387 insertions(+), 10 deletions(-)
diff --git a/cpp/celeborn/network/Message.h b/cpp/celeborn/network/Message.h
index 8aeca60c9..e0944294c 100644
--- a/cpp/celeborn/network/Message.h
+++ b/cpp/celeborn/network/Message.h
@@ -252,6 +252,18 @@ class PushData : public Message {
return requestId_;
}
+ uint8_t mode() const {
+ return mode_;
+ }
+
+ std::string shuffleKey() const {
+ return shuffleKey_;
+ }
+
+ std::string partitionUniqueId() const {
+ return partitionUniqueId_;
+ }
+
private:
int internalEncodedLength() const override;
diff --git a/cpp/celeborn/network/MessageDispatcher.cpp
b/cpp/celeborn/network/MessageDispatcher.cpp
index 30161483d..8dd08ec38 100644
--- a/cpp/celeborn/network/MessageDispatcher.cpp
+++ b/cpp/celeborn/network/MessageDispatcher.cpp
@@ -129,14 +129,30 @@ void MessageDispatcher::read(Context*,
std::unique_ptr<Message> toRecvMsg) {
folly::Future<std::unique_ptr<Message>> MessageDispatcher::operator()(
std::unique_ptr<Message> toSendMsg) {
CELEBORN_CHECK(!closed_);
- CELEBORN_CHECK(toSendMsg->type() == Message::RPC_REQUEST);
- RpcRequest* request = reinterpret_cast<RpcRequest*>(toSendMsg.get());
+ auto currTime = std::chrono::system_clock::now();
+ long requestId;
+ switch (toSendMsg->type()) {
+ case Message::RPC_REQUEST: {
+ RpcRequest* request = reinterpret_cast<RpcRequest*>(toSendMsg.get());
+ requestId = request->requestId();
+ break;
+ }
+ case Message::PUSH_DATA: {
+ PushData* pushData = reinterpret_cast<PushData*>(toSendMsg.get());
+ requestId = pushData->requestId();
+ break;
+ }
+ default: {
+ CELEBORN_FAIL("unsupported type");
+ }
+ }
+
auto f = requestIdRegistry_.withLock(
[&](auto& registry) -> folly::Future<std::unique_ptr<Message>> {
- auto& holder = registry[request->requestId()];
- holder.requestTime = std::chrono::system_clock::now();
+ auto& holder = registry[requestId];
+ holder.requestTime = currTime;
auto& p = holder.msgPromise;
- p.setInterruptHandler([requestId = request->requestId(),
+ p.setInterruptHandler([requestId,
this](const folly::exception_wrapper&) {
this->requestIdRegistry_.lock()->erase(requestId);
LOG(WARNING) << "rpc request interrupted, requestId: " << requestId;
@@ -150,6 +166,11 @@ folly::Future<std::unique_ptr<Message>>
MessageDispatcher::operator()(
return f;
}
+folly::Future<std::unique_ptr<Message>> MessageDispatcher::sendPushDataRequest(
+ std::unique_ptr<Message> toSendMsg) {
+ return (*this)(std::move(toSendMsg));
+}
+
folly::Future<std::unique_ptr<Message>>
MessageDispatcher::sendFetchChunkRequest(
const protocol::StreamChunkSlice& streamChunkSlice,
diff --git a/cpp/celeborn/network/MessageDispatcher.h
b/cpp/celeborn/network/MessageDispatcher.h
index 676ef081b..7ce6a7afe 100644
--- a/cpp/celeborn/network/MessageDispatcher.h
+++ b/cpp/celeborn/network/MessageDispatcher.h
@@ -60,6 +60,9 @@ class MessageDispatcher : public wangle::ClientDispatcherBase<
return operator()(std::move(toSendMsg));
}
+ virtual folly::Future<std::unique_ptr<Message>> sendPushDataRequest(
+ std::unique_ptr<Message> toSendMsg);
+
virtual folly::Future<std::unique_ptr<Message>> sendFetchChunkRequest(
const protocol::StreamChunkSlice& streamChunkSlice,
std::unique_ptr<Message> toSendMsg);
diff --git a/cpp/celeborn/network/TransportClient.cpp
b/cpp/celeborn/network/TransportClient.cpp
index 796eb5d50..c3e4d81fd 100644
--- a/cpp/celeborn/network/TransportClient.cpp
+++ b/cpp/celeborn/network/TransportClient.cpp
@@ -79,6 +79,43 @@ void TransportClient::sendRpcRequestWithoutResponse(const
RpcRequest& request) {
}
}
+void TransportClient::pushDataAsync(
+ const PushData& pushData,
+ Timeout timeout,
+ std::shared_ptr<RpcResponseCallback> callback) {
+ try {
+ auto requestMsg = std::make_unique<PushData>(pushData);
+ auto future = dispatcher_->sendPushDataRequest(std::move(requestMsg));
+ std::move(future)
+ .within(timeout)
+ .thenValue(
+ [_callback = callback](std::unique_ptr<Message> responseMsg) {
+ if (responseMsg->type() == Message::RPC_RESPONSE) {
+ auto rpcResponse =
+ reinterpret_cast<RpcResponse*>(responseMsg.get());
+ _callback->onSuccess(rpcResponse->body());
+ } else {
+ _callback->onFailure(std::make_unique<std::runtime_error>(
+ "pushData return value type is not rpcResponse"));
+ }
+ })
+ .thenError([_callback = callback](const folly::exception_wrapper& e) {
+ _callback->onFailure(
+ std::make_unique<std::runtime_error>(e.what().toStdString()));
+ });
+
+ } catch (std::exception& e) {
+ auto errorMsg = fmt::format(
+ "PushData failed. shuffleKey: {}, partitionUniqueId: {}, mode: {},
error message: {}",
+ pushData.shuffleKey(),
+ pushData.partitionUniqueId(),
+ pushData.mode(),
+ e.what());
+ LOG(ERROR) << errorMsg;
+ callback->onFailure(std::make_unique<std::runtime_error>(errorMsg));
+ }
+}
+
void TransportClient::fetchChunkAsync(
const protocol::StreamChunkSlice& streamChunkSlice,
const RpcRequest& request,
diff --git a/cpp/celeborn/network/TransportClient.h
b/cpp/celeborn/network/TransportClient.h
index 5bfa4afdc..e3ece7d22 100644
--- a/cpp/celeborn/network/TransportClient.h
+++ b/cpp/celeborn/network/TransportClient.h
@@ -54,6 +54,17 @@ using FetchChunkSuccessCallback = std::function<void(
using FetchChunkFailureCallback = std::function<
void(protocol::StreamChunkSlice, std::unique_ptr<std::exception>)>;
+class RpcResponseCallback {
+ public:
+ RpcResponseCallback() = default;
+
+ virtual ~RpcResponseCallback() = default;
+
+ virtual void onSuccess(std::unique_ptr<memory::ReadOnlyByteBuffer>) = 0;
+
+ virtual void onFailure(std::unique_ptr<std::exception> exception) = 0;
+};
+
/**
* TransportClient sends the messages to the network layer, and handles
* the message callback, timeout, error handling, etc.
@@ -76,6 +87,11 @@ class TransportClient {
// Ignore the response, return immediately.
virtual void sendRpcRequestWithoutResponse(const RpcRequest& request);
+ virtual void pushDataAsync(
+ const PushData& pushData,
+ Timeout timeout,
+ std::shared_ptr<RpcResponseCallback> callback);
+
virtual void fetchChunkAsync(
const protocol::StreamChunkSlice& streamChunkSlice,
const RpcRequest& request,
diff --git a/cpp/celeborn/network/tests/MessageDispatcherTest.cpp
b/cpp/celeborn/network/tests/MessageDispatcherTest.cpp
index 4d5dbc541..6127548d6 100644
--- a/cpp/celeborn/network/tests/MessageDispatcherTest.cpp
+++ b/cpp/celeborn/network/tests/MessageDispatcherTest.cpp
@@ -127,6 +127,92 @@ TEST(MessageDispatcherTest,
sendRpcRequestAndReceiveFailure) {
EXPECT_TRUE(future.hasException());
}
+TEST(MessageDispatcherTest, sendPushDataAndReceiveSuccess) {
+ std::unique_ptr<Message> sentMsg;
+ MockHandler mockHandler(sentMsg);
+ auto mockPipeline = createMockedPipeline(std::move(mockHandler));
+ auto dispatcher = std::make_unique<MessageDispatcher>();
+ dispatcher->setPipeline(mockPipeline.get());
+
+ const long requestId = 1001;
+ const uint8_t mode = 2;
+ const std::string shuffleKey = "test-shuffle-key";
+ const std::string partitionUniqueId = "test-partition-id";
+ const std::string requestBody = "test-request-body";
+ auto pushData = std::make_unique<PushData>(
+ requestId,
+ mode,
+ shuffleKey,
+ partitionUniqueId,
+ toReadOnlyByteBuffer(requestBody));
+ auto future = dispatcher->sendPushDataRequest(std::move(pushData));
+
+ EXPECT_FALSE(future.isReady());
+ EXPECT_EQ(sentMsg->type(), Message::PUSH_DATA);
+ auto sentPushData = dynamic_cast<PushData*>(sentMsg.get());
+ EXPECT_EQ(sentPushData->requestId(), requestId);
+ EXPECT_EQ(sentPushData->mode(), mode);
+ EXPECT_EQ(sentPushData->shuffleKey(), shuffleKey);
+ EXPECT_EQ(sentPushData->partitionUniqueId(), partitionUniqueId);
+ EXPECT_EQ(sentPushData->body()->remainingSize(), requestBody.size());
+ EXPECT_EQ(
+ sentPushData->body()->readToString(requestBody.size()), requestBody);
+
+ const std::string responseBody = "test-response-body";
+ auto rpcResponse = std::make_unique<RpcResponse>(
+ requestId, toReadOnlyByteBuffer(responseBody));
+ dispatcher->read(nullptr, std::move(rpcResponse));
+
+ EXPECT_TRUE(future.isReady());
+ auto receivedMsg = std::move(future).get();
+ EXPECT_EQ(receivedMsg->type(), Message::RPC_RESPONSE);
+ auto receivedRpcResponse = dynamic_cast<RpcResponse*>(receivedMsg.get());
+ EXPECT_EQ(receivedRpcResponse->body()->remainingSize(), responseBody.size());
+ EXPECT_EQ(
+ receivedRpcResponse->body()->readToString(responseBody.size()),
+ responseBody);
+}
+
+TEST(MessageDispatcherTest, sendPushDataAndReceiveFailure) {
+ std::unique_ptr<Message> sentMsg;
+ MockHandler mockHandler(sentMsg);
+ auto mockPipeline = createMockedPipeline(std::move(mockHandler));
+ auto dispatcher = std::make_unique<MessageDispatcher>();
+ dispatcher->setPipeline(mockPipeline.get());
+
+ const long requestId = 1001;
+ const uint8_t mode = 2;
+ const std::string shuffleKey = "test-shuffle-key";
+ const std::string partitionUniqueId = "test-partition-id";
+ const std::string requestBody = "test-request-body";
+ auto pushData = std::make_unique<PushData>(
+ requestId,
+ mode,
+ shuffleKey,
+ partitionUniqueId,
+ toReadOnlyByteBuffer(requestBody));
+ auto future = dispatcher->sendPushDataRequest(std::move(pushData));
+
+ EXPECT_FALSE(future.isReady());
+ EXPECT_EQ(sentMsg->type(), Message::PUSH_DATA);
+ auto sentPushData = dynamic_cast<PushData*>(sentMsg.get());
+ EXPECT_EQ(sentPushData->requestId(), requestId);
+ EXPECT_EQ(sentPushData->mode(), mode);
+ EXPECT_EQ(sentPushData->shuffleKey(), shuffleKey);
+ EXPECT_EQ(sentPushData->partitionUniqueId(), partitionUniqueId);
+ EXPECT_EQ(sentPushData->body()->remainingSize(), requestBody.size());
+ EXPECT_EQ(
+ sentPushData->body()->readToString(requestBody.size()), requestBody);
+
+ const std::string errorMsg = "test-error-msg";
+ auto copiedErrorMsg = errorMsg;
+ auto rpcFailure =
+ std::make_unique<RpcFailure>(requestId, std::move(copiedErrorMsg));
+ dispatcher->read(nullptr, std::move(rpcFailure));
+
+ EXPECT_TRUE(future.hasException());
+}
+
TEST(MessageDispatcherTest, sendFetchChunkRequestAndReceiveSuccess) {
std::unique_ptr<Message> sentMsg;
MockHandler mockHandler(sentMsg);
diff --git a/cpp/celeborn/network/tests/TransportClientTest.cpp
b/cpp/celeborn/network/tests/TransportClientTest.cpp
index 82740ad67..0135d61e5 100644
--- a/cpp/celeborn/network/tests/TransportClientTest.cpp
+++ b/cpp/celeborn/network/tests/TransportClientTest.cpp
@@ -15,6 +15,7 @@
* limitations under the License.
*/
+#include <folly/init/Init.h>
#include <gtest/gtest.h>
#include "celeborn/network/TransportClient.h"
@@ -38,6 +39,13 @@ class MockDispatcher : public MessageDispatcher {
sentMsg_ = std::move(toSendMsg);
}
+ folly::Future<std::unique_ptr<Message>> sendPushDataRequest(
+ std::unique_ptr<Message> toSendMsg) override {
+ sentMsg_ = std::move(toSendMsg);
+ msgPromise_ = MsgPromise();
+ return msgPromise_.getFuture();
+ }
+
folly::Future<std::unique_ptr<Message>> sendFetchChunkRequest(
const protocol::StreamChunkSlice& streamChunkSlice,
std::unique_ptr<Message> toSendMsg) override {
@@ -66,9 +74,61 @@ std::unique_ptr<memory::ReadOnlyByteBuffer>
toReadOnlyByteBuffer(
buffer->writeFromString(content);
return memory::ByteBuffer::toReadOnly(std::move(buffer));
}
+
+class MockRpcResponseCallback : public RpcResponseCallback {
+ public:
+ MockRpcResponseCallback() = default;
+
+ ~MockRpcResponseCallback() override = default;
+
+ void onSuccess(std::unique_ptr<memory::ReadOnlyByteBuffer> data) override {
+ onSuccessBuffer_ = std::move(data);
+ }
+
+ void onFailure(std::unique_ptr<std::exception> exception) override {
+ onFailureException_ = std::move(exception);
+ }
+
+ std::unique_ptr<memory::ReadOnlyByteBuffer> getOnSuccessBuffer() {
+ return std::move(onSuccessBuffer_);
+ }
+
+ std::unique_ptr<std::exception> getOnFailureException() {
+ return std::move(onFailureException_);
+ }
+
+ private:
+ std::unique_ptr<memory::ReadOnlyByteBuffer> onSuccessBuffer_;
+ std::unique_ptr<std::exception> onFailureException_;
+};
} // namespace
-TEST(TransportClientTest, sendRpcRequestSync) {
+class TransportClientTest : public testing::Test {
+ protected:
+ TransportClientTest() {
+ if (!follyInit_) {
+ std::lock_guard<std::mutex> lock(initMutex_);
+ if (!follyInit_) {
+ int argc = 0;
+ char* arg = "test-arg";
+ char** argv = &arg;
+ follyInit_ = std::make_unique<folly::Init>(&argc, &argv, false);
+ }
+ }
+ }
+
+ ~TransportClientTest() override = default;
+
+ private:
+ // Must be only inited once per process.
+ static std::unique_ptr<folly::Init> follyInit_;
+ static std::mutex initMutex_;
+};
+
+std::unique_ptr<folly::Init> TransportClientTest::follyInit_ = {};
+std::mutex TransportClientTest::initMutex_ = {};
+
+TEST_F(TransportClientTest, sendRpcRequestSync) {
auto mockDispatcher = std::make_unique<MockDispatcher>();
auto rawMockDispatcher = mockDispatcher.get();
const auto timeoutInterval = MS(10000);
@@ -109,7 +169,7 @@ TEST(TransportClientTest, sendRpcRequestSync) {
responseBody);
}
-TEST(TransportClientTest, sendRpcRequestSyncTimeout) {
+TEST_F(TransportClientTest, sendRpcRequestSyncTimeout) {
auto mockDispatcher = std::make_unique<MockDispatcher>();
auto rawMockDispatcher = mockDispatcher.get();
const auto timeoutInterval = MS(200);
@@ -149,7 +209,7 @@ TEST(TransportClientTest, sendRpcRequestSyncTimeout) {
EXPECT_TRUE(timeoutHappened);
}
-TEST(TransportClientTest, sendRpcRequestWithoutResponse) {
+TEST_F(TransportClientTest, sendRpcRequestWithoutResponse) {
auto mockDispatcher = std::make_unique<MockDispatcher>();
auto rawMockDispatcher = mockDispatcher.get();
const auto timeoutInterval = MS(10000);
@@ -169,7 +229,149 @@ TEST(TransportClientTest, sendRpcRequestWithoutResponse) {
sentRpcRequest->body()->readToString(requestBody.size()), requestBody);
}
-TEST(TransportClientTest, fetchChunkAsyncSuccess) {
+TEST_F(TransportClientTest, pushDataAsyncSuccess) {
+ // Construct mock utils.
+ auto mockDispatcher = std::make_unique<MockDispatcher>();
+ auto rawMockDispatcher = mockDispatcher.get();
+ TransportClient client(nullptr, std::move(mockDispatcher), MS(10000));
+ auto mockRpcResponseCallback = std::make_shared<MockRpcResponseCallback>();
+
+ // Construct toSend PushData.
+ const long requestId = 1001;
+ const uint8_t mode = 2;
+ const std::string shuffleKey = "test-shuffle-key";
+ const std::string partitionUniqueId = "test-partition-id";
+ const std::string requestBody = "test-request-body";
+ auto pushData = std::make_unique<PushData>(
+ requestId,
+ mode,
+ shuffleKey,
+ partitionUniqueId,
+ toReadOnlyByteBuffer(requestBody));
+
+ // Send pushData via client, check that the sentPushData is identical to
+ // original pushData.
+ client.pushDataAsync(*pushData, MS(10000), mockRpcResponseCallback);
+ auto sentMsg = rawMockDispatcher->getSentMsg();
+ EXPECT_EQ(sentMsg->type(), Message::PUSH_DATA);
+ auto sentPushData = dynamic_cast<PushData*>(sentMsg.get());
+ EXPECT_EQ(sentPushData->requestId(), requestId);
+ EXPECT_EQ(sentPushData->mode(), mode);
+ EXPECT_EQ(sentPushData->shuffleKey(), shuffleKey);
+ EXPECT_EQ(sentPushData->partitionUniqueId(), partitionUniqueId);
+ EXPECT_EQ(sentPushData->body()->remainingSize(), requestBody.size());
+ EXPECT_EQ(
+ sentPushData->body()->readToString(requestBody.size()), requestBody);
+ EXPECT_FALSE(mockRpcResponseCallback->getOnSuccessBuffer());
+
+ // Construct response and make dispatcher receive it.
+ const std::string responseBody = "test-response-body";
+ auto rpcResponse = std::make_unique<RpcResponse>(
+ requestId, toReadOnlyByteBuffer(responseBody));
+ rawMockDispatcher->receiveMsg(std::move(rpcResponse));
+
+ // Check the received message.
+ auto onSuccessBuffer = mockRpcResponseCallback->getOnSuccessBuffer();
+ auto onFailureException = mockRpcResponseCallback->getOnFailureException();
+ EXPECT_TRUE(onSuccessBuffer);
+ EXPECT_FALSE(onFailureException);
+ EXPECT_EQ(onSuccessBuffer->remainingSize(), responseBody.size());
+ EXPECT_EQ(onSuccessBuffer->readToString(responseBody.size()), responseBody);
+}
+
+TEST_F(TransportClientTest, pushDataAsyncFailure) {
+ // Construct mock utils.
+ auto mockDispatcher = std::make_unique<MockDispatcher>();
+ auto rawMockDispatcher = mockDispatcher.get();
+ TransportClient client(nullptr, std::move(mockDispatcher), MS(10000));
+ auto mockRpcResponseCallback = std::make_shared<MockRpcResponseCallback>();
+
+ // Construct toSend PushData.
+ const long requestId = 1001;
+ const uint8_t mode = 2;
+ const std::string shuffleKey = "test-shuffle-key";
+ const std::string partitionUniqueId = "test-partition-id";
+ const std::string requestBody = "test-request-body";
+ auto pushData = std::make_unique<PushData>(
+ requestId,
+ mode,
+ shuffleKey,
+ partitionUniqueId,
+ toReadOnlyByteBuffer(requestBody));
+
+ // Send pushData via client, check that the sentPushData is identical to
+ // original pushData.
+ client.pushDataAsync(*pushData, MS(10000), mockRpcResponseCallback);
+ auto sentMsg = rawMockDispatcher->getSentMsg();
+ EXPECT_EQ(sentMsg->type(), Message::PUSH_DATA);
+ auto sentPushData = dynamic_cast<PushData*>(sentMsg.get());
+ EXPECT_EQ(sentPushData->requestId(), requestId);
+ EXPECT_EQ(sentPushData->mode(), mode);
+ EXPECT_EQ(sentPushData->shuffleKey(), shuffleKey);
+ EXPECT_EQ(sentPushData->partitionUniqueId(), partitionUniqueId);
+ EXPECT_EQ(sentPushData->body()->remainingSize(), requestBody.size());
+ EXPECT_EQ(
+ sentPushData->body()->readToString(requestBody.size()), requestBody);
+ EXPECT_FALSE(mockRpcResponseCallback->getOnSuccessBuffer());
+
+ // Construct failure and make dispatcher receive it.
+ auto rpcFailure = std::make_unique<RpcFailure>(requestId,
"failure-msg-body");
+ rawMockDispatcher->receiveMsg(std::move(rpcFailure));
+
+ // Check the received message.
+ auto onSuccessBuffer = mockRpcResponseCallback->getOnSuccessBuffer();
+ auto onFailureException = mockRpcResponseCallback->getOnFailureException();
+ EXPECT_FALSE(onSuccessBuffer);
+ EXPECT_TRUE(onFailureException);
+}
+
+TEST_F(TransportClientTest, pushDataAsyncTimeout) {
+ // Construct mock utils.
+ auto mockDispatcher = std::make_unique<MockDispatcher>();
+ auto rawMockDispatcher = mockDispatcher.get();
+ TransportClient client(nullptr, std::move(mockDispatcher), MS(10000));
+ auto mockRpcResponseCallback = std::make_shared<MockRpcResponseCallback>();
+
+ // Construct toSend PushData.
+ const long requestId = 1001;
+ const uint8_t mode = 2;
+ const std::string shuffleKey = "test-shuffle-key";
+ const std::string partitionUniqueId = "test-partition-id";
+ const std::string requestBody = "test-request-body";
+ auto pushData = std::make_unique<PushData>(
+ requestId,
+ mode,
+ shuffleKey,
+ partitionUniqueId,
+ toReadOnlyByteBuffer(requestBody));
+
+ // Send pushData via client, check that the sentPushData is identical to
+ // original pushData.
+ auto timeoutInterval = MS(100);
+ client.pushDataAsync(*pushData, timeoutInterval, mockRpcResponseCallback);
+ auto sentMsg = rawMockDispatcher->getSentMsg();
+ EXPECT_EQ(sentMsg->type(), Message::PUSH_DATA);
+ auto sentPushData = dynamic_cast<PushData*>(sentMsg.get());
+ EXPECT_EQ(sentPushData->requestId(), requestId);
+ EXPECT_EQ(sentPushData->mode(), mode);
+ EXPECT_EQ(sentPushData->shuffleKey(), shuffleKey);
+ EXPECT_EQ(sentPushData->partitionUniqueId(), partitionUniqueId);
+ EXPECT_EQ(sentPushData->body()->remainingSize(), requestBody.size());
+ EXPECT_EQ(
+ sentPushData->body()->readToString(requestBody.size()), requestBody);
+ EXPECT_FALSE(mockRpcResponseCallback->getOnSuccessBuffer());
+
+ // Wait for timeout.
+ std::this_thread::sleep_for(timeoutInterval * 3);
+
+ // Check the received message.
+ auto onSuccessBuffer = mockRpcResponseCallback->getOnSuccessBuffer();
+ auto onFailureException = mockRpcResponseCallback->getOnFailureException();
+ EXPECT_FALSE(onSuccessBuffer);
+ EXPECT_TRUE(onFailureException);
+}
+
+TEST_F(TransportClientTest, fetchChunkAsyncSuccess) {
auto mockDispatcher = std::make_unique<MockDispatcher>();
auto rawMockDispatcher = mockDispatcher.get();
TransportClient client(nullptr, std::move(mockDispatcher), MS(10000));
@@ -218,7 +420,7 @@ TEST(TransportClientTest, fetchChunkAsyncSuccess) {
EXPECT_EQ(onSuccessBuffer->readToString(responseBody.size()), responseBody);
}
-TEST(TransportClientTest, fetchChunkAsyncFailure) {
+TEST_F(TransportClientTest, fetchChunkAsyncFailure) {
auto mockDispatcher = std::make_unique<MockDispatcher>();
auto rawMockDispatcher = mockDispatcher.get();
TransportClient client(nullptr, std::move(mockDispatcher), MS(10000));