afterincomparableyum commented on code in PR #3611:
URL: https://github.com/apache/celeborn/pull/3611#discussion_r3007699196


##########
cpp/celeborn/client/ShuffleClient.cpp:
##########
@@ -269,11 +270,447 @@ int ShuffleClientImpl::pushData(
   return body->remainingSize();
 }
 
+int ShuffleClientImpl::mergeData(
+    int shuffleId,
+    int mapId,
+    int attemptId,
+    int partitionId,
+    const uint8_t* data,
+    size_t offset,
+    size_t length,
+    int numMappers,
+    int numPartitions) {
+  const auto mapKey = utils::makeMapKey(shuffleId, mapId, attemptId);
+  if (checkMapperEnded(shuffleId, mapId, mapKey)) {
+    return 0;
+  }
+
+  auto partitionLocationMap =
+      getPartitionLocation(shuffleId, numMappers, numPartitions);
+  CELEBORN_CHECK_NOT_NULL(partitionLocationMap);
+  auto partitionLocationOptional = partitionLocationMap->get(partitionId);
+  if (!partitionLocationOptional.has_value()) {
+    if (!revive(
+            shuffleId,
+            mapId,
+            attemptId,
+            partitionId,
+            -1,
+            nullptr,
+            protocol::StatusCode::PUSH_DATA_FAIL_NON_CRITICAL_CAUSE)) {
+      CELEBORN_FAIL(fmt::format(
+          "Revive for shuffleId {} partitionId {} failed.",
+          shuffleId,
+          partitionId));
+    }
+    partitionLocationOptional = partitionLocationMap->get(partitionId);
+  }
+  if (checkMapperEnded(shuffleId, mapId, mapKey)) {
+    return 0;
+  }
+
+  CELEBORN_CHECK(partitionLocationOptional.has_value());
+  auto partitionLocation = partitionLocationOptional.value();
+  auto pushState = getPushState(mapKey);
+  const int nextBatchId = pushState->nextBatchId();
+
+  CELEBORN_CHECK(
+      length <= static_cast<size_t>(std::numeric_limits<int>::max()),
+      fmt::format(
+          "Data length {} exceeds maximum supported size {}",
+          length,
+          std::numeric_limits<int>::max()));
+
+  const uint8_t* dataToWrite = data + offset;
+  int lengthToWrite = static_cast<int>(length);
+  std::unique_ptr<uint8_t[]> compressedBuffer;
+
+  if (shuffleCompressionEnabled_ && compressorFactory_) {
+    auto compressor = compressorFactory_();
+    const size_t compressedCapacity =
+        compressor->getDstCapacity(static_cast<int>(length));
+    compressedBuffer = std::make_unique<uint8_t[]>(compressedCapacity);
+    const size_t compressedSize = compressor->compress(
+        dataToWrite, 0, static_cast<int>(length), compressedBuffer.get(), 0);
+
+    CELEBORN_CHECK(
+        compressedSize <= static_cast<size_t>(std::numeric_limits<int>::max()),
+        fmt::format(
+            "Compressed size {} exceeds maximum supported size {}",
+            compressedSize,
+            std::numeric_limits<int>::max()));
+
+    lengthToWrite = static_cast<int>(compressedSize);
+    dataToWrite = compressedBuffer.get();
+  }
+
+  CELEBORN_CHECK(
+      static_cast<size_t>(lengthToWrite) <=
+          std::numeric_limits<size_t>::max() - kBatchHeaderSize,
+      fmt::format(
+          "Buffer size {} + header {} would overflow",
+          lengthToWrite,
+          kBatchHeaderSize));
+
+  auto writeBuffer = memory::ByteBuffer::createWriteOnly(
+      kBatchHeaderSize + static_cast<size_t>(lengthToWrite));
+  writeBuffer->writeLE<int>(mapId);
+  writeBuffer->writeLE<int>(attemptId);
+  writeBuffer->writeLE<int>(nextBatchId);
+  writeBuffer->writeLE<int>(lengthToWrite);
+  writeBuffer->writeFromBuffer(
+      dataToWrite, 0, static_cast<size_t>(lengthToWrite));
+
+  auto body = memory::ByteBuffer::toReadOnly(std::move(writeBuffer));
+  auto addressPairKey = genAddressPairKey(*partitionLocation);
+
+  bool shouldPush = pushState->addBatchData(
+      addressPairKey, partitionLocation, nextBatchId, std::move(body));
+
+  if (shouldPush) {
+    auto dataBatches = pushState->takeDataBatches(addressPairKey);
+    if (dataBatches) {
+      auto hostAndPushPort = partitionLocation->hostAndPushPort();
+      doPushMergedData(
+          hostAndPushPort,
+          shuffleId,
+          mapId,
+          attemptId,
+          numMappers,
+          numPartitions,
+          mapKey,
+          dataBatches->requireBatches(),
+          pushState,
+          conf_->clientPushMaxReviveTimes());
+    }
+  }
+
+  CELEBORN_CHECK_LE(
+      lengthToWrite,
+      std::numeric_limits<int>::max() - static_cast<int>(kBatchHeaderSize),
+      "Batch bytes size {} + header {} would overflow int",
+      lengthToWrite,
+      kBatchHeaderSize);
+  return static_cast<int>(kBatchHeaderSize) + lengthToWrite;
+}
+
+void ShuffleClientImpl::pushMergedData(
+    int shuffleId,
+    int mapId,
+    int attemptId) {
+  const auto mapKey = utils::makeMapKey(shuffleId, mapId, attemptId);
+  auto pushState = getPushState(mapKey);
+  int pushBufferMaxSize = conf_->clientPushBufferMaxSize();
+
+  auto& batchesMap = pushState->getBatchesMap();
+  std::vector<std::string> keys;
+  batchesMap.forEach(
+      [&keys](const std::string& key, const std::shared_ptr<DataBatches>&) {
+        keys.push_back(key);
+      });
+
+  for (const auto& addressPairKey : keys) {
+    auto dataBatches = pushState->takeDataBatches(addressPairKey);
+    if (!dataBatches) {
+      continue;
+    }
+    while (dataBatches->getTotalSize() > 0) {
+      auto batches = dataBatches->requireBatches(pushBufferMaxSize);
+      if (batches.empty()) {
+        break;
+      }
+      auto hostAndPushPort = batches.front().loc->hostAndPushPort();
+      doPushMergedData(
+          hostAndPushPort,
+          shuffleId,
+          mapId,
+          attemptId,
+          0,
+          0,
+          mapKey,
+          std::move(batches),
+          pushState,
+          conf_->clientPushMaxReviveTimes());
+    }
+  }
+}
+
+void ShuffleClientImpl::doPushMergedData(
+    const std::string& hostAndPushPort,
+    int shuffleId,
+    int mapId,
+    int attemptId,
+    int numMappers,
+    int numPartitions,
+    const std::string& mapKey,
+    std::vector<DataBatch> batches,
+    std::shared_ptr<PushState> pushState,
+    int remainReviveTimes) {
+  if (batches.empty()) {
+    return;
+  }
+
+  int groupedBatchId = pushState->nextBatchId();
+  int groupedBatchBytesSize = 0;
+  for (const auto& batch : batches) {
+    CELEBORN_CHECK_LE(
+        batch.body->size(),
+        static_cast<size_t>(INT_MAX),
+        "Batch body size {} exceeds INT_MAX",
+        batch.body->size());
+    int bodySize = static_cast<int>(batch.body->size());
+    CELEBORN_CHECK_LE(
+        bodySize,
+        INT_MAX - groupedBatchBytesSize,
+        "Grouped batch size would overflow: adding {} to {}",
+        bodySize,
+        groupedBatchBytesSize);
+    groupedBatchBytesSize += bodySize;
+  }
+
+  limitMaxInFlight(mapKey, *pushState, hostAndPushPort);
+  pushState->addBatch(groupedBatchId, groupedBatchBytesSize, hostAndPushPort);
+
+  const int numBatches = static_cast<int>(batches.size());
+  std::vector<int> partitionIds(numBatches);
+  std::vector<std::string> partitionUniqueIds(numBatches);
+  std::vector<int32_t> offsets(numBatches);
+  int currentSize = 0;
+
+  // TODO: Each clone() + concat() copies data. Java avoids this with Netty's
+  // zero-copy CompositeByteBuf. We need to build something like an IOBuf chain
+  // directly (ex: via appendToChain) to eliminate the double-clone and O(n^2)
+  // re-copy from repeated pairwise concat.
+  std::vector<std::unique_ptr<memory::ReadOnlyByteBuffer>> bodyParts;
+  for (int i = 0; i < numBatches; i++) {
+    partitionIds[i] = batches[i].loc->id;
+    partitionUniqueIds[i] = batches[i].loc->uniqueId();
+    offsets[i] = currentSize;
+    int bodySize = static_cast<int>(batches[i].body->size());
+    CELEBORN_DCHECK_LE(bodySize, INT_MAX - currentSize);
+    currentSize += bodySize;
+    bodyParts.push_back(batches[i].body->clone());
+  }
+
+  // Concatenate all bodies
+  std::unique_ptr<memory::ReadOnlyByteBuffer> mergedBody;
+  if (bodyParts.size() == 1) {
+    mergedBody = std::move(bodyParts[0]);
+  } else {
+    mergedBody = memory::ByteBuffer::concat(*bodyParts[0], *bodyParts[1]);
+    for (size_t i = 2; i < bodyParts.size(); i++) {
+      mergedBody = memory::ByteBuffer::concat(*mergedBody, *bodyParts[i]);
+    }
+  }
+
+  auto shuffleKey = utils::makeShuffleKey(appUniqueId_, shuffleId);
+  network::PushMergedData pushMergedData(
+      network::Message::nextRequestId(),
+      protocol::PartitionLocation::Mode::PRIMARY,
+      shuffleKey,
+      std::move(partitionUniqueIds),
+      std::move(offsets),
+      std::move(mergedBody));
+
+  auto callback = PushMergedDataCallback::create(
+      shuffleId,
+      mapId,
+      attemptId,
+      numMappers,
+      numPartitions,
+      mapKey,
+      hostAndPushPort,
+      groupedBatchId,
+      std::move(batches),
+      std::move(partitionIds),
+      pushState,
+      weak_from_this(),
+      remainReviveTimes);
+
+  auto host = hostAndPushPort.substr(0, hostAndPushPort.find(':'));
+  auto portStr = hostAndPushPort.substr(hostAndPushPort.find(':') + 1);
+  auto port = static_cast<uint16_t>(std::stoi(portStr));

Review Comment:
   done



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to