Copilot commented on code in PR #3611:
URL: https://github.com/apache/celeborn/pull/3611#discussion_r2909014936
##########
cpp/celeborn/protocol/Encoders.cpp:
##########
@@ -33,5 +33,56 @@ std::string decode(memory::ReadOnlyByteBuffer& buffer) {
int size = buffer.read<int>();
return buffer.readToString(size);
}
+
+int encodedLength(const std::vector<std::string>& arr) {
+ int total = sizeof(int);
+ for (const auto& s : arr) {
+ total += encodedLength(s);
+ }
+ return total;
+}
+
+void encode(
+ memory::WriteOnlyByteBuffer& buffer,
+ const std::vector<std::string>& arr) {
+ buffer.write<int>(static_cast<int>(arr.size()));
+ for (const auto& s : arr) {
+ encode(buffer, s);
+ }
+}
+
+std::vector<std::string> decodeStringArray(memory::ReadOnlyByteBuffer& buffer)
{
+ int count = buffer.read<int>();
+ std::vector<std::string> result;
+ result.reserve(count);
+ for (int i = 0; i < count; i++) {
Review Comment:
`decodeStringArray` reads `count` from the buffer and immediately calls
`result.reserve(count)` without validating that `count` is non-negative and
reasonable for the remaining buffer size. A malformed/corrupted message could
make `count` negative (converted to huge `size_t`) or extremely large, leading
to excessive allocation or crash. Please add validation (e.g., `count >= 0` and
an upper bound derived from `remainingSize()`), and fail fast if invalid.
##########
cpp/celeborn/protocol/Encoders.cpp:
##########
@@ -33,5 +33,56 @@ std::string decode(memory::ReadOnlyByteBuffer& buffer) {
int size = buffer.read<int>();
return buffer.readToString(size);
}
+
+int encodedLength(const std::vector<std::string>& arr) {
+ int total = sizeof(int);
+ for (const auto& s : arr) {
+ total += encodedLength(s);
+ }
+ return total;
+}
+
+void encode(
+ memory::WriteOnlyByteBuffer& buffer,
+ const std::vector<std::string>& arr) {
+ buffer.write<int>(static_cast<int>(arr.size()));
+ for (const auto& s : arr) {
+ encode(buffer, s);
+ }
+}
+
+std::vector<std::string> decodeStringArray(memory::ReadOnlyByteBuffer& buffer)
{
+ int count = buffer.read<int>();
+ std::vector<std::string> result;
+ result.reserve(count);
+ for (int i = 0; i < count; i++) {
+ result.push_back(decode(buffer));
+ }
+ return result;
+}
+
+int encodedLength(const std::vector<int32_t>& arr) {
+ return sizeof(int) + sizeof(int32_t) * arr.size();
+}
+
+void encode(
+ memory::WriteOnlyByteBuffer& buffer,
+ const std::vector<int32_t>& arr) {
+ buffer.write<int>(static_cast<int>(arr.size()));
+ for (auto val : arr) {
+ buffer.write<int32_t>(val);
+ }
+}
+
+std::vector<int32_t> decodeIntArray(memory::ReadOnlyByteBuffer& buffer) {
+ int count = buffer.read<int>();
+ std::vector<int32_t> result;
+ result.reserve(count);
+ for (int i = 0; i < count; i++) {
+ result.push_back(buffer.read<int32_t>());
Review Comment:
`decodeIntArray` has the same issue as `decodeStringArray`: it reserves
`count` elements without validating `count` from the wire. Please validate
`count` (non-negative + bounded) before reserving/looping to avoid memory
blowups or undefined behavior on malformed input.
##########
cpp/celeborn/client/ShuffleClient.cpp:
##########
@@ -269,11 +270,423 @@ int ShuffleClientImpl::pushData(
return body->remainingSize();
}
+int ShuffleClientImpl::mergeData(
+ int shuffleId,
+ int mapId,
+ int attemptId,
+ int partitionId,
+ const uint8_t* data,
+ size_t offset,
+ size_t length,
+ int numMappers,
+ int numPartitions) {
+ const auto mapKey = utils::makeMapKey(shuffleId, mapId, attemptId);
+ if (checkMapperEnded(shuffleId, mapId, mapKey)) {
+ return 0;
+ }
+
+ auto partitionLocationMap =
+ getPartitionLocation(shuffleId, numMappers, numPartitions);
+ CELEBORN_CHECK_NOT_NULL(partitionLocationMap);
+ auto partitionLocationOptional = partitionLocationMap->get(partitionId);
+ if (!partitionLocationOptional.has_value()) {
+ if (!revive(
+ shuffleId,
+ mapId,
+ attemptId,
+ partitionId,
+ -1,
+ nullptr,
+ protocol::StatusCode::PUSH_DATA_FAIL_NON_CRITICAL_CAUSE)) {
+ CELEBORN_FAIL(fmt::format(
+ "Revive for shuffleId {} partitionId {} failed.",
+ shuffleId,
+ partitionId));
+ }
+ partitionLocationOptional = partitionLocationMap->get(partitionId);
+ }
+ if (checkMapperEnded(shuffleId, mapId, mapKey)) {
+ return 0;
+ }
+
+ CELEBORN_CHECK(partitionLocationOptional.has_value());
+ auto partitionLocation = partitionLocationOptional.value();
+ auto pushState = getPushState(mapKey);
+ const int nextBatchId = pushState->nextBatchId();
+
+ CELEBORN_CHECK(
+ length <= static_cast<size_t>(std::numeric_limits<int>::max()),
+ fmt::format(
+ "Data length {} exceeds maximum supported size {}",
+ length,
+ std::numeric_limits<int>::max()));
+
+ const uint8_t* dataToWrite = data + offset;
+ int lengthToWrite = static_cast<int>(length);
+ std::unique_ptr<uint8_t[]> compressedBuffer;
+
+ if (shuffleCompressionEnabled_ && compressorFactory_) {
+ auto compressor = compressorFactory_();
+ const size_t compressedCapacity =
+ compressor->getDstCapacity(static_cast<int>(length));
+ compressedBuffer = std::make_unique<uint8_t[]>(compressedCapacity);
+ const size_t compressedSize = compressor->compress(
+ dataToWrite, 0, static_cast<int>(length), compressedBuffer.get(), 0);
+
+ CELEBORN_CHECK(
+ compressedSize <= static_cast<size_t>(std::numeric_limits<int>::max()),
+ fmt::format(
+ "Compressed size {} exceeds maximum supported size {}",
+ compressedSize,
+ std::numeric_limits<int>::max()));
+
+ lengthToWrite = static_cast<int>(compressedSize);
+ dataToWrite = compressedBuffer.get();
+ }
+
+ CELEBORN_CHECK(
+ static_cast<size_t>(lengthToWrite) <=
+ std::numeric_limits<size_t>::max() - kBatchHeaderSize,
+ fmt::format(
+ "Buffer size {} + header {} would overflow",
+ lengthToWrite,
+ kBatchHeaderSize));
+
+ auto writeBuffer = memory::ByteBuffer::createWriteOnly(
+ kBatchHeaderSize + static_cast<size_t>(lengthToWrite));
+ writeBuffer->writeLE<int>(mapId);
+ writeBuffer->writeLE<int>(attemptId);
+ writeBuffer->writeLE<int>(nextBatchId);
+ writeBuffer->writeLE<int>(lengthToWrite);
+ writeBuffer->writeFromBuffer(
+ dataToWrite, 0, static_cast<size_t>(lengthToWrite));
+
+ auto body = memory::ByteBuffer::toReadOnly(std::move(writeBuffer));
+ auto addressPairKey = genAddressPairKey(*partitionLocation);
+
+ bool shouldPush = pushState->addBatchData(
+ addressPairKey, partitionLocation, nextBatchId, std::move(body));
+
+ if (shouldPush) {
+ auto dataBatches = pushState->takeDataBatches(addressPairKey);
+ if (dataBatches) {
+ auto hostAndPushPort = partitionLocation->hostAndPushPort();
+ doPushMergedData(
+ hostAndPushPort,
+ shuffleId,
+ mapId,
+ attemptId,
+ numMappers,
+ numPartitions,
+ mapKey,
+ dataBatches->requireBatches(),
+ pushState,
+ conf_->clientPushMaxReviveTimes());
+ }
+ }
+
+ return static_cast<int>(kBatchHeaderSize) + lengthToWrite;
+}
+
+void ShuffleClientImpl::pushMergedData(
+ int shuffleId,
+ int mapId,
+ int attemptId) {
+ const auto mapKey = utils::makeMapKey(shuffleId, mapId, attemptId);
+ auto pushState = getPushState(mapKey);
+ int pushBufferMaxSize = conf_->clientPushBufferMaxSize();
+
+ auto& batchesMap = pushState->getBatchesMap();
+ std::vector<std::string> keys;
+ batchesMap.forEach(
+ [&keys](const std::string& key, const std::shared_ptr<DataBatches>&) {
+ keys.push_back(key);
+ });
+
+ for (const auto& addressPairKey : keys) {
+ auto dataBatches = pushState->takeDataBatches(addressPairKey);
+ if (!dataBatches) {
+ continue;
+ }
+ while (dataBatches->getTotalSize() > 0) {
+ auto batches = dataBatches->requireBatches(pushBufferMaxSize);
+ if (batches.empty()) {
+ break;
+ }
+ auto hostAndPushPort = batches.front().loc->hostAndPushPort();
+ doPushMergedData(
+ hostAndPushPort,
+ shuffleId,
+ mapId,
+ attemptId,
+ 0,
+ 0,
+ mapKey,
+ std::move(batches),
+ pushState,
+ conf_->clientPushMaxReviveTimes());
+ }
+ }
+}
+
+void ShuffleClientImpl::doPushMergedData(
+ const std::string& hostAndPushPort,
+ int shuffleId,
+ int mapId,
+ int attemptId,
+ int numMappers,
+ int numPartitions,
+ const std::string& mapKey,
+ std::vector<DataBatch> batches,
+ std::shared_ptr<PushState> pushState,
+ int remainReviveTimes) {
+ if (batches.empty()) {
+ return;
+ }
+
+ int groupedBatchId = pushState->nextBatchId();
+ int groupedBatchBytesSize = 0;
+ for (const auto& batch : batches) {
+ groupedBatchBytesSize += static_cast<int>(batch.body->size());
+ }
+
+ limitMaxInFlight(mapKey, *pushState, hostAndPushPort);
+ pushState->addBatch(groupedBatchId, groupedBatchBytesSize, hostAndPushPort);
Review Comment:
In `doPushMergedData`, `groupedBatchBytesSize` sums `batch.body->size()`
into an `int` without overflow checks. If many batches are merged (or the
payload is large), this can overflow and corrupt in-flight byte accounting.
Please accumulate into `size_t`/`int64_t` and validate it fits into the type
expected by `PushState::addBatch` (or change `addBatch` to accept a wider type).
##########
cpp/celeborn/client/ShuffleClient.cpp:
##########
@@ -269,11 +270,423 @@ int ShuffleClientImpl::pushData(
return body->remainingSize();
}
+int ShuffleClientImpl::mergeData(
+ int shuffleId,
+ int mapId,
+ int attemptId,
+ int partitionId,
+ const uint8_t* data,
+ size_t offset,
+ size_t length,
+ int numMappers,
+ int numPartitions) {
+ const auto mapKey = utils::makeMapKey(shuffleId, mapId, attemptId);
+ if (checkMapperEnded(shuffleId, mapId, mapKey)) {
+ return 0;
+ }
+
+ auto partitionLocationMap =
+ getPartitionLocation(shuffleId, numMappers, numPartitions);
+ CELEBORN_CHECK_NOT_NULL(partitionLocationMap);
+ auto partitionLocationOptional = partitionLocationMap->get(partitionId);
+ if (!partitionLocationOptional.has_value()) {
+ if (!revive(
+ shuffleId,
+ mapId,
+ attemptId,
+ partitionId,
+ -1,
+ nullptr,
+ protocol::StatusCode::PUSH_DATA_FAIL_NON_CRITICAL_CAUSE)) {
+ CELEBORN_FAIL(fmt::format(
+ "Revive for shuffleId {} partitionId {} failed.",
+ shuffleId,
+ partitionId));
+ }
+ partitionLocationOptional = partitionLocationMap->get(partitionId);
+ }
+ if (checkMapperEnded(shuffleId, mapId, mapKey)) {
+ return 0;
+ }
+
+ CELEBORN_CHECK(partitionLocationOptional.has_value());
+ auto partitionLocation = partitionLocationOptional.value();
+ auto pushState = getPushState(mapKey);
+ const int nextBatchId = pushState->nextBatchId();
+
+ CELEBORN_CHECK(
+ length <= static_cast<size_t>(std::numeric_limits<int>::max()),
+ fmt::format(
+ "Data length {} exceeds maximum supported size {}",
+ length,
+ std::numeric_limits<int>::max()));
+
+ const uint8_t* dataToWrite = data + offset;
+ int lengthToWrite = static_cast<int>(length);
+ std::unique_ptr<uint8_t[]> compressedBuffer;
+
+ if (shuffleCompressionEnabled_ && compressorFactory_) {
+ auto compressor = compressorFactory_();
+ const size_t compressedCapacity =
+ compressor->getDstCapacity(static_cast<int>(length));
+ compressedBuffer = std::make_unique<uint8_t[]>(compressedCapacity);
+ const size_t compressedSize = compressor->compress(
+ dataToWrite, 0, static_cast<int>(length), compressedBuffer.get(), 0);
+
+ CELEBORN_CHECK(
+ compressedSize <= static_cast<size_t>(std::numeric_limits<int>::max()),
+ fmt::format(
+ "Compressed size {} exceeds maximum supported size {}",
+ compressedSize,
+ std::numeric_limits<int>::max()));
+
+ lengthToWrite = static_cast<int>(compressedSize);
+ dataToWrite = compressedBuffer.get();
+ }
+
+ CELEBORN_CHECK(
+ static_cast<size_t>(lengthToWrite) <=
+ std::numeric_limits<size_t>::max() - kBatchHeaderSize,
+ fmt::format(
+ "Buffer size {} + header {} would overflow",
+ lengthToWrite,
+ kBatchHeaderSize));
+
+ auto writeBuffer = memory::ByteBuffer::createWriteOnly(
+ kBatchHeaderSize + static_cast<size_t>(lengthToWrite));
+ writeBuffer->writeLE<int>(mapId);
+ writeBuffer->writeLE<int>(attemptId);
+ writeBuffer->writeLE<int>(nextBatchId);
+ writeBuffer->writeLE<int>(lengthToWrite);
+ writeBuffer->writeFromBuffer(
+ dataToWrite, 0, static_cast<size_t>(lengthToWrite));
+
+ auto body = memory::ByteBuffer::toReadOnly(std::move(writeBuffer));
+ auto addressPairKey = genAddressPairKey(*partitionLocation);
+
+ bool shouldPush = pushState->addBatchData(
+ addressPairKey, partitionLocation, nextBatchId, std::move(body));
+
+ if (shouldPush) {
+ auto dataBatches = pushState->takeDataBatches(addressPairKey);
+ if (dataBatches) {
+ auto hostAndPushPort = partitionLocation->hostAndPushPort();
+ doPushMergedData(
+ hostAndPushPort,
+ shuffleId,
+ mapId,
+ attemptId,
+ numMappers,
+ numPartitions,
+ mapKey,
+ dataBatches->requireBatches(),
+ pushState,
+ conf_->clientPushMaxReviveTimes());
+ }
+ }
+
+ return static_cast<int>(kBatchHeaderSize) + lengthToWrite;
+}
Review Comment:
`mergeData` returns `static_cast<int>(kBatchHeaderSize) + lengthToWrite`
without the overflow guard that exists in `pushData` (`lengthToWrite <= INT_MAX
- kBatchHeaderSize`). If `lengthToWrite` is near `INT_MAX`, the return value
(and any downstream accounting) can overflow. Please add the same int overflow
check used in `pushData` before returning/using the combined size.
##########
cpp/celeborn/client/writer/DataBatches.cpp:
##########
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "celeborn/client/writer/DataBatches.h"
+
+namespace celeborn {
+namespace client {
+
+void DataBatches::addDataBatch(
+ std::shared_ptr<const protocol::PartitionLocation> loc,
+ int batchId,
+ std::unique_ptr<memory::ReadOnlyByteBuffer> body) {
+ std::lock_guard<std::mutex> lock(mutex_);
+ int bodySize = static_cast<int>(body->size());
+ batches_.emplace_back(std::move(loc), batchId, std::move(body));
+ totalSize_ += bodySize;
+}
Review Comment:
`addDataBatch` casts `body->size()` to `int` and accumulates it into
`totalSize_` (also an `int`) with no overflow checks. If a batch body exceeds
`INT_MAX` (or many batches accumulate), this can overflow and break the push
threshold logic. Consider storing sizes as `size_t`/`int64_t` and/or adding
explicit bounds checks before casting.
##########
cpp/celeborn/client/writer/PushState.cpp:
##########
@@ -235,6 +236,32 @@ void PushState::cleanup() {
}
}
+bool PushState::addBatchData(
+ const std::string& addressPairKey,
+ std::shared_ptr<const protocol::PartitionLocation> loc,
+ int batchId,
+ std::unique_ptr<memory::ReadOnlyByteBuffer> body) {
+ auto batches = batchesMap_.computeIfAbsent(
+ addressPairKey, [&]() { return std::make_shared<DataBatches>(); });
+ batches->addDataBatch(std::move(loc), batchId, std::move(body));
+ return batches->getTotalSize() > pushBufferMaxSize_;
+}
Review Comment:
`PushState::cleanup()` is used when a mapper has already ended (see
`ShuffleClientImpl::checkMapperEnded`), but the new `batchesMap_` is not
cleared anywhere. That can retain buffered batch bodies in memory after mapper
end/cleanup. Consider clearing `batchesMap_` as part of `cleanup()` (or
otherwise ensuring pending `DataBatches` are released) to avoid memory
retention.
##########
cpp/celeborn/client/writer/DataBatches.cpp:
##########
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "celeborn/client/writer/DataBatches.h"
+
+namespace celeborn {
+namespace client {
+
+void DataBatches::addDataBatch(
+ std::shared_ptr<const protocol::PartitionLocation> loc,
+ int batchId,
+ std::unique_ptr<memory::ReadOnlyByteBuffer> body) {
+ std::lock_guard<std::mutex> lock(mutex_);
+ int bodySize = static_cast<int>(body->size());
+ batches_.emplace_back(std::move(loc), batchId, std::move(body));
+ totalSize_ += bodySize;
+}
+
+int DataBatches::getTotalSize() const {
+ std::lock_guard<std::mutex> lock(mutex_);
+ return totalSize_;
+}
+
+std::vector<DataBatch> DataBatches::requireBatches() {
+ std::lock_guard<std::mutex> lock(mutex_);
+ totalSize_ = 0;
+ return std::move(batches_);
+}
+
+std::vector<DataBatch> DataBatches::requireBatches(int requestSize) {
+ std::lock_guard<std::mutex> lock(mutex_);
+ if (requestSize >= totalSize_) {
+ totalSize_ = 0;
+ return std::move(batches_);
+ }
+ std::vector<DataBatch> result;
+ int currentSize = 0;
+ while (currentSize < requestSize && !batches_.empty()) {
+ int bodySize = static_cast<int>(batches_.front().body->size());
+ result.push_back(std::move(batches_.front()));
+ batches_.erase(batches_.begin());
+ currentSize += bodySize;
Review Comment:
`requireBatches(int)` repeatedly does `erase(batches_.begin())` inside a
loop, which is O(n^2) due to element shifting on every erase. This can become a
bottleneck when many batches are buffered. Consider using an index and erasing
a range once, switching `batches_` to `std::deque`, or moving out the prefix
with a single `std::move`/`erase` call.
##########
cpp/celeborn/client/ShuffleClient.cpp:
##########
@@ -269,11 +270,423 @@ int ShuffleClientImpl::pushData(
return body->remainingSize();
}
+int ShuffleClientImpl::mergeData(
+ int shuffleId,
+ int mapId,
+ int attemptId,
+ int partitionId,
+ const uint8_t* data,
+ size_t offset,
+ size_t length,
+ int numMappers,
+ int numPartitions) {
+ const auto mapKey = utils::makeMapKey(shuffleId, mapId, attemptId);
+ if (checkMapperEnded(shuffleId, mapId, mapKey)) {
+ return 0;
+ }
+
+ auto partitionLocationMap =
+ getPartitionLocation(shuffleId, numMappers, numPartitions);
+ CELEBORN_CHECK_NOT_NULL(partitionLocationMap);
+ auto partitionLocationOptional = partitionLocationMap->get(partitionId);
+ if (!partitionLocationOptional.has_value()) {
+ if (!revive(
+ shuffleId,
+ mapId,
+ attemptId,
+ partitionId,
+ -1,
+ nullptr,
+ protocol::StatusCode::PUSH_DATA_FAIL_NON_CRITICAL_CAUSE)) {
+ CELEBORN_FAIL(fmt::format(
+ "Revive for shuffleId {} partitionId {} failed.",
+ shuffleId,
+ partitionId));
+ }
+ partitionLocationOptional = partitionLocationMap->get(partitionId);
+ }
+ if (checkMapperEnded(shuffleId, mapId, mapKey)) {
+ return 0;
+ }
+
+ CELEBORN_CHECK(partitionLocationOptional.has_value());
+ auto partitionLocation = partitionLocationOptional.value();
+ auto pushState = getPushState(mapKey);
+ const int nextBatchId = pushState->nextBatchId();
+
+ CELEBORN_CHECK(
+ length <= static_cast<size_t>(std::numeric_limits<int>::max()),
+ fmt::format(
+ "Data length {} exceeds maximum supported size {}",
+ length,
+ std::numeric_limits<int>::max()));
+
+ const uint8_t* dataToWrite = data + offset;
+ int lengthToWrite = static_cast<int>(length);
+ std::unique_ptr<uint8_t[]> compressedBuffer;
+
+ if (shuffleCompressionEnabled_ && compressorFactory_) {
+ auto compressor = compressorFactory_();
+ const size_t compressedCapacity =
+ compressor->getDstCapacity(static_cast<int>(length));
+ compressedBuffer = std::make_unique<uint8_t[]>(compressedCapacity);
+ const size_t compressedSize = compressor->compress(
+ dataToWrite, 0, static_cast<int>(length), compressedBuffer.get(), 0);
+
+ CELEBORN_CHECK(
+ compressedSize <= static_cast<size_t>(std::numeric_limits<int>::max()),
+ fmt::format(
+ "Compressed size {} exceeds maximum supported size {}",
+ compressedSize,
+ std::numeric_limits<int>::max()));
+
+ lengthToWrite = static_cast<int>(compressedSize);
+ dataToWrite = compressedBuffer.get();
+ }
+
+ CELEBORN_CHECK(
+ static_cast<size_t>(lengthToWrite) <=
+ std::numeric_limits<size_t>::max() - kBatchHeaderSize,
+ fmt::format(
+ "Buffer size {} + header {} would overflow",
+ lengthToWrite,
+ kBatchHeaderSize));
+
+ auto writeBuffer = memory::ByteBuffer::createWriteOnly(
+ kBatchHeaderSize + static_cast<size_t>(lengthToWrite));
+ writeBuffer->writeLE<int>(mapId);
+ writeBuffer->writeLE<int>(attemptId);
+ writeBuffer->writeLE<int>(nextBatchId);
+ writeBuffer->writeLE<int>(lengthToWrite);
+ writeBuffer->writeFromBuffer(
+ dataToWrite, 0, static_cast<size_t>(lengthToWrite));
+
+ auto body = memory::ByteBuffer::toReadOnly(std::move(writeBuffer));
+ auto addressPairKey = genAddressPairKey(*partitionLocation);
+
+ bool shouldPush = pushState->addBatchData(
+ addressPairKey, partitionLocation, nextBatchId, std::move(body));
+
+ if (shouldPush) {
+ auto dataBatches = pushState->takeDataBatches(addressPairKey);
+ if (dataBatches) {
+ auto hostAndPushPort = partitionLocation->hostAndPushPort();
+ doPushMergedData(
+ hostAndPushPort,
+ shuffleId,
+ mapId,
+ attemptId,
+ numMappers,
+ numPartitions,
+ mapKey,
+ dataBatches->requireBatches(),
+ pushState,
+ conf_->clientPushMaxReviveTimes());
+ }
+ }
+
+ return static_cast<int>(kBatchHeaderSize) + lengthToWrite;
+}
+
+void ShuffleClientImpl::pushMergedData(
+ int shuffleId,
+ int mapId,
+ int attemptId) {
+ const auto mapKey = utils::makeMapKey(shuffleId, mapId, attemptId);
+ auto pushState = getPushState(mapKey);
+ int pushBufferMaxSize = conf_->clientPushBufferMaxSize();
+
+ auto& batchesMap = pushState->getBatchesMap();
+ std::vector<std::string> keys;
+ batchesMap.forEach(
+ [&keys](const std::string& key, const std::shared_ptr<DataBatches>&) {
+ keys.push_back(key);
+ });
+
+ for (const auto& addressPairKey : keys) {
+ auto dataBatches = pushState->takeDataBatches(addressPairKey);
+ if (!dataBatches) {
+ continue;
+ }
+ while (dataBatches->getTotalSize() > 0) {
+ auto batches = dataBatches->requireBatches(pushBufferMaxSize);
+ if (batches.empty()) {
+ break;
+ }
+ auto hostAndPushPort = batches.front().loc->hostAndPushPort();
+ doPushMergedData(
+ hostAndPushPort,
+ shuffleId,
+ mapId,
+ attemptId,
+ 0,
+ 0,
+ mapKey,
+ std::move(batches),
+ pushState,
+ conf_->clientPushMaxReviveTimes());
+ }
+ }
+}
+
+void ShuffleClientImpl::doPushMergedData(
+ const std::string& hostAndPushPort,
+ int shuffleId,
+ int mapId,
+ int attemptId,
+ int numMappers,
+ int numPartitions,
+ const std::string& mapKey,
+ std::vector<DataBatch> batches,
+ std::shared_ptr<PushState> pushState,
+ int remainReviveTimes) {
+ if (batches.empty()) {
+ return;
+ }
+
+ int groupedBatchId = pushState->nextBatchId();
+ int groupedBatchBytesSize = 0;
+ for (const auto& batch : batches) {
+ groupedBatchBytesSize += static_cast<int>(batch.body->size());
+ }
+
+ limitMaxInFlight(mapKey, *pushState, hostAndPushPort);
+ pushState->addBatch(groupedBatchId, groupedBatchBytesSize, hostAndPushPort);
+
+ const int numBatches = static_cast<int>(batches.size());
+ std::vector<int> partitionIds(numBatches);
+ std::vector<std::string> partitionUniqueIds(numBatches);
+ std::vector<int32_t> offsets(numBatches);
+ int currentSize = 0;
+
+ std::vector<std::unique_ptr<memory::ReadOnlyByteBuffer>> bodyParts;
+ for (int i = 0; i < numBatches; i++) {
+ partitionIds[i] = batches[i].loc->id;
+ partitionUniqueIds[i] = batches[i].loc->uniqueId();
+ offsets[i] = currentSize;
+ currentSize += static_cast<int>(batches[i].body->size());
+ bodyParts.push_back(batches[i].body->clone());
+ }
Review Comment:
`offsets` are built from `currentSize` (an `int`) and stored into
`std::vector<int32_t>`. There are no checks that the running total fits in
`int32_t`, or even in `int`, so large merged payloads can overflow and generate
incorrect offsets (breaking server-side parsing). Please use a wider
accumulator (e.g., `int64_t`/`size_t`) and validate the final offsets are
within `INT32_MAX` before encoding.
##########
cpp/celeborn/client/writer/PushMergedDataCallback.cpp:
##########
@@ -0,0 +1,326 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "celeborn/client/writer/PushMergedDataCallback.h"
+#include "celeborn/conf/CelebornConf.h"
+#include "celeborn/protocol/TransportMessage.h"
+
+namespace celeborn {
+namespace client {
+
+std::shared_ptr<PushMergedDataCallback> PushMergedDataCallback::create(
+ int shuffleId,
+ int mapId,
+ int attemptId,
+ int numMappers,
+ int numPartitions,
+ const std::string& mapKey,
+ const std::string& hostAndPushPort,
+ int groupedBatchId,
+ std::vector<DataBatch> batches,
+ std::vector<int> partitionIds,
+ std::shared_ptr<PushState> pushState,
+ std::weak_ptr<ShuffleClientImpl> weakClient,
+ int remainingReviveTimes) {
+ return std::shared_ptr<PushMergedDataCallback>(new PushMergedDataCallback(
+ shuffleId,
+ mapId,
+ attemptId,
+ numMappers,
+ numPartitions,
+ mapKey,
+ hostAndPushPort,
+ groupedBatchId,
+ std::move(batches),
+ std::move(partitionIds),
+ pushState,
+ weakClient,
+ remainingReviveTimes));
+}
+
+PushMergedDataCallback::PushMergedDataCallback(
+ int shuffleId,
+ int mapId,
+ int attemptId,
+ int numMappers,
+ int numPartitions,
+ const std::string& mapKey,
+ const std::string& hostAndPushPort,
+ int groupedBatchId,
+ std::vector<DataBatch> batches,
+ std::vector<int> partitionIds,
+ std::shared_ptr<PushState> pushState,
+ std::weak_ptr<ShuffleClientImpl> weakClient,
+ int remainingReviveTimes)
+ : shuffleId_(shuffleId),
+ mapId_(mapId),
+ attemptId_(attemptId),
+ numMappers_(numMappers),
+ numPartitions_(numPartitions),
+ mapKey_(mapKey),
+ hostAndPushPort_(hostAndPushPort),
+ groupedBatchId_(groupedBatchId),
+ batches_(std::move(batches)),
+ partitionIds_(std::move(partitionIds)),
+ pushState_(pushState),
+ weakClient_(weakClient),
+ remainingReviveTimes_(remainingReviveTimes) {}
+
+void PushMergedDataCallback::onSuccess(
+ std::unique_ptr<memory::ReadOnlyByteBuffer> response) {
+ auto sharedClient = weakClient_.lock();
+ if (!sharedClient) {
+ LOG(WARNING) << "ShuffleClientImpl has expired when "
+ "PushMergedDataCallbackOnSuccess, ignored, shuffle "
+ << shuffleId_ << " map " << mapId_ << " attempt " <<
attemptId_
+ << " groupedBatch " << groupedBatchId_ << ".";
+ return;
+ }
+
+ if (response->remainingSize() <= 0) {
+ pushState_->removeBatch(groupedBatchId_, hostAndPushPort_);
+ return;
+ }
+
+ protocol::StatusCode reason =
+ static_cast<protocol::StatusCode>(response->read<uint8_t>());
+ switch (reason) {
+ case protocol::StatusCode::MAP_ENDED: {
+ auto mapperEndSet = sharedClient->mapperEndSets().computeIfAbsent(
+ shuffleId_,
+ []() { return std::make_shared<utils::ConcurrentHashSet<int>>(); });
+ mapperEndSet->insert(mapId_);
+ pushState_->removeBatch(groupedBatchId_, hostAndPushPort_);
+ break;
+ }
+ case protocol::StatusCode::HARD_SPLIT:
+ case protocol::StatusCode::SOFT_SPLIT: {
+ VLOG(1) << "Push merged data to " << hostAndPushPort_
+ << " split required for shuffle " << shuffleId_ << " map "
+ << mapId_ << " attempt " << attemptId_ << " groupedBatch "
+ << groupedBatchId_ << ".";
+
+ if (response->remainingSize() > 0) {
+ // Parse PbPushMergedDataSplitPartitionInfo from TransportMessage
+ auto transportMsg = std::make_unique<protocol::TransportMessage>(
+ response->readToReadOnlyBuffer(response->remainingSize()));
+ PbPushMergedDataSplitPartitionInfo partitionInfo;
+ if (!partitionInfo.ParseFromString(transportMsg->payload())) {
+ pushState_->setException(std::make_unique<std::runtime_error>(
+ "Failed to parse PbPushMergedDataSplitPartitionInfo"));
+ return;
+ }
+
+ for (int i = 0; i < partitionInfo.splitpartitionindexes_size(); i++) {
+ int partitionIndex = partitionInfo.splitpartitionindexes(i);
+ int statusCode = partitionInfo.statuscodes(i);
+
+ if (statusCode ==
+ static_cast<int>(protocol::StatusCode::SOFT_SPLIT)) {
+ int partitionId = partitionIds_[partitionIndex];
+ if (!ShuffleClientImpl::newerPartitionLocationExists(
+ sharedClient->getPartitionLocationMap(shuffleId_).value(),
+ partitionId,
+ batches_[partitionIndex].loc->epoch)) {
Review Comment:
When handling `PbPushMergedDataSplitPartitionInfo`, the code assumes
`splitpartitionindexes(i)` is a valid index into `partitionIds_`/`batches_` and
that `statuscodes_size()` matches `splitpartitionindexes_size()`. Since this
data comes from the wire, an invalid index or mismatched array sizes can cause
out-of-bounds access and crash. Please validate sizes match and bounds-check
`partitionIndex` before indexing (and fail the push with an exception if
invalid).
##########
cpp/celeborn/client/ShuffleClient.cpp:
##########
@@ -269,11 +270,423 @@ int ShuffleClientImpl::pushData(
return body->remainingSize();
}
+int ShuffleClientImpl::mergeData(
+ int shuffleId,
+ int mapId,
+ int attemptId,
+ int partitionId,
+ const uint8_t* data,
+ size_t offset,
+ size_t length,
+ int numMappers,
+ int numPartitions) {
+ const auto mapKey = utils::makeMapKey(shuffleId, mapId, attemptId);
+ if (checkMapperEnded(shuffleId, mapId, mapKey)) {
+ return 0;
+ }
+
+ auto partitionLocationMap =
+ getPartitionLocation(shuffleId, numMappers, numPartitions);
+ CELEBORN_CHECK_NOT_NULL(partitionLocationMap);
+ auto partitionLocationOptional = partitionLocationMap->get(partitionId);
+ if (!partitionLocationOptional.has_value()) {
+ if (!revive(
+ shuffleId,
+ mapId,
+ attemptId,
+ partitionId,
+ -1,
+ nullptr,
+ protocol::StatusCode::PUSH_DATA_FAIL_NON_CRITICAL_CAUSE)) {
+ CELEBORN_FAIL(fmt::format(
+ "Revive for shuffleId {} partitionId {} failed.",
+ shuffleId,
+ partitionId));
+ }
+ partitionLocationOptional = partitionLocationMap->get(partitionId);
+ }
+ if (checkMapperEnded(shuffleId, mapId, mapKey)) {
+ return 0;
+ }
+
+ CELEBORN_CHECK(partitionLocationOptional.has_value());
+ auto partitionLocation = partitionLocationOptional.value();
+ auto pushState = getPushState(mapKey);
+ const int nextBatchId = pushState->nextBatchId();
+
+ CELEBORN_CHECK(
+ length <= static_cast<size_t>(std::numeric_limits<int>::max()),
+ fmt::format(
+ "Data length {} exceeds maximum supported size {}",
+ length,
+ std::numeric_limits<int>::max()));
+
+ const uint8_t* dataToWrite = data + offset;
+ int lengthToWrite = static_cast<int>(length);
+ std::unique_ptr<uint8_t[]> compressedBuffer;
+
+ if (shuffleCompressionEnabled_ && compressorFactory_) {
+ auto compressor = compressorFactory_();
+ const size_t compressedCapacity =
+ compressor->getDstCapacity(static_cast<int>(length));
+ compressedBuffer = std::make_unique<uint8_t[]>(compressedCapacity);
+ const size_t compressedSize = compressor->compress(
+ dataToWrite, 0, static_cast<int>(length), compressedBuffer.get(), 0);
+
+ CELEBORN_CHECK(
+ compressedSize <= static_cast<size_t>(std::numeric_limits<int>::max()),
+ fmt::format(
+ "Compressed size {} exceeds maximum supported size {}",
+ compressedSize,
+ std::numeric_limits<int>::max()));
+
+ lengthToWrite = static_cast<int>(compressedSize);
+ dataToWrite = compressedBuffer.get();
+ }
+
+ CELEBORN_CHECK(
+ static_cast<size_t>(lengthToWrite) <=
+ std::numeric_limits<size_t>::max() - kBatchHeaderSize,
+ fmt::format(
+ "Buffer size {} + header {} would overflow",
+ lengthToWrite,
+ kBatchHeaderSize));
+
+ auto writeBuffer = memory::ByteBuffer::createWriteOnly(
+ kBatchHeaderSize + static_cast<size_t>(lengthToWrite));
+ writeBuffer->writeLE<int>(mapId);
+ writeBuffer->writeLE<int>(attemptId);
+ writeBuffer->writeLE<int>(nextBatchId);
+ writeBuffer->writeLE<int>(lengthToWrite);
+ writeBuffer->writeFromBuffer(
+ dataToWrite, 0, static_cast<size_t>(lengthToWrite));
+
+ auto body = memory::ByteBuffer::toReadOnly(std::move(writeBuffer));
+ auto addressPairKey = genAddressPairKey(*partitionLocation);
+
+ bool shouldPush = pushState->addBatchData(
+ addressPairKey, partitionLocation, nextBatchId, std::move(body));
+
+ if (shouldPush) {
+ auto dataBatches = pushState->takeDataBatches(addressPairKey);
+ if (dataBatches) {
+ auto hostAndPushPort = partitionLocation->hostAndPushPort();
+ doPushMergedData(
+ hostAndPushPort,
+ shuffleId,
+ mapId,
+ attemptId,
+ numMappers,
+ numPartitions,
+ mapKey,
+ dataBatches->requireBatches(),
+ pushState,
+ conf_->clientPushMaxReviveTimes());
+ }
+ }
+
+ return static_cast<int>(kBatchHeaderSize) + lengthToWrite;
+}
+
+void ShuffleClientImpl::pushMergedData(
+ int shuffleId,
+ int mapId,
+ int attemptId) {
+ const auto mapKey = utils::makeMapKey(shuffleId, mapId, attemptId);
+ auto pushState = getPushState(mapKey);
+ int pushBufferMaxSize = conf_->clientPushBufferMaxSize();
+
+ auto& batchesMap = pushState->getBatchesMap();
+ std::vector<std::string> keys;
+ batchesMap.forEach(
+ [&keys](const std::string& key, const std::shared_ptr<DataBatches>&) {
+ keys.push_back(key);
+ });
+
+ for (const auto& addressPairKey : keys) {
+ auto dataBatches = pushState->takeDataBatches(addressPairKey);
+ if (!dataBatches) {
+ continue;
+ }
+ while (dataBatches->getTotalSize() > 0) {
+ auto batches = dataBatches->requireBatches(pushBufferMaxSize);
+ if (batches.empty()) {
+ break;
+ }
+ auto hostAndPushPort = batches.front().loc->hostAndPushPort();
+ doPushMergedData(
+ hostAndPushPort,
+ shuffleId,
+ mapId,
+ attemptId,
+ 0,
+ 0,
+ mapKey,
+ std::move(batches),
+ pushState,
+ conf_->clientPushMaxReviveTimes());
+ }
+ }
+}
+
+void ShuffleClientImpl::doPushMergedData(
+ const std::string& hostAndPushPort,
+ int shuffleId,
+ int mapId,
+ int attemptId,
+ int numMappers,
+ int numPartitions,
+ const std::string& mapKey,
+ std::vector<DataBatch> batches,
+ std::shared_ptr<PushState> pushState,
+ int remainReviveTimes) {
+ if (batches.empty()) {
+ return;
+ }
+
+ int groupedBatchId = pushState->nextBatchId();
+ int groupedBatchBytesSize = 0;
+ for (const auto& batch : batches) {
+ groupedBatchBytesSize += static_cast<int>(batch.body->size());
+ }
+
+ limitMaxInFlight(mapKey, *pushState, hostAndPushPort);
+ pushState->addBatch(groupedBatchId, groupedBatchBytesSize, hostAndPushPort);
+
+ const int numBatches = static_cast<int>(batches.size());
+ std::vector<int> partitionIds(numBatches);
+ std::vector<std::string> partitionUniqueIds(numBatches);
+ std::vector<int32_t> offsets(numBatches);
+ int currentSize = 0;
+
+ std::vector<std::unique_ptr<memory::ReadOnlyByteBuffer>> bodyParts;
+ for (int i = 0; i < numBatches; i++) {
+ partitionIds[i] = batches[i].loc->id;
+ partitionUniqueIds[i] = batches[i].loc->uniqueId();
+ offsets[i] = currentSize;
+ currentSize += static_cast<int>(batches[i].body->size());
+ bodyParts.push_back(batches[i].body->clone());
+ }
+
+ // Concatenate all bodies
+ std::unique_ptr<memory::ReadOnlyByteBuffer> mergedBody;
+ if (bodyParts.size() == 1) {
+ mergedBody = std::move(bodyParts[0]);
+ } else {
+ mergedBody = memory::ByteBuffer::concat(*bodyParts[0], *bodyParts[1]);
+ for (size_t i = 2; i < bodyParts.size(); i++) {
+ mergedBody = memory::ByteBuffer::concat(*mergedBody, *bodyParts[i]);
+ }
Review Comment:
`doPushMergedData` clones every batch body into `bodyParts` and then calls
`ByteBuffer::concat`, which clones/trim-copies internally as well. This
double-cloning plus repeated concat in a loop adds unnecessary CPU/memory
overhead for large batch counts. Consider avoiding the pre-clone (concat
already clones) and/or building a single chained `IOBuf`/ByteBuffer in one pass
instead of repeated `concat` calls.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]