This is an automated email from the ASF dual-hosted git repository.

nicholasjiang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 5d6ce216a [CELEBORN-1983][FOLLOWUP] Fix fetch fail not throw due to 
reach spark maxTaskFailures
5d6ce216a is described below

commit 5d6ce216a3267e66d86e52d2127fba75d898a1bf
Author: Xianming Lei <[email protected]>
AuthorDate: Thu Nov 20 11:51:22 2025 +0800

    [CELEBORN-1983][FOLLOWUP] Fix fetch fail not throw due to reach spark 
maxTaskFailures
    
    ### What changes were proposed in this pull request?
    Fix fetch fail not throw due to reach spark maxTaskFailures.
    
    ### Why are the changes needed?
    The condition `ti.attemptNumber() >= maxTaskFails - 1` may not be executed. 
Suppose that the current `taskAttempts` is index0, index1, index2, and index3, 
and that index0 and index1 have already failed while index2 and index3 are 
running, and the current `reportFetchFailed` is index3, then the final result 
will be false, while the expected result will be true.
    Therefore, we should check the attemptNumber of the current task separately 
before the loop starts.
    
    <img width="3558" height="608" alt="image" 
src="https://github.com/user-attachments/assets/2a0af3e7-912e-420e-a864-4c525d07e251";
 />
    <img width="2332" height="814" alt="image" 
src="https://github.com/user-attachments/assets/bf832091-56d5-41b8-b58a-502e409d67a8";
 />
    
    ### Does this PR resolve a correctness bug?
    
    No.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Existing UTs.
    
    Closes #3531 from leixm/follow_CELEBORN-1983.
    
    Authored-by: Xianming Lei <[email protected]>
    Signed-off-by: SteNicholas <[email protected]>
---
 .../apache/spark/shuffle/celeborn/SparkUtils.java  | 43 +++++++++++++++++-----
 .../apache/spark/shuffle/celeborn/SparkUtils.java  | 43 +++++++++++++++++-----
 2 files changed, 66 insertions(+), 20 deletions(-)

diff --git 
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
 
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
index bd57cd329..a3ec12a1e 100644
--- 
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
+++ 
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
@@ -335,16 +335,19 @@ public class SparkUtils {
         if (taskAttempts == null) return true;
 
         TaskInfo taskInfo = taskAttempts._1();
+        int failedTaskAttempts = 1;
+        boolean hasRunningAttempt = false;
         for (TaskInfo ti : taskAttempts._2()) {
           if (ti.taskId() != taskId) {
             if (reportedStageTaskIds.contains(ti.taskId())) {
               logger.info(
-                  "StageId={} index={} taskId={} attempt={} another attempt {} 
has reported shuffle fetch failure, ignore it.",
+                  "StageId={} index={} taskId={} attempt={} another attempt {} 
has reported shuffle fetch failure.",
                   stageId,
                   taskInfo.index(),
                   taskId,
                   taskInfo.attemptNumber(),
                   ti.attemptNumber());
+              failedTaskAttempts += 1;
             } else if (ti.successful()) {
               logger.info(
                   "StageId={} index={} taskId={} attempt={} another attempt {} 
is successful.",
@@ -362,22 +365,42 @@ public class SparkUtils {
                   taskId,
                   taskInfo.attemptNumber(),
                   ti.attemptNumber());
-              return false;
-            }
-          } else {
-            if (ti.attemptNumber() >= maxTaskFails - 1) {
-              logger.warn(
-                  "StageId={} index={} taskId={} attemptNumber {} reach 
maxTaskFails {}.",
+              hasRunningAttempt = true;
+            } else if ("FAILED".equals(ti.status()) || 
"UNKNOWN".equals(ti.status())) {
+              // For KILLED state task, Spark does not count the number of 
failures
+              // For UNKNOWN state task, Spark does count the number of 
failures
+              // For FAILED state task, Spark decides whether to count the 
failure based on the
+              // different failure reasons. Since we cannot obtain the failure
+              // reason here, we will count all tasks in FAILED state.
+              logger.info(
+                  "StageId={} index={} taskId={} attempt={} another attempt {} 
status={}.",
                   stageId,
                   taskInfo.index(),
                   taskId,
+                  taskInfo.attemptNumber(),
                   ti.attemptNumber(),
-                  maxTaskFails);
-              return true;
+                  ti.status());
+              failedTaskAttempts += 1;
             }
           }
         }
-        return true;
+        // The following situations should trigger a FetchFailed exception:
+        //  1. If failedTaskAttempts >= maxTaskFails
+        //  2. If no other taskAttempts are running
+        if (failedTaskAttempts >= maxTaskFails || !hasRunningAttempt) {
+          logger.warn(
+              "StageId={}, index={}, taskId={}, attemptNumber={}: Task failure 
count {} reached "
+                  + "maximum allowed failures {} or no running attempt 
exists.",
+              stageId,
+              taskInfo.index(),
+              taskId,
+              taskInfo.attemptNumber(),
+              failedTaskAttempts,
+              maxTaskFails);
+          return true;
+        } else {
+          return false;
+        }
       } else {
         logger.error(
             "Can not get TaskSetManager for taskId: {}, ignore it. (This 
typically occurs when: "
diff --git 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
index 82b3cf405..2b30a2020 100644
--- 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
+++ 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SparkUtils.java
@@ -450,16 +450,19 @@ public class SparkUtils {
         if (taskAttempts == null) return true;
 
         TaskInfo taskInfo = taskAttempts._1();
+        int failedTaskAttempts = 1;
+        boolean hasRunningAttempt = false;
         for (TaskInfo ti : taskAttempts._2()) {
           if (ti.taskId() != taskId) {
             if (reportedStageTaskIds.contains(ti.taskId())) {
               LOG.info(
-                  "StageId={} index={} taskId={} attempt={} another attempt {} 
has reported shuffle fetch failure, ignore it.",
+                  "StageId={} index={} taskId={} attempt={} another attempt {} 
has reported shuffle fetch failure.",
                   stageId,
                   taskInfo.index(),
                   taskId,
                   taskInfo.attemptNumber(),
                   ti.attemptNumber());
+              failedTaskAttempts += 1;
             } else if (ti.successful()) {
               LOG.info(
                   "StageId={} index={} taskId={} attempt={} another attempt {} 
is successful.",
@@ -477,22 +480,42 @@ public class SparkUtils {
                   taskId,
                   taskInfo.attemptNumber(),
                   ti.attemptNumber());
-              return false;
-            }
-          } else {
-            if (ti.attemptNumber() >= maxTaskFails - 1) {
-              LOG.warn(
-                  "StageId={} index={} taskId={} attemptNumber {} reach 
maxTaskFails {}.",
+              hasRunningAttempt = true;
+            } else if ("FAILED".equals(ti.status()) || 
"UNKNOWN".equals(ti.status())) {
+              // For KILLED state task, Spark does not count the number of 
failures
+              // For UNKNOWN state task, Spark does count the number of 
failures
+              // For FAILED state task, Spark decides whether to count the 
failure based on the
+              // different failure reasons. Since we cannot obtain the failure
+              // reason here, we will count all tasks in FAILED state.
+              LOG.info(
+                  "StageId={} index={} taskId={} attempt={} another attempt {} 
status={}.",
                   stageId,
                   taskInfo.index(),
                   taskId,
+                  taskInfo.attemptNumber(),
                   ti.attemptNumber(),
-                  maxTaskFails);
-              return true;
+                  ti.status());
+              failedTaskAttempts += 1;
             }
           }
         }
-        return true;
+        // The following situations should trigger a FetchFailed exception:
+        //  1. If failedTaskAttempts >= maxTaskFails
+        //  2. If no other taskAttempts are running
+        if (failedTaskAttempts >= maxTaskFails || !hasRunningAttempt) {
+          LOG.warn(
+              "StageId={}, index={}, taskId={}, attemptNumber={}: Task failure 
count {} reached "
+                  + "maximum allowed failures {} or no running attempt 
exists.",
+              stageId,
+              taskInfo.index(),
+              taskId,
+              taskInfo.attemptNumber(),
+              failedTaskAttempts,
+              maxTaskFails);
+          return true;
+        } else {
+          return false;
+        }
       } else {
         LOG.error(
             "Can not get TaskSetManager for taskId: {}, ignore it. (This 
typically occurs when: "

Reply via email to