afterincomparableyum commented on code in PR #3611:
URL: https://github.com/apache/celeborn/pull/3611#discussion_r2915670235


##########
cpp/celeborn/client/ShuffleClient.cpp:
##########
@@ -269,11 +270,423 @@ int ShuffleClientImpl::pushData(
   return body->remainingSize();
 }
 
+int ShuffleClientImpl::mergeData(
+    int shuffleId,
+    int mapId,
+    int attemptId,
+    int partitionId,
+    const uint8_t* data,
+    size_t offset,
+    size_t length,
+    int numMappers,
+    int numPartitions) {
+  const auto mapKey = utils::makeMapKey(shuffleId, mapId, attemptId);
+  if (checkMapperEnded(shuffleId, mapId, mapKey)) {
+    return 0;
+  }
+
+  auto partitionLocationMap =
+      getPartitionLocation(shuffleId, numMappers, numPartitions);
+  CELEBORN_CHECK_NOT_NULL(partitionLocationMap);
+  auto partitionLocationOptional = partitionLocationMap->get(partitionId);
+  if (!partitionLocationOptional.has_value()) {
+    if (!revive(
+            shuffleId,
+            mapId,
+            attemptId,
+            partitionId,
+            -1,
+            nullptr,
+            protocol::StatusCode::PUSH_DATA_FAIL_NON_CRITICAL_CAUSE)) {
+      CELEBORN_FAIL(fmt::format(
+          "Revive for shuffleId {} partitionId {} failed.",
+          shuffleId,
+          partitionId));
+    }
+    partitionLocationOptional = partitionLocationMap->get(partitionId);
+  }
+  if (checkMapperEnded(shuffleId, mapId, mapKey)) {
+    return 0;
+  }
+
+  CELEBORN_CHECK(partitionLocationOptional.has_value());
+  auto partitionLocation = partitionLocationOptional.value();
+  auto pushState = getPushState(mapKey);
+  const int nextBatchId = pushState->nextBatchId();
+
+  CELEBORN_CHECK(
+      length <= static_cast<size_t>(std::numeric_limits<int>::max()),
+      fmt::format(
+          "Data length {} exceeds maximum supported size {}",
+          length,
+          std::numeric_limits<int>::max()));
+
+  const uint8_t* dataToWrite = data + offset;
+  int lengthToWrite = static_cast<int>(length);
+  std::unique_ptr<uint8_t[]> compressedBuffer;
+
+  if (shuffleCompressionEnabled_ && compressorFactory_) {
+    auto compressor = compressorFactory_();
+    const size_t compressedCapacity =
+        compressor->getDstCapacity(static_cast<int>(length));
+    compressedBuffer = std::make_unique<uint8_t[]>(compressedCapacity);
+    const size_t compressedSize = compressor->compress(
+        dataToWrite, 0, static_cast<int>(length), compressedBuffer.get(), 0);
+
+    CELEBORN_CHECK(
+        compressedSize <= static_cast<size_t>(std::numeric_limits<int>::max()),
+        fmt::format(
+            "Compressed size {} exceeds maximum supported size {}",
+            compressedSize,
+            std::numeric_limits<int>::max()));
+
+    lengthToWrite = static_cast<int>(compressedSize);
+    dataToWrite = compressedBuffer.get();
+  }
+
+  CELEBORN_CHECK(
+      static_cast<size_t>(lengthToWrite) <=
+          std::numeric_limits<size_t>::max() - kBatchHeaderSize,
+      fmt::format(
+          "Buffer size {} + header {} would overflow",
+          lengthToWrite,
+          kBatchHeaderSize));
+
+  auto writeBuffer = memory::ByteBuffer::createWriteOnly(
+      kBatchHeaderSize + static_cast<size_t>(lengthToWrite));
+  writeBuffer->writeLE<int>(mapId);
+  writeBuffer->writeLE<int>(attemptId);
+  writeBuffer->writeLE<int>(nextBatchId);
+  writeBuffer->writeLE<int>(lengthToWrite);
+  writeBuffer->writeFromBuffer(
+      dataToWrite, 0, static_cast<size_t>(lengthToWrite));
+
+  auto body = memory::ByteBuffer::toReadOnly(std::move(writeBuffer));
+  auto addressPairKey = genAddressPairKey(*partitionLocation);
+
+  bool shouldPush = pushState->addBatchData(
+      addressPairKey, partitionLocation, nextBatchId, std::move(body));
+
+  if (shouldPush) {
+    auto dataBatches = pushState->takeDataBatches(addressPairKey);
+    if (dataBatches) {
+      auto hostAndPushPort = partitionLocation->hostAndPushPort();
+      doPushMergedData(
+          hostAndPushPort,
+          shuffleId,
+          mapId,
+          attemptId,
+          numMappers,
+          numPartitions,
+          mapKey,
+          dataBatches->requireBatches(),
+          pushState,
+          conf_->clientPushMaxReviveTimes());
+    }
+  }
+
+  return static_cast<int>(kBatchHeaderSize) + lengthToWrite;
+}
+
+void ShuffleClientImpl::pushMergedData(
+    int shuffleId,
+    int mapId,
+    int attemptId) {
+  const auto mapKey = utils::makeMapKey(shuffleId, mapId, attemptId);
+  auto pushState = getPushState(mapKey);
+  int pushBufferMaxSize = conf_->clientPushBufferMaxSize();
+
+  auto& batchesMap = pushState->getBatchesMap();
+  std::vector<std::string> keys;
+  batchesMap.forEach(
+      [&keys](const std::string& key, const std::shared_ptr<DataBatches>&) {
+        keys.push_back(key);
+      });
+
+  for (const auto& addressPairKey : keys) {
+    auto dataBatches = pushState->takeDataBatches(addressPairKey);
+    if (!dataBatches) {
+      continue;
+    }
+    while (dataBatches->getTotalSize() > 0) {
+      auto batches = dataBatches->requireBatches(pushBufferMaxSize);
+      if (batches.empty()) {
+        break;
+      }
+      auto hostAndPushPort = batches.front().loc->hostAndPushPort();
+      doPushMergedData(
+          hostAndPushPort,
+          shuffleId,
+          mapId,
+          attemptId,
+          0,
+          0,
+          mapKey,
+          std::move(batches),
+          pushState,
+          conf_->clientPushMaxReviveTimes());
+    }
+  }
+}
+
+void ShuffleClientImpl::doPushMergedData(
+    const std::string& hostAndPushPort,
+    int shuffleId,
+    int mapId,
+    int attemptId,
+    int numMappers,
+    int numPartitions,
+    const std::string& mapKey,
+    std::vector<DataBatch> batches,
+    std::shared_ptr<PushState> pushState,
+    int remainReviveTimes) {
+  if (batches.empty()) {
+    return;
+  }
+
+  int groupedBatchId = pushState->nextBatchId();
+  int groupedBatchBytesSize = 0;
+  for (const auto& batch : batches) {
+    groupedBatchBytesSize += static_cast<int>(batch.body->size());
+  }
+
+  limitMaxInFlight(mapKey, *pushState, hostAndPushPort);
+  pushState->addBatch(groupedBatchId, groupedBatchBytesSize, hostAndPushPort);
+
+  const int numBatches = static_cast<int>(batches.size());
+  std::vector<int> partitionIds(numBatches);
+  std::vector<std::string> partitionUniqueIds(numBatches);
+  std::vector<int32_t> offsets(numBatches);
+  int currentSize = 0;
+
+  std::vector<std::unique_ptr<memory::ReadOnlyByteBuffer>> bodyParts;
+  for (int i = 0; i < numBatches; i++) {
+    partitionIds[i] = batches[i].loc->id;
+    partitionUniqueIds[i] = batches[i].loc->uniqueId();
+    offsets[i] = currentSize;
+    currentSize += static_cast<int>(batches[i].body->size());
+    bodyParts.push_back(batches[i].body->clone());
+  }

Review Comment:
   The body sizes were already validated in the overflow-checked loop above, so 
currentSize here is guaranteed to stay within int. I'll add a CELEBORN_DCHECK 
though as a compromise, but it will never overflow essentially. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to