[FLINK-7541][runtime] Refactor StateAssignmentOperation and use OperatorID

This is not complete refactor, some methods still relay on the order of the
new and old operators.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/f1b2b83d
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/f1b2b83d
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/f1b2b83d

Branch: refs/heads/master
Commit: f1b2b83d639658907bbc521f967e2673c4c866a1
Parents: 79b916c
Author: Piotr Nowojski <piotr.nowoj...@gmail.com>
Authored: Fri Aug 25 15:23:15 2017 +0200
Committer: Stefan Richter <s.rich...@data-artisans.com>
Committed: Mon Sep 25 17:55:55 2017 +0200

----------------------------------------------------------------------
 .../checkpoint/StateAssignmentOperation.java    | 270 +++++++++++--------
 .../runtime/jobgraph/OperatorInstanceID.java    |  73 +++++
 2 files changed, 234 insertions(+), 109 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/f1b2b83d/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
index cc9f9cd..76db912 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
@@ -23,6 +23,7 @@ import org.apache.flink.runtime.executiongraph.Execution;
 import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.runtime.jobgraph.OperatorInstanceID;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
@@ -30,6 +31,9 @@ import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.util.Preconditions;
 
+import 
org.apache.flink.shaded.guava18.com.google.common.collect.ArrayListMultimap;
+import org.apache.flink.shaded.guava18.com.google.common.collect.Multimap;
+
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -42,6 +46,9 @@ import java.util.List;
 import java.util.Map;
 import java.util.Set;
 
+import static org.apache.flink.util.Preconditions.checkNotNull;
+import static org.apache.flink.util.Preconditions.checkState;
+
 /**
  * This class encapsulates the operation of assigning restored state when 
restoring from a checkpoint.
  */
@@ -116,9 +123,7 @@ public class StateAssignmentOperation {
                        executionJobVertex.getMaxParallelism(),
                        newParallelism);
 
-               //2. Redistribute the operator state.
                /**
-                *
                 * Redistribute ManagedOperatorStates and RawOperatorStates 
from old parallelism to new parallelism.
                 *
                 * The old ManagedOperatorStates with old parallelism 3:
@@ -137,13 +142,27 @@ public class StateAssignmentOperation {
                 * op2   state2,0         state2,1         state2,2             
state2,3
                 * op3   state3,0         state3,1         state3,2             
state3,3
                 */
-               List<List<Collection<OperatorStateHandle>>> 
newManagedOperatorStates = new ArrayList<>();
-               List<List<Collection<OperatorStateHandle>>> 
newRawOperatorStates = new ArrayList<>();
-
-               reDistributePartitionableStates(operatorStates, newParallelism, 
newManagedOperatorStates, newRawOperatorStates);
-
+               Multimap<OperatorInstanceID, OperatorStateHandle> 
newManagedOperatorStates = ArrayListMultimap.create();
+               Multimap<OperatorInstanceID, OperatorStateHandle> 
newRawOperatorStates = ArrayListMultimap.create();
+
+               reDistributePartitionableStates(
+                       operatorStates,
+                       newParallelism,
+                       operatorIDs,
+                       newManagedOperatorStates,
+                       newRawOperatorStates);
+
+               Multimap<OperatorInstanceID, KeyedStateHandle> 
newManagedKeyedState = ArrayListMultimap.create();
+               Multimap<OperatorInstanceID, KeyedStateHandle> newRawKeyedState 
= ArrayListMultimap.create();
+
+               reDistributeKeyedStates(
+                       operatorStates,
+                       newParallelism,
+                       operatorIDs,
+                       keyGroupPartitions,
+                       newManagedKeyedState,
+                       newRawKeyedState);
 
-               //3. Compute TaskStateHandles of every subTask in the 
executionJobVertex
                /**
                 *  An executionJobVertex's all state handles needed to restore 
are something like a matrix
                 *
@@ -153,113 +172,122 @@ public class StateAssignmentOperation {
                 * op2   sh(2,0)         sh(2,1)           sh(2,2)              
sh(2,3)
                 * op3   sh(3,0)         sh(3,1)           sh(3,2)              
sh(3,3)
                 *
-                * we will compute the state handles column by column.
-                *
                 */
+               assignTaskStateToExecutionJobVertices(
+                       executionJobVertex,
+                       newManagedOperatorStates,
+                       newRawOperatorStates,
+                       newManagedKeyedState,
+                       newRawKeyedState,
+                       newParallelism);
+       }
+
+       private void assignTaskStateToExecutionJobVertices(
+                       ExecutionJobVertex executionJobVertex,
+                       Multimap<OperatorInstanceID, OperatorStateHandle> 
subManagedOperatorState,
+                       Multimap<OperatorInstanceID, OperatorStateHandle> 
subRawOperatorState,
+                       Multimap<OperatorInstanceID, KeyedStateHandle> 
subManagedKeyedState,
+                       Multimap<OperatorInstanceID, KeyedStateHandle> 
subRawKeyedState,
+                       int newParallelism) {
+
+               List<OperatorID> operatorIDs = 
executionJobVertex.getOperatorIDs();
+
                for (int subTaskIndex = 0; subTaskIndex < newParallelism; 
subTaskIndex++) {
 
                        Execution currentExecutionAttempt = 
executionJobVertex.getTaskVertices()[subTaskIndex]
                                .getCurrentExecutionAttempt();
 
-                       Tuple2<Collection<KeyedStateHandle>, 
Collection<KeyedStateHandle>> subKeyedState = null;
-
-                       List<Collection<OperatorStateHandle>> 
subManagedOperatorState = new ArrayList<>();
-                       List<Collection<OperatorStateHandle>> 
subRawOperatorState = new ArrayList<>();
-
+                       TaskStateSnapshot taskState = new TaskStateSnapshot();
+                       boolean statelessTask = true;
 
-                       for (int operatorIndex = 0; operatorIndex < 
operatorIDs.size(); operatorIndex++) {
-                               OperatorState operatorState = 
operatorStates.get(operatorIndex);
-                               int oldParallelism = 
operatorState.getParallelism();
+                       for (OperatorID operatorID : operatorIDs) {
+                               OperatorInstanceID instanceID = 
OperatorInstanceID.of(subTaskIndex, operatorID);
 
-                               // PartitionedState
-                               reAssignSubPartitionableState(
-                                       newManagedOperatorStates,
-                                       newRawOperatorStates,
-                                       subTaskIndex,
-                                       operatorIndex,
+                               OperatorSubtaskState operatorSubtaskState = 
operatorSubtaskStateFrom(
+                                       instanceID,
                                        subManagedOperatorState,
-                                       subRawOperatorState);
+                                       subRawOperatorState,
+                                       subManagedKeyedState,
+                                       subRawKeyedState);
 
-                               // KeyedState
-                               if (isHeadOperator(operatorIndex, operatorIDs)) 
{
-                                       subKeyedState = reAssignSubKeyedStates(
-                                               operatorState,
-                                               keyGroupPartitions,
-                                               subTaskIndex,
-                                               newParallelism,
-                                               oldParallelism);
+                               if (operatorSubtaskState.hasState()) {
+                                       statelessTask = false;
                                }
+                               
taskState.putSubtaskStateByOperatorID(operatorID, operatorSubtaskState);
                        }
 
-                       // check if a stateless task
-                       if (!allElementsAreNull(subManagedOperatorState) ||
-                               !allElementsAreNull(subRawOperatorState) ||
-                               subKeyedState != null) {
-
-                               TaskStateSnapshot taskState = new 
TaskStateSnapshot();
-
-                               for (int i = 0; i < operatorIDs.size(); ++i) {
-
-                                       OperatorID operatorID = 
operatorIDs.get(i);
-
-                                       Collection<KeyedStateHandle> rawKeyed = 
Collections.emptyList();
-                                       Collection<KeyedStateHandle> 
managedKeyed = Collections.emptyList();
-
-                                       // keyed state case
-                                       if (subKeyedState != null) {
-                                               managedKeyed = subKeyedState.f0;
-                                               rawKeyed = subKeyedState.f1;
-                                       }
-
-                                       OperatorSubtaskState 
operatorSubtaskState =
-                                               new OperatorSubtaskState(
-                                                       
subManagedOperatorState.get(i),
-                                                       
subRawOperatorState.get(i),
-                                                       managedKeyed,
-                                                       rawKeyed
-                                               );
-
-                                       
taskState.putSubtaskStateByOperatorID(operatorID, operatorSubtaskState);
-                               }
-
+                       if (!statelessTask) {
                                
currentExecutionAttempt.setInitialState(taskState);
                        }
                }
        }
 
+       private static OperatorSubtaskState operatorSubtaskStateFrom(
+                       OperatorInstanceID instanceID,
+                       Multimap<OperatorInstanceID, OperatorStateHandle> 
subManagedOperatorState,
+                       Multimap<OperatorInstanceID, OperatorStateHandle> 
subRawOperatorState,
+                       Multimap<OperatorInstanceID, KeyedStateHandle> 
subManagedKeyedState,
+                       Multimap<OperatorInstanceID, KeyedStateHandle> 
subRawKeyedState) {
+
+               if (!subManagedOperatorState.containsKey(instanceID) &&
+                       !subRawOperatorState.containsKey(instanceID) &&
+                       !subManagedKeyedState.containsKey(instanceID) &&
+                       !subRawKeyedState.containsKey(instanceID)) {
+                       
+                       return new OperatorSubtaskState();
+               }
+               if (!subManagedKeyedState.containsKey(instanceID)) {
+                       checkState(!subRawKeyedState.containsKey(instanceID));
+               }
+               return new OperatorSubtaskState(
+                       subManagedOperatorState.get(instanceID),
+                       subRawOperatorState.get(instanceID),
+                       subManagedKeyedState.get(instanceID),
+                       subRawKeyedState.get(instanceID));
+       }
+
        private static boolean isHeadOperator(int opIdx, List<OperatorID> 
operatorIDs) {
                return opIdx == operatorIDs.size() - 1;
        }
 
        public void checkParallelismPreconditions(List<OperatorState> 
operatorStates, ExecutionJobVertex executionJobVertex) {
-
                for (OperatorState operatorState : operatorStates) {
                        checkParallelismPreconditions(operatorState, 
executionJobVertex);
                }
        }
 
-
-       private void reAssignSubPartitionableState(
-                       List<List<Collection<OperatorStateHandle>>> 
newMangedOperatorStates,
-                       List<List<Collection<OperatorStateHandle>>> 
newRawOperatorStates,
-                       int subTaskIndex, int operatorIndex,
-                       List<Collection<OperatorStateHandle>> 
subManagedOperatorState,
-                       List<Collection<OperatorStateHandle>> 
subRawOperatorState) {
-
-               if (newMangedOperatorStates.get(operatorIndex) != null && 
!newMangedOperatorStates.get(operatorIndex).isEmpty()) {
-                       Collection<OperatorStateHandle> operatorStateHandles = 
newMangedOperatorStates.get(operatorIndex).get(subTaskIndex);
-                       subManagedOperatorState.add(operatorStateHandles != 
null ? operatorStateHandles : Collections.<OperatorStateHandle>emptyList());
-               } else {
-                       
subManagedOperatorState.add(Collections.<OperatorStateHandle>emptyList());
-               }
-               if (newRawOperatorStates.get(operatorIndex) != null && 
!newRawOperatorStates.get(operatorIndex).isEmpty()) {
-                       Collection<OperatorStateHandle> operatorStateHandles = 
newRawOperatorStates.get(operatorIndex).get(subTaskIndex);
-                       subRawOperatorState.add(operatorStateHandles != null ? 
operatorStateHandles : Collections.<OperatorStateHandle>emptyList());
-               } else {
-                       
subRawOperatorState.add(Collections.<OperatorStateHandle>emptyList());
+       private void reDistributeKeyedStates(
+                       List<OperatorState> oldOperatorStates,
+                       int newParallelism,
+                       List<OperatorID> newOperatorIDs,
+                       List<KeyGroupRange> newKeyGroupPartitions,
+                       Multimap<OperatorInstanceID, KeyedStateHandle> 
newManagedKeyedState,
+                       Multimap<OperatorInstanceID, KeyedStateHandle> 
newRawKeyedState) {
+               //TODO: rewrite this method to only use OperatorID
+               checkState(newOperatorIDs.size() == oldOperatorStates.size(),
+                       "This method still depends on the order of the new and 
old operators");
+
+               for (int operatorIndex = 0; operatorIndex < 
newOperatorIDs.size(); operatorIndex++) {
+                       OperatorState operatorState = 
oldOperatorStates.get(operatorIndex);
+                       int oldParallelism = operatorState.getParallelism();
+
+                       for (int subTaskIndex = 0; subTaskIndex < 
newParallelism; subTaskIndex++) {
+                               OperatorInstanceID instanceID = 
OperatorInstanceID.of(subTaskIndex, newOperatorIDs.get(operatorIndex));
+                               if (isHeadOperator(operatorIndex, 
newOperatorIDs)) {
+                                       Tuple2<Collection<KeyedStateHandle>, 
Collection<KeyedStateHandle>> subKeyedStates = reAssignSubKeyedStates(
+                                               operatorState,
+                                               newKeyGroupPartitions,
+                                               subTaskIndex,
+                                               newParallelism,
+                                               oldParallelism);
+                                       newManagedKeyedState.putAll(instanceID, 
subKeyedStates.f0);
+                                       newRawKeyedState.putAll(instanceID, 
subKeyedStates.f1);
+                               }
+                       }
                }
        }
 
+       // TODO rewrite based on operator id
        private Tuple2<Collection<KeyedStateHandle>, 
Collection<KeyedStateHandle>> reAssignSubKeyedStates(
                        OperatorState operatorState,
                        List<KeyGroupRange> keyGroupPartitions,
@@ -284,48 +312,50 @@ public class StateAssignmentOperation {
                }
 
                if (subManagedKeyedState.isEmpty() && 
subRawKeyedState.isEmpty()) {
-                       return null;
+                       return new Tuple2<>(Collections.emptyList(), 
Collections.emptyList());
                } else {
                        return new Tuple2<>(subManagedKeyedState, 
subRawKeyedState);
                }
        }
 
-
-       private <X> boolean allElementsAreNull(List<X> nonPartitionableStates) {
-               for (Object streamStateHandle : nonPartitionableStates) {
-                       if (streamStateHandle != null) {
-                               return false;
-                       }
-               }
-               return true;
-       }
-
        private void reDistributePartitionableStates(
-                       List<OperatorState> operatorStates, int newParallelism,
-                       List<List<Collection<OperatorStateHandle>>> 
newManagedOperatorStates,
-                       List<List<Collection<OperatorStateHandle>>> 
newRawOperatorStates) {
+                       List<OperatorState> oldOperatorStates,
+                       int newParallelism,
+                       List<OperatorID> newOperatorIDs,
+                       Multimap<OperatorInstanceID, OperatorStateHandle> 
newManagedOperatorStates,
+                       Multimap<OperatorInstanceID, OperatorStateHandle> 
newRawOperatorStates) {
+
+               //TODO: rewrite this method to only use OperatorID
+               checkState(newOperatorIDs.size() == oldOperatorStates.size(),
+                       "This method still depends on the order of the new and 
old operators");
 
                //collect the old partitionable state
                List<List<OperatorStateHandle>> oldManagedOperatorStates = new 
ArrayList<>();
                List<List<OperatorStateHandle>> oldRawOperatorStates = new 
ArrayList<>();
 
-               collectPartionableStates(operatorStates, 
oldManagedOperatorStates, oldRawOperatorStates);
-
+               collectPartionableStates(oldOperatorStates, 
oldManagedOperatorStates, oldRawOperatorStates);
 
                //redistribute
                OperatorStateRepartitioner opStateRepartitioner = 
RoundRobinOperatorStateRepartitioner.INSTANCE;
 
-               for (int operatorIndex = 0; operatorIndex < 
operatorStates.size(); operatorIndex++) {
-                       int oldParallelism = 
operatorStates.get(operatorIndex).getParallelism();
-                       
newManagedOperatorStates.add(applyRepartitioner(opStateRepartitioner,
-                               oldManagedOperatorStates.get(operatorIndex), 
oldParallelism, newParallelism));
-                       
newRawOperatorStates.add(applyRepartitioner(opStateRepartitioner,
-                               oldRawOperatorStates.get(operatorIndex), 
oldParallelism, newParallelism));
-
+               for (int operatorIndex = 0; operatorIndex < 
oldOperatorStates.size(); operatorIndex++) {
+                       OperatorID operatorID = 
newOperatorIDs.get(operatorIndex);
+                       int oldParallelism = 
oldOperatorStates.get(operatorIndex).getParallelism();
+                       newManagedOperatorStates.putAll(applyRepartitioner(
+                               operatorID,
+                               opStateRepartitioner,
+                               oldManagedOperatorStates.get(operatorIndex),
+                               oldParallelism,
+                               newParallelism));
+                       newRawOperatorStates.putAll(applyRepartitioner(
+                               operatorID,
+                               opStateRepartitioner,
+                               oldRawOperatorStates.get(operatorIndex),
+                               oldParallelism,
+                               newParallelism));
                }
        }
 
-
        private void collectPartionableStates(
                        List<OperatorState> operatorStates,
                        List<List<OperatorStateHandle>> managedOperatorStates,
@@ -356,7 +386,6 @@ public class StateAssignmentOperation {
                }
        }
 
-
        /**
         * Collect {@link KeyGroupsStateHandle  managedKeyedStateHandles} which 
have intersection with given
         * {@link KeyGroupRange} from {@link TaskState operatorState}
@@ -524,6 +553,28 @@ public class StateAssignmentOperation {
                }
        }
 
+       public static Multimap<OperatorInstanceID, OperatorStateHandle> 
applyRepartitioner(
+                       OperatorID operatorID,
+                       OperatorStateRepartitioner opStateRepartitioner,
+                       List<OperatorStateHandle> chainOpParallelStates,
+                       int oldParallelism,
+                       int newParallelism) {
+               Multimap<OperatorInstanceID, OperatorStateHandle> result = 
ArrayListMultimap.create();
+
+               List<Collection<OperatorStateHandle>> states = 
applyRepartitioner(
+                       opStateRepartitioner,
+                       chainOpParallelStates,
+                       oldParallelism,
+                       newParallelism);
+
+               for (int subtaskIndex = 0; subtaskIndex < states.size(); 
subtaskIndex++) {
+                       checkNotNull(states.get(subtaskIndex) != null, 
"states.get(subtaskIndex) is null");
+                       result.putAll(OperatorInstanceID.of(subtaskIndex, 
operatorID), states.get(subtaskIndex));
+               }
+
+               return result;
+       }
+
        /**
         * Repartitions the given operator state using the given {@link 
OperatorStateRepartitioner} with respect to the new
         * parallelism.
@@ -534,6 +585,7 @@ public class StateAssignmentOperation {
         * @param newParallelism        parallelism with which the state should 
be partitioned
         * @return repartitioned state
         */
+       // TODO rewrite based on operator id
        public static List<Collection<OperatorStateHandle>> applyRepartitioner(
                        OperatorStateRepartitioner opStateRepartitioner,
                        List<OperatorStateHandle> chainOpParallelStates,

http://git-wip-us.apache.org/repos/asf/flink/blob/f1b2b83d/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/OperatorInstanceID.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/OperatorInstanceID.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/OperatorInstanceID.java
new file mode 100644
index 0000000..76bcdbf
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/OperatorInstanceID.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.jobgraph;
+
+import java.util.Objects;
+
+/**
+ * An ID for physical instance of the operator.
+ */
+public class OperatorInstanceID  {
+
+       private final int subtaskId;
+       private final OperatorID operatorId;
+
+       public static OperatorInstanceID of(int subtaskId, OperatorID 
operatorID) {
+               return new OperatorInstanceID(subtaskId, operatorID);
+       }
+
+       public OperatorInstanceID(int subtaskId, OperatorID operatorId) {
+               this.subtaskId = subtaskId;
+               this.operatorId = operatorId;
+       }
+
+       public int getSubtaskId() {
+               return subtaskId;
+       }
+
+       public OperatorID getOperatorId() {
+               return operatorId;
+       }
+
+       @Override
+       public int hashCode() {
+               return Objects.hash(subtaskId, operatorId);
+       }
+
+       @Override
+       public boolean equals(Object obj) {
+               if (obj == this) {
+                       return true;
+               }
+               if (obj == null) {
+                       return false;
+               }
+               if (!(obj instanceof OperatorInstanceID)) {
+                       return false;
+               }
+               OperatorInstanceID other = (OperatorInstanceID) obj;
+               return this.subtaskId == other.subtaskId &&
+                       Objects.equals(this.operatorId, other.operatorId);
+       }
+
+       @Override
+       public String toString() {
+               return String.format("<%d, %s>", subtaskId, operatorId);
+       }
+}

Reply via email to