This is an automated email from the ASF dual-hosted git repository.

nicholasjiang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 36cdc29ea [CELEBORN-2206][CIP-14] Support PushDataCallback in CppClient
36cdc29ea is described below

commit 36cdc29eab749fe24fa23cb654c252bdbde1bd85
Author: HolyLow <[email protected]>
AuthorDate: Fri Nov 28 15:22:48 2025 +0800

    [CELEBORN-2206][CIP-14] Support PushDataCallback in CppClient
    
    ### What changes were proposed in this pull request?
    This PR supports PushDataCallback in CppClient.
    
    ### Why are the changes needed?
    PushDataCallback is the building block of PushData logic of CppClient's 
writing procedure.
    
    ### Does this PR resolve a correctness bug?
    No.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Compilation and UTs.
    
    Closes #3543 from 
HolyLow/issue/celeborn-2206-support-PushDataCallback-in-cpp-client.
    
    Authored-by: HolyLow <[email protected]>
    Signed-off-by: SteNicholas <[email protected]>
---
 cpp/celeborn/client/CMakeLists.txt                 |   1 +
 cpp/celeborn/client/ShuffleClient.h                |  38 ++
 cpp/celeborn/client/tests/CMakeLists.txt           |   1 +
 cpp/celeborn/client/tests/PushDataCallbackTest.cpp | 389 +++++++++++++++++++++
 cpp/celeborn/client/writer/PushDataCallback.cpp    | 266 ++++++++++++++
 cpp/celeborn/client/writer/PushDataCallback.h      |  96 +++++
 cpp/celeborn/client/writer/ReviveManager.h         |   1 +
 cpp/celeborn/conf/CelebornConf.cpp                 |   8 +-
 cpp/celeborn/conf/CelebornConf.h                   |   6 +
 cpp/celeborn/protocol/PartitionLocation.cpp        |  12 +
 cpp/celeborn/protocol/PartitionLocation.h          |   9 +-
 cpp/celeborn/utils/CelebornUtils.h                 |   6 +
 12 files changed, 828 insertions(+), 5 deletions(-)

diff --git a/cpp/celeborn/client/CMakeLists.txt 
b/cpp/celeborn/client/CMakeLists.txt
index fa0290733..112b1f38d 100644
--- a/cpp/celeborn/client/CMakeLists.txt
+++ b/cpp/celeborn/client/CMakeLists.txt
@@ -19,6 +19,7 @@ add_library(
         writer/PushState.cpp
         writer/PushStrategy.cpp
         writer/ReviveManager.cpp
+        writer/PushDataCallback.cpp
         ShuffleClient.cpp
         compress/Decompressor.cpp
         compress/Lz4Decompressor.cpp
diff --git a/cpp/celeborn/client/ShuffleClient.h 
b/cpp/celeborn/client/ShuffleClient.h
index 41e953515..dc71a39bc 100644
--- a/cpp/celeborn/client/ShuffleClient.h
+++ b/cpp/celeborn/client/ShuffleClient.h
@@ -18,6 +18,9 @@
 #pragma once
 
 #include "celeborn/client/reader/CelebornInputStream.h"
+#include "celeborn/client/writer/PushDataCallback.h"
+#include "celeborn/client/writer/PushState.h"
+#include "celeborn/client/writer/ReviveManager.h"
 #include "celeborn/network/NettyRpcEndpointRef.h"
 
 namespace celeborn {
@@ -51,11 +54,15 @@ class ShuffleClient {
   virtual void shutdown() = 0;
 };
 
+class ReviveManager;
+class PushDataCallback;
+
 class ShuffleClientImpl
     : public ShuffleClient,
       public std::enable_shared_from_this<ShuffleClientImpl> {
  public:
   friend class ReviveManager;
+  friend class PushDataCallback;
 
   using PtrReviveRequest = std::shared_ptr<protocol::ReviveRequest>;
   using PartitionLocationMap = utils::ConcurrentHashMap<
@@ -104,12 +111,29 @@ class ShuffleClientImpl
       const std::shared_ptr<const conf::CelebornConf>& conf,
       const std::shared_ptr<network::TransportClientFactory>& clientFactory);
 
+  // TODO: currently this function serves as a stub. will be updated in future
+  //  commits.
+  virtual void submitRetryPushData(
+      int shuffleId,
+      std::unique_ptr<memory::ReadOnlyByteBuffer> body,
+      int batchId,
+      std::shared_ptr<PushDataCallback> pushDataCallback,
+      std::shared_ptr<PushState> pushState,
+      PtrReviveRequest request,
+      int remainReviveTimes,
+      long dueTimeMs) {}
+
   // TODO: currently this function serves as a stub. will be updated in future
   //  commits.
   virtual bool mapperEnded(int shuffleId, int mapId) {
     return true;
   }
 
+  // TODO: currently this function serves as a stub. will be updated in future
+  //  commits.
+  virtual void addRequestToReviveManager(
+      std::shared_ptr<protocol::ReviveRequest> reviveRequest) {}
+
   // TODO: currently this function serves as a stub. will be updated in future
   //  commits.
   virtual std::optional<std::unordered_map<int, int>> reviveBatch(
@@ -124,6 +148,16 @@ class ShuffleClientImpl
     return partitionLocationMaps_.get(shuffleId);
   }
 
+  virtual utils::
+      ConcurrentHashMap<int, std::shared_ptr<utils::ConcurrentHashSet<int>>>&
+      mapperEndSets() {
+    return mapperEndSets_;
+  }
+
+  virtual void addPushDataRetryTask(folly::Func&& task) {
+    pushDataRetryPool_->add(std::move(task));
+  }
+
  private:
   // TODO: no support for WAIT as it is not used.
   static bool newerPartitionLocationExists(
@@ -140,12 +174,16 @@ class ShuffleClientImpl
   std::shared_ptr<const conf::CelebornConf> conf_;
   std::shared_ptr<network::NettyRpcEndpointRef> lifecycleManagerRef_;
   std::shared_ptr<network::TransportClientFactory> clientFactory_;
+  std::shared_ptr<folly::IOExecutor> pushDataRetryPool_;
+  std::shared_ptr<ReviveManager> reviveManager_;
   std::mutex mutex_;
   utils::ConcurrentHashMap<
       int,
       std::shared_ptr<protocol::GetReducerFileGroupResponse>>
       reducerFileGroupInfos_;
   utils::ConcurrentHashMap<int, PtrPartitionLocationMap> 
partitionLocationMaps_;
+  utils::ConcurrentHashMap<int, std::shared_ptr<utils::ConcurrentHashSet<int>>>
+      mapperEndSets_;
 };
 } // namespace client
 } // namespace celeborn
diff --git a/cpp/celeborn/client/tests/CMakeLists.txt 
b/cpp/celeborn/client/tests/CMakeLists.txt
index 8183611c6..e19703f31 100644
--- a/cpp/celeborn/client/tests/CMakeLists.txt
+++ b/cpp/celeborn/client/tests/CMakeLists.txt
@@ -16,6 +16,7 @@
 add_executable(
         celeborn_client_test
         WorkerPartitionReaderTest.cpp
+        PushDataCallbackTest.cpp
         PushStateTest.cpp
         ReviveManagerTest.cpp
         Lz4DecompressorTest.cpp
diff --git a/cpp/celeborn/client/tests/PushDataCallbackTest.cpp 
b/cpp/celeborn/client/tests/PushDataCallbackTest.cpp
new file mode 100644
index 000000000..171c8e714
--- /dev/null
+++ b/cpp/celeborn/client/tests/PushDataCallbackTest.cpp
@@ -0,0 +1,389 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <gtest/gtest.h>
+
+#include "celeborn/client/writer/PushDataCallback.h"
+
+using namespace celeborn;
+using namespace celeborn::client;
+
+namespace {
+class MockShuffleClient : public ShuffleClientImpl {
+ public:
+  friend class PushDataCallback;
+
+  using FuncOnSubmitRetryPushData = std::function<void(
+      int,
+      std::unique_ptr<memory::ReadOnlyByteBuffer>,
+      int,
+      std::shared_ptr<PushDataCallback>,
+      std::shared_ptr<PushState>,
+      PtrReviveRequest,
+      int,
+      long)>;
+  using FuncOnAddRequestToReviveManager =
+      std::function<void(std::shared_ptr<protocol::ReviveRequest>)>;
+  using FuncOnAddPushDataRetryTask = std::function<void(folly::Func&&)>;
+
+  static std::shared_ptr<MockShuffleClient> create() {
+    return std::shared_ptr<MockShuffleClient>(new MockShuffleClient());
+  }
+
+  virtual ~MockShuffleClient() = default;
+
+  bool mapperEnded(int shuffleId, int mapId) override {
+    return false;
+  }
+
+  std::optional<PtrPartitionLocationMap> getPartitionLocationMap(
+      int shuffleId) override {
+    return {std::make_shared<PartitionLocationMap>()};
+  }
+
+  void submitRetryPushData(
+      int shuffleId,
+      std::unique_ptr<memory::ReadOnlyByteBuffer> body,
+      int batchId,
+      std::shared_ptr<PushDataCallback> pushDataCallback,
+      std::shared_ptr<PushState> pushState,
+      PtrReviveRequest request,
+      int remainReviveTimes,
+      long dueTimeMs) override {
+    onSubmitRetryPushData_(
+        shuffleId,
+        std::move(body),
+        batchId,
+        pushDataCallback,
+        pushState,
+        request,
+        remainReviveTimes,
+        dueTimeMs);
+  }
+
+  void setOnSubmitRetryPushData(FuncOnSubmitRetryPushData&& func) {
+    onSubmitRetryPushData_ = func;
+  }
+
+  void addRequestToReviveManager(
+      std::shared_ptr<protocol::ReviveRequest> reviveRequest) override {
+    onAddRequestToReviveManager_(reviveRequest);
+  }
+
+  void setOnAddRequestToReviveManager(FuncOnAddRequestToReviveManager&& func) {
+    onAddRequestToReviveManager_ = func;
+  }
+
+  utils::ConcurrentHashMap<int, 
std::shared_ptr<utils::ConcurrentHashSet<int>>>&
+  mapperEndSets() override {
+    return ShuffleClientImpl::mapperEndSets();
+  }
+
+  void addPushDataRetryTask(folly::Func&& task) override {
+    onAddPushDataRetryTask_(std::move(task));
+  }
+
+  void setOnAddPushDataRetryTask(FuncOnAddPushDataRetryTask&& func) {
+    onAddPushDataRetryTask_ = func;
+  }
+
+ private:
+  MockShuffleClient()
+      : ShuffleClientImpl(
+            "mock",
+            std::make_shared<conf::CelebornConf>(),
+            nullptr) {}
+
+  FuncOnSubmitRetryPushData onSubmitRetryPushData_ =
+      [](int,
+         std::unique_ptr<memory::ReadOnlyByteBuffer>,
+         int,
+         std::shared_ptr<PushDataCallback>,
+         std::shared_ptr<PushState>,
+         PtrReviveRequest,
+         int,
+         long) { CELEBORN_UNREACHABLE("not expected to call this"); };
+  FuncOnAddRequestToReviveManager onAddRequestToReviveManager_ =
+      [](std::shared_ptr<protocol::ReviveRequest>) {
+        CELEBORN_UNREACHABLE("not expected to call this");
+      };
+  FuncOnAddPushDataRetryTask onAddPushDataRetryTask_ = [](folly::Func&&) {
+    CELEBORN_UNREACHABLE("not expected to call this");
+  };
+};
+
+std::unique_ptr<memory::ReadOnlyByteBuffer> createReadOnlyByteBuffer(
+    uint8_t code) {
+  auto writeBuffer = memory::ByteBuffer::createWriteOnly(1);
+  writeBuffer->write<uint8_t>(code);
+  return std::move(memory::ByteBuffer::toReadOnly(std::move(writeBuffer)));
+}
+} // namespace
+
+const int testPushReviveIntervalMs = 100;
+const int testShuffleId = 1000;
+const int testMapId = 1001;
+const int testAttemptId = 1002;
+const int testPartitionId = 1003;
+const int testNumMappers = 1004;
+const int testNumPartitions = 1005;
+const std::string testMapKey = "test-map-key";
+const int testBatchId = 1006;
+const auto testLastestLocation =
+    std::make_shared<const protocol::PartitionLocation>();
+
+TEST(PushDataCallbackTest, onSuccessAndNoOperation) {
+  const auto celebornConf = conf::CelebornConf();
+  auto pushState = std::make_shared<PushState>(celebornConf);
+  auto mockClient = MockShuffleClient::create();
+  auto pushDataCallback = PushDataCallback::create(
+      testShuffleId,
+      testMapId,
+      testAttemptId,
+      testPartitionId,
+      testNumMappers,
+      testNumPartitions,
+      testMapKey,
+      testBatchId,
+      memory::ReadOnlyByteBuffer::createEmptyBuffer(),
+      pushState,
+      mockClient->weak_from_this(),
+      1,
+      testLastestLocation);
+
+  auto response = memory::ReadOnlyByteBuffer::createEmptyBuffer();
+  pushDataCallback->onSuccess(std::move(response));
+}
+
+TEST(PushDataCallbackTest, onSuccessAndMapEnd) {
+  const auto celebornConf = conf::CelebornConf();
+  auto pushState = std::make_shared<PushState>(celebornConf);
+  auto mockClient = MockShuffleClient::create();
+  EXPECT_EQ(mockClient->mapperEndSets().size(), 0);
+  auto pushDataCallback = PushDataCallback::create(
+      testShuffleId,
+      testMapId,
+      testAttemptId,
+      testPartitionId,
+      testNumMappers,
+      testNumPartitions,
+      testMapKey,
+      testBatchId,
+      memory::ReadOnlyByteBuffer::createEmptyBuffer(),
+      pushState,
+      mockClient->weak_from_this(),
+      1,
+      testLastestLocation);
+
+  auto response = createReadOnlyByteBuffer(protocol::StatusCode::MAP_ENDED);
+  pushDataCallback->onSuccess(std::move(response));
+  auto& mapperEndSets = mockClient->mapperEndSets();
+  EXPECT_TRUE(mapperEndSets.containsKey(testShuffleId));
+  auto mapperEndSet = mapperEndSets.get(testShuffleId).value();
+  EXPECT_TRUE(mapperEndSet->contains(testMapId));
+}
+
+TEST(PushDataCallbackTest, onSuccessAndSoftSplit) {
+  const auto celebornConf = conf::CelebornConf();
+  auto pushState = std::make_shared<PushState>(celebornConf);
+  auto mockClient = MockShuffleClient::create();
+
+  int addRequestCalledTimes = 0;
+  mockClient->setOnAddRequestToReviveManager(
+      [=, &addRequestCalledTimes](
+          std::shared_ptr<protocol::ReviveRequest> request) mutable {
+        EXPECT_EQ(request->shuffleId, testShuffleId);
+        EXPECT_EQ(request->mapId, testMapId);
+        EXPECT_EQ(request->attemptId, testAttemptId);
+        EXPECT_EQ(request->partitionId, testPartitionId);
+        EXPECT_EQ(request->cause, protocol::StatusCode::SOFT_SPLIT);
+
+        addRequestCalledTimes++;
+      });
+
+  auto pushDataCallback = PushDataCallback::create(
+      testShuffleId,
+      testMapId,
+      testAttemptId,
+      testPartitionId,
+      testNumMappers,
+      testNumPartitions,
+      testMapKey,
+      testBatchId,
+      memory::ReadOnlyByteBuffer::createEmptyBuffer(),
+      pushState,
+      mockClient->weak_from_this(),
+      1,
+      testLastestLocation);
+
+  auto response = createReadOnlyByteBuffer(protocol::StatusCode::SOFT_SPLIT);
+  pushDataCallback->onSuccess(std::move(response));
+  EXPECT_GT(addRequestCalledTimes, 0);
+}
+
+TEST(PushDataCallbackTest, onSuccessAndHardSplit) {
+  const auto celebornConf = conf::CelebornConf();
+  auto pushState = std::make_shared<PushState>(celebornConf);
+  const int testRemainingReviveTimes = 1;
+  auto mockClient = MockShuffleClient::create();
+
+  int addRequestCalledTimes = 0;
+  mockClient->setOnAddRequestToReviveManager(
+      [=, &addRequestCalledTimes](
+          std::shared_ptr<protocol::ReviveRequest> request) mutable {
+        EXPECT_EQ(request->shuffleId, testShuffleId);
+        EXPECT_EQ(request->mapId, testMapId);
+        EXPECT_EQ(request->attemptId, testAttemptId);
+        EXPECT_EQ(request->partitionId, testPartitionId);
+        EXPECT_EQ(request->cause, protocol::StatusCode::HARD_SPLIT);
+
+        addRequestCalledTimes++;
+      });
+  int addRetryTaskCalledTimes = 0;
+  mockClient->setOnAddPushDataRetryTask(
+      [&addRetryTaskCalledTimes](folly::Func&& task) {
+        addRetryTaskCalledTimes++;
+        task();
+      });
+  int submitRetryCalledTimes = 0;
+  mockClient->setOnSubmitRetryPushData(
+      [=, &submitRetryCalledTimes](
+          int shuffleId,
+          std::unique_ptr<memory::ReadOnlyByteBuffer> body,
+          int batchId,
+          std::shared_ptr<PushDataCallback> pushDataCallback,
+          std::shared_ptr<PushState> pushState,
+          ShuffleClientImpl::PtrReviveRequest request,
+          int remainReviveTimes,
+          long dueTimeMs) mutable {
+        EXPECT_EQ(shuffleId, testShuffleId);
+        EXPECT_EQ(batchId, testBatchId);
+        EXPECT_EQ(remainReviveTimes, testRemainingReviveTimes);
+
+        submitRetryCalledTimes++;
+      });
+
+  auto pushDataCallback = PushDataCallback::create(
+      testShuffleId,
+      testMapId,
+      testAttemptId,
+      testPartitionId,
+      testNumMappers,
+      testNumPartitions,
+      testMapKey,
+      testBatchId,
+      memory::ReadOnlyByteBuffer::createEmptyBuffer(),
+      pushState,
+      mockClient->weak_from_this(),
+      testRemainingReviveTimes,
+      testLastestLocation);
+
+  auto response = createReadOnlyByteBuffer(protocol::StatusCode::HARD_SPLIT);
+  pushDataCallback->onSuccess(std::move(response));
+  EXPECT_GT(addRequestCalledTimes, 0);
+  EXPECT_GT(addRetryTaskCalledTimes, 0);
+  EXPECT_GT(submitRetryCalledTimes, 0);
+}
+
+TEST(PushDataCallbackTest, onFailureAndRevive) {
+  const auto celebornConf = conf::CelebornConf();
+  auto pushState = std::make_shared<PushState>(celebornConf);
+  const int testRemainingReviveTimes = 1;
+  auto mockClient = MockShuffleClient::create();
+
+  int addRequestCalledTimes = 0;
+  mockClient->setOnAddRequestToReviveManager(
+      [=, &addRequestCalledTimes](
+          std::shared_ptr<protocol::ReviveRequest> request) mutable {
+        EXPECT_EQ(request->shuffleId, testShuffleId);
+        EXPECT_EQ(request->mapId, testMapId);
+        EXPECT_EQ(request->attemptId, testAttemptId);
+        EXPECT_EQ(request->partitionId, testPartitionId);
+
+        addRequestCalledTimes++;
+      });
+  int addRetryTaskCalledTimes = 0;
+  mockClient->setOnAddPushDataRetryTask(
+      [&addRetryTaskCalledTimes](folly::Func&& task) {
+        addRetryTaskCalledTimes++;
+        task();
+      });
+  int submitRetryCalledTimes = 0;
+  mockClient->setOnSubmitRetryPushData(
+      [=, &submitRetryCalledTimes](
+          int shuffleId,
+          std::unique_ptr<memory::ReadOnlyByteBuffer> body,
+          int batchId,
+          std::shared_ptr<PushDataCallback> pushDataCallback,
+          std::shared_ptr<PushState> pushState,
+          ShuffleClientImpl::PtrReviveRequest request,
+          int remainReviveTimes,
+          long dueTimeMs) mutable {
+        EXPECT_EQ(shuffleId, testShuffleId);
+        EXPECT_EQ(batchId, testBatchId);
+        EXPECT_EQ(remainReviveTimes, testRemainingReviveTimes - 1);
+
+        submitRetryCalledTimes++;
+      });
+
+  auto pushDataCallback = PushDataCallback::create(
+      testShuffleId,
+      testMapId,
+      testAttemptId,
+      testPartitionId,
+      testNumMappers,
+      testNumPartitions,
+      testMapKey,
+      testBatchId,
+      memory::ReadOnlyByteBuffer::createEmptyBuffer(),
+      pushState,
+      mockClient->weak_from_this(),
+      testRemainingReviveTimes,
+      testLastestLocation);
+
+  auto exception = std::make_unique<std::runtime_error>("test");
+  pushDataCallback->onFailure(std::move(exception));
+  EXPECT_GT(addRequestCalledTimes, 0);
+  EXPECT_GT(addRetryTaskCalledTimes, 0);
+  EXPECT_GT(submitRetryCalledTimes, 0);
+}
+
+TEST(PushDataCallbackTest, onFailureAndNoRevive) {
+  const auto celebornConf = conf::CelebornConf();
+  auto pushState = std::make_shared<PushState>(celebornConf);
+  const int testRemainingReviveTimes = 0;
+  auto mockClient = MockShuffleClient::create();
+
+  auto pushDataCallback = PushDataCallback::create(
+      testShuffleId,
+      testMapId,
+      testAttemptId,
+      testPartitionId,
+      testNumMappers,
+      testNumPartitions,
+      testMapKey,
+      testBatchId,
+      memory::ReadOnlyByteBuffer::createEmptyBuffer(),
+      pushState,
+      mockClient->weak_from_this(),
+      testRemainingReviveTimes,
+      testLastestLocation);
+
+  auto exception = std::make_unique<std::runtime_error>("test");
+  pushDataCallback->onFailure(std::move(exception));
+  EXPECT_TRUE(pushState->exceptionExists());
+}
diff --git a/cpp/celeborn/client/writer/PushDataCallback.cpp 
b/cpp/celeborn/client/writer/PushDataCallback.cpp
new file mode 100644
index 000000000..43550ba7f
--- /dev/null
+++ b/cpp/celeborn/client/writer/PushDataCallback.cpp
@@ -0,0 +1,266 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "celeborn/client/writer/PushDataCallback.h"
+#include "celeborn/conf/CelebornConf.h"
+
+namespace celeborn {
+namespace client {
+
+std::shared_ptr<PushDataCallback> PushDataCallback::create(
+    int shuffleId,
+    int mapId,
+    int attemptId,
+    int partitionId,
+    int numMappers,
+    int numPartitions,
+    const std::string& mapKey,
+    int batchId,
+    std::unique_ptr<memory::ReadOnlyByteBuffer> databody,
+    std::shared_ptr<PushState> pushState,
+    std::weak_ptr<ShuffleClientImpl> weakClient,
+    int remainingReviveTimes,
+    std::shared_ptr<const protocol::PartitionLocation> latestLocation) {
+  return std::shared_ptr<PushDataCallback>(new PushDataCallback(
+      shuffleId,
+      mapId,
+      attemptId,
+      partitionId,
+      numMappers,
+      numPartitions,
+      mapKey,
+      batchId,
+      std::move(databody),
+      pushState,
+      weakClient,
+      remainingReviveTimes,
+      latestLocation));
+}
+
+PushDataCallback::PushDataCallback(
+    int shuffleId,
+    int mapId,
+    int attemptId,
+    int partitionId,
+    int numMappers,
+    int numPartitions,
+    const std::string& mapKey,
+    int batchId,
+    std::unique_ptr<memory::ReadOnlyByteBuffer> databody,
+    std::shared_ptr<PushState> pushState,
+    std::weak_ptr<ShuffleClientImpl> weakClient,
+    int remainingReviveTimes,
+    std::shared_ptr<const protocol::PartitionLocation> latestLocation)
+    : shuffleId_(shuffleId),
+      mapId_(mapId),
+      attemptId_(attemptId),
+      partitionId_(partitionId),
+      numMappers_(numMappers),
+      numPartitions_(numPartitions),
+      mapKey_(mapKey),
+      batchId_(batchId),
+      databody_(std::move(databody)),
+      pushState_(pushState),
+      weakClient_(weakClient),
+      remainingReviveTimes_(remainingReviveTimes),
+      latestLocation_(latestLocation) {}
+
+void PushDataCallback::onSuccess(
+    std::unique_ptr<memory::ReadOnlyByteBuffer> response) {
+  auto sharedClient = weakClient_.lock();
+  if (!sharedClient) {
+    LOG(WARNING) << "ShuffleClientImpl has expired when "
+                    "PushDataCallbackOnSuccess, ignored, shuffle "
+                 << shuffleId_ << " map " << mapId_ << " attempt " << 
attemptId_
+                 << " partition " << partitionId_ << " batch " << batchId_
+                 << ".";
+    return;
+  }
+  if (response->remainingSize() <= 0) {
+    pushState_->onSuccess(latestLocation_->hostAndPushPort());
+    pushState_->removeBatch(batchId_, latestLocation_->hostAndPushPort());
+    return;
+  }
+  protocol::StatusCode reason =
+      static_cast<protocol::StatusCode>(response->read<uint8_t>());
+  switch (reason) {
+    case protocol::StatusCode::MAP_ENDED: {
+      auto mapperEndSet = sharedClient->mapperEndSets().computeIfAbsent(
+          shuffleId_,
+          []() { return std::make_shared<utils::ConcurrentHashSet<int>>(); });
+      mapperEndSet->insert(mapId_);
+      break;
+    }
+    case protocol::StatusCode::SOFT_SPLIT: {
+      VLOG(1) << "Push data to " << latestLocation_->hostAndPushPort()
+              << " soft split required for shuffle " << shuffleId_ << " map "
+              << mapId_ << " attempt " << attemptId_ << " partition "
+              << partitionId_ << " batch " << batchId_ << ".";
+      if (!ShuffleClientImpl::newerPartitionLocationExists(
+              sharedClient->getPartitionLocationMap(shuffleId_).value(),
+              partitionId_,
+              latestLocation_->epoch)) {
+        auto reviveRequest = std::make_shared<protocol::ReviveRequest>(
+            shuffleId_,
+            mapId_,
+            attemptId_,
+            partitionId_,
+            latestLocation_->epoch,
+            latestLocation_,
+            protocol::StatusCode::SOFT_SPLIT);
+        sharedClient->addRequestToReviveManager(reviveRequest);
+      }
+      pushState_->onSuccess(latestLocation_->hostAndPushPort());
+      pushState_->removeBatch(batchId_, latestLocation_->hostAndPushPort());
+      break;
+    }
+    case protocol::StatusCode::HARD_SPLIT: {
+      VLOG(1) << "Push data to " << latestLocation_->hostAndPushPort()
+              << " hard split required for shuffle " << shuffleId_ << " map "
+              << mapId_ << " attempt " << attemptId_ << " partition "
+              << partitionId_ << " batch " << batchId_ << ".";
+      reviveAndRetryPushData(*sharedClient, protocol::StatusCode::HARD_SPLIT);
+      break;
+    }
+    case protocol::StatusCode::PUSH_DATA_SUCCESS_PRIMARY_CONGESTED: {
+      VLOG(1) << "Push data to " << latestLocation_->hostAndPushPort()
+              << " primary congestion required for shuffle " << shuffleId_
+              << " map " << mapId_ << " attempt " << attemptId_ << " partition 
"
+              << partitionId_ << " batch " << batchId_ << ".";
+      pushState_->onCongestControl(latestLocation_->hostAndPushPort());
+      pushState_->removeBatch(batchId_, latestLocation_->hostAndPushPort());
+      break;
+    }
+    case protocol::StatusCode::PUSH_DATA_SUCCESS_REPLICA_CONGESTED: {
+      VLOG(1) << "Push data to " << latestLocation_->hostAndPushPort()
+              << " replicate congestion required for shuffle " << shuffleId_
+              << " map " << mapId_ << " attempt " << attemptId_ << " partition 
"
+              << partitionId_ << " batch " << batchId_ << ".";
+      pushState_->onCongestControl(latestLocation_->hostAndPushPort());
+      pushState_->removeBatch(batchId_, latestLocation_->hostAndPushPort());
+      break;
+    }
+    default: {
+      // This is treated as success.
+      LOG(WARNING) << "unhandled PushData success protocol::StatusCode: "
+                   << reason;
+    }
+  }
+}
+
+void PushDataCallback::onFailure(std::unique_ptr<std::exception> exception) {
+  auto sharedClient = weakClient_.lock();
+  if (!sharedClient) {
+    LOG(WARNING) << "ShuffleClientImpl has expired when "
+                    "PushDataCallbackOnFailure, ignored, shuffle "
+                 << shuffleId_ << " map " << mapId_ << " attempt " << 
attemptId_
+                 << " partition " << partitionId_ << " batch " << batchId_
+                 << ".";
+    return;
+  }
+  if (pushState_->exceptionExists()) {
+    return;
+  }
+
+  LOG(ERROR) << "Push data to " << latestLocation_->hostAndPushPort()
+             << " failed for shuffle " << shuffleId_ << " map " << mapId_
+             << " attempt " << attemptId_ << " partition " << partitionId_
+             << " batch " << batchId_ << ", remain revive times "
+             << remainingReviveTimes_;
+
+  if (remainingReviveTimes_ <= 0) {
+    // TODO: set more specific exception.
+    pushState_->setException(std::move(exception));
+    return;
+  }
+
+  if (sharedClient->mapperEnded(shuffleId_, mapId_)) {
+    pushState_->removeBatch(batchId_, latestLocation_->hostAndPushPort());
+    LOG(INFO) << "Push data to " << latestLocation_->hostAndPushPort()
+              << " failed but mapper already ended for shuffle " << shuffleId_
+              << " map " << mapId_ << " attempt " << attemptId_ << " partition 
"
+              << partitionId_ << " batch " << batchId_
+              << ", remain revive times " << remainingReviveTimes_ << ".";
+    return;
+  }
+  remainingReviveTimes_--;
+  // TODO: we use PRIMARY exception as the dummy value here, but the cause
+  //  should be extracted from error msg. Especially, the cause should tell if
+  //  the exception is from PRIMARY or REPLICATE.
+  protocol::StatusCode cause =
+      protocol::StatusCode::PUSH_DATA_CONNECTION_EXCEPTION_PRIMARY;
+  reviveAndRetryPushData(*sharedClient, cause);
+}
+
+void PushDataCallback::updateLatestLocation(
+    std::shared_ptr<const protocol::PartitionLocation> latestLocation) {
+  pushState_->addBatch(batchId_, latestLocation->hostAndPushPort());
+  pushState_->removeBatch(batchId_, latestLocation_->hostAndPushPort());
+  latestLocation_ = latestLocation;
+}
+
+void PushDataCallback::reviveAndRetryPushData(
+    ShuffleClientImpl& shuffleClient,
+    protocol::StatusCode cause) {
+  auto reviveRequest = std::make_shared<protocol::ReviveRequest>(
+      shuffleId_,
+      mapId_,
+      attemptId_,
+      partitionId_,
+      latestLocation_->epoch,
+      latestLocation_,
+      cause);
+  VLOG(1) << "addRequest to reviveManager, shuffleId "
+          << reviveRequest->shuffleId << " mapId " << reviveRequest->mapId
+          << " attemptId " << reviveRequest->attemptId << " partitionId "
+          << reviveRequest->partitionId << " batchId " << batchId_ << " epoch "
+          << reviveRequest->epoch;
+  shuffleClient.addRequestToReviveManager(reviveRequest);
+  long dueTimeMs = utils::currentTimeMillis() +
+      shuffleClient.conf_->clientRpcRequestPartitionLocationRpcAskTimeout() /
+          utils::MS(1);
+  shuffleClient.addPushDataRetryTask(
+      [weakClient = this->weakClient_,
+       shuffleId = this->shuffleId_,
+       body = this->databody_->clone(),
+       batchId = this->batchId_,
+       callback = shared_from_this(),
+       pushState = this->pushState_,
+       reviveRequest,
+       remainingReviveTimes = this->remainingReviveTimes_,
+       dueTimeMs]() {
+        auto sharedClient = weakClient.lock();
+        if (!sharedClient) {
+          LOG(WARNING) << "ShuffleClientImpl has expired when "
+                          "PushDataFailureCallback, ignored, shuffleId "
+                       << shuffleId;
+          return;
+        }
+        sharedClient->submitRetryPushData(
+            shuffleId,
+            body->clone(),
+            batchId,
+            callback,
+            pushState,
+            reviveRequest,
+            remainingReviveTimes,
+            dueTimeMs);
+      });
+}
+
+} // namespace client
+} // namespace celeborn
diff --git a/cpp/celeborn/client/writer/PushDataCallback.h 
b/cpp/celeborn/client/writer/PushDataCallback.h
new file mode 100644
index 000000000..9916cd191
--- /dev/null
+++ b/cpp/celeborn/client/writer/PushDataCallback.h
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include "celeborn/client/ShuffleClient.h"
+#include "celeborn/client/writer/PushState.h"
+
+namespace celeborn {
+namespace client {
+
+class ShuffleClientImpl;
+
+class PushDataCallback : public network::RpcResponseCallback,
+                         public std::enable_shared_from_this<PushDataCallback> 
{
+ public:
+  // Only allow construction from create() method to ensure that functionality
+  // of std::shared_from_this works properly.
+  static std::shared_ptr<PushDataCallback> create(
+      int shuffleId,
+      int mapId,
+      int attemptId,
+      int partitionId,
+      int numMappers,
+      int numPartitions,
+      const std::string& mapKey,
+      int batchId,
+      std::unique_ptr<memory::ReadOnlyByteBuffer> databody,
+      std::shared_ptr<PushState> pushState,
+      std::weak_ptr<ShuffleClientImpl> weakClient,
+      int remainingReviveTimes,
+      std::shared_ptr<const protocol::PartitionLocation> latestLocation);
+
+  void onSuccess(std::unique_ptr<memory::ReadOnlyByteBuffer> response) 
override;
+
+  void onFailure(std::unique_ptr<std::exception> exception) override;
+
+  // The location of a PushDataCallback might be updated if a revive is
+  // involved, and the location must be updated by calling
+  // updateLatestLocation() to make sure the location is properly updated.
+  void updateLatestLocation(
+      std::shared_ptr<const protocol::PartitionLocation> latestLocation);
+
+ private:
+  // The constructor is hidden to ensure that functionality of
+  // std::shared_from_this works properly.
+  PushDataCallback(
+      int shuffleId,
+      int mapId,
+      int attemptId,
+      int partitionId,
+      int numMappers,
+      int numPartitions,
+      const std::string& mapKey,
+      int batchId,
+      std::unique_ptr<memory::ReadOnlyByteBuffer> databody,
+      std::shared_ptr<PushState> pushState,
+      std::weak_ptr<ShuffleClientImpl> weakClient,
+      int remainingReviveTimes,
+      std::shared_ptr<const protocol::PartitionLocation> latestLocation);
+
+  void reviveAndRetryPushData(
+      ShuffleClientImpl& shuffleClient,
+      protocol::StatusCode cause);
+
+  const int shuffleId_;
+  const int mapId_;
+  const int attemptId_;
+  const int partitionId_;
+  const int numMappers_;
+  const int numPartitions_;
+  const std::string mapKey_;
+  const int batchId_;
+  const std::unique_ptr<memory::ReadOnlyByteBuffer> databody_;
+  const std::shared_ptr<PushState> pushState_;
+  const std::weak_ptr<ShuffleClientImpl> weakClient_;
+  int remainingReviveTimes_;
+  std::shared_ptr<const protocol::PartitionLocation> latestLocation_;
+};
+
+} // namespace client
+} // namespace celeborn
diff --git a/cpp/celeborn/client/writer/ReviveManager.h 
b/cpp/celeborn/client/writer/ReviveManager.h
index 4f3b7962e..925e2b924 100644
--- a/cpp/celeborn/client/writer/ReviveManager.h
+++ b/cpp/celeborn/client/writer/ReviveManager.h
@@ -26,6 +26,7 @@
 namespace celeborn {
 namespace client {
 
+class ShuffleClientImpl;
 /// ReviveManager is responsible for buffering the ReviveRequests, and issue
 /// the revive requests periodically in batches.
 class ReviveManager : public std::enable_shared_from_this<ReviveManager> {
diff --git a/cpp/celeborn/conf/CelebornConf.cpp 
b/cpp/celeborn/conf/CelebornConf.cpp
index ab1a7abf0..1d58516b3 100644
--- a/cpp/celeborn/conf/CelebornConf.cpp
+++ b/cpp/celeborn/conf/CelebornConf.cpp
@@ -142,10 +142,11 @@ const std::unordered_map<std::string, 
folly::Optional<std::string>>
         NUM_PROP(kClientPushMaxReqsInFlightTotal, 256),
         NUM_PROP(kClientPushLimitInFlightTimeoutMs, 240000),
         NUM_PROP(kClientPushLimitInFlightSleepDeltaMs, 50),
+        STR_PROP(kClientRpcRequestPartitionLocationAskTimeout, "60s"),
         STR_PROP(kClientRpcGetReducerFileGroupRpcAskTimeout, "60s"),
         STR_PROP(kNetworkConnectTimeout, "10s"),
         STR_PROP(kClientFetchTimeout, "600s"),
-        NUM_PROP(kNetworkIoNumConnectionsPerPeer, "1"),
+        NUM_PROP(kNetworkIoNumConnectionsPerPeer, 1),
         NUM_PROP(kNetworkIoClientThreads, 0),
         NUM_PROP(kClientFetchMaxReqsInFlight, 3),
         STR_PROP(
@@ -223,6 +224,11 @@ long CelebornConf::clientPushLimitInFlightSleepDeltaMs() 
const {
       optionalProperty(kClientPushLimitInFlightSleepDeltaMs).value());
 }
 
+Timeout CelebornConf::clientRpcRequestPartitionLocationRpcAskTimeout() const {
+  return utils::toTimeout(toDuration(
+      optionalProperty(kClientRpcRequestPartitionLocationAskTimeout).value()));
+}
+
 Timeout CelebornConf::clientRpcGetReducerFileGroupRpcAskTimeout() const {
   return utils::toTimeout(toDuration(
       optionalProperty(kClientRpcGetReducerFileGroupRpcAskTimeout).value()));
diff --git a/cpp/celeborn/conf/CelebornConf.h b/cpp/celeborn/conf/CelebornConf.h
index d6ce24f18..530a59781 100644
--- a/cpp/celeborn/conf/CelebornConf.h
+++ b/cpp/celeborn/conf/CelebornConf.h
@@ -68,6 +68,10 @@ class CelebornConf : public BaseConf {
   static constexpr std::string_view kClientPushLimitInFlightSleepDeltaMs{
       "celeborn.client.push.limit.inFlight.sleepInterval"};
 
+  static constexpr std::string_view
+      kClientRpcRequestPartitionLocationAskTimeout{
+          "celeborn.client.rpc.requestPartition.askTimeout"};
+
   static constexpr std::string_view kClientRpcGetReducerFileGroupRpcAskTimeout{
       "celeborn.client.rpc.getReducerFileGroup.askTimeout"};
 
@@ -120,6 +124,8 @@ class CelebornConf : public BaseConf {
 
   long clientPushLimitInFlightSleepDeltaMs() const;
 
+  Timeout clientRpcRequestPartitionLocationRpcAskTimeout() const;
+
   Timeout clientRpcGetReducerFileGroupRpcAskTimeout() const;
 
   Timeout networkConnectTimeout() const;
diff --git a/cpp/celeborn/protocol/PartitionLocation.cpp 
b/cpp/celeborn/protocol/PartitionLocation.cpp
index a1820c5df..333a4ef13 100644
--- a/cpp/celeborn/protocol/PartitionLocation.cpp
+++ b/cpp/celeborn/protocol/PartitionLocation.cpp
@@ -148,6 +148,18 @@ std::unique_ptr<PbPartitionLocation> 
PartitionLocation::toPbWithoutPeer()
   return pbPartitionLocation;
 }
 
+std::string PartitionLocation::filename() const {
+  return fmt::format("{}-{}-{}", id, epoch, static_cast<int>(mode));
+}
+
+std::string PartitionLocation::uniqueId() const {
+  return fmt::format("{}-{}", id, epoch);
+}
+
+std::string PartitionLocation::hostAndPushPort() const {
+  return fmt::format("{}:{}", host, pushPort);
+}
+
 StatusCode toStatusCode(int32_t code) {
   CELEBORN_CHECK(code >= 0);
   CELEBORN_CHECK(code <= StatusCode::TAIL);
diff --git a/cpp/celeborn/protocol/PartitionLocation.h 
b/cpp/celeborn/protocol/PartitionLocation.h
index 4e88d16de..16a535bd0 100644
--- a/cpp/celeborn/protocol/PartitionLocation.h
+++ b/cpp/celeborn/protocol/PartitionLocation.h
@@ -92,10 +92,11 @@ struct PartitionLocation {
 
   std::unique_ptr<PbPartitionLocation> toPb() const;
 
-  std::string filename() const {
-    return std::to_string(id) + "-" + std::to_string(epoch) + "-" +
-        std::to_string(mode);
-  }
+  std::string filename() const;
+
+  std::string uniqueId() const;
+
+  std::string hostAndPushPort() const;
 
  private:
   static std::unique_ptr<PartitionLocation> fromPbWithoutPeer(
diff --git a/cpp/celeborn/utils/CelebornUtils.h 
b/cpp/celeborn/utils/CelebornUtils.h
index 669e4060e..ac1419914 100644
--- a/cpp/celeborn/utils/CelebornUtils.h
+++ b/cpp/celeborn/utils/CelebornUtils.h
@@ -62,6 +62,12 @@ inline Timeout toTimeout(Duration duration) {
   return std::chrono::duration_cast<Timeout>(duration);
 }
 
+inline uint64_t currentTimeMillis() {
+  return std::chrono::duration_cast<std::chrono::milliseconds>(
+             std::chrono::system_clock::now().time_since_epoch())
+      .count();
+}
+
 /// parse string like "Any-Host-Str:Port#1:Port#2:...:Port#num", split into
 /// {"Any-Host-Str", "Port#1", "Port#2", ..., "Port#num"}. Note that the
 /// "Any-Host-Str" might contain ':' in IPV6 address.


Reply via email to