[FLINK-5892] Add new StateAssignmentOperationV2
Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/8045faba Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/8045faba Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/8045faba Branch: refs/heads/table-retraction Commit: 8045fabac736cc8c6b48fda8328cf91f329dc3bf Parents: 591841f Author: guowei.mgw <guowei....@gmail.com> Authored: Mon Apr 24 11:47:47 2017 +0200 Committer: zentol <ches...@apache.org> Committed: Fri Apr 28 20:09:11 2017 +0200 ---------------------------------------------------------------------- .../checkpoint/StateAssignmentOperationV2.java | 458 +++++++++++++++++++ 1 file changed, 458 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/8045faba/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperationV2.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperationV2.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperationV2.java new file mode 100644 index 0000000..83c188c --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperationV2.java @@ -0,0 +1,458 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.runtime.checkpoint; + +//import com.google.common.collect.Lists; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.runtime.executiongraph.Execution; +import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; +import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.state.ChainedStateHandle; +import org.apache.flink.runtime.state.KeyGroupRange; +import org.apache.flink.runtime.state.KeyGroupsStateHandle; +import org.apache.flink.runtime.state.KeyedStateHandle; +import org.apache.flink.runtime.state.OperatorStateHandle; +import org.apache.flink.runtime.state.StreamStateHandle; +import org.apache.flink.runtime.state.TaskStateHandles; +import org.apache.flink.util.Preconditions; +import org.slf4j.Logger; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * This class encapsulates the operation of assigning restored state when restoring from a checkpoint that works on the + * granularity of operators. This is the case for checkpoints that were created either with a Flink version >= 1.3 or + * 1.2 if the savepoint only contains {@link SubtaskState}s for which the length of contained + * {@link ChainedStateHandle}s is equal to 1. + */ +public class StateAssignmentOperationV2 { + + private final Logger logger; + private final Map<JobVertexID, ExecutionJobVertex> tasks; + private final Map<JobVertexID, TaskState> taskStates; + private final boolean allowNonRestoredState; + + public StateAssignmentOperationV2( + Logger logger, + Map<JobVertexID, ExecutionJobVertex> tasks, + Map<JobVertexID, TaskState> taskStates, + boolean allowNonRestoredState) { + + this.logger = Preconditions.checkNotNull(logger); + this.tasks = Preconditions.checkNotNull(tasks); + this.taskStates = Preconditions.checkNotNull(taskStates); + this.allowNonRestoredState = allowNonRestoredState; + } + + public boolean assignStates() throws Exception { + Map<JobVertexID, TaskState> localStates = new HashMap<>(taskStates); + Map<JobVertexID, ExecutionJobVertex> localTasks = this.tasks; + + Set<JobVertexID> allOperatorIDs = new HashSet<>(); + for (ExecutionJobVertex executionJobVertex : tasks.values()) { + //allOperatorIDs.addAll(Lists.newArrayList(executionJobVertex.getOperatorIDs())); + } + for (Map.Entry<JobVertexID, TaskState> taskGroupStateEntry : taskStates.entrySet()) { + TaskState taskState = taskGroupStateEntry.getValue(); + //----------------------------------------find operator for state--------------------------------------------- + + if (!allOperatorIDs.contains(taskGroupStateEntry.getKey())) { + if (allowNonRestoredState) { + logger.info("Skipped checkpoint state for operator {}.", taskState.getJobVertexID()); + continue; + } else { + throw new IllegalStateException("There is no operator for the state " + taskState.getJobVertexID()); + } + } + } + + for (Map.Entry<JobVertexID, ExecutionJobVertex> task : localTasks.entrySet()) { + final ExecutionJobVertex executionJobVertex = task.getValue(); + + // find the states of all operators belonging to this task + JobVertexID[] operatorIDs = null;//executionJobVertex.getOperatorIDs(); + JobVertexID[] altOperatorIDs = null;//executionJobVertex.getUserDefinedOperatorIDs(); + List<TaskState> operatorStates = new ArrayList<>(); + boolean statelessTask = true; + for (int x = 0; x < operatorIDs.length; x++) { + JobVertexID operatorID = altOperatorIDs[x] == null + ? operatorIDs[x] + : altOperatorIDs[x]; + + TaskState operatorState = localStates.remove(operatorID); + if (operatorState == null) { + operatorState = new TaskState( + operatorID, + executionJobVertex.getParallelism(), + executionJobVertex.getMaxParallelism(), + 1); + } else { + statelessTask = false; + } + operatorStates.add(operatorState); + } + if (statelessTask) { // skip tasks where no operator has any state + continue; + } + + assignAttemptState(task.getValue(), operatorStates); + } + + return true; + } + + private void assignAttemptState(ExecutionJobVertex executionJobVertex, List<TaskState> operatorStates) { + + JobVertexID[] operatorIDs = null;//executionJobVertex.getOperatorIDs(); + + //1. first compute the new parallelism + checkParallelismPreconditions(operatorStates, executionJobVertex); + + int newParallelism = executionJobVertex.getParallelism(); + + List<KeyGroupRange> keyGroupPartitions = null;//StateAssignmentOperationUtils.createKeyGroupPartitions( + //executionJobVertex.getMaxParallelism(), + //newParallelism); + + //2. Redistribute the operator state. + /** + * + * Redistribute ManagedOperatorStates and RawOperatorStates from old parallelism to new parallelism. + * + * The old ManagedOperatorStates with old parallelism 3: + * + * parallelism0 parallelism1 parallelism2 + * op0 states0,0 state0,1 state0,2 + * op1 + * op2 states2,0 state2,1 state1,2 + * op3 states3,0 state3,1 state3,2 + * + * The new ManagedOperatorStates with new parallelism 4: + * + * parallelism0 parallelism1 parallelism2 parallelism3 + * op0 state0,0 state0,1 state0,2 state0,3 + * op1 + * op2 state2,0 state2,1 state2,2 state2,3 + * op3 state3,0 state3,1 state3,2 state3,3 + */ + List<List<Collection<OperatorStateHandle>>> newManagedOperatorStates = new ArrayList<>(); + List<List<Collection<OperatorStateHandle>>> newRawOperatorStates = new ArrayList<>(); + + reDistributePartitionableStates(operatorStates, newParallelism, newManagedOperatorStates, newRawOperatorStates); + + + //3. Compute TaskStateHandles of every subTask in the executionJobVertex + /** + * An executionJobVertex's all state handles needed to restore are something like a matrix + * + * parallelism0 parallelism1 parallelism2 parallelism3 + * op0 sh(0,0) sh(0,1) sh(0,2) sh(0,3) + * op1 sh(1,0) sh(1,1) sh(1,2) sh(1,3) + * op2 sh(2,0) sh(2,1) sh(2,2) sh(2,3) + * op3 sh(3,0) sh(3,1) sh(3,2) sh(3,3) + * + * we will compute the state handles column by column. + * + */ + for (int subTaskIndex = 0; subTaskIndex < newParallelism; subTaskIndex++) { + + Execution currentExecutionAttempt = executionJobVertex.getTaskVertices()[subTaskIndex] + .getCurrentExecutionAttempt(); + + List<StreamStateHandle> subNonPartitionableState = new ArrayList<>(); + + Tuple2<Collection<KeyedStateHandle>, Collection<KeyedStateHandle>> subKeyedState = null; + + List<Collection<OperatorStateHandle>> subManagedOperatorState = new ArrayList<>(); + List<Collection<OperatorStateHandle>> subRawOperatorState = new ArrayList<>(); + + + for (int operatorIndex = 0; operatorIndex < operatorIDs.length; operatorIndex++) { + TaskState operatorState = operatorStates.get(operatorIndex); + int oldParallelism = operatorState.getParallelism(); + + // NonPartitioned State + + reAssignSubNonPartitionedStates( + operatorState, + subTaskIndex, + newParallelism, + oldParallelism, + subNonPartitionableState); + + // PartitionedState + reAssignSubPartitionableState(newManagedOperatorStates, + newRawOperatorStates, + subTaskIndex, + operatorIndex, + subManagedOperatorState, + subRawOperatorState); + + // KeyedState + if (operatorIndex == operatorIDs.length - 1) { + subKeyedState = reAssignSubKeyedStates(operatorState, + keyGroupPartitions, + subTaskIndex, + newParallelism, + oldParallelism); + + } + } + + + // check if a stateless task + if (!allElementsAreNull(subNonPartitionableState) || + !allElementsAreNull(subManagedOperatorState) || + !allElementsAreNull(subRawOperatorState) || + subKeyedState != null) { + + TaskStateHandles taskStateHandles = new TaskStateHandles( + + new ChainedStateHandle<>(subNonPartitionableState), + subManagedOperatorState, + subRawOperatorState, + subKeyedState != null ? subKeyedState.f0 : null, + subKeyedState != null ? subKeyedState.f1 : null); + + currentExecutionAttempt.setInitialState(taskStateHandles); + } + } + } + + + public void checkParallelismPreconditions(List<TaskState> operatorStates, ExecutionJobVertex executionJobVertex) { + + for (TaskState taskState : operatorStates) { + //StateAssignmentOperation.checkParallelismPreconditions(taskState, executionJobVertex, this.logger); + } + } + + + private void reAssignSubPartitionableState( + List<List<Collection<OperatorStateHandle>>> newMangedOperatorStates, + List<List<Collection<OperatorStateHandle>>> newRawOperatorStates, + int subTaskIndex, int operatorIndex, + List<Collection<OperatorStateHandle>> subManagedOperatorState, + List<Collection<OperatorStateHandle>> subRawOperatorState) { + + if (newMangedOperatorStates.get(operatorIndex) != null) { + subManagedOperatorState.add(newMangedOperatorStates.get(operatorIndex).get(subTaskIndex)); + } else { + subManagedOperatorState.add(null); + } + if (newRawOperatorStates.get(operatorIndex) != null) { + subRawOperatorState.add(newRawOperatorStates.get(operatorIndex).get(subTaskIndex)); + } else { + subRawOperatorState.add(null); + } + + + } + + private Tuple2<Collection<KeyedStateHandle>, Collection<KeyedStateHandle>> reAssignSubKeyedStates( + TaskState operatorState, + List<KeyGroupRange> keyGroupPartitions, + int subTaskIndex, + int newParallelism, + int oldParallelism) { + + Collection<KeyedStateHandle> subManagedKeyedState; + Collection<KeyedStateHandle> subRawKeyedState; + + if (newParallelism == oldParallelism) { + if (operatorState.getState(subTaskIndex) != null) { + KeyedStateHandle oldSubManagedKeyedState = operatorState.getState(subTaskIndex).getManagedKeyedState(); + KeyedStateHandle oldSubRawKeyedState = operatorState.getState(subTaskIndex).getRawKeyedState(); + subManagedKeyedState = oldSubManagedKeyedState != null ? Collections.singletonList( + oldSubManagedKeyedState) : null; + subRawKeyedState = oldSubRawKeyedState != null ? Collections.singletonList( + oldSubRawKeyedState) : null; + } else { + subManagedKeyedState = null; + subRawKeyedState = null; + } + } else { + subManagedKeyedState = getManagedKeyedStateHandles(operatorState, keyGroupPartitions.get(subTaskIndex)); + subRawKeyedState = getRawKeyedStateHandles(operatorState, keyGroupPartitions.get(subTaskIndex)); + } + if (subManagedKeyedState == null && subRawKeyedState == null) { + return null; + } + return new Tuple2<>(subManagedKeyedState, subRawKeyedState); + } + + + private <X> boolean allElementsAreNull(List<X> nonPartitionableStates) { + for (Object streamStateHandle : nonPartitionableStates) { + if (streamStateHandle != null) { + return false; + } + } + return true; + } + + + private void reAssignSubNonPartitionedStates( + TaskState operatorState, + int subTaskIndex, + int newParallelism, + int oldParallelism, + List<StreamStateHandle> subNonPartitionableState) { + if (oldParallelism == newParallelism) { + if (operatorState.getState(subTaskIndex) != null && + !operatorState.getState(subTaskIndex).getLegacyOperatorState().isEmpty()) { + subNonPartitionableState.add(operatorState.getState(subTaskIndex).getLegacyOperatorState().get(0)); + } else { + subNonPartitionableState.add(null); + } + } else { + subNonPartitionableState.add(null); + } + } + + private void reDistributePartitionableStates( + List<TaskState> operatorStates, int newParallelism, + List<List<Collection<OperatorStateHandle>>> newManagedOperatorStates, + List<List<Collection<OperatorStateHandle>>> newRawOperatorStates) { + + //collect the old partitionalbe state + List<List<OperatorStateHandle>> oldManagedOperatorStates = new ArrayList<>(); + List<List<OperatorStateHandle>> oldRawOperatorStates = new ArrayList<>(); + + collectPartionableStates(operatorStates, oldManagedOperatorStates, oldRawOperatorStates); + + + //redistribute + OperatorStateRepartitioner opStateRepartitioner = RoundRobinOperatorStateRepartitioner.INSTANCE; + + for (int operatorIndex = 0; operatorIndex < operatorStates.size(); operatorIndex++) { + int oldParallelism = operatorStates.get(operatorIndex).getParallelism(); + //newManagedOperatorStates.add(StateAssignmentOperationUtils.applyRepartitioner(opStateRepartitioner, + // oldManagedOperatorStates.get(operatorIndex), oldParallelism, newParallelism)); + //newRawOperatorStates.add(StateAssignmentOperationUtils.applyRepartitioner(opStateRepartitioner, + // oldRawOperatorStates.get(operatorIndex), oldParallelism, newParallelism)); + + } + } + + + private void collectPartionableStates( + List<TaskState> operatorStates, + List<List<OperatorStateHandle>> managedOperatorStates, + List<List<OperatorStateHandle>> rawOperatorStates) { + + for (TaskState operatorState : operatorStates) { + List<OperatorStateHandle> managedOperatorState = null; + List<OperatorStateHandle> rawOperatorState = null; + + for (int i = 0; i < operatorState.getParallelism(); i++) { + SubtaskState subtaskState = operatorState.getState(i); + if (subtaskState != null) { + if (subtaskState.getManagedOperatorState() != null && + subtaskState.getManagedOperatorState().getLength() > 0 && + subtaskState.getManagedOperatorState().get(0) != null) { + if (managedOperatorState == null) { + managedOperatorState = new ArrayList<>(); + } + managedOperatorState.add(subtaskState.getManagedOperatorState().get(0)); + } + + if (subtaskState.getRawOperatorState() != null && + subtaskState.getRawOperatorState().getLength() > 0 && + subtaskState.getRawOperatorState().get(0) != null) { + if (rawOperatorState == null) { + rawOperatorState = new ArrayList<>(); + } + rawOperatorState.add(subtaskState.getRawOperatorState().get(0)); + } + } + + } + managedOperatorStates.add(managedOperatorState); + rawOperatorStates.add(rawOperatorState); + } + } + + + /** + * Collect {@link KeyGroupsStateHandle managedKeyedStateHandles} which have intersection with given + * {@link KeyGroupRange} from {@link TaskState operatorState} + * + * @param operatorState all state handles of a operator + * @param subtaskKeyGroupRange the KeyGroupRange of a subtask + * @return all managedKeyedStateHandles which have intersection with given KeyGroupRange + */ + public static List<KeyedStateHandle> getManagedKeyedStateHandles( + TaskState operatorState, + KeyGroupRange subtaskKeyGroupRange) { + + List<KeyedStateHandle> subtaskKeyedStateHandles = null; + + for (int i = 0; i < operatorState.getParallelism(); i++) { + if (operatorState.getState(i) != null && operatorState.getState(i).getManagedKeyedState() != null) { + KeyedStateHandle intersectedKeyedStateHandle = operatorState.getState(i).getManagedKeyedState().getIntersection(subtaskKeyGroupRange); + + if (intersectedKeyedStateHandle != null) { + if (subtaskKeyedStateHandles == null) { + subtaskKeyedStateHandles = new ArrayList<>(); + } + subtaskKeyedStateHandles.add(intersectedKeyedStateHandle); + } + } + } + + return subtaskKeyedStateHandles; + } + + /** + * Collect {@link KeyGroupsStateHandle rawKeyedStateHandles} which have intersection with given + * {@link KeyGroupRange} from {@link TaskState operatorState} + * + * @param operatorState all state handles of a operator + * @param subtaskKeyGroupRange the KeyGroupRange of a subtask + * @return all rawKeyedStateHandles which have intersection with given KeyGroupRange + */ + public static List<KeyedStateHandle> getRawKeyedStateHandles( + TaskState operatorState, + KeyGroupRange subtaskKeyGroupRange) { + + List<KeyedStateHandle> subtaskKeyedStateHandles = null; + + for (int i = 0; i < operatorState.getParallelism(); i++) { + if (operatorState.getState(i) != null && operatorState.getState(i).getRawKeyedState() != null) { + KeyedStateHandle intersectedKeyedStateHandle = operatorState.getState(i).getRawKeyedState().getIntersection(subtaskKeyGroupRange); + + if (intersectedKeyedStateHandle != null) { + if (subtaskKeyedStateHandles == null) { + subtaskKeyedStateHandles = new ArrayList<>(); + } + subtaskKeyedStateHandles.add(intersectedKeyedStateHandle); + } + } + } + + return subtaskKeyedStateHandles; + } +}