afterincomparableyum commented on code in PR #3611:
URL: https://github.com/apache/celeborn/pull/3611#discussion_r2915655985
##########
cpp/celeborn/client/ShuffleClient.cpp:
##########
@@ -269,11 +270,423 @@ int ShuffleClientImpl::pushData(
return body->remainingSize();
}
+int ShuffleClientImpl::mergeData(
+ int shuffleId,
+ int mapId,
+ int attemptId,
+ int partitionId,
+ const uint8_t* data,
+ size_t offset,
+ size_t length,
+ int numMappers,
+ int numPartitions) {
+ const auto mapKey = utils::makeMapKey(shuffleId, mapId, attemptId);
+ if (checkMapperEnded(shuffleId, mapId, mapKey)) {
+ return 0;
+ }
+
+ auto partitionLocationMap =
+ getPartitionLocation(shuffleId, numMappers, numPartitions);
+ CELEBORN_CHECK_NOT_NULL(partitionLocationMap);
+ auto partitionLocationOptional = partitionLocationMap->get(partitionId);
+ if (!partitionLocationOptional.has_value()) {
+ if (!revive(
+ shuffleId,
+ mapId,
+ attemptId,
+ partitionId,
+ -1,
+ nullptr,
+ protocol::StatusCode::PUSH_DATA_FAIL_NON_CRITICAL_CAUSE)) {
+ CELEBORN_FAIL(fmt::format(
+ "Revive for shuffleId {} partitionId {} failed.",
+ shuffleId,
+ partitionId));
+ }
+ partitionLocationOptional = partitionLocationMap->get(partitionId);
+ }
+ if (checkMapperEnded(shuffleId, mapId, mapKey)) {
+ return 0;
+ }
+
+ CELEBORN_CHECK(partitionLocationOptional.has_value());
+ auto partitionLocation = partitionLocationOptional.value();
+ auto pushState = getPushState(mapKey);
+ const int nextBatchId = pushState->nextBatchId();
+
+ CELEBORN_CHECK(
+ length <= static_cast<size_t>(std::numeric_limits<int>::max()),
+ fmt::format(
+ "Data length {} exceeds maximum supported size {}",
+ length,
+ std::numeric_limits<int>::max()));
+
+ const uint8_t* dataToWrite = data + offset;
+ int lengthToWrite = static_cast<int>(length);
+ std::unique_ptr<uint8_t[]> compressedBuffer;
+
+ if (shuffleCompressionEnabled_ && compressorFactory_) {
+ auto compressor = compressorFactory_();
+ const size_t compressedCapacity =
+ compressor->getDstCapacity(static_cast<int>(length));
+ compressedBuffer = std::make_unique<uint8_t[]>(compressedCapacity);
+ const size_t compressedSize = compressor->compress(
+ dataToWrite, 0, static_cast<int>(length), compressedBuffer.get(), 0);
+
+ CELEBORN_CHECK(
+ compressedSize <= static_cast<size_t>(std::numeric_limits<int>::max()),
+ fmt::format(
+ "Compressed size {} exceeds maximum supported size {}",
+ compressedSize,
+ std::numeric_limits<int>::max()));
+
+ lengthToWrite = static_cast<int>(compressedSize);
+ dataToWrite = compressedBuffer.get();
+ }
+
+ CELEBORN_CHECK(
+ static_cast<size_t>(lengthToWrite) <=
+ std::numeric_limits<size_t>::max() - kBatchHeaderSize,
+ fmt::format(
+ "Buffer size {} + header {} would overflow",
+ lengthToWrite,
+ kBatchHeaderSize));
+
+ auto writeBuffer = memory::ByteBuffer::createWriteOnly(
+ kBatchHeaderSize + static_cast<size_t>(lengthToWrite));
+ writeBuffer->writeLE<int>(mapId);
+ writeBuffer->writeLE<int>(attemptId);
+ writeBuffer->writeLE<int>(nextBatchId);
+ writeBuffer->writeLE<int>(lengthToWrite);
+ writeBuffer->writeFromBuffer(
+ dataToWrite, 0, static_cast<size_t>(lengthToWrite));
+
+ auto body = memory::ByteBuffer::toReadOnly(std::move(writeBuffer));
+ auto addressPairKey = genAddressPairKey(*partitionLocation);
+
+ bool shouldPush = pushState->addBatchData(
+ addressPairKey, partitionLocation, nextBatchId, std::move(body));
+
+ if (shouldPush) {
+ auto dataBatches = pushState->takeDataBatches(addressPairKey);
+ if (dataBatches) {
+ auto hostAndPushPort = partitionLocation->hostAndPushPort();
+ doPushMergedData(
+ hostAndPushPort,
+ shuffleId,
+ mapId,
+ attemptId,
+ numMappers,
+ numPartitions,
+ mapKey,
+ dataBatches->requireBatches(),
+ pushState,
+ conf_->clientPushMaxReviveTimes());
+ }
+ }
+
+ return static_cast<int>(kBatchHeaderSize) + lengthToWrite;
+}
Review Comment:
Would be cleaner to use CELEBORN_CHECK_LE. I will fix that for pushData in a
future PR as well. But for this mergeData, I will use CELEBORN_CHECK_LE instead
of just CELEBORN_CHECK
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]