gengliangwang commented on code in PR #55738:
URL: https://github.com/apache/spark/pull/55738#discussion_r3222729880


##########
core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala:
##########
@@ -1618,6 +1647,116 @@ private[spark] class DAGScheduler(
     }
   }
 
+  /**
+   * Returns true when this just-completed shuffle map task should have its 
output corrupted by
+   * the test-only fetch-failure injection. We corrupt only the partition-0 
task, and only on
+   * the stage attempt that first successfully completes partition 0 - latched 
into
+   * injectShuffleFetchFailuresCorruptedAttempt. Recomputes (later attempts) 
of that partition
+   * are left clean so the consumer can make progress on its retry. The latch 
is per-shuffle,
+   * so non-leaf stages whose earlier attempts failed on fetch from upstream 
are still
+   * corrupted on their first successful attempt.
+   */
+  private def shouldCorruptShuffleOutputForTest(shuffleId: Int, task: 
Task[_]): Boolean = {
+    if (task.partitionId != 0) return false
+    val recorded = injectShuffleFetchFailuresCorruptedAttempt.computeIfAbsent(
+      shuffleId, _ => task.stageAttemptId)
+    recorded == task.stageAttemptId
+  }
+
+  /**
+   * Apply the test-only fetch-failure injection to this just-completed map 
task: with
+   * DOWNSTREAM_DELAY > 0 record the mapId so 
maybeApplyDelayedCorruptionForTest can corrupt
+   * it later, otherwise update the MapStatus location to
+   * injectShuffleFetchFailuresInvalidBlockManagerId inline.
+   */
+  private def corruptShuffleOutputForTest(shuffleId: Int, status: MapStatus): 
Unit = {
+    val downstreamDelay =
+      sc.conf.get(config.Tests.INJECT_SHUFFLE_FETCH_FAILURES_DOWNSTREAM_DELAY)
+    if (downstreamDelay > 0) {
+      injectShuffleFetchFailuresPendingDelayedCorruption.put(shuffleId, 
status.mapId)
+    } else {
+      status.updateLocation(injectShuffleFetchFailuresInvalidBlockManagerId)
+    }
+  }
+
+  /**
+   * For INJECT_SHUFFLE_FORCE_CHECKSUM_MISMATCH_ON_RECOMPUTE: returns true 
when this shuffle map
+   * task is the recompute of a partition whose previous successful attempt 
was the one corrupted
+   * by INJECT_SHUFFLE_FETCH_FAILURES. Forcing the mismatch on the recompute 
drives the rollback
+   * path - downstream ShuffleMapStages get cleaned up and re-run fully, 
downstream ResultStages
+   * are aborted.
+   */
+  private def isForcedChecksumMismatchForTest(shuffleId: Int, task: Task[_]): 
Boolean = {
+    if 
(!sc.conf.get(config.Tests.INJECT_SHUFFLE_FORCE_CHECKSUM_MISMATCH_ON_RECOMPUTE))
 return false
+    if (task.partitionId != 0) return false
+    val recorded =
+      injectShuffleFetchFailuresCorruptedAttempt.getOrDefault(shuffleId, -1)
+    recorded >= 0 && recorded != task.stageAttemptId
+  }
+
+  /**
+   * Apply the deferred mapper-0 corruption (configured via
+   * INJECT_SHUFFLE_FETCH_FAILURES_DOWNSTREAM_DELAY for ShuffleMapStage 
consumers and
+   * INJECT_SHUFFLE_FETCH_FAILURES_RESULT_STAGE_DELAY for ResultStage 
consumers) when enough
+   * consumer tasks have succeeded. Walks the just-completed stage's direct 
shuffle parents,
+   * increments the per-shuffle consumer-success counter, and corrupts the 
registered MapStatus
+   * when the counter reaches the configured delay.
+   */
+  private def maybeApplyDelayedCorruptionForTest(stage: Stage): Unit = {
+    if (!sc.conf.get(config.Tests.INJECT_SHUFFLE_FETCH_FAILURES)) return
+    if (injectShuffleFetchFailuresPendingDelayedCorruption.isEmpty) return
+    val isResultStage = stage.isInstanceOf[ResultStage]
+    val delay = if (isResultStage) {
+      
sc.conf.get(config.Tests.INJECT_SHUFFLE_FETCH_FAILURES_RESULT_STAGE_DELAY)
+    } else {
+      sc.conf.get(config.Tests.INJECT_SHUFFLE_FETCH_FAILURES_DOWNSTREAM_DELAY)
+    }
+    if (delay <= 0) return  // delay == 0 was already handled at submission 
time
+
+    val parentShuffleIds = stage.parents.collect {
+      case sms: ShuffleMapStage => sms.shuffleDep.shuffleId
+    }
+    parentShuffleIds.foreach { shuffleId =>
+      if 
(injectShuffleFetchFailuresPendingDelayedCorruption.containsKey(shuffleId)) {
+        val newCount = 
injectShuffleFetchFailuresDownstreamSuccessCount.merge(shuffleId, 1, _ + _)
+        if (newCount >= delay) {
+          val mapId = 
injectShuffleFetchFailuresPendingDelayedCorruption.remove(shuffleId)
+          mapOutputTracker.updateMapOutput(
+            shuffleId, mapId, injectShuffleFetchFailuresInvalidBlockManagerId)

Review Comment:
   Both delayed-corruption call sites (here and the one in 
`maybePreemptiveCorruptionForResultStage`) reach the executor only through 
`MapOutputTracker.updateMapOutput`, which calls 
`invalidateSerializedMapOutputStatusCache()` but **does not** call 
`incrementEpoch()`. Compare `unregisterMapOutput` 
(`MapOutputTracker.scala:874`) which DOES bump the epoch precisely so executors 
invalidate their `MapOutputTrackerWorker.mapStatuses` cache and re-fetch.
   
   In `local[2]` (where these tests run) `env.mapOutputTracker` is 
`MapOutputTrackerMaster` shared with the driver (`Executor.scala:865-872` 
documents this), so the update is immediately visible. In any non-local 
deployment, executors that already fetched this shuffle's statuses would keep 
their cached (valid) location and the FetchFailed cascade would not fire. The 
original inline-corruption code (`status.updateLocation` *before* 
`registerMapOutput`) didn't have this limitation because it mutated the status 
before workers ever fetched. Suggest calling 
`mapOutputTracker.incrementEpoch()` after each `updateMapOutput` here, or 
noting the local-mode-only constraint in the method docstrings.



##########
core/src/main/scala/org/apache/spark/internal/config/Tests.scala:
##########
@@ -41,8 +41,39 @@ private[spark] object Tests {
 
   val INJECT_SHUFFLE_FETCH_FAILURES =
     ConfigBuilder("spark.testing.injectShuffleFetchFailures")
-      .doc("Injecting fetch failures for shuffle stages by providing an 
invalid BlockManager " +
-        "location for the first stage attempt. Testing only flag!")
+      .doc("Corrupt the registered MapStatus of partition 0 on the first 
successful attempt " +
+        "of every shuffle map stage, to induce downstream FetchFailed and 
stage retry. " +
+        "Testing only.")

Review Comment:
   This says corruption happens "on the first successful attempt of every 
shuffle map stage" — true for the latch, but the *visible* timing on the master 
switch alone is now governed by the new `DOWNSTREAM_DELAY=1` default, which 
defers the actual `updateMapOutput` until after 1 consumer task succeeds. A 
reader who only reads this flag's doc would think setting it alone gives inline 
corruption (matching the old semantics), but the new default behavior is 
materially different. Worth either pointing at 
`INJECT_SHUFFLE_FETCH_FAILURES_DOWNSTREAM_DELAY` here ("actual timing also 
depends on the *_DELAY flags below") or briefly noting the deferred-by-default 
behavior.



##########
core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala:
##########
@@ -2271,6 +2414,10 @@ private[spark] class DAGScheduler(
           taskScheduler.notifyPartitionCompletion(stageId, task.partitionId)
         }
 
+        if (Utils.isTesting) {
+          maybeApplyDelayedCorruptionForTest(stage)

Review Comment:
   Minor: this fires before the `task match { case rt: ResultTask ... case smt: 
ShuffleMapTask ... }` block, so it runs for every Success event including those 
for which the rest of the handler is skipped via `ignoreOldTaskAttempts`. A 
late completion from a rolled-back consumer attempt still ticks the per-shuffle 
success counter. Probably never matters for the deterministic local-mode tests 
this targets — but if so, worth a one-line note in 
`maybeApplyDelayedCorruptionForTest`'s docstring; otherwise gate this call on 
`!ignoreOldTaskAttempts`.



##########
core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala:
##########
@@ -2342,21 +2489,16 @@ private[spark] class DAGScheduler(
                 // The epoch of the task is acceptable (i.e., the task was 
launched after the most
                 // recent failure we're aware of for the executor), so mark 
the task's output as
                 // available.
-                // For testing purposes, inject fetch failures controlled from 
the driver-side by
-                // supplying an invalid location.
                 if (Utils.isTesting &&
                     sc.conf.get(config.Tests.INJECT_SHUFFLE_FETCH_FAILURES) &&
-                    task.stageAttemptId == 0) {
-                  val currentLocation = status.location
-                  val invalidLocation = BlockManagerId(
-                    execId = BlockManagerId.INVALID_EXECUTOR_ID,
-                    host = currentLocation.host,
-                    port = currentLocation.port,
-                    topologyInfo = currentLocation.topologyInfo)
-                  status.updateLocation(invalidLocation)
+                    
shouldCorruptShuffleOutputForTest(shuffleStage.shuffleDep.shuffleId, task)) {
+                  
corruptShuffleOutputForTest(shuffleStage.shuffleDep.shuffleId, status)
                 }
-                val isChecksumMismatched = mapOutputTracker.registerMapOutput(
+                val realChecksumMismatched = 
mapOutputTracker.registerMapOutput(
                   shuffleStage.shuffleDep.shuffleId, smt.partitionId, status)
+                val isChecksumMismatched = realChecksumMismatched ||
+                  (Utils.isTesting &&
+                    
isForcedChecksumMismatchForTest(shuffleStage.shuffleDep.shuffleId, task))

Review Comment:
   The original local was `isChecksumMismatched`. Now the tracker result is 
`realChecksumMismatched` and the OR'd combined value reuses the old name on the 
next line — the original name now refers to a different concept, and "real" 
implies a "fake" version exists (which is what the test forcing produces, but 
the reader has no setup for that). Either name the tracker result something 
neutral, or just inline the OR and drop the intermediate:
   ```suggestion
                   val isChecksumMismatched = 
mapOutputTracker.registerMapOutput(
                     shuffleStage.shuffleDep.shuffleId, smt.partitionId, 
status) ||
                     (Utils.isTesting &&
                       
isForcedChecksumMismatchForTest(shuffleStage.shuffleDep.shuffleId, task))
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to