This is an automated email from the ASF dual-hosted git repository.

nicholasjiang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new ffff5bb94 [CELEBORN-2196][CIP-14] Support ReviveManager in CppClient
ffff5bb94 is described below

commit ffff5bb94efd7c3a2c6c51aebccfd3c11b5d5198
Author: HolyLow <[email protected]>
AuthorDate: Mon Nov 17 11:58:16 2025 +0800

    [CELEBORN-2196][CIP-14] Support ReviveManager in CppClient
    
    ### What changes were proposed in this pull request?
    This PR supports ReviveManager in CppClient.
    
    ### Why are the changes needed?
    ReviveManager is the building component for writing procedure of CppClient.
    
    ### Does this PR resolve a correctness bug?
    No.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Compilation and UTs.
    
    Closes #3538 from 
HolyLow/issue/celeborn-2196-support-ReviveManager-in-cpp-client.
    
    Authored-by: HolyLow <[email protected]>
    Signed-off-by: SteNicholas <[email protected]>
---
 cpp/celeborn/client/CMakeLists.txt              |   1 +
 cpp/celeborn/client/ShuffleClient.cpp           |  19 ++
 cpp/celeborn/client/ShuffleClient.h             |  53 +++++-
 cpp/celeborn/client/tests/CMakeLists.txt        |   1 +
 cpp/celeborn/client/tests/ReviveManagerTest.cpp | 237 ++++++++++++++++++++++++
 cpp/celeborn/client/writer/ReviveManager.cpp    | 160 ++++++++++++++++
 cpp/celeborn/client/writer/ReviveManager.h      |  73 ++++++++
 cpp/celeborn/conf/CelebornConf.cpp              |  11 ++
 cpp/celeborn/conf/CelebornConf.h                |  10 +
 cpp/celeborn/tests/DataSumWithReaderClient.cpp  |   2 +-
 10 files changed, 564 insertions(+), 3 deletions(-)

diff --git a/cpp/celeborn/client/CMakeLists.txt 
b/cpp/celeborn/client/CMakeLists.txt
index 1af9069ee..fa0290733 100644
--- a/cpp/celeborn/client/CMakeLists.txt
+++ b/cpp/celeborn/client/CMakeLists.txt
@@ -18,6 +18,7 @@ add_library(
         reader/CelebornInputStream.cpp
         writer/PushState.cpp
         writer/PushStrategy.cpp
+        writer/ReviveManager.cpp
         ShuffleClient.cpp
         compress/Decompressor.cpp
         compress/Lz4Decompressor.cpp
diff --git a/cpp/celeborn/client/ShuffleClient.cpp 
b/cpp/celeborn/client/ShuffleClient.cpp
index 6e29db909..22d19256c 100644
--- a/cpp/celeborn/client/ShuffleClient.cpp
+++ b/cpp/celeborn/client/ShuffleClient.cpp
@@ -21,6 +21,14 @@
 
 namespace celeborn {
 namespace client {
+std::shared_ptr<ShuffleClientImpl> ShuffleClientImpl::create(
+    const std::string& appUniqueId,
+    const std::shared_ptr<const conf::CelebornConf>& conf,
+    const std::shared_ptr<network::TransportClientFactory>& clientFactory) {
+  return std::shared_ptr<ShuffleClientImpl>(
+      new ShuffleClientImpl(appUniqueId, conf, clientFactory));
+}
+
 ShuffleClientImpl::ShuffleClientImpl(
     const std::string& appUniqueId,
     const std::shared_ptr<const conf::CelebornConf>& conf,
@@ -135,6 +143,17 @@ bool ShuffleClientImpl::cleanupShuffle(int shuffleId) {
   return true;
 }
 
+bool ShuffleClientImpl::newerPartitionLocationExists(
+    std::shared_ptr<utils::ConcurrentHashMap<
+        int,
+        std::shared_ptr<const protocol::PartitionLocation>>> locationMap,
+    int partitionId,
+    int epoch) {
+  auto locationOptional = locationMap->get(partitionId);
+  return locationOptional.has_value() &&
+      locationOptional.value()->epoch > epoch;
+}
+
 std::shared_ptr<protocol::GetReducerFileGroupResponse>
 ShuffleClientImpl::getReducerFileGroupInfo(int shuffleId) {
   auto reducerFileGroupInfoOptional = reducerFileGroupInfos_.get(shuffleId);
diff --git a/cpp/celeborn/client/ShuffleClient.h 
b/cpp/celeborn/client/ShuffleClient.h
index b56c60cf8..41e953515 100644
--- a/cpp/celeborn/client/ShuffleClient.h
+++ b/cpp/celeborn/client/ShuffleClient.h
@@ -51,9 +51,21 @@ class ShuffleClient {
   virtual void shutdown() = 0;
 };
 
-class ShuffleClientImpl : public ShuffleClient {
+class ShuffleClientImpl
+    : public ShuffleClient,
+      public std::enable_shared_from_this<ShuffleClientImpl> {
  public:
-  ShuffleClientImpl(
+  friend class ReviveManager;
+
+  using PtrReviveRequest = std::shared_ptr<protocol::ReviveRequest>;
+  using PartitionLocationMap = utils::ConcurrentHashMap<
+      int,
+      std::shared_ptr<const protocol::PartitionLocation>>;
+  using PtrPartitionLocationMap = std::shared_ptr<PartitionLocationMap>;
+
+  // Only allow construction from create() method to ensure that functionality
+  // of std::shared_from_this works properly.
+  static std::shared_ptr<ShuffleClientImpl> create(
       const std::string& appUniqueId,
       const std::shared_ptr<const conf::CelebornConf>& conf,
       const std::shared_ptr<network::TransportClientFactory>& clientFactory);
@@ -84,7 +96,43 @@ class ShuffleClientImpl : public ShuffleClient {
 
   void shutdown() override {}
 
+ protected:
+  // The constructor is hidden to ensure that functionality of
+  // std::shared_from_this works properly.
+  ShuffleClientImpl(
+      const std::string& appUniqueId,
+      const std::shared_ptr<const conf::CelebornConf>& conf,
+      const std::shared_ptr<network::TransportClientFactory>& clientFactory);
+
+  // TODO: currently this function serves as a stub. will be updated in future
+  //  commits.
+  virtual bool mapperEnded(int shuffleId, int mapId) {
+    return true;
+  }
+
+  // TODO: currently this function serves as a stub. will be updated in future
+  //  commits.
+  virtual std::optional<std::unordered_map<int, int>> reviveBatch(
+      int shuffleId,
+      const std::unordered_set<int>& mapIds,
+      const std::unordered_map<int, PtrReviveRequest>& requests) {
+    return std::nullopt;
+  }
+
+  virtual std::optional<PtrPartitionLocationMap> getPartitionLocationMap(
+      int shuffleId) {
+    return partitionLocationMaps_.get(shuffleId);
+  }
+
  private:
+  // TODO: no support for WAIT as it is not used.
+  static bool newerPartitionLocationExists(
+      std::shared_ptr<utils::ConcurrentHashMap<
+          int,
+          std::shared_ptr<const protocol::PartitionLocation>>> locationMap,
+      int partitionId,
+      int epoch);
+
   std::shared_ptr<protocol::GetReducerFileGroupResponse>
   getReducerFileGroupInfo(int shuffleId);
 
@@ -97,6 +145,7 @@ class ShuffleClientImpl : public ShuffleClient {
       int,
       std::shared_ptr<protocol::GetReducerFileGroupResponse>>
       reducerFileGroupInfos_;
+  utils::ConcurrentHashMap<int, PtrPartitionLocationMap> 
partitionLocationMaps_;
 };
 } // namespace client
 } // namespace celeborn
diff --git a/cpp/celeborn/client/tests/CMakeLists.txt 
b/cpp/celeborn/client/tests/CMakeLists.txt
index 6341c84e7..8183611c6 100644
--- a/cpp/celeborn/client/tests/CMakeLists.txt
+++ b/cpp/celeborn/client/tests/CMakeLists.txt
@@ -17,6 +17,7 @@ add_executable(
         celeborn_client_test
         WorkerPartitionReaderTest.cpp
         PushStateTest.cpp
+        ReviveManagerTest.cpp
         Lz4DecompressorTest.cpp
         ZstdDecompressorTest.cpp
         Lz4CompressorTest.cpp
diff --git a/cpp/celeborn/client/tests/ReviveManagerTest.cpp 
b/cpp/celeborn/client/tests/ReviveManagerTest.cpp
new file mode 100644
index 000000000..8db844485
--- /dev/null
+++ b/cpp/celeborn/client/tests/ReviveManagerTest.cpp
@@ -0,0 +1,237 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <gtest/gtest.h>
+
+#include "celeborn/client/writer/ReviveManager.h"
+
+using namespace celeborn;
+using namespace celeborn::client;
+
+namespace {
+class MockShuffleClient : public ShuffleClientImpl {
+ public:
+  friend class ReviveManager;
+
+  static std::shared_ptr<MockShuffleClient> create() {
+    return std::shared_ptr<MockShuffleClient>(new MockShuffleClient());
+  }
+
+  virtual ~MockShuffleClient() = default;
+
+  bool mapperEnded(int shuffleId, int mapId) override {
+    return onMapperEnded_(shuffleId, mapId);
+  }
+
+  void setOnMapperEnded(std::function<bool(int, int)>&& onMapperEnded) {
+    onMapperEnded_ = onMapperEnded;
+  }
+
+  std::optional<std::unordered_map<int, int>> reviveBatch(
+      int shuffleId,
+      const std::unordered_set<int>& mapIds,
+      const std::unordered_map<int, PtrReviveRequest>& requests) override {
+    return onReviveBatch_(shuffleId, mapIds, requests);
+  }
+
+  void setOnReviveBatch(
+      std::function<std::optional<std::unordered_map<int, int>>(
+          int,
+          const std::unordered_set<int>&,
+          const std::unordered_map<int, PtrReviveRequest>&)>&& onReviveBatch) {
+    onReviveBatch_ = onReviveBatch;
+  }
+
+  std::optional<PtrPartitionLocationMap> getPartitionLocationMap(
+      int shuffleId) override {
+    return onGetPartitionLocationMap_(shuffleId);
+  }
+
+  void setOnGetPartitionLocationMap(
+      std::function<std::optional<PtrPartitionLocationMap>(int)>&&
+          onGetPartitionLocationMap) {
+    onGetPartitionLocationMap_ = onGetPartitionLocationMap;
+  }
+
+ private:
+  MockShuffleClient()
+      : ShuffleClientImpl(
+            "mock",
+            std::make_shared<conf::CelebornConf>(),
+            nullptr) {}
+  std::function<bool(int, int)> onMapperEnded_ = [](int, int) { return false; 
};
+  std::function<std::optional<std::unordered_map<int, int>>(
+      int,
+      const std::unordered_set<int>&,
+      const std::unordered_map<int, PtrReviveRequest>&)>
+      onReviveBatch_ = [](int,
+                          const std::unordered_set<int>&,
+                          const std::unordered_map<int, PtrReviveRequest>&) {
+        return std::nullopt;
+      };
+  std::function<std::optional<PtrPartitionLocationMap>(int)>
+      onGetPartitionLocationMap_ =
+          [](int) -> std::optional<PtrPartitionLocationMap> {
+    return {std::make_shared<PartitionLocationMap>()};
+  };
+};
+} // namespace
+
+class ReviveManagerTest : public testing::Test {
+ protected:
+  void SetUp() override {
+    mockShuffleClient_ = MockShuffleClient::create();
+    auto conf = conf::CelebornConf();
+    conf.registerProperty(
+        conf::CelebornConf::kClientPushReviveInterval,
+        fmt::format("{}ms", pushReviveIntervalMs_));
+    reviveManager_ = ReviveManager::create(
+        "test", conf, mockShuffleClient_->weak_from_this());
+  }
+
+  std::shared_ptr<MockShuffleClient> mockShuffleClient_;
+  std::shared_ptr<ReviveManager> reviveManager_;
+
+  static constexpr int pushReviveIntervalMs_ = 100;
+};
+
+TEST_F(ReviveManagerTest, successOnMapperEnded) {
+  int mapperEndedCalledTimes = 0;
+  const int testShuffleId = 1000;
+  const int testMapId = 1001;
+  const int testAttemptId = 1002;
+  const int testPartitionId = 1003;
+  const int testEpoch = 1004;
+  auto reviveRequest = std::make_shared<protocol::ReviveRequest>(
+      testShuffleId,
+      testMapId,
+      testAttemptId,
+      testPartitionId,
+      testEpoch,
+      nullptr,
+      static_cast<protocol::StatusCode>(0));
+  mockShuffleClient_->setOnMapperEnded(
+      [=, &mapperEndedCalledTimes](int shuffleId, int mapId) mutable {
+        EXPECT_EQ(testShuffleId, shuffleId);
+        EXPECT_EQ(testMapId, mapId);
+        mapperEndedCalledTimes++;
+        return true;
+      });
+
+  EXPECT_EQ(
+      reviveRequest->reviveStatus, protocol::StatusCode::REVIVE_INITIALIZED);
+  reviveManager_->addRequest(reviveRequest);
+  std::this_thread::sleep_for(
+      std::chrono::milliseconds(pushReviveIntervalMs_ * 3));
+
+  EXPECT_GT(mapperEndedCalledTimes, 0);
+  EXPECT_EQ(reviveRequest->reviveStatus, protocol::StatusCode::SUCCESS);
+}
+
+TEST_F(ReviveManagerTest, successOnReviveSuccess) {
+  int reviveBatchCalledTimes = 0;
+  const int testShuffleId = 1000;
+  const int testMapId = 1001;
+  const int testAttemptId = 1002;
+  const int testPartitionId = 1003;
+  const int testEpoch = 1004;
+  auto reviveRequest = std::make_shared<protocol::ReviveRequest>(
+      testShuffleId,
+      testMapId,
+      testAttemptId,
+      testPartitionId,
+      testEpoch,
+      nullptr,
+      static_cast<protocol::StatusCode>(0));
+  mockShuffleClient_->setOnReviveBatch(
+      [=, &reviveBatchCalledTimes](
+          int shuffleId,
+          const std::unordered_set<int>& mapIds,
+          const std::unordered_map<int, ShuffleClientImpl::PtrReviveRequest>&
+              requestsToSend) mutable
+      -> std::optional<std::unordered_map<int, int>> {
+        EXPECT_EQ(testShuffleId, shuffleId);
+        EXPECT_GT(mapIds.count(testMapId), 0);
+        EXPECT_GT(requestsToSend.count(testPartitionId), 0);
+        auto request = requestsToSend.find(testPartitionId)->second;
+        EXPECT_EQ(request->shuffleId, testShuffleId);
+        EXPECT_EQ(request->mapId, testMapId);
+        EXPECT_EQ(request->attemptId, testAttemptId);
+        EXPECT_EQ(request->partitionId, testPartitionId);
+        EXPECT_EQ(request->epoch, testEpoch);
+
+        std::unordered_map<int, int> result;
+        result[request->partitionId] = protocol::StatusCode::SUCCESS;
+        ++reviveBatchCalledTimes;
+        return {result};
+      });
+
+  EXPECT_EQ(
+      reviveRequest->reviveStatus, protocol::StatusCode::REVIVE_INITIALIZED);
+  reviveManager_->addRequest(reviveRequest);
+  std::this_thread::sleep_for(
+      std::chrono::milliseconds(pushReviveIntervalMs_ * 3));
+
+  EXPECT_GT(reviveBatchCalledTimes, 0);
+  EXPECT_EQ(reviveRequest->reviveStatus, protocol::StatusCode::SUCCESS);
+}
+
+TEST_F(ReviveManagerTest, failureOnReviveFailure) {
+  int reviveBatchCalledTimes = 0;
+  const int testShuffleId = 1000;
+  const int testMapId = 1001;
+  const int testAttemptId = 1002;
+  const int testPartitionId = 1003;
+  const int testEpoch = 1004;
+  auto reviveRequest = std::make_shared<protocol::ReviveRequest>(
+      testShuffleId,
+      testMapId,
+      testAttemptId,
+      testPartitionId,
+      testEpoch,
+      nullptr,
+      static_cast<protocol::StatusCode>(0));
+  mockShuffleClient_->setOnReviveBatch(
+      [=, &reviveBatchCalledTimes](
+          int shuffleId,
+          const std::unordered_set<int>& mapIds,
+          const std::unordered_map<int, ShuffleClientImpl::PtrReviveRequest>&
+              requestsToSend) mutable
+      -> std::optional<std::unordered_map<int, int>> {
+        EXPECT_EQ(testShuffleId, shuffleId);
+        EXPECT_GT(mapIds.count(testMapId), 0);
+        EXPECT_GT(requestsToSend.count(testPartitionId), 0);
+        auto request = requestsToSend.find(testPartitionId)->second;
+        EXPECT_EQ(request->shuffleId, testShuffleId);
+        EXPECT_EQ(request->mapId, testMapId);
+        EXPECT_EQ(request->attemptId, testAttemptId);
+        EXPECT_EQ(request->partitionId, testPartitionId);
+        EXPECT_EQ(request->epoch, testEpoch);
+
+        ++reviveBatchCalledTimes;
+        return std::nullopt;
+      });
+
+  EXPECT_EQ(
+      reviveRequest->reviveStatus, protocol::StatusCode::REVIVE_INITIALIZED);
+  reviveManager_->addRequest(reviveRequest);
+  std::this_thread::sleep_for(
+      std::chrono::milliseconds(pushReviveIntervalMs_ * 3));
+
+  EXPECT_GT(reviveBatchCalledTimes, 0);
+  EXPECT_EQ(reviveRequest->reviveStatus, protocol::StatusCode::REVIVE_FAILED);
+}
diff --git a/cpp/celeborn/client/writer/ReviveManager.cpp 
b/cpp/celeborn/client/writer/ReviveManager.cpp
new file mode 100644
index 000000000..47a548546
--- /dev/null
+++ b/cpp/celeborn/client/writer/ReviveManager.cpp
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "celeborn/client/writer/ReviveManager.h"
+
+namespace celeborn {
+namespace client {
+
+folly::FunctionScheduler ReviveManager::globalExecutor_ =
+    folly::FunctionScheduler();
+
+std::shared_ptr<ReviveManager> ReviveManager::create(
+    const std::string& name,
+    const conf::CelebornConf& conf,
+    std::weak_ptr<ShuffleClientImpl> weakClient) {
+  auto reviveManager =
+      std::shared_ptr<ReviveManager>(new ReviveManager(name, conf, 
weakClient));
+  reviveManager->start();
+  return std::move(reviveManager);
+}
+
+ReviveManager::ReviveManager(
+    const std::string& name,
+    const conf::CelebornConf& conf,
+    std::weak_ptr<ShuffleClientImpl> weakClient)
+    : name_(name),
+      batchSize_(conf.clientPushReviveBatchSize()),
+      interval_(conf.clientPushReviveInterval()),
+      weakClient_(weakClient) {}
+
+ReviveManager::~ReviveManager() {
+  globalExecutor_.cancelFunction(name_);
+}
+
+void ReviveManager::start() {
+  bool expected = false;
+  if (!started_.compare_exchange_strong(expected, true)) {
+    return;
+  }
+  globalExecutor_.start();
+  auto task = [weak_this = weak_from_this(), batchSize = batchSize_]() {
+    try {
+      auto shared_this = weak_this.lock();
+      if (!shared_this) {
+        return;
+      }
+      bool continueFlag = true;
+      do {
+        std::unordered_map<
+            int,
+            std::unique_ptr<std::unordered_set<PtrReviveRequest>>>
+            shuffleMap;
+        std::vector<PtrReviveRequest> requests;
+        shared_this->requestQueue_.withLock([&](auto& queue) {
+          for (int i = 0; i < batchSize && !queue.empty(); i++) {
+            requests.push_back(queue.front());
+            queue.pop();
+          }
+        });
+        if (requests.empty()) {
+          break;
+        }
+        for (auto& request : requests) {
+          auto& set = shuffleMap[request->shuffleId];
+          if (!set) {
+            set = std::make_unique<std::unordered_set<PtrReviveRequest>>();
+          }
+          set->insert(request);
+        }
+        auto shuffleClient = shared_this->weakClient_.lock();
+        if (!shuffleClient) {
+          return;
+        }
+        for (auto& [shuffleId, requestSet] : shuffleMap) {
+          std::unordered_set<int> mapIds;
+          std::vector<PtrReviveRequest> filteredRequests;
+          std::unordered_map<int, PtrReviveRequest> requestsToSend;
+
+          auto locationMapOptional =
+              shuffleClient->getPartitionLocationMap(shuffleId);
+          CELEBORN_CHECK(locationMapOptional.has_value());
+          auto locationMap = locationMapOptional.value();
+          for (auto& request : *requestSet) {
+            if (shuffleClient->newerPartitionLocationExists(
+                    locationMap, request->partitionId, request->epoch) ||
+                shuffleClient->mapperEnded(shuffleId, request->mapId)) {
+              request->reviveStatus = protocol::StatusCode::SUCCESS;
+            } else {
+              filteredRequests.push_back(request);
+              mapIds.insert(request->mapId);
+              if (auto iter = requestsToSend.find(request->partitionId);
+                  iter == requestsToSend.end() ||
+                  iter->second->epoch < request->epoch) {
+                requestsToSend[request->partitionId] = request;
+              }
+            }
+          }
+
+          if (requestsToSend.empty()) {
+            continue;
+          }
+          if (auto resultOptional =
+                  shuffleClient->reviveBatch(shuffleId, mapIds, 
requestsToSend);
+              resultOptional.has_value()) {
+            auto result = resultOptional.value();
+            for (auto& request : filteredRequests) {
+              if (shuffleClient->mapperEnded(shuffleId, request->mapId)) {
+                request->reviveStatus = protocol::StatusCode::SUCCESS;
+              } else {
+                request->reviveStatus = result[request->partitionId];
+              }
+            }
+          } else {
+            for (auto& request : filteredRequests) {
+              request->reviveStatus = protocol::StatusCode::REVIVE_FAILED;
+            }
+          }
+        }
+        continueFlag =
+            (shared_this->requestQueue_.lock()->size() > batchSize / 2);
+      } while (continueFlag);
+    } catch (std::exception& e) {
+      LOG(ERROR) << "ReviveManager error occurred: " << e.what();
+    }
+  };
+  startFunction(task);
+}
+
+void ReviveManager::startFunction(std::function<void()> task) {
+  try {
+    globalExecutor_.addFunction(task, interval_, name_, interval_);
+  } catch (std::exception& e) {
+    LOG(ERROR) << "startFunction failed, current function name " << name_
+               << ", retry again...";
+    name_ += "-";
+    name_ += std::to_string(rand() % 10000);
+    startFunction(task);
+  }
+}
+
+void ReviveManager::addRequest(PtrReviveRequest request) {
+  requestQueue_.withLock([&](auto& queue) { queue.push(std::move(request)); });
+}
+
+} // namespace client
+} // namespace celeborn
diff --git a/cpp/celeborn/client/writer/ReviveManager.h 
b/cpp/celeborn/client/writer/ReviveManager.h
new file mode 100644
index 000000000..4f3b7962e
--- /dev/null
+++ b/cpp/celeborn/client/writer/ReviveManager.h
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <folly/experimental/FunctionScheduler.h>
+
+#include "celeborn/client/ShuffleClient.h"
+#include "celeborn/conf/CelebornConf.h"
+#include "celeborn/protocol/ControlMessages.h"
+
+namespace celeborn {
+namespace client {
+
+/// ReviveManager is responsible for buffering the ReviveRequests, and issue
+/// the revive requests periodically in batches.
+class ReviveManager : public std::enable_shared_from_this<ReviveManager> {
+ public:
+  using PtrReviveRequest = std::shared_ptr<protocol::ReviveRequest>;
+
+  // Only allow construction from create() method to ensure that functionality
+  // of std::shared_from_this works properly.
+  static std::shared_ptr<ReviveManager> create(
+      const std::string& name,
+      const conf::CelebornConf& conf,
+      std::weak_ptr<ShuffleClientImpl> weakClient);
+
+  ~ReviveManager();
+
+  void addRequest(PtrReviveRequest request);
+
+ private:
+  // The constructor is hidden to ensure that functionality of
+  // std::shared_from_this works properly.
+  ReviveManager(
+      const std::string& name,
+      const conf::CelebornConf& conf,
+      std::weak_ptr<ShuffleClientImpl> weakClient);
+
+  // The start() method must not be called within constructor, because for
+  // std::enable_shred_from_this the shared_from_this() or weak_from_this()
+  // would not work until the object construction is complete.
+  void start();
+
+  void startFunction(std::function<void()> task);
+
+  // Scheduler for issuing the requests periodically.
+  static folly::FunctionScheduler globalExecutor_;
+
+  std::string name_;
+  const int batchSize_;
+  const Timeout interval_;
+  std::weak_ptr<ShuffleClientImpl> weakClient_;
+  folly::Synchronized<std::queue<PtrReviveRequest>, std::mutex> requestQueue_;
+  std::atomic<bool> started_{false};
+};
+
+} // namespace client
+} // namespace celeborn
diff --git a/cpp/celeborn/conf/CelebornConf.cpp 
b/cpp/celeborn/conf/CelebornConf.cpp
index a3e464332..ab1a7abf0 100644
--- a/cpp/celeborn/conf/CelebornConf.cpp
+++ b/cpp/celeborn/conf/CelebornConf.cpp
@@ -135,6 +135,8 @@ const std::unordered_map<std::string, 
folly::Optional<std::string>>
     CelebornConf::kDefaultProperties = {
         STR_PROP(kRpcAskTimeout, "60s"),
         STR_PROP(kRpcLookupTimeout, "30s"),
+        STR_PROP(kClientPushReviveInterval, "100ms"),
+        NUM_PROP(kClientPushReviveBatchSize, 2048),
         STR_PROP(kClientPushLimitStrategy, kSimplePushStrategy),
         NUM_PROP(kClientPushMaxReqsInFlightPerWorker, 32),
         NUM_PROP(kClientPushMaxReqsInFlightTotal, 256),
@@ -190,6 +192,15 @@ Timeout CelebornConf::rpcLookupTimeout() const {
       toDuration(optionalProperty(kRpcLookupTimeout).value()));
 }
 
+Timeout CelebornConf::clientPushReviveInterval() const {
+  return utils::toTimeout(
+      toDuration(optionalProperty(kClientPushReviveInterval).value()));
+}
+
+int CelebornConf::clientPushReviveBatchSize() const {
+  return std::stoi(optionalProperty(kClientPushReviveBatchSize).value());
+}
+
 std::string CelebornConf::clientPushLimitStrategy() const {
   return optionalProperty(kClientPushLimitStrategy).value();
 }
diff --git a/cpp/celeborn/conf/CelebornConf.h b/cpp/celeborn/conf/CelebornConf.h
index 98b5b9306..d6ce24f18 100644
--- a/cpp/celeborn/conf/CelebornConf.h
+++ b/cpp/celeborn/conf/CelebornConf.h
@@ -45,6 +45,12 @@ class CelebornConf : public BaseConf {
   static constexpr std::string_view kRpcLookupTimeout{
       "celeborn.rpc.lookupTimeout"};
 
+  static constexpr std::string_view kClientPushReviveInterval{
+      "celeborn.client.push.revive.interval"};
+
+  static constexpr std::string_view kClientPushReviveBatchSize{
+      "celeborn.client.push.revive.batchSize"};
+
   static constexpr std::string_view kClientPushLimitStrategy{
       "celeborn.client.push.limit.strategy"};
 
@@ -100,6 +106,10 @@ class CelebornConf : public BaseConf {
 
   Timeout rpcLookupTimeout() const;
 
+  Timeout clientPushReviveInterval() const;
+
+  int clientPushReviveBatchSize() const;
+
   std::string clientPushLimitStrategy() const;
 
   int clientPushMaxReqsInFlightPerWorker() const;
diff --git a/cpp/celeborn/tests/DataSumWithReaderClient.cpp 
b/cpp/celeborn/tests/DataSumWithReaderClient.cpp
index 8c060fa54..ac62d13fe 100644
--- a/cpp/celeborn/tests/DataSumWithReaderClient.cpp
+++ b/cpp/celeborn/tests/DataSumWithReaderClient.cpp
@@ -46,7 +46,7 @@ int main(int argc, char** argv) {
       celeborn::conf::CelebornConf::kShuffleCompressionCodec, compressCodec);
   auto clientFactory =
       std::make_shared<celeborn::network::TransportClientFactory>(conf);
-  auto shuffleClient = std::make_unique<celeborn::client::ShuffleClientImpl>(
+  auto shuffleClient = celeborn::client::ShuffleClientImpl::create(
       appUniqueId, conf, clientFactory);
   shuffleClient->setupLifecycleManagerRef(
       lifecycleManagerHost, lifecycleManagerPort);

Reply via email to