JunRuiLee commented on code in PR #24771:
URL: https://github.com/apache/flink/pull/24771#discussion_r1611044994


##########
flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultBatchJobRecoveryHandler.java:
##########
@@ -0,0 +1,840 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.scheduler.adaptivebatch;
+
+import org.apache.flink.api.common.JobStatus;
+import org.apache.flink.configuration.BatchExecutionOptions;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.JobException;
+import org.apache.flink.runtime.clusterframework.types.ResourceID;
+import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutor;
+import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
+import org.apache.flink.runtime.execution.ExecutionState;
+import org.apache.flink.runtime.executiongraph.Execution;
+import org.apache.flink.runtime.executiongraph.ExecutionGraph;
+import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
+import org.apache.flink.runtime.executiongraph.ExecutionVertex;
+import org.apache.flink.runtime.executiongraph.IntermediateResultPartition;
+import org.apache.flink.runtime.executiongraph.InternalExecutionGraphAccessor;
+import org.apache.flink.runtime.executiongraph.JobVertexInputInfo;
+import org.apache.flink.runtime.executiongraph.ResultPartitionBytes;
+import org.apache.flink.runtime.executiongraph.TaskExecutionStateTransition;
+import org.apache.flink.runtime.executiongraph.failover.FailoverStrategy;
+import org.apache.flink.runtime.failure.FailureEnricherUtils;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
+import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.jobgraph.OperatorID;
+import 
org.apache.flink.runtime.jobmaster.event.ExecutionJobVertexInitializedEvent;
+import org.apache.flink.runtime.jobmaster.event.ExecutionVertexFinishedEvent;
+import org.apache.flink.runtime.jobmaster.event.ExecutionVertexResetEvent;
+import org.apache.flink.runtime.jobmaster.event.JobEvent;
+import org.apache.flink.runtime.jobmaster.event.JobEventManager;
+import org.apache.flink.runtime.jobmaster.event.JobEventReplayHandler;
+import 
org.apache.flink.runtime.operators.coordination.OperatorCoordinatorHolder;
+import org.apache.flink.runtime.scheduler.strategy.ConsumerVertexGroup;
+import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
+import org.apache.flink.runtime.shuffle.DefaultShuffleMasterSnapshotContext;
+import org.apache.flink.runtime.shuffle.PartitionWithMetrics;
+import org.apache.flink.runtime.shuffle.ShuffleDescriptor;
+import org.apache.flink.runtime.shuffle.ShuffleMaster;
+import org.apache.flink.runtime.shuffle.ShuffleMasterSnapshot;
+import org.apache.flink.util.clock.Clock;
+import org.apache.flink.util.clock.SystemClock;
+import org.apache.flink.util.function.ConsumerWithException;
+import org.apache.flink.util.function.QuadConsumerWithException;
+import org.apache.flink.util.function.TriConsumer;
+
+import org.apache.flink.shaded.guava31.com.google.common.collect.Sets;
+
+import org.slf4j.Logger;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.LinkedHashMap;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.CompletableFuture;
+import java.util.function.Consumer;
+import java.util.stream.Collectors;
+
+import static 
org.apache.flink.runtime.operators.coordination.OperatorCoordinator.NO_CHECKPOINT;
+import static org.apache.flink.util.Preconditions.checkNotNull;
+import static org.apache.flink.util.Preconditions.checkState;
+
+/** Default implementation of {@link BatchJobRecoveryHandler} and {@link 
JobEventReplayHandler}. */
+public class DefaultBatchJobRecoveryHandler
+        implements BatchJobRecoveryHandler, JobEventReplayHandler {
+
+    private Logger log;
+
+    private final JobEventManager jobEventManager;
+
+    private final Configuration jobMasterConfiguration;
+
+    private ExecutionGraph executionGraph;
+
+    /** The timestamp (via {@link Clock#relativeTimeMillis()}) of the last 
snapshot. */
+    private long lastSnapshotRelativeTime;
+
+    private final Set<JobVertexID> needToSnapshotJobVertices = new HashSet<>();
+
+    private ShuffleMaster<?> shuffleMaster;
+
+    private static final ResourceID UNKNOWN_PRODUCER = ResourceID.generate();
+
+    private long snapshotMinPauseMills;
+
+    private Clock clock;
+
+    private ComponentMainThreadExecutor mainThreadExecutor;
+
+    private FailoverStrategy failoverStrategy;
+
+    private TriConsumer<Throwable, Long, CompletableFuture<Map<String, 
String>>> failJobFunction;
+
+    private ConsumerWithException<Set<ExecutionVertexID>, Exception>
+            resetVerticesInRecoveringFunction;
+
+    private final Map<ExecutionVertexID, ExecutionVertexFinishedEvent>
+            executionVertexFinishedEventMap = new LinkedHashMap<>();
+
+    private final List<ExecutionJobVertexInitializedEvent> 
jobVertexInitializedEvents =
+            new ArrayList<>();
+
+    private Consumer<Map<IntermediateResultPartitionID, ResultPartitionBytes>>
+            updateResultPartitionBytesMetricsFunction;
+
+    private QuadConsumerWithException<
+                    ExecutionJobVertex,
+                    Integer,
+                    Map<IntermediateDataSetID, JobVertexInputInfo>,
+                    Long,
+                    JobException>
+            initializeJobVertexFunction;
+
+    private Consumer<List<ExecutionJobVertex>> updateTopologyFunction;
+
+    /**
+     * This collection signifies the group of vertices where, upon a failure 
of any single execution
+     * instance, all executions within the job vertex need to be restarted 
once.
+     */
+    private final Set<JobVertexID> requiredRestartJobVertices = new 
HashSet<>();
+
+    public DefaultBatchJobRecoveryHandler(
+            JobEventManager jobEventManager, Configuration 
jobMasterConfiguration) {
+        this.jobEventManager = jobEventManager;
+        this.jobMasterConfiguration = jobMasterConfiguration;
+    }
+
+    @Override
+    public void initialize(
+            Logger log,
+            ExecutionGraph executionGraph,
+            ShuffleMaster<?> shuffleMaster,
+            ComponentMainThreadExecutor mainThreadExecutor,
+            FailoverStrategy failoverStrategy,
+            TriConsumer<Throwable, Long, CompletableFuture<Map<String, 
String>>> failJobFunction,
+            ConsumerWithException<Set<ExecutionVertexID>, Exception>
+                    resetVerticesInRecoveringFunction,
+            Consumer<Map<IntermediateResultPartitionID, ResultPartitionBytes>>
+                    updateResultPartitionBytesMetricsFunction,
+            QuadConsumerWithException<
+                            ExecutionJobVertex,
+                            Integer,
+                            Map<IntermediateDataSetID, JobVertexInputInfo>,
+                            Long,
+                            JobException>
+                    initializeJobVertexFunction,
+            Consumer<List<ExecutionJobVertex>> updateTopologyFunction) {
+
+        this.log = checkNotNull(log);
+        this.executionGraph = checkNotNull(executionGraph);
+        this.shuffleMaster = checkNotNull(shuffleMaster);
+        this.snapshotMinPauseMills =
+                jobMasterConfiguration
+                        
.get(BatchExecutionOptions.JOB_RECOVERY_SNAPSHOT_MIN_PAUSE)
+                        .toMillis();
+        this.clock = SystemClock.getInstance();
+        this.mainThreadExecutor = checkNotNull(mainThreadExecutor);
+        this.failoverStrategy = checkNotNull(failoverStrategy);
+        this.failJobFunction = checkNotNull(failJobFunction);
+        this.resetVerticesInRecoveringFunction = 
checkNotNull(resetVerticesInRecoveringFunction);
+        this.updateResultPartitionBytesMetricsFunction =
+                checkNotNull(updateResultPartitionBytesMetricsFunction);
+        this.initializeJobVertexFunction = 
checkNotNull(initializeJobVertexFunction);
+        this.updateTopologyFunction = checkNotNull(updateTopologyFunction);
+
+        try {
+            jobEventManager.start();
+        } catch (Throwable throwable) {
+            failJobFunction.accept(
+                    throwable,
+                    System.currentTimeMillis(),
+                    FailureEnricherUtils.EMPTY_FAILURE_LABELS);
+        }
+    }
+
+    @Override
+    public void stop(boolean clearUp) {
+        jobEventManager.stop(clearUp);
+    }
+
+    @Override
+    public void startRecovering(
+            Consumer<Set<JobVertexID>> recoverFinishedListener, Runnable 
recoverFailedListener) {
+        mainThreadExecutor.assertRunningInMainThread();
+
+        startRecovering();
+
+        
shuffleMaster.notifyPartitionRecoveryStarted(executionGraph.getJobID());
+
+        if (!jobEventManager.replay(this)) {
+            log.warn(
+                    "Fail to replay log for {}, will start the job as a new 
one.",
+                    executionGraph.getJobID());
+            recoverFailed(recoverFailedListener, failJobFunction);
+            return;
+        }
+        log.info("Replay all job events successfully.");
+
+        CompletableFuture<Collection<PartitionWithMetrics>> existingPartitions 
=
+                ((InternalExecutionGraphAccessor) executionGraph)
+                        .getShuffleMaster()
+                        .getAllPartitionWithMetrics(executionGraph.getJobID());
+
+        existingPartitions.whenCompleteAsync(
+                (partitions, throwable) -> {
+                    if (throwable != null) {
+                        recoverFailed(recoverFailedListener, failJobFunction);
+                    }
+                    try {
+                        recoverPartitions(partitions, 
updateResultPartitionBytesMetricsFunction);
+                        recoverFinished(recoverFinishedListener);
+                    } catch (Exception exception) {
+                        recoverFailed(recoverFailedListener, failJobFunction);
+                    }
+                },
+                mainThreadExecutor);
+    }
+
+    @Override
+    public boolean needRecover() {
+        try {
+            return jobEventManager.hasJobEvents();
+        } catch (Throwable throwable) {
+            failJobFunction.accept(
+                    throwable,
+                    System.currentTimeMillis(),
+                    FailureEnricherUtils.EMPTY_FAILURE_LABELS);
+            return false;
+        }
+    }
+
+    @Override
+    public boolean isRecovering() {
+        return executionGraph.getState() == JobStatus.RECONCILING;
+    }
+
+    private void restoreShuffleMaster(List<ShuffleMasterSnapshot> snapshots) {
+        checkState(shuffleMaster.supportsBatchSnapshot());
+        shuffleMaster.restoreState(snapshots);
+    }
+
+    private void startRecovering() {
+        log.info("Try to recover from JM failover.");
+        executionGraph.transitionState(JobStatus.CREATED, 
JobStatus.RECONCILING);
+    }
+
+    private void restoreOperatorCoordinators(
+            Map<OperatorID, byte[]> snapshots, Map<OperatorID, JobVertexID> 
operatorToJobVertex)
+            throws Exception {
+        for (Map.Entry<OperatorID, byte[]> entry : snapshots.entrySet()) {
+            OperatorID operatorId = entry.getKey();
+            JobVertexID jobVertexId = 
checkNotNull(operatorToJobVertex.get(operatorId));
+            ExecutionJobVertex jobVertex = getExecutionJobVertex(jobVertexId);
+            log.info(
+                    "Restore operator coordinators of {} from job event, 
checkpointId {}.",
+                    jobVertex.getName(),
+                    NO_CHECKPOINT);
+
+            for (OperatorCoordinatorHolder holder : 
jobVertex.getOperatorCoordinators()) {
+                if (holder.coordinator().supportsBatchSnapshot()) {
+                    byte[] snapshot = snapshots.get(holder.operatorId());
+                    holder.resetToCheckpoint(NO_CHECKPOINT, snapshot);
+                }
+            }
+        }
+
+        reviseVertices();
+    }
+
+    @Override
+    public void startReplay() {
+        // do nothing.
+    }
+
+    @Override
+    public void replayOneEvent(JobEvent jobEvent) {
+        if (jobEvent instanceof ExecutionVertexFinishedEvent) {
+            ExecutionVertexFinishedEvent event = 
(ExecutionVertexFinishedEvent) jobEvent;
+            executionVertexFinishedEventMap.put(event.getExecutionVertexId(), 
event);
+        } else if (jobEvent instanceof ExecutionVertexResetEvent) {
+            ExecutionVertexResetEvent event = (ExecutionVertexResetEvent) 
jobEvent;
+            for (ExecutionVertexID executionVertexId : 
event.getExecutionVertexIds()) {
+                executionVertexFinishedEventMap.remove(executionVertexId);
+            }
+        } else if (jobEvent instanceof ExecutionJobVertexInitializedEvent) {
+            
jobVertexInitializedEvents.add((ExecutionJobVertexInitializedEvent) jobEvent);
+        } else {
+            throw new IllegalStateException("Unsupported job event " + 
jobEvent);
+        }
+    }
+
+    @Override
+    public void finalizeReplay() throws Exception {
+        // recover job vertex initialization info and update topology
+        long currentTimeMillis = System.currentTimeMillis();
+        final List<ExecutionJobVertex> initializedJobVertices = new 
ArrayList<>();
+        for (ExecutionJobVertexInitializedEvent event : 
jobVertexInitializedEvents) {
+            final ExecutionJobVertex jobVertex = 
getExecutionJobVertex(event.getJobVertexId());
+            initializeJobVertexFunction.accept(
+                    jobVertex,
+                    event.getParallelism(),
+                    event.getJobVertexInputInfos(),
+                    currentTimeMillis);
+            initializedJobVertices.add(jobVertex);
+        }
+        updateTopologyFunction.accept(initializedJobVertices);
+
+        // remove the last batch of vertices that do not record the states of 
operator coordinator
+        // and shuffle master
+        LinkedList<ExecutionVertexFinishedEvent> finishedEvents =
+                new LinkedList<>(executionVertexFinishedEventMap.values());
+        while (!finishedEvents.isEmpty()
+                && 
!finishedEvents.getLast().hasOperatorCoordinatorAndShuffleMasterSnapshots()) {
+            finishedEvents.removeLast();
+        }
+
+        if (finishedEvents.isEmpty()) {
+            return;
+        }
+
+        // find the last operator coordinator state for each operator 
coordinator
+        Map<OperatorID, byte[]> operatorCoordinatorSnapshots = new HashMap<>();
+
+        List<ShuffleMasterSnapshot> shuffleMasterSnapshots = new ArrayList<>();
+
+        // transition states of all vertices
+        for (ExecutionVertexFinishedEvent event : finishedEvents) {
+            JobVertexID jobVertexId = 
event.getExecutionVertexId().getJobVertexId();
+            ExecutionJobVertex jobVertex = 
executionGraph.getJobVertex(jobVertexId);
+            checkState(jobVertex.isInitialized());
+
+            int subTaskIndex = event.getExecutionVertexId().getSubtaskIndex();
+            Execution execution =
+                    
jobVertex.getTaskVertices()[subTaskIndex].getCurrentExecutionAttempt();
+            // recover execution info.
+            execution.recoverExecution(
+                    event.getExecutionAttemptId(),
+                    event.getTaskManagerLocation(),
+                    event.getUserAccumulators(),
+                    event.getIOMetrics());
+
+            // recover operator coordinator
+            for (Map.Entry<OperatorID, CompletableFuture<byte[]>> entry :
+                    event.getOperatorCoordinatorSnapshotFutures().entrySet()) {
+                checkState(entry.getValue().isDone());
+                operatorCoordinatorSnapshots.put(entry.getKey(), 
entry.getValue().get());
+            }
+
+            // recover shuffle master
+            if (event.getShuffleMasterSnapshotFuture() != null) {
+                ShuffleMasterSnapshot shuffleMasterSnapshot =
+                        event.getShuffleMasterSnapshotFuture().get();
+                if (shuffleMasterSnapshot.isIncremental()) {
+                    shuffleMasterSnapshots.add(shuffleMasterSnapshot);
+                } else {
+                    shuffleMasterSnapshots = 
Arrays.asList(shuffleMasterSnapshot);
+                }
+            }
+        }
+
+        // restore operator coordinator state if needed.
+        final Map<OperatorID, JobVertexID> operatorToJobVertex = new 
HashMap<>();
+        for (ExecutionJobVertex jobVertex : 
executionGraph.getAllVertices().values()) {
+            if (!jobVertex.isInitialized()) {
+                continue;
+            }
+
+            for (OperatorCoordinatorHolder holder : 
jobVertex.getOperatorCoordinators()) {
+                operatorToJobVertex.put(holder.operatorId(), 
jobVertex.getJobVertexId());
+            }
+        }
+
+        try {
+            restoreOperatorCoordinators(operatorCoordinatorSnapshots, 
operatorToJobVertex);
+        } catch (Exception exception) {
+            log.warn("Restore coordinator operator failed.", exception);
+            throw exception;
+        }
+
+        // restore shuffle master
+        restoreShuffleMaster(shuffleMasterSnapshots);
+    }
+
+    @Override
+    public void notifyExecutionVertexReset(Collection<ExecutionVertexID> 
vertices) {
+        // write execute vertex reset event.
+        checkState(!isRecovering());
+        jobEventManager.writeEvent(new ExecutionVertexResetEvent(new 
ArrayList<>(vertices)), false);
+    }
+
+    @Override
+    public void notifyExecutionJobVertexInitialization(
+            JobVertexID jobVertexId,
+            int parallelism,
+            Map<IntermediateDataSetID, JobVertexInputInfo> 
jobVertexInputInfos) {
+        // write execution job vertex initialized event.
+        checkState(!isRecovering());
+        jobEventManager.writeEvent(
+                new ExecutionJobVertexInitializedEvent(
+                        jobVertexId, parallelism, jobVertexInputInfos),
+                false);
+    }
+
+    @Override
+    public void notifyExecutionFinished(
+            ExecutionVertexID executionVertexId, TaskExecutionStateTransition 
taskExecutionState) {
+        checkState(!isRecovering());
+
+        checkState(taskExecutionState.getExecutionState() == 
ExecutionState.FINISHED);
+        Execution execution = 
getExecutionVertex(executionVertexId).getCurrentExecutionAttempt();
+
+        // check whether the job vertex is finished.
+        ExecutionJobVertex jobVertex = execution.getVertex().getJobVertex();
+        boolean jobVertexFinished = jobVertex.getAggregateState() == 
ExecutionState.FINISHED;
+
+        // snapshot operator coordinators if needed.
+        needToSnapshotJobVertices.add(executionVertexId.getJobVertexId());
+        final Map<OperatorID, CompletableFuture<byte[]>> 
operatorCoordinatorSnapshotFutures =
+                new HashMap<>();
+        CompletableFuture<ShuffleMasterSnapshot> shuffleMasterSnapshotFuture = 
null;
+        long currentRelativeTime = clock.relativeTimeMillis();
+        if (jobVertexFinished
+                || (currentRelativeTime - lastSnapshotRelativeTime >= 
snapshotMinPauseMills)) {
+            // operator coordinator
+            
operatorCoordinatorSnapshotFutures.putAll(snapshotOperatorCoordinators());
+            lastSnapshotRelativeTime = currentRelativeTime;
+            needToSnapshotJobVertices.clear();
+
+            // shuffle master
+            shuffleMasterSnapshotFuture = snapshotShuffleMaster();
+        }
+
+        // write job event.
+        jobEventManager.writeEvent(
+                new ExecutionVertexFinishedEvent(
+                        execution.getAttemptId(),
+                        execution.getAssignedResourceLocation(),
+                        operatorCoordinatorSnapshotFutures,
+                        shuffleMasterSnapshotFuture,
+                        execution.getIOMetrics(),
+                        execution.getUserAccumulators()),
+                jobVertexFinished);
+    }
+
+    private Map<OperatorID, CompletableFuture<byte[]>> 
snapshotOperatorCoordinators() {
+
+        final Map<OperatorID, CompletableFuture<byte[]>> snapshotFutures = new 
HashMap<>();
+
+        for (JobVertexID jobVertexId : needToSnapshotJobVertices) {
+            final ExecutionJobVertex jobVertex = 
checkNotNull(getExecutionJobVertex(jobVertexId));
+
+            log.info(
+                    "Snapshot operator coordinators of {} to job event, 
checkpointId {}.",
+                    jobVertex.getName(),
+                    NO_CHECKPOINT);
+
+            for (OperatorCoordinatorHolder holder : 
jobVertex.getOperatorCoordinators()) {
+                if (holder.coordinator().supportsBatchSnapshot()) {
+                    final CompletableFuture<byte[]> checkpointFuture = new 
CompletableFuture<>();
+                    holder.checkpointCoordinator(NO_CHECKPOINT, 
checkpointFuture);
+                    snapshotFutures.put(holder.operatorId(), checkpointFuture);
+                }
+            }
+        }
+        return snapshotFutures;
+    }
+
+    private CompletableFuture<ShuffleMasterSnapshot> snapshotShuffleMaster() {
+
+        checkState(shuffleMaster.supportsBatchSnapshot());
+        CompletableFuture<ShuffleMasterSnapshot> shuffleMasterSnapshotFuture =
+                new CompletableFuture<>();
+        shuffleMaster.snapshotState(
+                shuffleMasterSnapshotFuture, new 
DefaultShuffleMasterSnapshotContext());
+        return shuffleMasterSnapshotFuture;
+    }
+
+    private void reviseVertices() throws Exception {
+        Set<ExecutionVertexID> verticesToRestart = new HashSet<>();
+
+        for (ExecutionJobVertex jobVertex : 
executionGraph.getAllVertices().values()) {
+            if (!jobVertex.isInitialized() || 
jobVertex.getOperatorCoordinators().isEmpty()) {
+                continue;
+            }
+
+            boolean allSupportsBatchSnapshot =
+                    jobVertex.getOperatorCoordinators().stream()
+                            .allMatch(holder -> 
holder.coordinator().supportsBatchSnapshot());
+
+            Set<ExecutionVertexID> unfinishedTasks =
+                    Arrays.stream(jobVertex.getTaskVertices())
+                            .filter(vertex -> vertex.getExecutionState() != 
ExecutionState.FINISHED)
+                            .map(
+                                    executionVertex -> {
+                                        // transition to terminal state to 
allow reset it
+                                        executionVertex
+                                                .getCurrentExecutionAttempt()
+                                                
.transitionState(ExecutionState.CANCELED);
+                                        return executionVertex.getID();
+                                    })
+                            .collect(Collectors.toSet());
+
+            if (allSupportsBatchSnapshot) {
+                log.info(
+                        "All operator coordinators of jobVertex {} support 
batch snapshot, "
+                                + "add {} unfinished tasks to revise.",
+                        jobVertex.getName(),
+                        unfinishedTasks.size());
+                verticesToRestart.addAll(unfinishedTasks);
+            } else if (unfinishedTasks.isEmpty()) {
+                log.info(
+                        "JobVertex {} is finished, but not all of its operator 
coordinators support "
+                                + "batch snapshot. Therefore, if any single 
task within it requires "
+                                + "a restart in the future, all tasks 
associated with this JobVertex "
+                                + "need to be restarted as well.",
+                        jobVertex.getName());
+                requiredRestartJobVertices.add(jobVertex.getJobVertexId());
+            } else {
+                log.info(
+                        "Restart all tasks of jobVertex {} because it has not 
been finished and not "
+                                + "all of its operator coordinators support 
batch snapshot.",
+                        jobVertex.getName());
+                verticesToRestart.addAll(
+                        Arrays.stream(jobVertex.getTaskVertices())
+                                .map(ExecutionVertex::getID)
+                                .collect(Collectors.toSet()));
+            }
+        }
+
+        resetVerticesInRecovering(verticesToRestart);
+    }
+
+    private void resetVerticesInRecovering(Set<ExecutionVertexID> 
verticesToRestart)
+            throws Exception {
+        checkState(isRecovering());
+
+        Set<JobVertexID> extraNeedToRestartJobVertices =
+                verticesToRestart.stream()
+                        .map(ExecutionVertexID::getJobVertexId)
+                        .filter(requiredRestartJobVertices::contains)
+                        .collect(Collectors.toSet());
+
+        requiredRestartJobVertices.removeAll(extraNeedToRestartJobVertices);
+
+        verticesToRestart.addAll(
+                extraNeedToRestartJobVertices.stream()
+                        .flatMap(
+                                jobVertexId -> {
+                                    ExecutionJobVertex jobVertex =
+                                            getExecutionJobVertex(jobVertexId);
+                                    return 
Arrays.stream(jobVertex.getTaskVertices())
+                                            .map(ExecutionVertex::getID);
+                                })
+                        .collect(Collectors.toSet()));
+
+        // we only reset tasks which are not CREATED.
+        Set<ExecutionVertexID> verticesToReset =
+                verticesToRestart.stream()
+                        .filter(
+                                executionVertexID ->
+                                        
getExecutionVertex(executionVertexID).getExecutionState()
+                                                != ExecutionState.CREATED)
+                        .collect(Collectors.toSet());
+
+        resetVerticesInRecoveringFunction.accept(verticesToReset);
+    }
+
+    private void recoverFailed(
+            Runnable recoverFailedListener,
+            TriConsumer<Throwable, Long, CompletableFuture<Map<String, 
String>>> failJobFunction) {
+        String message =
+                String.format(
+                        "Job %s recover failed from JM failover, fail global.",
+                        executionGraph.getJobID());
+        log.warn(message);
+        
shuffleMaster.notifyPartitionRecoveryCompleted(executionGraph.getJobID());
+        executionGraph.transitionState(JobStatus.RECONCILING, 
JobStatus.RUNNING);
+
+        // clear job events and restart job event manager.
+        jobEventManager.stop(true);
+        try {
+            jobEventManager.start();
+        } catch (Throwable throwable) {
+            failJobFunction.accept(
+                    throwable,
+                    System.currentTimeMillis(),
+                    FailureEnricherUtils.EMPTY_FAILURE_LABELS);
+            return;
+        }
+
+        recoverFailedListener.run();
+    }
+
+    private void recoverFinished(Consumer<Set<JobVertexID>> 
recoverFinishedListener) {
+        log.info("Job {} successfully recovered from JM failover", 
executionGraph.getJobID());
+        
shuffleMaster.notifyPartitionRecoveryCompleted(executionGraph.getJobID());
+
+        executionGraph.transitionState(JobStatus.RECONCILING, 
JobStatus.RUNNING);
+        checkExecutionGraphState();
+        recoverFinishedListener.accept(requiredRestartJobVertices);
+    }
+
+    private void checkExecutionGraphState() {
+        for (ExecutionVertex executionVertex : 
executionGraph.getAllExecutionVertices()) {
+            ExecutionState state = executionVertex.getExecutionState();
+            checkState(state == ExecutionState.CREATED || state == 
ExecutionState.FINISHED);
+        }
+    }
+
+    private void recoverPartitions(
+            Collection<PartitionWithMetrics> partitionWithMetrics,
+            Consumer<Map<IntermediateResultPartitionID, ResultPartitionBytes>>
+                    updateResultPartitionBytesMetricsFunction)
+            throws Exception {
+        mainThreadExecutor.assertRunningInMainThread();
+
+        Set<ResultPartitionID> actualPartitions =
+                partitionWithMetrics.stream()
+                        .map(PartitionWithMetrics::getPartition)
+                        .map(ShuffleDescriptor::getResultPartitionID)
+                        .collect(Collectors.toSet());
+
+        ReconcileResult reconcileResult = 
reconcilePartitions(actualPartitions);
+
+        log.info(
+                "All unknown partitions {}, missing partitions {}, available 
partitions {}",
+                reconcileResult.unknownPartitions,
+                reconcileResult.missingPartitions,
+                reconcileResult.availablePartitions);
+
+        // release all unknown partitions
+        ((InternalExecutionGraphAccessor) executionGraph)
+                .getPartitionTracker()
+                
.stopTrackingAndReleasePartitions(reconcileResult.unknownPartitions);
+
+        // start tracking all available partitions
+        Map<IntermediateResultPartitionID, ResultPartitionBytes> 
availablePartitionBytes =
+                new HashMap<>();
+        partitionWithMetrics.stream()
+                .filter(
+                        partitionAndMetric ->
+                                reconcileResult.availablePartitions.contains(
+                                        
partitionAndMetric.getPartition().getResultPartitionID()))
+                .forEach(
+                        partitionAndMetric -> {
+                            ShuffleDescriptor shuffleDescriptor = 
partitionAndMetric.getPartition();
+
+                            // we cannot get the producer id when using remote 
shuffle
+                            ResourceID producerTaskExecutorId = 
UNKNOWN_PRODUCER;
+                            if 
(shuffleDescriptor.storesLocalResourcesOn().isPresent()) {
+                                producerTaskExecutorId =
+                                        
shuffleDescriptor.storesLocalResourcesOn().get();
+                            }
+                            IntermediateResultPartition partition =
+                                    executionGraph.getResultPartitionOrThrow(
+                                            shuffleDescriptor
+                                                    .getResultPartitionID()
+                                                    .getPartitionId());
+                            ((InternalExecutionGraphAccessor) executionGraph)
+                                    .getPartitionTracker()
+                                    .startTrackingPartition(
+                                            producerTaskExecutorId,
+                                            
Execution.createResultPartitionDeploymentDescriptor(
+                                                    partition, 
shuffleDescriptor));
+
+                            availablePartitionBytes.put(
+                                    
shuffleDescriptor.getResultPartitionID().getPartitionId(),
+                                    
partitionAndMetric.getPartitionMetrics().getPartitionBytes());
+                        });
+
+        // recover the produced partitions for executions
+        Map<
+                        ExecutionVertexID,
+                        Map<IntermediateResultPartitionID, 
ResultPartitionDeploymentDescriptor>>
+                allDescriptors = new HashMap<>();
+        ((InternalExecutionGraphAccessor) executionGraph)
+                .getPartitionTracker()
+                .getAllTrackedNonClusterPartitions()
+                .forEach(
+                        descriptor -> {
+                            ExecutionVertexID vertexId =
+                                    descriptor
+                                            .getShuffleDescriptor()
+                                            .getResultPartitionID()
+                                            .getProducerId()
+                                            .getExecutionVertexId();
+                            if (!allDescriptors.containsKey(vertexId)) {
+                                allDescriptors.put(vertexId, new HashMap<>());
+                            }
+
+                            allDescriptors
+                                    .get(vertexId)
+                                    .put(descriptor.getPartitionId(), 
descriptor);
+                        });
+
+        allDescriptors.forEach(
+                (vertexId, descriptors) ->
+                        getExecutionVertex(vertexId)
+                                .getCurrentExecutionAttempt()
+                                .recoverProducedPartitions(descriptors));
+
+        // recover result partition bytes
+        
updateResultPartitionBytesMetricsFunction.accept(availablePartitionBytes);
+
+        // restart all producers of missing partitions
+        List<ExecutionVertexID> missingPartitionVertices =
+                reconcileResult.missingPartitions.stream()
+                        .map(ResultPartitionID::getPartitionId)
+                        .map(this::getProducer)
+                        .map(ExecutionVertex::getID)
+                        .collect(Collectors.toList());
+
+        // get all vertices need to restart according failover strategy.
+        Set<ExecutionVertexID> verticesToReset = new HashSet<>();
+        for (ExecutionVertexID executionVertexId : missingPartitionVertices) {
+            if (!verticesToReset.contains(executionVertexId)) {
+                verticesToReset.addAll(
+                        
failoverStrategy.getTasksNeedingRestart(executionVertexId, null));

Review Comment:
   Because after restoring the operator coordinator, at that time the upstream 
partition is not available. If we call 
`failoverStrategy.getTasksNeedingRestart`, all upstream tasks will be 
mistakenly included. So I do this only after partition recovered finished.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to