Copilot commented on code in PR #3611:
URL: https://github.com/apache/celeborn/pull/3611#discussion_r2958952212
##########
cpp/celeborn/protocol/Encoders.cpp:
##########
@@ -33,5 +34,70 @@ std::string decode(memory::ReadOnlyByteBuffer& buffer) {
int size = buffer.read<int>();
return buffer.readToString(size);
}
+
+int encodedLength(const std::vector<std::string>& arr) {
+ int total = sizeof(int);
+ for (const auto& s : arr) {
+ total += encodedLength(s);
+ }
+ return total;
+}
+
+void encode(
+ memory::WriteOnlyByteBuffer& buffer,
+ const std::vector<std::string>& arr) {
+ buffer.write<int>(static_cast<int>(arr.size()));
+ for (const auto& s : arr) {
+ encode(buffer, s);
+ }
Review Comment:
`encode(..., const std::vector<std::string>&)` writes `arr.size()` as an
`int` via `static_cast<int>(arr.size())` without checking it fits. If
`arr.size()` exceeds `INT_MAX`, this truncates/overflows and produces an
invalid wire format. Add an explicit size check (or use a fixed-width unsigned
length type consistently) before casting.
##########
cpp/celeborn/protocol/Encoders.cpp:
##########
@@ -33,5 +34,70 @@ std::string decode(memory::ReadOnlyByteBuffer& buffer) {
int size = buffer.read<int>();
return buffer.readToString(size);
}
+
+int encodedLength(const std::vector<std::string>& arr) {
+ int total = sizeof(int);
+ for (const auto& s : arr) {
+ total += encodedLength(s);
+ }
+ return total;
+}
+
+void encode(
+ memory::WriteOnlyByteBuffer& buffer,
+ const std::vector<std::string>& arr) {
+ buffer.write<int>(static_cast<int>(arr.size()));
+ for (const auto& s : arr) {
+ encode(buffer, s);
+ }
+}
+
+std::vector<std::string> decodeStringArray(memory::ReadOnlyByteBuffer& buffer)
{
+ int count = buffer.read<int>();
+ CELEBORN_CHECK_GE(count, 0, "Invalid string array count: {}", count);
+ CELEBORN_CHECK_LE(
+ static_cast<size_t>(count) * sizeof(int),
+ buffer.remainingSize(),
+ "String array count {} exceeds remaining buffer size {}",
+ count,
+ buffer.remainingSize());
+ std::vector<std::string> result;
+ result.reserve(count);
+ for (int i = 0; i < count; i++) {
+ result.push_back(decode(buffer));
+ }
+ return result;
+}
+
+int encodedLength(const std::vector<int32_t>& arr) {
+ return sizeof(int) + sizeof(int32_t) * arr.size();
+}
+
+void encode(
+ memory::WriteOnlyByteBuffer& buffer,
+ const std::vector<int32_t>& arr) {
+ buffer.write<int>(static_cast<int>(arr.size()));
+ for (auto val : arr) {
+ buffer.write<int32_t>(val);
+ }
+}
+
+std::vector<int32_t> decodeIntArray(memory::ReadOnlyByteBuffer& buffer) {
+ int count = buffer.read<int>();
+ CELEBORN_CHECK_GE(count, 0, "Invalid int array count: {}", count);
+ CELEBORN_CHECK_LE(
+ static_cast<size_t>(count) * sizeof(int32_t),
+ buffer.remainingSize(),
Review Comment:
`decodeIntArray` does `static_cast<size_t>(count) * sizeof(int32_t)` for the
bounds check. This multiplication can overflow for large `count` values and
allow an invalid count to pass the check. Prefer `count <= remainingSize /
sizeof(int32_t)` to make the bound overflow-safe.
##########
cpp/celeborn/client/writer/PushMergedDataCallback.cpp:
##########
@@ -0,0 +1,343 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "celeborn/client/writer/PushMergedDataCallback.h"
+#include "celeborn/conf/CelebornConf.h"
+#include "celeborn/protocol/TransportMessage.h"
+#include "celeborn/utils/Exceptions.h"
+
+namespace celeborn {
+namespace client {
+
+std::shared_ptr<PushMergedDataCallback> PushMergedDataCallback::create(
+ int shuffleId,
+ int mapId,
+ int attemptId,
+ int numMappers,
+ int numPartitions,
+ const std::string& mapKey,
+ const std::string& hostAndPushPort,
+ int groupedBatchId,
+ std::vector<DataBatch> batches,
+ std::vector<int> partitionIds,
+ std::shared_ptr<PushState> pushState,
+ std::weak_ptr<ShuffleClientImpl> weakClient,
+ int remainingReviveTimes) {
+ return std::shared_ptr<PushMergedDataCallback>(new PushMergedDataCallback(
+ shuffleId,
+ mapId,
+ attemptId,
+ numMappers,
+ numPartitions,
+ mapKey,
+ hostAndPushPort,
+ groupedBatchId,
+ std::move(batches),
+ std::move(partitionIds),
+ pushState,
+ weakClient,
+ remainingReviveTimes));
+}
+
+PushMergedDataCallback::PushMergedDataCallback(
+ int shuffleId,
+ int mapId,
+ int attemptId,
+ int numMappers,
+ int numPartitions,
+ const std::string& mapKey,
+ const std::string& hostAndPushPort,
+ int groupedBatchId,
+ std::vector<DataBatch> batches,
+ std::vector<int> partitionIds,
+ std::shared_ptr<PushState> pushState,
+ std::weak_ptr<ShuffleClientImpl> weakClient,
+ int remainingReviveTimes)
+ : shuffleId_(shuffleId),
+ mapId_(mapId),
+ attemptId_(attemptId),
+ numMappers_(numMappers),
+ numPartitions_(numPartitions),
+ mapKey_(mapKey),
+ hostAndPushPort_(hostAndPushPort),
+ groupedBatchId_(groupedBatchId),
+ batches_(std::move(batches)),
+ partitionIds_(std::move(partitionIds)),
+ pushState_(pushState),
+ weakClient_(weakClient),
+ remainingReviveTimes_(remainingReviveTimes) {}
+
+void PushMergedDataCallback::onSuccess(
+ std::unique_ptr<memory::ReadOnlyByteBuffer> response) {
+ auto sharedClient = weakClient_.lock();
+ if (!sharedClient) {
+ LOG(WARNING) << "ShuffleClientImpl has expired when "
+ "PushMergedDataCallbackOnSuccess, ignored, shuffle "
+ << shuffleId_ << " map " << mapId_ << " attempt " <<
attemptId_
+ << " groupedBatch " << groupedBatchId_ << ".";
+ return;
+ }
+
+ if (response->remainingSize() <= 0) {
Review Comment:
On successful PushMergedData responses with an empty body, the callback
removes the batch but never calls `pushState_->onSuccess(hostAndPushPort_)`.
This means the push strategy never gets success signals for merged pushes,
which can skew congestion control / throttling behavior. Call `onSuccess`
before `removeBatch` for the success paths.
##########
cpp/celeborn/protocol/Encoders.cpp:
##########
@@ -33,5 +34,70 @@ std::string decode(memory::ReadOnlyByteBuffer& buffer) {
int size = buffer.read<int>();
return buffer.readToString(size);
}
+
+int encodedLength(const std::vector<std::string>& arr) {
+ int total = sizeof(int);
+ for (const auto& s : arr) {
+ total += encodedLength(s);
+ }
+ return total;
+}
+
+void encode(
+ memory::WriteOnlyByteBuffer& buffer,
+ const std::vector<std::string>& arr) {
+ buffer.write<int>(static_cast<int>(arr.size()));
+ for (const auto& s : arr) {
+ encode(buffer, s);
+ }
+}
+
+std::vector<std::string> decodeStringArray(memory::ReadOnlyByteBuffer& buffer)
{
+ int count = buffer.read<int>();
+ CELEBORN_CHECK_GE(count, 0, "Invalid string array count: {}", count);
+ CELEBORN_CHECK_LE(
+ static_cast<size_t>(count) * sizeof(int),
+ buffer.remainingSize(),
+ "String array count {} exceeds remaining buffer size {}",
+ count,
+ buffer.remainingSize());
Review Comment:
`decodeStringArray` uses `static_cast<size_t>(count) * sizeof(int)` when
checking remaining buffer size. If `count` is very large, the multiplication
can overflow `size_t` and wrap, potentially bypassing the bounds check and
leading to OOM / out-of-bounds reads. Use a division-based bound like `count <=
remainingSize / sizeof(int)` (and keep the `count >= 0` check) to avoid
overflow.
##########
cpp/celeborn/client/ShuffleClient.cpp:
##########
@@ -269,11 +270,447 @@ int ShuffleClientImpl::pushData(
return body->remainingSize();
}
+int ShuffleClientImpl::mergeData(
+ int shuffleId,
+ int mapId,
+ int attemptId,
+ int partitionId,
+ const uint8_t* data,
+ size_t offset,
+ size_t length,
+ int numMappers,
+ int numPartitions) {
+ const auto mapKey = utils::makeMapKey(shuffleId, mapId, attemptId);
+ if (checkMapperEnded(shuffleId, mapId, mapKey)) {
+ return 0;
+ }
+
+ auto partitionLocationMap =
+ getPartitionLocation(shuffleId, numMappers, numPartitions);
+ CELEBORN_CHECK_NOT_NULL(partitionLocationMap);
+ auto partitionLocationOptional = partitionLocationMap->get(partitionId);
+ if (!partitionLocationOptional.has_value()) {
+ if (!revive(
+ shuffleId,
+ mapId,
+ attemptId,
+ partitionId,
+ -1,
+ nullptr,
+ protocol::StatusCode::PUSH_DATA_FAIL_NON_CRITICAL_CAUSE)) {
+ CELEBORN_FAIL(fmt::format(
+ "Revive for shuffleId {} partitionId {} failed.",
+ shuffleId,
+ partitionId));
+ }
+ partitionLocationOptional = partitionLocationMap->get(partitionId);
+ }
+ if (checkMapperEnded(shuffleId, mapId, mapKey)) {
+ return 0;
+ }
+
+ CELEBORN_CHECK(partitionLocationOptional.has_value());
+ auto partitionLocation = partitionLocationOptional.value();
+ auto pushState = getPushState(mapKey);
+ const int nextBatchId = pushState->nextBatchId();
+
+ CELEBORN_CHECK(
+ length <= static_cast<size_t>(std::numeric_limits<int>::max()),
+ fmt::format(
+ "Data length {} exceeds maximum supported size {}",
+ length,
+ std::numeric_limits<int>::max()));
+
+ const uint8_t* dataToWrite = data + offset;
+ int lengthToWrite = static_cast<int>(length);
+ std::unique_ptr<uint8_t[]> compressedBuffer;
+
+ if (shuffleCompressionEnabled_ && compressorFactory_) {
+ auto compressor = compressorFactory_();
+ const size_t compressedCapacity =
+ compressor->getDstCapacity(static_cast<int>(length));
+ compressedBuffer = std::make_unique<uint8_t[]>(compressedCapacity);
+ const size_t compressedSize = compressor->compress(
+ dataToWrite, 0, static_cast<int>(length), compressedBuffer.get(), 0);
+
+ CELEBORN_CHECK(
+ compressedSize <= static_cast<size_t>(std::numeric_limits<int>::max()),
+ fmt::format(
+ "Compressed size {} exceeds maximum supported size {}",
+ compressedSize,
+ std::numeric_limits<int>::max()));
+
+ lengthToWrite = static_cast<int>(compressedSize);
+ dataToWrite = compressedBuffer.get();
+ }
+
+ CELEBORN_CHECK(
+ static_cast<size_t>(lengthToWrite) <=
+ std::numeric_limits<size_t>::max() - kBatchHeaderSize,
+ fmt::format(
+ "Buffer size {} + header {} would overflow",
+ lengthToWrite,
+ kBatchHeaderSize));
+
+ auto writeBuffer = memory::ByteBuffer::createWriteOnly(
+ kBatchHeaderSize + static_cast<size_t>(lengthToWrite));
+ writeBuffer->writeLE<int>(mapId);
+ writeBuffer->writeLE<int>(attemptId);
+ writeBuffer->writeLE<int>(nextBatchId);
+ writeBuffer->writeLE<int>(lengthToWrite);
+ writeBuffer->writeFromBuffer(
+ dataToWrite, 0, static_cast<size_t>(lengthToWrite));
+
+ auto body = memory::ByteBuffer::toReadOnly(std::move(writeBuffer));
+ auto addressPairKey = genAddressPairKey(*partitionLocation);
+
+ bool shouldPush = pushState->addBatchData(
+ addressPairKey, partitionLocation, nextBatchId, std::move(body));
+
+ if (shouldPush) {
+ auto dataBatches = pushState->takeDataBatches(addressPairKey);
+ if (dataBatches) {
+ auto hostAndPushPort = partitionLocation->hostAndPushPort();
+ doPushMergedData(
+ hostAndPushPort,
+ shuffleId,
+ mapId,
+ attemptId,
+ numMappers,
+ numPartitions,
+ mapKey,
+ dataBatches->requireBatches(),
+ pushState,
+ conf_->clientPushMaxReviveTimes());
+ }
+ }
+
+ CELEBORN_CHECK_LE(
+ lengthToWrite,
+ std::numeric_limits<int>::max() - static_cast<int>(kBatchHeaderSize),
+ "Batch bytes size {} + header {} would overflow int",
+ lengthToWrite,
+ kBatchHeaderSize);
+ return static_cast<int>(kBatchHeaderSize) + lengthToWrite;
+}
+
+void ShuffleClientImpl::pushMergedData(
+ int shuffleId,
+ int mapId,
+ int attemptId) {
+ const auto mapKey = utils::makeMapKey(shuffleId, mapId, attemptId);
+ auto pushState = getPushState(mapKey);
+ int pushBufferMaxSize = conf_->clientPushBufferMaxSize();
+
+ auto& batchesMap = pushState->getBatchesMap();
+ std::vector<std::string> keys;
+ batchesMap.forEach(
+ [&keys](const std::string& key, const std::shared_ptr<DataBatches>&) {
+ keys.push_back(key);
+ });
+
+ for (const auto& addressPairKey : keys) {
+ auto dataBatches = pushState->takeDataBatches(addressPairKey);
+ if (!dataBatches) {
+ continue;
+ }
+ while (dataBatches->getTotalSize() > 0) {
+ auto batches = dataBatches->requireBatches(pushBufferMaxSize);
+ if (batches.empty()) {
+ break;
+ }
+ auto hostAndPushPort = batches.front().loc->hostAndPushPort();
+ doPushMergedData(
+ hostAndPushPort,
+ shuffleId,
+ mapId,
+ attemptId,
+ 0,
+ 0,
+ mapKey,
+ std::move(batches),
+ pushState,
+ conf_->clientPushMaxReviveTimes());
+ }
+ }
+}
+
+void ShuffleClientImpl::doPushMergedData(
+ const std::string& hostAndPushPort,
+ int shuffleId,
+ int mapId,
+ int attemptId,
+ int numMappers,
+ int numPartitions,
+ const std::string& mapKey,
+ std::vector<DataBatch> batches,
+ std::shared_ptr<PushState> pushState,
+ int remainReviveTimes) {
+ if (batches.empty()) {
+ return;
+ }
+
+ int groupedBatchId = pushState->nextBatchId();
+ int groupedBatchBytesSize = 0;
+ for (const auto& batch : batches) {
+ CELEBORN_CHECK_LE(
+ batch.body->size(),
+ static_cast<size_t>(INT_MAX),
+ "Batch body size {} exceeds INT_MAX",
+ batch.body->size());
+ int bodySize = static_cast<int>(batch.body->size());
+ CELEBORN_CHECK_LE(
+ bodySize,
+ INT_MAX - groupedBatchBytesSize,
+ "Grouped batch size would overflow: adding {} to {}",
+ bodySize,
+ groupedBatchBytesSize);
+ groupedBatchBytesSize += bodySize;
+ }
+
+ limitMaxInFlight(mapKey, *pushState, hostAndPushPort);
+ pushState->addBatch(groupedBatchId, groupedBatchBytesSize, hostAndPushPort);
+
+ const int numBatches = static_cast<int>(batches.size());
+ std::vector<int> partitionIds(numBatches);
+ std::vector<std::string> partitionUniqueIds(numBatches);
+ std::vector<int32_t> offsets(numBatches);
+ int currentSize = 0;
+
+ // TODO: Each clone() + concat() copies data. Java avoids this with Netty's
+ // zero-copy CompositeByteBuf. We need to build something like an IOBuf chain
+ // directly (ex: via appendToChain) to eliminate the double-clone and O(n^2)
+ // re-copy from repeated pairwise concat.
+ std::vector<std::unique_ptr<memory::ReadOnlyByteBuffer>> bodyParts;
+ for (int i = 0; i < numBatches; i++) {
+ partitionIds[i] = batches[i].loc->id;
+ partitionUniqueIds[i] = batches[i].loc->uniqueId();
+ offsets[i] = currentSize;
+ int bodySize = static_cast<int>(batches[i].body->size());
+ CELEBORN_DCHECK_LE(bodySize, INT_MAX - currentSize);
+ currentSize += bodySize;
+ bodyParts.push_back(batches[i].body->clone());
+ }
+
+ // Concatenate all bodies
+ std::unique_ptr<memory::ReadOnlyByteBuffer> mergedBody;
+ if (bodyParts.size() == 1) {
+ mergedBody = std::move(bodyParts[0]);
+ } else {
+ mergedBody = memory::ByteBuffer::concat(*bodyParts[0], *bodyParts[1]);
+ for (size_t i = 2; i < bodyParts.size(); i++) {
+ mergedBody = memory::ByteBuffer::concat(*mergedBody, *bodyParts[i]);
+ }
+ }
+
+ auto shuffleKey = utils::makeShuffleKey(appUniqueId_, shuffleId);
+ network::PushMergedData pushMergedData(
+ network::Message::nextRequestId(),
+ protocol::PartitionLocation::Mode::PRIMARY,
+ shuffleKey,
+ std::move(partitionUniqueIds),
+ std::move(offsets),
+ std::move(mergedBody));
+
+ auto callback = PushMergedDataCallback::create(
+ shuffleId,
+ mapId,
+ attemptId,
+ numMappers,
+ numPartitions,
+ mapKey,
+ hostAndPushPort,
+ groupedBatchId,
+ std::move(batches),
+ std::move(partitionIds),
+ pushState,
+ weak_from_this(),
+ remainReviveTimes);
+
+ auto host = hostAndPushPort.substr(0, hostAndPushPort.find(':'));
+ auto portStr = hostAndPushPort.substr(hostAndPushPort.find(':') + 1);
+ auto port = static_cast<uint16_t>(std::stoi(portStr));
Review Comment:
`doPushMergedData` parses `hostAndPushPort` using `find(':')`/`substr`,
which breaks for IPv6 hosts containing ':' (and also assumes ':' always
exists). There is already `utils::parseColonSeparatedHostPorts()` explicitly
designed to handle IPv6 (see CelebornUtils.h). Prefer that helper or use
`batches.front().loc->host` + `pushPort` to avoid incorrect parsing and
potential exceptions.
##########
cpp/celeborn/client/writer/PushMergedDataCallback.cpp:
##########
@@ -0,0 +1,343 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "celeborn/client/writer/PushMergedDataCallback.h"
+#include "celeborn/conf/CelebornConf.h"
+#include "celeborn/protocol/TransportMessage.h"
+#include "celeborn/utils/Exceptions.h"
+
+namespace celeborn {
+namespace client {
+
+std::shared_ptr<PushMergedDataCallback> PushMergedDataCallback::create(
+ int shuffleId,
+ int mapId,
+ int attemptId,
+ int numMappers,
+ int numPartitions,
+ const std::string& mapKey,
+ const std::string& hostAndPushPort,
+ int groupedBatchId,
+ std::vector<DataBatch> batches,
+ std::vector<int> partitionIds,
+ std::shared_ptr<PushState> pushState,
+ std::weak_ptr<ShuffleClientImpl> weakClient,
+ int remainingReviveTimes) {
+ return std::shared_ptr<PushMergedDataCallback>(new PushMergedDataCallback(
+ shuffleId,
+ mapId,
+ attemptId,
+ numMappers,
+ numPartitions,
+ mapKey,
+ hostAndPushPort,
+ groupedBatchId,
+ std::move(batches),
+ std::move(partitionIds),
+ pushState,
+ weakClient,
+ remainingReviveTimes));
+}
+
+PushMergedDataCallback::PushMergedDataCallback(
+ int shuffleId,
+ int mapId,
+ int attemptId,
+ int numMappers,
+ int numPartitions,
+ const std::string& mapKey,
+ const std::string& hostAndPushPort,
+ int groupedBatchId,
+ std::vector<DataBatch> batches,
+ std::vector<int> partitionIds,
+ std::shared_ptr<PushState> pushState,
+ std::weak_ptr<ShuffleClientImpl> weakClient,
+ int remainingReviveTimes)
+ : shuffleId_(shuffleId),
+ mapId_(mapId),
+ attemptId_(attemptId),
+ numMappers_(numMappers),
+ numPartitions_(numPartitions),
+ mapKey_(mapKey),
+ hostAndPushPort_(hostAndPushPort),
+ groupedBatchId_(groupedBatchId),
+ batches_(std::move(batches)),
+ partitionIds_(std::move(partitionIds)),
+ pushState_(pushState),
+ weakClient_(weakClient),
+ remainingReviveTimes_(remainingReviveTimes) {}
+
+void PushMergedDataCallback::onSuccess(
+ std::unique_ptr<memory::ReadOnlyByteBuffer> response) {
+ auto sharedClient = weakClient_.lock();
+ if (!sharedClient) {
+ LOG(WARNING) << "ShuffleClientImpl has expired when "
+ "PushMergedDataCallbackOnSuccess, ignored, shuffle "
+ << shuffleId_ << " map " << mapId_ << " attempt " <<
attemptId_
+ << " groupedBatch " << groupedBatchId_ << ".";
+ return;
+ }
+
+ if (response->remainingSize() <= 0) {
+ pushState_->removeBatch(groupedBatchId_, hostAndPushPort_);
+ return;
+ }
+
+ protocol::StatusCode reason =
+ static_cast<protocol::StatusCode>(response->read<uint8_t>());
+ switch (reason) {
+ case protocol::StatusCode::MAP_ENDED: {
+ auto mapperEndSet = sharedClient->mapperEndSets().computeIfAbsent(
+ shuffleId_,
+ []() { return std::make_shared<utils::ConcurrentHashSet<int>>(); });
+ mapperEndSet->insert(mapId_);
+ pushState_->removeBatch(groupedBatchId_, hostAndPushPort_);
+ break;
Review Comment:
In the `MAP_ENDED` success case, the callback updates `mapperEndSets` and
removes the in-flight batch, but it also never calls
`pushState_->onSuccess(hostAndPushPort_)`. For consistency with
`PushDataCallback`/Java client behavior and to keep congestion control
accurate, mark the host as successful before removing the batch.
##########
cpp/celeborn/client/writer/PushMergedDataCallback.cpp:
##########
@@ -0,0 +1,343 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "celeborn/client/writer/PushMergedDataCallback.h"
+#include "celeborn/conf/CelebornConf.h"
+#include "celeborn/protocol/TransportMessage.h"
+#include "celeborn/utils/Exceptions.h"
+
+namespace celeborn {
+namespace client {
+
+std::shared_ptr<PushMergedDataCallback> PushMergedDataCallback::create(
+ int shuffleId,
+ int mapId,
+ int attemptId,
+ int numMappers,
+ int numPartitions,
+ const std::string& mapKey,
+ const std::string& hostAndPushPort,
+ int groupedBatchId,
+ std::vector<DataBatch> batches,
+ std::vector<int> partitionIds,
+ std::shared_ptr<PushState> pushState,
+ std::weak_ptr<ShuffleClientImpl> weakClient,
+ int remainingReviveTimes) {
+ return std::shared_ptr<PushMergedDataCallback>(new PushMergedDataCallback(
+ shuffleId,
+ mapId,
+ attemptId,
+ numMappers,
+ numPartitions,
+ mapKey,
+ hostAndPushPort,
+ groupedBatchId,
+ std::move(batches),
+ std::move(partitionIds),
+ pushState,
+ weakClient,
+ remainingReviveTimes));
+}
+
+PushMergedDataCallback::PushMergedDataCallback(
+ int shuffleId,
+ int mapId,
+ int attemptId,
+ int numMappers,
+ int numPartitions,
+ const std::string& mapKey,
+ const std::string& hostAndPushPort,
+ int groupedBatchId,
+ std::vector<DataBatch> batches,
+ std::vector<int> partitionIds,
+ std::shared_ptr<PushState> pushState,
+ std::weak_ptr<ShuffleClientImpl> weakClient,
+ int remainingReviveTimes)
+ : shuffleId_(shuffleId),
+ mapId_(mapId),
+ attemptId_(attemptId),
+ numMappers_(numMappers),
+ numPartitions_(numPartitions),
+ mapKey_(mapKey),
+ hostAndPushPort_(hostAndPushPort),
+ groupedBatchId_(groupedBatchId),
+ batches_(std::move(batches)),
+ partitionIds_(std::move(partitionIds)),
+ pushState_(pushState),
+ weakClient_(weakClient),
+ remainingReviveTimes_(remainingReviveTimes) {}
+
+void PushMergedDataCallback::onSuccess(
+ std::unique_ptr<memory::ReadOnlyByteBuffer> response) {
+ auto sharedClient = weakClient_.lock();
+ if (!sharedClient) {
+ LOG(WARNING) << "ShuffleClientImpl has expired when "
+ "PushMergedDataCallbackOnSuccess, ignored, shuffle "
+ << shuffleId_ << " map " << mapId_ << " attempt " <<
attemptId_
+ << " groupedBatch " << groupedBatchId_ << ".";
+ return;
+ }
+
+ if (response->remainingSize() <= 0) {
+ pushState_->removeBatch(groupedBatchId_, hostAndPushPort_);
+ return;
+ }
+
+ protocol::StatusCode reason =
+ static_cast<protocol::StatusCode>(response->read<uint8_t>());
+ switch (reason) {
+ case protocol::StatusCode::MAP_ENDED: {
+ auto mapperEndSet = sharedClient->mapperEndSets().computeIfAbsent(
+ shuffleId_,
+ []() { return std::make_shared<utils::ConcurrentHashSet<int>>(); });
+ mapperEndSet->insert(mapId_);
+ pushState_->removeBatch(groupedBatchId_, hostAndPushPort_);
+ break;
+ }
+ case protocol::StatusCode::HARD_SPLIT:
+ case protocol::StatusCode::SOFT_SPLIT: {
+ VLOG(1) << "Push merged data to " << hostAndPushPort_
+ << " split required for shuffle " << shuffleId_ << " map "
+ << mapId_ << " attempt " << attemptId_ << " groupedBatch "
+ << groupedBatchId_ << ".";
+
+ if (response->remainingSize() > 0) {
+ // Parse PbPushMergedDataSplitPartitionInfo from TransportMessage
+ auto transportMsg = std::make_unique<protocol::TransportMessage>(
+ response->readToReadOnlyBuffer(response->remainingSize()));
+ PbPushMergedDataSplitPartitionInfo partitionInfo;
+ if (!partitionInfo.ParseFromString(transportMsg->payload())) {
+ pushState_->setException(std::make_unique<std::runtime_error>(
+ "Failed to parse PbPushMergedDataSplitPartitionInfo"));
+ return;
+ }
+
+ CELEBORN_CHECK_EQ(
+ partitionInfo.statuscodes_size(),
+ partitionInfo.splitpartitionindexes_size(),
+ "Mismatched sizes: statuscodes {} vs splitpartitionindexes {}",
+ partitionInfo.statuscodes_size(),
+ partitionInfo.splitpartitionindexes_size());
+ const int numBatches = static_cast<int>(batches_.size());
+ for (int i = 0; i < partitionInfo.splitpartitionindexes_size(); i++) {
+ int partitionIndex = partitionInfo.splitpartitionindexes(i);
+ CELEBORN_CHECK_GE(partitionIndex, 0);
+ CELEBORN_CHECK_LT(
+ partitionIndex,
+ numBatches,
+ "Partition index {} out of range [0, {})",
+ partitionIndex,
+ numBatches);
+ int statusCode = partitionInfo.statuscodes(i);
+
+ if (statusCode ==
+ static_cast<int>(protocol::StatusCode::SOFT_SPLIT)) {
+ int partitionId = partitionIds_[partitionIndex];
+ if (!ShuffleClientImpl::newerPartitionLocationExists(
+ sharedClient->getPartitionLocationMap(shuffleId_).value(),
+ partitionId,
+ batches_[partitionIndex].loc->epoch)) {
+ auto reviveRequest = std::make_shared<protocol::ReviveRequest>(
+ shuffleId_,
+ mapId_,
+ attemptId_,
+ partitionId,
+ batches_[partitionIndex].loc->epoch,
+ batches_[partitionIndex].loc,
+ protocol::StatusCode::SOFT_SPLIT);
+ sharedClient->addRequestToReviveManager(reviveRequest);
+ }
+ }
+ }
+
+ // For any HARD_SPLIT partitions, need to resubmit
+ std::vector<DataBatch> batchesToRetry;
+ std::vector<std::shared_ptr<protocol::ReviveRequest>> reviveRequests;
+ for (int i = 0; i < partitionInfo.splitpartitionindexes_size(); i++) {
+ int partitionIndex = partitionInfo.splitpartitionindexes(i);
+ CELEBORN_DCHECK_GE(partitionIndex, 0);
+ CELEBORN_DCHECK_LT(partitionIndex, numBatches);
+ int statusCode = partitionInfo.statuscodes(i);
+ if (statusCode ==
+ static_cast<int>(protocol::StatusCode::HARD_SPLIT)) {
+ int partitionId = partitionIds_[partitionIndex];
+ auto reviveRequest = std::make_shared<protocol::ReviveRequest>(
+ shuffleId_,
+ mapId_,
+ attemptId_,
+ partitionId,
+ batches_[partitionIndex].loc->epoch,
+ batches_[partitionIndex].loc,
+ protocol::StatusCode::HARD_SPLIT);
+ sharedClient->addRequestToReviveManager(reviveRequest);
+ reviveRequests.push_back(reviveRequest);
+ batchesToRetry.push_back(std::move(batches_[partitionIndex]));
+ }
+ }
+
+ if (!batchesToRetry.empty()) {
+ long dueTimeMs = utils::currentTimeMillis() +
+ sharedClient->conf_
+ ->clientRpcRequestPartitionLocationRpcAskTimeout() /
+ utils::MS(1);
+ sharedClient->submitRetryPushMergedData(
+ shuffleId_,
+ mapId_,
+ attemptId_,
+ numMappers_,
+ numPartitions_,
+ mapKey_,
+ std::move(batchesToRetry),
+ std::move(reviveRequests),
+ groupedBatchId_,
+ pushState_,
+ remainingReviveTimes_,
+ dueTimeMs);
+ } else {
+ pushState_->removeBatch(groupedBatchId_, hostAndPushPort_);
+ }
+ } else {
+ // Old worker without per-partition split info: revive all batches
+ std::vector<std::shared_ptr<protocol::ReviveRequest>> reviveRequests;
+ for (size_t i = 0; i < batches_.size(); i++) {
+ auto reviveRequest = std::make_shared<protocol::ReviveRequest>(
+ shuffleId_,
+ mapId_,
+ attemptId_,
+ partitionIds_[i],
+ batches_[i].loc->epoch,
+ batches_[i].loc,
+ reason);
+ sharedClient->addRequestToReviveManager(reviveRequest);
+ reviveRequests.push_back(reviveRequest);
+ }
+
+ long dueTimeMs = utils::currentTimeMillis() +
+ sharedClient->conf_
+ ->clientRpcRequestPartitionLocationRpcAskTimeout() /
+ utils::MS(1);
+ sharedClient->submitRetryPushMergedData(
+ shuffleId_,
+ mapId_,
+ attemptId_,
+ numMappers_,
+ numPartitions_,
+ mapKey_,
+ std::move(batches_),
+ std::move(reviveRequests),
+ groupedBatchId_,
+ pushState_,
+ remainingReviveTimes_,
+ dueTimeMs);
+ }
+ break;
+ }
+ case protocol::StatusCode::PUSH_DATA_SUCCESS_PRIMARY_CONGESTED: {
+ VLOG(1) << "Push merged data to " << hostAndPushPort_
+ << " primary congestion for shuffle " << shuffleId_ << " map "
+ << mapId_ << " attempt " << attemptId_ << " groupedBatch "
+ << groupedBatchId_ << ".";
+ pushState_->onCongestControl(hostAndPushPort_);
+ pushState_->removeBatch(groupedBatchId_, hostAndPushPort_);
+ break;
+ }
+ case protocol::StatusCode::PUSH_DATA_SUCCESS_REPLICA_CONGESTED: {
+ VLOG(1) << "Push merged data to " << hostAndPushPort_
+ << " replicate congestion for shuffle " << shuffleId_ << " map "
+ << mapId_ << " attempt " << attemptId_ << " groupedBatch "
+ << groupedBatchId_ << ".";
+ pushState_->onCongestControl(hostAndPushPort_);
+ pushState_->removeBatch(groupedBatchId_, hostAndPushPort_);
+ break;
+ }
+ default: {
+ LOG(WARNING) << "unhandled PushMergedData success StatusCode: " <<
reason;
+ pushState_->removeBatch(groupedBatchId_, hostAndPushPort_);
+ }
Review Comment:
The default (non-split, non-congestion) PushMergedData success path removes
the batch but never calls `pushState_->onSuccess(hostAndPushPort_)`. Without
success feedback, `PushStrategy` may remain overly throttled. Add an
`onSuccess` call in this default-success branch too.
##########
cpp/celeborn/client/writer/PushState.cpp:
##########
@@ -235,6 +237,32 @@ void PushState::cleanup() {
}
}
+bool PushState::addBatchData(
+ const std::string& addressPairKey,
+ std::shared_ptr<const protocol::PartitionLocation> loc,
+ int batchId,
+ std::unique_ptr<memory::ReadOnlyByteBuffer> body) {
+ auto batches = batchesMap_.computeIfAbsent(
+ addressPairKey, [&]() { return std::make_shared<DataBatches>(); });
+ batches->addDataBatch(std::move(loc), batchId, std::move(body));
+ return batches->getTotalSize() > pushBufferMaxSize_;
+}
+
+std::shared_ptr<DataBatches> PushState::takeDataBatches(
+ const std::string& addressPairKey) {
+ auto result = batchesMap_.get(addressPairKey);
+ if (result.has_value()) {
+ batchesMap_.erase(addressPairKey);
+ return result.value();
Review Comment:
`PushState::takeDataBatches` does a non-atomic `get()` followed by
`erase()`. Because `ConcurrentHashMap` is mutex-protected per operation, two
threads can both `get()` the same `DataBatches` before either `erase()`,
leading to duplicate pushes of the same batches. Make this a single atomic
operation by using `batchesMap_.erase(addressPairKey)` (which already returns
the removed value) and return that.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]