[Flink-5892] Restore state on operator level

Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/f7980a7e
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/f7980a7e
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/f7980a7e

Branch: refs/heads/table-retraction
Commit: f7980a7e29457753eb3c5b975f3bb4b59d2014f8
Parents: 8045fab
Author: zentol <ches...@apache.org>
Authored: Fri Apr 28 19:40:20 2017 +0200
Committer: zentol <ches...@apache.org>
Committed: Fri Apr 28 20:11:35 2017 +0200

----------------------------------------------------------------------
 docs/ops/upgrading.md                           |   4 +-
 docs/setup/savepoints.md                        |   5 -
 .../checkpoint/savepoint/SavepointV0.java       |   6 +
 .../checkpoint/CheckpointCoordinator.java       |   9 +-
 .../runtime/checkpoint/CompletedCheckpoint.java |  57 +-
 .../flink/runtime/checkpoint/OperatorState.java | 183 ++++++
 .../checkpoint/OperatorSubtaskState.java        | 229 +++++++
 .../runtime/checkpoint/PendingCheckpoint.java   |  87 +--
 .../checkpoint/StateAssignmentOperation.java    | 611 ++++++++++++-------
 .../checkpoint/StateAssignmentOperationV2.java  | 458 --------------
 .../flink/runtime/checkpoint/TaskState.java     |   4 +
 .../runtime/checkpoint/savepoint/Savepoint.java |  12 +
 .../checkpoint/savepoint/SavepointLoader.java   |  54 +-
 .../checkpoint/savepoint/SavepointV1.java       |   6 +
 .../savepoint/SavepointV1Serializer.java        |   4 +-
 .../checkpoint/savepoint/SavepointV2.java       | 155 ++++-
 .../savepoint/SavepointV2Serializer.java        | 114 ++--
 .../executiongraph/ExecutionJobVertex.java      |  63 ++
 .../runtime/jobgraph/InputFormatVertex.java     |   4 +-
 .../flink/runtime/jobgraph/JobVertex.java       |  25 +-
 .../flink/runtime/jobgraph/OperatorID.java      |  45 ++
 ...tCoordinatorExternalizedCheckpointsTest.java |   5 +-
 .../CheckpointCoordinatorFailureTest.java       |  35 +-
 .../CheckpointCoordinatorMasterHooksTest.java   |   5 +-
 .../checkpoint/CheckpointCoordinatorTest.java   | 236 +++++--
 .../checkpoint/CheckpointStateRestoreTest.java  |  55 +-
 .../CompletedCheckpointStoreTest.java           |  59 +-
 .../checkpoint/CompletedCheckpointTest.java     |  47 +-
 .../checkpoint/PendingCheckpointTest.java       |  24 +-
 .../StandaloneCompletedCheckpointStoreTest.java |  10 +-
 ...ZooKeeperCompletedCheckpointStoreITCase.java |  14 +-
 .../ZooKeeperCompletedCheckpointStoreTest.java  |   2 +-
 .../savepoint/CheckpointTestUtils.java          |  97 ++-
 .../savepoint/SavepointLoaderTest.java          |  24 +-
 .../savepoint/SavepointStoreTest.java           |  23 +-
 .../savepoint/SavepointV2SerializerTest.java    |  22 +-
 .../checkpoint/savepoint/SavepointV2Test.java   |  14 +-
 .../executiongraph/LegacyJobVertexIdTest.java   |   4 +-
 .../RecoverableCompletedCheckpointStore.java    |   2 +-
 .../api/graph/StreamGraphHasherV1.java          |  13 -
 .../api/graph/StreamGraphHasherV2.java          |  13 -
 .../api/graph/StreamGraphUserHashHasher.java    |   9 -
 .../api/graph/StreamingJobGraphGenerator.java   |  46 +-
 .../StreamingJobGraphGeneratorNodeHashTest.java |  18 +-
 .../test/checkpointing/SavepointITCase.java     |  54 +-
 .../AbstractOperatorRestoreTestBase.java        | 261 ++++++++
 .../state/operator/restore/ExecutionMode.java   |  31 +
 .../restore/keyed/KeyedComplexChainTest.java    |  61 ++
 .../state/operator/restore/keyed/KeyedJob.java  | 243 ++++++++
 ...AbstractNonKeyedOperatorRestoreTestBase.java |  59 ++
 .../restore/unkeyed/ChainBreakTest.java         |  55 ++
 .../unkeyed/ChainLengthDecreaseTest.java        |  51 ++
 .../unkeyed/ChainLengthIncreaseTest.java        |  56 ++
 .../restore/unkeyed/ChainOrderTest.java         |  54 ++
 .../restore/unkeyed/ChainUnionTest.java         |  53 ++
 .../operator/restore/unkeyed/NonKeyedJob.java   | 198 ++++++
 .../operatorstate/complexKeyed/_metadata        | Bin 0 -> 137490 bytes
 .../resources/operatorstate/nonKeyed/_metadata  | Bin 0 -> 3212 bytes
 58 files changed, 2971 insertions(+), 1117 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/docs/ops/upgrading.md
----------------------------------------------------------------------
diff --git a/docs/ops/upgrading.md b/docs/ops/upgrading.md
index 194d8af..7259a6b 100644
--- a/docs/ops/upgrading.md
+++ b/docs/ops/upgrading.md
@@ -73,6 +73,8 @@ val mappedEvents: DataStream[(Int, Long)] = events
 
 **Note:** Since the operator IDs stored in a savepoint and IDs of operators in 
the application to start must be equal, it is highly recommended to assign 
unique IDs to all operators of an application that might be upgraded in the 
future. This advice applies to all operators, i.e., operators with and without 
explicitly declared operator state, because some operators have internal state 
that is not visible to the user. Upgrading an application without assigned 
operator IDs is significantly more difficult and may only be possible via a 
low-level workaround using the `setUidHash()` method.
 
+**Important:** As of 1.3.0 this also applies to operators that are part of a 
chain.
+
 By default all state stored in a savepoint must be matched to the operators of 
a starting application. However, users can explicitly agree to skip (and 
thereby discard) state that cannot be matched to an operator when starting a 
application from a savepoint. Stateful operators for which no state is found in 
the savepoint are initialized with their default state.
 
 ### Stateful Operators and User Functions
@@ -105,7 +107,7 @@ When upgrading an application by changing its topology, a 
few things need to be
 * **Adding a stateful operator:** The state of the operator will be 
initialized with the default state unless it takes over the state of another 
operator.
 * **Removing a stateful operator:** The state of the removed operator is lost 
unless another operator takes it over. When starting the upgraded application, 
you have to explicitly agree to discard the state.
 * **Changing of input and output types of operators:** When adding a new 
operator before or behind an operator with internal state, you have to ensure 
that the input or output type of the stateful operator is not modified to 
preserve the data type of the internal operator state (see above for details).
-* **Changing operator chaining:** Operators can be chained together for 
improved performance. However, chaining can limit the ability of an application 
to be upgraded if a chain contains a stateful operator that is not the first 
operator of the chain. In such a case, it is not possible to break the chain 
such that the stateful operator is moved out of the chain. It is also not 
possible to append or inject an existing stateful operator into a chain. The 
chaining behavior can be changed by modifying the parallelism of a chained 
operator or by adding or removing explicit operator chaining instructions. 
+* **Changing operator chaining:** Operators can be chained together for 
improved performance. When restoring from a savepoint taken since 1.3.0 it is 
possible to modify chains while preversing state consistency. It is possible a 
break the chain such that a stateful operator is moved out of the chain. It is 
also possible to append or inject a new or existing stateful operator into a 
chain, or to modify the operator order within a chain. However, when upgrading 
a savepoint to 1.3.0 it is paramount that the topology did not change in 
regards to chaining. All operators that are part of a chain should be assigned 
an ID as described in the [Matching Operator State](#Matching Operator State) 
section above.
 
 ## Upgrading the Flink Framework Version
 

http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/docs/setup/savepoints.md
----------------------------------------------------------------------
diff --git a/docs/setup/savepoints.md b/docs/setup/savepoints.md
index 4bdc43f..eada9b4 100644
--- a/docs/setup/savepoints.md
+++ b/docs/setup/savepoints.md
@@ -185,8 +185,3 @@ If you did not assign IDs, the auto generated IDs of the 
stateful operators will
 If the savepoint was triggered with Flink >= 1.2.0 and using no deprecated 
state API like `Checkpointed`, you can simply restore the program from a 
savepoint and specify a new parallelism.
 
 If you are resuming from a savepoint triggered with Flink < 1.2.0 or using now 
deprecated APIs you first have to migrate your job and savepoint to Flink 1.2.0 
before being able to change the parallelism. See the [upgrading jobs and Flink 
versions guide]({{ site.baseurl }}/ops/upgrading.html).
-
-## Current limitations
-
-- **Chaining**: Chained operators are identified by the ID of the first task. 
It's not possible to manually assign an ID to an intermediate chained task, 
e.g. in the chain `[  a -> b -> c ]` only **a** can have its ID assigned 
manually, but not **b** or **c**. To work around this, you can [manually define 
the task chains](index.html#task-chaining-and-resource-groups). If you rely on 
the automatic ID assignment, a change in the chaining behaviour will also 
change the IDs.
-

http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/flink-runtime/src/main/java/org/apache/flink/migration/runtime/checkpoint/savepoint/SavepointV0.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/checkpoint/savepoint/SavepointV0.java
 
b/flink-runtime/src/main/java/org/apache/flink/migration/runtime/checkpoint/savepoint/SavepointV0.java
index f3ec1cf..7888d2f 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/migration/runtime/checkpoint/savepoint/SavepointV0.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/migration/runtime/checkpoint/savepoint/SavepointV0.java
@@ -20,6 +20,7 @@ package 
org.apache.flink.migration.runtime.checkpoint.savepoint;
 
 import org.apache.flink.migration.runtime.checkpoint.TaskState;
 import org.apache.flink.runtime.checkpoint.MasterState;
+import org.apache.flink.runtime.checkpoint.OperatorState;
 import org.apache.flink.runtime.checkpoint.savepoint.Savepoint;
 import org.apache.flink.util.Preconditions;
 
@@ -72,6 +73,11 @@ public class SavepointV0 implements Savepoint {
        }
 
        @Override
+       public Collection<OperatorState> getOperatorStates() {
+               return null;
+       }
+
+       @Override
        public void dispose() throws Exception {
                //NOP
        }

http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
index fb6cc72..96add06 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
@@ -36,6 +36,7 @@ import 
org.apache.flink.runtime.executiongraph.ExecutionVertex;
 import org.apache.flink.runtime.executiongraph.JobStatusListener;
 import org.apache.flink.runtime.jobgraph.JobStatus;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings;
 import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
 import org.apache.flink.runtime.messages.checkpoint.DeclineCheckpoint;
@@ -892,7 +893,7 @@ public class CheckpointCoordinator {
                if (LOG.isDebugEnabled()) {
                        StringBuilder builder = new StringBuilder();
                        builder.append("Checkpoint state: ");
-                       for (TaskState state : 
completedCheckpoint.getTaskStates().values()) {
+                       for (OperatorState state : 
completedCheckpoint.getOperatorStates().values()) {
                                builder.append(state);
                                builder.append(", ");
                        }
@@ -1017,11 +1018,11 @@ public class CheckpointCoordinator {
                        LOG.info("Restoring from latest valid checkpoint: {}.", 
latest);
 
                        // re-assign the task states
-
-                       final Map<JobVertexID, TaskState> taskStates = 
latest.getTaskStates();
+                       final Map<OperatorID, OperatorState> operatorStates = 
latest.getOperatorStates();
 
                        StateAssignmentOperation stateAssignmentOperation =
-                                       new StateAssignmentOperation(LOG, 
tasks, taskStates, allowNonRestoredState);
+                                       new StateAssignmentOperation(tasks, 
operatorStates, allowNonRestoredState);
+
                        stateAssignmentOperation.assignStates();
 
                        // call master hooks for restore

http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java
index bb49b45..1ab5b41 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java
@@ -18,10 +18,9 @@
 
 package org.apache.flink.runtime.checkpoint;
 
-import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.runtime.jobgraph.JobStatus;
-import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.state.StateUtil;
 import org.apache.flink.runtime.state.StreamStateHandle;
@@ -94,8 +93,8 @@ public class CompletedCheckpoint implements Serializable {
        /** The duration of the checkpoint (completion timestamp - trigger 
timestamp). */
        private final long duration;
 
-       /** States of the different task groups belonging to this checkpoint */
-       private final HashMap<JobVertexID, TaskState> taskStates;
+       /** States of the different operator groups belonging to this 
checkpoint */
+       private final Map<OperatorID, OperatorState> operatorStates;
 
        /** Properties for this checkpoint. */
        private final CheckpointProperties props;
@@ -117,38 +116,12 @@ public class CompletedCheckpoint implements Serializable {
 
        // 
------------------------------------------------------------------------
 
-       @VisibleForTesting
-       CompletedCheckpoint(
-                       JobID job,
-                       long checkpointID,
-                       long timestamp,
-                       long completionTimestamp,
-                       Map<JobVertexID, TaskState> taskStates) {
-
-               this(job, checkpointID, timestamp, completionTimestamp, 
taskStates,
-                               Collections.<MasterState>emptyList(),
-                               CheckpointProperties.forStandardCheckpoint());
-       }
-
-       public CompletedCheckpoint(
-                       JobID job,
-                       long checkpointID,
-                       long timestamp,
-                       long completionTimestamp,
-                       Map<JobVertexID, TaskState> taskStates,
-                       @Nullable Collection<MasterState> masterHookStates,
-                       CheckpointProperties props) {
-
-               this(job, checkpointID, timestamp, completionTimestamp, 
taskStates, 
-                               masterHookStates, props, null, null);
-       }
-
        public CompletedCheckpoint(
                        JobID job,
                        long checkpointID,
                        long timestamp,
                        long completionTimestamp,
-                       Map<JobVertexID, TaskState> taskStates,
+                       Map<OperatorID, OperatorState> operatorStates,
                        @Nullable Collection<MasterState> masterHookStates,
                        CheckpointProperties props,
                        @Nullable StreamStateHandle externalizedMetadata,
@@ -171,7 +144,7 @@ public class CompletedCheckpoint implements Serializable {
 
                // we create copies here, to make sure we have no shared mutable
                // data structure with the "outside world"
-               this.taskStates = new HashMap<>(checkNotNull(taskStates));
+               this.operatorStates = new 
HashMap<>(checkNotNull(operatorStates));
                this.masterHookStates = masterHookStates == null || 
masterHookStates.isEmpty() ?
                                Collections.<MasterState>emptyList() :
                                new ArrayList<>(masterHookStates);
@@ -239,19 +212,15 @@ public class CompletedCheckpoint implements Serializable {
        public long getStateSize() {
                long result = 0L;
 
-               for (TaskState taskState : taskStates.values()) {
-                       result += taskState.getStateSize();
+               for (OperatorState operatorState : operatorStates.values()) {
+                       result += operatorState.getStateSize();
                }
 
                return result;
        }
 
-       public Map<JobVertexID, TaskState> getTaskStates() {
-               return Collections.unmodifiableMap(taskStates);
-       }
-
-       public TaskState getTaskState(JobVertexID jobVertexID) {
-               return taskStates.get(jobVertexID);
+       public Map<OperatorID, OperatorState> getOperatorStates() {
+               return operatorStates;
        }
 
        public Collection<MasterState> getMasterHookStates() {
@@ -288,7 +257,7 @@ public class CompletedCheckpoint implements Serializable {
         * @param sharedStateRegistry The registry where shared states are 
registered
         */
        public void registerSharedStates(SharedStateRegistry 
sharedStateRegistry) {
-               sharedStateRegistry.registerAll(taskStates.values());
+               sharedStateRegistry.registerAll(operatorStates.values());
        }
 
        // 
--------------------------------------------------------------------------------------------
@@ -338,7 +307,7 @@ public class CompletedCheckpoint implements Serializable {
                protected void doDiscardPrivateState() {
                        // discard private state objects
                        try {
-                               
StateUtil.bestEffortDiscardAllStateObjects(taskStates.values());
+                               
StateUtil.bestEffortDiscardAllStateObjects(operatorStates.values());
                        } catch (Exception e) {
                                storedException = 
ExceptionUtils.firstOrSuppressed(e, storedException);
                        }
@@ -353,7 +322,7 @@ public class CompletedCheckpoint implements Serializable {
                }
 
                protected void clearTaskStatesAndNotifyDiscardCompleted() {
-                       taskStates.clear();
+                       operatorStates.clear();
                        // to be null-pointer safe, copy reference to stack
                        CompletedCheckpointStats.DiscardCallback 
discardCallback =
                                CompletedCheckpoint.this.discardCallback;
@@ -392,7 +361,7 @@ public class CompletedCheckpoint implements Serializable {
 
                @Override
                protected void doDiscardSharedState() {
-                       sharedStateRegistry.unregisterAll(taskStates.values());
+                       
sharedStateRegistry.unregisterAll(operatorStates.values());
                }
        }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorState.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorState.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorState.java
new file mode 100644
index 0000000..aa676e7
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorState.java
@@ -0,0 +1,183 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint;
+
+import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.runtime.state.CompositeStateHandle;
+import org.apache.flink.runtime.state.SharedStateRegistry;
+import org.apache.flink.util.Preconditions;
+
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Objects;
+
+/**
+ * Simple container class which contains the raw/managed/legacy operator state 
and key-group state handles for the sub
+ * tasks of an operator.
+ */
+public class OperatorState implements CompositeStateHandle {
+
+       private static final long serialVersionUID = -4845578005863201810L;
+
+       /** id of the operator */
+       private final OperatorID operatorID;
+
+       /** handles to non-partitioned states, subtaskindex -> subtaskstate */
+       private final Map<Integer, OperatorSubtaskState> operatorSubtaskStates;
+
+       /** parallelism of the operator when it was checkpointed */
+       private final int parallelism;
+
+       /** maximum parallelism of the operator when the job was first created 
*/
+       private final int maxParallelism;
+
+       public OperatorState(OperatorID operatorID, int parallelism, int 
maxParallelism) {
+               Preconditions.checkArgument(
+                       parallelism <= maxParallelism,
+                       "Parallelism " + parallelism + " is not smaller or 
equal to max parallelism " + maxParallelism + ".");
+
+               this.operatorID = operatorID;
+
+               this.operatorSubtaskStates = new HashMap<>(parallelism);
+
+               this.parallelism = parallelism;
+               this.maxParallelism = maxParallelism;
+       }
+
+       public OperatorID getOperatorID() {
+               return operatorID;
+       }
+
+       public void putState(int subtaskIndex, OperatorSubtaskState 
subtaskState) {
+               Preconditions.checkNotNull(subtaskState);
+
+               if (subtaskIndex < 0 || subtaskIndex >= parallelism) {
+                       throw new IndexOutOfBoundsException("The given sub task 
index " + subtaskIndex +
+                               " exceeds the maximum number of sub tasks " + 
operatorSubtaskStates.size());
+               } else {
+                       operatorSubtaskStates.put(subtaskIndex, subtaskState);
+               }
+       }
+
+       public OperatorSubtaskState getState(int subtaskIndex) {
+               if (subtaskIndex < 0 || subtaskIndex >= parallelism) {
+                       throw new IndexOutOfBoundsException("The given sub task 
index " + subtaskIndex +
+                               " exceeds the maximum number of sub tasks " + 
operatorSubtaskStates.size());
+               } else {
+                       return operatorSubtaskStates.get(subtaskIndex);
+               }
+       }
+
+       public Collection<OperatorSubtaskState> getStates() {
+               return operatorSubtaskStates.values();
+       }
+
+       public int getNumberCollectedStates() {
+               return operatorSubtaskStates.size();
+       }
+
+       public int getParallelism() {
+               return parallelism;
+       }
+
+       public int getMaxParallelism() {
+               return maxParallelism;
+       }
+
+       public boolean hasNonPartitionedState() {
+               for (OperatorSubtaskState sts : operatorSubtaskStates.values()) 
{
+                       if (sts != null && sts.getLegacyOperatorState() != 
null) {
+                               return true;
+                       }
+               }
+               return false;
+       }
+
+       @Override
+       public void discardState() throws Exception {
+               for (OperatorSubtaskState operatorSubtaskState : 
operatorSubtaskStates.values()) {
+                       operatorSubtaskState.discardState();
+               }
+       }
+
+       @Override
+       public void registerSharedStates(SharedStateRegistry 
sharedStateRegistry) {
+               for (OperatorSubtaskState operatorSubtaskState : 
operatorSubtaskStates.values()) {
+                       
operatorSubtaskState.registerSharedStates(sharedStateRegistry);
+               }
+       }
+
+       @Override
+       public void unregisterSharedStates(SharedStateRegistry 
sharedStateRegistry) {
+               for (OperatorSubtaskState operatorSubtaskState : 
operatorSubtaskStates.values()) {
+                       
operatorSubtaskState.unregisterSharedStates(sharedStateRegistry);
+               }
+       }
+
+       @Override
+       public long getStateSize() {
+               long result = 0L;
+
+               for (int i = 0; i < parallelism; i++) {
+                       OperatorSubtaskState operatorSubtaskState = 
operatorSubtaskStates.get(i);
+                       if (operatorSubtaskState != null) {
+                               result += operatorSubtaskState.getStateSize();
+                       }
+               }
+
+               return result;
+       }
+
+       @Override
+       public boolean equals(Object obj) {
+               if (obj instanceof OperatorState) {
+                       OperatorState other = (OperatorState) obj;
+
+                       return operatorID.equals(other.operatorID)
+                               && parallelism == other.parallelism
+                               && 
operatorSubtaskStates.equals(other.operatorSubtaskStates);
+               } else {
+                       return false;
+               }
+       }
+
+       @Override
+       public int hashCode() {
+               return parallelism + 31 * Objects.hash(operatorID, 
operatorSubtaskStates);
+       }
+
+       public Map<Integer, OperatorSubtaskState> getSubtaskStates() {
+               return Collections.unmodifiableMap(operatorSubtaskStates);
+       }
+
+       @Override
+       public String toString() {
+               // KvStates are always null in 1.1. Don't print this as it might
+               // confuse users that don't care about how we store it 
internally.
+               return "OperatorState(" +
+                       "operatorID: " + operatorID +
+                       ", parallelism: " + parallelism +
+                       ", maxParallelism: " + maxParallelism +
+                       ", sub task states: " + operatorSubtaskStates.size() +
+                       ", total size (bytes): " + getStateSize() +
+                       ')';
+       }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java
new file mode 100644
index 0000000..863816a
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java
@@ -0,0 +1,229 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint;
+
+import org.apache.flink.runtime.state.CompositeStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
+import org.apache.flink.runtime.state.SharedStateRegistry;
+import org.apache.flink.runtime.state.StateObject;
+import org.apache.flink.runtime.state.StateUtil;
+import org.apache.flink.runtime.state.StreamStateHandle;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.Arrays;
+
+/**
+ * Container for the state of one parallel subtask of an operator. This is 
part of the {@link OperatorState}.
+ */
+public class OperatorSubtaskState implements CompositeStateHandle {
+
+       private static final Logger LOG = 
LoggerFactory.getLogger(OperatorSubtaskState.class);
+
+       private static final long serialVersionUID = -2394696997971923995L;
+
+       /**
+        * Legacy (non-repartitionable) operator state.
+        *
+        * @deprecated Non-repartitionable operator state that has been 
deprecated.
+        * Can be removed when we remove the APIs for non-repartitionable 
operator state.
+        */
+       @Deprecated
+       private final StreamStateHandle legacyOperatorState;
+
+       /**
+        * Snapshot from the {@link 
org.apache.flink.runtime.state.OperatorStateBackend}.
+        */
+       private final OperatorStateHandle managedOperatorState;
+
+       /**
+        * Snapshot written using {@link 
org.apache.flink.runtime.state.OperatorStateCheckpointOutputStream}.
+        */
+       private final OperatorStateHandle rawOperatorState;
+
+       /**
+        * Snapshot from {@link 
org.apache.flink.runtime.state.KeyedStateBackend}.
+        */
+       private final KeyedStateHandle managedKeyedState;
+
+       /**
+        * Snapshot written using {@link 
org.apache.flink.runtime.state.KeyedStateCheckpointOutputStream}.
+        */
+       private final KeyedStateHandle rawKeyedState;
+
+       /**
+        * The state size. This is also part of the deserialized state handle.
+        * We store it here in order to not deserialize the state handle when
+        * gathering stats.
+        */
+       private final long stateSize;
+
+       public OperatorSubtaskState(
+               StreamStateHandle legacyOperatorState,
+               OperatorStateHandle managedOperatorState,
+               OperatorStateHandle rawOperatorState,
+               KeyedStateHandle managedKeyedState,
+               KeyedStateHandle rawKeyedState) {
+
+               this.legacyOperatorState = legacyOperatorState;
+               this.managedOperatorState = managedOperatorState;
+               this.rawOperatorState = rawOperatorState;
+               this.managedKeyedState = managedKeyedState;
+               this.rawKeyedState = rawKeyedState;
+
+               try {
+                       long calculateStateSize = 
getSizeNullSafe(legacyOperatorState);
+                       calculateStateSize += 
getSizeNullSafe(managedOperatorState);
+                       calculateStateSize += getSizeNullSafe(rawOperatorState);
+                       calculateStateSize += 
getSizeNullSafe(managedKeyedState);
+                       calculateStateSize += getSizeNullSafe(rawKeyedState);
+                       stateSize = calculateStateSize;
+               } catch (Exception e) {
+                       throw new RuntimeException("Failed to get state size.", 
e);
+               }
+       }
+
+       private static long getSizeNullSafe(StateObject stateObject) throws 
Exception {
+               return stateObject != null ? stateObject.getStateSize() : 0L;
+       }
+
+       // 
--------------------------------------------------------------------------------------------
+
+       /**
+        * @deprecated Non-repartitionable operator state that has been 
deprecated.
+        * Can be removed when we remove the APIs for non-repartitionable 
operator state.
+        */
+       @Deprecated
+       public StreamStateHandle getLegacyOperatorState() {
+               return legacyOperatorState;
+       }
+
+       public OperatorStateHandle getManagedOperatorState() {
+               return managedOperatorState;
+       }
+
+       public OperatorStateHandle getRawOperatorState() {
+               return rawOperatorState;
+       }
+
+       public KeyedStateHandle getManagedKeyedState() {
+               return managedKeyedState;
+       }
+
+       public KeyedStateHandle getRawKeyedState() {
+               return rawKeyedState;
+       }
+
+       @Override
+       public void discardState() {
+               try {
+                       StateUtil.bestEffortDiscardAllStateObjects(
+                               Arrays.asList(
+                                       legacyOperatorState,
+                                       managedOperatorState,
+                                       rawOperatorState,
+                                       managedKeyedState,
+                                       rawKeyedState));
+               } catch (Exception e) {
+                       LOG.warn("Error while discarding operator states.", e);
+               }
+       }
+
+       @Override
+       public void registerSharedStates(SharedStateRegistry 
sharedStateRegistry) {
+               // No shared states
+       }
+
+       @Override
+       public void unregisterSharedStates(SharedStateRegistry 
sharedStateRegistry) {
+               // No shared states
+       }
+
+       @Override
+       public long getStateSize() {
+               return stateSize;
+       }
+
+       // 
--------------------------------------------------------------------------------------------
+
+       @Override
+       public boolean equals(Object o) {
+               if (this == o) {
+                       return true;
+               }
+               if (o == null || getClass() != o.getClass()) {
+                       return false;
+               }
+
+               OperatorSubtaskState that = (OperatorSubtaskState) o;
+
+               if (stateSize != that.stateSize) {
+                       return false;
+               }
+
+               if (legacyOperatorState != null ?
+                       !legacyOperatorState.equals(that.legacyOperatorState)
+                       : that.legacyOperatorState != null) {
+                       return false;
+               }
+               if (managedOperatorState != null ?
+                       !managedOperatorState.equals(that.managedOperatorState)
+                       : that.managedOperatorState != null) {
+                       return false;
+               }
+               if (rawOperatorState != null ?
+                       !rawOperatorState.equals(that.rawOperatorState)
+                       : that.rawOperatorState != null) {
+                       return false;
+               }
+               if (managedKeyedState != null ?
+                       !managedKeyedState.equals(that.managedKeyedState)
+                       : that.managedKeyedState != null) {
+                       return false;
+               }
+               return rawKeyedState != null ?
+                       rawKeyedState.equals(that.rawKeyedState)
+                       : that.rawKeyedState == null;
+
+       }
+
+       @Override
+       public int hashCode() {
+               int result = legacyOperatorState != null ? 
legacyOperatorState.hashCode() : 0;
+               result = 31 * result + (managedOperatorState != null ? 
managedOperatorState.hashCode() : 0);
+               result = 31 * result + (rawOperatorState != null ? 
rawOperatorState.hashCode() : 0);
+               result = 31 * result + (managedKeyedState != null ? 
managedKeyedState.hashCode() : 0);
+               result = 31 * result + (rawKeyedState != null ? 
rawKeyedState.hashCode() : 0);
+               result = 31 * result + (int) (stateSize ^ (stateSize >>> 32));
+               return result;
+       }
+
+       @Override
+       public String toString() {
+               return "SubtaskState{" +
+                       "legacyState=" + legacyOperatorState +
+                       ", operatorStateFromBackend=" + managedOperatorState +
+                       ", operatorStateFromStream=" + rawOperatorState +
+                       ", keyedStateFromBackend=" + managedKeyedState +
+                       ", keyedStateFromStream=" + rawKeyedState +
+                       ", stateSize=" + stateSize +
+                       '}';
+       }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
index cc3dce2..370032a 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
@@ -26,8 +26,9 @@ import org.apache.flink.runtime.concurrent.Future;
 import org.apache.flink.runtime.concurrent.impl.FlinkCompletableFuture;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.executiongraph.ExecutionVertex;
-import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.state.ChainedStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StateUtil;
 import org.apache.flink.runtime.state.StreamStateHandle;
@@ -87,7 +88,7 @@ public class PendingCheckpoint {
 
        private final long checkpointTimestamp;
 
-       private final Map<JobVertexID, TaskState> taskStates;
+       private final Map<OperatorID, OperatorState> operatorStates;
 
        private final Map<ExecutionAttemptID, ExecutionVertex> 
notYetAcknowledgedTasks;
 
@@ -146,7 +147,7 @@ public class PendingCheckpoint {
                this.targetDirectory = targetDirectory;
                this.executor = Preconditions.checkNotNull(executor);
 
-               this.taskStates = new HashMap<>();
+               this.operatorStates = new HashMap<>();
                this.masterState = new ArrayList<>();
                this.acknowledgedTasks = new 
HashSet<>(verticesToConfirm.size());
                this.onCompletionPromise = new FlinkCompletableFuture<>();
@@ -178,8 +179,8 @@ public class PendingCheckpoint {
                return numAcknowledgedTasks;
        }
 
-       public Map<JobVertexID, TaskState> getTaskStates() {
-               return taskStates;
+       public Map<OperatorID, OperatorState> getOperatorStates() {
+               return operatorStates;
        }
 
        public boolean isFullyAcknowledged() {
@@ -261,7 +262,7 @@ public class PendingCheckpoint {
                        // make sure we fulfill the promise with an exception 
if something fails
                        try {
                                // externalize the metadata
-                               final Savepoint savepoint = new 
SavepointV2(checkpointId, taskStates.values());
+                               final Savepoint savepoint = new 
SavepointV2(checkpointId, operatorStates.values(), masterState);
 
                                // TEMP FIX - The savepoint store is strictly 
typed to file systems currently
                                //            but the checkpoints think more 
generic. we need to work with file handles
@@ -326,7 +327,7 @@ public class PendingCheckpoint {
                                checkpointId,
                                checkpointTimestamp,
                                System.currentTimeMillis(),
-                               taskStates,
+                               operatorStates,
                                masterState,
                                props,
                                externalMetadata,
@@ -380,41 +381,53 @@ public class PendingCheckpoint {
                                acknowledgedTasks.add(executionAttemptId);
                        }
 
-                       JobVertexID jobVertexID = vertex.getJobvertexId();
+                       List<OperatorID> operatorIDs = 
vertex.getJobVertex().getOperatorIDs();
                        int subtaskIndex = vertex.getParallelSubtaskIndex();
                        long ackTimestamp = System.currentTimeMillis();
 
                        long stateSize = 0;
-                       if (null != subtaskState) {
-                               TaskState taskState = 
taskStates.get(jobVertexID);
-
-                               if (null == taskState) {
-                                       @SuppressWarnings("deprecation")
-                                       ChainedStateHandle<StreamStateHandle> 
nonPartitionedState = 
-                                                       
subtaskState.getLegacyOperatorState();
-                                       ChainedStateHandle<OperatorStateHandle> 
partitioneableState =
-                                                       
subtaskState.getManagedOperatorState();
-                                       //TODO this should go away when we 
remove chained state, assigning state to operators directly instead
-                                       int chainLength;
-                                       if (nonPartitionedState != null) {
-                                               chainLength = 
nonPartitionedState.getLength();
-                                       } else if (partitioneableState != null) 
{
-                                               chainLength = 
partitioneableState.getLength();
-                                       } else {
-                                               chainLength = 1;
-                                       }
+                       if (subtaskState != null) {
+                               stateSize = subtaskState.getStateSize();
 
-                                       taskState = new TaskState(
-                                                       jobVertexID,
+                               @SuppressWarnings("deprecation")
+                               ChainedStateHandle<StreamStateHandle> 
nonPartitionedState =
+                                       subtaskState.getLegacyOperatorState();
+                               ChainedStateHandle<OperatorStateHandle> 
partitioneableState =
+                                       subtaskState.getManagedOperatorState();
+                               ChainedStateHandle<OperatorStateHandle> 
rawOperatorState =
+                                       subtaskState.getRawOperatorState();
+
+                               // break task state apart into separate 
operator states
+                               for (int x = 0; x < operatorIDs.size(); x++) {
+                                       OperatorID operatorID = 
operatorIDs.get(x);
+                                       OperatorState operatorState = 
operatorStates.get(operatorID);
+
+                                       if (operatorState == null) {
+                                               operatorState = new 
OperatorState(
+                                                       operatorID,
                                                        
vertex.getTotalNumberOfParallelSubtasks(),
-                                                       
vertex.getMaxParallelism(),
-                                                       chainLength);
+                                                       
vertex.getMaxParallelism());
+                                               operatorStates.put(operatorID, 
operatorState);
+                                       }
 
-                                       taskStates.put(jobVertexID, taskState);
-                               }
+                                       KeyedStateHandle managedKeyedState = 
null;
+                                       KeyedStateHandle rawKeyedState = null;
 
-                               taskState.putState(subtaskIndex, subtaskState);
-                               stateSize = subtaskState.getStateSize();
+                                       // only the head operator retains the 
keyed state
+                                       if (x == operatorIDs.size() - 1) {
+                                               managedKeyedState = 
subtaskState.getManagedKeyedState();
+                                               rawKeyedState = 
subtaskState.getRawKeyedState();
+                                       }
+
+                                       OperatorSubtaskState 
operatorSubtaskState = new OperatorSubtaskState(
+                                                       nonPartitionedState != 
null ? nonPartitionedState.get(x) : null,
+                                                       partitioneableState != 
null ? partitioneableState.get(x) : null,
+                                                       rawOperatorState != 
null ? rawOperatorState.get(x) : null,
+                                                       managedKeyedState,
+                                                       rawKeyedState);
+
+                                       operatorState.putState(subtaskIndex, 
operatorSubtaskState);
+                               }
                        }
 
                        ++numAcknowledgedTasks;
@@ -435,7 +448,7 @@ public class PendingCheckpoint {
                                        metrics.getBytesBufferedInAlignment(),
                                        alignmentDurationMillis);
 
-                               statsCallback.reportSubtaskStats(jobVertexID, 
subtaskStateStats);
+                               
statsCallback.reportSubtaskStats(vertex.getJobvertexId(), subtaskStateStats);
                        }
 
                        return TaskAcknowledgeResult.SUCCESS;
@@ -530,12 +543,12 @@ public class PendingCheckpoint {
                                                        // discard the private 
states.
                                                        // unregistered shared 
states are still considered private at this point.
                                                        try {
-                                                               
StateUtil.bestEffortDiscardAllStateObjects(taskStates.values());
+                                                               
StateUtil.bestEffortDiscardAllStateObjects(operatorStates.values());
                                                        } catch (Throwable t) {
                                                                LOG.warn("Could 
not properly dispose the private states in the pending checkpoint {} of job 
{}.",
                                                                        
checkpointId, jobId, t);
                                                        } finally {
-                                                               
taskStates.clear();
+                                                               
operatorStates.clear();
                                                        }
                                                }
                                        });

http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
index ac70e1a..1042d5a 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
@@ -18,9 +18,11 @@
 
 package org.apache.flink.runtime.checkpoint;
 
+import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.runtime.executiongraph.Execution;
 import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
@@ -31,277 +33,400 @@ import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.util.Preconditions;
 import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 
 /**
  * This class encapsulates the operation of assigning restored state when 
restoring from a checkpoint.
  */
 public class StateAssignmentOperation {
 
-       private final Logger logger;
+       private static final Logger LOG = 
LoggerFactory.getLogger(StateAssignmentOperation.class);
+
        private final Map<JobVertexID, ExecutionJobVertex> tasks;
-       private final Map<JobVertexID, TaskState> taskStates;
+       private final Map<OperatorID, OperatorState> operatorStates;
        private final boolean allowNonRestoredState;
 
        public StateAssignmentOperation(
-                       Logger logger,
                        Map<JobVertexID, ExecutionJobVertex> tasks,
-                       Map<JobVertexID, TaskState> taskStates,
+                       Map<OperatorID, OperatorState> operatorStates,
                        boolean allowNonRestoredState) {
 
-               this.logger = Preconditions.checkNotNull(logger);
                this.tasks = Preconditions.checkNotNull(tasks);
-               this.taskStates = Preconditions.checkNotNull(taskStates);
+               this.operatorStates = 
Preconditions.checkNotNull(operatorStates);
                this.allowNonRestoredState = allowNonRestoredState;
        }
 
        public boolean assignStates() throws Exception {
-
-               // this tracks if we find missing node hash ids and already use 
secondary mappings
-               boolean expandedToLegacyIds = false;
-
+               Map<OperatorID, OperatorState> localOperators = new 
HashMap<>(operatorStates);
                Map<JobVertexID, ExecutionJobVertex> localTasks = this.tasks;
 
-               for (Map.Entry<JobVertexID, TaskState> taskGroupStateEntry : 
taskStates.entrySet()) {
-
-                       TaskState taskState = taskGroupStateEntry.getValue();
-
-                       //----------------------------------------find vertex 
for state---------------------------------------------
-
-                       ExecutionJobVertex executionJobVertex = 
localTasks.get(taskGroupStateEntry.getKey());
-
-                       // on the first time we can not find the execution job 
vertex for an id, we also consider alternative ids,
-                       // for example as generated from older flink versions, 
to provide backwards compatibility.
-                       if (executionJobVertex == null && !expandedToLegacyIds) 
{
-                               localTasks = 
ExecutionJobVertex.includeLegacyJobVertexIDs(localTasks);
-                               executionJobVertex = 
localTasks.get(taskGroupStateEntry.getKey());
-                               expandedToLegacyIds = true;
-                               logger.info("Could not find ExecutionJobVertex. 
Including legacy JobVertexIDs in search.");
-                       }
-
-                       if (executionJobVertex == null) {
-                               if (allowNonRestoredState) {
-                                       logger.info("Skipped checkpoint state 
for operator {}.", taskState.getJobVertexID());
-                                       continue;
+               checkStateMappingCompleteness(allowNonRestoredState, 
operatorStates, tasks);
+
+               for (Map.Entry<JobVertexID, ExecutionJobVertex> task : 
localTasks.entrySet()) {
+                       final ExecutionJobVertex executionJobVertex = 
task.getValue();
+
+                       // find the states of all operators belonging to this 
task
+                       List<OperatorID> operatorIDs = 
executionJobVertex.getOperatorIDs();
+                       List<OperatorID> altOperatorIDs = 
executionJobVertex.getUserDefinedOperatorIDs();
+                       List<OperatorState> operatorStates = new ArrayList<>();
+                       boolean statelessTask = true;
+                       for (int x = 0; x < operatorIDs.size(); x++) {
+                               OperatorID operatorID = altOperatorIDs.get(x) 
== null
+                                       ? operatorIDs.get(x)
+                                       : altOperatorIDs.get(x);
+
+                               OperatorState operatorState = 
localOperators.remove(operatorID);
+                               if (operatorState == null) {
+                                       operatorState = new OperatorState(
+                                               operatorID,
+                                               
executionJobVertex.getParallelism(),
+                                               
executionJobVertex.getMaxParallelism());
                                } else {
-                                       throw new IllegalStateException("There 
is no execution job vertex for the job" +
-                                                       " vertex ID " + 
taskGroupStateEntry.getKey());
+                                       statelessTask = false;
                                }
+                               operatorStates.add(operatorState);
+                       }
+                       if (statelessTask) { // skip tasks where no operator 
has any state
+                               continue;
                        }
 
-                       checkParallelismPreconditions(taskState, 
executionJobVertex);
-
-                       assignTaskStatesToOperatorInstances(taskState, 
executionJobVertex);
+                       assignAttemptState(task.getValue(), operatorStates);
                }
 
                return true;
        }
 
-       private void checkParallelismPreconditions(TaskState taskState, 
ExecutionJobVertex executionJobVertex) {
-               //----------------------------------------max parallelism 
preconditions-------------------------------------
+       private void assignAttemptState(ExecutionJobVertex executionJobVertex, 
List<OperatorState> operatorStates) {
 
-               // check that the number of key groups have not changed or if 
we need to override it to satisfy the restored state
-               if (taskState.getMaxParallelism() != 
executionJobVertex.getMaxParallelism()) {
+               List<OperatorID> operatorIDs = 
executionJobVertex.getOperatorIDs();
 
-                       if (!executionJobVertex.isMaxParallelismConfigured()) {
-                               // if the max parallelism was not explicitly 
specified by the user, we derive it from the state
+               //1. first compute the new parallelism
+               checkParallelismPreconditions(operatorStates, 
executionJobVertex);
+
+               int newParallelism = executionJobVertex.getParallelism();
+
+               List<KeyGroupRange> keyGroupPartitions = 
createKeyGroupPartitions(
+                       executionJobVertex.getMaxParallelism(),
+                       newParallelism);
+
+               //2. Redistribute the operator state.
+               /**
+                *
+                * Redistribute ManagedOperatorStates and RawOperatorStates 
from old parallelism to new parallelism.
+                *
+                * The old ManagedOperatorStates with old parallelism 3:
+                *
+                *              parallelism0 parallelism1 parallelism2
+                * op0   states0,0    state0,1     state0,2
+                * op1
+                * op2   states2,0    state2,1     state1,2
+                * op3   states3,0    state3,1     state3,2
+                *
+                * The new ManagedOperatorStates with new parallelism 4:
+                *
+                *              parallelism0 parallelism1 parallelism2 
parallelism3
+                * op0   state0,0         state0,1         state0,2             
state0,3
+                * op1
+                * op2   state2,0         state2,1         state2,2             
state2,3
+                * op3   state3,0         state3,1         state3,2             
state3,3
+                */
+               List<List<Collection<OperatorStateHandle>>> 
newManagedOperatorStates = new ArrayList<>();
+               List<List<Collection<OperatorStateHandle>>> 
newRawOperatorStates = new ArrayList<>();
+
+               reDistributePartitionableStates(operatorStates, newParallelism, 
newManagedOperatorStates, newRawOperatorStates);
+
+
+               //3. Compute TaskStateHandles of every subTask in the 
executionJobVertex
+               /**
+                *  An executionJobVertex's all state handles needed to restore 
are something like a matrix
+                *
+                *              parallelism0 parallelism1 parallelism2 
parallelism3
+                * op0   sh(0,0)     sh(0,1)       sh(0,2)          sh(0,3)
+                * op1   sh(1,0)         sh(1,1)           sh(1,2)          
sh(1,3)
+                * op2   sh(2,0)         sh(2,1)           sh(2,2)              
sh(2,3)
+                * op3   sh(3,0)         sh(3,1)           sh(3,2)              
sh(3,3)
+                *
+                * we will compute the state handles column by column.
+                *
+                */
+               for (int subTaskIndex = 0; subTaskIndex < newParallelism; 
subTaskIndex++) {
+
+                       Execution currentExecutionAttempt = 
executionJobVertex.getTaskVertices()[subTaskIndex]
+                               .getCurrentExecutionAttempt();
+
+                       List<StreamStateHandle> subNonPartitionableState = new 
ArrayList<>();
+
+                       Tuple2<Collection<KeyedStateHandle>, 
Collection<KeyedStateHandle>> subKeyedState = null;
+
+                       List<Collection<OperatorStateHandle>> 
subManagedOperatorState = new ArrayList<>();
+                       List<Collection<OperatorStateHandle>> 
subRawOperatorState = new ArrayList<>();
+
+
+                       for (int operatorIndex = 0; operatorIndex < 
operatorIDs.size(); operatorIndex++) {
+                               OperatorState operatorState = 
operatorStates.get(operatorIndex);
+                               int oldParallelism = 
operatorState.getParallelism();
+
+                               // NonPartitioned State
+
+                               reAssignSubNonPartitionedStates(
+                                       operatorState,
+                                       subTaskIndex,
+                                       newParallelism,
+                                       oldParallelism,
+                                       subNonPartitionableState);
+
+                               // PartitionedState
+                               
reAssignSubPartitionableState(newManagedOperatorStates,
+                                       newRawOperatorStates,
+                                       subTaskIndex,
+                                       operatorIndex,
+                                       subManagedOperatorState,
+                                       subRawOperatorState);
+
+                               // KeyedState
+                               if (operatorIndex == operatorIDs.size() - 1) {
+                                       subKeyedState = 
reAssignSubKeyedStates(operatorState,
+                                               keyGroupPartitions,
+                                               subTaskIndex,
+                                               newParallelism,
+                                               oldParallelism);
 
-                               if (logger.isDebugEnabled()) {
-                                       logger.debug("Overriding maximum 
parallelism for JobVertex " + executionJobVertex.getJobVertexId()
-                                                       + " from " + 
executionJobVertex.getMaxParallelism() + " to " + 
taskState.getMaxParallelism());
                                }
+                       }
 
-                               
executionJobVertex.setMaxParallelism(taskState.getMaxParallelism());
-                       } else {
-                               // if the max parallelism was explicitly 
specified, we complain on mismatch
-                               throw new IllegalStateException("The maximum 
parallelism (" +
-                                               taskState.getMaxParallelism() + 
") with which the latest " +
-                                               "checkpoint of the execution 
job vertex " + executionJobVertex +
-                                               " has been taken and the 
current maximum parallelism (" +
-                                               
executionJobVertex.getMaxParallelism() + ") changed. This " +
-                                               "is currently not supported.");
+
+                       // check if a stateless task
+                       if (!allElementsAreNull(subNonPartitionableState) ||
+                               !allElementsAreNull(subManagedOperatorState) ||
+                               !allElementsAreNull(subRawOperatorState) ||
+                               subKeyedState != null) {
+
+                               TaskStateHandles taskStateHandles = new 
TaskStateHandles(
+
+                                       new 
ChainedStateHandle<>(subNonPartitionableState),
+                                       subManagedOperatorState,
+                                       subRawOperatorState,
+                                       subKeyedState != null ? 
subKeyedState.f0 : null,
+                                       subKeyedState != null ? 
subKeyedState.f1 : null);
+
+                               
currentExecutionAttempt.setInitialState(taskStateHandles);
                        }
                }
+       }
 
-               //----------------------------------------parallelism 
preconditions-----------------------------------------
 
-               final int oldParallelism = taskState.getParallelism();
-               final int newParallelism = executionJobVertex.getParallelism();
+       public void checkParallelismPreconditions(List<OperatorState> 
operatorStates, ExecutionJobVertex executionJobVertex) {
 
-               if (taskState.hasNonPartitionedState() && (oldParallelism != 
newParallelism)) {
-                       throw new IllegalStateException("Cannot restore the 
latest checkpoint because " +
-                                       "the operator " + 
executionJobVertex.getJobVertexId() + " has non-partitioned " +
-                                       "state and its parallelism changed. The 
operator " + executionJobVertex.getJobVertexId() +
-                                       " has parallelism " + newParallelism + 
" whereas the corresponding " +
-                                       "state object has a parallelism of " + 
oldParallelism);
+               for (OperatorState operatorState : operatorStates) {
+                       checkParallelismPreconditions(operatorState, 
executionJobVertex);
                }
        }
 
-       private static void assignTaskStatesToOperatorInstances(
-                       TaskState taskState, ExecutionJobVertex 
executionJobVertex) {
 
-               final int oldParallelism = taskState.getParallelism();
-               final int newParallelism = executionJobVertex.getParallelism();
+       private void reAssignSubPartitionableState(
+                       List<List<Collection<OperatorStateHandle>>> 
newMangedOperatorStates,
+                       List<List<Collection<OperatorStateHandle>>> 
newRawOperatorStates,
+                       int subTaskIndex, int operatorIndex,
+                       List<Collection<OperatorStateHandle>> 
subManagedOperatorState,
+                       List<Collection<OperatorStateHandle>> 
subRawOperatorState) {
 
-               List<KeyGroupRange> keyGroupPartitions = 
createKeyGroupPartitions(
-                               executionJobVertex.getMaxParallelism(),
-                               newParallelism);
-
-               final int chainLength = taskState.getChainLength();
+               if (newMangedOperatorStates.get(operatorIndex) != null) {
+                       
subManagedOperatorState.add(newMangedOperatorStates.get(operatorIndex).get(subTaskIndex));
+               } else {
+                       subManagedOperatorState.add(null);
+               }
+               if (newRawOperatorStates.get(operatorIndex) != null) {
+                       
subRawOperatorState.add(newRawOperatorStates.get(operatorIndex).get(subTaskIndex));
+               } else {
+                       subRawOperatorState.add(null);
+               }
 
-               // operator chain idx -> list of the stored op states from all 
parallel instances for this chain idx
-               @SuppressWarnings("unchecked")
-               List<OperatorStateHandle>[] parallelOpStatesBackend = new 
List[chainLength];
-               @SuppressWarnings("unchecked")
-               List<OperatorStateHandle>[] parallelOpStatesStream = new 
List[chainLength];
 
-               List<KeyedStateHandle> parallelKeyedStatesBackend = new 
ArrayList<>(oldParallelism);
-               List<KeyedStateHandle> parallelKeyedStateStream = new 
ArrayList<>(oldParallelism);
+       }
 
-               for (int p = 0; p < oldParallelism; ++p) {
-                       SubtaskState subtaskState = taskState.getState(p);
+       private Tuple2<Collection<KeyedStateHandle>, 
Collection<KeyedStateHandle>> reAssignSubKeyedStates(
+                       OperatorState operatorState,
+                       List<KeyGroupRange> keyGroupPartitions,
+                       int subTaskIndex,
+                       int newParallelism,
+                       int oldParallelism) {
+
+               Collection<KeyedStateHandle> subManagedKeyedState;
+               Collection<KeyedStateHandle> subRawKeyedState;
+
+               if (newParallelism == oldParallelism) {
+                       if (operatorState.getState(subTaskIndex) != null) {
+                               KeyedStateHandle oldSubManagedKeyedState = 
operatorState.getState(subTaskIndex).getManagedKeyedState();
+                               KeyedStateHandle oldSubRawKeyedState = 
operatorState.getState(subTaskIndex).getRawKeyedState();
+                               subManagedKeyedState = oldSubManagedKeyedState 
!= null ? Collections.singletonList(
+                                       oldSubManagedKeyedState) : null;
+                               subRawKeyedState = oldSubRawKeyedState != null 
? Collections.singletonList(
+                                       oldSubRawKeyedState) : null;
+                       } else {
+                               subManagedKeyedState = null;
+                               subRawKeyedState = null;
+                       }
+               } else {
+                       subManagedKeyedState = 
getManagedKeyedStateHandles(operatorState, 
keyGroupPartitions.get(subTaskIndex));
+                       subRawKeyedState = 
getRawKeyedStateHandles(operatorState, keyGroupPartitions.get(subTaskIndex));
+               }
+               if (subManagedKeyedState == null && subRawKeyedState == null) {
+                       return null;
+               }
+               return new Tuple2<>(subManagedKeyedState, subRawKeyedState);
+       }
 
-                       if (null != subtaskState) {
-                               collectParallelStatesByChainOperator(
-                                               parallelOpStatesBackend, 
subtaskState.getManagedOperatorState());
 
-                               collectParallelStatesByChainOperator(
-                                               parallelOpStatesStream, 
subtaskState.getRawOperatorState());
+       private <X> boolean allElementsAreNull(List<X> nonPartitionableStates) {
+               for (Object streamStateHandle : nonPartitionableStates) {
+                       if (streamStateHandle != null) {
+                               return false;
+                       }
+               }
+               return true;
+       }
 
-                               KeyedStateHandle keyedStateBackend = 
subtaskState.getManagedKeyedState();
-                               if (null != keyedStateBackend) {
-                                       
parallelKeyedStatesBackend.add(keyedStateBackend);
-                               }
 
-                               KeyedStateHandle keyedStateStream = 
subtaskState.getRawKeyedState();
-                               if (null != keyedStateStream) {
-                                       
parallelKeyedStateStream.add(keyedStateStream);
-                               }
+       private void reAssignSubNonPartitionedStates(
+                       OperatorState operatorState,
+                       int subTaskIndex,
+                       int newParallelism,
+                       int oldParallelism,
+               List<StreamStateHandle> subNonPartitionableState) {
+               if (oldParallelism == newParallelism) {
+                       if (operatorState.getState(subTaskIndex) != null) {
+                               
subNonPartitionableState.add(operatorState.getState(subTaskIndex).getLegacyOperatorState());
+                       } else {
+                               subNonPartitionableState.add(null);
                        }
+               } else {
+                       subNonPartitionableState.add(null);
                }
+       }
 
-               // operator chain index -> lists with collected states (one 
collection for each parallel subtasks)
-               @SuppressWarnings("unchecked")
-               List<Collection<OperatorStateHandle>>[] 
partitionedParallelStatesBackend = new List[chainLength];
+       private void reDistributePartitionableStates(
+                       List<OperatorState> operatorStates, int newParallelism,
+                       List<List<Collection<OperatorStateHandle>>> 
newManagedOperatorStates,
+                       List<List<Collection<OperatorStateHandle>>> 
newRawOperatorStates) {
 
-               @SuppressWarnings("unchecked")
-               List<Collection<OperatorStateHandle>>[] 
partitionedParallelStatesStream = new List[chainLength];
+               //collect the old partitionalbe state
+               List<List<OperatorStateHandle>> oldManagedOperatorStates = new 
ArrayList<>();
+               List<List<OperatorStateHandle>> oldRawOperatorStates = new 
ArrayList<>();
 
-               //TODO here we can employ different redistribution strategies 
for state, e.g. union state.
-               // For now we only offer round robin as the default.
-               OperatorStateRepartitioner opStateRepartitioner = 
RoundRobinOperatorStateRepartitioner.INSTANCE;
+               collectPartionableStates(operatorStates, 
oldManagedOperatorStates, oldRawOperatorStates);
 
-               for (int chainIdx = 0; chainIdx < chainLength; ++chainIdx) {
 
-                       List<OperatorStateHandle> chainOpParallelStatesBackend 
= parallelOpStatesBackend[chainIdx];
-                       List<OperatorStateHandle> chainOpParallelStatesStream = 
parallelOpStatesStream[chainIdx];
+               //redistribute
+               OperatorStateRepartitioner opStateRepartitioner = 
RoundRobinOperatorStateRepartitioner.INSTANCE;
 
-                       partitionedParallelStatesBackend[chainIdx] = 
applyRepartitioner(
-                                       opStateRepartitioner,
-                                       chainOpParallelStatesBackend,
-                                       oldParallelism,
-                                       newParallelism);
+               for (int operatorIndex = 0; operatorIndex < 
operatorStates.size(); operatorIndex++) {
+                       int oldParallelism = 
operatorStates.get(operatorIndex).getParallelism();
+                       
newManagedOperatorStates.add(applyRepartitioner(opStateRepartitioner,
+                               oldManagedOperatorStates.get(operatorIndex), 
oldParallelism, newParallelism));
+                       
newRawOperatorStates.add(applyRepartitioner(opStateRepartitioner,
+                               oldRawOperatorStates.get(operatorIndex), 
oldParallelism, newParallelism));
 
-                       partitionedParallelStatesStream[chainIdx] = 
applyRepartitioner(
-                                       opStateRepartitioner,
-                                       chainOpParallelStatesStream,
-                                       oldParallelism,
-                                       newParallelism);
                }
+       }
 
-               for (int subTaskIdx = 0; subTaskIdx < newParallelism; 
++subTaskIdx) {
-                       // non-partitioned state
-                       ChainedStateHandle<StreamStateHandle> 
nonPartitionableState = null;
-
-                       if (oldParallelism == newParallelism) {
-                               if (taskState.getState(subTaskIdx) != null) {
-                                       nonPartitionableState = 
taskState.getState(subTaskIdx).getLegacyOperatorState();
-                               }
-                       }
 
-                       // partitionable state
-                       @SuppressWarnings("unchecked")
-                       Collection<OperatorStateHandle>[] iab = new 
Collection[chainLength];
-                       @SuppressWarnings("unchecked")
-                       Collection<OperatorStateHandle>[] ias = new 
Collection[chainLength];
-                       List<Collection<OperatorStateHandle>> 
operatorStateFromBackend = Arrays.asList(iab);
-                       List<Collection<OperatorStateHandle>> 
operatorStateFromStream = Arrays.asList(ias);
+       private void collectPartionableStates(
+                       List<OperatorState> operatorStates,
+                       List<List<OperatorStateHandle>> managedOperatorStates,
+                       List<List<OperatorStateHandle>> rawOperatorStates) {
 
-                       for (int chainIdx = 0; chainIdx < 
partitionedParallelStatesBackend.length; ++chainIdx) {
-                               List<Collection<OperatorStateHandle>> 
redistributedOpStateBackend =
-                                               
partitionedParallelStatesBackend[chainIdx];
+               for (OperatorState operatorState : operatorStates) {
+                       List<OperatorStateHandle> managedOperatorState = null;
+                       List<OperatorStateHandle> rawOperatorState = null;
 
-                               List<Collection<OperatorStateHandle>> 
redistributedOpStateStream =
-                                               
partitionedParallelStatesStream[chainIdx];
+                       for (int i = 0; i < operatorState.getParallelism(); 
i++) {
+                               OperatorSubtaskState operatorSubtaskState = 
operatorState.getState(i);
+                               if (operatorSubtaskState != null) {
+                                       if 
(operatorSubtaskState.getManagedOperatorState() != null) {
+                                               if (managedOperatorState == 
null) {
+                                                       managedOperatorState = 
new ArrayList<>();
+                                               }
+                                               
managedOperatorState.add(operatorSubtaskState.getManagedOperatorState());
+                                       }
 
-                               if (redistributedOpStateBackend != null) {
-                                       operatorStateFromBackend.set(chainIdx, 
redistributedOpStateBackend.get(subTaskIdx));
+                                       if 
(operatorSubtaskState.getRawOperatorState() != null) {
+                                               if (rawOperatorState == null) {
+                                                       rawOperatorState = new 
ArrayList<>();
+                                               }
+                                               
rawOperatorState.add(operatorSubtaskState.getRawOperatorState());
+                                       }
                                }
 
-                               if (redistributedOpStateStream != null) {
-                                       operatorStateFromStream.set(chainIdx, 
redistributedOpStateStream.get(subTaskIdx));
-                               }
                        }
+                       managedOperatorStates.add(managedOperatorState);
+                       rawOperatorStates.add(rawOperatorState);
+               }
+       }
 
-                       Execution currentExecutionAttempt = executionJobVertex
-                                       .getTaskVertices()[subTaskIdx]
-                                       .getCurrentExecutionAttempt();
-
-                       List<KeyedStateHandle> newKeyedStatesBackend;
-                       List<KeyedStateHandle> newKeyedStateStream;
-                       if (oldParallelism == newParallelism) {
-                               SubtaskState subtaskState = 
taskState.getState(subTaskIdx);
-                               if (subtaskState != null) {
-                                       KeyedStateHandle oldKeyedStatesBackend 
= subtaskState.getManagedKeyedState();
-                                       KeyedStateHandle oldKeyedStatesStream = 
subtaskState.getRawKeyedState();
-                                       newKeyedStatesBackend = 
oldKeyedStatesBackend != null ? Collections.singletonList(
-                                                       oldKeyedStatesBackend) 
: null;
-                                       newKeyedStateStream = 
oldKeyedStatesStream != null ? Collections.singletonList(
-                                                       oldKeyedStatesStream) : 
null;
-                               } else {
-                                       newKeyedStatesBackend = null;
-                                       newKeyedStateStream = null;
-                               }
-                       } else {
-                               KeyGroupRange subtaskKeyGroupIds = 
keyGroupPartitions.get(subTaskIdx);
-                               newKeyedStatesBackend = 
getKeyedStateHandles(parallelKeyedStatesBackend, subtaskKeyGroupIds);
-                               newKeyedStateStream = 
getKeyedStateHandles(parallelKeyedStateStream, subtaskKeyGroupIds);
-                       }
 
-                       TaskStateHandles taskStateHandles = new 
TaskStateHandles(
-                                       nonPartitionableState,
-                                       operatorStateFromBackend,
-                                       operatorStateFromStream,
-                                       newKeyedStatesBackend,
-                                       newKeyedStateStream);
+       /**
+        * Collect {@link KeyGroupsStateHandle  managedKeyedStateHandles} which 
have intersection with given
+        * {@link KeyGroupRange} from {@link TaskState operatorState}
+        *
+        * @param operatorState        all state handles of a operator
+        * @param subtaskKeyGroupRange the KeyGroupRange of a subtask
+        * @return all managedKeyedStateHandles which have intersection with 
given KeyGroupRange
+        */
+       public static List<KeyedStateHandle> getManagedKeyedStateHandles(
+                       OperatorState operatorState,
+                       KeyGroupRange subtaskKeyGroupRange) {
+
+               List<KeyedStateHandle> subtaskKeyedStateHandles = null;
+
+               for (int i = 0; i < operatorState.getParallelism(); i++) {
+                       if (operatorState.getState(i) != null && 
operatorState.getState(i).getManagedKeyedState() != null) {
+                               KeyedStateHandle intersectedKeyedStateHandle = 
operatorState.getState(i).getManagedKeyedState().getIntersection(subtaskKeyGroupRange);
 
-                       
currentExecutionAttempt.setInitialState(taskStateHandles);
+                               if (intersectedKeyedStateHandle != null) {
+                                       if (subtaskKeyedStateHandles == null) {
+                                               subtaskKeyedStateHandles = new 
ArrayList<>();
+                                       }
+                                       
subtaskKeyedStateHandles.add(intersectedKeyedStateHandle);
+                               }
+                       }
                }
+
+               return subtaskKeyedStateHandles;
        }
 
        /**
-        * Determine the subset of {@link KeyGroupsStateHandle 
KeyGroupsStateHandles} with correct
-        * key group index for the given subtask {@link KeyGroupRange}.
-        * <p>
-        * <p>This is publicly visible to be used in tests.
+        * Collect {@link KeyGroupsStateHandle  rawKeyedStateHandles} which 
have intersection with given
+        * {@link KeyGroupRange} from {@link TaskState operatorState}
+        *
+        * @param operatorState        all state handles of a operator
+        * @param subtaskKeyGroupRange the KeyGroupRange of a subtask
+        * @return all rawKeyedStateHandles which have intersection with given 
KeyGroupRange
         */
-       public static List<KeyedStateHandle> getKeyedStateHandles(
-                       Collection<? extends KeyedStateHandle> 
keyedStateHandles,
-                       KeyGroupRange subtaskKeyGroupRange) {
+       public static List<KeyedStateHandle> getRawKeyedStateHandles(
+               OperatorState operatorState,
+               KeyGroupRange subtaskKeyGroupRange) {
 
-               List<KeyedStateHandle> subtaskKeyedStateHandles = new 
ArrayList<>();
+               List<KeyedStateHandle> subtaskKeyedStateHandles = null;
 
-               for (KeyedStateHandle keyedStateHandle : keyedStateHandles) {
-                       KeyedStateHandle intersectedKeyedStateHandle = 
keyedStateHandle.getIntersection(subtaskKeyGroupRange);
+               for (int i = 0; i < operatorState.getParallelism(); i++) {
+                       if (operatorState.getState(i) != null && 
operatorState.getState(i).getRawKeyedState() != null) {
+                               KeyedStateHandle intersectedKeyedStateHandle = 
operatorState.getState(i).getRawKeyedState().getIntersection(subtaskKeyGroupRange);
 
-                       if (intersectedKeyedStateHandle != null) {
-                               
subtaskKeyedStateHandles.add(intersectedKeyedStateHandle);
+                               if (intersectedKeyedStateHandle != null) {
+                                       if (subtaskKeyedStateHandles == null) {
+                                               subtaskKeyedStateHandles = new 
ArrayList<>();
+                                       }
+                                       
subtaskKeyedStateHandles.add(intersectedKeyedStateHandle);
+                               }
                        }
                }
 
@@ -331,37 +456,90 @@ public class StateAssignmentOperation {
        }
 
        /**
-        * @param chainParallelOpStates array = chain ops, array[idx] = 
parallel states for this chain op.
-        * @param chainOpState the operator chain
+        * Verifies conditions in regards to parallelism and maxParallelism 
that must be met when restoring state.
+        *
+        * @param operatorState      state to restore
+        * @param executionJobVertex task for which the state should be restored
         */
-       private static void collectParallelStatesByChainOperator(
-                       List<OperatorStateHandle>[] chainParallelOpStates, 
ChainedStateHandle<OperatorStateHandle> chainOpState) {
+       private static void checkParallelismPreconditions(OperatorState 
operatorState, ExecutionJobVertex executionJobVertex) {
+               //----------------------------------------max parallelism 
preconditions-------------------------------------
 
-               if (null != chainOpState) {
+               // check that the number of key groups have not changed or if 
we need to override it to satisfy the restored state
+               if (operatorState.getMaxParallelism() != 
executionJobVertex.getMaxParallelism()) {
 
-                       int chainLength = chainOpState.getLength();
-                       Preconditions.checkState(chainLength >= 
chainParallelOpStates.length,
-                                       "Found more states than operators in 
the chain. Chain length: " + chainLength +
-                                                       ", States: " + 
chainParallelOpStates.length);
+                       if (!executionJobVertex.isMaxParallelismConfigured()) {
+                               // if the max parallelism was not explicitly 
specified by the user, we derive it from the state
 
-                       for (int chainIdx = 0; chainIdx < 
chainParallelOpStates.length; ++chainIdx) {
-                               OperatorStateHandle operatorState = 
chainOpState.get(chainIdx);
+                               LOG.debug("Overriding maximum parallelism for 
JobVertex {} from {} to {}",
+                                       executionJobVertex.getJobVertexId(), 
executionJobVertex.getMaxParallelism(), operatorState.getMaxParallelism());
 
-                               if (null != operatorState) {
+                               
executionJobVertex.setMaxParallelism(operatorState.getMaxParallelism());
+                       } else {
+                               // if the max parallelism was explicitly 
specified, we complain on mismatch
+                               throw new IllegalStateException("The maximum 
parallelism (" +
+                                       operatorState.getMaxParallelism() + ") 
with which the latest " +
+                                       "checkpoint of the execution job vertex 
" + executionJobVertex +
+                                       " has been taken and the current 
maximum parallelism (" +
+                                       executionJobVertex.getMaxParallelism() 
+ ") changed. This " +
+                                       "is currently not supported.");
+                       }
+               }
 
-                                       List<OperatorStateHandle> 
opParallelStatesForOneChainOp = chainParallelOpStates[chainIdx];
+               //----------------------------------------parallelism 
preconditions-----------------------------------------
 
-                                       if (null == 
opParallelStatesForOneChainOp) {
-                                               opParallelStatesForOneChainOp = 
new ArrayList<>();
-                                               chainParallelOpStates[chainIdx] 
= opParallelStatesForOneChainOp;
-                                       }
-                                       
opParallelStatesForOneChainOp.add(operatorState);
+               final int oldParallelism = operatorState.getParallelism();
+               final int newParallelism = executionJobVertex.getParallelism();
+
+               if (operatorState.hasNonPartitionedState() && (oldParallelism 
!= newParallelism)) {
+                       throw new IllegalStateException("Cannot restore the 
latest checkpoint because " +
+                               "the operator " + 
executionJobVertex.getJobVertexId() + " has non-partitioned " +
+                               "state and its parallelism changed. The 
operator " + executionJobVertex.getJobVertexId() +
+                               " has parallelism " + newParallelism + " 
whereas the corresponding " +
+                               "state object has a parallelism of " + 
oldParallelism);
+               }
+       }
+
+       /**
+        * Verifies that all operator states can be mapped to an execution job 
vertex.
+        *
+        * @param allowNonRestoredState if false an exception will be thrown if 
a state could not be mapped
+        * @param operatorStates operator states to map
+        * @param tasks task to map to
+        */
+       private static void checkStateMappingCompleteness(
+                       boolean allowNonRestoredState,
+                       Map<OperatorID, OperatorState> operatorStates,
+                       Map<JobVertexID, ExecutionJobVertex> tasks) {
+
+               Set<OperatorID> allOperatorIDs = new HashSet<>();
+               for (ExecutionJobVertex executionJobVertex : tasks.values()) {
+                       
allOperatorIDs.addAll(executionJobVertex.getOperatorIDs());
+               }
+               for (Map.Entry<OperatorID, OperatorState> 
operatorGroupStateEntry : operatorStates.entrySet()) {
+                       OperatorState operatorState = 
operatorGroupStateEntry.getValue();
+                       //----------------------------------------find operator 
for state---------------------------------------------
+
+                       if 
(!allOperatorIDs.contains(operatorGroupStateEntry.getKey())) {
+                               if (allowNonRestoredState) {
+                                       LOG.info("Skipped checkpoint state for 
operator {}.", operatorState.getOperatorID());
+                               } else {
+                                       throw new IllegalStateException("There 
is no operator for the state " + operatorState.getOperatorID());
                                }
                        }
                }
        }
 
-       private static List<Collection<OperatorStateHandle>> applyRepartitioner(
+       /**
+        * Repartitions the given operator state using the given {@link 
OperatorStateRepartitioner} with respect to the new
+        * parallelism.
+        *
+        * @param opStateRepartitioner  partitioner to use
+        * @param chainOpParallelStates state to repartition
+        * @param oldParallelism        parallelism with which the state is 
currently partitioned
+        * @param newParallelism        parallelism with which the state should 
be partitioned
+        * @return repartitioned state
+        */
+       public static List<Collection<OperatorStateHandle>> applyRepartitioner(
                        OperatorStateRepartitioner opStateRepartitioner,
                        List<OperatorStateHandle> chainOpParallelStates,
                        int oldParallelism,
@@ -399,4 +577,27 @@ public class StateAssignmentOperation {
                        return repackStream;
                }
        }
-}
\ No newline at end of file
+
+       /**
+        * Determine the subset of {@link KeyGroupsStateHandle 
KeyGroupsStateHandles} with correct
+        * key group index for the given subtask {@link KeyGroupRange}.
+        * <p>
+        * <p>This is publicly visible to be used in tests.
+        */
+       public static List<KeyedStateHandle> getKeyedStateHandles(
+               Collection<? extends KeyedStateHandle> keyedStateHandles,
+               KeyGroupRange subtaskKeyGroupRange) {
+
+               List<KeyedStateHandle> subtaskKeyedStateHandles = new 
ArrayList<>();
+
+               for (KeyedStateHandle keyedStateHandle : keyedStateHandles) {
+                       KeyedStateHandle intersectedKeyedStateHandle = 
keyedStateHandle.getIntersection(subtaskKeyGroupRange);
+
+                       if (intersectedKeyedStateHandle != null) {
+                               
subtaskKeyedStateHandles.add(intersectedKeyedStateHandle);
+                       }
+               }
+
+               return subtaskKeyedStateHandles;
+       }
+}

Reply via email to