JunRuiLee commented on code in PR #24771: URL: https://github.com/apache/flink/pull/24771#discussion_r1611044994
########## flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultBatchJobRecoveryHandler.java: ########## @@ -0,0 +1,840 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.scheduler.adaptivebatch; + +import org.apache.flink.api.common.JobStatus; +import org.apache.flink.configuration.BatchExecutionOptions; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.runtime.JobException; +import org.apache.flink.runtime.clusterframework.types.ResourceID; +import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutor; +import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor; +import org.apache.flink.runtime.execution.ExecutionState; +import org.apache.flink.runtime.executiongraph.Execution; +import org.apache.flink.runtime.executiongraph.ExecutionGraph; +import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; +import org.apache.flink.runtime.executiongraph.ExecutionVertex; +import org.apache.flink.runtime.executiongraph.IntermediateResultPartition; +import org.apache.flink.runtime.executiongraph.InternalExecutionGraphAccessor; +import org.apache.flink.runtime.executiongraph.JobVertexInputInfo; +import org.apache.flink.runtime.executiongraph.ResultPartitionBytes; +import org.apache.flink.runtime.executiongraph.TaskExecutionStateTransition; +import org.apache.flink.runtime.executiongraph.failover.FailoverStrategy; +import org.apache.flink.runtime.failure.FailureEnricherUtils; +import org.apache.flink.runtime.io.network.partition.ResultPartitionID; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; +import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.jobgraph.OperatorID; +import org.apache.flink.runtime.jobmaster.event.ExecutionJobVertexInitializedEvent; +import org.apache.flink.runtime.jobmaster.event.ExecutionVertexFinishedEvent; +import org.apache.flink.runtime.jobmaster.event.ExecutionVertexResetEvent; +import org.apache.flink.runtime.jobmaster.event.JobEvent; +import org.apache.flink.runtime.jobmaster.event.JobEventManager; +import org.apache.flink.runtime.jobmaster.event.JobEventReplayHandler; +import org.apache.flink.runtime.operators.coordination.OperatorCoordinatorHolder; +import org.apache.flink.runtime.scheduler.strategy.ConsumerVertexGroup; +import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID; +import org.apache.flink.runtime.shuffle.DefaultShuffleMasterSnapshotContext; +import org.apache.flink.runtime.shuffle.PartitionWithMetrics; +import org.apache.flink.runtime.shuffle.ShuffleDescriptor; +import org.apache.flink.runtime.shuffle.ShuffleMaster; +import org.apache.flink.runtime.shuffle.ShuffleMasterSnapshot; +import org.apache.flink.util.clock.Clock; +import org.apache.flink.util.clock.SystemClock; +import org.apache.flink.util.function.ConsumerWithException; +import org.apache.flink.util.function.QuadConsumerWithException; +import org.apache.flink.util.function.TriConsumer; + +import org.apache.flink.shaded.guava31.com.google.common.collect.Sets; + +import org.slf4j.Logger; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; +import java.util.stream.Collectors; + +import static org.apache.flink.runtime.operators.coordination.OperatorCoordinator.NO_CHECKPOINT; +import static org.apache.flink.util.Preconditions.checkNotNull; +import static org.apache.flink.util.Preconditions.checkState; + +/** Default implementation of {@link BatchJobRecoveryHandler} and {@link JobEventReplayHandler}. */ +public class DefaultBatchJobRecoveryHandler + implements BatchJobRecoveryHandler, JobEventReplayHandler { + + private Logger log; + + private final JobEventManager jobEventManager; + + private final Configuration jobMasterConfiguration; + + private ExecutionGraph executionGraph; + + /** The timestamp (via {@link Clock#relativeTimeMillis()}) of the last snapshot. */ + private long lastSnapshotRelativeTime; + + private final Set<JobVertexID> needToSnapshotJobVertices = new HashSet<>(); + + private ShuffleMaster<?> shuffleMaster; + + private static final ResourceID UNKNOWN_PRODUCER = ResourceID.generate(); + + private long snapshotMinPauseMills; + + private Clock clock; + + private ComponentMainThreadExecutor mainThreadExecutor; + + private FailoverStrategy failoverStrategy; + + private TriConsumer<Throwable, Long, CompletableFuture<Map<String, String>>> failJobFunction; + + private ConsumerWithException<Set<ExecutionVertexID>, Exception> + resetVerticesInRecoveringFunction; + + private final Map<ExecutionVertexID, ExecutionVertexFinishedEvent> + executionVertexFinishedEventMap = new LinkedHashMap<>(); + + private final List<ExecutionJobVertexInitializedEvent> jobVertexInitializedEvents = + new ArrayList<>(); + + private Consumer<Map<IntermediateResultPartitionID, ResultPartitionBytes>> + updateResultPartitionBytesMetricsFunction; + + private QuadConsumerWithException< + ExecutionJobVertex, + Integer, + Map<IntermediateDataSetID, JobVertexInputInfo>, + Long, + JobException> + initializeJobVertexFunction; + + private Consumer<List<ExecutionJobVertex>> updateTopologyFunction; + + /** + * This collection signifies the group of vertices where, upon a failure of any single execution + * instance, all executions within the job vertex need to be restarted once. + */ + private final Set<JobVertexID> requiredRestartJobVertices = new HashSet<>(); + + public DefaultBatchJobRecoveryHandler( + JobEventManager jobEventManager, Configuration jobMasterConfiguration) { + this.jobEventManager = jobEventManager; + this.jobMasterConfiguration = jobMasterConfiguration; + } + + @Override + public void initialize( + Logger log, + ExecutionGraph executionGraph, + ShuffleMaster<?> shuffleMaster, + ComponentMainThreadExecutor mainThreadExecutor, + FailoverStrategy failoverStrategy, + TriConsumer<Throwable, Long, CompletableFuture<Map<String, String>>> failJobFunction, + ConsumerWithException<Set<ExecutionVertexID>, Exception> + resetVerticesInRecoveringFunction, + Consumer<Map<IntermediateResultPartitionID, ResultPartitionBytes>> + updateResultPartitionBytesMetricsFunction, + QuadConsumerWithException< + ExecutionJobVertex, + Integer, + Map<IntermediateDataSetID, JobVertexInputInfo>, + Long, + JobException> + initializeJobVertexFunction, + Consumer<List<ExecutionJobVertex>> updateTopologyFunction) { + + this.log = checkNotNull(log); + this.executionGraph = checkNotNull(executionGraph); + this.shuffleMaster = checkNotNull(shuffleMaster); + this.snapshotMinPauseMills = + jobMasterConfiguration + .get(BatchExecutionOptions.JOB_RECOVERY_SNAPSHOT_MIN_PAUSE) + .toMillis(); + this.clock = SystemClock.getInstance(); + this.mainThreadExecutor = checkNotNull(mainThreadExecutor); + this.failoverStrategy = checkNotNull(failoverStrategy); + this.failJobFunction = checkNotNull(failJobFunction); + this.resetVerticesInRecoveringFunction = checkNotNull(resetVerticesInRecoveringFunction); + this.updateResultPartitionBytesMetricsFunction = + checkNotNull(updateResultPartitionBytesMetricsFunction); + this.initializeJobVertexFunction = checkNotNull(initializeJobVertexFunction); + this.updateTopologyFunction = checkNotNull(updateTopologyFunction); + + try { + jobEventManager.start(); + } catch (Throwable throwable) { + failJobFunction.accept( + throwable, + System.currentTimeMillis(), + FailureEnricherUtils.EMPTY_FAILURE_LABELS); + } + } + + @Override + public void stop(boolean clearUp) { + jobEventManager.stop(clearUp); + } + + @Override + public void startRecovering( + Consumer<Set<JobVertexID>> recoverFinishedListener, Runnable recoverFailedListener) { + mainThreadExecutor.assertRunningInMainThread(); + + startRecovering(); + + shuffleMaster.notifyPartitionRecoveryStarted(executionGraph.getJobID()); + + if (!jobEventManager.replay(this)) { + log.warn( + "Fail to replay log for {}, will start the job as a new one.", + executionGraph.getJobID()); + recoverFailed(recoverFailedListener, failJobFunction); + return; + } + log.info("Replay all job events successfully."); + + CompletableFuture<Collection<PartitionWithMetrics>> existingPartitions = + ((InternalExecutionGraphAccessor) executionGraph) + .getShuffleMaster() + .getAllPartitionWithMetrics(executionGraph.getJobID()); + + existingPartitions.whenCompleteAsync( + (partitions, throwable) -> { + if (throwable != null) { + recoverFailed(recoverFailedListener, failJobFunction); + } + try { + recoverPartitions(partitions, updateResultPartitionBytesMetricsFunction); + recoverFinished(recoverFinishedListener); + } catch (Exception exception) { + recoverFailed(recoverFailedListener, failJobFunction); + } + }, + mainThreadExecutor); + } + + @Override + public boolean needRecover() { + try { + return jobEventManager.hasJobEvents(); + } catch (Throwable throwable) { + failJobFunction.accept( + throwable, + System.currentTimeMillis(), + FailureEnricherUtils.EMPTY_FAILURE_LABELS); + return false; + } + } + + @Override + public boolean isRecovering() { + return executionGraph.getState() == JobStatus.RECONCILING; + } + + private void restoreShuffleMaster(List<ShuffleMasterSnapshot> snapshots) { + checkState(shuffleMaster.supportsBatchSnapshot()); + shuffleMaster.restoreState(snapshots); + } + + private void startRecovering() { + log.info("Try to recover from JM failover."); + executionGraph.transitionState(JobStatus.CREATED, JobStatus.RECONCILING); + } + + private void restoreOperatorCoordinators( + Map<OperatorID, byte[]> snapshots, Map<OperatorID, JobVertexID> operatorToJobVertex) + throws Exception { + for (Map.Entry<OperatorID, byte[]> entry : snapshots.entrySet()) { + OperatorID operatorId = entry.getKey(); + JobVertexID jobVertexId = checkNotNull(operatorToJobVertex.get(operatorId)); + ExecutionJobVertex jobVertex = getExecutionJobVertex(jobVertexId); + log.info( + "Restore operator coordinators of {} from job event, checkpointId {}.", + jobVertex.getName(), + NO_CHECKPOINT); + + for (OperatorCoordinatorHolder holder : jobVertex.getOperatorCoordinators()) { + if (holder.coordinator().supportsBatchSnapshot()) { + byte[] snapshot = snapshots.get(holder.operatorId()); + holder.resetToCheckpoint(NO_CHECKPOINT, snapshot); + } + } + } + + reviseVertices(); + } + + @Override + public void startReplay() { + // do nothing. + } + + @Override + public void replayOneEvent(JobEvent jobEvent) { + if (jobEvent instanceof ExecutionVertexFinishedEvent) { + ExecutionVertexFinishedEvent event = (ExecutionVertexFinishedEvent) jobEvent; + executionVertexFinishedEventMap.put(event.getExecutionVertexId(), event); + } else if (jobEvent instanceof ExecutionVertexResetEvent) { + ExecutionVertexResetEvent event = (ExecutionVertexResetEvent) jobEvent; + for (ExecutionVertexID executionVertexId : event.getExecutionVertexIds()) { + executionVertexFinishedEventMap.remove(executionVertexId); + } + } else if (jobEvent instanceof ExecutionJobVertexInitializedEvent) { + jobVertexInitializedEvents.add((ExecutionJobVertexInitializedEvent) jobEvent); + } else { + throw new IllegalStateException("Unsupported job event " + jobEvent); + } + } + + @Override + public void finalizeReplay() throws Exception { + // recover job vertex initialization info and update topology + long currentTimeMillis = System.currentTimeMillis(); + final List<ExecutionJobVertex> initializedJobVertices = new ArrayList<>(); + for (ExecutionJobVertexInitializedEvent event : jobVertexInitializedEvents) { + final ExecutionJobVertex jobVertex = getExecutionJobVertex(event.getJobVertexId()); + initializeJobVertexFunction.accept( + jobVertex, + event.getParallelism(), + event.getJobVertexInputInfos(), + currentTimeMillis); + initializedJobVertices.add(jobVertex); + } + updateTopologyFunction.accept(initializedJobVertices); + + // remove the last batch of vertices that do not record the states of operator coordinator + // and shuffle master + LinkedList<ExecutionVertexFinishedEvent> finishedEvents = + new LinkedList<>(executionVertexFinishedEventMap.values()); + while (!finishedEvents.isEmpty() + && !finishedEvents.getLast().hasOperatorCoordinatorAndShuffleMasterSnapshots()) { + finishedEvents.removeLast(); + } + + if (finishedEvents.isEmpty()) { + return; + } + + // find the last operator coordinator state for each operator coordinator + Map<OperatorID, byte[]> operatorCoordinatorSnapshots = new HashMap<>(); + + List<ShuffleMasterSnapshot> shuffleMasterSnapshots = new ArrayList<>(); + + // transition states of all vertices + for (ExecutionVertexFinishedEvent event : finishedEvents) { + JobVertexID jobVertexId = event.getExecutionVertexId().getJobVertexId(); + ExecutionJobVertex jobVertex = executionGraph.getJobVertex(jobVertexId); + checkState(jobVertex.isInitialized()); + + int subTaskIndex = event.getExecutionVertexId().getSubtaskIndex(); + Execution execution = + jobVertex.getTaskVertices()[subTaskIndex].getCurrentExecutionAttempt(); + // recover execution info. + execution.recoverExecution( + event.getExecutionAttemptId(), + event.getTaskManagerLocation(), + event.getUserAccumulators(), + event.getIOMetrics()); + + // recover operator coordinator + for (Map.Entry<OperatorID, CompletableFuture<byte[]>> entry : + event.getOperatorCoordinatorSnapshotFutures().entrySet()) { + checkState(entry.getValue().isDone()); + operatorCoordinatorSnapshots.put(entry.getKey(), entry.getValue().get()); + } + + // recover shuffle master + if (event.getShuffleMasterSnapshotFuture() != null) { + ShuffleMasterSnapshot shuffleMasterSnapshot = + event.getShuffleMasterSnapshotFuture().get(); + if (shuffleMasterSnapshot.isIncremental()) { + shuffleMasterSnapshots.add(shuffleMasterSnapshot); + } else { + shuffleMasterSnapshots = Arrays.asList(shuffleMasterSnapshot); + } + } + } + + // restore operator coordinator state if needed. + final Map<OperatorID, JobVertexID> operatorToJobVertex = new HashMap<>(); + for (ExecutionJobVertex jobVertex : executionGraph.getAllVertices().values()) { + if (!jobVertex.isInitialized()) { + continue; + } + + for (OperatorCoordinatorHolder holder : jobVertex.getOperatorCoordinators()) { + operatorToJobVertex.put(holder.operatorId(), jobVertex.getJobVertexId()); + } + } + + try { + restoreOperatorCoordinators(operatorCoordinatorSnapshots, operatorToJobVertex); + } catch (Exception exception) { + log.warn("Restore coordinator operator failed.", exception); + throw exception; + } + + // restore shuffle master + restoreShuffleMaster(shuffleMasterSnapshots); + } + + @Override + public void notifyExecutionVertexReset(Collection<ExecutionVertexID> vertices) { + // write execute vertex reset event. + checkState(!isRecovering()); + jobEventManager.writeEvent(new ExecutionVertexResetEvent(new ArrayList<>(vertices)), false); + } + + @Override + public void notifyExecutionJobVertexInitialization( + JobVertexID jobVertexId, + int parallelism, + Map<IntermediateDataSetID, JobVertexInputInfo> jobVertexInputInfos) { + // write execution job vertex initialized event. + checkState(!isRecovering()); + jobEventManager.writeEvent( + new ExecutionJobVertexInitializedEvent( + jobVertexId, parallelism, jobVertexInputInfos), + false); + } + + @Override + public void notifyExecutionFinished( + ExecutionVertexID executionVertexId, TaskExecutionStateTransition taskExecutionState) { + checkState(!isRecovering()); + + checkState(taskExecutionState.getExecutionState() == ExecutionState.FINISHED); + Execution execution = getExecutionVertex(executionVertexId).getCurrentExecutionAttempt(); + + // check whether the job vertex is finished. + ExecutionJobVertex jobVertex = execution.getVertex().getJobVertex(); + boolean jobVertexFinished = jobVertex.getAggregateState() == ExecutionState.FINISHED; + + // snapshot operator coordinators if needed. + needToSnapshotJobVertices.add(executionVertexId.getJobVertexId()); + final Map<OperatorID, CompletableFuture<byte[]>> operatorCoordinatorSnapshotFutures = + new HashMap<>(); + CompletableFuture<ShuffleMasterSnapshot> shuffleMasterSnapshotFuture = null; + long currentRelativeTime = clock.relativeTimeMillis(); + if (jobVertexFinished + || (currentRelativeTime - lastSnapshotRelativeTime >= snapshotMinPauseMills)) { + // operator coordinator + operatorCoordinatorSnapshotFutures.putAll(snapshotOperatorCoordinators()); + lastSnapshotRelativeTime = currentRelativeTime; + needToSnapshotJobVertices.clear(); + + // shuffle master + shuffleMasterSnapshotFuture = snapshotShuffleMaster(); + } + + // write job event. + jobEventManager.writeEvent( + new ExecutionVertexFinishedEvent( + execution.getAttemptId(), + execution.getAssignedResourceLocation(), + operatorCoordinatorSnapshotFutures, + shuffleMasterSnapshotFuture, + execution.getIOMetrics(), + execution.getUserAccumulators()), + jobVertexFinished); + } + + private Map<OperatorID, CompletableFuture<byte[]>> snapshotOperatorCoordinators() { + + final Map<OperatorID, CompletableFuture<byte[]>> snapshotFutures = new HashMap<>(); + + for (JobVertexID jobVertexId : needToSnapshotJobVertices) { + final ExecutionJobVertex jobVertex = checkNotNull(getExecutionJobVertex(jobVertexId)); + + log.info( + "Snapshot operator coordinators of {} to job event, checkpointId {}.", + jobVertex.getName(), + NO_CHECKPOINT); + + for (OperatorCoordinatorHolder holder : jobVertex.getOperatorCoordinators()) { + if (holder.coordinator().supportsBatchSnapshot()) { + final CompletableFuture<byte[]> checkpointFuture = new CompletableFuture<>(); + holder.checkpointCoordinator(NO_CHECKPOINT, checkpointFuture); + snapshotFutures.put(holder.operatorId(), checkpointFuture); + } + } + } + return snapshotFutures; + } + + private CompletableFuture<ShuffleMasterSnapshot> snapshotShuffleMaster() { + + checkState(shuffleMaster.supportsBatchSnapshot()); + CompletableFuture<ShuffleMasterSnapshot> shuffleMasterSnapshotFuture = + new CompletableFuture<>(); + shuffleMaster.snapshotState( + shuffleMasterSnapshotFuture, new DefaultShuffleMasterSnapshotContext()); + return shuffleMasterSnapshotFuture; + } + + private void reviseVertices() throws Exception { + Set<ExecutionVertexID> verticesToRestart = new HashSet<>(); + + for (ExecutionJobVertex jobVertex : executionGraph.getAllVertices().values()) { + if (!jobVertex.isInitialized() || jobVertex.getOperatorCoordinators().isEmpty()) { + continue; + } + + boolean allSupportsBatchSnapshot = + jobVertex.getOperatorCoordinators().stream() + .allMatch(holder -> holder.coordinator().supportsBatchSnapshot()); + + Set<ExecutionVertexID> unfinishedTasks = + Arrays.stream(jobVertex.getTaskVertices()) + .filter(vertex -> vertex.getExecutionState() != ExecutionState.FINISHED) + .map( + executionVertex -> { + // transition to terminal state to allow reset it + executionVertex + .getCurrentExecutionAttempt() + .transitionState(ExecutionState.CANCELED); + return executionVertex.getID(); + }) + .collect(Collectors.toSet()); + + if (allSupportsBatchSnapshot) { + log.info( + "All operator coordinators of jobVertex {} support batch snapshot, " + + "add {} unfinished tasks to revise.", + jobVertex.getName(), + unfinishedTasks.size()); + verticesToRestart.addAll(unfinishedTasks); + } else if (unfinishedTasks.isEmpty()) { + log.info( + "JobVertex {} is finished, but not all of its operator coordinators support " + + "batch snapshot. Therefore, if any single task within it requires " + + "a restart in the future, all tasks associated with this JobVertex " + + "need to be restarted as well.", + jobVertex.getName()); + requiredRestartJobVertices.add(jobVertex.getJobVertexId()); + } else { + log.info( + "Restart all tasks of jobVertex {} because it has not been finished and not " + + "all of its operator coordinators support batch snapshot.", + jobVertex.getName()); + verticesToRestart.addAll( + Arrays.stream(jobVertex.getTaskVertices()) + .map(ExecutionVertex::getID) + .collect(Collectors.toSet())); + } + } + + resetVerticesInRecovering(verticesToRestart); + } + + private void resetVerticesInRecovering(Set<ExecutionVertexID> verticesToRestart) + throws Exception { + checkState(isRecovering()); + + Set<JobVertexID> extraNeedToRestartJobVertices = + verticesToRestart.stream() + .map(ExecutionVertexID::getJobVertexId) + .filter(requiredRestartJobVertices::contains) + .collect(Collectors.toSet()); + + requiredRestartJobVertices.removeAll(extraNeedToRestartJobVertices); + + verticesToRestart.addAll( + extraNeedToRestartJobVertices.stream() + .flatMap( + jobVertexId -> { + ExecutionJobVertex jobVertex = + getExecutionJobVertex(jobVertexId); + return Arrays.stream(jobVertex.getTaskVertices()) + .map(ExecutionVertex::getID); + }) + .collect(Collectors.toSet())); + + // we only reset tasks which are not CREATED. + Set<ExecutionVertexID> verticesToReset = + verticesToRestart.stream() + .filter( + executionVertexID -> + getExecutionVertex(executionVertexID).getExecutionState() + != ExecutionState.CREATED) + .collect(Collectors.toSet()); + + resetVerticesInRecoveringFunction.accept(verticesToReset); + } + + private void recoverFailed( + Runnable recoverFailedListener, + TriConsumer<Throwable, Long, CompletableFuture<Map<String, String>>> failJobFunction) { + String message = + String.format( + "Job %s recover failed from JM failover, fail global.", + executionGraph.getJobID()); + log.warn(message); + shuffleMaster.notifyPartitionRecoveryCompleted(executionGraph.getJobID()); + executionGraph.transitionState(JobStatus.RECONCILING, JobStatus.RUNNING); + + // clear job events and restart job event manager. + jobEventManager.stop(true); + try { + jobEventManager.start(); + } catch (Throwable throwable) { + failJobFunction.accept( + throwable, + System.currentTimeMillis(), + FailureEnricherUtils.EMPTY_FAILURE_LABELS); + return; + } + + recoverFailedListener.run(); + } + + private void recoverFinished(Consumer<Set<JobVertexID>> recoverFinishedListener) { + log.info("Job {} successfully recovered from JM failover", executionGraph.getJobID()); + shuffleMaster.notifyPartitionRecoveryCompleted(executionGraph.getJobID()); + + executionGraph.transitionState(JobStatus.RECONCILING, JobStatus.RUNNING); + checkExecutionGraphState(); + recoverFinishedListener.accept(requiredRestartJobVertices); + } + + private void checkExecutionGraphState() { + for (ExecutionVertex executionVertex : executionGraph.getAllExecutionVertices()) { + ExecutionState state = executionVertex.getExecutionState(); + checkState(state == ExecutionState.CREATED || state == ExecutionState.FINISHED); + } + } + + private void recoverPartitions( + Collection<PartitionWithMetrics> partitionWithMetrics, + Consumer<Map<IntermediateResultPartitionID, ResultPartitionBytes>> + updateResultPartitionBytesMetricsFunction) + throws Exception { + mainThreadExecutor.assertRunningInMainThread(); + + Set<ResultPartitionID> actualPartitions = + partitionWithMetrics.stream() + .map(PartitionWithMetrics::getPartition) + .map(ShuffleDescriptor::getResultPartitionID) + .collect(Collectors.toSet()); + + ReconcileResult reconcileResult = reconcilePartitions(actualPartitions); + + log.info( + "All unknown partitions {}, missing partitions {}, available partitions {}", + reconcileResult.unknownPartitions, + reconcileResult.missingPartitions, + reconcileResult.availablePartitions); + + // release all unknown partitions + ((InternalExecutionGraphAccessor) executionGraph) + .getPartitionTracker() + .stopTrackingAndReleasePartitions(reconcileResult.unknownPartitions); + + // start tracking all available partitions + Map<IntermediateResultPartitionID, ResultPartitionBytes> availablePartitionBytes = + new HashMap<>(); + partitionWithMetrics.stream() + .filter( + partitionAndMetric -> + reconcileResult.availablePartitions.contains( + partitionAndMetric.getPartition().getResultPartitionID())) + .forEach( + partitionAndMetric -> { + ShuffleDescriptor shuffleDescriptor = partitionAndMetric.getPartition(); + + // we cannot get the producer id when using remote shuffle + ResourceID producerTaskExecutorId = UNKNOWN_PRODUCER; + if (shuffleDescriptor.storesLocalResourcesOn().isPresent()) { + producerTaskExecutorId = + shuffleDescriptor.storesLocalResourcesOn().get(); + } + IntermediateResultPartition partition = + executionGraph.getResultPartitionOrThrow( + shuffleDescriptor + .getResultPartitionID() + .getPartitionId()); + ((InternalExecutionGraphAccessor) executionGraph) + .getPartitionTracker() + .startTrackingPartition( + producerTaskExecutorId, + Execution.createResultPartitionDeploymentDescriptor( + partition, shuffleDescriptor)); + + availablePartitionBytes.put( + shuffleDescriptor.getResultPartitionID().getPartitionId(), + partitionAndMetric.getPartitionMetrics().getPartitionBytes()); + }); + + // recover the produced partitions for executions + Map< + ExecutionVertexID, + Map<IntermediateResultPartitionID, ResultPartitionDeploymentDescriptor>> + allDescriptors = new HashMap<>(); + ((InternalExecutionGraphAccessor) executionGraph) + .getPartitionTracker() + .getAllTrackedNonClusterPartitions() + .forEach( + descriptor -> { + ExecutionVertexID vertexId = + descriptor + .getShuffleDescriptor() + .getResultPartitionID() + .getProducerId() + .getExecutionVertexId(); + if (!allDescriptors.containsKey(vertexId)) { + allDescriptors.put(vertexId, new HashMap<>()); + } + + allDescriptors + .get(vertexId) + .put(descriptor.getPartitionId(), descriptor); + }); + + allDescriptors.forEach( + (vertexId, descriptors) -> + getExecutionVertex(vertexId) + .getCurrentExecutionAttempt() + .recoverProducedPartitions(descriptors)); + + // recover result partition bytes + updateResultPartitionBytesMetricsFunction.accept(availablePartitionBytes); + + // restart all producers of missing partitions + List<ExecutionVertexID> missingPartitionVertices = + reconcileResult.missingPartitions.stream() + .map(ResultPartitionID::getPartitionId) + .map(this::getProducer) + .map(ExecutionVertex::getID) + .collect(Collectors.toList()); + + // get all vertices need to restart according failover strategy. + Set<ExecutionVertexID> verticesToReset = new HashSet<>(); + for (ExecutionVertexID executionVertexId : missingPartitionVertices) { + if (!verticesToReset.contains(executionVertexId)) { + verticesToReset.addAll( + failoverStrategy.getTasksNeedingRestart(executionVertexId, null)); Review Comment: Because after restoring the operator coordinator, at that time the upstream partition is not available. If we call `failoverStrategy.getTasksNeedingRestart`, all upstream tasks will be mistakenly included. So I do this only after partition recovered finished. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org