This is an automated email from the ASF dual-hosted git repository.

zhouky pushed a commit to branch branch-0.3
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/branch-0.3 by this push:
     new 7d2c47db0 [CELEBORN-1059] Fix callback not update if push worker 
excluded during retry
7d2c47db0 is described below

commit 7d2c47db0f3bd25c7386fbbf7a6793c527796146
Author: onebox-li <[email protected]>
AuthorDate: Wed Nov 1 10:23:50 2023 +0800

    [CELEBORN-1059] Fix callback not update if push worker excluded during retry
    
    ### What changes were proposed in this pull request?
    When retry push data and revive succeed in 
ShuffleClientImpl#submitRetryPushData, if new location is excluded, the 
callback's `lastest` location has not been updated when 
wrappedCallback.onFailure is called in 
ShuffleClientImpl#isPushTargetWorkerExcluded. Therefore there may be problems 
with subsequent revive.
    
    ### Why are the changes needed?
    Ditto
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Manual test.
    
    Closes #2005 from onebox-li/improve-push-exclude.
    
    Authored-by: onebox-li <[email protected]>
    Signed-off-by: zky.zhoukeyong <[email protected]>
    (cherry picked from commit cd8acf89c968aad47ae3edcd5b63edc1a76721c7)
    Signed-off-by: zky.zhoukeyong <[email protected]>
---
 .../apache/celeborn/client/ShuffleClientImpl.java  | 38 ++++++++++------------
 .../celeborn/client/ChangePartitionManager.scala   |  4 ++-
 .../celeborn/tests/spark/RetryReviveTest.scala     |  2 +-
 3 files changed, 21 insertions(+), 23 deletions(-)

diff --git 
a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java 
b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
index 85379a535..c532d4f2a 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
@@ -272,6 +272,7 @@ public class ShuffleClientImpl extends ShuffleClient {
           partitionId,
           batchId,
           newLoc);
+      pushDataRpcResponseCallback.updateLatestPartition(newLoc);
       try {
         if (!isPushTargetWorkerExcluded(newLoc, pushDataRpcResponseCallback)) {
           if (!testRetryRevive || remainReviveTimes < 1) {
@@ -281,7 +282,6 @@ public class ShuffleClientImpl extends ShuffleClient {
             String shuffleKey = Utils.makeShuffleKey(appUniqueId, shuffleId);
             PushData newPushData =
                 new PushData(PRIMARY_MODE, shuffleKey, newLoc.getUniqueId(), 
newBuffer);
-            pushDataRpcResponseCallback.updateLatestPartition(newLoc);
             client.pushData(newPushData, pushDataTimeout, 
pushDataRpcResponseCallback);
           } else {
             throw new RuntimeException(
@@ -633,18 +633,17 @@ public class ShuffleClientImpl extends ShuffleClient {
 
   void excludeWorkerByCause(StatusCode cause, PartitionLocation oldLocation) {
     if (pushExcludeWorkerOnFailureEnabled && oldLocation != null) {
-      if (cause == StatusCode.PUSH_DATA_CREATE_CONNECTION_FAIL_PRIMARY) {
-        pushExcludedWorkers.add(oldLocation.hostAndPushPort());
-      } else if (cause == StatusCode.PUSH_DATA_CONNECTION_EXCEPTION_PRIMARY) {
-        pushExcludedWorkers.add(oldLocation.hostAndPushPort());
-      } else if (cause == StatusCode.PUSH_DATA_TIMEOUT_PRIMARY) {
-        pushExcludedWorkers.add(oldLocation.hostAndPushPort());
-      } else if (cause == StatusCode.PUSH_DATA_CREATE_CONNECTION_FAIL_REPLICA) 
{
-        pushExcludedWorkers.add(oldLocation.getPeer().hostAndPushPort());
-      } else if (cause == StatusCode.PUSH_DATA_CONNECTION_EXCEPTION_REPLICA) {
-        pushExcludedWorkers.add(oldLocation.getPeer().hostAndPushPort());
-      } else if (cause == StatusCode.PUSH_DATA_TIMEOUT_REPLICA) {
-        pushExcludedWorkers.add(oldLocation.getPeer().hostAndPushPort());
+      switch (cause) {
+        case PUSH_DATA_CREATE_CONNECTION_FAIL_PRIMARY:
+        case PUSH_DATA_CONNECTION_EXCEPTION_PRIMARY:
+        case PUSH_DATA_TIMEOUT_PRIMARY:
+          pushExcludedWorkers.add(oldLocation.hostAndPushPort());
+          break;
+        case PUSH_DATA_CREATE_CONNECTION_FAIL_REPLICA:
+        case PUSH_DATA_CONNECTION_EXCEPTION_REPLICA:
+        case PUSH_DATA_TIMEOUT_REPLICA:
+          pushExcludedWorkers.add(oldLocation.getPeer().hostAndPushPort());
+          break;
       }
     }
   }
@@ -905,10 +904,10 @@ public class ShuffleClientImpl extends ShuffleClient {
             PartitionLocation latest = loc;
 
             @Override
-            public void updateLatestPartition(PartitionLocation latest) {
-              pushState.addBatch(nextBatchId, latest.hostAndPushPort());
+            public void updateLatestPartition(PartitionLocation newloc) {
+              pushState.addBatch(nextBatchId, newloc.hostAndPushPort());
               pushState.removeBatch(nextBatchId, 
this.latest.hostAndPushPort());
-              this.latest = latest;
+              this.latest = newloc;
             }
 
             @Override
@@ -1003,12 +1002,10 @@ public class ShuffleClientImpl extends ShuffleClient {
 
             @Override
             public void onFailure(Throwable e) {
-              StatusCode cause = getPushDataFailCause(e.getMessage());
-
               if (pushState.exception.get() != null) {
                 return;
               }
-
+              StatusCode cause = getPushDataFailCause(e.getMessage());
               if (remainReviveTimes <= 0) {
                 if (e instanceof CelebornIOException) {
                   callback.onFailure(e);
@@ -1383,11 +1380,10 @@ public class ShuffleClientImpl extends ShuffleClient {
 
           @Override
           public void onFailure(Throwable e) {
-            StatusCode cause = getPushDataFailCause(e.getMessage());
-
             if (pushState.exception.get() != null) {
               return;
             }
+            StatusCode cause = getPushDataFailCause(e.getMessage());
             if (remainReviveTimes <= 0) {
               if (e instanceof CelebornIOException) {
                 callback.onFailure(e);
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala 
b/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala
index e7b1a45b2..4edda41da 100644
--- 
a/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala
+++ 
b/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala
@@ -67,6 +67,8 @@ class ChangePartitionManager(
 
   private var batchHandleChangePartition: Option[ScheduledFuture[_]] = _
 
+  private val testRetryRevive = conf.testRetryRevive
+
   def start(): Unit = {
     batchHandleChangePartition = batchHandleChangePartitionSchedulerThread.map 
{
       // noinspection ConvertExpressionToSAM
@@ -204,7 +206,7 @@ class ChangePartitionManager(
     logWarning(s"Batch handle change partition for $changes")
 
     // Exclude all failed workers
-    if (changePartitions.exists(_.causes.isDefined)) {
+    if (changePartitions.exists(_.causes.isDefined) && !testRetryRevive) {
       changePartitions.filter(_.causes.isDefined).foreach { changePartition =>
         lifecycleManager.workerStatusTracker.excludeWorkerFromPartition(
           shuffleId,
diff --git 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RetryReviveTest.scala
 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RetryReviveTest.scala
index 2ddc895d8..4d0b42ded 100644
--- 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RetryReviveTest.scala
+++ 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RetryReviveTest.scala
@@ -51,7 +51,7 @@ class RetryReviveTest extends AnyFunSuite
       .config(updateSparkConf(sparkConf, ShuffleMode.HASH))
       .getOrCreate()
     val result = ss.sparkContext.parallelize(1 to 1000, 2)
-      .map { i => (i, Range(1, 1000).mkString(",")) }.groupByKey(16).collect()
+      .map { i => (i, Range(1, 1000).mkString(",")) }.groupByKey(4).collect()
     assert(result.size == 1000)
     ss.stop()
   }

Reply via email to