This is an automated email from the ASF dual-hosted git repository.
nicholasjiang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new 8ccd4dad1 [CELEBORN-2182][CIP-14] Support PushState and PushStrategy
in CppClient
8ccd4dad1 is described below
commit 8ccd4dad10a53647d54231337b5c6966160ee236
Author: HolyLow <[email protected]>
AuthorDate: Mon Nov 3 11:36:17 2025 +0800
[CELEBORN-2182][CIP-14] Support PushState and PushStrategy in CppClient
### What changes were proposed in this pull request?
This PR supports PushState and PushStrategy in CppClient.
### Why are the changes needed?
These functionalities are used in the writing procedure in CppClient.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Compilation and UTs.
Closes #3515 from
HolyLow/issue/celeborn-2182-support-PushState-in-cpp-client.
Authored-by: HolyLow <[email protected]>
Signed-off-by: SteNicholas <[email protected]>
---
cpp/celeborn/client/CMakeLists.txt | 2 +
cpp/celeborn/client/tests/CMakeLists.txt | 1 +
cpp/celeborn/client/tests/PushStateTest.cpp | 155 +++++++++++++++++++++++
cpp/celeborn/client/writer/PushState.cpp | 163 +++++++++++++++++++++++++
cpp/celeborn/client/writer/PushState.h | 83 +++++++++++++
cpp/celeborn/client/writer/PushStrategy.cpp | 34 ++++++
cpp/celeborn/client/writer/PushStrategy.h | 79 ++++++++++++
cpp/celeborn/conf/CelebornConf.cpp | 27 ++++
cpp/celeborn/conf/CelebornConf.h | 27 ++++
cpp/celeborn/utils/CelebornUtils.h | 56 +++++++++
cpp/celeborn/utils/tests/CelebornUtilsTest.cpp | 85 +++++++++++--
11 files changed, 705 insertions(+), 7 deletions(-)
diff --git a/cpp/celeborn/client/CMakeLists.txt
b/cpp/celeborn/client/CMakeLists.txt
index 2586f6855..1af9069ee 100644
--- a/cpp/celeborn/client/CMakeLists.txt
+++ b/cpp/celeborn/client/CMakeLists.txt
@@ -16,6 +16,8 @@ add_library(
client
reader/WorkerPartitionReader.cpp
reader/CelebornInputStream.cpp
+ writer/PushState.cpp
+ writer/PushStrategy.cpp
ShuffleClient.cpp
compress/Decompressor.cpp
compress/Lz4Decompressor.cpp
diff --git a/cpp/celeborn/client/tests/CMakeLists.txt
b/cpp/celeborn/client/tests/CMakeLists.txt
index d8a98e2b6..6341c84e7 100644
--- a/cpp/celeborn/client/tests/CMakeLists.txt
+++ b/cpp/celeborn/client/tests/CMakeLists.txt
@@ -16,6 +16,7 @@
add_executable(
celeborn_client_test
WorkerPartitionReaderTest.cpp
+ PushStateTest.cpp
Lz4DecompressorTest.cpp
ZstdDecompressorTest.cpp
Lz4CompressorTest.cpp
diff --git a/cpp/celeborn/client/tests/PushStateTest.cpp
b/cpp/celeborn/client/tests/PushStateTest.cpp
new file mode 100644
index 000000000..94b6c3aa3
--- /dev/null
+++ b/cpp/celeborn/client/tests/PushStateTest.cpp
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <gtest/gtest.h>
+
+#include "celeborn/client/writer/PushState.h"
+
+using namespace celeborn;
+using namespace celeborn::client;
+
+class PushStateTest : public testing::Test {
+ protected:
+ void SetUp() override {
+ conf::CelebornConf conf;
+ conf.registerProperty(
+ conf::CelebornConf::kClientPushLimitInFlightTimeoutMs,
+ std::to_string(pushTimeoutMs_));
+ conf.registerProperty(
+ conf::CelebornConf::kClientPushLimitInFlightSleepDeltaMs,
+ std::to_string(pushSleepDeltaMs_));
+ conf.registerProperty(
+ conf::CelebornConf::kClientPushMaxReqsInFlightTotal,
+ std::to_string(maxReqsInFlight_));
+ conf.registerProperty(
+ conf::CelebornConf::kClientPushMaxReqsInFlightPerWorker,
+ std::to_string(maxReqsInFlight_));
+
+ pushState_ = std::make_unique<PushState>(conf);
+ }
+
+ std::unique_ptr<PushState> pushState_;
+ static constexpr int pushTimeoutMs_ = 100;
+ static constexpr int pushSleepDeltaMs_ = 10;
+ static constexpr int maxReqsInFlight_ = 2;
+};
+
+TEST_F(PushStateTest, limitMaxInFlight) {
+ const std::string hostAndPushPort = "xx.xx.xx.xx:8080";
+ const int addBatchCalls = maxReqsInFlight_ + 1;
+ std::vector<bool> addBatchMarks(addBatchCalls, false);
+ std::thread addBatchThread([&]() {
+ for (auto i = 0; i < addBatchCalls; i++) {
+ pushState_->addBatch(i, hostAndPushPort);
+ EXPECT_FALSE(pushState_->limitMaxInFlight(hostAndPushPort));
+ addBatchMarks[i] = true;
+ }
+ });
+
+ std::this_thread::sleep_for(std::chrono::milliseconds(pushSleepDeltaMs_));
+ for (auto i = 0; i < maxReqsInFlight_; i++) {
+ EXPECT_TRUE(addBatchMarks[i]);
+ }
+ EXPECT_FALSE(addBatchMarks[maxReqsInFlight_]);
+
+ pushState_->removeBatch(0, hostAndPushPort);
+ addBatchThread.join();
+ EXPECT_TRUE(addBatchMarks[maxReqsInFlight_]);
+}
+
+TEST_F(PushStateTest, limitMaxInFlightTimeout) {
+ const std::string hostAndPushPort = "xx.xx.xx.xx:8080";
+ const int addBatchCalls = maxReqsInFlight_ + 1;
+ std::vector<bool> addBatchMarks(addBatchCalls, false);
+ std::thread addBatchThread([&]() {
+ for (auto i = 0; i < addBatchCalls; i++) {
+ pushState_->addBatch(i, hostAndPushPort);
+ auto result = pushState_->limitMaxInFlight(hostAndPushPort);
+ if (i < maxReqsInFlight_) {
+ EXPECT_FALSE(result);
+ } else {
+ EXPECT_TRUE(result);
+ }
+ addBatchMarks[i] = !result;
+ }
+ });
+
+ std::this_thread::sleep_for(std::chrono::milliseconds(pushSleepDeltaMs_));
+ for (auto i = 0; i < maxReqsInFlight_; i++) {
+ EXPECT_TRUE(addBatchMarks[i]);
+ }
+ EXPECT_FALSE(addBatchMarks[maxReqsInFlight_]);
+
+ addBatchThread.join();
+ EXPECT_FALSE(addBatchMarks[maxReqsInFlight_]);
+}
+
+TEST_F(PushStateTest, limitZeroInFlight) {
+ const std::string hostAndPushPort = "xx.xx.xx.xx:8080";
+ const int addBatchCalls = 1;
+ std::vector<bool> addBatchMarks(addBatchCalls, false);
+ std::thread addBatchThread([&]() {
+ pushState_->addBatch(0, hostAndPushPort);
+ EXPECT_FALSE(pushState_->limitZeroInFlight());
+ addBatchMarks[0] = true;
+ });
+
+ std::this_thread::sleep_for(std::chrono::milliseconds(pushSleepDeltaMs_));
+ EXPECT_FALSE(addBatchMarks[0]);
+
+ pushState_->removeBatch(0, hostAndPushPort);
+ addBatchThread.join();
+ EXPECT_TRUE(addBatchMarks[0]);
+}
+
+TEST_F(PushStateTest, limitZeroInFlightTimeout) {
+ const std::string hostAndPushPort = "xx.xx.xx.xx:8080";
+ const int addBatchCalls = 1;
+ std::vector<bool> addBatchMarks(addBatchCalls, false);
+ std::thread addBatchThread([&]() {
+ pushState_->addBatch(0, hostAndPushPort);
+ auto result = pushState_->limitZeroInFlight();
+ EXPECT_TRUE(result);
+ addBatchMarks[0] = !result;
+ });
+
+ std::this_thread::sleep_for(std::chrono::milliseconds(pushSleepDeltaMs_));
+ EXPECT_FALSE(addBatchMarks[0]);
+
+ addBatchThread.join();
+ EXPECT_FALSE(addBatchMarks[0]);
+}
+
+TEST_F(PushStateTest, throwException) {
+ const std::string hostAndPushPort = "xx.xx.xx.xx:8080";
+ pushState_->setException(std::make_unique<std::exception>());
+ bool exceptionThrowed = false;
+ try {
+ pushState_->limitMaxInFlight(hostAndPushPort);
+ } catch (...) {
+ exceptionThrowed = true;
+ }
+ EXPECT_TRUE(exceptionThrowed);
+
+ exceptionThrowed = false;
+ try {
+ pushState_->limitZeroInFlight();
+ } catch (...) {
+ exceptionThrowed = true;
+ }
+ EXPECT_TRUE(exceptionThrowed);
+}
diff --git a/cpp/celeborn/client/writer/PushState.cpp
b/cpp/celeborn/client/writer/PushState.cpp
new file mode 100644
index 000000000..b86a2aadb
--- /dev/null
+++ b/cpp/celeborn/client/writer/PushState.cpp
@@ -0,0 +1,163 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "celeborn/client/writer/PushState.h"
+
+namespace celeborn {
+namespace client {
+
+PushState::PushState(const conf::CelebornConf& conf)
+ : waitInflightTimeoutMs_(conf.clientPushLimitInFlightTimeoutMs()),
+ deltaMs_(conf.clientPushLimitInFlightSleepDeltaMs()),
+ pushStrategy_(PushStrategy::create(conf)),
+ maxInFlightReqsTotal_(conf.clientPushMaxReqsInFlightTotal()) {}
+
+int PushState::nextBatchId() {
+ return currBatchId_.fetch_add(1);
+}
+
+void PushState::addBatch(int batchId, const std::string& hostAndPushPort) {
+ auto batchIdSet = inflightBatchesPerAddress_.computeIfAbsent(
+ hostAndPushPort,
+ [&]() { return std::make_shared<utils::ConcurrentHashSet<int>>(); });
+ batchIdSet->insert(batchId);
+ totalInflightReqs_.fetch_add(1);
+}
+
+void PushState::onSuccess(const std::string& hostAndPushPort) {
+ pushStrategy_->onSuccess(hostAndPushPort);
+}
+
+void PushState::onCongestControl(const std::string& hostAndPushPort) {
+ pushStrategy_->onCongestControl(hostAndPushPort);
+}
+
+void PushState::removeBatch(int batchId, const std::string& hostAndPushPort) {
+ auto batchIdSetOptional = inflightBatchesPerAddress_.get(hostAndPushPort);
+ if (batchIdSetOptional.has_value()) {
+ auto batchIdSet = batchIdSetOptional.value();
+ batchIdSet->erase(batchId);
+ totalInflightReqs_.fetch_sub(1);
+ } else {
+ LOG(WARNING) << "BatchIdSet of " << hostAndPushPort << " doesn't exist.";
+ }
+}
+
+bool PushState::limitMaxInFlight(const std::string& hostAndPushPort) {
+ throwIfExceptionExists();
+
+ pushStrategy_->limitPushSpeed(*this, hostAndPushPort);
+ int currentMaxReqsInFlight =
+ pushStrategy_->getCurrentMaxReqsInFlight(hostAndPushPort);
+
+ auto batchIdSet = inflightBatchesPerAddress_.computeIfAbsent(
+ hostAndPushPort,
+ [&]() { return std::make_shared<utils::ConcurrentHashSet<int>>(); });
+ long times = waitInflightTimeoutMs_ / deltaMs_;
+ for (; times > 0; times--) {
+ if (totalInflightReqs_ <= maxInFlightReqsTotal_ &&
+ batchIdSet->size() <= currentMaxReqsInFlight) {
+ break;
+ }
+ throwIfExceptionExists();
+ std::this_thread::sleep_for(utils::MS(deltaMs_));
+ }
+
+ if (times <= 0) {
+ LOG(WARNING) << "After waiting for " << waitInflightTimeoutMs_
+ << " ms, there are still " << batchIdSet->size()
+ << " batches in flight for hostAndPushPort " <<
hostAndPushPort
+ << ", which exceeds the current limit "
+ << currentMaxReqsInFlight;
+ }
+ throwIfExceptionExists();
+ return times <= 0;
+}
+
+bool PushState::limitZeroInFlight() {
+ throwIfExceptionExists();
+
+ long times = waitInflightTimeoutMs_ / deltaMs_;
+ for (; times > 0; times--) {
+ if (totalInflightReqs_ <= 0) {
+ break;
+ }
+ throwIfExceptionExists();
+ std::this_thread::sleep_for(utils::MS(deltaMs_));
+ }
+
+ if (times <= 0) {
+ std::string addressInfos;
+ inflightBatchesPerAddress_.forEach(
+ [&](const std::string& address,
+ const std::shared_ptr<utils::ConcurrentHashSet<int>>&
+ inflightBatches) {
+ if (inflightBatches->size() <= 0) {
+ return;
+ }
+ if (!addressInfos.empty()) {
+ addressInfos += ", ";
+ }
+ addressInfos += fmt::format(
+ "{} batches for hostAndPushPort {}",
+ inflightBatches->size(),
+ address);
+ });
+ LOG(ERROR) << "After waiting for " << waitInflightTimeoutMs_
+ << " ms, there are still " << totalInflightReqs_
+ << " in flight: [" << addressInfos
+ << "] which exceeds the current limit 0.";
+ }
+ throwIfExceptionExists();
+ return times <= 0;
+}
+
+bool PushState::exceptionExists() const {
+ auto exp = exception_.rlock();
+ return (bool)(*exp);
+}
+
+void PushState::setException(std::unique_ptr<std::exception> exception) {
+ auto exp = exception_.wlock();
+ if (!(*exp)) {
+ *exp = std::move(exception);
+ }
+}
+
+std::optional<std::string> PushState::getExceptionMsg() const {
+ auto exp = exception_.rlock();
+ if (*exp) {
+ return (*exp)->what();
+ }
+ return std::nullopt;
+}
+
+void PushState::cleanup() {
+ inflightBatchesPerAddress_.clear();
+ totalInflightReqs_ = 0;
+ pushStrategy_->clear();
+}
+
+void PushState::throwIfExceptionExists() {
+ auto exp = exception_.rlock();
+ if (*exp) {
+ CELEBORN_FAIL((*exp)->what());
+ }
+}
+
+} // namespace client
+} // namespace celeborn
diff --git a/cpp/celeborn/client/writer/PushState.h
b/cpp/celeborn/client/writer/PushState.h
new file mode 100644
index 000000000..a32b0971c
--- /dev/null
+++ b/cpp/celeborn/client/writer/PushState.h
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <atomic>
+
+#include "celeborn/client/writer/PushStrategy.h"
+#include "celeborn/conf/CelebornConf.h"
+#include "celeborn/utils/CelebornUtils.h"
+
+namespace celeborn {
+namespace client {
+
+class PushStrategy;
+
+/// Records the states of a mapKey, including the ongoing package number id,
the
+/// exception, etc. Besides, the congestionControl is also enforced.
+class PushState {
+ public:
+ PushState(const conf::CelebornConf& conf);
+
+ int nextBatchId();
+
+ void addBatch(int batchId, const std::string& hostAndPushPort);
+
+ void onSuccess(const std::string& hostAndPushPort);
+
+ void onCongestControl(const std::string& hostAndPushPort);
+
+ void removeBatch(int batchId, const std::string& hostAndPushPort);
+
+ // Check if the host's ongoing package num reaches the max limit, if so,
+ // block until the ongoing package num decreases below max limit. If the
+ // limit operation succeeds before timeout, return false, otherwise return
+ // true.
+ bool limitMaxInFlight(const std::string& hostAndPushPort);
+
+ // Check if the pushState's ongoing package num reaches zero, if not, block
+ // until the ongoing package num decreases to zero. If the limit operation
+ // succeeds before timeout, return false, otherwise return true.
+ bool limitZeroInFlight();
+
+ bool exceptionExists() const;
+
+ void setException(std::unique_ptr<std::exception> exception);
+
+ std::optional<std::string> getExceptionMsg() const;
+
+ void cleanup();
+
+ private:
+ void throwIfExceptionExists();
+
+ std::atomic<int> currBatchId_{1};
+ std::atomic<long> totalInflightReqs_{0};
+ const long waitInflightTimeoutMs_;
+ const long deltaMs_;
+ const std::unique_ptr<PushStrategy> pushStrategy_;
+ const int maxInFlightReqsTotal_;
+ utils::ConcurrentHashMap<
+ std::string,
+ std::shared_ptr<utils::ConcurrentHashSet<int>>>
+ inflightBatchesPerAddress_;
+ folly::Synchronized<std::unique_ptr<std::exception>> exception_;
+};
+
+} // namespace client
+} // namespace celeborn
diff --git a/cpp/celeborn/client/writer/PushStrategy.cpp
b/cpp/celeborn/client/writer/PushStrategy.cpp
new file mode 100644
index 000000000..66dad6c4e
--- /dev/null
+++ b/cpp/celeborn/client/writer/PushStrategy.cpp
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "celeborn/client/writer/PushStrategy.h"
+
+namespace celeborn {
+namespace client {
+
+std::unique_ptr<PushStrategy> PushStrategy::create(
+ const conf::CelebornConf& conf) {
+ auto strategyName = conf.clientPushLimitStrategy();
+ if (strategyName == conf::CelebornConf::kSimplePushStrategy) {
+ return std::make_unique<SimplePushStrategy>(conf);
+ } else {
+ CELEBORN_FAIL("unsupported pushStrategy: " + strategyName);
+ }
+}
+
+} // namespace client
+} // namespace celeborn
diff --git a/cpp/celeborn/client/writer/PushStrategy.h
b/cpp/celeborn/client/writer/PushStrategy.h
new file mode 100644
index 000000000..60a40f593
--- /dev/null
+++ b/cpp/celeborn/client/writer/PushStrategy.h
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <memory>
+
+#include "celeborn/client/writer/PushState.h"
+#include "celeborn/conf/CelebornConf.h"
+
+namespace celeborn {
+namespace client {
+
+class PushState;
+
+class PushStrategy {
+ public:
+ static std::unique_ptr<PushStrategy> create(const conf::CelebornConf& conf);
+
+ PushStrategy() = default;
+
+ virtual ~PushStrategy() = default;
+
+ virtual void onSuccess(const std::string& hostAndPushPort) = 0;
+
+ virtual void onCongestControl(const std::string& hostAndPushPort) = 0;
+
+ virtual void clear() = 0;
+
+ // Control the push speed to meet the requirement.
+ virtual void limitPushSpeed(
+ PushState& pushState,
+ const std::string& hostAndPushPort) = 0;
+
+ virtual int getCurrentMaxReqsInFlight(const std::string& hostAndPushPort) =
0;
+};
+
+class SimplePushStrategy : public PushStrategy {
+ public:
+ SimplePushStrategy(const conf::CelebornConf& conf)
+ : maxInFlightPerWorker_(conf.clientPushMaxReqsInFlightPerWorker()) {}
+
+ ~SimplePushStrategy() = default;
+
+ void onSuccess(const std::string& hostAndPushPort) override {}
+
+ void onCongestControl(const std::string& hostAndPushPort) override {}
+
+ void clear() override {}
+
+ void limitPushSpeed(PushState& pushState, const std::string& hostAndPushPort)
+ override {}
+
+ int getCurrentMaxReqsInFlight(const std::string& hostAndPushPort) override {
+ return maxInFlightPerWorker_;
+ }
+
+ private:
+ const int maxInFlightPerWorker_;
+};
+
+// TODO: support SlowStartPushStrategy
+
+} // namespace client
+} // namespace celeborn
diff --git a/cpp/celeborn/conf/CelebornConf.cpp
b/cpp/celeborn/conf/CelebornConf.cpp
index 73b2b24dc..a3e464332 100644
--- a/cpp/celeborn/conf/CelebornConf.cpp
+++ b/cpp/celeborn/conf/CelebornConf.cpp
@@ -135,6 +135,11 @@ const std::unordered_map<std::string,
folly::Optional<std::string>>
CelebornConf::kDefaultProperties = {
STR_PROP(kRpcAskTimeout, "60s"),
STR_PROP(kRpcLookupTimeout, "30s"),
+ STR_PROP(kClientPushLimitStrategy, kSimplePushStrategy),
+ NUM_PROP(kClientPushMaxReqsInFlightPerWorker, 32),
+ NUM_PROP(kClientPushMaxReqsInFlightTotal, 256),
+ NUM_PROP(kClientPushLimitInFlightTimeoutMs, 240000),
+ NUM_PROP(kClientPushLimitInFlightSleepDeltaMs, 50),
STR_PROP(kClientRpcGetReducerFileGroupRpcAskTimeout, "60s"),
STR_PROP(kNetworkConnectTimeout, "10s"),
STR_PROP(kClientFetchTimeout, "600s"),
@@ -185,6 +190,28 @@ Timeout CelebornConf::rpcLookupTimeout() const {
toDuration(optionalProperty(kRpcLookupTimeout).value()));
}
+std::string CelebornConf::clientPushLimitStrategy() const {
+ return optionalProperty(kClientPushLimitStrategy).value();
+}
+
+int CelebornConf::clientPushMaxReqsInFlightPerWorker() const {
+ return std::stoi(
+ optionalProperty(kClientPushMaxReqsInFlightPerWorker).value());
+}
+
+int CelebornConf::clientPushMaxReqsInFlightTotal() const {
+ return std::stoi(optionalProperty(kClientPushMaxReqsInFlightTotal).value());
+}
+
+long CelebornConf::clientPushLimitInFlightTimeoutMs() const {
+ return
std::stol(optionalProperty(kClientPushLimitInFlightTimeoutMs).value());
+}
+
+long CelebornConf::clientPushLimitInFlightSleepDeltaMs() const {
+ return std::stol(
+ optionalProperty(kClientPushLimitInFlightSleepDeltaMs).value());
+}
+
Timeout CelebornConf::clientRpcGetReducerFileGroupRpcAskTimeout() const {
return utils::toTimeout(toDuration(
optionalProperty(kClientRpcGetReducerFileGroupRpcAskTimeout).value()));
diff --git a/cpp/celeborn/conf/CelebornConf.h b/cpp/celeborn/conf/CelebornConf.h
index 6ee278148..98b5b9306 100644
--- a/cpp/celeborn/conf/CelebornConf.h
+++ b/cpp/celeborn/conf/CelebornConf.h
@@ -45,6 +45,23 @@ class CelebornConf : public BaseConf {
static constexpr std::string_view kRpcLookupTimeout{
"celeborn.rpc.lookupTimeout"};
+ static constexpr std::string_view kClientPushLimitStrategy{
+ "celeborn.client.push.limit.strategy"};
+
+ static constexpr std::string_view kSimplePushStrategy{"SIMPLE"};
+
+ static constexpr std::string_view kClientPushMaxReqsInFlightPerWorker{
+ "celeborn.client.push.maxReqsInFlight.perWorker"};
+
+ static constexpr std::string_view kClientPushMaxReqsInFlightTotal{
+ "celeborn.client.push.maxReqsInFlight.total"};
+
+ static constexpr std::string_view kClientPushLimitInFlightTimeoutMs{
+ "celeborn.client.push.limit.inFlight.timeout"};
+
+ static constexpr std::string_view kClientPushLimitInFlightSleepDeltaMs{
+ "celeborn.client.push.limit.inFlight.sleepInterval"};
+
static constexpr std::string_view kClientRpcGetReducerFileGroupRpcAskTimeout{
"celeborn.client.rpc.getReducerFileGroup.askTimeout"};
@@ -83,6 +100,16 @@ class CelebornConf : public BaseConf {
Timeout rpcLookupTimeout() const;
+ std::string clientPushLimitStrategy() const;
+
+ int clientPushMaxReqsInFlightPerWorker() const;
+
+ int clientPushMaxReqsInFlightTotal() const;
+
+ long clientPushLimitInFlightTimeoutMs() const;
+
+ long clientPushLimitInFlightSleepDeltaMs() const;
+
Timeout clientRpcGetReducerFileGroupRpcAskTimeout() const;
Timeout networkConnectTimeout() const;
diff --git a/cpp/celeborn/utils/CelebornUtils.h
b/cpp/celeborn/utils/CelebornUtils.h
index df42a09e8..669e4060e 100644
--- a/cpp/celeborn/utils/CelebornUtils.h
+++ b/cpp/celeborn/utils/CelebornUtils.h
@@ -57,6 +57,7 @@ void writeRpcAddress(
using Duration = std::chrono::duration<double>;
using Timeout = std::chrono::milliseconds;
+using MS = std::chrono::milliseconds;
inline Timeout toTimeout(Duration duration) {
return std::chrono::duration_cast<Timeout>(duration);
}
@@ -161,6 +162,15 @@ class ConcurrentHashMap {
return synchronizedMap_.lock()->size();
}
+ template <class Function>
+ void forEach(Function&& apply) {
+ synchronizedMap_.withLock([&](auto& map) -> void {
+ for (auto& [key, value] : map) {
+ apply(key, value);
+ }
+ });
+ }
+
std::optional<TValue> erase(const TKey& key) {
// Explicitly declaring the return type helps type deduction.
return synchronizedMap_.withLock([&](auto& map) -> std::optional<TValue> {
@@ -183,5 +193,51 @@ class ConcurrentHashMap {
synchronizedMap_;
};
+template <typename TValue, typename THasher = std::hash<TValue>>
+class ConcurrentHashSet {
+ public:
+ // Return true if the value exists, false otherwise.
+ bool contains(const TValue& value) {
+ return synchronizedSet_.withLock(
+ [&](auto& set) { return set.find(value) != set.end(); });
+ }
+
+ // Return true if the value is inserted because it doesn't exist,
+ // false if the value is not inserted because it already exists.
+ bool insert(const TValue& value) {
+ return synchronizedSet_.withLock([&](auto& set) {
+ if (set.find(value) != set.end()) {
+ return false;
+ }
+ set.insert(value);
+ return true;
+ });
+ }
+
+ // Return true if the erasion happens because the value exists,
+ // false if the erasion doesn't happen because the value doesn't exist.
+ bool erase(const TValue& value) {
+ return synchronizedSet_.withLock([&](auto& set) {
+ if (set.find(value) != set.end()) {
+ set.erase(value);
+ return true;
+ }
+ return false;
+ });
+ }
+
+ size_t size() const {
+ return synchronizedSet_.lock()->size();
+ }
+
+ void clear() {
+ synchronizedSet_.lock()->clear();
+ }
+
+ private:
+ folly::Synchronized<std::unordered_set<TValue, THasher>, std::mutex>
+ synchronizedSet_;
+};
+
} // namespace utils
} // namespace celeborn
diff --git a/cpp/celeborn/utils/tests/CelebornUtilsTest.cpp
b/cpp/celeborn/utils/tests/CelebornUtilsTest.cpp
index 53f29cb62..8517074ec 100644
--- a/cpp/celeborn/utils/tests/CelebornUtilsTest.cpp
+++ b/cpp/celeborn/utils/tests/CelebornUtilsTest.cpp
@@ -29,15 +29,17 @@ class CelebornUtilsTest : public testing::Test {
protected:
void SetUp() override {
map_ = std::make_unique<ConcurrentHashMap<std::string, int>>();
+ set_ = std::make_unique<ConcurrentHashSet<int>>();
}
std::unique_ptr<ConcurrentHashMap<std::string, int>> map_;
+ std::unique_ptr<ConcurrentHashSet<int>> set_;
};
TEST_F(CelebornUtilsTest, mapBasicInsertAndRetrieve) {
map_->set("apple", 10);
auto result = map_->get("apple");
- ASSERT_TRUE(result.has_value());
+ EXPECT_TRUE(result.has_value());
EXPECT_EQ(10, *result);
}
@@ -45,7 +47,7 @@ TEST_F(CelebornUtilsTest, mapUpdateExistingKey) {
map_->set("apple", 10);
map_->set("apple", 20);
auto result = map_->get("apple");
- ASSERT_TRUE(result.has_value());
+ EXPECT_TRUE(result.has_value());
EXPECT_EQ(20, *result);
}
@@ -53,16 +55,33 @@ TEST_F(CelebornUtilsTest, mapComputeIfAbsent) {
map_->set("apple", 10);
map_->computeIfAbsent("apple", []() { return 20; });
auto result = map_->get("apple");
- ASSERT_TRUE(result.has_value());
+ EXPECT_TRUE(result.has_value());
EXPECT_EQ(10, *result);
map_->computeIfAbsent("banana", []() { return 30; });
map_->computeIfAbsent("banana", []() { return 40; });
result = map_->get("banana");
- ASSERT_TRUE(result.has_value());
+ EXPECT_TRUE(result.has_value());
EXPECT_EQ(30, *result);
}
+TEST_F(CelebornUtilsTest, mapForEach) {
+ map_->set("apple", 10);
+ map_->set("banana", 20);
+
+ int sum = 0;
+ int64_t hashSum = 0;
+ map_->forEach([&](const std::string& key, const int value) {
+ sum += value;
+ hashSum += std::hash<std::string>{}(key);
+ });
+
+ EXPECT_EQ(sum, 10 + 20);
+ EXPECT_EQ(
+ hashSum,
+ std::hash<std::string>{}("apple") + std::hash<std::string>{}("banana"));
+}
+
TEST_F(CelebornUtilsTest, mapRemoveKey) {
map_->set("banana", 30);
map_->erase("banana");
@@ -99,7 +118,7 @@ TEST_F(CelebornUtilsTest, mapConcurrentInserts) {
for (int j = 0; j < ITEMS_PER_THREAD; ++j) {
std::string key = "thread" + std::to_string(i) + "-" + std::to_string(j);
auto result = map_->get(key);
- ASSERT_TRUE(result.has_value()) << "Missing key: " << key;
+ EXPECT_TRUE(result.has_value()) << "Missing key: " << key;
EXPECT_EQ(j, *result);
}
}
@@ -131,7 +150,7 @@ TEST_F(CelebornUtilsTest, mapConcurrentUpdates) {
// Verify the final value is from the last writer
auto result = map_->get("contended");
- ASSERT_TRUE(result.has_value());
+ EXPECT_TRUE(result.has_value());
// The exact value depends on thread scheduling, but it should be
// from one of the threads (between 0*100+99 and 7*100+99)
EXPECT_GE(*result, 99);
@@ -180,7 +199,59 @@ TEST_F(CelebornUtilsTest, mapConcurrentReadWrite) {
// Verify final values are from writers
for (int i = 0; i < NUM_WRITERS; ++i) {
auto result = map_->get("key" + std::to_string(i));
- ASSERT_TRUE(result.has_value());
+ EXPECT_TRUE(result.has_value());
EXPECT_EQ(i, *result);
}
}
+
+TEST_F(CelebornUtilsTest, setBasicInsertAndRetrieve) {
+ EXPECT_TRUE(set_->insert(10));
+ EXPECT_TRUE(set_->contains(10));
+}
+
+TEST_F(CelebornUtilsTest, setUpdateExistingValue) {
+ EXPECT_TRUE(set_->insert(10));
+ EXPECT_FALSE(set_->insert(10));
+}
+
+TEST_F(CelebornUtilsTest, setRemoveValue) {
+ EXPECT_TRUE(set_->insert(10));
+ EXPECT_TRUE(set_->erase(10));
+ EXPECT_FALSE(set_->erase(10));
+ EXPECT_FALSE(set_->contains(10));
+}
+
+TEST_F(CelebornUtilsTest, setNonExistingValue) {
+ EXPECT_FALSE(set_->contains(10));
+}
+
+TEST_F(CelebornUtilsTest, setConcurrentInserts) {
+ constexpr int NUM_THREADS = 8;
+ constexpr int ITEMS_PER_THREAD = 100;
+ std::vector<std::thread> threads;
+
+ for (int i = 0; i < NUM_THREADS; ++i) {
+ threads.emplace_back([this, i] {
+ for (int j = 0; j < ITEMS_PER_THREAD; ++j) {
+ std::string key =
+ "thread" + std::to_string(i) + "-" + std::to_string(j);
+ set_->insert(std::hash<std::string>{}(key));
+ }
+ });
+ }
+
+ for (auto& t : threads) {
+ t.join();
+ }
+
+ EXPECT_EQ(set_->size(), NUM_THREADS * ITEMS_PER_THREAD);
+
+ // Verify all items were inserted
+ for (int i = 0; i < NUM_THREADS; ++i) {
+ for (int j = 0; j < ITEMS_PER_THREAD; ++j) {
+ std::string key = "thread" + std::to_string(i) + "-" + std::to_string(j);
+ EXPECT_TRUE(set_->contains(std::hash<std::string>{}(key)))
+ << "Missing key: " << key;
+ }
+ }
+}