This is an automated email from the ASF dual-hosted git repository.
nicholasjiang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new f35b6b80a [CELEBORN-2206][CIP-14] Support PushData and Revive in Cpp's
ShuffleClient
f35b6b80a is described below
commit f35b6b80ac13af3cb28cc8522bc209256a289f76
Author: HolyLow <[email protected]>
AuthorDate: Mon Dec 8 14:00:36 2025 +0800
[CELEBORN-2206][CIP-14] Support PushData and Revive in Cpp's ShuffleClient
### What changes were proposed in this pull request?
This PR supports PushData and Revive in Cpp's ShuffleClient so that the Cpp
module is capable of writing to Celeborn Server.
### Why are the changes needed?
This PR enables Cpp module to write to Celeborn Server.
### Does this PR resolve a correctness bug?
No.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Compilation.
Closes #3553 from
HolyLow/issue/celeborn-2215-support-PushData-and-Revive-in-cpp-ShuffleClient.
Authored-by: HolyLow <[email protected]>
Signed-off-by: SteNicholas <[email protected]>
---
cpp/celeborn/client/ShuffleClient.cpp | 540 ++++++++++++++++++++-
cpp/celeborn/client/ShuffleClient.h | 137 ++++--
cpp/celeborn/client/tests/PushDataCallbackTest.cpp | 10 +-
cpp/celeborn/client/tests/ReviveManagerTest.cpp | 11 +-
cpp/celeborn/conf/CelebornConf.cpp | 97 ++--
cpp/celeborn/conf/CelebornConf.h | 39 +-
cpp/celeborn/memory/ByteBuffer.h | 7 +-
cpp/celeborn/network/TransportClient.cpp | 11 +-
cpp/celeborn/network/TransportClient.h | 5 +-
cpp/celeborn/tests/DataSumWithReaderClient.cpp | 6 +-
cpp/celeborn/utils/CelebornUtils.cpp | 4 +
cpp/celeborn/utils/CelebornUtils.h | 8 +
12 files changed, 805 insertions(+), 70 deletions(-)
diff --git a/cpp/celeborn/client/ShuffleClient.cpp
b/cpp/celeborn/client/ShuffleClient.cpp
index 22d19256c..807fac2e0 100644
--- a/cpp/celeborn/client/ShuffleClient.cpp
+++ b/cpp/celeborn/client/ShuffleClient.cpp
@@ -21,19 +21,46 @@
namespace celeborn {
namespace client {
+
+ShuffleClientEndpoint::ShuffleClientEndpoint(
+ const std::shared_ptr<const conf::CelebornConf>& conf)
+ : conf_(conf),
+ pushDataRetryPool_(std::make_shared<folly::IOThreadPoolExecutor>(
+ conf_->clientPushRetryThreads(),
+ std::make_shared<folly::NamedThreadFactory>(
+ "client-pushdata-retrier"))),
+ clientFactory_(std::make_shared<network::TransportClientFactory>(conf_))
{
+}
+
+std::shared_ptr<folly::IOThreadPoolExecutor>
+ShuffleClientEndpoint::pushDataRetryPool() const {
+ return pushDataRetryPool_;
+}
+
+std::shared_ptr<network::TransportClientFactory>
+ShuffleClientEndpoint::clientFactory() const {
+ return clientFactory_;
+}
+
std::shared_ptr<ShuffleClientImpl> ShuffleClientImpl::create(
const std::string& appUniqueId,
const std::shared_ptr<const conf::CelebornConf>& conf,
- const std::shared_ptr<network::TransportClientFactory>& clientFactory) {
+ const ShuffleClientEndpoint& clientEndpoint) {
return std::shared_ptr<ShuffleClientImpl>(
- new ShuffleClientImpl(appUniqueId, conf, clientFactory));
+ new ShuffleClientImpl(appUniqueId, conf, clientEndpoint));
}
ShuffleClientImpl::ShuffleClientImpl(
const std::string& appUniqueId,
const std::shared_ptr<const conf::CelebornConf>& conf,
- const std::shared_ptr<network::TransportClientFactory>& clientFactory)
- : appUniqueId_(appUniqueId), conf_(conf), clientFactory_(clientFactory) {}
+ const ShuffleClientEndpoint& clientEndpoint)
+ : appUniqueId_(appUniqueId),
+ conf_(conf),
+ clientFactory_(clientEndpoint.clientFactory()),
+ pushDataRetryPool_(clientEndpoint.pushDataRetryPool()) {
+ CELEBORN_CHECK_NOT_NULL(clientFactory_);
+ CELEBORN_CHECK_NOT_NULL(pushDataRetryPool_);
+}
void ShuffleClientImpl::setupLifecycleManagerRef(std::string& host, int port) {
auto managerClient = clientFactory_->createClient(host, port);
@@ -47,6 +74,8 @@ void ShuffleClientImpl::setupLifecycleManagerRef(std::string&
host, int port) {
port,
managerClient,
*conf_);
+
+ initReviveManagerLocked();
}
}
@@ -54,6 +83,172 @@ void ShuffleClientImpl::setupLifecycleManagerRef(
std::shared_ptr<network::NettyRpcEndpointRef>& lifecycleManagerRef) {
std::lock_guard<std::mutex> lock(mutex_);
lifecycleManagerRef_ = lifecycleManagerRef;
+
+ initReviveManagerLocked();
+}
+
+std::shared_ptr<utils::ConcurrentHashMap<
+ int,
+ std::shared_ptr<const protocol::PartitionLocation>>>
+ShuffleClientImpl::getPartitionLocation(
+ int shuffleId,
+ int numMappers,
+ int numPartitions) {
+ auto partitionLocationOptional = partitionLocationMaps_.get(shuffleId);
+ if (partitionLocationOptional.has_value()) {
+ return partitionLocationOptional.value();
+ }
+
+ registerShuffle(shuffleId, numMappers, numPartitions);
+
+ partitionLocationOptional = partitionLocationMaps_.get(shuffleId);
+ CELEBORN_CHECK(
+ partitionLocationOptional.has_value(),
+ "partitionLocation is empty because registerShuffle failed");
+ auto partitionLocationMap = partitionLocationOptional.value();
+ CELEBORN_CHECK_NOT_NULL(partitionLocationMap);
+ return partitionLocationMap;
+}
+
+int ShuffleClientImpl::pushData(
+ int shuffleId,
+ int mapId,
+ int attemptId,
+ int partitionId,
+ const uint8_t* data,
+ size_t offset,
+ size_t length,
+ int numMappers,
+ int numPartitions) {
+ const auto mapKey = utils::makeMapKey(shuffleId, mapId, attemptId);
+ if (checkMapperEnded(shuffleId, mapId, mapKey)) {
+ return 0;
+ }
+
+ auto partitionLocationMap =
+ getPartitionLocation(shuffleId, numMappers, numPartitions);
+ CELEBORN_CHECK_NOT_NULL(partitionLocationMap);
+ auto partitionLocationOptional = partitionLocationMap->get(partitionId);
+ if (!partitionLocationOptional.has_value()) {
+ if (!revive(
+ shuffleId,
+ mapId,
+ attemptId,
+ partitionId,
+ -1,
+ nullptr,
+ protocol::StatusCode::PUSH_DATA_FAIL_NON_CRITICAL_CAUSE)) {
+ CELEBORN_FAIL(fmt::format(
+ "Revive for shuffleId {} partitionId {} failed.",
+ shuffleId,
+ partitionId));
+ }
+ partitionLocationOptional = partitionLocationMap->get(partitionId);
+ }
+ if (checkMapperEnded(shuffleId, mapId, mapKey)) {
+ return 0;
+ }
+
+ CELEBORN_CHECK(partitionLocationOptional.has_value());
+ auto partitionLocation = partitionLocationOptional.value();
+ auto pushState = getPushState(mapKey);
+ const int nextBatchId = pushState->nextBatchId();
+
+ // TODO: compression in writing is not supported.
+
+ auto writeBuffer =
+ memory::ByteBuffer::createWriteOnly(kBatchHeaderSize + length);
+ // TODO: the java side uses Platform to write the data. We simply assume
+ // littleEndian here.
+ writeBuffer->writeLE<int>(mapId);
+ writeBuffer->writeLE<int>(attemptId);
+ writeBuffer->writeLE<int>(nextBatchId);
+ writeBuffer->writeLE<int>(length);
+ writeBuffer->writeFromBuffer(data, offset, length);
+
+ auto hostAndPushPort = partitionLocation->hostAndPushPort();
+ // Check limit.
+ limitMaxInFlight(mapKey, *pushState, hostAndPushPort);
+ // Add inFlight requests.
+ pushState->addBatch(nextBatchId, hostAndPushPort);
+ // Build pushData request.
+ const auto shuffleKey = utils::makeShuffleKey(appUniqueId_, shuffleId);
+ auto body = memory::ByteBuffer::toReadOnly(std::move(writeBuffer));
+ network::PushData pushData(
+ network::Message::nextRequestId(),
+ protocol::PartitionLocation::Mode::PRIMARY,
+ shuffleKey,
+ partitionLocation->uniqueId(),
+ body->clone());
+ // Build callback.
+ auto pushDataCallback = PushDataCallback::create(
+ shuffleId,
+ mapId,
+ attemptId,
+ partitionId,
+ numMappers,
+ numPartitions,
+ mapKey,
+ nextBatchId,
+ body->clone(),
+ pushState,
+ weak_from_this(),
+ conf_->clientPushMaxReviveTimes(),
+ partitionLocation);
+ // Do push data.
+ auto client = clientFactory_->createClient(
+ partitionLocation->host, partitionLocation->pushPort, partitionId);
+ client->pushDataAsync(
+ pushData, conf_->clientPushDataTimeout(), pushDataCallback);
+ return body->remainingSize();
+}
+
+void ShuffleClientImpl::mapperEnd(
+ int shuffleId,
+ int mapId,
+ int attemptId,
+ int numMappers) {
+ mapPartitionMapperEnd(shuffleId, mapId, attemptId, numMappers, -1);
+}
+
+void ShuffleClientImpl::mapPartitionMapperEnd(
+ int shuffleId,
+ int mapId,
+ int attemptId,
+ int numMappers,
+ int partitionId) {
+ auto mapKey = utils::makeMapKey(shuffleId, mapId, attemptId);
+ auto pushState = getPushState(mapKey);
+
+ try {
+ limitZeroInFlight(mapKey, *pushState);
+
+ auto mapperEndResponse =
+ lifecycleManagerRef_
+ ->askSync<protocol::MapperEnd, protocol::MapperEndResponse>(
+ protocol::MapperEnd{
+ shuffleId, mapId, attemptId, numMappers, partitionId});
+ if (mapperEndResponse->status != protocol::StatusCode::SUCCESS) {
+ CELEBORN_FAIL(
+ "MapperEnd failed. protocol::StatusCode " +
+ std::to_string(mapperEndResponse->status));
+ }
+ } catch (std::exception& e) {
+ LOG(ERROR) << "mapperEnd failed, error msg: " << e.what();
+ pushStates_.erase(mapKey);
+ CELEBORN_FAIL(e.what());
+ }
+ pushStates_.erase(mapKey);
+}
+
+void ShuffleClientImpl::cleanup(int shuffleId, int mapId, int attemptId) {
+ auto mapKey = utils::makeMapKey(shuffleId, mapId, attemptId);
+ auto pushStateOptional = pushStates_.erase(mapKey);
+ if (pushStateOptional.has_value()) {
+ auto pushState = pushStateOptional.value();
+ pushState->setException(
+ std::make_unique<std::runtime_error>(mapKey + "is cleaned up"));
+ }
}
std::unique_ptr<CelebornInputStream> ShuffleClientImpl::readPartition(
@@ -143,6 +338,343 @@ bool ShuffleClientImpl::cleanupShuffle(int shuffleId) {
return true;
}
+std::shared_ptr<PushState> ShuffleClientImpl::getPushState(
+ const std::string& mapKey) {
+ return pushStates_.computeIfAbsent(
+ mapKey, [&]() { return std::make_shared<PushState>(*conf_); });
+}
+
+void ShuffleClientImpl::initReviveManagerLocked() {
+ if (!reviveManager_) {
+ std::string uniqueName = appUniqueId_;
+ uniqueName += std::to_string(utils::currentTimeNanos());
+ reviveManager_ =
+ ReviveManager::create(uniqueName, *conf_, weak_from_this());
+ }
+}
+
+void ShuffleClientImpl::registerShuffle(
+ int shuffleId,
+ int numMappers,
+ int numPartitions) {
+ auto shuffleMutex = shuffleMutexes_.computeIfAbsent(
+ shuffleId, []() { return std::make_shared<std::mutex>(); });
+ // RegisterShuffle might be issued concurrently, we only allow one issue
+ // for each shuffleId.
+ std::lock_guard<std::mutex> lock(*shuffleMutex);
+ if (partitionLocationMaps_.containsKey(shuffleId)) {
+ return;
+ }
+ CELEBORN_CHECK(
+ lifecycleManagerRef_, "lifecycleManagerRef_ is not initialized");
+ const int maxRetries = conf_->clientRegisterShuffleMaxRetries();
+ int numRetries = 1;
+ for (; numRetries <= maxRetries; numRetries++) {
+ try {
+ // Send the query request to lifecycleManager.
+ auto registerShuffleResponse = lifecycleManagerRef_->askSync<
+ protocol::RegisterShuffle,
+ protocol::RegisterShuffleResponse>(
+ protocol::RegisterShuffle{shuffleId, numMappers, numPartitions},
+ conf_->clientRpcRegisterShuffleRpcAskTimeout());
+
+ switch (registerShuffleResponse->status) {
+ case protocol::StatusCode::SUCCESS: {
+ VLOG(1) << "success to registerShuffle, shuffleId " << shuffleId
+ << " numMappers " << numMappers << " numPartitions "
+ << numPartitions;
+ auto partitionLocationMap =
std::make_shared<utils::ConcurrentHashMap<
+ int,
+ std::shared_ptr<const protocol::PartitionLocation>>>();
+ auto& partitionLocations =
+ registerShuffleResponse->partitionLocations;
+ for (auto i = 0; i < partitionLocations.size(); i++) {
+ auto id = partitionLocations[i]->id;
+ partitionLocationMap->set(id, std::move(partitionLocations[i]));
+ }
+ partitionLocationMaps_.set(
+ shuffleId, std::move(partitionLocationMap));
+ return;
+ }
+ default: {
+ LOG(ERROR)
+ << "LifecycleManager request slots return protocol::StatusCode "
+ << registerShuffleResponse->status << " , shuffleId " <<
shuffleId
+ << " numMappers " << numMappers << " numPartitions "
+ << numPartitions << " , retry again, remain retry times "
+ << maxRetries - numRetries;
+ }
+ }
+ } catch (std::exception& e) {
+ CELEBORN_FAIL(fmt::format(
+ "registerShuffle encounters error after {} tries, "
+ "shuffleId {} numMappers {} numPartitions {}, errorMsg: {}",
+ numRetries,
+ shuffleId,
+ numMappers,
+ numPartitions,
+ e.what()));
+ break;
+ }
+ std::this_thread::sleep_for(conf_->clientRegisterShuffleRetryWait());
+ }
+ partitionLocationMaps_.set(shuffleId, nullptr);
+ CELEBORN_FAIL(fmt::format(
+ "registerShuffle failed after {} tries, "
+ "shuffleId {} numMappers {} numPartitions {}",
+ maxRetries,
+ shuffleId,
+ numMappers,
+ numPartitions));
+}
+
+void ShuffleClientImpl::submitRetryPushData(
+ int shuffleId,
+ std::unique_ptr<memory::ReadOnlyByteBuffer> body,
+ int batchId,
+ std::shared_ptr<PushDataCallback> pushDataCallback,
+ std::shared_ptr<PushState> pushState,
+ PtrReviveRequest request,
+ int remainReviveTimes,
+ long dueTimeMs) {
+ long reviveWaitTimeMs = dueTimeMs - utils::currentTimeMillis();
+ long accumulatedTimeMs = 0;
+ const long deltaMs = 50;
+ while (request->reviveStatus.load() ==
+ protocol::StatusCode::REVIVE_INITIALIZED &&
+ accumulatedTimeMs <= reviveWaitTimeMs) {
+ std::this_thread::sleep_for(utils::MS(deltaMs));
+ accumulatedTimeMs += deltaMs;
+ }
+ if (mapperEnded(shuffleId, request->mapId)) {
+ if (request->loc) {
+ VLOG(1) << "Revive for push data success, but the mapper already ended "
+ "for shuffle "
+ << shuffleId << " map " << request->mapId << " attempt "
+ << request->attemptId << " partition " << request->partitionId
+ << " batch " << batchId << " location hostAndPushPort "
+ << request->loc->hostAndPushPort() << ".";
+ pushState->removeBatch(batchId, request->loc->hostAndPushPort());
+ } else {
+ VLOG(1) << "Revive for push data success, but the mapper already ended "
+ "for shuffle "
+ << shuffleId << " map " << request->mapId << " attempt "
+ << request->attemptId << " partition " << request->partitionId
+ << " batch " << batchId << " no location available.";
+ }
+ return;
+ }
+ if (request->reviveStatus.load() != protocol::StatusCode::SUCCESS) {
+ // TODO: the exception message here should be assembled.
+ pushDataCallback->onFailure(std::make_unique<std::exception>());
+ return;
+ }
+ auto locationMapOptional = partitionLocationMaps_.get(shuffleId);
+ CELEBORN_CHECK(locationMapOptional.has_value());
+ auto newLocationOptional =
+ locationMapOptional.value()->get(request->partitionId);
+ CELEBORN_CHECK(newLocationOptional.has_value());
+ auto newLocation = newLocationOptional.value();
+ LOG(INFO) << "Revive for push data success, new location for shuffle "
+ << shuffleId << " map " << request->mapId << " attempt "
+ << request->attemptId << " partition " << request->partitionId
+ << " batch " << batchId << " is location hostAndPushPort "
+ << newLocation->hostAndPushPort() << ".";
+ pushDataCallback->updateLatestLocation(newLocation);
+
+ try {
+ CELEBORN_CHECK_GT(remainReviveTimes, 0, "no remainReviveTime left");
+ network::PushData pushData(
+ network::Message::nextRequestId(),
+ protocol::PartitionLocation::Mode::PRIMARY,
+ utils::makeShuffleKey(appUniqueId_, shuffleId),
+ newLocation->uniqueId(),
+ std::move(body));
+ auto client = clientFactory_->createClient(
+ newLocation->host, newLocation->pushPort, request->partitionId);
+ client->pushDataAsync(
+ pushData, conf_->clientPushDataTimeout(), pushDataCallback);
+ } catch (const std::exception& e) {
+ LOG(ERROR) << "Exception raised while pushing data for shuffle "
+ << shuffleId << " map " << request->mapId << " attempt "
+ << request->attemptId << " partition " << request->partitionId
+ << " batch " << batchId << " location hostAndPushPort "
+ << newLocation->hostAndPushPort() << " errorMsg " << e.what()
+ << ".";
+ // TODO: The failure should be treated better.
+ pushDataCallback->onFailure(std::make_unique<std::exception>(e));
+ }
+}
+
+bool ShuffleClientImpl::checkMapperEnded(
+ int shuffleId,
+ int mapId,
+ const std::string& mapKey) {
+ if (mapperEnded(shuffleId, mapId)) {
+ VLOG(1) << "Mapper already ended for shuffle " << shuffleId << " map "
+ << mapId;
+ if (auto pushStateOptional = pushStates_.get(mapKey);
+ pushStateOptional.has_value()) {
+ auto pushState = pushStateOptional.value();
+ pushState->cleanup();
+ }
+ return true;
+ }
+ return false;
+}
+
+bool ShuffleClientImpl::mapperEnded(int shuffleId, int mapId) {
+ if (auto mapperEndSetOptional = mapperEndSets_.get(shuffleId);
+ mapperEndSetOptional.has_value() &&
+ mapperEndSetOptional.value()->contains(mapId)) {
+ return true;
+ }
+ if (stageEnded(shuffleId)) {
+ return true;
+ }
+ return false;
+}
+
+bool ShuffleClientImpl::stageEnded(int shuffleId) {
+ return stageEndShuffleSet_.contains(shuffleId);
+}
+
+void ShuffleClientImpl::addRequestToReviveManager(
+ std::shared_ptr<protocol::ReviveRequest> reviveRequest) {
+ reviveManager_->addRequest(std::move(reviveRequest));
+}
+
+std::optional<std::unordered_map<int, int>> ShuffleClientImpl::reviveBatch(
+ int shuffleId,
+ const std::unordered_set<int>& mapIds,
+ const std::unordered_map<int, PtrReviveRequest>& requests) {
+ std::unordered_map<int, int> result;
+ auto partitionLocationMap = partitionLocationMaps_.get(shuffleId).value();
+ std::unordered_map<int, std::shared_ptr<const protocol::PartitionLocation>>
+ oldLocationMap;
+ protocol::Revive revive;
+ revive.shuffleId = shuffleId;
+ revive.mapIds.insert(mapIds.begin(), mapIds.end());
+ for (auto& [partitionId, request] : requests) {
+ oldLocationMap[request->partitionId] = request->loc;
+ revive.reviveRequests.insert(request);
+ }
+ try {
+ auto response =
+ lifecycleManagerRef_
+ ->askSync<protocol::Revive, protocol::ChangeLocationResponse>(
+ revive,
+ conf_->clientRpcRequestPartitionLocationRpcAskTimeout());
+ auto mapperEndSet = mapperEndSets_.computeIfAbsent(shuffleId, []() {
+ return std::make_shared<utils::ConcurrentHashSet<int>>();
+ });
+ for (auto endedMapId : response->endedMapIds) {
+ mapperEndSet->insert(endedMapId);
+ }
+ for (auto& partitionInfo : response->partitionInfos) {
+ switch (partitionInfo.status) {
+ case protocol::StatusCode::SUCCESS: {
+ partitionLocationMap->set(
+ partitionInfo.partitionId, partitionInfo.partition);
+ break;
+ }
+ case protocol::StatusCode::STAGE_ENDED: {
+ stageEndShuffleSet_.insert(shuffleId);
+ return {std::move(result)};
+ }
+ case protocol::StatusCode::SHUFFLE_NOT_REGISTERED: {
+ LOG(ERROR) << "shuffleId " << shuffleId << " not registered!";
+ return std::nullopt;
+ }
+ default: {
+ // noop
+ }
+ }
+ result[partitionInfo.partitionId] = partitionInfo.status;
+ }
+ return {std::move(result)};
+ } catch (std::exception& e) {
+ LOG(ERROR) << "reviveBatch failed: " << e.what();
+ return std::nullopt;
+ }
+}
+
+bool ShuffleClientImpl::revive(
+ int shuffleId,
+ int mapId,
+ int attemptId,
+ int partitionId,
+ int epoch,
+ std::shared_ptr<const protocol::PartitionLocation> oldLocation,
+ protocol::StatusCode cause) {
+ auto request = std::make_shared<protocol::ReviveRequest>(
+ shuffleId, mapId, attemptId, partitionId, epoch, oldLocation, cause);
+ auto resultOptional =
+ reviveBatch(shuffleId, {mapId}, {{partitionId, request}});
+ if (mapperEnded(shuffleId, mapId)) {
+ VLOG(1) << "Revive success, but the mapper ended for shuffle " << shuffleId
+ << " map " << mapId << " attempt " << attemptId << " partition"
+ << partitionId << ", just return true(Assume revive
successfully).";
+ return true;
+ }
+ if (resultOptional.has_value()) {
+ auto result = resultOptional.value();
+ return result.find(partitionId) != result.end() &&
+ result[partitionId] == protocol::StatusCode::SUCCESS;
+ }
+ return false;
+}
+
+void ShuffleClientImpl::limitMaxInFlight(
+ const std::string& mapKey,
+ PushState& pushState,
+ const std::string& hostAndPushPort) {
+ bool reachLimit = pushState.limitMaxInFlight(hostAndPushPort);
+ if (reachLimit) {
+ auto msg = fmt::format(
+ "Waiting timeout for task {} while limiting max "
+ "in-flight requests to {}.",
+ mapKey,
+ hostAndPushPort);
+ if (auto exceptionMsgOptional = pushState.getExceptionMsg();
+ exceptionMsgOptional.has_value()) {
+ msg += " PushState exception: " + exceptionMsgOptional.value();
+ }
+ CELEBORN_FAIL(msg);
+ }
+}
+
+void ShuffleClientImpl::limitZeroInFlight(
+ const std::string& mapKey,
+ PushState& pushState) {
+ bool reachLimit = pushState.limitZeroInFlight();
+ if (reachLimit) {
+ auto msg = fmt::format(
+ "Waiting timeout for task {} while limiting zero "
+ "in-flight requests.",
+ mapKey);
+ if (auto exceptionMsgOptional = pushState.getExceptionMsg();
+ exceptionMsgOptional.has_value()) {
+ msg += " PushState exception: " + exceptionMsgOptional.value();
+ }
+ CELEBORN_FAIL(msg);
+ }
+}
+
+std::optional<ShuffleClientImpl::PtrPartitionLocationMap>
+ShuffleClientImpl::getPartitionLocationMap(int shuffleId) {
+ return partitionLocationMaps_.get(shuffleId);
+}
+
+utils::ConcurrentHashMap<int, std::shared_ptr<utils::ConcurrentHashSet<int>>>&
+ShuffleClientImpl::mapperEndSets() {
+ return mapperEndSets_;
+}
+
+void ShuffleClientImpl::addPushDataRetryTask(folly::Func&& task) {
+ pushDataRetryPool_->add(std::move(task));
+}
+
bool ShuffleClientImpl::newerPartitionLocationExists(
std::shared_ptr<utils::ConcurrentHashMap<
int,
diff --git a/cpp/celeborn/client/ShuffleClient.h
b/cpp/celeborn/client/ShuffleClient.h
index dc71a39bc..3e8cb9d37 100644
--- a/cpp/celeborn/client/ShuffleClient.h
+++ b/cpp/celeborn/client/ShuffleClient.h
@@ -32,6 +32,25 @@ class ShuffleClient {
virtual void setupLifecycleManagerRef(
std::shared_ptr<network::NettyRpcEndpointRef>& lifecycleManagerRef) = 0;
+ virtual int pushData(
+ int shuffleId,
+ int mapId,
+ int attemptId,
+ int partitionId,
+ const uint8_t* data,
+ size_t offset,
+ size_t length,
+ int numMappers,
+ int numPartitions) = 0;
+
+ // TODO: PushMergedData is not supported yet.
+
+ virtual void
+ mapperEnd(int shuffleId, int mapId, int attemptId, int numMappers) = 0;
+
+ // Cleanup states of a map task.
+ virtual void cleanup(int shuffleId, int mapId, int attemptId) = 0;
+
virtual void updateReducerFileGroup(int shuffleId) = 0;
virtual std::unique_ptr<CelebornInputStream> readPartition(
@@ -57,6 +76,23 @@ class ShuffleClient {
class ReviveManager;
class PushDataCallback;
+/// ShuffleClientEndpoint holds all the resources of ShuffleClient, including
+/// threadPools and clientFactories. The endpoint could be reused by multiple
+/// ShuffleClient to avoid creating too many resources.
+class ShuffleClientEndpoint {
+ public:
+ ShuffleClientEndpoint(const std::shared_ptr<const conf::CelebornConf>& conf);
+
+ std::shared_ptr<folly::IOThreadPoolExecutor> pushDataRetryPool() const;
+
+ std::shared_ptr<network::TransportClientFactory> clientFactory() const;
+
+ private:
+ const std::shared_ptr<const conf::CelebornConf> conf_;
+ std::shared_ptr<folly::IOThreadPoolExecutor> pushDataRetryPool_;
+ std::shared_ptr<network::TransportClientFactory> clientFactory_;
+};
+
class ShuffleClientImpl
: public ShuffleClient,
public std::enable_shared_from_this<ShuffleClientImpl> {
@@ -75,13 +111,41 @@ class ShuffleClientImpl
static std::shared_ptr<ShuffleClientImpl> create(
const std::string& appUniqueId,
const std::shared_ptr<const conf::CelebornConf>& conf,
- const std::shared_ptr<network::TransportClientFactory>& clientFactory);
+ const ShuffleClientEndpoint& clientEndpoint);
void setupLifecycleManagerRef(std::string& host, int port) override;
void setupLifecycleManagerRef(std::shared_ptr<network::NettyRpcEndpointRef>&
lifecycleManagerRef) override;
+ std::shared_ptr<utils::ConcurrentHashMap<
+ int,
+ std::shared_ptr<const protocol::PartitionLocation>>>
+ getPartitionLocation(int shuffleId, int numMappers, int numPartitions);
+
+ int pushData(
+ int shuffleId,
+ int mapId,
+ int attemptId,
+ int partitionId,
+ const uint8_t* data,
+ size_t offset,
+ size_t length,
+ int numMappers,
+ int numPartitions) override;
+
+ void mapperEnd(int shuffleId, int mapId, int attemptId, int numMappers)
+ override;
+
+ void mapPartitionMapperEnd(
+ int shuffleId,
+ int mapId,
+ int attemptId,
+ int numMappers,
+ int partitionId);
+
+ void cleanup(int shuffleId, int mapId, int attemptId) override;
+
std::unique_ptr<CelebornInputStream> readPartition(
int shuffleId,
int partitionId,
@@ -109,10 +173,8 @@ class ShuffleClientImpl
ShuffleClientImpl(
const std::string& appUniqueId,
const std::shared_ptr<const conf::CelebornConf>& conf,
- const std::shared_ptr<network::TransportClientFactory>& clientFactory);
+ const ShuffleClientEndpoint& clientEndpoint);
- // TODO: currently this function serves as a stub. will be updated in future
- // commits.
virtual void submitRetryPushData(
int shuffleId,
std::unique_ptr<memory::ReadOnlyByteBuffer> body,
@@ -121,44 +183,58 @@ class ShuffleClientImpl
std::shared_ptr<PushState> pushState,
PtrReviveRequest request,
int remainReviveTimes,
- long dueTimeMs) {}
+ long dueTimeMs);
- // TODO: currently this function serves as a stub. will be updated in future
- // commits.
- virtual bool mapperEnded(int shuffleId, int mapId) {
- return true;
- }
+ virtual bool mapperEnded(int shuffleId, int mapId);
- // TODO: currently this function serves as a stub. will be updated in future
- // commits.
virtual void addRequestToReviveManager(
- std::shared_ptr<protocol::ReviveRequest> reviveRequest) {}
+ std::shared_ptr<protocol::ReviveRequest> reviveRequest);
- // TODO: currently this function serves as a stub. will be updated in future
- // commits.
virtual std::optional<std::unordered_map<int, int>> reviveBatch(
int shuffleId,
const std::unordered_set<int>& mapIds,
- const std::unordered_map<int, PtrReviveRequest>& requests) {
- return std::nullopt;
- }
+ const std::unordered_map<int, PtrReviveRequest>& requests);
virtual std::optional<PtrPartitionLocationMap> getPartitionLocationMap(
- int shuffleId) {
- return partitionLocationMaps_.get(shuffleId);
- }
+ int shuffleId);
virtual utils::
ConcurrentHashMap<int, std::shared_ptr<utils::ConcurrentHashSet<int>>>&
- mapperEndSets() {
- return mapperEndSets_;
- }
+ mapperEndSets();
- virtual void addPushDataRetryTask(folly::Func&& task) {
- pushDataRetryPool_->add(std::move(task));
- }
+ virtual void addPushDataRetryTask(folly::Func&& task);
private:
+ std::shared_ptr<PushState> getPushState(const std::string& mapKey);
+
+ void initReviveManagerLocked();
+
+ void registerShuffle(int shuffleId, int numMappers, int numPartitions);
+
+ bool checkMapperEnded(int shuffleId, int mapId, const std::string& mapKey);
+
+ bool stageEnded(int shuffleId);
+
+ bool revive(
+ int shuffleId,
+ int mapId,
+ int attemptId,
+ int partitionId,
+ int epoch,
+ std::shared_ptr<const protocol::PartitionLocation> oldLocation,
+ protocol::StatusCode cause);
+
+ // Check if the pushState's ongoing package num reaches the max limit, if so,
+ // block until the ongoing package num decreases below max limit.
+ void limitMaxInFlight(
+ const std::string& mapKey,
+ PushState& pushState,
+ const std::string& hostAndPushPort);
+
+ // Check if the pushState's ongoing package num reaches zero, if not, block
+ // until the ongoing package num decreases to zero.
+ void limitZeroInFlight(const std::string& mapKey, PushState& pushState);
+
// TODO: no support for WAIT as it is not used.
static bool newerPartitionLocationExists(
std::shared_ptr<utils::ConcurrentHashMap<
@@ -170,6 +246,8 @@ class ShuffleClientImpl
std::shared_ptr<protocol::GetReducerFileGroupResponse>
getReducerFileGroupInfo(int shuffleId);
+ static constexpr size_t kBatchHeaderSize = 4 * 4;
+
const std::string appUniqueId_;
std::shared_ptr<const conf::CelebornConf> conf_;
std::shared_ptr<network::NettyRpcEndpointRef> lifecycleManagerRef_;
@@ -177,13 +255,18 @@ class ShuffleClientImpl
std::shared_ptr<folly::IOExecutor> pushDataRetryPool_;
std::shared_ptr<ReviveManager> reviveManager_;
std::mutex mutex_;
+ utils::ConcurrentHashMap<int, std::shared_ptr<std::mutex>> shuffleMutexes_;
utils::ConcurrentHashMap<
int,
std::shared_ptr<protocol::GetReducerFileGroupResponse>>
reducerFileGroupInfos_;
utils::ConcurrentHashMap<int, PtrPartitionLocationMap>
partitionLocationMaps_;
+ utils::ConcurrentHashMap<std::string, std::shared_ptr<PushState>>
pushStates_;
utils::ConcurrentHashMap<int, std::shared_ptr<utils::ConcurrentHashSet<int>>>
mapperEndSets_;
+ utils::ConcurrentHashSet<int> stageEndShuffleSet_;
+
+ // TODO: pushExcludedWorker is not supported yet
};
} // namespace client
} // namespace celeborn
diff --git a/cpp/celeborn/client/tests/PushDataCallbackTest.cpp
b/cpp/celeborn/client/tests/PushDataCallbackTest.cpp
index 171c8e714..f91637488 100644
--- a/cpp/celeborn/client/tests/PushDataCallbackTest.cpp
+++ b/cpp/celeborn/client/tests/PushDataCallbackTest.cpp
@@ -106,7 +106,9 @@ class MockShuffleClient : public ShuffleClientImpl {
: ShuffleClientImpl(
"mock",
std::make_shared<conf::CelebornConf>(),
- nullptr) {}
+ dummyEndpoint()) {}
+
+ static const ShuffleClientEndpoint& dummyEndpoint();
FuncOnSubmitRetryPushData onSubmitRetryPushData_ =
[](int,
@@ -126,6 +128,12 @@ class MockShuffleClient : public ShuffleClientImpl {
};
};
+const ShuffleClientEndpoint& MockShuffleClient::dummyEndpoint() {
+ static auto conf = std::make_shared<conf::CelebornConf>();
+ static auto dummy = ShuffleClientEndpoint(conf);
+ return dummy;
+}
+
std::unique_ptr<memory::ReadOnlyByteBuffer> createReadOnlyByteBuffer(
uint8_t code) {
auto writeBuffer = memory::ByteBuffer::createWriteOnly(1);
diff --git a/cpp/celeborn/client/tests/ReviveManagerTest.cpp
b/cpp/celeborn/client/tests/ReviveManagerTest.cpp
index 8db844485..7bae1ab2a 100644
--- a/cpp/celeborn/client/tests/ReviveManagerTest.cpp
+++ b/cpp/celeborn/client/tests/ReviveManagerTest.cpp
@@ -72,7 +72,10 @@ class MockShuffleClient : public ShuffleClientImpl {
: ShuffleClientImpl(
"mock",
std::make_shared<conf::CelebornConf>(),
- nullptr) {}
+ dummyEndpoint()) {}
+
+ static const ShuffleClientEndpoint& dummyEndpoint();
+
std::function<bool(int, int)> onMapperEnded_ = [](int, int) { return false;
};
std::function<std::optional<std::unordered_map<int, int>>(
int,
@@ -89,6 +92,12 @@ class MockShuffleClient : public ShuffleClientImpl {
return {std::make_shared<PartitionLocationMap>()};
};
};
+
+const ShuffleClientEndpoint& MockShuffleClient::dummyEndpoint() {
+ static auto conf = std::make_shared<conf::CelebornConf>();
+ static auto dummy = ShuffleClientEndpoint(conf);
+ return dummy;
+}
} // namespace
class ReviveManagerTest : public testing::Test {
diff --git a/cpp/celeborn/conf/CelebornConf.cpp
b/cpp/celeborn/conf/CelebornConf.cpp
index 1d58516b3..50b48aa7f 100644
--- a/cpp/celeborn/conf/CelebornConf.cpp
+++ b/cpp/celeborn/conf/CelebornConf.cpp
@@ -131,39 +131,50 @@ Duration toDuration(const std::string& str) {
} // namespace
-const std::unordered_map<std::string, folly::Optional<std::string>>
- CelebornConf::kDefaultProperties = {
- STR_PROP(kRpcAskTimeout, "60s"),
- STR_PROP(kRpcLookupTimeout, "30s"),
- STR_PROP(kClientPushReviveInterval, "100ms"),
- NUM_PROP(kClientPushReviveBatchSize, 2048),
- STR_PROP(kClientPushLimitStrategy, kSimplePushStrategy),
- NUM_PROP(kClientPushMaxReqsInFlightPerWorker, 32),
- NUM_PROP(kClientPushMaxReqsInFlightTotal, 256),
- NUM_PROP(kClientPushLimitInFlightTimeoutMs, 240000),
- NUM_PROP(kClientPushLimitInFlightSleepDeltaMs, 50),
- STR_PROP(kClientRpcRequestPartitionLocationAskTimeout, "60s"),
- STR_PROP(kClientRpcGetReducerFileGroupRpcAskTimeout, "60s"),
- STR_PROP(kNetworkConnectTimeout, "10s"),
- STR_PROP(kClientFetchTimeout, "600s"),
- NUM_PROP(kNetworkIoNumConnectionsPerPeer, 1),
- NUM_PROP(kNetworkIoClientThreads, 0),
- NUM_PROP(kClientFetchMaxReqsInFlight, 3),
- STR_PROP(
- kShuffleCompressionCodec,
- protocol::toString(protocol::CompressionCodec::NONE)),
- NUM_PROP(kShuffleCompressionZstdCompressLevel, 1),
- // NUM_PROP(kNumExample, 50'000),
- // BOOL_PROP(kBoolExample, false),
-};
+const std::unordered_map<std::string, folly::Optional<std::string>>&
+CelebornConf::defaultProperties() {
+ static const std::unordered_map<std::string, folly::Optional<std::string>>
+ defaultProp = {
+ STR_PROP(kRpcAskTimeout, "60s"),
+ STR_PROP(kRpcLookupTimeout, "30s"),
+ STR_PROP(kClientIoConnectionTimeout, "300s"),
+ STR_PROP(kClientRpcRegisterShuffleAskTimeout, "60s"),
+ NUM_PROP(kClientRegisterShuffleMaxRetries, 3),
+ STR_PROP(kClientRegisterShuffleRetryWait, "3s"),
+ NUM_PROP(kClientPushRetryThreads, 8),
+ STR_PROP(kClientPushTimeout, "120s"),
+ STR_PROP(kClientPushReviveInterval, "100ms"),
+ NUM_PROP(kClientPushReviveBatchSize, 2048),
+ NUM_PROP(kClientPushMaxReviveTimes, 5),
+ STR_PROP(kClientPushLimitStrategy, kSimplePushStrategy),
+ NUM_PROP(kClientPushMaxReqsInFlightPerWorker, 32),
+ NUM_PROP(kClientPushMaxReqsInFlightTotal, 256),
+ NUM_PROP(kClientPushLimitInFlightTimeoutMs, 240000),
+ NUM_PROP(kClientPushLimitInFlightSleepDeltaMs, 50),
+ STR_PROP(kClientRpcRequestPartitionLocationAskTimeout, "60s"),
+ STR_PROP(kClientRpcGetReducerFileGroupRpcAskTimeout, "60s"),
+ STR_PROP(kNetworkConnectTimeout, "10s"),
+ STR_PROP(kClientFetchTimeout, "600s"),
+ NUM_PROP(kNetworkIoNumConnectionsPerPeer, 1),
+ NUM_PROP(kNetworkIoClientThreads, 0),
+ NUM_PROP(kClientFetchMaxReqsInFlight, 3),
+ STR_PROP(
+ kShuffleCompressionCodec,
+ protocol::toString(protocol::CompressionCodec::NONE)),
+ NUM_PROP(kShuffleCompressionZstdCompressLevel, 1),
+ // NUM_PROP(kNumExample, 50'000),
+ // BOOL_PROP(kBoolExample, false),
+ };
+ return defaultProp;
+}
CelebornConf::CelebornConf() {
- registeredProps_ = kDefaultProperties;
+ registeredProps_ = defaultProperties();
}
CelebornConf::CelebornConf(const std::string& filename) {
initialize(filename);
- registeredProps_ = kDefaultProperties;
+ registeredProps_ = defaultProperties();
}
CelebornConf::CelebornConf(const CelebornConf& other) {
@@ -193,6 +204,34 @@ Timeout CelebornConf::rpcLookupTimeout() const {
toDuration(optionalProperty(kRpcLookupTimeout).value()));
}
+Timeout CelebornConf::clientIoConnectionTimeout() const {
+ return utils::toTimeout(
+ toDuration(optionalProperty(kClientIoConnectionTimeout).value()));
+}
+
+Timeout CelebornConf::clientRpcRegisterShuffleRpcAskTimeout() const {
+ return utils::toTimeout(toDuration(
+ optionalProperty(kClientRpcRegisterShuffleAskTimeout).value()));
+}
+
+int CelebornConf::clientRegisterShuffleMaxRetries() const {
+ return std::stoi(optionalProperty(kClientRegisterShuffleMaxRetries).value());
+}
+
+Timeout CelebornConf::clientRegisterShuffleRetryWait() const {
+ return utils::toTimeout(
+ toDuration(optionalProperty(kClientRegisterShuffleRetryWait).value()));
+}
+
+int CelebornConf::clientPushRetryThreads() const {
+ return std::stoi(optionalProperty(kClientPushRetryThreads).value());
+}
+
+Timeout CelebornConf::clientPushDataTimeout() const {
+ return utils::toTimeout(
+ toDuration(optionalProperty(kClientPushTimeout).value()));
+}
+
Timeout CelebornConf::clientPushReviveInterval() const {
return utils::toTimeout(
toDuration(optionalProperty(kClientPushReviveInterval).value()));
@@ -202,6 +241,10 @@ int CelebornConf::clientPushReviveBatchSize() const {
return std::stoi(optionalProperty(kClientPushReviveBatchSize).value());
}
+int CelebornConf::clientPushMaxReviveTimes() const {
+ return std::stoi(optionalProperty(kClientPushMaxReviveTimes).value());
+}
+
std::string CelebornConf::clientPushLimitStrategy() const {
return optionalProperty(kClientPushLimitStrategy).value();
}
diff --git a/cpp/celeborn/conf/CelebornConf.h b/cpp/celeborn/conf/CelebornConf.h
index 530a59781..fb4294d2c 100644
--- a/cpp/celeborn/conf/CelebornConf.h
+++ b/cpp/celeborn/conf/CelebornConf.h
@@ -37,20 +37,41 @@ namespace conf {
class CelebornConf : public BaseConf {
public:
- static const std::unordered_map<std::string, folly::Optional<std::string>>
- kDefaultProperties;
+ static const std::unordered_map<std::string, folly::Optional<std::string>>&
+ defaultProperties();
static constexpr std::string_view kRpcAskTimeout{"celeborn.rpc.askTimeout"};
static constexpr std::string_view kRpcLookupTimeout{
"celeborn.rpc.lookupTimeout"};
+ static constexpr std::string_view kClientIoConnectionTimeout{
+ "celeborn.client.io.connectionTimeout"};
+
+ static constexpr std::string_view kClientRpcRegisterShuffleAskTimeout{
+ "celeborn.client.rpc.registerShuffle.askTimeout"};
+
+ static constexpr std::string_view kClientRegisterShuffleMaxRetries{
+ "celeborn.client.registerShuffle.maxRetries"};
+
+ static constexpr std::string_view kClientRegisterShuffleRetryWait{
+ "celeborn.client.registerShuffle.retryWait"};
+
+ static constexpr std::string_view kClientPushRetryThreads{
+ "celeborn.client.push.retry.threads"};
+
+ static constexpr std::string_view kClientPushTimeout{
+ "celeborn.client.push.timeout"};
+
static constexpr std::string_view kClientPushReviveInterval{
"celeborn.client.push.revive.interval"};
static constexpr std::string_view kClientPushReviveBatchSize{
"celeborn.client.push.revive.batchSize"};
+ static constexpr std::string_view kClientPushMaxReviveTimes{
+ "celeborn.client.push.revive.maxRetries"};
+
static constexpr std::string_view kClientPushLimitStrategy{
"celeborn.client.push.limit.strategy"};
@@ -110,10 +131,24 @@ class CelebornConf : public BaseConf {
Timeout rpcLookupTimeout() const;
+ Timeout clientIoConnectionTimeout() const;
+
+ Timeout clientRpcRegisterShuffleRpcAskTimeout() const;
+
+ int clientRegisterShuffleMaxRetries() const;
+
+ Timeout clientRegisterShuffleRetryWait() const;
+
+ int clientPushRetryThreads() const;
+
+ Timeout clientPushDataTimeout() const;
+
Timeout clientPushReviveInterval() const;
int clientPushReviveBatchSize() const;
+ int clientPushMaxReviveTimes() const;
+
std::string clientPushLimitStrategy() const;
int clientPushMaxReqsInFlightPerWorker() const;
diff --git a/cpp/celeborn/memory/ByteBuffer.h b/cpp/celeborn/memory/ByteBuffer.h
index 0c1027fe2..434d51b7f 100644
--- a/cpp/celeborn/memory/ByteBuffer.h
+++ b/cpp/celeborn/memory/ByteBuffer.h
@@ -172,8 +172,11 @@ class WriteOnlyByteBuffer : public ByteBuffer {
appender_->push(reinterpret_cast<const uint8_t*>(ptr), data.size());
}
- void writeFromBuffer(const void* data, const size_t len) const {
- appender_->push(static_cast<const uint8_t*>(data), len);
+ void writeFromBuffer(
+ const uint8_t* data,
+ const size_t offset,
+ const size_t length) const {
+ appender_->push(data + offset, length);
}
size_t size() const {
diff --git a/cpp/celeborn/network/TransportClient.cpp
b/cpp/celeborn/network/TransportClient.cpp
index c3e4d81fd..c4b865d1c 100644
--- a/cpp/celeborn/network/TransportClient.cpp
+++ b/cpp/celeborn/network/TransportClient.cpp
@@ -165,7 +165,7 @@ SerializePipeline::Ptr MessagePipelineFactory::newPipeline(
}
TransportClientFactory::TransportClientFactory(
- const std::shared_ptr<conf::CelebornConf>& conf) {
+ const std::shared_ptr<const conf::CelebornConf>& conf) {
numConnectionsPerPeer_ = conf->networkIoNumConnectionsPerPeer();
rpcLookupTimeout_ = conf->rpcLookupTimeout();
connectTimeout_ = conf->networkConnectTimeout();
@@ -180,6 +180,13 @@ TransportClientFactory::TransportClientFactory(
std::shared_ptr<TransportClient> TransportClientFactory::createClient(
const std::string& host,
uint16_t port) {
+ return createClient(host, port, std::rand());
+}
+
+std::shared_ptr<TransportClient> TransportClientFactory::createClient(
+ const std::string& host,
+ uint16_t port,
+ int32_t partitionId) {
auto address = folly::SocketAddress(host, port);
auto pool = clientPools_.withLock([&](auto& registry) {
auto iter = registry.find(address);
@@ -191,7 +198,7 @@ std::shared_ptr<TransportClient>
TransportClientFactory::createClient(
registry[address] = createdPool;
return createdPool;
});
- auto clientId = std::rand() % numConnectionsPerPeer_;
+ auto clientId = partitionId % numConnectionsPerPeer_;
{
std::lock_guard<std::mutex> lock(pool->mutex);
// TODO: auto-disconnect if the connection is idle for a long time?
diff --git a/cpp/celeborn/network/TransportClient.h
b/cpp/celeborn/network/TransportClient.h
index e3ece7d22..78c87414c 100644
--- a/cpp/celeborn/network/TransportClient.h
+++ b/cpp/celeborn/network/TransportClient.h
@@ -119,12 +119,15 @@ class MessagePipelineFactory
class TransportClientFactory {
public:
- TransportClientFactory(const std::shared_ptr<conf::CelebornConf>& conf);
+ TransportClientFactory(const std::shared_ptr<const conf::CelebornConf>&
conf);
virtual std::shared_ptr<TransportClient> createClient(
const std::string& host,
uint16_t port);
+ virtual std::shared_ptr<TransportClient>
+ createClient(const std::string& host, uint16_t port, int32_t partitionId);
+
private:
struct ClientPool {
std::mutex mutex;
diff --git a/cpp/celeborn/tests/DataSumWithReaderClient.cpp
b/cpp/celeborn/tests/DataSumWithReaderClient.cpp
index ac62d13fe..533303323 100644
--- a/cpp/celeborn/tests/DataSumWithReaderClient.cpp
+++ b/cpp/celeborn/tests/DataSumWithReaderClient.cpp
@@ -44,10 +44,10 @@ int main(int argc, char** argv) {
auto conf = std::make_shared<celeborn::conf::CelebornConf>();
conf->registerProperty(
celeborn::conf::CelebornConf::kShuffleCompressionCodec, compressCodec);
- auto clientFactory =
- std::make_shared<celeborn::network::TransportClientFactory>(conf);
+ auto clientEndpoint =
+ std::make_shared<celeborn::client::ShuffleClientEndpoint>(conf);
auto shuffleClient = celeborn::client::ShuffleClientImpl::create(
- appUniqueId, conf, clientFactory);
+ appUniqueId, conf, *clientEndpoint);
shuffleClient->setupLifecycleManagerRef(
lifecycleManagerHost, lifecycleManagerPort);
diff --git a/cpp/celeborn/utils/CelebornUtils.cpp
b/cpp/celeborn/utils/CelebornUtils.cpp
index 24340cb79..675850445 100644
--- a/cpp/celeborn/utils/CelebornUtils.cpp
+++ b/cpp/celeborn/utils/CelebornUtils.cpp
@@ -23,6 +23,10 @@ std::string makeShuffleKey(const std::string& appId, const
int shuffleId) {
return appId + "-" + std::to_string(shuffleId);
}
+std::string makeMapKey(int shuffleId, int mapId, int attemptId) {
+ return fmt::format("{}-{}-{}", shuffleId, mapId, attemptId);
+}
+
void writeUTF(memory::WriteOnlyByteBuffer& buffer, const std::string& msg) {
buffer.write<short>(msg.size());
buffer.writeFromString(msg);
diff --git a/cpp/celeborn/utils/CelebornUtils.h
b/cpp/celeborn/utils/CelebornUtils.h
index ac1419914..f79f329a5 100644
--- a/cpp/celeborn/utils/CelebornUtils.h
+++ b/cpp/celeborn/utils/CelebornUtils.h
@@ -48,6 +48,8 @@ std::vector<T> toVector(const std::set<T>& in) {
std::string makeShuffleKey(const std::string& appId, int shuffleId);
+std::string makeMapKey(int shuffleId, int mapId, int attemptId);
+
void writeUTF(memory::WriteOnlyByteBuffer& buffer, const std::string& msg);
void writeRpcAddress(
@@ -68,6 +70,12 @@ inline uint64_t currentTimeMillis() {
.count();
}
+inline uint64_t currentTimeNanos() {
+ return std::chrono::duration_cast<std::chrono::nanoseconds>(
+ std::chrono::high_resolution_clock ::now().time_since_epoch())
+ .count();
+}
+
/// parse string like "Any-Host-Str:Port#1:Port#2:...:Port#num", split into
/// {"Any-Host-Str", "Port#1", "Port#2", ..., "Port#num"}. Note that the
/// "Any-Host-Str" might contain ':' in IPV6 address.