This is an automated email from the ASF dual-hosted git repository.

nicholasjiang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 8ccd4dad1 [CELEBORN-2182][CIP-14] Support PushState and PushStrategy 
in CppClient
8ccd4dad1 is described below

commit 8ccd4dad10a53647d54231337b5c6966160ee236
Author: HolyLow <[email protected]>
AuthorDate: Mon Nov 3 11:36:17 2025 +0800

    [CELEBORN-2182][CIP-14] Support PushState and PushStrategy in CppClient
    
    ### What changes were proposed in this pull request?
    This PR supports PushState and PushStrategy in CppClient.
    
    ### Why are the changes needed?
    These functionalities are used in the writing procedure in CppClient.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Compilation and UTs.
    
    Closes #3515 from 
HolyLow/issue/celeborn-2182-support-PushState-in-cpp-client.
    
    Authored-by: HolyLow <[email protected]>
    Signed-off-by: SteNicholas <[email protected]>
---
 cpp/celeborn/client/CMakeLists.txt             |   2 +
 cpp/celeborn/client/tests/CMakeLists.txt       |   1 +
 cpp/celeborn/client/tests/PushStateTest.cpp    | 155 +++++++++++++++++++++++
 cpp/celeborn/client/writer/PushState.cpp       | 163 +++++++++++++++++++++++++
 cpp/celeborn/client/writer/PushState.h         |  83 +++++++++++++
 cpp/celeborn/client/writer/PushStrategy.cpp    |  34 ++++++
 cpp/celeborn/client/writer/PushStrategy.h      |  79 ++++++++++++
 cpp/celeborn/conf/CelebornConf.cpp             |  27 ++++
 cpp/celeborn/conf/CelebornConf.h               |  27 ++++
 cpp/celeborn/utils/CelebornUtils.h             |  56 +++++++++
 cpp/celeborn/utils/tests/CelebornUtilsTest.cpp |  85 +++++++++++--
 11 files changed, 705 insertions(+), 7 deletions(-)

diff --git a/cpp/celeborn/client/CMakeLists.txt 
b/cpp/celeborn/client/CMakeLists.txt
index 2586f6855..1af9069ee 100644
--- a/cpp/celeborn/client/CMakeLists.txt
+++ b/cpp/celeborn/client/CMakeLists.txt
@@ -16,6 +16,8 @@ add_library(
         client
         reader/WorkerPartitionReader.cpp
         reader/CelebornInputStream.cpp
+        writer/PushState.cpp
+        writer/PushStrategy.cpp
         ShuffleClient.cpp
         compress/Decompressor.cpp
         compress/Lz4Decompressor.cpp
diff --git a/cpp/celeborn/client/tests/CMakeLists.txt 
b/cpp/celeborn/client/tests/CMakeLists.txt
index d8a98e2b6..6341c84e7 100644
--- a/cpp/celeborn/client/tests/CMakeLists.txt
+++ b/cpp/celeborn/client/tests/CMakeLists.txt
@@ -16,6 +16,7 @@
 add_executable(
         celeborn_client_test
         WorkerPartitionReaderTest.cpp
+        PushStateTest.cpp
         Lz4DecompressorTest.cpp
         ZstdDecompressorTest.cpp
         Lz4CompressorTest.cpp
diff --git a/cpp/celeborn/client/tests/PushStateTest.cpp 
b/cpp/celeborn/client/tests/PushStateTest.cpp
new file mode 100644
index 000000000..94b6c3aa3
--- /dev/null
+++ b/cpp/celeborn/client/tests/PushStateTest.cpp
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <gtest/gtest.h>
+
+#include "celeborn/client/writer/PushState.h"
+
+using namespace celeborn;
+using namespace celeborn::client;
+
+class PushStateTest : public testing::Test {
+ protected:
+  void SetUp() override {
+    conf::CelebornConf conf;
+    conf.registerProperty(
+        conf::CelebornConf::kClientPushLimitInFlightTimeoutMs,
+        std::to_string(pushTimeoutMs_));
+    conf.registerProperty(
+        conf::CelebornConf::kClientPushLimitInFlightSleepDeltaMs,
+        std::to_string(pushSleepDeltaMs_));
+    conf.registerProperty(
+        conf::CelebornConf::kClientPushMaxReqsInFlightTotal,
+        std::to_string(maxReqsInFlight_));
+    conf.registerProperty(
+        conf::CelebornConf::kClientPushMaxReqsInFlightPerWorker,
+        std::to_string(maxReqsInFlight_));
+
+    pushState_ = std::make_unique<PushState>(conf);
+  }
+
+  std::unique_ptr<PushState> pushState_;
+  static constexpr int pushTimeoutMs_ = 100;
+  static constexpr int pushSleepDeltaMs_ = 10;
+  static constexpr int maxReqsInFlight_ = 2;
+};
+
+TEST_F(PushStateTest, limitMaxInFlight) {
+  const std::string hostAndPushPort = "xx.xx.xx.xx:8080";
+  const int addBatchCalls = maxReqsInFlight_ + 1;
+  std::vector<bool> addBatchMarks(addBatchCalls, false);
+  std::thread addBatchThread([&]() {
+    for (auto i = 0; i < addBatchCalls; i++) {
+      pushState_->addBatch(i, hostAndPushPort);
+      EXPECT_FALSE(pushState_->limitMaxInFlight(hostAndPushPort));
+      addBatchMarks[i] = true;
+    }
+  });
+
+  std::this_thread::sleep_for(std::chrono::milliseconds(pushSleepDeltaMs_));
+  for (auto i = 0; i < maxReqsInFlight_; i++) {
+    EXPECT_TRUE(addBatchMarks[i]);
+  }
+  EXPECT_FALSE(addBatchMarks[maxReqsInFlight_]);
+
+  pushState_->removeBatch(0, hostAndPushPort);
+  addBatchThread.join();
+  EXPECT_TRUE(addBatchMarks[maxReqsInFlight_]);
+}
+
+TEST_F(PushStateTest, limitMaxInFlightTimeout) {
+  const std::string hostAndPushPort = "xx.xx.xx.xx:8080";
+  const int addBatchCalls = maxReqsInFlight_ + 1;
+  std::vector<bool> addBatchMarks(addBatchCalls, false);
+  std::thread addBatchThread([&]() {
+    for (auto i = 0; i < addBatchCalls; i++) {
+      pushState_->addBatch(i, hostAndPushPort);
+      auto result = pushState_->limitMaxInFlight(hostAndPushPort);
+      if (i < maxReqsInFlight_) {
+        EXPECT_FALSE(result);
+      } else {
+        EXPECT_TRUE(result);
+      }
+      addBatchMarks[i] = !result;
+    }
+  });
+
+  std::this_thread::sleep_for(std::chrono::milliseconds(pushSleepDeltaMs_));
+  for (auto i = 0; i < maxReqsInFlight_; i++) {
+    EXPECT_TRUE(addBatchMarks[i]);
+  }
+  EXPECT_FALSE(addBatchMarks[maxReqsInFlight_]);
+
+  addBatchThread.join();
+  EXPECT_FALSE(addBatchMarks[maxReqsInFlight_]);
+}
+
+TEST_F(PushStateTest, limitZeroInFlight) {
+  const std::string hostAndPushPort = "xx.xx.xx.xx:8080";
+  const int addBatchCalls = 1;
+  std::vector<bool> addBatchMarks(addBatchCalls, false);
+  std::thread addBatchThread([&]() {
+    pushState_->addBatch(0, hostAndPushPort);
+    EXPECT_FALSE(pushState_->limitZeroInFlight());
+    addBatchMarks[0] = true;
+  });
+
+  std::this_thread::sleep_for(std::chrono::milliseconds(pushSleepDeltaMs_));
+  EXPECT_FALSE(addBatchMarks[0]);
+
+  pushState_->removeBatch(0, hostAndPushPort);
+  addBatchThread.join();
+  EXPECT_TRUE(addBatchMarks[0]);
+}
+
+TEST_F(PushStateTest, limitZeroInFlightTimeout) {
+  const std::string hostAndPushPort = "xx.xx.xx.xx:8080";
+  const int addBatchCalls = 1;
+  std::vector<bool> addBatchMarks(addBatchCalls, false);
+  std::thread addBatchThread([&]() {
+    pushState_->addBatch(0, hostAndPushPort);
+    auto result = pushState_->limitZeroInFlight();
+    EXPECT_TRUE(result);
+    addBatchMarks[0] = !result;
+  });
+
+  std::this_thread::sleep_for(std::chrono::milliseconds(pushSleepDeltaMs_));
+  EXPECT_FALSE(addBatchMarks[0]);
+
+  addBatchThread.join();
+  EXPECT_FALSE(addBatchMarks[0]);
+}
+
+TEST_F(PushStateTest, throwException) {
+  const std::string hostAndPushPort = "xx.xx.xx.xx:8080";
+  pushState_->setException(std::make_unique<std::exception>());
+  bool exceptionThrowed = false;
+  try {
+    pushState_->limitMaxInFlight(hostAndPushPort);
+  } catch (...) {
+    exceptionThrowed = true;
+  }
+  EXPECT_TRUE(exceptionThrowed);
+
+  exceptionThrowed = false;
+  try {
+    pushState_->limitZeroInFlight();
+  } catch (...) {
+    exceptionThrowed = true;
+  }
+  EXPECT_TRUE(exceptionThrowed);
+}
diff --git a/cpp/celeborn/client/writer/PushState.cpp 
b/cpp/celeborn/client/writer/PushState.cpp
new file mode 100644
index 000000000..b86a2aadb
--- /dev/null
+++ b/cpp/celeborn/client/writer/PushState.cpp
@@ -0,0 +1,163 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "celeborn/client/writer/PushState.h"
+
+namespace celeborn {
+namespace client {
+
+PushState::PushState(const conf::CelebornConf& conf)
+    : waitInflightTimeoutMs_(conf.clientPushLimitInFlightTimeoutMs()),
+      deltaMs_(conf.clientPushLimitInFlightSleepDeltaMs()),
+      pushStrategy_(PushStrategy::create(conf)),
+      maxInFlightReqsTotal_(conf.clientPushMaxReqsInFlightTotal()) {}
+
+int PushState::nextBatchId() {
+  return currBatchId_.fetch_add(1);
+}
+
+void PushState::addBatch(int batchId, const std::string& hostAndPushPort) {
+  auto batchIdSet = inflightBatchesPerAddress_.computeIfAbsent(
+      hostAndPushPort,
+      [&]() { return std::make_shared<utils::ConcurrentHashSet<int>>(); });
+  batchIdSet->insert(batchId);
+  totalInflightReqs_.fetch_add(1);
+}
+
+void PushState::onSuccess(const std::string& hostAndPushPort) {
+  pushStrategy_->onSuccess(hostAndPushPort);
+}
+
+void PushState::onCongestControl(const std::string& hostAndPushPort) {
+  pushStrategy_->onCongestControl(hostAndPushPort);
+}
+
+void PushState::removeBatch(int batchId, const std::string& hostAndPushPort) {
+  auto batchIdSetOptional = inflightBatchesPerAddress_.get(hostAndPushPort);
+  if (batchIdSetOptional.has_value()) {
+    auto batchIdSet = batchIdSetOptional.value();
+    batchIdSet->erase(batchId);
+    totalInflightReqs_.fetch_sub(1);
+  } else {
+    LOG(WARNING) << "BatchIdSet of " << hostAndPushPort << " doesn't exist.";
+  }
+}
+
+bool PushState::limitMaxInFlight(const std::string& hostAndPushPort) {
+  throwIfExceptionExists();
+
+  pushStrategy_->limitPushSpeed(*this, hostAndPushPort);
+  int currentMaxReqsInFlight =
+      pushStrategy_->getCurrentMaxReqsInFlight(hostAndPushPort);
+
+  auto batchIdSet = inflightBatchesPerAddress_.computeIfAbsent(
+      hostAndPushPort,
+      [&]() { return std::make_shared<utils::ConcurrentHashSet<int>>(); });
+  long times = waitInflightTimeoutMs_ / deltaMs_;
+  for (; times > 0; times--) {
+    if (totalInflightReqs_ <= maxInFlightReqsTotal_ &&
+        batchIdSet->size() <= currentMaxReqsInFlight) {
+      break;
+    }
+    throwIfExceptionExists();
+    std::this_thread::sleep_for(utils::MS(deltaMs_));
+  }
+
+  if (times <= 0) {
+    LOG(WARNING) << "After waiting for " << waitInflightTimeoutMs_
+                 << " ms, there are still " << batchIdSet->size()
+                 << " batches in flight for hostAndPushPort " << 
hostAndPushPort
+                 << ", which exceeds the current limit "
+                 << currentMaxReqsInFlight;
+  }
+  throwIfExceptionExists();
+  return times <= 0;
+}
+
+bool PushState::limitZeroInFlight() {
+  throwIfExceptionExists();
+
+  long times = waitInflightTimeoutMs_ / deltaMs_;
+  for (; times > 0; times--) {
+    if (totalInflightReqs_ <= 0) {
+      break;
+    }
+    throwIfExceptionExists();
+    std::this_thread::sleep_for(utils::MS(deltaMs_));
+  }
+
+  if (times <= 0) {
+    std::string addressInfos;
+    inflightBatchesPerAddress_.forEach(
+        [&](const std::string& address,
+            const std::shared_ptr<utils::ConcurrentHashSet<int>>&
+                inflightBatches) {
+          if (inflightBatches->size() <= 0) {
+            return;
+          }
+          if (!addressInfos.empty()) {
+            addressInfos += ", ";
+          }
+          addressInfos += fmt::format(
+              "{} batches for hostAndPushPort {}",
+              inflightBatches->size(),
+              address);
+        });
+    LOG(ERROR) << "After waiting for " << waitInflightTimeoutMs_
+               << " ms, there are still " << totalInflightReqs_
+               << " in flight: [" << addressInfos
+               << "] which exceeds the current limit 0.";
+  }
+  throwIfExceptionExists();
+  return times <= 0;
+}
+
+bool PushState::exceptionExists() const {
+  auto exp = exception_.rlock();
+  return (bool)(*exp);
+}
+
+void PushState::setException(std::unique_ptr<std::exception> exception) {
+  auto exp = exception_.wlock();
+  if (!(*exp)) {
+    *exp = std::move(exception);
+  }
+}
+
+std::optional<std::string> PushState::getExceptionMsg() const {
+  auto exp = exception_.rlock();
+  if (*exp) {
+    return (*exp)->what();
+  }
+  return std::nullopt;
+}
+
+void PushState::cleanup() {
+  inflightBatchesPerAddress_.clear();
+  totalInflightReqs_ = 0;
+  pushStrategy_->clear();
+}
+
+void PushState::throwIfExceptionExists() {
+  auto exp = exception_.rlock();
+  if (*exp) {
+    CELEBORN_FAIL((*exp)->what());
+  }
+}
+
+} // namespace client
+} // namespace celeborn
diff --git a/cpp/celeborn/client/writer/PushState.h 
b/cpp/celeborn/client/writer/PushState.h
new file mode 100644
index 000000000..a32b0971c
--- /dev/null
+++ b/cpp/celeborn/client/writer/PushState.h
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <atomic>
+
+#include "celeborn/client/writer/PushStrategy.h"
+#include "celeborn/conf/CelebornConf.h"
+#include "celeborn/utils/CelebornUtils.h"
+
+namespace celeborn {
+namespace client {
+
+class PushStrategy;
+
+/// Records the states of a mapKey, including the ongoing package number id, 
the
+/// exception, etc. Besides, the congestionControl is also enforced.
+class PushState {
+ public:
+  PushState(const conf::CelebornConf& conf);
+
+  int nextBatchId();
+
+  void addBatch(int batchId, const std::string& hostAndPushPort);
+
+  void onSuccess(const std::string& hostAndPushPort);
+
+  void onCongestControl(const std::string& hostAndPushPort);
+
+  void removeBatch(int batchId, const std::string& hostAndPushPort);
+
+  // Check if the host's ongoing package num reaches the max limit, if so,
+  // block until the ongoing package num decreases below max limit. If the
+  // limit operation succeeds before timeout, return false, otherwise return
+  // true.
+  bool limitMaxInFlight(const std::string& hostAndPushPort);
+
+  // Check if the pushState's ongoing package num reaches zero, if not, block
+  // until the ongoing package num decreases to zero. If the limit operation
+  // succeeds before timeout, return false, otherwise return true.
+  bool limitZeroInFlight();
+
+  bool exceptionExists() const;
+
+  void setException(std::unique_ptr<std::exception> exception);
+
+  std::optional<std::string> getExceptionMsg() const;
+
+  void cleanup();
+
+ private:
+  void throwIfExceptionExists();
+
+  std::atomic<int> currBatchId_{1};
+  std::atomic<long> totalInflightReqs_{0};
+  const long waitInflightTimeoutMs_;
+  const long deltaMs_;
+  const std::unique_ptr<PushStrategy> pushStrategy_;
+  const int maxInFlightReqsTotal_;
+  utils::ConcurrentHashMap<
+      std::string,
+      std::shared_ptr<utils::ConcurrentHashSet<int>>>
+      inflightBatchesPerAddress_;
+  folly::Synchronized<std::unique_ptr<std::exception>> exception_;
+};
+
+} // namespace client
+} // namespace celeborn
diff --git a/cpp/celeborn/client/writer/PushStrategy.cpp 
b/cpp/celeborn/client/writer/PushStrategy.cpp
new file mode 100644
index 000000000..66dad6c4e
--- /dev/null
+++ b/cpp/celeborn/client/writer/PushStrategy.cpp
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "celeborn/client/writer/PushStrategy.h"
+
+namespace celeborn {
+namespace client {
+
+std::unique_ptr<PushStrategy> PushStrategy::create(
+    const conf::CelebornConf& conf) {
+  auto strategyName = conf.clientPushLimitStrategy();
+  if (strategyName == conf::CelebornConf::kSimplePushStrategy) {
+    return std::make_unique<SimplePushStrategy>(conf);
+  } else {
+    CELEBORN_FAIL("unsupported pushStrategy: " + strategyName);
+  }
+}
+
+} // namespace client
+} // namespace celeborn
diff --git a/cpp/celeborn/client/writer/PushStrategy.h 
b/cpp/celeborn/client/writer/PushStrategy.h
new file mode 100644
index 000000000..60a40f593
--- /dev/null
+++ b/cpp/celeborn/client/writer/PushStrategy.h
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <memory>
+
+#include "celeborn/client/writer/PushState.h"
+#include "celeborn/conf/CelebornConf.h"
+
+namespace celeborn {
+namespace client {
+
+class PushState;
+
+class PushStrategy {
+ public:
+  static std::unique_ptr<PushStrategy> create(const conf::CelebornConf& conf);
+
+  PushStrategy() = default;
+
+  virtual ~PushStrategy() = default;
+
+  virtual void onSuccess(const std::string& hostAndPushPort) = 0;
+
+  virtual void onCongestControl(const std::string& hostAndPushPort) = 0;
+
+  virtual void clear() = 0;
+
+  // Control the push speed to meet the requirement.
+  virtual void limitPushSpeed(
+      PushState& pushState,
+      const std::string& hostAndPushPort) = 0;
+
+  virtual int getCurrentMaxReqsInFlight(const std::string& hostAndPushPort) = 
0;
+};
+
+class SimplePushStrategy : public PushStrategy {
+ public:
+  SimplePushStrategy(const conf::CelebornConf& conf)
+      : maxInFlightPerWorker_(conf.clientPushMaxReqsInFlightPerWorker()) {}
+
+  ~SimplePushStrategy() = default;
+
+  void onSuccess(const std::string& hostAndPushPort) override {}
+
+  void onCongestControl(const std::string& hostAndPushPort) override {}
+
+  void clear() override {}
+
+  void limitPushSpeed(PushState& pushState, const std::string& hostAndPushPort)
+      override {}
+
+  int getCurrentMaxReqsInFlight(const std::string& hostAndPushPort) override {
+    return maxInFlightPerWorker_;
+  }
+
+ private:
+  const int maxInFlightPerWorker_;
+};
+
+// TODO: support SlowStartPushStrategy
+
+} // namespace client
+} // namespace celeborn
diff --git a/cpp/celeborn/conf/CelebornConf.cpp 
b/cpp/celeborn/conf/CelebornConf.cpp
index 73b2b24dc..a3e464332 100644
--- a/cpp/celeborn/conf/CelebornConf.cpp
+++ b/cpp/celeborn/conf/CelebornConf.cpp
@@ -135,6 +135,11 @@ const std::unordered_map<std::string, 
folly::Optional<std::string>>
     CelebornConf::kDefaultProperties = {
         STR_PROP(kRpcAskTimeout, "60s"),
         STR_PROP(kRpcLookupTimeout, "30s"),
+        STR_PROP(kClientPushLimitStrategy, kSimplePushStrategy),
+        NUM_PROP(kClientPushMaxReqsInFlightPerWorker, 32),
+        NUM_PROP(kClientPushMaxReqsInFlightTotal, 256),
+        NUM_PROP(kClientPushLimitInFlightTimeoutMs, 240000),
+        NUM_PROP(kClientPushLimitInFlightSleepDeltaMs, 50),
         STR_PROP(kClientRpcGetReducerFileGroupRpcAskTimeout, "60s"),
         STR_PROP(kNetworkConnectTimeout, "10s"),
         STR_PROP(kClientFetchTimeout, "600s"),
@@ -185,6 +190,28 @@ Timeout CelebornConf::rpcLookupTimeout() const {
       toDuration(optionalProperty(kRpcLookupTimeout).value()));
 }
 
+std::string CelebornConf::clientPushLimitStrategy() const {
+  return optionalProperty(kClientPushLimitStrategy).value();
+}
+
+int CelebornConf::clientPushMaxReqsInFlightPerWorker() const {
+  return std::stoi(
+      optionalProperty(kClientPushMaxReqsInFlightPerWorker).value());
+}
+
+int CelebornConf::clientPushMaxReqsInFlightTotal() const {
+  return std::stoi(optionalProperty(kClientPushMaxReqsInFlightTotal).value());
+}
+
+long CelebornConf::clientPushLimitInFlightTimeoutMs() const {
+  return 
std::stol(optionalProperty(kClientPushLimitInFlightTimeoutMs).value());
+}
+
+long CelebornConf::clientPushLimitInFlightSleepDeltaMs() const {
+  return std::stol(
+      optionalProperty(kClientPushLimitInFlightSleepDeltaMs).value());
+}
+
 Timeout CelebornConf::clientRpcGetReducerFileGroupRpcAskTimeout() const {
   return utils::toTimeout(toDuration(
       optionalProperty(kClientRpcGetReducerFileGroupRpcAskTimeout).value()));
diff --git a/cpp/celeborn/conf/CelebornConf.h b/cpp/celeborn/conf/CelebornConf.h
index 6ee278148..98b5b9306 100644
--- a/cpp/celeborn/conf/CelebornConf.h
+++ b/cpp/celeborn/conf/CelebornConf.h
@@ -45,6 +45,23 @@ class CelebornConf : public BaseConf {
   static constexpr std::string_view kRpcLookupTimeout{
       "celeborn.rpc.lookupTimeout"};
 
+  static constexpr std::string_view kClientPushLimitStrategy{
+      "celeborn.client.push.limit.strategy"};
+
+  static constexpr std::string_view kSimplePushStrategy{"SIMPLE"};
+
+  static constexpr std::string_view kClientPushMaxReqsInFlightPerWorker{
+      "celeborn.client.push.maxReqsInFlight.perWorker"};
+
+  static constexpr std::string_view kClientPushMaxReqsInFlightTotal{
+      "celeborn.client.push.maxReqsInFlight.total"};
+
+  static constexpr std::string_view kClientPushLimitInFlightTimeoutMs{
+      "celeborn.client.push.limit.inFlight.timeout"};
+
+  static constexpr std::string_view kClientPushLimitInFlightSleepDeltaMs{
+      "celeborn.client.push.limit.inFlight.sleepInterval"};
+
   static constexpr std::string_view kClientRpcGetReducerFileGroupRpcAskTimeout{
       "celeborn.client.rpc.getReducerFileGroup.askTimeout"};
 
@@ -83,6 +100,16 @@ class CelebornConf : public BaseConf {
 
   Timeout rpcLookupTimeout() const;
 
+  std::string clientPushLimitStrategy() const;
+
+  int clientPushMaxReqsInFlightPerWorker() const;
+
+  int clientPushMaxReqsInFlightTotal() const;
+
+  long clientPushLimitInFlightTimeoutMs() const;
+
+  long clientPushLimitInFlightSleepDeltaMs() const;
+
   Timeout clientRpcGetReducerFileGroupRpcAskTimeout() const;
 
   Timeout networkConnectTimeout() const;
diff --git a/cpp/celeborn/utils/CelebornUtils.h 
b/cpp/celeborn/utils/CelebornUtils.h
index df42a09e8..669e4060e 100644
--- a/cpp/celeborn/utils/CelebornUtils.h
+++ b/cpp/celeborn/utils/CelebornUtils.h
@@ -57,6 +57,7 @@ void writeRpcAddress(
 
 using Duration = std::chrono::duration<double>;
 using Timeout = std::chrono::milliseconds;
+using MS = std::chrono::milliseconds;
 inline Timeout toTimeout(Duration duration) {
   return std::chrono::duration_cast<Timeout>(duration);
 }
@@ -161,6 +162,15 @@ class ConcurrentHashMap {
     return synchronizedMap_.lock()->size();
   }
 
+  template <class Function>
+  void forEach(Function&& apply) {
+    synchronizedMap_.withLock([&](auto& map) -> void {
+      for (auto& [key, value] : map) {
+        apply(key, value);
+      }
+    });
+  }
+
   std::optional<TValue> erase(const TKey& key) {
     // Explicitly declaring the return type helps type deduction.
     return synchronizedMap_.withLock([&](auto& map) -> std::optional<TValue> {
@@ -183,5 +193,51 @@ class ConcurrentHashMap {
       synchronizedMap_;
 };
 
+template <typename TValue, typename THasher = std::hash<TValue>>
+class ConcurrentHashSet {
+ public:
+  // Return true if the value exists, false otherwise.
+  bool contains(const TValue& value) {
+    return synchronizedSet_.withLock(
+        [&](auto& set) { return set.find(value) != set.end(); });
+  }
+
+  // Return true if the value is inserted because it doesn't exist,
+  // false if the value is not inserted because it already exists.
+  bool insert(const TValue& value) {
+    return synchronizedSet_.withLock([&](auto& set) {
+      if (set.find(value) != set.end()) {
+        return false;
+      }
+      set.insert(value);
+      return true;
+    });
+  }
+
+  // Return true if the erasion happens because the value exists,
+  // false if the erasion doesn't happen because the value doesn't exist.
+  bool erase(const TValue& value) {
+    return synchronizedSet_.withLock([&](auto& set) {
+      if (set.find(value) != set.end()) {
+        set.erase(value);
+        return true;
+      }
+      return false;
+    });
+  }
+
+  size_t size() const {
+    return synchronizedSet_.lock()->size();
+  }
+
+  void clear() {
+    synchronizedSet_.lock()->clear();
+  }
+
+ private:
+  folly::Synchronized<std::unordered_set<TValue, THasher>, std::mutex>
+      synchronizedSet_;
+};
+
 } // namespace utils
 } // namespace celeborn
diff --git a/cpp/celeborn/utils/tests/CelebornUtilsTest.cpp 
b/cpp/celeborn/utils/tests/CelebornUtilsTest.cpp
index 53f29cb62..8517074ec 100644
--- a/cpp/celeborn/utils/tests/CelebornUtilsTest.cpp
+++ b/cpp/celeborn/utils/tests/CelebornUtilsTest.cpp
@@ -29,15 +29,17 @@ class CelebornUtilsTest : public testing::Test {
  protected:
   void SetUp() override {
     map_ = std::make_unique<ConcurrentHashMap<std::string, int>>();
+    set_ = std::make_unique<ConcurrentHashSet<int>>();
   }
 
   std::unique_ptr<ConcurrentHashMap<std::string, int>> map_;
+  std::unique_ptr<ConcurrentHashSet<int>> set_;
 };
 
 TEST_F(CelebornUtilsTest, mapBasicInsertAndRetrieve) {
   map_->set("apple", 10);
   auto result = map_->get("apple");
-  ASSERT_TRUE(result.has_value());
+  EXPECT_TRUE(result.has_value());
   EXPECT_EQ(10, *result);
 }
 
@@ -45,7 +47,7 @@ TEST_F(CelebornUtilsTest, mapUpdateExistingKey) {
   map_->set("apple", 10);
   map_->set("apple", 20);
   auto result = map_->get("apple");
-  ASSERT_TRUE(result.has_value());
+  EXPECT_TRUE(result.has_value());
   EXPECT_EQ(20, *result);
 }
 
@@ -53,16 +55,33 @@ TEST_F(CelebornUtilsTest, mapComputeIfAbsent) {
   map_->set("apple", 10);
   map_->computeIfAbsent("apple", []() { return 20; });
   auto result = map_->get("apple");
-  ASSERT_TRUE(result.has_value());
+  EXPECT_TRUE(result.has_value());
   EXPECT_EQ(10, *result);
 
   map_->computeIfAbsent("banana", []() { return 30; });
   map_->computeIfAbsent("banana", []() { return 40; });
   result = map_->get("banana");
-  ASSERT_TRUE(result.has_value());
+  EXPECT_TRUE(result.has_value());
   EXPECT_EQ(30, *result);
 }
 
+TEST_F(CelebornUtilsTest, mapForEach) {
+  map_->set("apple", 10);
+  map_->set("banana", 20);
+
+  int sum = 0;
+  int64_t hashSum = 0;
+  map_->forEach([&](const std::string& key, const int value) {
+    sum += value;
+    hashSum += std::hash<std::string>{}(key);
+  });
+
+  EXPECT_EQ(sum, 10 + 20);
+  EXPECT_EQ(
+      hashSum,
+      std::hash<std::string>{}("apple") + std::hash<std::string>{}("banana"));
+}
+
 TEST_F(CelebornUtilsTest, mapRemoveKey) {
   map_->set("banana", 30);
   map_->erase("banana");
@@ -99,7 +118,7 @@ TEST_F(CelebornUtilsTest, mapConcurrentInserts) {
     for (int j = 0; j < ITEMS_PER_THREAD; ++j) {
       std::string key = "thread" + std::to_string(i) + "-" + std::to_string(j);
       auto result = map_->get(key);
-      ASSERT_TRUE(result.has_value()) << "Missing key: " << key;
+      EXPECT_TRUE(result.has_value()) << "Missing key: " << key;
       EXPECT_EQ(j, *result);
     }
   }
@@ -131,7 +150,7 @@ TEST_F(CelebornUtilsTest, mapConcurrentUpdates) {
 
   // Verify the final value is from the last writer
   auto result = map_->get("contended");
-  ASSERT_TRUE(result.has_value());
+  EXPECT_TRUE(result.has_value());
   // The exact value depends on thread scheduling, but it should be
   // from one of the threads (between 0*100+99 and 7*100+99)
   EXPECT_GE(*result, 99);
@@ -180,7 +199,59 @@ TEST_F(CelebornUtilsTest, mapConcurrentReadWrite) {
   // Verify final values are from writers
   for (int i = 0; i < NUM_WRITERS; ++i) {
     auto result = map_->get("key" + std::to_string(i));
-    ASSERT_TRUE(result.has_value());
+    EXPECT_TRUE(result.has_value());
     EXPECT_EQ(i, *result);
   }
 }
+
+TEST_F(CelebornUtilsTest, setBasicInsertAndRetrieve) {
+  EXPECT_TRUE(set_->insert(10));
+  EXPECT_TRUE(set_->contains(10));
+}
+
+TEST_F(CelebornUtilsTest, setUpdateExistingValue) {
+  EXPECT_TRUE(set_->insert(10));
+  EXPECT_FALSE(set_->insert(10));
+}
+
+TEST_F(CelebornUtilsTest, setRemoveValue) {
+  EXPECT_TRUE(set_->insert(10));
+  EXPECT_TRUE(set_->erase(10));
+  EXPECT_FALSE(set_->erase(10));
+  EXPECT_FALSE(set_->contains(10));
+}
+
+TEST_F(CelebornUtilsTest, setNonExistingValue) {
+  EXPECT_FALSE(set_->contains(10));
+}
+
+TEST_F(CelebornUtilsTest, setConcurrentInserts) {
+  constexpr int NUM_THREADS = 8;
+  constexpr int ITEMS_PER_THREAD = 100;
+  std::vector<std::thread> threads;
+
+  for (int i = 0; i < NUM_THREADS; ++i) {
+    threads.emplace_back([this, i] {
+      for (int j = 0; j < ITEMS_PER_THREAD; ++j) {
+        std::string key =
+            "thread" + std::to_string(i) + "-" + std::to_string(j);
+        set_->insert(std::hash<std::string>{}(key));
+      }
+    });
+  }
+
+  for (auto& t : threads) {
+    t.join();
+  }
+
+  EXPECT_EQ(set_->size(), NUM_THREADS * ITEMS_PER_THREAD);
+
+  // Verify all items were inserted
+  for (int i = 0; i < NUM_THREADS; ++i) {
+    for (int j = 0; j < ITEMS_PER_THREAD; ++j) {
+      std::string key = "thread" + std::to_string(i) + "-" + std::to_string(j);
+      EXPECT_TRUE(set_->contains(std::hash<std::string>{}(key)))
+          << "Missing key: " << key;
+    }
+  }
+}

Reply via email to