This is an automated email from the ASF dual-hosted git repository.
nicholasjiang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new fddb81754 [CELEBORN-2226][CIP-14] Support RetryFetchChunk
functionality for Cel…
fddb81754 is described below
commit fddb81754b03326f15df0b84ea1568a0621b7b88
Author: afterincomparableyum
<[email protected]>
AuthorDate: Sun Mar 1 09:54:10 2026 +0800
[CELEBORN-2226][CIP-14] Support RetryFetchChunk functionality for Cel…
Implement chunk-fetch retry logic in CelebornInputStream::getNextChunk(),
matching the Java CelebornInputStream behavior. When a chunk fetch fails, the
retry loop excludes the failed worker, switches to the peer replica (if
available), and sleeps between retry rounds before creating a new reader.
Added getLocation() to PartitionReader interface and WorkerPartitionReader
Replaced the stub getNextChunk() with full retry logic: excluded worker
checks, peer switching, configurable retry count, sleep between retries
Updated moveToNextChunk() and moveToNextReader() to handle nullable returns
from getNextChunk()
Added unit test for WorkerPartitionReader::getLocation()
Added unit tests for getNextChunk() retry logic
CI and build passes
Closes #3605 from afterincomparableyum/cpp-client/celeborn-2226.
Authored-by: afterincomparableyum
<[email protected]>
Signed-off-by: SteNicholas <[email protected]>
---
cpp/celeborn/client/reader/CelebornInputStream.cpp | 66 +++++-
.../client/reader/WorkerPartitionReader.cpp | 4 +
cpp/celeborn/client/reader/WorkerPartitionReader.h | 8 +
.../client/tests/CelebornInputStreamRetryTest.cpp | 244 +++++++++++++++++++++
.../client/tests/WorkerPartitionReaderTest.cpp | 36 +++
5 files changed, 351 insertions(+), 7 deletions(-)
diff --git a/cpp/celeborn/client/reader/CelebornInputStream.cpp
b/cpp/celeborn/client/reader/CelebornInputStream.cpp
index f81152321..80d400ec1 100644
--- a/cpp/celeborn/client/reader/CelebornInputStream.cpp
+++ b/cpp/celeborn/client/reader/CelebornInputStream.cpp
@@ -157,7 +157,9 @@ bool CelebornInputStream::moveToNextChunk() {
if (currReader_->hasNext()) {
currChunk_ = getNextChunk();
- return true;
+ if (currChunk_) {
+ return true;
+ }
}
if (currLocationIndex_ < locations_.size()) {
moveToNextReader();
@@ -169,11 +171,59 @@ bool CelebornInputStream::moveToNextChunk() {
std::unique_ptr<memory::ReadOnlyByteBuffer>
CelebornInputStream::getNextChunk() {
- // TODO: support the failure retrying, including excluding the failed
- // location, open a reader to read from the location's peer.
- auto chunk = currReader_->next();
- verifyChunk(chunk);
- return std::move(chunk);
+ while (fetchChunkRetryCnt_ < fetchChunkMaxRetry_) {
+ try {
+ if (isExcluded(currReader_->getLocation())) {
+ CELEBORN_FAIL(
+ "Fetch data from excluded worker! {}",
+ currReader_->getLocation().hostAndFetchPort());
+ }
+ if (!currReader_->hasNext()) {
+ return nullptr;
+ }
+ auto chunk = currReader_->next();
+ verifyChunk(chunk);
+ return std::move(chunk);
+ } catch (const std::exception& e) {
+ auto failedLocation = currReader_->getLocation();
+ shuffleClient_->excludeFailedFetchLocation(
+ failedLocation.hostAndFetchPort(), e);
+ fetchChunkRetryCnt_++;
+ currReader_ = nullptr;
+
+ if (fetchChunkRetryCnt_ == fetchChunkMaxRetry_) {
+ CELEBORN_FAIL(
+ "Fetch chunk failed for {} times for location {}. Error: {}",
+ fetchChunkRetryCnt_,
+ failedLocation.hostAndFetchPort(),
+ e.what());
+ }
+
+ if (failedLocation.hasPeer() && !readSkewPartitionWithoutMapRange_) {
+ LOG(WARNING) << "Fetch chunk failed " << fetchChunkRetryCnt_ << "/"
+ << fetchChunkMaxRetry_ << " times for location "
+ << failedLocation.hostAndFetchPort()
+ << ", change to peer. Error: " << e.what();
+ // fetchChunkRetryCnt_ % 2 == 0 means both replicas have been tried,
+ // so sleep before next try.
+ if (fetchChunkRetryCnt_ % 2 == 0) {
+ std::this_thread::sleep_for(retryWait_);
+ }
+ currReader_ = createReaderWithRetry(*failedLocation.getPeer());
+ } else {
+ LOG(WARNING) << "Fetch chunk failed " << fetchChunkRetryCnt_ << "/"
+ << fetchChunkMaxRetry_ << " times for location "
+ << failedLocation.hostAndFetchPort()
+ << ". Error: " << e.what();
+ std::this_thread::sleep_for(retryWait_);
+ // TODO: Pass checkpoint metadata when supported to skip
+ // already-read chunks, improving retry performance.
+ currReader_ = createReaderWithRetry(failedLocation);
+ }
+ }
+ }
+
+ CELEBORN_FAIL("Fetch chunk failed!");
}
void CelebornInputStream::verifyChunk(
@@ -204,7 +254,9 @@ void CelebornInputStream::moveToNextReader() {
currLocationIndex_++;
if (currReader_->hasNext()) {
currChunk_ = getNextChunk();
- return;
+ if (currChunk_) {
+ return;
+ }
}
moveToNextReader();
}
diff --git a/cpp/celeborn/client/reader/WorkerPartitionReader.cpp
b/cpp/celeborn/client/reader/WorkerPartitionReader.cpp
index a7983a01e..7fb2c5de7 100644
--- a/cpp/celeborn/client/reader/WorkerPartitionReader.cpp
+++ b/cpp/celeborn/client/reader/WorkerPartitionReader.cpp
@@ -73,6 +73,10 @@ WorkerPartitionReader::~WorkerPartitionReader() {
client_->sendRpcRequestWithoutResponse(request);
}
+const protocol::PartitionLocation& WorkerPartitionReader::getLocation() const {
+ return location_;
+}
+
bool WorkerPartitionReader::hasNext() {
return toConsumeChunkId_ < streamHandler_->numChunks;
}
diff --git a/cpp/celeborn/client/reader/WorkerPartitionReader.h
b/cpp/celeborn/client/reader/WorkerPartitionReader.h
index db22da8c2..bda9d73f5 100644
--- a/cpp/celeborn/client/reader/WorkerPartitionReader.h
+++ b/cpp/celeborn/client/reader/WorkerPartitionReader.h
@@ -29,6 +29,8 @@ class PartitionReader {
virtual bool hasNext() = 0;
virtual std::unique_ptr<memory::ReadOnlyByteBuffer> next() = 0;
+
+ virtual const protocol::PartitionLocation& getLocation() const = 0;
};
class WorkerPartitionReader
@@ -51,6 +53,8 @@ class WorkerPartitionReader
std::unique_ptr<memory::ReadOnlyByteBuffer> next() override;
+ const protocol::PartitionLocation& getLocation() const override;
+
private:
// Disable creating the object directly to make sure that
// std::enable_shared_from_this works properly.
@@ -88,6 +92,10 @@ class WorkerPartitionReader
static constexpr auto kDefaultConsumeIter = std::chrono::milliseconds(500);
// TODO: add other params, such as fetchChunkRetryCnt, fetchChunkMaxRetries
+ // TODO: add TEST_CLIENT_FETCH_FAILURE support (matching Java's testFetch
+ // flag) to enable integration testing of getNextChunk() retry logic, as
+ // done by Java's ReadWriteTestWithFailures. This will likely be done once
+ // full C++ write support is achieved.
};
} // namespace client
} // namespace celeborn
diff --git a/cpp/celeborn/client/tests/CelebornInputStreamRetryTest.cpp
b/cpp/celeborn/client/tests/CelebornInputStreamRetryTest.cpp
index 564860459..f37ef103a 100644
--- a/cpp/celeborn/client/tests/CelebornInputStreamRetryTest.cpp
+++ b/cpp/celeborn/client/tests/CelebornInputStreamRetryTest.cpp
@@ -27,6 +27,7 @@ using namespace celeborn::client;
using namespace celeborn::network;
using namespace celeborn::protocol;
using namespace celeborn::conf;
+using namespace celeborn::memory;
namespace {
using MS = std::chrono::milliseconds;
@@ -180,6 +181,117 @@ std::shared_ptr<CelebornConf> makeTestConf(bool
replicateEnabled = true) {
CelebornConf::kClientFetchExcludeWorkerOnFailureEnabled, "true");
return conf;
}
+// Creates a valid PbStreamHandler RpcResponse for WorkerPartitionReader
+// construction. Each copy is safe to use independently (RpcResponse clones
+// the body on copy).
+RpcResponse makeStreamHandlerResponse(int numChunks = 1) {
+ PbStreamHandler pb;
+ pb.set_streamid(100);
+ pb.set_numchunks(numChunks);
+ for (int i = 0; i < numChunks; i++) {
+ pb.add_chunkoffsets(i);
+ }
+ pb.set_fullpath("test-fullpath");
+ TransportMessage transportMessage(STREAM_HANDLER, pb.SerializeAsString());
+ return RpcResponse(1111, transportMessage.toReadOnlyByteBuffer());
+}
+
+// Creates a chunk buffer
+std::unique_ptr<ReadOnlyByteBuffer> makeChunkBuffer(
+ int mapId,
+ int attemptId,
+ int batchId,
+ const std::string& payload) {
+ const size_t totalSize = 4 * sizeof(int) + payload.size();
+ auto buffer = ByteBuffer::createWriteOnly(totalSize, false);
+ buffer->writeLE<int>(mapId);
+ buffer->writeLE<int>(attemptId);
+ buffer->writeLE<int>(batchId);
+ buffer->writeLE<int>(static_cast<int>(payload.size()));
+ buffer->writeFromString(payload);
+ return ByteBuffer::toReadOnly(std::move(buffer));
+}
+
+// A TransportClient whose sendRpcRequestSync always returns a valid
+// stream handler, and whose fetchChunkAsync walks through a pre-configured
+// sequence of success/failure behaviors. Used to exercise the getNextChunk()
+// retry loop independently of reader-creation failures.
+class SequencedMockTransportClient : public TransportClient {
+ public:
+ SequencedMockTransportClient()
+ : TransportClient(nullptr, nullptr, MS(100)),
+ streamHandlerResponse_(makeStreamHandlerResponse()) {}
+
+ RpcResponse sendRpcRequestSync(const RpcRequest& request, Timeout timeout)
+ override {
+ return streamHandlerResponse_;
+ }
+
+ void sendRpcRequestWithoutResponse(const RpcRequest& request) override {}
+
+ void fetchChunkAsync(
+ const StreamChunkSlice& streamChunkSlice,
+ const RpcRequest& request,
+ FetchChunkSuccessCallback onSuccess,
+ FetchChunkFailureCallback onFailure) override {
+ auto idx = fetchCallIdx_++;
+ if (idx < fetchBehaviors_.size()) {
+ fetchBehaviors_[idx](streamChunkSlice, onSuccess, onFailure);
+ }
+ }
+
+ using FetchBehavior = std::function<void(
+ const StreamChunkSlice&,
+ FetchChunkSuccessCallback,
+ FetchChunkFailureCallback)>;
+
+ void addFetchSuccess(std::unique_ptr<ReadOnlyByteBuffer> chunk) {
+ auto iobuf = std::shared_ptr<folly::IOBuf>(chunk->getData());
+ fetchBehaviors_.push_back([iobuf](
+ const StreamChunkSlice& slice,
+ FetchChunkSuccessCallback onSuccess,
+ FetchChunkFailureCallback) {
+ onSuccess(slice, ByteBuffer::createReadOnly(iobuf->clone(), false));
+ });
+ }
+
+ void addFetchFailure(const std::string& errorMessage) {
+ fetchBehaviors_.push_back([errorMessage](
+ const StreamChunkSlice& slice,
+ FetchChunkSuccessCallback,
+ FetchChunkFailureCallback onFailure) {
+ onFailure(slice, std::make_unique<std::runtime_error>(errorMessage));
+ });
+ }
+
+ private:
+ RpcResponse streamHandlerResponse_;
+ std::vector<FetchBehavior> fetchBehaviors_;
+ size_t fetchCallIdx_{0};
+};
+
+class SequencedMockClientFactory : public TransportClientFactory {
+ public:
+ explicit SequencedMockClientFactory(
+ std::shared_ptr<SequencedMockTransportClient> client)
+ : TransportClientFactory(std::make_shared<CelebornConf>()),
+ client_(std::move(client)) {}
+
+ std::shared_ptr<TransportClient> createClient(
+ const std::string& host,
+ uint16_t port) override {
+ hosts_.push_back(host);
+ return client_;
+ }
+
+ const std::vector<std::string>& hosts() const {
+ return hosts_;
+ }
+
+ private:
+ std::shared_ptr<SequencedMockTransportClient> client_;
+ std::vector<std::string> hosts_;
+};
} // namespace
// Verifies that createReaderWithRetry exhausts all retries and throws.
@@ -408,4 +520,136 @@ TEST(CelebornInputStreamRetryTest,
replicationDoublesMaxRetries) {
// With maxRetriesForEachReplica=2 and replication enabled,
// fetchChunkMaxRetry = 2 * 2 = 4 total attempts
EXPECT_EQ(factory->hosts().size(), 4u);
+}
+
+// getNextChunk() retry tests
+// These tests exercise the retry loop inside getNextChunk(), which is
+// triggered when a successfully-created reader's next() call fails during
+// chunk fetching.
+
+// Verifies that when a chunk fetch fails on the primary, getNextChunk()
+// switches to the peer replica and successfully reads data on retry.
+TEST(CelebornInputStreamRetryTest, fetchChunkRetrySucceedsWithPeerSwitch) {
+ auto mockClient = std::make_shared<SequencedMockTransportClient>();
+ mockClient->addFetchFailure("chunk fetch failed");
+ const std::string payload = "hello";
+ mockClient->addFetchSuccess(makeChunkBuffer(0, 0, 0, payload));
+
+ auto factory = std::make_shared<SequencedMockClientFactory>(mockClient);
+ auto conf = makeTestConf(true);
+ auto excludedWorkers =
+ std::make_shared<CelebornInputStream::FetchExcludedWorkers>();
+ StubShuffleClient shuffleClient(conf, excludedWorkers);
+
+ auto location = makeLocationWithPeer();
+ std::vector<std::shared_ptr<const PartitionLocation>> locations;
+ locations.push_back(std::move(location));
+ std::vector<int> attempts = {0};
+
+ CelebornInputStream stream(
+ "test-shuffle-key",
+ conf,
+ factory,
+ std::move(locations),
+ attempts,
+ 0,
+ 0,
+ 100,
+ false,
+ excludedWorkers,
+ &shuffleClient);
+
+ std::vector<uint8_t> buffer(payload.size());
+ int bytesRead = stream.read(buffer.data(), 0, payload.size());
+ EXPECT_EQ(bytesRead, payload.size());
+ EXPECT_EQ(std::string(buffer.begin(), buffer.end()), payload);
+
+ auto& hosts = factory->hosts();
+ ASSERT_GE(hosts.size(), 2u);
+ EXPECT_EQ(hosts[0], "primary-host");
+ EXPECT_EQ(hosts[1], "replica-host");
+}
+
+// Verifies that getNextChunk() throws after exhausting all chunk-fetch
retries.
+TEST(CelebornInputStreamRetryTest, fetchChunkRetryExhaustsAllRetries) {
+ auto mockClient = std::make_shared<SequencedMockTransportClient>();
+ for (int i = 0; i < 4; i++) {
+ mockClient->addFetchFailure("chunk fetch failed");
+ }
+
+ auto factory = std::make_shared<SequencedMockClientFactory>(mockClient);
+ auto conf = makeTestConf(true);
+ auto excludedWorkers =
+ std::make_shared<CelebornInputStream::FetchExcludedWorkers>();
+ StubShuffleClient shuffleClient(conf, excludedWorkers);
+
+ auto location = makeLocationWithPeer();
+ std::vector<std::shared_ptr<const PartitionLocation>> locations;
+ locations.push_back(std::move(location));
+ std::vector<int> attempts = {0};
+
+ EXPECT_THROW(
+ CelebornInputStream(
+ "test-shuffle-key",
+ conf,
+ factory,
+ std::move(locations),
+ attempts,
+ 0,
+ 0,
+ 100,
+ false,
+ excludedWorkers,
+ &shuffleClient),
+ std::exception);
+
+ auto& hosts = factory->hosts();
+ EXPECT_EQ(hosts.size(), 4u);
+ EXPECT_EQ(hosts[0], "primary-host");
+ for (size_t i = 1; i < hosts.size(); i++) {
+ EXPECT_EQ(hosts[i], "replica-host");
+ }
+}
+
+// Verifies that without a peer, getNextChunk() retries the same location
+// and succeeds on the second attempt.
+TEST(CelebornInputStreamRetryTest, fetchChunkRetryNoPeerRetriesSameLocation) {
+ auto mockClient = std::make_shared<SequencedMockTransportClient>();
+ mockClient->addFetchFailure("chunk fetch failed");
+ const std::string payload = "world";
+ mockClient->addFetchSuccess(makeChunkBuffer(0, 0, 0, payload));
+
+ auto factory = std::make_shared<SequencedMockClientFactory>(mockClient);
+ auto conf = makeTestConf(false);
+ auto excludedWorkers =
+ std::make_shared<CelebornInputStream::FetchExcludedWorkers>();
+ StubShuffleClient shuffleClient(conf, excludedWorkers);
+
+ auto location = makeLocationWithoutPeer();
+ std::vector<std::shared_ptr<const PartitionLocation>> locations;
+ locations.push_back(std::move(location));
+ std::vector<int> attempts = {0};
+
+ CelebornInputStream stream(
+ "test-shuffle-key",
+ conf,
+ factory,
+ std::move(locations),
+ attempts,
+ 0,
+ 0,
+ 100,
+ false,
+ excludedWorkers,
+ &shuffleClient);
+
+ std::vector<uint8_t> buffer(payload.size());
+ int bytesRead = stream.read(buffer.data(), 0, payload.size());
+ EXPECT_EQ(bytesRead, payload.size());
+ EXPECT_EQ(std::string(buffer.begin(), buffer.end()), payload);
+
+ for (const auto& host : factory->hosts()) {
+ EXPECT_EQ(host, "solo-host");
+ }
+ EXPECT_EQ(factory->hosts().size(), 2u);
}
\ No newline at end of file
diff --git a/cpp/celeborn/client/tests/WorkerPartitionReaderTest.cpp
b/cpp/celeborn/client/tests/WorkerPartitionReaderTest.cpp
index d5a07a992..17703387a 100644
--- a/cpp/celeborn/client/tests/WorkerPartitionReaderTest.cpp
+++ b/cpp/celeborn/client/tests/WorkerPartitionReaderTest.cpp
@@ -227,3 +227,39 @@ TEST(WorkerPartitionReaderTest, fetchChunkFailure) {
EXPECT_TRUE(partitionReader->hasNext());
EXPECT_THROW(partitionReader->next(), std::exception);
}
+
+TEST(WorkerPartitionReaderTest, getLocationReturnsCorrectLocation) {
+ MockTransportClientFactory mockedClientFactory;
+ auto conf = std::make_shared<CelebornConf>();
+ auto transportClient = mockedClientFactory.getClient();
+
+ PbStreamHandler pb;
+ pb.set_streamid(100);
+ pb.set_numchunks(0);
+ pb.set_fullpath("test-fullpath");
+ TransportMessage transportMessage(STREAM_HANDLER, pb.SerializeAsString());
+ RpcResponse response =
+ RpcResponse(1111, transportMessage.toReadOnlyByteBuffer());
+ transportClient->setSyncResponse(response);
+
+ PartitionLocation location;
+ location.id = 42;
+ location.epoch = 7;
+ location.host = "test-location-host";
+ location.rpcPort = 0;
+ location.pushPort = 0;
+ location.fetchPort = 9999;
+ location.replicatePort = 0;
+ location.mode = PartitionLocation::PRIMARY;
+ location.storageInfo = std::make_unique<StorageInfo>();
+ location.storageInfo->type = StorageInfo::HDD;
+
+ auto reader = WorkerPartitionReader::create(
+ conf, "shuffle-key", location, 0, 100, &mockedClientFactory);
+
+ const auto& loc = reader->getLocation();
+ EXPECT_EQ(loc.host, "test-location-host");
+ EXPECT_EQ(loc.fetchPort, 9999);
+ EXPECT_EQ(loc.id, 42);
+ EXPECT_EQ(loc.epoch, 7);
+}