This is an automated email from the ASF dual-hosted git repository.
nicholasjiang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new 36cdc29ea [CELEBORN-2206][CIP-14] Support PushDataCallback in CppClient
36cdc29ea is described below
commit 36cdc29eab749fe24fa23cb654c252bdbde1bd85
Author: HolyLow <[email protected]>
AuthorDate: Fri Nov 28 15:22:48 2025 +0800
[CELEBORN-2206][CIP-14] Support PushDataCallback in CppClient
### What changes were proposed in this pull request?
This PR supports PushDataCallback in CppClient.
### Why are the changes needed?
PushDataCallback is the building block of PushData logic of CppClient's
writing procedure.
### Does this PR resolve a correctness bug?
No.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Compilation and UTs.
Closes #3543 from
HolyLow/issue/celeborn-2206-support-PushDataCallback-in-cpp-client.
Authored-by: HolyLow <[email protected]>
Signed-off-by: SteNicholas <[email protected]>
---
cpp/celeborn/client/CMakeLists.txt | 1 +
cpp/celeborn/client/ShuffleClient.h | 38 ++
cpp/celeborn/client/tests/CMakeLists.txt | 1 +
cpp/celeborn/client/tests/PushDataCallbackTest.cpp | 389 +++++++++++++++++++++
cpp/celeborn/client/writer/PushDataCallback.cpp | 266 ++++++++++++++
cpp/celeborn/client/writer/PushDataCallback.h | 96 +++++
cpp/celeborn/client/writer/ReviveManager.h | 1 +
cpp/celeborn/conf/CelebornConf.cpp | 8 +-
cpp/celeborn/conf/CelebornConf.h | 6 +
cpp/celeborn/protocol/PartitionLocation.cpp | 12 +
cpp/celeborn/protocol/PartitionLocation.h | 9 +-
cpp/celeborn/utils/CelebornUtils.h | 6 +
12 files changed, 828 insertions(+), 5 deletions(-)
diff --git a/cpp/celeborn/client/CMakeLists.txt
b/cpp/celeborn/client/CMakeLists.txt
index fa0290733..112b1f38d 100644
--- a/cpp/celeborn/client/CMakeLists.txt
+++ b/cpp/celeborn/client/CMakeLists.txt
@@ -19,6 +19,7 @@ add_library(
writer/PushState.cpp
writer/PushStrategy.cpp
writer/ReviveManager.cpp
+ writer/PushDataCallback.cpp
ShuffleClient.cpp
compress/Decompressor.cpp
compress/Lz4Decompressor.cpp
diff --git a/cpp/celeborn/client/ShuffleClient.h
b/cpp/celeborn/client/ShuffleClient.h
index 41e953515..dc71a39bc 100644
--- a/cpp/celeborn/client/ShuffleClient.h
+++ b/cpp/celeborn/client/ShuffleClient.h
@@ -18,6 +18,9 @@
#pragma once
#include "celeborn/client/reader/CelebornInputStream.h"
+#include "celeborn/client/writer/PushDataCallback.h"
+#include "celeborn/client/writer/PushState.h"
+#include "celeborn/client/writer/ReviveManager.h"
#include "celeborn/network/NettyRpcEndpointRef.h"
namespace celeborn {
@@ -51,11 +54,15 @@ class ShuffleClient {
virtual void shutdown() = 0;
};
+class ReviveManager;
+class PushDataCallback;
+
class ShuffleClientImpl
: public ShuffleClient,
public std::enable_shared_from_this<ShuffleClientImpl> {
public:
friend class ReviveManager;
+ friend class PushDataCallback;
using PtrReviveRequest = std::shared_ptr<protocol::ReviveRequest>;
using PartitionLocationMap = utils::ConcurrentHashMap<
@@ -104,12 +111,29 @@ class ShuffleClientImpl
const std::shared_ptr<const conf::CelebornConf>& conf,
const std::shared_ptr<network::TransportClientFactory>& clientFactory);
+ // TODO: currently this function serves as a stub. will be updated in future
+ // commits.
+ virtual void submitRetryPushData(
+ int shuffleId,
+ std::unique_ptr<memory::ReadOnlyByteBuffer> body,
+ int batchId,
+ std::shared_ptr<PushDataCallback> pushDataCallback,
+ std::shared_ptr<PushState> pushState,
+ PtrReviveRequest request,
+ int remainReviveTimes,
+ long dueTimeMs) {}
+
// TODO: currently this function serves as a stub. will be updated in future
// commits.
virtual bool mapperEnded(int shuffleId, int mapId) {
return true;
}
+ // TODO: currently this function serves as a stub. will be updated in future
+ // commits.
+ virtual void addRequestToReviveManager(
+ std::shared_ptr<protocol::ReviveRequest> reviveRequest) {}
+
// TODO: currently this function serves as a stub. will be updated in future
// commits.
virtual std::optional<std::unordered_map<int, int>> reviveBatch(
@@ -124,6 +148,16 @@ class ShuffleClientImpl
return partitionLocationMaps_.get(shuffleId);
}
+ virtual utils::
+ ConcurrentHashMap<int, std::shared_ptr<utils::ConcurrentHashSet<int>>>&
+ mapperEndSets() {
+ return mapperEndSets_;
+ }
+
+ virtual void addPushDataRetryTask(folly::Func&& task) {
+ pushDataRetryPool_->add(std::move(task));
+ }
+
private:
// TODO: no support for WAIT as it is not used.
static bool newerPartitionLocationExists(
@@ -140,12 +174,16 @@ class ShuffleClientImpl
std::shared_ptr<const conf::CelebornConf> conf_;
std::shared_ptr<network::NettyRpcEndpointRef> lifecycleManagerRef_;
std::shared_ptr<network::TransportClientFactory> clientFactory_;
+ std::shared_ptr<folly::IOExecutor> pushDataRetryPool_;
+ std::shared_ptr<ReviveManager> reviveManager_;
std::mutex mutex_;
utils::ConcurrentHashMap<
int,
std::shared_ptr<protocol::GetReducerFileGroupResponse>>
reducerFileGroupInfos_;
utils::ConcurrentHashMap<int, PtrPartitionLocationMap>
partitionLocationMaps_;
+ utils::ConcurrentHashMap<int, std::shared_ptr<utils::ConcurrentHashSet<int>>>
+ mapperEndSets_;
};
} // namespace client
} // namespace celeborn
diff --git a/cpp/celeborn/client/tests/CMakeLists.txt
b/cpp/celeborn/client/tests/CMakeLists.txt
index 8183611c6..e19703f31 100644
--- a/cpp/celeborn/client/tests/CMakeLists.txt
+++ b/cpp/celeborn/client/tests/CMakeLists.txt
@@ -16,6 +16,7 @@
add_executable(
celeborn_client_test
WorkerPartitionReaderTest.cpp
+ PushDataCallbackTest.cpp
PushStateTest.cpp
ReviveManagerTest.cpp
Lz4DecompressorTest.cpp
diff --git a/cpp/celeborn/client/tests/PushDataCallbackTest.cpp
b/cpp/celeborn/client/tests/PushDataCallbackTest.cpp
new file mode 100644
index 000000000..171c8e714
--- /dev/null
+++ b/cpp/celeborn/client/tests/PushDataCallbackTest.cpp
@@ -0,0 +1,389 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <gtest/gtest.h>
+
+#include "celeborn/client/writer/PushDataCallback.h"
+
+using namespace celeborn;
+using namespace celeborn::client;
+
+namespace {
+class MockShuffleClient : public ShuffleClientImpl {
+ public:
+ friend class PushDataCallback;
+
+ using FuncOnSubmitRetryPushData = std::function<void(
+ int,
+ std::unique_ptr<memory::ReadOnlyByteBuffer>,
+ int,
+ std::shared_ptr<PushDataCallback>,
+ std::shared_ptr<PushState>,
+ PtrReviveRequest,
+ int,
+ long)>;
+ using FuncOnAddRequestToReviveManager =
+ std::function<void(std::shared_ptr<protocol::ReviveRequest>)>;
+ using FuncOnAddPushDataRetryTask = std::function<void(folly::Func&&)>;
+
+ static std::shared_ptr<MockShuffleClient> create() {
+ return std::shared_ptr<MockShuffleClient>(new MockShuffleClient());
+ }
+
+ virtual ~MockShuffleClient() = default;
+
+ bool mapperEnded(int shuffleId, int mapId) override {
+ return false;
+ }
+
+ std::optional<PtrPartitionLocationMap> getPartitionLocationMap(
+ int shuffleId) override {
+ return {std::make_shared<PartitionLocationMap>()};
+ }
+
+ void submitRetryPushData(
+ int shuffleId,
+ std::unique_ptr<memory::ReadOnlyByteBuffer> body,
+ int batchId,
+ std::shared_ptr<PushDataCallback> pushDataCallback,
+ std::shared_ptr<PushState> pushState,
+ PtrReviveRequest request,
+ int remainReviveTimes,
+ long dueTimeMs) override {
+ onSubmitRetryPushData_(
+ shuffleId,
+ std::move(body),
+ batchId,
+ pushDataCallback,
+ pushState,
+ request,
+ remainReviveTimes,
+ dueTimeMs);
+ }
+
+ void setOnSubmitRetryPushData(FuncOnSubmitRetryPushData&& func) {
+ onSubmitRetryPushData_ = func;
+ }
+
+ void addRequestToReviveManager(
+ std::shared_ptr<protocol::ReviveRequest> reviveRequest) override {
+ onAddRequestToReviveManager_(reviveRequest);
+ }
+
+ void setOnAddRequestToReviveManager(FuncOnAddRequestToReviveManager&& func) {
+ onAddRequestToReviveManager_ = func;
+ }
+
+ utils::ConcurrentHashMap<int,
std::shared_ptr<utils::ConcurrentHashSet<int>>>&
+ mapperEndSets() override {
+ return ShuffleClientImpl::mapperEndSets();
+ }
+
+ void addPushDataRetryTask(folly::Func&& task) override {
+ onAddPushDataRetryTask_(std::move(task));
+ }
+
+ void setOnAddPushDataRetryTask(FuncOnAddPushDataRetryTask&& func) {
+ onAddPushDataRetryTask_ = func;
+ }
+
+ private:
+ MockShuffleClient()
+ : ShuffleClientImpl(
+ "mock",
+ std::make_shared<conf::CelebornConf>(),
+ nullptr) {}
+
+ FuncOnSubmitRetryPushData onSubmitRetryPushData_ =
+ [](int,
+ std::unique_ptr<memory::ReadOnlyByteBuffer>,
+ int,
+ std::shared_ptr<PushDataCallback>,
+ std::shared_ptr<PushState>,
+ PtrReviveRequest,
+ int,
+ long) { CELEBORN_UNREACHABLE("not expected to call this"); };
+ FuncOnAddRequestToReviveManager onAddRequestToReviveManager_ =
+ [](std::shared_ptr<protocol::ReviveRequest>) {
+ CELEBORN_UNREACHABLE("not expected to call this");
+ };
+ FuncOnAddPushDataRetryTask onAddPushDataRetryTask_ = [](folly::Func&&) {
+ CELEBORN_UNREACHABLE("not expected to call this");
+ };
+};
+
+std::unique_ptr<memory::ReadOnlyByteBuffer> createReadOnlyByteBuffer(
+ uint8_t code) {
+ auto writeBuffer = memory::ByteBuffer::createWriteOnly(1);
+ writeBuffer->write<uint8_t>(code);
+ return std::move(memory::ByteBuffer::toReadOnly(std::move(writeBuffer)));
+}
+} // namespace
+
+const int testPushReviveIntervalMs = 100;
+const int testShuffleId = 1000;
+const int testMapId = 1001;
+const int testAttemptId = 1002;
+const int testPartitionId = 1003;
+const int testNumMappers = 1004;
+const int testNumPartitions = 1005;
+const std::string testMapKey = "test-map-key";
+const int testBatchId = 1006;
+const auto testLastestLocation =
+ std::make_shared<const protocol::PartitionLocation>();
+
+TEST(PushDataCallbackTest, onSuccessAndNoOperation) {
+ const auto celebornConf = conf::CelebornConf();
+ auto pushState = std::make_shared<PushState>(celebornConf);
+ auto mockClient = MockShuffleClient::create();
+ auto pushDataCallback = PushDataCallback::create(
+ testShuffleId,
+ testMapId,
+ testAttemptId,
+ testPartitionId,
+ testNumMappers,
+ testNumPartitions,
+ testMapKey,
+ testBatchId,
+ memory::ReadOnlyByteBuffer::createEmptyBuffer(),
+ pushState,
+ mockClient->weak_from_this(),
+ 1,
+ testLastestLocation);
+
+ auto response = memory::ReadOnlyByteBuffer::createEmptyBuffer();
+ pushDataCallback->onSuccess(std::move(response));
+}
+
+TEST(PushDataCallbackTest, onSuccessAndMapEnd) {
+ const auto celebornConf = conf::CelebornConf();
+ auto pushState = std::make_shared<PushState>(celebornConf);
+ auto mockClient = MockShuffleClient::create();
+ EXPECT_EQ(mockClient->mapperEndSets().size(), 0);
+ auto pushDataCallback = PushDataCallback::create(
+ testShuffleId,
+ testMapId,
+ testAttemptId,
+ testPartitionId,
+ testNumMappers,
+ testNumPartitions,
+ testMapKey,
+ testBatchId,
+ memory::ReadOnlyByteBuffer::createEmptyBuffer(),
+ pushState,
+ mockClient->weak_from_this(),
+ 1,
+ testLastestLocation);
+
+ auto response = createReadOnlyByteBuffer(protocol::StatusCode::MAP_ENDED);
+ pushDataCallback->onSuccess(std::move(response));
+ auto& mapperEndSets = mockClient->mapperEndSets();
+ EXPECT_TRUE(mapperEndSets.containsKey(testShuffleId));
+ auto mapperEndSet = mapperEndSets.get(testShuffleId).value();
+ EXPECT_TRUE(mapperEndSet->contains(testMapId));
+}
+
+TEST(PushDataCallbackTest, onSuccessAndSoftSplit) {
+ const auto celebornConf = conf::CelebornConf();
+ auto pushState = std::make_shared<PushState>(celebornConf);
+ auto mockClient = MockShuffleClient::create();
+
+ int addRequestCalledTimes = 0;
+ mockClient->setOnAddRequestToReviveManager(
+ [=, &addRequestCalledTimes](
+ std::shared_ptr<protocol::ReviveRequest> request) mutable {
+ EXPECT_EQ(request->shuffleId, testShuffleId);
+ EXPECT_EQ(request->mapId, testMapId);
+ EXPECT_EQ(request->attemptId, testAttemptId);
+ EXPECT_EQ(request->partitionId, testPartitionId);
+ EXPECT_EQ(request->cause, protocol::StatusCode::SOFT_SPLIT);
+
+ addRequestCalledTimes++;
+ });
+
+ auto pushDataCallback = PushDataCallback::create(
+ testShuffleId,
+ testMapId,
+ testAttemptId,
+ testPartitionId,
+ testNumMappers,
+ testNumPartitions,
+ testMapKey,
+ testBatchId,
+ memory::ReadOnlyByteBuffer::createEmptyBuffer(),
+ pushState,
+ mockClient->weak_from_this(),
+ 1,
+ testLastestLocation);
+
+ auto response = createReadOnlyByteBuffer(protocol::StatusCode::SOFT_SPLIT);
+ pushDataCallback->onSuccess(std::move(response));
+ EXPECT_GT(addRequestCalledTimes, 0);
+}
+
+TEST(PushDataCallbackTest, onSuccessAndHardSplit) {
+ const auto celebornConf = conf::CelebornConf();
+ auto pushState = std::make_shared<PushState>(celebornConf);
+ const int testRemainingReviveTimes = 1;
+ auto mockClient = MockShuffleClient::create();
+
+ int addRequestCalledTimes = 0;
+ mockClient->setOnAddRequestToReviveManager(
+ [=, &addRequestCalledTimes](
+ std::shared_ptr<protocol::ReviveRequest> request) mutable {
+ EXPECT_EQ(request->shuffleId, testShuffleId);
+ EXPECT_EQ(request->mapId, testMapId);
+ EXPECT_EQ(request->attemptId, testAttemptId);
+ EXPECT_EQ(request->partitionId, testPartitionId);
+ EXPECT_EQ(request->cause, protocol::StatusCode::HARD_SPLIT);
+
+ addRequestCalledTimes++;
+ });
+ int addRetryTaskCalledTimes = 0;
+ mockClient->setOnAddPushDataRetryTask(
+ [&addRetryTaskCalledTimes](folly::Func&& task) {
+ addRetryTaskCalledTimes++;
+ task();
+ });
+ int submitRetryCalledTimes = 0;
+ mockClient->setOnSubmitRetryPushData(
+ [=, &submitRetryCalledTimes](
+ int shuffleId,
+ std::unique_ptr<memory::ReadOnlyByteBuffer> body,
+ int batchId,
+ std::shared_ptr<PushDataCallback> pushDataCallback,
+ std::shared_ptr<PushState> pushState,
+ ShuffleClientImpl::PtrReviveRequest request,
+ int remainReviveTimes,
+ long dueTimeMs) mutable {
+ EXPECT_EQ(shuffleId, testShuffleId);
+ EXPECT_EQ(batchId, testBatchId);
+ EXPECT_EQ(remainReviveTimes, testRemainingReviveTimes);
+
+ submitRetryCalledTimes++;
+ });
+
+ auto pushDataCallback = PushDataCallback::create(
+ testShuffleId,
+ testMapId,
+ testAttemptId,
+ testPartitionId,
+ testNumMappers,
+ testNumPartitions,
+ testMapKey,
+ testBatchId,
+ memory::ReadOnlyByteBuffer::createEmptyBuffer(),
+ pushState,
+ mockClient->weak_from_this(),
+ testRemainingReviveTimes,
+ testLastestLocation);
+
+ auto response = createReadOnlyByteBuffer(protocol::StatusCode::HARD_SPLIT);
+ pushDataCallback->onSuccess(std::move(response));
+ EXPECT_GT(addRequestCalledTimes, 0);
+ EXPECT_GT(addRetryTaskCalledTimes, 0);
+ EXPECT_GT(submitRetryCalledTimes, 0);
+}
+
+TEST(PushDataCallbackTest, onFailureAndRevive) {
+ const auto celebornConf = conf::CelebornConf();
+ auto pushState = std::make_shared<PushState>(celebornConf);
+ const int testRemainingReviveTimes = 1;
+ auto mockClient = MockShuffleClient::create();
+
+ int addRequestCalledTimes = 0;
+ mockClient->setOnAddRequestToReviveManager(
+ [=, &addRequestCalledTimes](
+ std::shared_ptr<protocol::ReviveRequest> request) mutable {
+ EXPECT_EQ(request->shuffleId, testShuffleId);
+ EXPECT_EQ(request->mapId, testMapId);
+ EXPECT_EQ(request->attemptId, testAttemptId);
+ EXPECT_EQ(request->partitionId, testPartitionId);
+
+ addRequestCalledTimes++;
+ });
+ int addRetryTaskCalledTimes = 0;
+ mockClient->setOnAddPushDataRetryTask(
+ [&addRetryTaskCalledTimes](folly::Func&& task) {
+ addRetryTaskCalledTimes++;
+ task();
+ });
+ int submitRetryCalledTimes = 0;
+ mockClient->setOnSubmitRetryPushData(
+ [=, &submitRetryCalledTimes](
+ int shuffleId,
+ std::unique_ptr<memory::ReadOnlyByteBuffer> body,
+ int batchId,
+ std::shared_ptr<PushDataCallback> pushDataCallback,
+ std::shared_ptr<PushState> pushState,
+ ShuffleClientImpl::PtrReviveRequest request,
+ int remainReviveTimes,
+ long dueTimeMs) mutable {
+ EXPECT_EQ(shuffleId, testShuffleId);
+ EXPECT_EQ(batchId, testBatchId);
+ EXPECT_EQ(remainReviveTimes, testRemainingReviveTimes - 1);
+
+ submitRetryCalledTimes++;
+ });
+
+ auto pushDataCallback = PushDataCallback::create(
+ testShuffleId,
+ testMapId,
+ testAttemptId,
+ testPartitionId,
+ testNumMappers,
+ testNumPartitions,
+ testMapKey,
+ testBatchId,
+ memory::ReadOnlyByteBuffer::createEmptyBuffer(),
+ pushState,
+ mockClient->weak_from_this(),
+ testRemainingReviveTimes,
+ testLastestLocation);
+
+ auto exception = std::make_unique<std::runtime_error>("test");
+ pushDataCallback->onFailure(std::move(exception));
+ EXPECT_GT(addRequestCalledTimes, 0);
+ EXPECT_GT(addRetryTaskCalledTimes, 0);
+ EXPECT_GT(submitRetryCalledTimes, 0);
+}
+
+TEST(PushDataCallbackTest, onFailureAndNoRevive) {
+ const auto celebornConf = conf::CelebornConf();
+ auto pushState = std::make_shared<PushState>(celebornConf);
+ const int testRemainingReviveTimes = 0;
+ auto mockClient = MockShuffleClient::create();
+
+ auto pushDataCallback = PushDataCallback::create(
+ testShuffleId,
+ testMapId,
+ testAttemptId,
+ testPartitionId,
+ testNumMappers,
+ testNumPartitions,
+ testMapKey,
+ testBatchId,
+ memory::ReadOnlyByteBuffer::createEmptyBuffer(),
+ pushState,
+ mockClient->weak_from_this(),
+ testRemainingReviveTimes,
+ testLastestLocation);
+
+ auto exception = std::make_unique<std::runtime_error>("test");
+ pushDataCallback->onFailure(std::move(exception));
+ EXPECT_TRUE(pushState->exceptionExists());
+}
diff --git a/cpp/celeborn/client/writer/PushDataCallback.cpp
b/cpp/celeborn/client/writer/PushDataCallback.cpp
new file mode 100644
index 000000000..43550ba7f
--- /dev/null
+++ b/cpp/celeborn/client/writer/PushDataCallback.cpp
@@ -0,0 +1,266 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "celeborn/client/writer/PushDataCallback.h"
+#include "celeborn/conf/CelebornConf.h"
+
+namespace celeborn {
+namespace client {
+
+std::shared_ptr<PushDataCallback> PushDataCallback::create(
+ int shuffleId,
+ int mapId,
+ int attemptId,
+ int partitionId,
+ int numMappers,
+ int numPartitions,
+ const std::string& mapKey,
+ int batchId,
+ std::unique_ptr<memory::ReadOnlyByteBuffer> databody,
+ std::shared_ptr<PushState> pushState,
+ std::weak_ptr<ShuffleClientImpl> weakClient,
+ int remainingReviveTimes,
+ std::shared_ptr<const protocol::PartitionLocation> latestLocation) {
+ return std::shared_ptr<PushDataCallback>(new PushDataCallback(
+ shuffleId,
+ mapId,
+ attemptId,
+ partitionId,
+ numMappers,
+ numPartitions,
+ mapKey,
+ batchId,
+ std::move(databody),
+ pushState,
+ weakClient,
+ remainingReviveTimes,
+ latestLocation));
+}
+
+PushDataCallback::PushDataCallback(
+ int shuffleId,
+ int mapId,
+ int attemptId,
+ int partitionId,
+ int numMappers,
+ int numPartitions,
+ const std::string& mapKey,
+ int batchId,
+ std::unique_ptr<memory::ReadOnlyByteBuffer> databody,
+ std::shared_ptr<PushState> pushState,
+ std::weak_ptr<ShuffleClientImpl> weakClient,
+ int remainingReviveTimes,
+ std::shared_ptr<const protocol::PartitionLocation> latestLocation)
+ : shuffleId_(shuffleId),
+ mapId_(mapId),
+ attemptId_(attemptId),
+ partitionId_(partitionId),
+ numMappers_(numMappers),
+ numPartitions_(numPartitions),
+ mapKey_(mapKey),
+ batchId_(batchId),
+ databody_(std::move(databody)),
+ pushState_(pushState),
+ weakClient_(weakClient),
+ remainingReviveTimes_(remainingReviveTimes),
+ latestLocation_(latestLocation) {}
+
+void PushDataCallback::onSuccess(
+ std::unique_ptr<memory::ReadOnlyByteBuffer> response) {
+ auto sharedClient = weakClient_.lock();
+ if (!sharedClient) {
+ LOG(WARNING) << "ShuffleClientImpl has expired when "
+ "PushDataCallbackOnSuccess, ignored, shuffle "
+ << shuffleId_ << " map " << mapId_ << " attempt " <<
attemptId_
+ << " partition " << partitionId_ << " batch " << batchId_
+ << ".";
+ return;
+ }
+ if (response->remainingSize() <= 0) {
+ pushState_->onSuccess(latestLocation_->hostAndPushPort());
+ pushState_->removeBatch(batchId_, latestLocation_->hostAndPushPort());
+ return;
+ }
+ protocol::StatusCode reason =
+ static_cast<protocol::StatusCode>(response->read<uint8_t>());
+ switch (reason) {
+ case protocol::StatusCode::MAP_ENDED: {
+ auto mapperEndSet = sharedClient->mapperEndSets().computeIfAbsent(
+ shuffleId_,
+ []() { return std::make_shared<utils::ConcurrentHashSet<int>>(); });
+ mapperEndSet->insert(mapId_);
+ break;
+ }
+ case protocol::StatusCode::SOFT_SPLIT: {
+ VLOG(1) << "Push data to " << latestLocation_->hostAndPushPort()
+ << " soft split required for shuffle " << shuffleId_ << " map "
+ << mapId_ << " attempt " << attemptId_ << " partition "
+ << partitionId_ << " batch " << batchId_ << ".";
+ if (!ShuffleClientImpl::newerPartitionLocationExists(
+ sharedClient->getPartitionLocationMap(shuffleId_).value(),
+ partitionId_,
+ latestLocation_->epoch)) {
+ auto reviveRequest = std::make_shared<protocol::ReviveRequest>(
+ shuffleId_,
+ mapId_,
+ attemptId_,
+ partitionId_,
+ latestLocation_->epoch,
+ latestLocation_,
+ protocol::StatusCode::SOFT_SPLIT);
+ sharedClient->addRequestToReviveManager(reviveRequest);
+ }
+ pushState_->onSuccess(latestLocation_->hostAndPushPort());
+ pushState_->removeBatch(batchId_, latestLocation_->hostAndPushPort());
+ break;
+ }
+ case protocol::StatusCode::HARD_SPLIT: {
+ VLOG(1) << "Push data to " << latestLocation_->hostAndPushPort()
+ << " hard split required for shuffle " << shuffleId_ << " map "
+ << mapId_ << " attempt " << attemptId_ << " partition "
+ << partitionId_ << " batch " << batchId_ << ".";
+ reviveAndRetryPushData(*sharedClient, protocol::StatusCode::HARD_SPLIT);
+ break;
+ }
+ case protocol::StatusCode::PUSH_DATA_SUCCESS_PRIMARY_CONGESTED: {
+ VLOG(1) << "Push data to " << latestLocation_->hostAndPushPort()
+ << " primary congestion required for shuffle " << shuffleId_
+ << " map " << mapId_ << " attempt " << attemptId_ << " partition
"
+ << partitionId_ << " batch " << batchId_ << ".";
+ pushState_->onCongestControl(latestLocation_->hostAndPushPort());
+ pushState_->removeBatch(batchId_, latestLocation_->hostAndPushPort());
+ break;
+ }
+ case protocol::StatusCode::PUSH_DATA_SUCCESS_REPLICA_CONGESTED: {
+ VLOG(1) << "Push data to " << latestLocation_->hostAndPushPort()
+ << " replicate congestion required for shuffle " << shuffleId_
+ << " map " << mapId_ << " attempt " << attemptId_ << " partition
"
+ << partitionId_ << " batch " << batchId_ << ".";
+ pushState_->onCongestControl(latestLocation_->hostAndPushPort());
+ pushState_->removeBatch(batchId_, latestLocation_->hostAndPushPort());
+ break;
+ }
+ default: {
+ // This is treated as success.
+ LOG(WARNING) << "unhandled PushData success protocol::StatusCode: "
+ << reason;
+ }
+ }
+}
+
+void PushDataCallback::onFailure(std::unique_ptr<std::exception> exception) {
+ auto sharedClient = weakClient_.lock();
+ if (!sharedClient) {
+ LOG(WARNING) << "ShuffleClientImpl has expired when "
+ "PushDataCallbackOnFailure, ignored, shuffle "
+ << shuffleId_ << " map " << mapId_ << " attempt " <<
attemptId_
+ << " partition " << partitionId_ << " batch " << batchId_
+ << ".";
+ return;
+ }
+ if (pushState_->exceptionExists()) {
+ return;
+ }
+
+ LOG(ERROR) << "Push data to " << latestLocation_->hostAndPushPort()
+ << " failed for shuffle " << shuffleId_ << " map " << mapId_
+ << " attempt " << attemptId_ << " partition " << partitionId_
+ << " batch " << batchId_ << ", remain revive times "
+ << remainingReviveTimes_;
+
+ if (remainingReviveTimes_ <= 0) {
+ // TODO: set more specific exception.
+ pushState_->setException(std::move(exception));
+ return;
+ }
+
+ if (sharedClient->mapperEnded(shuffleId_, mapId_)) {
+ pushState_->removeBatch(batchId_, latestLocation_->hostAndPushPort());
+ LOG(INFO) << "Push data to " << latestLocation_->hostAndPushPort()
+ << " failed but mapper already ended for shuffle " << shuffleId_
+ << " map " << mapId_ << " attempt " << attemptId_ << " partition
"
+ << partitionId_ << " batch " << batchId_
+ << ", remain revive times " << remainingReviveTimes_ << ".";
+ return;
+ }
+ remainingReviveTimes_--;
+ // TODO: we use PRIMARY exception as the dummy value here, but the cause
+ // should be extracted from error msg. Especially, the cause should tell if
+ // the exception is from PRIMARY or REPLICATE.
+ protocol::StatusCode cause =
+ protocol::StatusCode::PUSH_DATA_CONNECTION_EXCEPTION_PRIMARY;
+ reviveAndRetryPushData(*sharedClient, cause);
+}
+
+void PushDataCallback::updateLatestLocation(
+ std::shared_ptr<const protocol::PartitionLocation> latestLocation) {
+ pushState_->addBatch(batchId_, latestLocation->hostAndPushPort());
+ pushState_->removeBatch(batchId_, latestLocation_->hostAndPushPort());
+ latestLocation_ = latestLocation;
+}
+
+void PushDataCallback::reviveAndRetryPushData(
+ ShuffleClientImpl& shuffleClient,
+ protocol::StatusCode cause) {
+ auto reviveRequest = std::make_shared<protocol::ReviveRequest>(
+ shuffleId_,
+ mapId_,
+ attemptId_,
+ partitionId_,
+ latestLocation_->epoch,
+ latestLocation_,
+ cause);
+ VLOG(1) << "addRequest to reviveManager, shuffleId "
+ << reviveRequest->shuffleId << " mapId " << reviveRequest->mapId
+ << " attemptId " << reviveRequest->attemptId << " partitionId "
+ << reviveRequest->partitionId << " batchId " << batchId_ << " epoch "
+ << reviveRequest->epoch;
+ shuffleClient.addRequestToReviveManager(reviveRequest);
+ long dueTimeMs = utils::currentTimeMillis() +
+ shuffleClient.conf_->clientRpcRequestPartitionLocationRpcAskTimeout() /
+ utils::MS(1);
+ shuffleClient.addPushDataRetryTask(
+ [weakClient = this->weakClient_,
+ shuffleId = this->shuffleId_,
+ body = this->databody_->clone(),
+ batchId = this->batchId_,
+ callback = shared_from_this(),
+ pushState = this->pushState_,
+ reviveRequest,
+ remainingReviveTimes = this->remainingReviveTimes_,
+ dueTimeMs]() {
+ auto sharedClient = weakClient.lock();
+ if (!sharedClient) {
+ LOG(WARNING) << "ShuffleClientImpl has expired when "
+ "PushDataFailureCallback, ignored, shuffleId "
+ << shuffleId;
+ return;
+ }
+ sharedClient->submitRetryPushData(
+ shuffleId,
+ body->clone(),
+ batchId,
+ callback,
+ pushState,
+ reviveRequest,
+ remainingReviveTimes,
+ dueTimeMs);
+ });
+}
+
+} // namespace client
+} // namespace celeborn
diff --git a/cpp/celeborn/client/writer/PushDataCallback.h
b/cpp/celeborn/client/writer/PushDataCallback.h
new file mode 100644
index 000000000..9916cd191
--- /dev/null
+++ b/cpp/celeborn/client/writer/PushDataCallback.h
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include "celeborn/client/ShuffleClient.h"
+#include "celeborn/client/writer/PushState.h"
+
+namespace celeborn {
+namespace client {
+
+class ShuffleClientImpl;
+
+class PushDataCallback : public network::RpcResponseCallback,
+ public std::enable_shared_from_this<PushDataCallback>
{
+ public:
+ // Only allow construction from create() method to ensure that functionality
+ // of std::shared_from_this works properly.
+ static std::shared_ptr<PushDataCallback> create(
+ int shuffleId,
+ int mapId,
+ int attemptId,
+ int partitionId,
+ int numMappers,
+ int numPartitions,
+ const std::string& mapKey,
+ int batchId,
+ std::unique_ptr<memory::ReadOnlyByteBuffer> databody,
+ std::shared_ptr<PushState> pushState,
+ std::weak_ptr<ShuffleClientImpl> weakClient,
+ int remainingReviveTimes,
+ std::shared_ptr<const protocol::PartitionLocation> latestLocation);
+
+ void onSuccess(std::unique_ptr<memory::ReadOnlyByteBuffer> response)
override;
+
+ void onFailure(std::unique_ptr<std::exception> exception) override;
+
+ // The location of a PushDataCallback might be updated if a revive is
+ // involved, and the location must be updated by calling
+ // updateLatestLocation() to make sure the location is properly updated.
+ void updateLatestLocation(
+ std::shared_ptr<const protocol::PartitionLocation> latestLocation);
+
+ private:
+ // The constructor is hidden to ensure that functionality of
+ // std::shared_from_this works properly.
+ PushDataCallback(
+ int shuffleId,
+ int mapId,
+ int attemptId,
+ int partitionId,
+ int numMappers,
+ int numPartitions,
+ const std::string& mapKey,
+ int batchId,
+ std::unique_ptr<memory::ReadOnlyByteBuffer> databody,
+ std::shared_ptr<PushState> pushState,
+ std::weak_ptr<ShuffleClientImpl> weakClient,
+ int remainingReviveTimes,
+ std::shared_ptr<const protocol::PartitionLocation> latestLocation);
+
+ void reviveAndRetryPushData(
+ ShuffleClientImpl& shuffleClient,
+ protocol::StatusCode cause);
+
+ const int shuffleId_;
+ const int mapId_;
+ const int attemptId_;
+ const int partitionId_;
+ const int numMappers_;
+ const int numPartitions_;
+ const std::string mapKey_;
+ const int batchId_;
+ const std::unique_ptr<memory::ReadOnlyByteBuffer> databody_;
+ const std::shared_ptr<PushState> pushState_;
+ const std::weak_ptr<ShuffleClientImpl> weakClient_;
+ int remainingReviveTimes_;
+ std::shared_ptr<const protocol::PartitionLocation> latestLocation_;
+};
+
+} // namespace client
+} // namespace celeborn
diff --git a/cpp/celeborn/client/writer/ReviveManager.h
b/cpp/celeborn/client/writer/ReviveManager.h
index 4f3b7962e..925e2b924 100644
--- a/cpp/celeborn/client/writer/ReviveManager.h
+++ b/cpp/celeborn/client/writer/ReviveManager.h
@@ -26,6 +26,7 @@
namespace celeborn {
namespace client {
+class ShuffleClientImpl;
/// ReviveManager is responsible for buffering the ReviveRequests, and issue
/// the revive requests periodically in batches.
class ReviveManager : public std::enable_shared_from_this<ReviveManager> {
diff --git a/cpp/celeborn/conf/CelebornConf.cpp
b/cpp/celeborn/conf/CelebornConf.cpp
index ab1a7abf0..1d58516b3 100644
--- a/cpp/celeborn/conf/CelebornConf.cpp
+++ b/cpp/celeborn/conf/CelebornConf.cpp
@@ -142,10 +142,11 @@ const std::unordered_map<std::string,
folly::Optional<std::string>>
NUM_PROP(kClientPushMaxReqsInFlightTotal, 256),
NUM_PROP(kClientPushLimitInFlightTimeoutMs, 240000),
NUM_PROP(kClientPushLimitInFlightSleepDeltaMs, 50),
+ STR_PROP(kClientRpcRequestPartitionLocationAskTimeout, "60s"),
STR_PROP(kClientRpcGetReducerFileGroupRpcAskTimeout, "60s"),
STR_PROP(kNetworkConnectTimeout, "10s"),
STR_PROP(kClientFetchTimeout, "600s"),
- NUM_PROP(kNetworkIoNumConnectionsPerPeer, "1"),
+ NUM_PROP(kNetworkIoNumConnectionsPerPeer, 1),
NUM_PROP(kNetworkIoClientThreads, 0),
NUM_PROP(kClientFetchMaxReqsInFlight, 3),
STR_PROP(
@@ -223,6 +224,11 @@ long CelebornConf::clientPushLimitInFlightSleepDeltaMs()
const {
optionalProperty(kClientPushLimitInFlightSleepDeltaMs).value());
}
+Timeout CelebornConf::clientRpcRequestPartitionLocationRpcAskTimeout() const {
+ return utils::toTimeout(toDuration(
+ optionalProperty(kClientRpcRequestPartitionLocationAskTimeout).value()));
+}
+
Timeout CelebornConf::clientRpcGetReducerFileGroupRpcAskTimeout() const {
return utils::toTimeout(toDuration(
optionalProperty(kClientRpcGetReducerFileGroupRpcAskTimeout).value()));
diff --git a/cpp/celeborn/conf/CelebornConf.h b/cpp/celeborn/conf/CelebornConf.h
index d6ce24f18..530a59781 100644
--- a/cpp/celeborn/conf/CelebornConf.h
+++ b/cpp/celeborn/conf/CelebornConf.h
@@ -68,6 +68,10 @@ class CelebornConf : public BaseConf {
static constexpr std::string_view kClientPushLimitInFlightSleepDeltaMs{
"celeborn.client.push.limit.inFlight.sleepInterval"};
+ static constexpr std::string_view
+ kClientRpcRequestPartitionLocationAskTimeout{
+ "celeborn.client.rpc.requestPartition.askTimeout"};
+
static constexpr std::string_view kClientRpcGetReducerFileGroupRpcAskTimeout{
"celeborn.client.rpc.getReducerFileGroup.askTimeout"};
@@ -120,6 +124,8 @@ class CelebornConf : public BaseConf {
long clientPushLimitInFlightSleepDeltaMs() const;
+ Timeout clientRpcRequestPartitionLocationRpcAskTimeout() const;
+
Timeout clientRpcGetReducerFileGroupRpcAskTimeout() const;
Timeout networkConnectTimeout() const;
diff --git a/cpp/celeborn/protocol/PartitionLocation.cpp
b/cpp/celeborn/protocol/PartitionLocation.cpp
index a1820c5df..333a4ef13 100644
--- a/cpp/celeborn/protocol/PartitionLocation.cpp
+++ b/cpp/celeborn/protocol/PartitionLocation.cpp
@@ -148,6 +148,18 @@ std::unique_ptr<PbPartitionLocation>
PartitionLocation::toPbWithoutPeer()
return pbPartitionLocation;
}
+std::string PartitionLocation::filename() const {
+ return fmt::format("{}-{}-{}", id, epoch, static_cast<int>(mode));
+}
+
+std::string PartitionLocation::uniqueId() const {
+ return fmt::format("{}-{}", id, epoch);
+}
+
+std::string PartitionLocation::hostAndPushPort() const {
+ return fmt::format("{}:{}", host, pushPort);
+}
+
StatusCode toStatusCode(int32_t code) {
CELEBORN_CHECK(code >= 0);
CELEBORN_CHECK(code <= StatusCode::TAIL);
diff --git a/cpp/celeborn/protocol/PartitionLocation.h
b/cpp/celeborn/protocol/PartitionLocation.h
index 4e88d16de..16a535bd0 100644
--- a/cpp/celeborn/protocol/PartitionLocation.h
+++ b/cpp/celeborn/protocol/PartitionLocation.h
@@ -92,10 +92,11 @@ struct PartitionLocation {
std::unique_ptr<PbPartitionLocation> toPb() const;
- std::string filename() const {
- return std::to_string(id) + "-" + std::to_string(epoch) + "-" +
- std::to_string(mode);
- }
+ std::string filename() const;
+
+ std::string uniqueId() const;
+
+ std::string hostAndPushPort() const;
private:
static std::unique_ptr<PartitionLocation> fromPbWithoutPeer(
diff --git a/cpp/celeborn/utils/CelebornUtils.h
b/cpp/celeborn/utils/CelebornUtils.h
index 669e4060e..ac1419914 100644
--- a/cpp/celeborn/utils/CelebornUtils.h
+++ b/cpp/celeborn/utils/CelebornUtils.h
@@ -62,6 +62,12 @@ inline Timeout toTimeout(Duration duration) {
return std::chrono::duration_cast<Timeout>(duration);
}
+inline uint64_t currentTimeMillis() {
+ return std::chrono::duration_cast<std::chrono::milliseconds>(
+ std::chrono::system_clock::now().time_since_epoch())
+ .count();
+}
+
/// parse string like "Any-Host-Str:Port#1:Port#2:...:Port#num", split into
/// {"Any-Host-Str", "Port#1", "Port#2", ..., "Port#num"}. Note that the
/// "Any-Host-Str" might contain ':' in IPV6 address.