This is an automated email from the ASF dual-hosted git repository.
nicholasjiang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new ffff5bb94 [CELEBORN-2196][CIP-14] Support ReviveManager in CppClient
ffff5bb94 is described below
commit ffff5bb94efd7c3a2c6c51aebccfd3c11b5d5198
Author: HolyLow <[email protected]>
AuthorDate: Mon Nov 17 11:58:16 2025 +0800
[CELEBORN-2196][CIP-14] Support ReviveManager in CppClient
### What changes were proposed in this pull request?
This PR supports ReviveManager in CppClient.
### Why are the changes needed?
ReviveManager is the building component for writing procedure of CppClient.
### Does this PR resolve a correctness bug?
No.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Compilation and UTs.
Closes #3538 from
HolyLow/issue/celeborn-2196-support-ReviveManager-in-cpp-client.
Authored-by: HolyLow <[email protected]>
Signed-off-by: SteNicholas <[email protected]>
---
cpp/celeborn/client/CMakeLists.txt | 1 +
cpp/celeborn/client/ShuffleClient.cpp | 19 ++
cpp/celeborn/client/ShuffleClient.h | 53 +++++-
cpp/celeborn/client/tests/CMakeLists.txt | 1 +
cpp/celeborn/client/tests/ReviveManagerTest.cpp | 237 ++++++++++++++++++++++++
cpp/celeborn/client/writer/ReviveManager.cpp | 160 ++++++++++++++++
cpp/celeborn/client/writer/ReviveManager.h | 73 ++++++++
cpp/celeborn/conf/CelebornConf.cpp | 11 ++
cpp/celeborn/conf/CelebornConf.h | 10 +
cpp/celeborn/tests/DataSumWithReaderClient.cpp | 2 +-
10 files changed, 564 insertions(+), 3 deletions(-)
diff --git a/cpp/celeborn/client/CMakeLists.txt
b/cpp/celeborn/client/CMakeLists.txt
index 1af9069ee..fa0290733 100644
--- a/cpp/celeborn/client/CMakeLists.txt
+++ b/cpp/celeborn/client/CMakeLists.txt
@@ -18,6 +18,7 @@ add_library(
reader/CelebornInputStream.cpp
writer/PushState.cpp
writer/PushStrategy.cpp
+ writer/ReviveManager.cpp
ShuffleClient.cpp
compress/Decompressor.cpp
compress/Lz4Decompressor.cpp
diff --git a/cpp/celeborn/client/ShuffleClient.cpp
b/cpp/celeborn/client/ShuffleClient.cpp
index 6e29db909..22d19256c 100644
--- a/cpp/celeborn/client/ShuffleClient.cpp
+++ b/cpp/celeborn/client/ShuffleClient.cpp
@@ -21,6 +21,14 @@
namespace celeborn {
namespace client {
+std::shared_ptr<ShuffleClientImpl> ShuffleClientImpl::create(
+ const std::string& appUniqueId,
+ const std::shared_ptr<const conf::CelebornConf>& conf,
+ const std::shared_ptr<network::TransportClientFactory>& clientFactory) {
+ return std::shared_ptr<ShuffleClientImpl>(
+ new ShuffleClientImpl(appUniqueId, conf, clientFactory));
+}
+
ShuffleClientImpl::ShuffleClientImpl(
const std::string& appUniqueId,
const std::shared_ptr<const conf::CelebornConf>& conf,
@@ -135,6 +143,17 @@ bool ShuffleClientImpl::cleanupShuffle(int shuffleId) {
return true;
}
+bool ShuffleClientImpl::newerPartitionLocationExists(
+ std::shared_ptr<utils::ConcurrentHashMap<
+ int,
+ std::shared_ptr<const protocol::PartitionLocation>>> locationMap,
+ int partitionId,
+ int epoch) {
+ auto locationOptional = locationMap->get(partitionId);
+ return locationOptional.has_value() &&
+ locationOptional.value()->epoch > epoch;
+}
+
std::shared_ptr<protocol::GetReducerFileGroupResponse>
ShuffleClientImpl::getReducerFileGroupInfo(int shuffleId) {
auto reducerFileGroupInfoOptional = reducerFileGroupInfos_.get(shuffleId);
diff --git a/cpp/celeborn/client/ShuffleClient.h
b/cpp/celeborn/client/ShuffleClient.h
index b56c60cf8..41e953515 100644
--- a/cpp/celeborn/client/ShuffleClient.h
+++ b/cpp/celeborn/client/ShuffleClient.h
@@ -51,9 +51,21 @@ class ShuffleClient {
virtual void shutdown() = 0;
};
-class ShuffleClientImpl : public ShuffleClient {
+class ShuffleClientImpl
+ : public ShuffleClient,
+ public std::enable_shared_from_this<ShuffleClientImpl> {
public:
- ShuffleClientImpl(
+ friend class ReviveManager;
+
+ using PtrReviveRequest = std::shared_ptr<protocol::ReviveRequest>;
+ using PartitionLocationMap = utils::ConcurrentHashMap<
+ int,
+ std::shared_ptr<const protocol::PartitionLocation>>;
+ using PtrPartitionLocationMap = std::shared_ptr<PartitionLocationMap>;
+
+ // Only allow construction from create() method to ensure that functionality
+ // of std::shared_from_this works properly.
+ static std::shared_ptr<ShuffleClientImpl> create(
const std::string& appUniqueId,
const std::shared_ptr<const conf::CelebornConf>& conf,
const std::shared_ptr<network::TransportClientFactory>& clientFactory);
@@ -84,7 +96,43 @@ class ShuffleClientImpl : public ShuffleClient {
void shutdown() override {}
+ protected:
+ // The constructor is hidden to ensure that functionality of
+ // std::shared_from_this works properly.
+ ShuffleClientImpl(
+ const std::string& appUniqueId,
+ const std::shared_ptr<const conf::CelebornConf>& conf,
+ const std::shared_ptr<network::TransportClientFactory>& clientFactory);
+
+ // TODO: currently this function serves as a stub. will be updated in future
+ // commits.
+ virtual bool mapperEnded(int shuffleId, int mapId) {
+ return true;
+ }
+
+ // TODO: currently this function serves as a stub. will be updated in future
+ // commits.
+ virtual std::optional<std::unordered_map<int, int>> reviveBatch(
+ int shuffleId,
+ const std::unordered_set<int>& mapIds,
+ const std::unordered_map<int, PtrReviveRequest>& requests) {
+ return std::nullopt;
+ }
+
+ virtual std::optional<PtrPartitionLocationMap> getPartitionLocationMap(
+ int shuffleId) {
+ return partitionLocationMaps_.get(shuffleId);
+ }
+
private:
+ // TODO: no support for WAIT as it is not used.
+ static bool newerPartitionLocationExists(
+ std::shared_ptr<utils::ConcurrentHashMap<
+ int,
+ std::shared_ptr<const protocol::PartitionLocation>>> locationMap,
+ int partitionId,
+ int epoch);
+
std::shared_ptr<protocol::GetReducerFileGroupResponse>
getReducerFileGroupInfo(int shuffleId);
@@ -97,6 +145,7 @@ class ShuffleClientImpl : public ShuffleClient {
int,
std::shared_ptr<protocol::GetReducerFileGroupResponse>>
reducerFileGroupInfos_;
+ utils::ConcurrentHashMap<int, PtrPartitionLocationMap>
partitionLocationMaps_;
};
} // namespace client
} // namespace celeborn
diff --git a/cpp/celeborn/client/tests/CMakeLists.txt
b/cpp/celeborn/client/tests/CMakeLists.txt
index 6341c84e7..8183611c6 100644
--- a/cpp/celeborn/client/tests/CMakeLists.txt
+++ b/cpp/celeborn/client/tests/CMakeLists.txt
@@ -17,6 +17,7 @@ add_executable(
celeborn_client_test
WorkerPartitionReaderTest.cpp
PushStateTest.cpp
+ ReviveManagerTest.cpp
Lz4DecompressorTest.cpp
ZstdDecompressorTest.cpp
Lz4CompressorTest.cpp
diff --git a/cpp/celeborn/client/tests/ReviveManagerTest.cpp
b/cpp/celeborn/client/tests/ReviveManagerTest.cpp
new file mode 100644
index 000000000..8db844485
--- /dev/null
+++ b/cpp/celeborn/client/tests/ReviveManagerTest.cpp
@@ -0,0 +1,237 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <gtest/gtest.h>
+
+#include "celeborn/client/writer/ReviveManager.h"
+
+using namespace celeborn;
+using namespace celeborn::client;
+
+namespace {
+class MockShuffleClient : public ShuffleClientImpl {
+ public:
+ friend class ReviveManager;
+
+ static std::shared_ptr<MockShuffleClient> create() {
+ return std::shared_ptr<MockShuffleClient>(new MockShuffleClient());
+ }
+
+ virtual ~MockShuffleClient() = default;
+
+ bool mapperEnded(int shuffleId, int mapId) override {
+ return onMapperEnded_(shuffleId, mapId);
+ }
+
+ void setOnMapperEnded(std::function<bool(int, int)>&& onMapperEnded) {
+ onMapperEnded_ = onMapperEnded;
+ }
+
+ std::optional<std::unordered_map<int, int>> reviveBatch(
+ int shuffleId,
+ const std::unordered_set<int>& mapIds,
+ const std::unordered_map<int, PtrReviveRequest>& requests) override {
+ return onReviveBatch_(shuffleId, mapIds, requests);
+ }
+
+ void setOnReviveBatch(
+ std::function<std::optional<std::unordered_map<int, int>>(
+ int,
+ const std::unordered_set<int>&,
+ const std::unordered_map<int, PtrReviveRequest>&)>&& onReviveBatch) {
+ onReviveBatch_ = onReviveBatch;
+ }
+
+ std::optional<PtrPartitionLocationMap> getPartitionLocationMap(
+ int shuffleId) override {
+ return onGetPartitionLocationMap_(shuffleId);
+ }
+
+ void setOnGetPartitionLocationMap(
+ std::function<std::optional<PtrPartitionLocationMap>(int)>&&
+ onGetPartitionLocationMap) {
+ onGetPartitionLocationMap_ = onGetPartitionLocationMap;
+ }
+
+ private:
+ MockShuffleClient()
+ : ShuffleClientImpl(
+ "mock",
+ std::make_shared<conf::CelebornConf>(),
+ nullptr) {}
+ std::function<bool(int, int)> onMapperEnded_ = [](int, int) { return false;
};
+ std::function<std::optional<std::unordered_map<int, int>>(
+ int,
+ const std::unordered_set<int>&,
+ const std::unordered_map<int, PtrReviveRequest>&)>
+ onReviveBatch_ = [](int,
+ const std::unordered_set<int>&,
+ const std::unordered_map<int, PtrReviveRequest>&) {
+ return std::nullopt;
+ };
+ std::function<std::optional<PtrPartitionLocationMap>(int)>
+ onGetPartitionLocationMap_ =
+ [](int) -> std::optional<PtrPartitionLocationMap> {
+ return {std::make_shared<PartitionLocationMap>()};
+ };
+};
+} // namespace
+
+class ReviveManagerTest : public testing::Test {
+ protected:
+ void SetUp() override {
+ mockShuffleClient_ = MockShuffleClient::create();
+ auto conf = conf::CelebornConf();
+ conf.registerProperty(
+ conf::CelebornConf::kClientPushReviveInterval,
+ fmt::format("{}ms", pushReviveIntervalMs_));
+ reviveManager_ = ReviveManager::create(
+ "test", conf, mockShuffleClient_->weak_from_this());
+ }
+
+ std::shared_ptr<MockShuffleClient> mockShuffleClient_;
+ std::shared_ptr<ReviveManager> reviveManager_;
+
+ static constexpr int pushReviveIntervalMs_ = 100;
+};
+
+TEST_F(ReviveManagerTest, successOnMapperEnded) {
+ int mapperEndedCalledTimes = 0;
+ const int testShuffleId = 1000;
+ const int testMapId = 1001;
+ const int testAttemptId = 1002;
+ const int testPartitionId = 1003;
+ const int testEpoch = 1004;
+ auto reviveRequest = std::make_shared<protocol::ReviveRequest>(
+ testShuffleId,
+ testMapId,
+ testAttemptId,
+ testPartitionId,
+ testEpoch,
+ nullptr,
+ static_cast<protocol::StatusCode>(0));
+ mockShuffleClient_->setOnMapperEnded(
+ [=, &mapperEndedCalledTimes](int shuffleId, int mapId) mutable {
+ EXPECT_EQ(testShuffleId, shuffleId);
+ EXPECT_EQ(testMapId, mapId);
+ mapperEndedCalledTimes++;
+ return true;
+ });
+
+ EXPECT_EQ(
+ reviveRequest->reviveStatus, protocol::StatusCode::REVIVE_INITIALIZED);
+ reviveManager_->addRequest(reviveRequest);
+ std::this_thread::sleep_for(
+ std::chrono::milliseconds(pushReviveIntervalMs_ * 3));
+
+ EXPECT_GT(mapperEndedCalledTimes, 0);
+ EXPECT_EQ(reviveRequest->reviveStatus, protocol::StatusCode::SUCCESS);
+}
+
+TEST_F(ReviveManagerTest, successOnReviveSuccess) {
+ int reviveBatchCalledTimes = 0;
+ const int testShuffleId = 1000;
+ const int testMapId = 1001;
+ const int testAttemptId = 1002;
+ const int testPartitionId = 1003;
+ const int testEpoch = 1004;
+ auto reviveRequest = std::make_shared<protocol::ReviveRequest>(
+ testShuffleId,
+ testMapId,
+ testAttemptId,
+ testPartitionId,
+ testEpoch,
+ nullptr,
+ static_cast<protocol::StatusCode>(0));
+ mockShuffleClient_->setOnReviveBatch(
+ [=, &reviveBatchCalledTimes](
+ int shuffleId,
+ const std::unordered_set<int>& mapIds,
+ const std::unordered_map<int, ShuffleClientImpl::PtrReviveRequest>&
+ requestsToSend) mutable
+ -> std::optional<std::unordered_map<int, int>> {
+ EXPECT_EQ(testShuffleId, shuffleId);
+ EXPECT_GT(mapIds.count(testMapId), 0);
+ EXPECT_GT(requestsToSend.count(testPartitionId), 0);
+ auto request = requestsToSend.find(testPartitionId)->second;
+ EXPECT_EQ(request->shuffleId, testShuffleId);
+ EXPECT_EQ(request->mapId, testMapId);
+ EXPECT_EQ(request->attemptId, testAttemptId);
+ EXPECT_EQ(request->partitionId, testPartitionId);
+ EXPECT_EQ(request->epoch, testEpoch);
+
+ std::unordered_map<int, int> result;
+ result[request->partitionId] = protocol::StatusCode::SUCCESS;
+ ++reviveBatchCalledTimes;
+ return {result};
+ });
+
+ EXPECT_EQ(
+ reviveRequest->reviveStatus, protocol::StatusCode::REVIVE_INITIALIZED);
+ reviveManager_->addRequest(reviveRequest);
+ std::this_thread::sleep_for(
+ std::chrono::milliseconds(pushReviveIntervalMs_ * 3));
+
+ EXPECT_GT(reviveBatchCalledTimes, 0);
+ EXPECT_EQ(reviveRequest->reviveStatus, protocol::StatusCode::SUCCESS);
+}
+
+TEST_F(ReviveManagerTest, failureOnReviveFailure) {
+ int reviveBatchCalledTimes = 0;
+ const int testShuffleId = 1000;
+ const int testMapId = 1001;
+ const int testAttemptId = 1002;
+ const int testPartitionId = 1003;
+ const int testEpoch = 1004;
+ auto reviveRequest = std::make_shared<protocol::ReviveRequest>(
+ testShuffleId,
+ testMapId,
+ testAttemptId,
+ testPartitionId,
+ testEpoch,
+ nullptr,
+ static_cast<protocol::StatusCode>(0));
+ mockShuffleClient_->setOnReviveBatch(
+ [=, &reviveBatchCalledTimes](
+ int shuffleId,
+ const std::unordered_set<int>& mapIds,
+ const std::unordered_map<int, ShuffleClientImpl::PtrReviveRequest>&
+ requestsToSend) mutable
+ -> std::optional<std::unordered_map<int, int>> {
+ EXPECT_EQ(testShuffleId, shuffleId);
+ EXPECT_GT(mapIds.count(testMapId), 0);
+ EXPECT_GT(requestsToSend.count(testPartitionId), 0);
+ auto request = requestsToSend.find(testPartitionId)->second;
+ EXPECT_EQ(request->shuffleId, testShuffleId);
+ EXPECT_EQ(request->mapId, testMapId);
+ EXPECT_EQ(request->attemptId, testAttemptId);
+ EXPECT_EQ(request->partitionId, testPartitionId);
+ EXPECT_EQ(request->epoch, testEpoch);
+
+ ++reviveBatchCalledTimes;
+ return std::nullopt;
+ });
+
+ EXPECT_EQ(
+ reviveRequest->reviveStatus, protocol::StatusCode::REVIVE_INITIALIZED);
+ reviveManager_->addRequest(reviveRequest);
+ std::this_thread::sleep_for(
+ std::chrono::milliseconds(pushReviveIntervalMs_ * 3));
+
+ EXPECT_GT(reviveBatchCalledTimes, 0);
+ EXPECT_EQ(reviveRequest->reviveStatus, protocol::StatusCode::REVIVE_FAILED);
+}
diff --git a/cpp/celeborn/client/writer/ReviveManager.cpp
b/cpp/celeborn/client/writer/ReviveManager.cpp
new file mode 100644
index 000000000..47a548546
--- /dev/null
+++ b/cpp/celeborn/client/writer/ReviveManager.cpp
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "celeborn/client/writer/ReviveManager.h"
+
+namespace celeborn {
+namespace client {
+
+folly::FunctionScheduler ReviveManager::globalExecutor_ =
+ folly::FunctionScheduler();
+
+std::shared_ptr<ReviveManager> ReviveManager::create(
+ const std::string& name,
+ const conf::CelebornConf& conf,
+ std::weak_ptr<ShuffleClientImpl> weakClient) {
+ auto reviveManager =
+ std::shared_ptr<ReviveManager>(new ReviveManager(name, conf,
weakClient));
+ reviveManager->start();
+ return std::move(reviveManager);
+}
+
+ReviveManager::ReviveManager(
+ const std::string& name,
+ const conf::CelebornConf& conf,
+ std::weak_ptr<ShuffleClientImpl> weakClient)
+ : name_(name),
+ batchSize_(conf.clientPushReviveBatchSize()),
+ interval_(conf.clientPushReviveInterval()),
+ weakClient_(weakClient) {}
+
+ReviveManager::~ReviveManager() {
+ globalExecutor_.cancelFunction(name_);
+}
+
+void ReviveManager::start() {
+ bool expected = false;
+ if (!started_.compare_exchange_strong(expected, true)) {
+ return;
+ }
+ globalExecutor_.start();
+ auto task = [weak_this = weak_from_this(), batchSize = batchSize_]() {
+ try {
+ auto shared_this = weak_this.lock();
+ if (!shared_this) {
+ return;
+ }
+ bool continueFlag = true;
+ do {
+ std::unordered_map<
+ int,
+ std::unique_ptr<std::unordered_set<PtrReviveRequest>>>
+ shuffleMap;
+ std::vector<PtrReviveRequest> requests;
+ shared_this->requestQueue_.withLock([&](auto& queue) {
+ for (int i = 0; i < batchSize && !queue.empty(); i++) {
+ requests.push_back(queue.front());
+ queue.pop();
+ }
+ });
+ if (requests.empty()) {
+ break;
+ }
+ for (auto& request : requests) {
+ auto& set = shuffleMap[request->shuffleId];
+ if (!set) {
+ set = std::make_unique<std::unordered_set<PtrReviveRequest>>();
+ }
+ set->insert(request);
+ }
+ auto shuffleClient = shared_this->weakClient_.lock();
+ if (!shuffleClient) {
+ return;
+ }
+ for (auto& [shuffleId, requestSet] : shuffleMap) {
+ std::unordered_set<int> mapIds;
+ std::vector<PtrReviveRequest> filteredRequests;
+ std::unordered_map<int, PtrReviveRequest> requestsToSend;
+
+ auto locationMapOptional =
+ shuffleClient->getPartitionLocationMap(shuffleId);
+ CELEBORN_CHECK(locationMapOptional.has_value());
+ auto locationMap = locationMapOptional.value();
+ for (auto& request : *requestSet) {
+ if (shuffleClient->newerPartitionLocationExists(
+ locationMap, request->partitionId, request->epoch) ||
+ shuffleClient->mapperEnded(shuffleId, request->mapId)) {
+ request->reviveStatus = protocol::StatusCode::SUCCESS;
+ } else {
+ filteredRequests.push_back(request);
+ mapIds.insert(request->mapId);
+ if (auto iter = requestsToSend.find(request->partitionId);
+ iter == requestsToSend.end() ||
+ iter->second->epoch < request->epoch) {
+ requestsToSend[request->partitionId] = request;
+ }
+ }
+ }
+
+ if (requestsToSend.empty()) {
+ continue;
+ }
+ if (auto resultOptional =
+ shuffleClient->reviveBatch(shuffleId, mapIds,
requestsToSend);
+ resultOptional.has_value()) {
+ auto result = resultOptional.value();
+ for (auto& request : filteredRequests) {
+ if (shuffleClient->mapperEnded(shuffleId, request->mapId)) {
+ request->reviveStatus = protocol::StatusCode::SUCCESS;
+ } else {
+ request->reviveStatus = result[request->partitionId];
+ }
+ }
+ } else {
+ for (auto& request : filteredRequests) {
+ request->reviveStatus = protocol::StatusCode::REVIVE_FAILED;
+ }
+ }
+ }
+ continueFlag =
+ (shared_this->requestQueue_.lock()->size() > batchSize / 2);
+ } while (continueFlag);
+ } catch (std::exception& e) {
+ LOG(ERROR) << "ReviveManager error occurred: " << e.what();
+ }
+ };
+ startFunction(task);
+}
+
+void ReviveManager::startFunction(std::function<void()> task) {
+ try {
+ globalExecutor_.addFunction(task, interval_, name_, interval_);
+ } catch (std::exception& e) {
+ LOG(ERROR) << "startFunction failed, current function name " << name_
+ << ", retry again...";
+ name_ += "-";
+ name_ += std::to_string(rand() % 10000);
+ startFunction(task);
+ }
+}
+
+void ReviveManager::addRequest(PtrReviveRequest request) {
+ requestQueue_.withLock([&](auto& queue) { queue.push(std::move(request)); });
+}
+
+} // namespace client
+} // namespace celeborn
diff --git a/cpp/celeborn/client/writer/ReviveManager.h
b/cpp/celeborn/client/writer/ReviveManager.h
new file mode 100644
index 000000000..4f3b7962e
--- /dev/null
+++ b/cpp/celeborn/client/writer/ReviveManager.h
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <folly/experimental/FunctionScheduler.h>
+
+#include "celeborn/client/ShuffleClient.h"
+#include "celeborn/conf/CelebornConf.h"
+#include "celeborn/protocol/ControlMessages.h"
+
+namespace celeborn {
+namespace client {
+
+/// ReviveManager is responsible for buffering the ReviveRequests, and issue
+/// the revive requests periodically in batches.
+class ReviveManager : public std::enable_shared_from_this<ReviveManager> {
+ public:
+ using PtrReviveRequest = std::shared_ptr<protocol::ReviveRequest>;
+
+ // Only allow construction from create() method to ensure that functionality
+ // of std::shared_from_this works properly.
+ static std::shared_ptr<ReviveManager> create(
+ const std::string& name,
+ const conf::CelebornConf& conf,
+ std::weak_ptr<ShuffleClientImpl> weakClient);
+
+ ~ReviveManager();
+
+ void addRequest(PtrReviveRequest request);
+
+ private:
+ // The constructor is hidden to ensure that functionality of
+ // std::shared_from_this works properly.
+ ReviveManager(
+ const std::string& name,
+ const conf::CelebornConf& conf,
+ std::weak_ptr<ShuffleClientImpl> weakClient);
+
+ // The start() method must not be called within constructor, because for
+ // std::enable_shred_from_this the shared_from_this() or weak_from_this()
+ // would not work until the object construction is complete.
+ void start();
+
+ void startFunction(std::function<void()> task);
+
+ // Scheduler for issuing the requests periodically.
+ static folly::FunctionScheduler globalExecutor_;
+
+ std::string name_;
+ const int batchSize_;
+ const Timeout interval_;
+ std::weak_ptr<ShuffleClientImpl> weakClient_;
+ folly::Synchronized<std::queue<PtrReviveRequest>, std::mutex> requestQueue_;
+ std::atomic<bool> started_{false};
+};
+
+} // namespace client
+} // namespace celeborn
diff --git a/cpp/celeborn/conf/CelebornConf.cpp
b/cpp/celeborn/conf/CelebornConf.cpp
index a3e464332..ab1a7abf0 100644
--- a/cpp/celeborn/conf/CelebornConf.cpp
+++ b/cpp/celeborn/conf/CelebornConf.cpp
@@ -135,6 +135,8 @@ const std::unordered_map<std::string,
folly::Optional<std::string>>
CelebornConf::kDefaultProperties = {
STR_PROP(kRpcAskTimeout, "60s"),
STR_PROP(kRpcLookupTimeout, "30s"),
+ STR_PROP(kClientPushReviveInterval, "100ms"),
+ NUM_PROP(kClientPushReviveBatchSize, 2048),
STR_PROP(kClientPushLimitStrategy, kSimplePushStrategy),
NUM_PROP(kClientPushMaxReqsInFlightPerWorker, 32),
NUM_PROP(kClientPushMaxReqsInFlightTotal, 256),
@@ -190,6 +192,15 @@ Timeout CelebornConf::rpcLookupTimeout() const {
toDuration(optionalProperty(kRpcLookupTimeout).value()));
}
+Timeout CelebornConf::clientPushReviveInterval() const {
+ return utils::toTimeout(
+ toDuration(optionalProperty(kClientPushReviveInterval).value()));
+}
+
+int CelebornConf::clientPushReviveBatchSize() const {
+ return std::stoi(optionalProperty(kClientPushReviveBatchSize).value());
+}
+
std::string CelebornConf::clientPushLimitStrategy() const {
return optionalProperty(kClientPushLimitStrategy).value();
}
diff --git a/cpp/celeborn/conf/CelebornConf.h b/cpp/celeborn/conf/CelebornConf.h
index 98b5b9306..d6ce24f18 100644
--- a/cpp/celeborn/conf/CelebornConf.h
+++ b/cpp/celeborn/conf/CelebornConf.h
@@ -45,6 +45,12 @@ class CelebornConf : public BaseConf {
static constexpr std::string_view kRpcLookupTimeout{
"celeborn.rpc.lookupTimeout"};
+ static constexpr std::string_view kClientPushReviveInterval{
+ "celeborn.client.push.revive.interval"};
+
+ static constexpr std::string_view kClientPushReviveBatchSize{
+ "celeborn.client.push.revive.batchSize"};
+
static constexpr std::string_view kClientPushLimitStrategy{
"celeborn.client.push.limit.strategy"};
@@ -100,6 +106,10 @@ class CelebornConf : public BaseConf {
Timeout rpcLookupTimeout() const;
+ Timeout clientPushReviveInterval() const;
+
+ int clientPushReviveBatchSize() const;
+
std::string clientPushLimitStrategy() const;
int clientPushMaxReqsInFlightPerWorker() const;
diff --git a/cpp/celeborn/tests/DataSumWithReaderClient.cpp
b/cpp/celeborn/tests/DataSumWithReaderClient.cpp
index 8c060fa54..ac62d13fe 100644
--- a/cpp/celeborn/tests/DataSumWithReaderClient.cpp
+++ b/cpp/celeborn/tests/DataSumWithReaderClient.cpp
@@ -46,7 +46,7 @@ int main(int argc, char** argv) {
celeborn::conf::CelebornConf::kShuffleCompressionCodec, compressCodec);
auto clientFactory =
std::make_shared<celeborn::network::TransportClientFactory>(conf);
- auto shuffleClient = std::make_unique<celeborn::client::ShuffleClientImpl>(
+ auto shuffleClient = celeborn::client::ShuffleClientImpl::create(
appUniqueId, conf, clientFactory);
shuffleClient->setupLifecycleManagerRef(
lifecycleManagerHost, lifecycleManagerPort);