http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperationV2.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperationV2.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperationV2.java deleted file mode 100644 index 83c188c..0000000 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperationV2.java +++ /dev/null @@ -1,458 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.flink.runtime.checkpoint; - -//import com.google.common.collect.Lists; -import org.apache.flink.api.java.tuple.Tuple2; -import org.apache.flink.runtime.executiongraph.Execution; -import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; -import org.apache.flink.runtime.jobgraph.JobVertexID; -import org.apache.flink.runtime.state.ChainedStateHandle; -import org.apache.flink.runtime.state.KeyGroupRange; -import org.apache.flink.runtime.state.KeyGroupsStateHandle; -import org.apache.flink.runtime.state.KeyedStateHandle; -import org.apache.flink.runtime.state.OperatorStateHandle; -import org.apache.flink.runtime.state.StreamStateHandle; -import org.apache.flink.runtime.state.TaskStateHandles; -import org.apache.flink.util.Preconditions; -import org.slf4j.Logger; - -import java.util.ArrayList; -import java.util.Collection; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; - -/** - * This class encapsulates the operation of assigning restored state when restoring from a checkpoint that works on the - * granularity of operators. This is the case for checkpoints that were created either with a Flink version >= 1.3 or - * 1.2 if the savepoint only contains {@link SubtaskState}s for which the length of contained - * {@link ChainedStateHandle}s is equal to 1. - */ -public class StateAssignmentOperationV2 { - - private final Logger logger; - private final Map<JobVertexID, ExecutionJobVertex> tasks; - private final Map<JobVertexID, TaskState> taskStates; - private final boolean allowNonRestoredState; - - public StateAssignmentOperationV2( - Logger logger, - Map<JobVertexID, ExecutionJobVertex> tasks, - Map<JobVertexID, TaskState> taskStates, - boolean allowNonRestoredState) { - - this.logger = Preconditions.checkNotNull(logger); - this.tasks = Preconditions.checkNotNull(tasks); - this.taskStates = Preconditions.checkNotNull(taskStates); - this.allowNonRestoredState = allowNonRestoredState; - } - - public boolean assignStates() throws Exception { - Map<JobVertexID, TaskState> localStates = new HashMap<>(taskStates); - Map<JobVertexID, ExecutionJobVertex> localTasks = this.tasks; - - Set<JobVertexID> allOperatorIDs = new HashSet<>(); - for (ExecutionJobVertex executionJobVertex : tasks.values()) { - //allOperatorIDs.addAll(Lists.newArrayList(executionJobVertex.getOperatorIDs())); - } - for (Map.Entry<JobVertexID, TaskState> taskGroupStateEntry : taskStates.entrySet()) { - TaskState taskState = taskGroupStateEntry.getValue(); - //----------------------------------------find operator for state--------------------------------------------- - - if (!allOperatorIDs.contains(taskGroupStateEntry.getKey())) { - if (allowNonRestoredState) { - logger.info("Skipped checkpoint state for operator {}.", taskState.getJobVertexID()); - continue; - } else { - throw new IllegalStateException("There is no operator for the state " + taskState.getJobVertexID()); - } - } - } - - for (Map.Entry<JobVertexID, ExecutionJobVertex> task : localTasks.entrySet()) { - final ExecutionJobVertex executionJobVertex = task.getValue(); - - // find the states of all operators belonging to this task - JobVertexID[] operatorIDs = null;//executionJobVertex.getOperatorIDs(); - JobVertexID[] altOperatorIDs = null;//executionJobVertex.getUserDefinedOperatorIDs(); - List<TaskState> operatorStates = new ArrayList<>(); - boolean statelessTask = true; - for (int x = 0; x < operatorIDs.length; x++) { - JobVertexID operatorID = altOperatorIDs[x] == null - ? operatorIDs[x] - : altOperatorIDs[x]; - - TaskState operatorState = localStates.remove(operatorID); - if (operatorState == null) { - operatorState = new TaskState( - operatorID, - executionJobVertex.getParallelism(), - executionJobVertex.getMaxParallelism(), - 1); - } else { - statelessTask = false; - } - operatorStates.add(operatorState); - } - if (statelessTask) { // skip tasks where no operator has any state - continue; - } - - assignAttemptState(task.getValue(), operatorStates); - } - - return true; - } - - private void assignAttemptState(ExecutionJobVertex executionJobVertex, List<TaskState> operatorStates) { - - JobVertexID[] operatorIDs = null;//executionJobVertex.getOperatorIDs(); - - //1. first compute the new parallelism - checkParallelismPreconditions(operatorStates, executionJobVertex); - - int newParallelism = executionJobVertex.getParallelism(); - - List<KeyGroupRange> keyGroupPartitions = null;//StateAssignmentOperationUtils.createKeyGroupPartitions( - //executionJobVertex.getMaxParallelism(), - //newParallelism); - - //2. Redistribute the operator state. - /** - * - * Redistribute ManagedOperatorStates and RawOperatorStates from old parallelism to new parallelism. - * - * The old ManagedOperatorStates with old parallelism 3: - * - * parallelism0 parallelism1 parallelism2 - * op0 states0,0 state0,1 state0,2 - * op1 - * op2 states2,0 state2,1 state1,2 - * op3 states3,0 state3,1 state3,2 - * - * The new ManagedOperatorStates with new parallelism 4: - * - * parallelism0 parallelism1 parallelism2 parallelism3 - * op0 state0,0 state0,1 state0,2 state0,3 - * op1 - * op2 state2,0 state2,1 state2,2 state2,3 - * op3 state3,0 state3,1 state3,2 state3,3 - */ - List<List<Collection<OperatorStateHandle>>> newManagedOperatorStates = new ArrayList<>(); - List<List<Collection<OperatorStateHandle>>> newRawOperatorStates = new ArrayList<>(); - - reDistributePartitionableStates(operatorStates, newParallelism, newManagedOperatorStates, newRawOperatorStates); - - - //3. Compute TaskStateHandles of every subTask in the executionJobVertex - /** - * An executionJobVertex's all state handles needed to restore are something like a matrix - * - * parallelism0 parallelism1 parallelism2 parallelism3 - * op0 sh(0,0) sh(0,1) sh(0,2) sh(0,3) - * op1 sh(1,0) sh(1,1) sh(1,2) sh(1,3) - * op2 sh(2,0) sh(2,1) sh(2,2) sh(2,3) - * op3 sh(3,0) sh(3,1) sh(3,2) sh(3,3) - * - * we will compute the state handles column by column. - * - */ - for (int subTaskIndex = 0; subTaskIndex < newParallelism; subTaskIndex++) { - - Execution currentExecutionAttempt = executionJobVertex.getTaskVertices()[subTaskIndex] - .getCurrentExecutionAttempt(); - - List<StreamStateHandle> subNonPartitionableState = new ArrayList<>(); - - Tuple2<Collection<KeyedStateHandle>, Collection<KeyedStateHandle>> subKeyedState = null; - - List<Collection<OperatorStateHandle>> subManagedOperatorState = new ArrayList<>(); - List<Collection<OperatorStateHandle>> subRawOperatorState = new ArrayList<>(); - - - for (int operatorIndex = 0; operatorIndex < operatorIDs.length; operatorIndex++) { - TaskState operatorState = operatorStates.get(operatorIndex); - int oldParallelism = operatorState.getParallelism(); - - // NonPartitioned State - - reAssignSubNonPartitionedStates( - operatorState, - subTaskIndex, - newParallelism, - oldParallelism, - subNonPartitionableState); - - // PartitionedState - reAssignSubPartitionableState(newManagedOperatorStates, - newRawOperatorStates, - subTaskIndex, - operatorIndex, - subManagedOperatorState, - subRawOperatorState); - - // KeyedState - if (operatorIndex == operatorIDs.length - 1) { - subKeyedState = reAssignSubKeyedStates(operatorState, - keyGroupPartitions, - subTaskIndex, - newParallelism, - oldParallelism); - - } - } - - - // check if a stateless task - if (!allElementsAreNull(subNonPartitionableState) || - !allElementsAreNull(subManagedOperatorState) || - !allElementsAreNull(subRawOperatorState) || - subKeyedState != null) { - - TaskStateHandles taskStateHandles = new TaskStateHandles( - - new ChainedStateHandle<>(subNonPartitionableState), - subManagedOperatorState, - subRawOperatorState, - subKeyedState != null ? subKeyedState.f0 : null, - subKeyedState != null ? subKeyedState.f1 : null); - - currentExecutionAttempt.setInitialState(taskStateHandles); - } - } - } - - - public void checkParallelismPreconditions(List<TaskState> operatorStates, ExecutionJobVertex executionJobVertex) { - - for (TaskState taskState : operatorStates) { - //StateAssignmentOperation.checkParallelismPreconditions(taskState, executionJobVertex, this.logger); - } - } - - - private void reAssignSubPartitionableState( - List<List<Collection<OperatorStateHandle>>> newMangedOperatorStates, - List<List<Collection<OperatorStateHandle>>> newRawOperatorStates, - int subTaskIndex, int operatorIndex, - List<Collection<OperatorStateHandle>> subManagedOperatorState, - List<Collection<OperatorStateHandle>> subRawOperatorState) { - - if (newMangedOperatorStates.get(operatorIndex) != null) { - subManagedOperatorState.add(newMangedOperatorStates.get(operatorIndex).get(subTaskIndex)); - } else { - subManagedOperatorState.add(null); - } - if (newRawOperatorStates.get(operatorIndex) != null) { - subRawOperatorState.add(newRawOperatorStates.get(operatorIndex).get(subTaskIndex)); - } else { - subRawOperatorState.add(null); - } - - - } - - private Tuple2<Collection<KeyedStateHandle>, Collection<KeyedStateHandle>> reAssignSubKeyedStates( - TaskState operatorState, - List<KeyGroupRange> keyGroupPartitions, - int subTaskIndex, - int newParallelism, - int oldParallelism) { - - Collection<KeyedStateHandle> subManagedKeyedState; - Collection<KeyedStateHandle> subRawKeyedState; - - if (newParallelism == oldParallelism) { - if (operatorState.getState(subTaskIndex) != null) { - KeyedStateHandle oldSubManagedKeyedState = operatorState.getState(subTaskIndex).getManagedKeyedState(); - KeyedStateHandle oldSubRawKeyedState = operatorState.getState(subTaskIndex).getRawKeyedState(); - subManagedKeyedState = oldSubManagedKeyedState != null ? Collections.singletonList( - oldSubManagedKeyedState) : null; - subRawKeyedState = oldSubRawKeyedState != null ? Collections.singletonList( - oldSubRawKeyedState) : null; - } else { - subManagedKeyedState = null; - subRawKeyedState = null; - } - } else { - subManagedKeyedState = getManagedKeyedStateHandles(operatorState, keyGroupPartitions.get(subTaskIndex)); - subRawKeyedState = getRawKeyedStateHandles(operatorState, keyGroupPartitions.get(subTaskIndex)); - } - if (subManagedKeyedState == null && subRawKeyedState == null) { - return null; - } - return new Tuple2<>(subManagedKeyedState, subRawKeyedState); - } - - - private <X> boolean allElementsAreNull(List<X> nonPartitionableStates) { - for (Object streamStateHandle : nonPartitionableStates) { - if (streamStateHandle != null) { - return false; - } - } - return true; - } - - - private void reAssignSubNonPartitionedStates( - TaskState operatorState, - int subTaskIndex, - int newParallelism, - int oldParallelism, - List<StreamStateHandle> subNonPartitionableState) { - if (oldParallelism == newParallelism) { - if (operatorState.getState(subTaskIndex) != null && - !operatorState.getState(subTaskIndex).getLegacyOperatorState().isEmpty()) { - subNonPartitionableState.add(operatorState.getState(subTaskIndex).getLegacyOperatorState().get(0)); - } else { - subNonPartitionableState.add(null); - } - } else { - subNonPartitionableState.add(null); - } - } - - private void reDistributePartitionableStates( - List<TaskState> operatorStates, int newParallelism, - List<List<Collection<OperatorStateHandle>>> newManagedOperatorStates, - List<List<Collection<OperatorStateHandle>>> newRawOperatorStates) { - - //collect the old partitionalbe state - List<List<OperatorStateHandle>> oldManagedOperatorStates = new ArrayList<>(); - List<List<OperatorStateHandle>> oldRawOperatorStates = new ArrayList<>(); - - collectPartionableStates(operatorStates, oldManagedOperatorStates, oldRawOperatorStates); - - - //redistribute - OperatorStateRepartitioner opStateRepartitioner = RoundRobinOperatorStateRepartitioner.INSTANCE; - - for (int operatorIndex = 0; operatorIndex < operatorStates.size(); operatorIndex++) { - int oldParallelism = operatorStates.get(operatorIndex).getParallelism(); - //newManagedOperatorStates.add(StateAssignmentOperationUtils.applyRepartitioner(opStateRepartitioner, - // oldManagedOperatorStates.get(operatorIndex), oldParallelism, newParallelism)); - //newRawOperatorStates.add(StateAssignmentOperationUtils.applyRepartitioner(opStateRepartitioner, - // oldRawOperatorStates.get(operatorIndex), oldParallelism, newParallelism)); - - } - } - - - private void collectPartionableStates( - List<TaskState> operatorStates, - List<List<OperatorStateHandle>> managedOperatorStates, - List<List<OperatorStateHandle>> rawOperatorStates) { - - for (TaskState operatorState : operatorStates) { - List<OperatorStateHandle> managedOperatorState = null; - List<OperatorStateHandle> rawOperatorState = null; - - for (int i = 0; i < operatorState.getParallelism(); i++) { - SubtaskState subtaskState = operatorState.getState(i); - if (subtaskState != null) { - if (subtaskState.getManagedOperatorState() != null && - subtaskState.getManagedOperatorState().getLength() > 0 && - subtaskState.getManagedOperatorState().get(0) != null) { - if (managedOperatorState == null) { - managedOperatorState = new ArrayList<>(); - } - managedOperatorState.add(subtaskState.getManagedOperatorState().get(0)); - } - - if (subtaskState.getRawOperatorState() != null && - subtaskState.getRawOperatorState().getLength() > 0 && - subtaskState.getRawOperatorState().get(0) != null) { - if (rawOperatorState == null) { - rawOperatorState = new ArrayList<>(); - } - rawOperatorState.add(subtaskState.getRawOperatorState().get(0)); - } - } - - } - managedOperatorStates.add(managedOperatorState); - rawOperatorStates.add(rawOperatorState); - } - } - - - /** - * Collect {@link KeyGroupsStateHandle managedKeyedStateHandles} which have intersection with given - * {@link KeyGroupRange} from {@link TaskState operatorState} - * - * @param operatorState all state handles of a operator - * @param subtaskKeyGroupRange the KeyGroupRange of a subtask - * @return all managedKeyedStateHandles which have intersection with given KeyGroupRange - */ - public static List<KeyedStateHandle> getManagedKeyedStateHandles( - TaskState operatorState, - KeyGroupRange subtaskKeyGroupRange) { - - List<KeyedStateHandle> subtaskKeyedStateHandles = null; - - for (int i = 0; i < operatorState.getParallelism(); i++) { - if (operatorState.getState(i) != null && operatorState.getState(i).getManagedKeyedState() != null) { - KeyedStateHandle intersectedKeyedStateHandle = operatorState.getState(i).getManagedKeyedState().getIntersection(subtaskKeyGroupRange); - - if (intersectedKeyedStateHandle != null) { - if (subtaskKeyedStateHandles == null) { - subtaskKeyedStateHandles = new ArrayList<>(); - } - subtaskKeyedStateHandles.add(intersectedKeyedStateHandle); - } - } - } - - return subtaskKeyedStateHandles; - } - - /** - * Collect {@link KeyGroupsStateHandle rawKeyedStateHandles} which have intersection with given - * {@link KeyGroupRange} from {@link TaskState operatorState} - * - * @param operatorState all state handles of a operator - * @param subtaskKeyGroupRange the KeyGroupRange of a subtask - * @return all rawKeyedStateHandles which have intersection with given KeyGroupRange - */ - public static List<KeyedStateHandle> getRawKeyedStateHandles( - TaskState operatorState, - KeyGroupRange subtaskKeyGroupRange) { - - List<KeyedStateHandle> subtaskKeyedStateHandles = null; - - for (int i = 0; i < operatorState.getParallelism(); i++) { - if (operatorState.getState(i) != null && operatorState.getState(i).getRawKeyedState() != null) { - KeyedStateHandle intersectedKeyedStateHandle = operatorState.getState(i).getRawKeyedState().getIntersection(subtaskKeyGroupRange); - - if (intersectedKeyedStateHandle != null) { - if (subtaskKeyedStateHandles == null) { - subtaskKeyedStateHandles = new ArrayList<>(); - } - subtaskKeyedStateHandles.add(intersectedKeyedStateHandle); - } - } - } - - return subtaskKeyedStateHandles; - } -}
http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java index 4f5f536..aa5c516 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java @@ -34,7 +34,11 @@ import java.util.Objects; * tasks of a {@link org.apache.flink.runtime.jobgraph.JobVertex}. * * This class basically groups all non-partitioned state and key-group state belonging to the same job vertex together. + * + * @deprecated Internal class for savepoint backwards compatibility. Don't use for other purposes. */ +@Deprecated +@SuppressWarnings("deprecation") public class TaskState implements CompositeStateHandle { private static final long serialVersionUID = -4845578005863201810L; http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/Savepoint.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/Savepoint.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/Savepoint.java index 79ec596..a7cf4b5 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/Savepoint.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/Savepoint.java @@ -21,6 +21,7 @@ package org.apache.flink.runtime.checkpoint.savepoint; import org.apache.flink.core.io.Versioned; import org.apache.flink.runtime.checkpoint.CheckpointIDCounter; import org.apache.flink.runtime.checkpoint.MasterState; +import org.apache.flink.runtime.checkpoint.OperatorState; import org.apache.flink.runtime.checkpoint.TaskState; import java.util.Collection; @@ -54,8 +55,10 @@ public interface Savepoint extends Versioned { * * <p>These are used to restore the snapshot state. * + * @deprecated Only kept for backwards-compatibility with versionS < 1.3. Will be removed in the future. * @return Snapshotted task states */ + @Deprecated Collection<TaskState> getTaskStates(); /** @@ -64,6 +67,15 @@ public interface Savepoint extends Versioned { Collection<MasterState> getMasterStates(); /** + * Returns the snapshotted operator states. + * + * <p>These are used to restore the snapshot state. + * + * @return Snapshotted operator states + */ + Collection<OperatorState> getOperatorStates(); + + /** * Disposes the savepoint. */ void dispose() throws Exception; http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointLoader.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointLoader.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointLoader.java index 8ee38da..38db7c2 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointLoader.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointLoader.java @@ -22,9 +22,10 @@ import org.apache.flink.api.common.JobID; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.runtime.checkpoint.CheckpointProperties; import org.apache.flink.runtime.checkpoint.CompletedCheckpoint; -import org.apache.flink.runtime.checkpoint.TaskState; +import org.apache.flink.runtime.checkpoint.OperatorState; import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.state.StreamStateHandle; import org.slf4j.Logger; @@ -67,32 +68,42 @@ public class SavepointLoader { final Tuple2<Savepoint, StreamStateHandle> savepointAndHandle = SavepointStore.loadSavepointWithHandle(savepointPath, classLoader); - final Savepoint savepoint = savepointAndHandle.f0; + Savepoint savepoint = savepointAndHandle.f0; final StreamStateHandle metadataHandle = savepointAndHandle.f1; - final Map<JobVertexID, TaskState> taskStates = new HashMap<>(savepoint.getTaskStates().size()); + if (savepoint.getTaskStates() != null) { + savepoint = SavepointV2.convertToOperatorStateSavepointV2(tasks, savepoint); + } + // generate mapping from operator to task + Map<OperatorID, ExecutionJobVertex> operatorToJobVertexMapping = new HashMap<>(); + for (ExecutionJobVertex task : tasks.values()) { + for (OperatorID operatorID : task.getOperatorIDs()) { + operatorToJobVertexMapping.put(operatorID, task); + } + } + // (2) validate it (parallelism, etc) boolean expandedToLegacyIds = false; - // (2) validate it (parallelism, etc) - for (TaskState taskState : savepoint.getTaskStates()) { + HashMap<OperatorID, OperatorState> operatorStates = new HashMap<>(savepoint.getOperatorStates().size()); + for (OperatorState operatorState : savepoint.getOperatorStates()) { - ExecutionJobVertex executionJobVertex = tasks.get(taskState.getJobVertexID()); + ExecutionJobVertex executionJobVertex = operatorToJobVertexMapping.get(operatorState.getOperatorID()); // on the first time we can not find the execution job vertex for an id, we also consider alternative ids, // for example as generated from older flink versions, to provide backwards compatibility. if (executionJobVertex == null && !expandedToLegacyIds) { - tasks = ExecutionJobVertex.includeLegacyJobVertexIDs(tasks); - executionJobVertex = tasks.get(taskState.getJobVertexID()); + operatorToJobVertexMapping = ExecutionJobVertex.includeAlternativeOperatorIDs(operatorToJobVertexMapping); + executionJobVertex = operatorToJobVertexMapping.get(operatorState.getOperatorID()); expandedToLegacyIds = true; - LOG.info("Could not find ExecutionJobVertex. Including legacy JobVertexIDs in search."); + LOG.info("Could not find ExecutionJobVertex. Including user-defined OperatorIDs in search."); } if (executionJobVertex != null) { - if (executionJobVertex.getMaxParallelism() == taskState.getMaxParallelism() + if (executionJobVertex.getMaxParallelism() == operatorState.getMaxParallelism() || !executionJobVertex.isMaxParallelismConfigured()) { - taskStates.put(taskState.getJobVertexID(), taskState); + operatorStates.put(operatorState.getOperatorID(), operatorState); } else { String msg = String.format("Failed to rollback to savepoint %s. " + "Max parallelism mismatch between savepoint state and new program. " + @@ -100,21 +111,21 @@ public class SavepointLoader { "max parallelism %d. This indicates that the program has been changed " + "in a non-compatible way after the savepoint.", savepoint, - taskState.getJobVertexID(), - taskState.getMaxParallelism(), + operatorState.getOperatorID(), + operatorState.getMaxParallelism(), executionJobVertex.getMaxParallelism()); throw new IllegalStateException(msg); } } else if (allowNonRestoredState) { - LOG.info("Skipping savepoint state for operator {}.", taskState.getJobVertexID()); + LOG.info("Skipping savepoint state for operator {}.", operatorState.getOperatorID()); } else { String msg = String.format("Failed to rollback to savepoint %s. " + "Cannot map savepoint state for operator %s to the new program, " + "because the operator is not available in the new program. If " + "you want to allow to skip this, you can set the --allowNonRestoredState " + "option on the CLI.", - savepointPath, taskState.getJobVertexID()); + savepointPath, operatorState.getOperatorID()); throw new IllegalStateException(msg); } @@ -122,8 +133,17 @@ public class SavepointLoader { // (3) convert to checkpoint so the system can fall back to it CheckpointProperties props = CheckpointProperties.forStandardSavepoint(); - return new CompletedCheckpoint(jobId, savepoint.getCheckpointId(), 0L, 0L, - taskStates, savepoint.getMasterStates(), props, metadataHandle, savepointPath); + + return new CompletedCheckpoint( + jobId, + savepoint.getCheckpointId(), + 0L, + 0L, + operatorStates, + savepoint.getMasterStates(), + props, + metadataHandle, + savepointPath); } // ------------------------------------------------------------------------ http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1.java index 196c870..daf5b7f 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1.java @@ -19,6 +19,7 @@ package org.apache.flink.runtime.checkpoint.savepoint; import org.apache.flink.runtime.checkpoint.MasterState; +import org.apache.flink.runtime.checkpoint.OperatorState; import org.apache.flink.runtime.checkpoint.TaskState; import org.apache.flink.util.Preconditions; @@ -68,6 +69,11 @@ public class SavepointV1 implements Savepoint { } @Override + public Collection<OperatorState> getOperatorStates() { + return null; + } + + @Override public void dispose() throws Exception { // since checkpoints are never deserialized into this format, // this method should never be called http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java index ae9f4a9..aaa8cdd 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java @@ -19,7 +19,6 @@ package org.apache.flink.runtime.checkpoint.savepoint; import org.apache.flink.core.fs.Path; -import org.apache.flink.runtime.checkpoint.MasterState; import org.apache.flink.runtime.checkpoint.SubtaskState; import org.apache.flink.runtime.checkpoint.TaskState; import org.apache.flink.runtime.jobgraph.JobVertexID; @@ -38,7 +37,6 @@ import java.io.DataOutputStream; import java.io.IOException; import java.util.ArrayList; import java.util.Collection; -import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -97,7 +95,7 @@ class SavepointV1Serializer implements SavepointSerializer<SavepointV2> { } } - return new SavepointV2(checkpointId, taskStates, Collections.<MasterState>emptyList()); + return new SavepointV2(checkpointId, taskStates); } public void serializeOld(SavepointV1 savepoint, DataOutputStream dos) throws IOException { http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2.java index 100982d..6a3b57f 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2.java @@ -19,10 +19,23 @@ package org.apache.flink.runtime.checkpoint.savepoint; import org.apache.flink.runtime.checkpoint.MasterState; +import org.apache.flink.runtime.checkpoint.OperatorState; +import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; +import org.apache.flink.runtime.checkpoint.SubtaskState; import org.apache.flink.runtime.checkpoint.TaskState; +import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.jobgraph.OperatorID; +import org.apache.flink.runtime.state.ChainedStateHandle; +import org.apache.flink.runtime.state.KeyedStateHandle; +import org.apache.flink.runtime.state.OperatorStateHandle; +import org.apache.flink.runtime.state.StreamStateHandle; import java.util.Collection; import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import static org.apache.flink.util.Preconditions.checkNotNull; @@ -38,20 +51,48 @@ public class SavepointV2 implements Savepoint { /** The checkpoint ID */ private final long checkpointId; - /** The task states */ + /** + * The task states + * @deprecated Only kept for backwards-compatibility with versions < 1.3. Will be removed in the future. + */ + @Deprecated private final Collection<TaskState> taskStates; + /** The operator states */ + private final Collection<OperatorState> operatorStates; + /** The states generated by the CheckpointCoordinator */ private final Collection<MasterState> masterStates; - + /** @deprecated Only kept for backwards-compatibility with versions < 1.3. Will be removed in the future. */ + @Deprecated public SavepointV2(long checkpointId, Collection<TaskState> taskStates) { - this(checkpointId, taskStates, Collections.<MasterState>emptyList()); + this( + checkpointId, + null, + checkNotNull(taskStates, "taskStates"), + Collections.<MasterState>emptyList() + ); + } + + public SavepointV2(long checkpointId, Collection<OperatorState> operatorStates, Collection<MasterState> masterStates) { + this( + checkpointId, + checkNotNull(operatorStates, "operatorStates"), + null, + masterStates + ); } - public SavepointV2(long checkpointId, Collection<TaskState> taskStates, Collection<MasterState> masterStates) { + private SavepointV2( + long checkpointId, + Collection<OperatorState> operatorStates, + Collection<TaskState> taskStates, + Collection<MasterState> masterStates) { + this.checkpointId = checkpointId; - this.taskStates = checkNotNull(taskStates, "taskStates"); + this.operatorStates = operatorStates; + this.taskStates = taskStates; this.masterStates = checkNotNull(masterStates, "masterStates"); } @@ -66,6 +107,11 @@ public class SavepointV2 implements Savepoint { } @Override + public Collection<OperatorState> getOperatorStates() { + return operatorStates; + } + + @Override public Collection<TaskState> getTaskStates() { return taskStates; } @@ -77,10 +123,10 @@ public class SavepointV2 implements Savepoint { @Override public void dispose() throws Exception { - for (TaskState taskState : taskStates) { - taskState.discardState(); + for (OperatorState operatorState : operatorStates) { + operatorState.discardState(); } - taskStates.clear(); + operatorStates.clear(); masterStates.clear(); } @@ -88,4 +134,97 @@ public class SavepointV2 implements Savepoint { public String toString() { return "Checkpoint Metadata (version=" + VERSION + ')'; } + + /** + * Converts the {@link Savepoint} containing {@link TaskState TaskStates} to an equivalent savepoint containing + * {@link OperatorState OperatorStates}. + * + * @param savepoint savepoint to convert + * @param tasks map of all vertices and their job vertex ids + * @return converted completed checkpoint + * @deprecated Only kept for backwards-compatibility with versions < 1.3. Will be removed in the future. + * */ + @Deprecated + public static Savepoint convertToOperatorStateSavepointV2( + Map<JobVertexID, ExecutionJobVertex> tasks, + Savepoint savepoint) { + + if (savepoint.getOperatorStates() != null) { + return savepoint; + } + + boolean expandedToLegacyIds = false; + + Map<OperatorID, OperatorState> operatorStates = new HashMap<>(savepoint.getTaskStates().size() << 1); + + for (TaskState taskState : savepoint.getTaskStates()) { + ExecutionJobVertex jobVertex = tasks.get(taskState.getJobVertexID()); + + // on the first time we can not find the execution job vertex for an id, we also consider alternative ids, + // for example as generated from older flink versions, to provide backwards compatibility. + if (jobVertex == null && !expandedToLegacyIds) { + tasks = ExecutionJobVertex.includeLegacyJobVertexIDs(tasks); + jobVertex = tasks.get(taskState.getJobVertexID()); + expandedToLegacyIds = true; + } + + List<OperatorID> operatorIDs = jobVertex.getOperatorIDs(); + + for (int subtaskIndex = 0; subtaskIndex < jobVertex.getParallelism(); subtaskIndex++) { + SubtaskState subtaskState = taskState.getState(subtaskIndex); + + if (subtaskState == null) { + continue; + } + + @SuppressWarnings("deprecation") + ChainedStateHandle<StreamStateHandle> nonPartitionedState = + subtaskState.getLegacyOperatorState(); + ChainedStateHandle<OperatorStateHandle> partitioneableState = + subtaskState.getManagedOperatorState(); + ChainedStateHandle<OperatorStateHandle> rawOperatorState = + subtaskState.getRawOperatorState(); + + for (int chainIndex = 0; chainIndex < taskState.getChainLength(); chainIndex++) { + + // task consists of multiple operators so we have to break the state apart + for (int o = 0; o < operatorIDs.size(); o++) { + OperatorID operatorID = operatorIDs.get(o); + OperatorState operatorState = operatorStates.get(operatorID); + + if (operatorState == null) { + operatorState = new OperatorState( + operatorID, + jobVertex.getParallelism(), + jobVertex.getMaxParallelism()); + operatorStates.put(operatorID, operatorState); + } + + KeyedStateHandle managedKeyedState = null; + KeyedStateHandle rawKeyedState = null; + + // only the head operator retains the keyed state + if (o == operatorIDs.size() - 1) { + managedKeyedState = subtaskState.getManagedKeyedState(); + rawKeyedState = subtaskState.getRawKeyedState(); + } + + OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState( + nonPartitionedState != null ? nonPartitionedState.get(o) : null, + partitioneableState != null ? partitioneableState.get(o) : null, + rawOperatorState != null ? rawOperatorState.get(o) : null, + managedKeyedState, + rawKeyedState); + + operatorState.putState(subtaskIndex, operatorSubtaskState); + } + } + } + } + + return new SavepointV2( + savepoint.getCheckpointId(), + operatorStates.values(), + savepoint.getMasterStates()); + } } http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2Serializer.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2Serializer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2Serializer.java index 307ea16..1b5f2c6 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2Serializer.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2Serializer.java @@ -20,10 +20,9 @@ package org.apache.flink.runtime.checkpoint.savepoint; import org.apache.flink.core.fs.Path; import org.apache.flink.runtime.checkpoint.MasterState; -import org.apache.flink.runtime.checkpoint.SubtaskState; -import org.apache.flink.runtime.checkpoint.TaskState; -import org.apache.flink.runtime.jobgraph.JobVertexID; -import org.apache.flink.runtime.state.ChainedStateHandle; +import org.apache.flink.runtime.jobgraph.OperatorID; +import org.apache.flink.runtime.checkpoint.OperatorState; +import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.KeyGroupRangeOffsets; import org.apache.flink.runtime.state.KeyGroupsStateHandle; @@ -97,25 +96,25 @@ class SavepointV2Serializer implements SavepointSerializer<SavepointV2> { serializeMasterState(ms, dos); } - // third: task states - final Collection<TaskState> taskStates = checkpointMetadata.getTaskStates(); - dos.writeInt(taskStates.size()); + // third: operator states + Collection<OperatorState> operatorStates = checkpointMetadata.getOperatorStates(); + dos.writeInt(operatorStates.size()); - for (TaskState taskState : checkpointMetadata.getTaskStates()) { - // Vertex ID - dos.writeLong(taskState.getJobVertexID().getLowerPart()); - dos.writeLong(taskState.getJobVertexID().getUpperPart()); + for (OperatorState operatorState : operatorStates) { + // Operator ID + dos.writeLong(operatorState.getOperatorID().getLowerPart()); + dos.writeLong(operatorState.getOperatorID().getUpperPart()); // Parallelism - int parallelism = taskState.getParallelism(); + int parallelism = operatorState.getParallelism(); dos.writeInt(parallelism); - dos.writeInt(taskState.getMaxParallelism()); - dos.writeInt(taskState.getChainLength()); + dos.writeInt(operatorState.getMaxParallelism()); + dos.writeInt(1); // Sub task states - Map<Integer, SubtaskState> subtaskStateMap = taskState.getSubtaskStates(); + Map<Integer, OperatorSubtaskState> subtaskStateMap = operatorState.getSubtaskStates(); dos.writeInt(subtaskStateMap.size()); - for (Map.Entry<Integer, SubtaskState> entry : subtaskStateMap.entrySet()) { + for (Map.Entry<Integer, OperatorSubtaskState> entry : subtaskStateMap.entrySet()) { dos.writeInt(entry.getKey()); serializeSubtaskState(entry.getValue(), dos); } @@ -147,31 +146,32 @@ class SavepointV2Serializer implements SavepointSerializer<SavepointV2> { throw new IOException("invalid number of master states: " + numMasterStates); } - // third: task states - final int numTaskStates = dis.readInt(); - final ArrayList<TaskState> taskStates = new ArrayList<>(numTaskStates); + // third: operator states + int numTaskStates = dis.readInt(); + List<OperatorState> operatorStates = new ArrayList<>(numTaskStates); for (int i = 0; i < numTaskStates; i++) { - JobVertexID jobVertexId = new JobVertexID(dis.readLong(), dis.readLong()); + OperatorID jobVertexId = new OperatorID(dis.readLong(), dis.readLong()); int parallelism = dis.readInt(); int maxParallelism = dis.readInt(); int chainLength = dis.readInt(); // Add task state - TaskState taskState = new TaskState(jobVertexId, parallelism, maxParallelism, chainLength); - taskStates.add(taskState); + OperatorState taskState = new OperatorState(jobVertexId, parallelism, maxParallelism); + operatorStates.add(taskState); // Sub task states int numSubTaskStates = dis.readInt(); for (int j = 0; j < numSubTaskStates; j++) { int subtaskIndex = dis.readInt(); - SubtaskState subtaskState = deserializeSubtaskState(dis); + + OperatorSubtaskState subtaskState = deserializeSubtaskState(dis); taskState.putState(subtaskIndex, subtaskState); } } - return new SavepointV2(checkpointId, taskStates, masterStates); + return new SavepointV2(checkpointId, operatorStates, masterStates); } // ------------------------------------------------------------------------ @@ -235,35 +235,32 @@ class SavepointV2Serializer implements SavepointSerializer<SavepointV2> { // task state (de)serialization methods // ------------------------------------------------------------------------ - private static void serializeSubtaskState(SubtaskState subtaskState, DataOutputStream dos) throws IOException { + private static void serializeSubtaskState(OperatorSubtaskState subtaskState, DataOutputStream dos) throws IOException { dos.writeLong(-1); - ChainedStateHandle<StreamStateHandle> nonPartitionableState = subtaskState.getLegacyOperatorState(); + StreamStateHandle nonPartitionableState = subtaskState.getLegacyOperatorState(); - int len = nonPartitionableState != null ? nonPartitionableState.getLength() : 0; + int len = nonPartitionableState != null ? 1 : 0; dos.writeInt(len); - for (int i = 0; i < len; ++i) { - StreamStateHandle stateHandle = nonPartitionableState.get(i); - serializeStreamStateHandle(stateHandle, dos); + if (len == 1) { + serializeStreamStateHandle(nonPartitionableState, dos); } - ChainedStateHandle<OperatorStateHandle> operatorStateBackend = subtaskState.getManagedOperatorState(); + OperatorStateHandle operatorStateBackend = subtaskState.getManagedOperatorState(); - len = operatorStateBackend != null ? operatorStateBackend.getLength() : 0; + len = operatorStateBackend != null ? 1 : 0; dos.writeInt(len); - for (int i = 0; i < len; ++i) { - OperatorStateHandle stateHandle = operatorStateBackend.get(i); - serializeOperatorStateHandle(stateHandle, dos); + if (len == 1) { + serializeOperatorStateHandle(operatorStateBackend, dos); } - ChainedStateHandle<OperatorStateHandle> operatorStateFromStream = subtaskState.getRawOperatorState(); + OperatorStateHandle operatorStateFromStream = subtaskState.getRawOperatorState(); - len = operatorStateFromStream != null ? operatorStateFromStream.getLength() : 0; + len = operatorStateFromStream != null ? 1 : 0; dos.writeInt(len); - for (int i = 0; i < len; ++i) { - OperatorStateHandle stateHandle = operatorStateFromStream.get(i); - serializeOperatorStateHandle(stateHandle, dos); + if (len == 1) { + serializeOperatorStateHandle(operatorStateFromStream, dos); } KeyedStateHandle keyedStateBackend = subtaskState.getManagedKeyedState(); @@ -273,49 +270,28 @@ class SavepointV2Serializer implements SavepointSerializer<SavepointV2> { serializeKeyedStateHandle(keyedStateStream, dos); } - private static SubtaskState deserializeSubtaskState(DataInputStream dis) throws IOException { + private static OperatorSubtaskState deserializeSubtaskState(DataInputStream dis) throws IOException { // Duration field has been removed from SubtaskState long ignoredDuration = dis.readLong(); int len = dis.readInt(); - List<StreamStateHandle> nonPartitionableState = new ArrayList<>(len); - for (int i = 0; i < len; ++i) { - StreamStateHandle streamStateHandle = deserializeStreamStateHandle(dis); - nonPartitionableState.add(streamStateHandle); - } - + StreamStateHandle nonPartitionableState = len == 0 ? null : deserializeStreamStateHandle(dis); len = dis.readInt(); - List<OperatorStateHandle> operatorStateBackend = new ArrayList<>(len); - for (int i = 0; i < len; ++i) { - OperatorStateHandle streamStateHandle = deserializeOperatorStateHandle(dis); - operatorStateBackend.add(streamStateHandle); - } + OperatorStateHandle operatorStateBackend = len == 0 ? null : deserializeOperatorStateHandle(dis); len = dis.readInt(); - List<OperatorStateHandle> operatorStateStream = new ArrayList<>(len); - for (int i = 0; i < len; ++i) { - OperatorStateHandle streamStateHandle = deserializeOperatorStateHandle(dis); - operatorStateStream.add(streamStateHandle); - } + OperatorStateHandle operatorStateStream = len == 0 ? null : deserializeOperatorStateHandle(dis); KeyedStateHandle keyedStateBackend = deserializeKeyedStateHandle(dis); KeyedStateHandle keyedStateStream = deserializeKeyedStateHandle(dis); - ChainedStateHandle<StreamStateHandle> nonPartitionableStateChain = - new ChainedStateHandle<>(nonPartitionableState); - - ChainedStateHandle<OperatorStateHandle> operatorStateBackendChain = - new ChainedStateHandle<>(operatorStateBackend); - - ChainedStateHandle<OperatorStateHandle> operatorStateStreamChain = - new ChainedStateHandle<>(operatorStateStream); - return new SubtaskState( - nonPartitionableStateChain, - operatorStateBackendChain, - operatorStateStreamChain, + return new OperatorSubtaskState( + nonPartitionableState, + operatorStateBackend, + operatorStateStream, keyedStateBackend, keyedStateStream); } http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertex.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertex.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertex.java index 5fbce4d..2e5de64 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertex.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertex.java @@ -41,6 +41,7 @@ import org.apache.flink.runtime.jobgraph.JobEdge; import org.apache.flink.runtime.jobgraph.JobVertex; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.configuration.JobManagerOptions; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.jobmanager.scheduler.CoLocationGroup; import org.apache.flink.runtime.jobmanager.scheduler.SlotSharingGroup; import org.apache.flink.runtime.state.KeyGroupRangeAssignment; @@ -68,6 +69,24 @@ public class ExecutionJobVertex implements AccessExecutionJobVertex, Archiveable private final ExecutionGraph graph; private final JobVertex jobVertex; + + /** + * The IDs of all operators contained in this execution job vertex. + * + * The ID's are stored depth-first post-order; for the forking chain below the ID's would be stored as [D, E, B, C, A]. + * A - B - D + * \ \ + * C E + * This is the same order that operators are stored in the {@code StreamTask}. + */ + private final List<OperatorID> operatorIDs; + + /** + * The alternative IDs of all operators contained in this execution job vertex. + * + * The ID's are in the same order as {@link ExecutionJobVertex#operatorIDs}. + */ + private final List<OperatorID> userDefinedOperatorIds; private final ExecutionVertex[] taskVertices; @@ -139,6 +158,8 @@ public class ExecutionJobVertex implements AccessExecutionJobVertex, Archiveable this.serializedTaskInformation = null; this.taskVertices = new ExecutionVertex[numTaskVertices]; + this.operatorIDs = Collections.unmodifiableList(jobVertex.getOperatorIDs()); + this.userDefinedOperatorIds = Collections.unmodifiableList(jobVertex.getUserDefinedOperatorIDs()); this.inputs = new ArrayList<>(jobVertex.getInputs().size()); @@ -214,6 +235,24 @@ public class ExecutionJobVertex implements AccessExecutionJobVertex, Archiveable finishedSubtasks = new boolean[parallelism]; } + /** + * Returns a list containing the IDs of all operators contained in this execution job vertex. + * + * @return list containing the IDs of all contained operators + */ + public List<OperatorID> getOperatorIDs() { + return operatorIDs; + } + + /** + * Returns a list containing the alternative IDs of all operators contained in this execution job vertex. + * + * @return list containing alternative the IDs of all contained operators + */ + public List<OperatorID> getUserDefinedOperatorIDs() { + return userDefinedOperatorIds; + } + public void setMaxParallelism(int maxParallelismDerived) { Preconditions.checkState(!maxParallelismConfigured, @@ -731,6 +770,30 @@ public class ExecutionJobVertex implements AccessExecutionJobVertex, Archiveable return expanded; } + public static Map<OperatorID, ExecutionJobVertex> includeAlternativeOperatorIDs( + Map<OperatorID, ExecutionJobVertex> operatorMapping) { + + Map<OperatorID, ExecutionJobVertex> expanded = new HashMap<>(2 * operatorMapping.size()); + // first include all existing ids + expanded.putAll(operatorMapping); + + // now expand and add user-defined ids + for (ExecutionJobVertex executionJobVertex : operatorMapping.values()) { + if (executionJobVertex != null) { + JobVertex jobVertex = executionJobVertex.getJobVertex(); + if (jobVertex != null) { + for (OperatorID operatorID : jobVertex.getUserDefinedOperatorIDs()) { + if (operatorID != null) { + expanded.put(operatorID, executionJobVertex); + } + } + } + } + } + + return expanded; + } + @Override public ArchivedExecutionJobVertex archive() { return new ArchivedExecutionJobVertex(this); http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/InputFormatVertex.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/InputFormatVertex.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/InputFormatVertex.java index c4fc907..5627ac7 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/InputFormatVertex.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/InputFormatVertex.java @@ -39,8 +39,8 @@ public class InputFormatVertex extends JobVertex { super(name, id); } - public InputFormatVertex(String name, JobVertexID id, List<JobVertexID> alternativeIds) { - super(name, id, alternativeIds); + public InputFormatVertex(String name, JobVertexID id, List<JobVertexID> alternativeIds, List<OperatorID> operatorIds, List<OperatorID> alternativeOperatorIds) { + super(name, id, alternativeIds, operatorIds, alternativeOperatorIds); } http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobVertex.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobVertex.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobVertex.java index 1180db4..4f52895 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobVertex.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobVertex.java @@ -50,8 +50,15 @@ public class JobVertex implements java.io.Serializable { /** The ID of the vertex. */ private final JobVertexID id; + /** The alternative IDs of the vertex. */ private final ArrayList<JobVertexID> idAlternatives = new ArrayList<>(); + /** The IDs of all operators contained in this vertex. */ + private final ArrayList<OperatorID> operatorIDs = new ArrayList<>(); + + /** The alternative IDs of all operators contained in this vertex. */ + private final ArrayList<OperatorID> operatorIdsAlternatives = new ArrayList<>(); + /** List of produced data sets, one per writer */ private final ArrayList<IntermediateDataSet> results = new ArrayList<IntermediateDataSet>(); @@ -125,6 +132,9 @@ public class JobVertex implements java.io.Serializable { public JobVertex(String name, JobVertexID id) { this.name = name == null ? DEFAULT_NAME : name; this.id = id == null ? new JobVertexID() : id; + // the id lists must have the same size + this.operatorIDs.add(OperatorID.fromJobVertexID(this.id)); + this.operatorIdsAlternatives.add(null); } /** @@ -133,11 +143,16 @@ public class JobVertex implements java.io.Serializable { * @param name The name of the new job vertex. * @param primaryId The id of the job vertex. * @param alternativeIds The alternative ids of the job vertex. + * @param operatorIds The ids of all operators contained in this job vertex. + * @param alternativeOperatorIds The alternative ids of all operators contained in this job vertex- */ - public JobVertex(String name, JobVertexID primaryId, List<JobVertexID> alternativeIds) { + public JobVertex(String name, JobVertexID primaryId, List<JobVertexID> alternativeIds, List<OperatorID> operatorIds, List<OperatorID> alternativeOperatorIds) { + Preconditions.checkArgument(operatorIds.size() == alternativeOperatorIds.size()); this.name = name == null ? DEFAULT_NAME : name; this.id = primaryId == null ? new JobVertexID() : primaryId; this.idAlternatives.addAll(alternativeIds); + this.operatorIDs.addAll(operatorIds); + this.operatorIdsAlternatives.addAll(alternativeOperatorIds); } // -------------------------------------------------------------------------------------------- @@ -196,6 +211,14 @@ public class JobVertex implements java.io.Serializable { return this.inputs.size(); } + public List<OperatorID> getOperatorIDs() { + return operatorIDs; + } + + public List<OperatorID> getUserDefinedOperatorIDs() { + return operatorIdsAlternatives; + } + /** * Returns the vertex's configuration object which can be used to pass custom settings to the task at runtime. * http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/OperatorID.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/OperatorID.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/OperatorID.java new file mode 100644 index 0000000..0e378de --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/OperatorID.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.jobgraph; + +import org.apache.flink.util.AbstractID; + +/** + * A class for statistically unique operator IDs. + */ +public class OperatorID extends AbstractID { + + private static final long serialVersionUID = 1L; + + public OperatorID() { + super(); + } + + public OperatorID(byte[] bytes) { + super(bytes); + } + + public OperatorID(long lowerPart, long upperPart) { + super(lowerPart, upperPart); + } + + public static OperatorID fromJobVertexID(JobVertexID id) { + return new OperatorID(id.getLowerPart(), id.getUpperPart()); + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorExternalizedCheckpointsTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorExternalizedCheckpointsTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorExternalizedCheckpointsTest.java index 9f94f2f..d293eea 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorExternalizedCheckpointsTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorExternalizedCheckpointsTest.java @@ -34,6 +34,7 @@ import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; import org.apache.flink.runtime.executiongraph.ExecutionVertex; import org.apache.flink.runtime.jobgraph.JobStatus; import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings; import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint; import org.apache.flink.runtime.state.filesystem.FileStateHandle; @@ -190,7 +191,9 @@ public class CheckpointCoordinatorExternalizedCheckpointsTest { false); for (ExecutionVertex vertex : vertices) { - assertEquals(checkpoint.getTaskState(vertex.getJobvertexId()), loaded.getTaskState(vertex.getJobvertexId())); + for (OperatorID operatorID : vertex.getJobVertex().getOperatorIDs()) { + assertEquals(checkpoint.getOperatorStates().get(operatorID), loaded.getOperatorStates().get(operatorID)); + } } } http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java index 90b7fe7..6e20be3 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java @@ -25,10 +25,13 @@ import org.apache.flink.runtime.executiongraph.ExecutionVertex; import org.apache.flink.runtime.jobgraph.JobStatus; import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings; import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint; +import org.apache.flink.runtime.state.ChainedStateHandle; +import org.apache.flink.runtime.state.KeyedStateHandle; +import org.apache.flink.runtime.state.OperatorStateHandle; +import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.util.TestLogger; import org.junit.Test; import org.junit.runner.RunWith; -import org.powermock.api.mockito.PowerMockito; import org.powermock.core.classloader.annotations.PrepareForTest; import org.powermock.modules.junit4.PowerMockRunner; @@ -38,8 +41,10 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import static org.mockito.Matchers.anyInt; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; @RunWith(PowerMockRunner.class) @PrepareForTest(PendingCheckpoint.class) @@ -85,8 +90,26 @@ public class CheckpointCoordinatorFailureTest extends TestLogger { final long checkpointId = coord.getPendingCheckpoints().keySet().iterator().next(); SubtaskState subtaskState = mock(SubtaskState.class); - PowerMockito.when(subtaskState.getLegacyOperatorState()).thenReturn(null); - PowerMockito.when(subtaskState.getManagedOperatorState()).thenReturn(null); + + StreamStateHandle legacyHandle = mock(StreamStateHandle.class); + ChainedStateHandle<StreamStateHandle> chainedLegacyHandle = mock(ChainedStateHandle.class); + when(chainedLegacyHandle.get(anyInt())).thenReturn(legacyHandle); + when(subtaskState.getLegacyOperatorState()).thenReturn(chainedLegacyHandle); + + OperatorStateHandle managedHandle = mock(OperatorStateHandle.class); + ChainedStateHandle<OperatorStateHandle> chainedManagedHandle = mock(ChainedStateHandle.class); + when(chainedManagedHandle.get(anyInt())).thenReturn(managedHandle); + when(subtaskState.getManagedOperatorState()).thenReturn(chainedManagedHandle); + + OperatorStateHandle rawHandle = mock(OperatorStateHandle.class); + ChainedStateHandle<OperatorStateHandle> chainedRawHandle = mock(ChainedStateHandle.class); + when(chainedRawHandle.get(anyInt())).thenReturn(rawHandle); + when(subtaskState.getRawOperatorState()).thenReturn(chainedRawHandle); + + KeyedStateHandle managedKeyedHandle = mock(KeyedStateHandle.class); + when(subtaskState.getRawKeyedState()).thenReturn(managedKeyedHandle); + KeyedStateHandle managedRawHandle = mock(KeyedStateHandle.class); + when(subtaskState.getManagedKeyedState()).thenReturn(managedRawHandle); AcknowledgeCheckpoint acknowledgeMessage = new AcknowledgeCheckpoint(jid, executionAttemptId, checkpointId, new CheckpointMetrics(), subtaskState); @@ -102,7 +125,11 @@ public class CheckpointCoordinatorFailureTest extends TestLogger { assertTrue(pendingCheckpoint.isDiscarded()); // make sure that the subtask state has been discarded after we could not complete it. - verify(subtaskState).discardState(); + verify(subtaskState.getLegacyOperatorState().get(0)).discardState(); + verify(subtaskState.getManagedOperatorState().get(0)).discardState(); + verify(subtaskState.getRawOperatorState().get(0)).discardState(); + verify(subtaskState.getManagedKeyedState()).discardState(); + verify(subtaskState.getRawKeyedState()).discardState(); } private static final class FailingCompletedCheckpointStore implements CompletedCheckpointStore { http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorMasterHooksTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorMasterHooksTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorMasterHooksTest.java index 7c271a7..d6daa4e 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorMasterHooksTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorMasterHooksTest.java @@ -27,6 +27,7 @@ import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; import org.apache.flink.runtime.executiongraph.ExecutionVertex; import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings; import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint; @@ -228,7 +229,7 @@ public class CheckpointCoordinatorMasterHooksTest { final CompletedCheckpoint checkpoint = new CompletedCheckpoint( jid, checkpointId, 123L, 125L, - Collections.<JobVertexID, TaskState>emptyMap(), + Collections.<OperatorID, OperatorState>emptyMap(), masterHookStates, CheckpointProperties.forStandardCheckpoint(), null, @@ -282,7 +283,7 @@ public class CheckpointCoordinatorMasterHooksTest { final CompletedCheckpoint checkpoint = new CompletedCheckpoint( jid, checkpointId, 123L, 125L, - Collections.<JobVertexID, TaskState>emptyMap(), + Collections.<OperatorID, OperatorState>emptyMap(), masterHookStates, CheckpointProperties.forStandardCheckpoint(), null,