http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
index 344b340..88b95f5 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
@@ -23,14 +23,15 @@ import org.apache.flink.runtime.concurrent.Executors;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.executiongraph.ExecutionVertex;
 import org.apache.flink.runtime.jobgraph.JobStatus;
+import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings;
 import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
-import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.util.TestLogger;
+
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.powermock.core.classloader.annotations.PrepareForTest;
@@ -42,8 +43,8 @@ import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
-import static org.mockito.Matchers.anyInt;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
@@ -89,29 +90,26 @@ public class CheckpointCoordinatorFailureTest extends 
TestLogger {
                assertFalse(pendingCheckpoint.isDiscarded());
 
                final long checkpointId = 
coord.getPendingCheckpoints().keySet().iterator().next();
-               
-               SubtaskState subtaskState = mock(SubtaskState.class);
+
 
                StreamStateHandle legacyHandle = mock(StreamStateHandle.class);
-               ChainedStateHandle<StreamStateHandle> chainedLegacyHandle = 
mock(ChainedStateHandle.class);
-               
when(chainedLegacyHandle.get(anyInt())).thenReturn(legacyHandle);
-               
when(subtaskState.getLegacyOperatorState()).thenReturn(chainedLegacyHandle);
+               KeyedStateHandle managedKeyedHandle = 
mock(KeyedStateHandle.class);
+               KeyedStateHandle rawKeyedHandle = mock(KeyedStateHandle.class);
+               OperatorStateHandle managedOpHandle = 
mock(OperatorStateHandle.class);
+               OperatorStateHandle rawOpHandle = 
mock(OperatorStateHandle.class);
 
-               OperatorStateHandle managedHandle = 
mock(OperatorStateHandle.class);
-               ChainedStateHandle<OperatorStateHandle> chainedManagedHandle = 
mock(ChainedStateHandle.class);
-               
when(chainedManagedHandle.get(anyInt())).thenReturn(managedHandle);
-               
when(subtaskState.getManagedOperatorState()).thenReturn(chainedManagedHandle);
+               final OperatorSubtaskState operatorSubtaskState = spy(new 
OperatorSubtaskState(
+                       legacyHandle,
+                       managedOpHandle,
+                       rawOpHandle,
+                       managedKeyedHandle,
+                       rawKeyedHandle));
 
-               OperatorStateHandle rawHandle = mock(OperatorStateHandle.class);
-               ChainedStateHandle<OperatorStateHandle> chainedRawHandle = 
mock(ChainedStateHandle.class);
-               when(chainedRawHandle.get(anyInt())).thenReturn(rawHandle);
-               
when(subtaskState.getRawOperatorState()).thenReturn(chainedRawHandle);
+               TaskStateSnapshot subtaskState = spy(new TaskStateSnapshot());
+               subtaskState.putSubtaskStateByOperatorID(new OperatorID(), 
operatorSubtaskState);
+
+               
when(subtaskState.getSubtaskStateByOperatorID(OperatorID.fromJobVertexID(vertex.getJobvertexId()))).thenReturn(operatorSubtaskState);
 
-               KeyedStateHandle managedKeyedHandle = 
mock(KeyedStateHandle.class);
-               
when(subtaskState.getRawKeyedState()).thenReturn(managedKeyedHandle);
-               KeyedStateHandle managedRawHandle = 
mock(KeyedStateHandle.class);
-               
when(subtaskState.getManagedKeyedState()).thenReturn(managedRawHandle);
-               
                AcknowledgeCheckpoint acknowledgeMessage = new 
AcknowledgeCheckpoint(jid, executionAttemptId, checkpointId, new 
CheckpointMetrics(), subtaskState);
                
                try {
@@ -126,11 +124,12 @@ public class CheckpointCoordinatorFailureTest extends 
TestLogger {
                assertTrue(pendingCheckpoint.isDiscarded());
 
                // make sure that the subtask state has been discarded after we 
could not complete it.
-               
verify(subtaskState.getLegacyOperatorState().get(0)).discardState();
-               
verify(subtaskState.getManagedOperatorState().get(0)).discardState();
-               
verify(subtaskState.getRawOperatorState().get(0)).discardState();
-               verify(subtaskState.getManagedKeyedState()).discardState();
-               verify(subtaskState.getRawKeyedState()).discardState();
+               verify(operatorSubtaskState).discardState();
+               
verify(operatorSubtaskState.getLegacyOperatorState()).discardState();
+               
verify(operatorSubtaskState.getManagedOperatorState().iterator().next()).discardState();
+               
verify(operatorSubtaskState.getRawOperatorState().iterator().next()).discardState();
+               
verify(operatorSubtaskState.getManagedKeyedState().iterator().next()).discardState();
+               
verify(operatorSubtaskState.getRawKeyedState().iterator().next()).discardState();
        }
 
        private static final class FailingCompletedCheckpointStore implements 
CompletedCheckpointStore {

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
index cb92df6..d9af879 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
@@ -44,7 +44,6 @@ import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.state.StreamStateHandle;
-import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.runtime.state.filesystem.FileStateHandle;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
 import org.apache.flink.runtime.testutils.CommonTestUtils;
@@ -93,7 +92,6 @@ import static org.junit.Assert.fail;
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.anyLong;
 import static org.mockito.Mockito.doAnswer;
-import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.eq;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
@@ -102,7 +100,6 @@ import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
-import static org.mockito.Mockito.withSettings;
 
 /**
  * Tests for the checkpoint coordinator.
@@ -555,31 +552,29 @@ public class CheckpointCoordinatorTest extends TestLogger 
{
                        assertFalse(checkpoint.isDiscarded());
                        assertFalse(checkpoint.isFullyAcknowledged());
 
-                       OperatorID opID1 = 
OperatorID.fromJobVertexID(vertex1.getJobvertexId());
-                       OperatorID opID2 = 
OperatorID.fromJobVertexID(vertex2.getJobvertexId());
-
-                       Map<OperatorID, OperatorState> operatorStates = 
checkpoint.getOperatorStates();
-
-                       operatorStates.put(opID1, new SpyInjectingOperatorState(
-                               opID1, 
vertex1.getTotalNumberOfParallelSubtasks(), vertex1.getMaxParallelism()));
-                       operatorStates.put(opID2, new SpyInjectingOperatorState(
-                               opID2, 
vertex2.getTotalNumberOfParallelSubtasks(), vertex2.getMaxParallelism()));
-
                        // check that the vertices received the trigger 
checkpoint message
                        {
                                verify(vertex1.getCurrentExecutionAttempt(), 
times(1)).triggerCheckpoint(eq(checkpointId), eq(timestamp), 
any(CheckpointOptions.class));
                                verify(vertex2.getCurrentExecutionAttempt(), 
times(1)).triggerCheckpoint(eq(checkpointId), eq(timestamp), 
any(CheckpointOptions.class));
                        }
 
+                       OperatorID opID1 = 
OperatorID.fromJobVertexID(vertex1.getJobvertexId());
+                       OperatorID opID2 = 
OperatorID.fromJobVertexID(vertex2.getJobvertexId());
+                       TaskStateSnapshot taskOperatorSubtaskStates1 = 
mock(TaskStateSnapshot.class);
+                       TaskStateSnapshot taskOperatorSubtaskStates2 = 
mock(TaskStateSnapshot.class);
+                       OperatorSubtaskState subtaskState1 = 
mock(OperatorSubtaskState.class);
+                       OperatorSubtaskState subtaskState2 = 
mock(OperatorSubtaskState.class);
+                       
when(taskOperatorSubtaskStates1.getSubtaskStateByOperatorID(opID1)).thenReturn(subtaskState1);
+                       
when(taskOperatorSubtaskStates2.getSubtaskStateByOperatorID(opID2)).thenReturn(subtaskState2);
+
                        // acknowledge from one of the tasks
-                       AcknowledgeCheckpoint acknowledgeCheckpoint1 = new 
AcknowledgeCheckpoint(jid, attemptID2, checkpointId, new CheckpointMetrics(), 
mock(SubtaskState.class));
+                       AcknowledgeCheckpoint acknowledgeCheckpoint1 = new 
AcknowledgeCheckpoint(jid, attemptID2, checkpointId, new CheckpointMetrics(), 
taskOperatorSubtaskStates2);
                        coord.receiveAcknowledgeMessage(acknowledgeCheckpoint1);
-                       OperatorSubtaskState subtaskState2 = 
operatorStates.get(opID2).getState(vertex2.getParallelSubtaskIndex());
                        assertEquals(1, 
checkpoint.getNumberOfAcknowledgedTasks());
                        assertEquals(1, 
checkpoint.getNumberOfNonAcknowledgedTasks());
                        assertFalse(checkpoint.isDiscarded());
                        assertFalse(checkpoint.isFullyAcknowledged());
-                       verify(subtaskState2, 
never()).registerSharedStates(any(SharedStateRegistry.class));
+                       verify(taskOperatorSubtaskStates2, 
never()).registerSharedStates(any(SharedStateRegistry.class));
 
                        // acknowledge the same task again (should not matter)
                        coord.receiveAcknowledgeMessage(acknowledgeCheckpoint1);
@@ -588,8 +583,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
                        verify(subtaskState2, 
never()).registerSharedStates(any(SharedStateRegistry.class));
 
                        // acknowledge the other task.
-                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, attemptID1, checkpointId, new CheckpointMetrics(), 
mock(SubtaskState.class)));
-                       OperatorSubtaskState subtaskState1 = 
operatorStates.get(opID1).getState(vertex1.getParallelSubtaskIndex());
+                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, attemptID1, checkpointId, new CheckpointMetrics(), 
taskOperatorSubtaskStates1));
 
                        // the checkpoint is internally converted to a 
successful checkpoint and the
                        // pending checkpoint object is disposed
@@ -628,9 +622,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
                        long checkpointIdNew = 
coord.getPendingCheckpoints().entrySet().iterator().next().getKey();
                        coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, attemptID1, checkpointIdNew));
-                       subtaskState1 = 
operatorStates.get(opID1).getState(vertex1.getParallelSubtaskIndex());
                        coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, attemptID2, checkpointIdNew));
-                       subtaskState2 = 
operatorStates.get(opID2).getState(vertex2.getParallelSubtaskIndex());
 
                        assertEquals(0, coord.getNumberOfPendingCheckpoints());
                        assertEquals(1, 
coord.getNumberOfRetainedSuccessfulCheckpoints());
@@ -852,18 +844,20 @@ public class CheckpointCoordinatorTest extends TestLogger 
{
                        OperatorID opID2 = 
OperatorID.fromJobVertexID(ackVertex2.getJobvertexId());
                        OperatorID opID3 = 
OperatorID.fromJobVertexID(ackVertex3.getJobvertexId());
 
-                       Map<OperatorID, OperatorState> operatorStates1 = 
pending1.getOperatorStates();
+                       TaskStateSnapshot taskOperatorSubtaskStates1_1 = 
spy(new TaskStateSnapshot());
+                       TaskStateSnapshot taskOperatorSubtaskStates1_2 = 
spy(new TaskStateSnapshot());
+                       TaskStateSnapshot taskOperatorSubtaskStates1_3 = 
spy(new TaskStateSnapshot());
 
-                       operatorStates1.put(opID1, new 
SpyInjectingOperatorState(
-                               opID1, 
ackVertex1.getTotalNumberOfParallelSubtasks(), ackVertex1.getMaxParallelism()));
-                       operatorStates1.put(opID2, new 
SpyInjectingOperatorState(
-                               opID2, 
ackVertex2.getTotalNumberOfParallelSubtasks(), ackVertex2.getMaxParallelism()));
-                       operatorStates1.put(opID3, new 
SpyInjectingOperatorState(
-                               opID3, 
ackVertex3.getTotalNumberOfParallelSubtasks(), ackVertex3.getMaxParallelism()));
+                       OperatorSubtaskState subtaskState1_1 = 
mock(OperatorSubtaskState.class);
+                       OperatorSubtaskState subtaskState1_2 = 
mock(OperatorSubtaskState.class);
+                       OperatorSubtaskState subtaskState1_3 = 
mock(OperatorSubtaskState.class);
+                       
taskOperatorSubtaskStates1_1.putSubtaskStateByOperatorID(opID1, 
subtaskState1_1);
+                       
taskOperatorSubtaskStates1_2.putSubtaskStateByOperatorID(opID2, 
subtaskState1_2);
+                       
taskOperatorSubtaskStates1_3.putSubtaskStateByOperatorID(opID3, 
subtaskState1_3);
 
                        // acknowledge one of the three tasks
-                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId1, new 
CheckpointMetrics(), mock(SubtaskState.class)));
-                       OperatorSubtaskState subtaskState1_2 = 
operatorStates1.get(opID2).getState(ackVertex2.getParallelSubtaskIndex());
+                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId1, new 
CheckpointMetrics(), taskOperatorSubtaskStates1_2));
+
                        // start the second checkpoint
                        // trigger the first checkpoint. this should succeed
                        assertTrue(coord.triggerCheckpoint(timestamp2, false));
@@ -880,14 +874,17 @@ public class CheckpointCoordinatorTest extends TestLogger 
{
                        }
                        long checkpointId2 = pending2.getCheckpointId();
 
-                       Map<OperatorID, OperatorState> operatorStates2 = 
pending2.getOperatorStates();
+                       TaskStateSnapshot taskOperatorSubtaskStates2_1 = 
spy(new TaskStateSnapshot());
+                       TaskStateSnapshot taskOperatorSubtaskStates2_2 = 
spy(new TaskStateSnapshot());
+                       TaskStateSnapshot taskOperatorSubtaskStates2_3 = 
spy(new TaskStateSnapshot());
+
+                       OperatorSubtaskState subtaskState2_1 = 
mock(OperatorSubtaskState.class);
+                       OperatorSubtaskState subtaskState2_2 = 
mock(OperatorSubtaskState.class);
+                       OperatorSubtaskState subtaskState2_3 = 
mock(OperatorSubtaskState.class);
 
-                       operatorStates2.put(opID1, new 
SpyInjectingOperatorState(
-                               opID1, 
ackVertex1.getTotalNumberOfParallelSubtasks(), ackVertex1.getMaxParallelism()));
-                       operatorStates2.put(opID2, new 
SpyInjectingOperatorState(
-                               opID2, 
ackVertex2.getTotalNumberOfParallelSubtasks(), ackVertex2.getMaxParallelism()));
-                       operatorStates2.put(opID3, new 
SpyInjectingOperatorState(
-                               opID3, 
ackVertex3.getTotalNumberOfParallelSubtasks(), ackVertex3.getMaxParallelism()));
+                       
taskOperatorSubtaskStates2_1.putSubtaskStateByOperatorID(opID1, 
subtaskState2_1);
+                       
taskOperatorSubtaskStates2_2.putSubtaskStateByOperatorID(opID2, 
subtaskState2_2);
+                       
taskOperatorSubtaskStates2_3.putSubtaskStateByOperatorID(opID3, 
subtaskState2_3);
 
                        // trigger messages should have been sent
                        verify(triggerVertex1.getCurrentExecutionAttempt(), 
times(1)).triggerCheckpoint(eq(checkpointId2), eq(timestamp2), 
any(CheckpointOptions.class));
@@ -896,17 +893,13 @@ public class CheckpointCoordinatorTest extends TestLogger 
{
                        // we acknowledge one more task from the first 
checkpoint and the second
                        // checkpoint completely. The second checkpoint should 
then subsume the first checkpoint
 
-                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId2, new 
CheckpointMetrics(), mock(SubtaskState.class)));
-                       OperatorSubtaskState subtaskState2_3 = 
operatorStates2.get(opID3).getState(ackVertex3.getParallelSubtaskIndex());
+                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId2, new 
CheckpointMetrics(), taskOperatorSubtaskStates2_3));
 
-                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId2, new 
CheckpointMetrics(), mock(SubtaskState.class)));
-                       OperatorSubtaskState subtaskState2_1 = 
operatorStates2.get(opID1).getState(ackVertex1.getParallelSubtaskIndex());
+                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId2, new 
CheckpointMetrics(), taskOperatorSubtaskStates2_1));
 
-                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId1, new 
CheckpointMetrics(), mock(SubtaskState.class)));
-                       OperatorSubtaskState subtaskState1_1 = 
operatorStates1.get(opID1).getState(ackVertex1.getParallelSubtaskIndex());
+                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId1, new 
CheckpointMetrics(), taskOperatorSubtaskStates1_1));
 
-                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId2, new 
CheckpointMetrics(), mock(SubtaskState.class)));
-                       OperatorSubtaskState subtaskState2_2 = 
operatorStates2.get(opID2).getState(ackVertex2.getParallelSubtaskIndex());
+                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId2, new 
CheckpointMetrics(), taskOperatorSubtaskStates2_2));
 
                        // now, the second checkpoint should be confirmed, and 
the first discarded
                        // actually both pending checkpoints are discarded, and 
the second has been transformed
@@ -938,8 +931,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
                        verify(commitVertex.getCurrentExecutionAttempt(), 
times(1)).notifyCheckpointComplete(eq(checkpointId2), eq(timestamp2));
 
                        // send the last remaining ack for the first 
checkpoint. This should not do anything
-                       SubtaskState subtaskState1_3 = mock(SubtaskState.class);
-                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId1, new 
CheckpointMetrics(), subtaskState1_3));
+                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId1, new 
CheckpointMetrics(), taskOperatorSubtaskStates1_3));
                        verify(subtaskState1_3, times(1)).discardState();
 
                        coord.shutdown(JobStatus.FINISHED);
@@ -1005,13 +997,11 @@ public class CheckpointCoordinatorTest extends 
TestLogger {
 
                        OperatorID opID1 = 
OperatorID.fromJobVertexID(ackVertex1.getJobvertexId());
 
-                       Map<OperatorID, OperatorState> operatorStates = 
checkpoint.getOperatorStates();
+                       TaskStateSnapshot taskOperatorSubtaskStates1 = spy(new 
TaskStateSnapshot());
+                       OperatorSubtaskState subtaskState1 = 
mock(OperatorSubtaskState.class);
+                       
taskOperatorSubtaskStates1.putSubtaskStateByOperatorID(opID1, subtaskState1);
 
-                       operatorStates.put(opID1, new SpyInjectingOperatorState(
-                               opID1, 
ackVertex1.getTotalNumberOfParallelSubtasks(), ackVertex1.getMaxParallelism()));
-
-                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, ackAttemptID1, checkpoint.getCheckpointId(), new 
CheckpointMetrics(), mock(SubtaskState.class)));
-                       OperatorSubtaskState subtaskState = 
operatorStates.get(opID1).getState(ackVertex1.getParallelSubtaskIndex());
+                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, ackAttemptID1, checkpoint.getCheckpointId(), new 
CheckpointMetrics(), taskOperatorSubtaskStates1));
 
                        // wait until the checkpoint must have expired.
                        // we check every 250 msecs conservatively for 5 seconds
@@ -1029,7 +1019,7 @@ public class CheckpointCoordinatorTest extends TestLogger 
{
                        assertEquals(0, 
coord.getNumberOfRetainedSuccessfulCheckpoints());
 
                        // validate that the received states have been discarded
-                       verify(subtaskState, times(1)).discardState();
+                       verify(subtaskState1, times(1)).discardState();
 
                        // no confirm message must have been sent
                        verify(commitVertex.getCurrentExecutionAttempt(), 
times(0)).notifyCheckpointComplete(anyLong(), anyLong());
@@ -1147,26 +1137,18 @@ public class CheckpointCoordinatorTest extends 
TestLogger {
                long checkpointId = pendingCheckpoint.getCheckpointId();
 
                OperatorID opIDtrigger = 
OperatorID.fromJobVertexID(triggerVertex.getJobvertexId());
-               OperatorID opID1 = 
OperatorID.fromJobVertexID(ackVertex1.getJobvertexId());
-               OperatorID opID2 = 
OperatorID.fromJobVertexID(ackVertex2.getJobvertexId());
-
-               Map<OperatorID, OperatorState> operatorStates = 
pendingCheckpoint.getOperatorStates();
 
-               operatorStates.put(opIDtrigger, new SpyInjectingOperatorState(
-                       opIDtrigger, 
triggerVertex.getTotalNumberOfParallelSubtasks(), 
triggerVertex.getMaxParallelism()));
-               operatorStates.put(opID1, new SpyInjectingOperatorState(
-                       opID1, ackVertex1.getTotalNumberOfParallelSubtasks(), 
ackVertex1.getMaxParallelism()));
-               operatorStates.put(opID2, new SpyInjectingOperatorState(
-                       opID2, ackVertex2.getTotalNumberOfParallelSubtasks(), 
ackVertex2.getMaxParallelism()));
+               TaskStateSnapshot taskOperatorSubtaskStatesTrigger = spy(new 
TaskStateSnapshot());
+               OperatorSubtaskState subtaskStateTrigger = 
mock(OperatorSubtaskState.class);
+               
taskOperatorSubtaskStatesTrigger.putSubtaskStateByOperatorID(opIDtrigger, 
subtaskStateTrigger);
 
                // acknowledge the first trigger vertex
-               coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jobId, triggerAttemptId, checkpointId, new 
CheckpointMetrics(), mock(SubtaskState.class)));
-               OperatorSubtaskState storedTriggerSubtaskState = 
operatorStates.get(opIDtrigger).getState(triggerVertex.getParallelSubtaskIndex());
+               coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jobId, triggerAttemptId, checkpointId, new 
CheckpointMetrics(), taskOperatorSubtaskStatesTrigger));
 
                // verify that the subtask state has not been discarded
-               verify(storedTriggerSubtaskState, never()).discardState();
+               verify(subtaskStateTrigger, never()).discardState();
 
-               SubtaskState unknownSubtaskState = mock(SubtaskState.class);
+               TaskStateSnapshot unknownSubtaskState = 
mock(TaskStateSnapshot.class);
 
                // receive an acknowledge message for an unknown vertex
                coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jobId, new ExecutionAttemptID(), checkpointId, new 
CheckpointMetrics(), unknownSubtaskState));
@@ -1174,7 +1156,7 @@ public class CheckpointCoordinatorTest extends TestLogger 
{
                // we should discard acknowledge messages from an unknown 
vertex belonging to our job
                verify(unknownSubtaskState, times(1)).discardState();
 
-               SubtaskState differentJobSubtaskState = 
mock(SubtaskState.class);
+               TaskStateSnapshot differentJobSubtaskState = 
mock(TaskStateSnapshot.class);
 
                // receive an acknowledge message from an unknown job
                coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(new 
JobID(), new ExecutionAttemptID(), checkpointId, new CheckpointMetrics(), 
differentJobSubtaskState));
@@ -1183,22 +1165,22 @@ public class CheckpointCoordinatorTest extends 
TestLogger {
                verify(differentJobSubtaskState, never()).discardState();
 
                // duplicate acknowledge message for the trigger vertex
-               SubtaskState triggerSubtaskState = mock(SubtaskState.class);
+               TaskStateSnapshot triggerSubtaskState = 
mock(TaskStateSnapshot.class);
                coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jobId, triggerAttemptId, checkpointId, new 
CheckpointMetrics(), triggerSubtaskState));
 
                // duplicate acknowledge messages for a known vertex should not 
trigger discarding the state
                verify(triggerSubtaskState, never()).discardState();
 
                // let the checkpoint fail at the first ack vertex
-               reset(storedTriggerSubtaskState);
+               reset(subtaskStateTrigger);
                coord.receiveDeclineMessage(new DeclineCheckpoint(jobId, 
ackAttemptId1, checkpointId));
 
                assertTrue(pendingCheckpoint.isDiscarded());
 
                // check that we've cleaned up the already acknowledged state
-               verify(storedTriggerSubtaskState, times(1)).discardState();
+               verify(subtaskStateTrigger, times(1)).discardState();
 
-               SubtaskState ackSubtaskState = mock(SubtaskState.class);
+               TaskStateSnapshot ackSubtaskState = 
mock(TaskStateSnapshot.class);
 
                // late acknowledge message from the second ack vertex
                coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jobId, ackAttemptId2, checkpointId, new 
CheckpointMetrics(), ackSubtaskState));
@@ -1213,7 +1195,7 @@ public class CheckpointCoordinatorTest extends TestLogger 
{
                // we should not interfere with different jobs
                verify(differentJobSubtaskState, never()).discardState();
 
-               SubtaskState unknownSubtaskState2 = mock(SubtaskState.class);
+               TaskStateSnapshot unknownSubtaskState2 = 
mock(TaskStateSnapshot.class);
 
                // receive an acknowledge message for an unknown vertex
                coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jobId, new ExecutionAttemptID(), checkpointId, new 
CheckpointMetrics(), unknownSubtaskState2));
@@ -1470,18 +1452,16 @@ public class CheckpointCoordinatorTest extends 
TestLogger {
 
                OperatorID opID1 = 
OperatorID.fromJobVertexID(vertex1.getJobvertexId());
                OperatorID opID2 = 
OperatorID.fromJobVertexID(vertex2.getJobvertexId());
-
-               Map<OperatorID, OperatorState> operatorStates = 
pending.getOperatorStates();
-
-               operatorStates.put(opID1, new SpyInjectingOperatorState(
-                       opID1, vertex1.getTotalNumberOfParallelSubtasks(), 
vertex1.getMaxParallelism()));
-               operatorStates.put(opID2, new SpyInjectingOperatorState(
-                       opID2, vertex2.getTotalNumberOfParallelSubtasks(), 
vertex1.getMaxParallelism()));
+               TaskStateSnapshot taskOperatorSubtaskStates1 = 
mock(TaskStateSnapshot.class);
+               TaskStateSnapshot taskOperatorSubtaskStates2 = 
mock(TaskStateSnapshot.class);
+               OperatorSubtaskState subtaskState1 = 
mock(OperatorSubtaskState.class);
+               OperatorSubtaskState subtaskState2 = 
mock(OperatorSubtaskState.class);
+               
when(taskOperatorSubtaskStates1.getSubtaskStateByOperatorID(opID1)).thenReturn(subtaskState1);
+               
when(taskOperatorSubtaskStates2.getSubtaskStateByOperatorID(opID2)).thenReturn(subtaskState2);
 
                // acknowledge from one of the tasks
-               AcknowledgeCheckpoint acknowledgeCheckpoint2 = new 
AcknowledgeCheckpoint(jid, attemptID2, checkpointId, new CheckpointMetrics(), 
mock(SubtaskState.class));
+               AcknowledgeCheckpoint acknowledgeCheckpoint2 = new 
AcknowledgeCheckpoint(jid, attemptID2, checkpointId, new CheckpointMetrics(), 
taskOperatorSubtaskStates2);
                coord.receiveAcknowledgeMessage(acknowledgeCheckpoint2);
-               OperatorSubtaskState subtaskState2 = 
operatorStates.get(opID2).getState(vertex2.getParallelSubtaskIndex());
                assertEquals(1, pending.getNumberOfAcknowledgedTasks());
                assertEquals(1, pending.getNumberOfNonAcknowledgedTasks());
                assertFalse(pending.isDiscarded());
@@ -1495,8 +1475,7 @@ public class CheckpointCoordinatorTest extends TestLogger 
{
                assertFalse(savepointFuture.isDone());
 
                // acknowledge the other task.
-               coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, 
attemptID1, checkpointId, new CheckpointMetrics(), mock(SubtaskState.class)));
-               OperatorSubtaskState subtaskState1 = 
operatorStates.get(opID1).getState(vertex1.getParallelSubtaskIndex());
+               coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, 
attemptID1, checkpointId, new CheckpointMetrics(), taskOperatorSubtaskStates1));
 
                // the checkpoint is internally converted to a successful 
checkpoint and the
                // pending checkpoint object is disposed
@@ -1536,9 +1515,6 @@ public class CheckpointCoordinatorTest extends TestLogger 
{
                coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, 
attemptID1, checkpointIdNew));
                coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, 
attemptID2, checkpointIdNew));
 
-               subtaskState1 = 
operatorStates.get(opID1).getState(vertex1.getParallelSubtaskIndex());
-               subtaskState2 = 
operatorStates.get(opID2).getState(vertex2.getParallelSubtaskIndex());
-
                assertEquals(0, coord.getNumberOfPendingCheckpoints());
                assertEquals(0, 
coord.getNumberOfRetainedSuccessfulCheckpoints());
 
@@ -2037,20 +2013,8 @@ public class CheckpointCoordinatorTest extends 
TestLogger {
                List<KeyGroupRange> keyGroupPartitions1 = 
StateAssignmentOperation.createKeyGroupPartitions(maxParallelism1, 
parallelism1);
                List<KeyGroupRange> keyGroupPartitions2 = 
StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, 
parallelism2);
 
-               PendingCheckpoint pending = 
coord.getPendingCheckpoints().get(checkpointId);
-
-               OperatorID opID1 = OperatorID.fromJobVertexID(jobVertexID1);
-               OperatorID opID2 = OperatorID.fromJobVertexID(jobVertexID2);
-
-               Map<OperatorID, OperatorState> operatorStates = 
pending.getOperatorStates();
-
-               operatorStates.put(opID1, new SpyInjectingOperatorState(
-                       opID1, jobVertex1.getParallelism(), 
jobVertex1.getMaxParallelism()));
-               operatorStates.put(opID2, new SpyInjectingOperatorState(
-                       opID2, jobVertex2.getParallelism(), 
jobVertex2.getMaxParallelism()));
-
                for (int index = 0; index < jobVertex1.getParallelism(); 
index++) {
-                       SubtaskState subtaskState = 
mockSubtaskState(jobVertexID1, index, keyGroupPartitions1.get(index));
+                       TaskStateSnapshot subtaskState = 
mockSubtaskState(jobVertexID1, index, keyGroupPartitions1.get(index));
 
                        AcknowledgeCheckpoint acknowledgeCheckpoint = new 
AcknowledgeCheckpoint(
                                        jid,
@@ -2063,7 +2027,7 @@ public class CheckpointCoordinatorTest extends TestLogger 
{
                }
 
                for (int index = 0; index < jobVertex2.getParallelism(); 
index++) {
-                       SubtaskState subtaskState = 
mockSubtaskState(jobVertexID2, index, keyGroupPartitions2.get(index));
+                       TaskStateSnapshot subtaskState = 
mockSubtaskState(jobVertexID2, index, keyGroupPartitions2.get(index));
 
                        AcknowledgeCheckpoint acknowledgeCheckpoint = new 
AcknowledgeCheckpoint(
                                        jid,
@@ -2165,30 +2129,34 @@ public class CheckpointCoordinatorTest extends 
TestLogger {
                List<KeyGroupRange> keyGroupPartitions2 = 
StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, 
parallelism2);
 
                for (int index = 0; index < jobVertex1.getParallelism(); 
index++) {
-                       ChainedStateHandle<StreamStateHandle> valueSizeTuple = 
generateStateForVertex(jobVertexID1, index);
+                       StreamStateHandle valueSizeTuple = 
generateStateForVertex(jobVertexID1, index);
                        KeyGroupsStateHandle keyGroupState = 
generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index), false);
-                       SubtaskState checkpointStateHandles = new 
SubtaskState(valueSizeTuple, null, null, keyGroupState, null);
+                       OperatorSubtaskState operatorSubtaskState = new 
OperatorSubtaskState(valueSizeTuple, null, null, keyGroupState, null);
+                       TaskStateSnapshot taskOperatorSubtaskStates = new 
TaskStateSnapshot();
+                       
taskOperatorSubtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID1),
 operatorSubtaskState);
                        AcknowledgeCheckpoint acknowledgeCheckpoint = new 
AcknowledgeCheckpoint(
                                        jid,
                                        
jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
                                        checkpointId,
                                        new CheckpointMetrics(),
-                                       checkpointStateHandles);
+                               taskOperatorSubtaskStates);
 
                        coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
                }
 
 
                for (int index = 0; index < jobVertex2.getParallelism(); 
index++) {
-                       ChainedStateHandle<StreamStateHandle> valueSizeTuple = 
generateStateForVertex(jobVertexID2, index);
+                       StreamStateHandle valueSizeTuple = 
generateStateForVertex(jobVertexID2, index);
                        KeyGroupsStateHandle keyGroupState = 
generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index), false);
-                       SubtaskState checkpointStateHandles = new 
SubtaskState(valueSizeTuple, null, null, keyGroupState, null);
+                       OperatorSubtaskState operatorSubtaskState = new 
OperatorSubtaskState(valueSizeTuple, null, null, keyGroupState, null);
+                       TaskStateSnapshot taskOperatorSubtaskStates = new 
TaskStateSnapshot();
+                       
taskOperatorSubtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID2),
 operatorSubtaskState);
                        AcknowledgeCheckpoint acknowledgeCheckpoint = new 
AcknowledgeCheckpoint(
                                        jid,
                                        
jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
                                        checkpointId,
                                        new CheckpointMetrics(),
-                                       checkpointStateHandles);
+                                       taskOperatorSubtaskStates);
 
                        coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
                }
@@ -2284,17 +2252,20 @@ public class CheckpointCoordinatorTest extends 
TestLogger {
                                
StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, 
parallelism2);
 
                for (int index = 0; index < jobVertex1.getParallelism(); 
index++) {
-                       ChainedStateHandle<StreamStateHandle> valueSizeTuple = 
generateStateForVertex(jobVertexID1, index);
+                       StreamStateHandle valueSizeTuple = 
generateStateForVertex(jobVertexID1, index);
                        KeyGroupsStateHandle keyGroupState = 
generateKeyGroupState(
                                        jobVertexID1, 
keyGroupPartitions1.get(index), false);
 
-                       SubtaskState checkpointStateHandles = new 
SubtaskState(valueSizeTuple, null, null, keyGroupState, null);
+                       OperatorSubtaskState operatorSubtaskState = new 
OperatorSubtaskState(valueSizeTuple, null, null, keyGroupState, null);
+                       TaskStateSnapshot taskOperatorSubtaskStates = new 
TaskStateSnapshot();
+                       
taskOperatorSubtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID1),
 operatorSubtaskState);
+
                        AcknowledgeCheckpoint acknowledgeCheckpoint = new 
AcknowledgeCheckpoint(
                                        jid,
                                        
jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
                                        checkpointId,
                                        new CheckpointMetrics(),
-                                       checkpointStateHandles);
+                                       taskOperatorSubtaskStates);
 
                        coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
                }
@@ -2302,17 +2273,19 @@ public class CheckpointCoordinatorTest extends 
TestLogger {
 
                for (int index = 0; index < jobVertex2.getParallelism(); 
index++) {
 
-                       ChainedStateHandle<StreamStateHandle> state = 
generateStateForVertex(jobVertexID2, index);
+                       StreamStateHandle state = 
generateStateForVertex(jobVertexID2, index);
                        KeyGroupsStateHandle keyGroupState = 
generateKeyGroupState(
                                        jobVertexID2, 
keyGroupPartitions2.get(index), false);
 
-                       SubtaskState checkpointStateHandles = new 
SubtaskState(state, null, null, keyGroupState, null);
+                       OperatorSubtaskState operatorSubtaskState = new 
OperatorSubtaskState(state, null, null, keyGroupState, null);
+                       TaskStateSnapshot taskOperatorSubtaskStates = new 
TaskStateSnapshot();
+                       
taskOperatorSubtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID2),
 operatorSubtaskState);
                        AcknowledgeCheckpoint acknowledgeCheckpoint = new 
AcknowledgeCheckpoint(
                                        jid,
                                        
jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
                                        checkpointId,
                                        new CheckpointMetrics(),
-                                       checkpointStateHandles);
+                                       taskOperatorSubtaskStates);
 
                        coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
                }
@@ -2438,18 +2411,21 @@ public class CheckpointCoordinatorTest extends 
TestLogger {
 
                //vertex 1
                for (int index = 0; index < jobVertex1.getParallelism(); 
index++) {
-                       ChainedStateHandle<StreamStateHandle> valueSizeTuple = 
generateStateForVertex(jobVertexID1, index);
-                       ChainedStateHandle<OperatorStateHandle> opStateBackend 
= generateChainedPartitionableStateHandle(jobVertexID1, index, 2, 8, false);
+                       StreamStateHandle valueSizeTuple = 
generateStateForVertex(jobVertexID1, index);
+                       OperatorStateHandle opStateBackend = 
generatePartitionableStateHandle(jobVertexID1, index, 2, 8, false);
                        KeyGroupsStateHandle keyedStateBackend = 
generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index), false);
                        KeyGroupsStateHandle keyedStateRaw = 
generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index), true);
 
-                       SubtaskState checkpointStateHandles = new 
SubtaskState(valueSizeTuple, opStateBackend, null, keyedStateBackend, 
keyedStateRaw);
+                       OperatorSubtaskState operatorSubtaskState = new 
OperatorSubtaskState(valueSizeTuple, opStateBackend, null, keyedStateBackend, 
keyedStateRaw);
+                       TaskStateSnapshot taskOperatorSubtaskStates = new 
TaskStateSnapshot();
+                       
taskOperatorSubtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID1),
 operatorSubtaskState);
+
                        AcknowledgeCheckpoint acknowledgeCheckpoint = new 
AcknowledgeCheckpoint(
                                        jid,
                                        
jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
                                        checkpointId,
                                        new CheckpointMetrics(),
-                                       checkpointStateHandles);
+                                       taskOperatorSubtaskStates);
 
                        coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
                }
@@ -2460,19 +2436,21 @@ public class CheckpointCoordinatorTest extends 
TestLogger {
                for (int index = 0; index < jobVertex2.getParallelism(); 
index++) {
                        KeyGroupsStateHandle keyedStateBackend = 
generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index), false);
                        KeyGroupsStateHandle keyedStateRaw = 
generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index), true);
-                       ChainedStateHandle<OperatorStateHandle> opStateBackend 
= generateChainedPartitionableStateHandle(jobVertexID2, index, 2, 8, false);
-                       ChainedStateHandle<OperatorStateHandle> opStateRaw = 
generateChainedPartitionableStateHandle(jobVertexID2, index, 2, 8, true);
-                       expectedOpStatesBackend.add(opStateBackend);
-                       expectedOpStatesRaw.add(opStateRaw);
-                       SubtaskState checkpointStateHandles =
-                                       new SubtaskState(new 
ChainedStateHandle<>(
-                                                       
Collections.<StreamStateHandle>singletonList(null)), opStateBackend, 
opStateRaw, keyedStateBackend, keyedStateRaw);
+                       OperatorStateHandle opStateBackend = 
generatePartitionableStateHandle(jobVertexID2, index, 2, 8, false);
+                       OperatorStateHandle opStateRaw = 
generatePartitionableStateHandle(jobVertexID2, index, 2, 8, true);
+                       expectedOpStatesBackend.add(new 
ChainedStateHandle<>(Collections.singletonList(opStateBackend)));
+                       expectedOpStatesRaw.add(new 
ChainedStateHandle<>(Collections.singletonList(opStateRaw)));
+
+                       OperatorSubtaskState operatorSubtaskState = new 
OperatorSubtaskState(null, opStateBackend, opStateRaw, keyedStateBackend, 
keyedStateRaw);
+                       TaskStateSnapshot taskOperatorSubtaskStates = new 
TaskStateSnapshot();
+                       
taskOperatorSubtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID2),
 operatorSubtaskState);
+
                        AcknowledgeCheckpoint acknowledgeCheckpoint = new 
AcknowledgeCheckpoint(
                                        jid,
                                        
jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
                                        checkpointId,
                                        new CheckpointMetrics(),
-                                       checkpointStateHandles);
+                                       taskOperatorSubtaskStates);
 
                        coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
                }
@@ -2506,27 +2484,37 @@ public class CheckpointCoordinatorTest extends 
TestLogger {
                List<List<Collection<OperatorStateHandle>>> 
actualOpStatesBackend = new ArrayList<>(newJobVertex2.getParallelism());
                List<List<Collection<OperatorStateHandle>>> actualOpStatesRaw = 
new ArrayList<>(newJobVertex2.getParallelism());
                for (int i = 0; i < newJobVertex2.getParallelism(); i++) {
-                       KeyGroupsStateHandle originalKeyedStateBackend = 
generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i), false);
-                       KeyGroupsStateHandle originalKeyedStateRaw = 
generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i), true);
 
-                       TaskStateHandles taskStateHandles = 
newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateHandles();
+                       List<OperatorID> operatorIDs = 
newJobVertex2.getOperatorIDs();
 
-                       ChainedStateHandle<StreamStateHandle> operatorState = 
taskStateHandles.getLegacyOperatorState();
-                       List<Collection<OperatorStateHandle>> opStateBackend = 
taskStateHandles.getManagedOperatorState();
-                       List<Collection<OperatorStateHandle>> opStateRaw = 
taskStateHandles.getRawOperatorState();
-                       Collection<KeyedStateHandle> keyedStateBackend = 
taskStateHandles.getManagedKeyedState();
-                       Collection<KeyedStateHandle> keyGroupStateRaw = 
taskStateHandles.getRawKeyedState();
+                       KeyGroupsStateHandle originalKeyedStateBackend = 
generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i), false);
+                       KeyGroupsStateHandle originalKeyedStateRaw = 
generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i), true);
 
-                       actualOpStatesBackend.add(opStateBackend);
-                       actualOpStatesRaw.add(opStateRaw);
-                       // the 'non partition state' is not null because it is 
recombined.
-                       assertNotNull(operatorState);
-                       for (int index = 0; index < operatorState.getLength(); 
index++) {
-                               assertNull(operatorState.get(index));
+                       TaskStateSnapshot taskStateHandles = 
newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateSnapshot();
+
+                       final int headOpIndex = operatorIDs.size() - 1;
+                       List<Collection<OperatorStateHandle>> 
allParallelManagedOpStates = new ArrayList<>(operatorIDs.size());
+                       List<Collection<OperatorStateHandle>> 
allParallelRawOpStates = new ArrayList<>(operatorIDs.size());
+
+                       for (int idx = 0; idx < operatorIDs.size(); ++idx) {
+                               OperatorID operatorID = operatorIDs.get(idx);
+                               OperatorSubtaskState opState = 
taskStateHandles.getSubtaskStateByOperatorID(operatorID);
+                               
Assert.assertNull(opState.getLegacyOperatorState());
+                               Collection<OperatorStateHandle> opStateBackend 
= opState.getManagedOperatorState();
+                               Collection<OperatorStateHandle> opStateRaw = 
opState.getRawOperatorState();
+                               allParallelManagedOpStates.add(opStateBackend);
+                               allParallelRawOpStates.add(opStateRaw);
+                               if (idx == headOpIndex) {
+                                       Collection<KeyedStateHandle> 
keyedStateBackend = opState.getManagedKeyedState();
+                                       Collection<KeyedStateHandle> 
keyGroupStateRaw = opState.getRawKeyedState();
+                                       
compareKeyedState(Collections.singletonList(originalKeyedStateBackend), 
keyedStateBackend);
+                                       
compareKeyedState(Collections.singletonList(originalKeyedStateRaw), 
keyGroupStateRaw);
+                               }
                        }
-                       
compareKeyedState(Collections.singletonList(originalKeyedStateBackend), 
keyedStateBackend);
-                       
compareKeyedState(Collections.singletonList(originalKeyedStateRaw), 
keyGroupStateRaw);
+                       actualOpStatesBackend.add(allParallelManagedOpStates);
+                       actualOpStatesRaw.add(allParallelRawOpStates);
                }
+
                comparePartitionableState(expectedOpStatesBackend, 
actualOpStatesBackend);
                comparePartitionableState(expectedOpStatesRaw, 
actualOpStatesRaw);
        }
@@ -2578,14 +2566,11 @@ public class CheckpointCoordinatorTest extends 
TestLogger {
                        operatorStates.put(id.f1, taskState);
                        for (int index = 0; index < taskState.getParallelism(); 
index++) {
                                StreamStateHandle subNonPartitionedState = 
-                                       generateStateForVertex(id.f0, index)
-                                               .get(0);
+                                       generateStateForVertex(id.f0, index);
                                OperatorStateHandle subManagedOperatorState =
-                                       
generateChainedPartitionableStateHandle(id.f0, index, 2, 8, false)
-                                               .get(0);
+                                       generatePartitionableStateHandle(id.f0, 
index, 2, 8, false);
                                OperatorStateHandle subRawOperatorState =
-                                       
generateChainedPartitionableStateHandle(id.f0, index, 2, 8, true)
-                                               .get(0);
+                                       generatePartitionableStateHandle(id.f0, 
index, 2, 8, true);
 
                                OperatorSubtaskState subtaskState = new 
OperatorSubtaskState(subNonPartitionedState,
                                        subManagedOperatorState,
@@ -2707,57 +2692,75 @@ public class CheckpointCoordinatorTest extends 
TestLogger {
 
                for (int i = 0; i < newJobVertex1.getParallelism(); i++) {
 
-                       TaskStateHandles taskStateHandles = 
newJobVertex1.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateHandles();
-                       ChainedStateHandle<StreamStateHandle> 
actualSubNonPartitionedState = taskStateHandles.getLegacyOperatorState();
-                       List<Collection<OperatorStateHandle>> 
actualSubManagedOperatorState = taskStateHandles.getManagedOperatorState();
-                       List<Collection<OperatorStateHandle>> 
actualSubRawOperatorState = taskStateHandles.getRawOperatorState();
+                       final List<OperatorID> operatorIds = 
newJobVertex1.getOperatorIDs();
 
-                       assertNull(taskStateHandles.getManagedKeyedState());
-                       assertNull(taskStateHandles.getRawKeyedState());
+                       TaskStateSnapshot stateSnapshot = 
newJobVertex1.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateSnapshot();
+
+                       OperatorSubtaskState headOpState = 
stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIds.size() - 
1));
+                       
assertTrue(headOpState.getManagedKeyedState().isEmpty());
+                       assertTrue(headOpState.getRawKeyedState().isEmpty());
 
                        // operator5
                        {
                                int operatorIndexInChain = 2;
-                               
assertNull(actualSubNonPartitionedState.get(operatorIndexInChain));
-                               
assertNull(actualSubManagedOperatorState.get(operatorIndexInChain));
-                               
assertNull(actualSubRawOperatorState.get(operatorIndexInChain));
+                               OperatorSubtaskState opState =
+                                       
stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIndexInChain));
+
+                               assertNull(opState.getLegacyOperatorState());
+                               
assertTrue(opState.getManagedOperatorState().isEmpty());
+                               
assertTrue(opState.getRawOperatorState().isEmpty());
                        }
                        // operator1
                        {
                                int operatorIndexInChain = 1;
-                               ChainedStateHandle<StreamStateHandle> 
expectSubNonPartitionedState = generateStateForVertex(id1.f0, i);
-                               ChainedStateHandle<OperatorStateHandle> 
expectedManagedOpState = generateChainedPartitionableStateHandle(
+                               OperatorSubtaskState opState =
+                                       
stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIndexInChain));
+
+                               StreamStateHandle expectSubNonPartitionedState 
= generateStateForVertex(id1.f0, i);
+                               OperatorStateHandle expectedManagedOpState = 
generatePartitionableStateHandle(
                                        id1.f0, i, 2, 8, false);
-                               ChainedStateHandle<OperatorStateHandle> 
expectedRawOpState = generateChainedPartitionableStateHandle(
+                               OperatorStateHandle expectedRawOpState = 
generatePartitionableStateHandle(
                                        id1.f0, i, 2, 8, true);
 
                                assertTrue(CommonTestUtils.isSteamContentEqual(
-                                       
expectSubNonPartitionedState.get(0).openInputStream(),
-                                       
actualSubNonPartitionedState.get(operatorIndexInChain).openInputStream()));
-
-                               
assertTrue(CommonTestUtils.isSteamContentEqual(expectedManagedOpState.get(0).openInputStream(),
-                                       
actualSubManagedOperatorState.get(operatorIndexInChain).iterator().next().openInputStream()));
-
-                               
assertTrue(CommonTestUtils.isSteamContentEqual(expectedRawOpState.get(0).openInputStream(),
-                                       
actualSubRawOperatorState.get(operatorIndexInChain).iterator().next().openInputStream()));
+                                       
expectSubNonPartitionedState.openInputStream(),
+                                       
opState.getLegacyOperatorState().openInputStream()));
+
+                               Collection<OperatorStateHandle> 
managedOperatorState = opState.getManagedOperatorState();
+                               assertEquals(1, managedOperatorState.size());
+                               
assertTrue(CommonTestUtils.isSteamContentEqual(expectedManagedOpState.openInputStream(),
+                                       
managedOperatorState.iterator().next().openInputStream()));
+
+                               Collection<OperatorStateHandle> 
rawOperatorState = opState.getRawOperatorState();
+                               assertEquals(1, rawOperatorState.size());
+                               
assertTrue(CommonTestUtils.isSteamContentEqual(expectedRawOpState.openInputStream(),
+                                       
rawOperatorState.iterator().next().openInputStream()));
                        }
                        // operator2
                        {
                                int operatorIndexInChain = 0;
-                               ChainedStateHandle<StreamStateHandle> 
expectSubNonPartitionedState = generateStateForVertex(id2.f0, i);
-                               ChainedStateHandle<OperatorStateHandle> 
expectedManagedOpState = generateChainedPartitionableStateHandle(
+                               OperatorSubtaskState opState =
+                                       
stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIndexInChain));
+
+                               StreamStateHandle expectSubNonPartitionedState 
= generateStateForVertex(id2.f0, i);
+                               OperatorStateHandle expectedManagedOpState = 
generatePartitionableStateHandle(
                                        id2.f0, i, 2, 8, false);
-                               ChainedStateHandle<OperatorStateHandle> 
expectedRawOpState = generateChainedPartitionableStateHandle(
+                               OperatorStateHandle expectedRawOpState = 
generatePartitionableStateHandle(
                                        id2.f0, i, 2, 8, true);
 
-                               
assertTrue(CommonTestUtils.isSteamContentEqual(expectSubNonPartitionedState.get(0).openInputStream(),
-                                       
actualSubNonPartitionedState.get(operatorIndexInChain).openInputStream()));
-
-                               
assertTrue(CommonTestUtils.isSteamContentEqual(expectedManagedOpState.get(0).openInputStream(),
-                                       
actualSubManagedOperatorState.get(operatorIndexInChain).iterator().next().openInputStream()));
-
-                               
assertTrue(CommonTestUtils.isSteamContentEqual(expectedRawOpState.get(0).openInputStream(),
-                                       
actualSubRawOperatorState.get(operatorIndexInChain).iterator().next().openInputStream()));
+                               assertTrue(CommonTestUtils.isSteamContentEqual(
+                                       
expectSubNonPartitionedState.openInputStream(),
+                                       
opState.getLegacyOperatorState().openInputStream()));
+
+                               Collection<OperatorStateHandle> 
managedOperatorState = opState.getManagedOperatorState();
+                               assertEquals(1, managedOperatorState.size());
+                               
assertTrue(CommonTestUtils.isSteamContentEqual(expectedManagedOpState.openInputStream(),
+                                       
managedOperatorState.iterator().next().openInputStream()));
+
+                               Collection<OperatorStateHandle> 
rawOperatorState = opState.getRawOperatorState();
+                               assertEquals(1, rawOperatorState.size());
+                               
assertTrue(CommonTestUtils.isSteamContentEqual(expectedRawOpState.openInputStream(),
+                                       
rawOperatorState.iterator().next().openInputStream()));
                        }
                }
 
@@ -2765,38 +2768,48 @@ public class CheckpointCoordinatorTest extends 
TestLogger {
                List<List<Collection<OperatorStateHandle>>> 
actualRawOperatorStates = new ArrayList<>(newJobVertex2.getParallelism());
 
                for (int i = 0; i < newJobVertex2.getParallelism(); i++) {
-                       TaskStateHandles taskStateHandles = 
newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateHandles();
+
+                       final List<OperatorID> operatorIds = 
newJobVertex2.getOperatorIDs();
+
+                       TaskStateSnapshot stateSnapshot = 
newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateSnapshot();
 
                        // operator 3
                        {
                                int operatorIndexInChain = 1;
+                               OperatorSubtaskState opState =
+                                       
stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIndexInChain));
+
                                List<Collection<OperatorStateHandle>> 
actualSubManagedOperatorState = new ArrayList<>(1);
-                               
actualSubManagedOperatorState.add(taskStateHandles.getManagedOperatorState().get(operatorIndexInChain));
+                               
actualSubManagedOperatorState.add(opState.getManagedOperatorState());
 
                                List<Collection<OperatorStateHandle>> 
actualSubRawOperatorState = new ArrayList<>(1);
-                               
actualSubRawOperatorState.add(taskStateHandles.getRawOperatorState().get(operatorIndexInChain));
+                               
actualSubRawOperatorState.add(opState.getRawOperatorState());
 
                                
actualManagedOperatorStates.add(actualSubManagedOperatorState);
                                
actualRawOperatorStates.add(actualSubRawOperatorState);
 
-                               
assertNull(taskStateHandles.getLegacyOperatorState().get(operatorIndexInChain));
+                               assertNull(opState.getLegacyOperatorState());
                        }
 
                        // operator 6
                        {
                                int operatorIndexInChain = 0;
-                               
assertNull(taskStateHandles.getManagedOperatorState().get(operatorIndexInChain));
-                               
assertNull(taskStateHandles.getRawOperatorState().get(operatorIndexInChain));
-                               
assertNull(taskStateHandles.getLegacyOperatorState().get(operatorIndexInChain));
+                               OperatorSubtaskState opState =
+                                       
stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIndexInChain));
+                               assertNull(opState.getLegacyOperatorState());
+                               
assertTrue(opState.getManagedOperatorState().isEmpty());
+                               
assertTrue(opState.getRawOperatorState().isEmpty());
 
                        }
 
                        KeyGroupsStateHandle originalKeyedStateBackend = 
generateKeyGroupState(id3.f0, newKeyGroupPartitions2.get(i), false);
                        KeyGroupsStateHandle originalKeyedStateRaw = 
generateKeyGroupState(id3.f0, newKeyGroupPartitions2.get(i), true);
 
+                       OperatorSubtaskState headOpState =
+                               
stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIds.size() - 
1));
 
-                       Collection<KeyedStateHandle> keyedStateBackend = 
taskStateHandles.getManagedKeyedState();
-                       Collection<KeyedStateHandle> keyGroupStateRaw = 
taskStateHandles.getRawKeyedState();
+                       Collection<KeyedStateHandle> keyedStateBackend = 
headOpState.getManagedKeyedState();
+                       Collection<KeyedStateHandle> keyGroupStateRaw = 
headOpState.getRawKeyedState();
 
 
                        
compareKeyedState(Collections.singletonList(originalKeyedStateBackend), 
keyedStateBackend);
@@ -2974,19 +2987,50 @@ public class CheckpointCoordinatorTest extends 
TestLogger {
                return new Tuple2<>(allSerializedValuesConcatenated, offsets);
        }
 
-       public static ChainedStateHandle<StreamStateHandle> 
generateStateForVertex(
+       public static StreamStateHandle generateStateForVertex(
                        JobVertexID jobVertexID,
                        int index) throws IOException {
 
                Random random = new Random(jobVertexID.hashCode() + index);
                int value = random.nextInt();
-               return generateChainedStateHandle(value);
+               return generateStreamStateHandle(value);
+       }
+
+       public static StreamStateHandle generateStreamStateHandle(Serializable 
value) throws IOException {
+               return 
TestByteStreamStateHandleDeepCompare.fromSerializable(String.valueOf(UUID.randomUUID()),
 value);
        }
 
        public static ChainedStateHandle<StreamStateHandle> 
generateChainedStateHandle(
                        Serializable value) throws IOException {
                return ChainedStateHandle.wrapSingleHandle(
-                               
TestByteStreamStateHandleDeepCompare.fromSerializable(String.valueOf(UUID.randomUUID()),
 value));
+                               generateStreamStateHandle(value));
+       }
+
+       public static OperatorStateHandle generatePartitionableStateHandle(
+               JobVertexID jobVertexID,
+               int index,
+               int namedStates,
+               int partitionsPerState,
+               boolean rawState) throws IOException {
+
+               Map<String, List<? extends Serializable>> statesListsMap = new 
HashMap<>(namedStates);
+
+               for (int i = 0; i < namedStates; ++i) {
+                       List<Integer> testStatesLists = new 
ArrayList<>(partitionsPerState);
+                       // generate state
+                       int seed = jobVertexID.hashCode() * index + i * 
namedStates;
+                       if (rawState) {
+                               seed = (seed + 1) * 31;
+                       }
+                       Random random = new Random(seed);
+                       for (int j = 0; j < partitionsPerState; ++j) {
+                               int simulatedStateValue = random.nextInt();
+                               testStatesLists.add(simulatedStateValue);
+                       }
+                       statesListsMap.put("state-" + i, testStatesLists);
+               }
+
+               return generatePartitionableStateHandle(statesListsMap);
        }
 
        public static ChainedStateHandle<OperatorStateHandle> 
generateChainedPartitionableStateHandle(
@@ -3013,11 +3057,11 @@ public class CheckpointCoordinatorTest extends 
TestLogger {
                        statesListsMap.put("state-" + i, testStatesLists);
                }
 
-               return generateChainedPartitionableStateHandle(statesListsMap);
+               return 
ChainedStateHandle.wrapSingleHandle(generatePartitionableStateHandle(statesListsMap));
        }
 
-       private static ChainedStateHandle<OperatorStateHandle> 
generateChainedPartitionableStateHandle(
-                       Map<String, List<? extends Serializable>> states) 
throws IOException {
+       private static OperatorStateHandle generatePartitionableStateHandle(
+               Map<String, List<? extends Serializable>> states) throws 
IOException {
 
                List<List<? extends Serializable>> namedStateSerializables = 
new ArrayList<>(states.size());
 
@@ -3032,20 +3076,18 @@ public class CheckpointCoordinatorTest extends 
TestLogger {
                int idx = 0;
                for (Map.Entry<String, List<? extends Serializable>> entry : 
states.entrySet()) {
                        offsetsMap.put(
-                                       entry.getKey(),
-                                       new OperatorStateHandle.StateMetaInfo(
-                                                       
serializationWithOffsets.f1.get(idx),
-                                                       
OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
+                               entry.getKey(),
+                               new OperatorStateHandle.StateMetaInfo(
+                                       serializationWithOffsets.f1.get(idx),
+                                       
OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
                        ++idx;
                }
 
                ByteStreamStateHandle streamStateHandle = new 
TestByteStreamStateHandleDeepCompare(
-                               String.valueOf(UUID.randomUUID()),
-                               serializationWithOffsets.f0);
+                       String.valueOf(UUID.randomUUID()),
+                       serializationWithOffsets.f0);
 
-               OperatorStateHandle operatorStateHandle =
-                               new OperatorStateHandle(offsetsMap, 
streamStateHandle);
-               return ChainedStateHandle.wrapSingleHandle(operatorStateHandle);
+               return new OperatorStateHandle(offsetsMap, streamStateHandle);
        }
 
        static ExecutionJobVertex mockExecutionJobVertex(
@@ -3139,24 +3181,23 @@ public class CheckpointCoordinatorTest extends 
TestLogger {
                return vertex;
        }
 
-       static SubtaskState mockSubtaskState(
+       static TaskStateSnapshot mockSubtaskState(
                JobVertexID jobVertexID,
                int index,
                KeyGroupRange keyGroupRange) throws IOException {
 
-               ChainedStateHandle<StreamStateHandle> nonPartitionedState = 
generateStateForVertex(jobVertexID, index);
-               ChainedStateHandle<OperatorStateHandle> partitionableState = 
generateChainedPartitionableStateHandle(jobVertexID, index, 2, 8, false);
+               StreamStateHandle nonPartitionedState = 
generateStateForVertex(jobVertexID, index);
+               OperatorStateHandle partitionableState = 
generatePartitionableStateHandle(jobVertexID, index, 2, 8, false);
                KeyGroupsStateHandle partitionedKeyGroupState = 
generateKeyGroupState(jobVertexID, keyGroupRange, false);
 
-               SubtaskState subtaskState = mock(SubtaskState.class, 
withSettings().serializable());
+               TaskStateSnapshot subtaskStates = spy(new TaskStateSnapshot());
+               OperatorSubtaskState subtaskState = spy(new 
OperatorSubtaskState(
+                       nonPartitionedState, partitionableState, null, 
partitionedKeyGroupState, null)
+               );
 
-               
doReturn(nonPartitionedState).when(subtaskState).getLegacyOperatorState();
-               
doReturn(partitionableState).when(subtaskState).getManagedOperatorState();
-               doReturn(null).when(subtaskState).getRawOperatorState();
-               
doReturn(partitionedKeyGroupState).when(subtaskState).getManagedKeyedState();
-               doReturn(null).when(subtaskState).getRawKeyedState();
+               
subtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID),
 subtaskState);
 
-               return subtaskState;
+               return subtaskStates;
        }
 
        public static void verifyStateRestore(
@@ -3165,27 +3206,27 @@ public class CheckpointCoordinatorTest extends 
TestLogger {
 
                for (int i = 0; i < executionJobVertex.getParallelism(); i++) {
 
-                       TaskStateHandles taskStateHandles = 
executionJobVertex.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateHandles();
+                       final List<OperatorID> operatorIds = 
executionJobVertex.getOperatorIDs();
 
-                       ChainedStateHandle<StreamStateHandle> 
expectNonPartitionedState = generateStateForVertex(jobVertexID, i);
-                       ChainedStateHandle<StreamStateHandle> 
actualNonPartitionedState = taskStateHandles.getLegacyOperatorState();
+                       TaskStateSnapshot stateSnapshot = 
executionJobVertex.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateSnapshot();
+
+                       OperatorSubtaskState operatorState = 
stateSnapshot.getSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID));
+
+                       StreamStateHandle expectNonPartitionedState = 
generateStateForVertex(jobVertexID, i);
                        assertTrue(CommonTestUtils.isSteamContentEqual(
-                                       
expectNonPartitionedState.get(0).openInputStream(),
-                                       
actualNonPartitionedState.get(0).openInputStream()));
+                                       
expectNonPartitionedState.openInputStream(),
+                               
operatorState.getLegacyOperatorState().openInputStream()));
 
                        ChainedStateHandle<OperatorStateHandle> 
expectedOpStateBackend =
                                        
generateChainedPartitionableStateHandle(jobVertexID, i, 2, 8, false);
 
-                       List<Collection<OperatorStateHandle>> 
actualPartitionableState = taskStateHandles.getManagedOperatorState();
-
                        assertTrue(CommonTestUtils.isSteamContentEqual(
                                        
expectedOpStateBackend.get(0).openInputStream(),
-                                       
actualPartitionableState.get(0).iterator().next().openInputStream()));
+                                       
operatorState.getManagedOperatorState().iterator().next().openInputStream()));
 
                        KeyGroupsStateHandle expectPartitionedKeyGroupState = 
generateKeyGroupState(
                                        jobVertexID, keyGroupPartitions.get(i), 
false);
-                       Collection<KeyedStateHandle> 
actualPartitionedKeyGroupState = taskStateHandles.getManagedKeyedState();
-                       
compareKeyedState(Collections.singletonList(expectPartitionedKeyGroupState), 
actualPartitionedKeyGroupState);
+                       
compareKeyedState(Collections.singletonList(expectPartitionedKeyGroupState), 
operatorState.getManagedKeyedState());
                }
        }
 
@@ -3632,17 +3673,4 @@ public class CheckpointCoordinatorTest extends 
TestLogger {
                        "The latest completed (proper) checkpoint should have 
been added to the completed checkpoint store.",
                        
completedCheckpointStore.getLatestCheckpoint().getCheckpointID() == 
checkpointIDCounter.getLast());
        }
-
-       private static final class SpyInjectingOperatorState extends 
OperatorState {
-
-               private static final long serialVersionUID = 
-4004437428483663815L;
-
-               public SpyInjectingOperatorState(OperatorID taskID, int 
parallelism, int maxParallelism) {
-                       super(taskID, parallelism, maxParallelism);
-               }
-
-               public void putState(int subtaskIndex, OperatorSubtaskState 
subtaskState) {
-                       super.putState(subtaskIndex, spy(subtaskState));
-               }
-       }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
index 7d24568..6ce071b 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
@@ -34,18 +34,18 @@ import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
-import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.runtime.util.SerializableObject;
+
 import org.hamcrest.BaseMatcher;
 import org.hamcrest.Description;
 import org.junit.Test;
 import org.mockito.Mockito;
 
-import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Objects;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.fail;
@@ -118,10 +118,20 @@ public class CheckpointStateRestoreTest {
                        PendingCheckpoint pending = 
coord.getPendingCheckpoints().values().iterator().next();
                        final long checkpointId = pending.getCheckpointId();
 
-                       SubtaskState checkpointStateHandles = new 
SubtaskState(serializedState, null, null, serializedKeyGroupStates, null);
-                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, statefulExec1.getAttemptId(), checkpointId, new 
CheckpointMetrics(), checkpointStateHandles));
-                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, statefulExec2.getAttemptId(), checkpointId, new 
CheckpointMetrics(), checkpointStateHandles));
-                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, statefulExec3.getAttemptId(), checkpointId, new 
CheckpointMetrics(), checkpointStateHandles));
+                       final TaskStateSnapshot subtaskStates = new 
TaskStateSnapshot();
+
+                       subtaskStates.putSubtaskStateByOperatorID(
+                               OperatorID.fromJobVertexID(statefulId),
+                               new OperatorSubtaskState(
+                                       serializedState.get(0),
+                                       
Collections.<OperatorStateHandle>emptyList(),
+                                       
Collections.<OperatorStateHandle>emptyList(),
+                                       
Collections.singletonList(serializedKeyGroupStates),
+                                       
Collections.<KeyedStateHandle>emptyList()));
+
+                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, statefulExec1.getAttemptId(), checkpointId, new 
CheckpointMetrics(), subtaskStates));
+                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, statefulExec2.getAttemptId(), checkpointId, new 
CheckpointMetrics(), subtaskStates));
+                       coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, statefulExec3.getAttemptId(), checkpointId, new 
CheckpointMetrics(), subtaskStates));
                        coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, statelessExec1.getAttemptId(), checkpointId));
                        coord.receiveAcknowledgeMessage(new 
AcknowledgeCheckpoint(jid, statelessExec2.getAttemptId(), checkpointId));
 
@@ -133,33 +143,26 @@ public class CheckpointStateRestoreTest {
 
                        // verify that each stateful vertex got the state
 
-                       final TaskStateHandles taskStateHandles = new 
TaskStateHandles(
-                                       serializedState,
-                                       
Collections.<Collection<OperatorStateHandle>>singletonList(null),
-                                       
Collections.<Collection<OperatorStateHandle>>singletonList(null),
-                                       
Collections.singletonList(serializedKeyGroupStates),
-                                       null);
-
-                       BaseMatcher<TaskStateHandles> matcher = new 
BaseMatcher<TaskStateHandles>() {
+                       BaseMatcher<TaskStateSnapshot> matcher = new 
BaseMatcher<TaskStateSnapshot>() {
                                @Override
                                public boolean matches(Object o) {
-                                       if (o instanceof TaskStateHandles) {
-                                               return 
o.equals(taskStateHandles);
+                                       if (o instanceof TaskStateSnapshot) {
+                                               return Objects.equals(o, 
subtaskStates);
                                        }
                                        return false;
                                }
 
                                @Override
                                public void describeTo(Description description) 
{
-                                       
description.appendValue(taskStateHandles);
+                                       description.appendValue(subtaskStates);
                                }
                        };
 
                        verify(statefulExec1, 
times(1)).setInitialState(Mockito.argThat(matcher));
                        verify(statefulExec2, 
times(1)).setInitialState(Mockito.argThat(matcher));
                        verify(statefulExec3, 
times(1)).setInitialState(Mockito.argThat(matcher));
-                       verify(statelessExec1, 
times(0)).setInitialState(Mockito.<TaskStateHandles>any());
-                       verify(statelessExec2, 
times(0)).setInitialState(Mockito.<TaskStateHandles>any());
+                       verify(statelessExec1, 
times(0)).setInitialState(Mockito.<TaskStateSnapshot>any());
+                       verify(statelessExec2, 
times(0)).setInitialState(Mockito.<TaskStateSnapshot>any());
                }
                catch (Exception e) {
                        e.printStackTrace();
@@ -250,9 +253,9 @@ public class CheckpointStateRestoreTest {
                Map<OperatorID, OperatorState> checkpointTaskStates = new 
HashMap<>();
                {
                        OperatorState taskState = new 
OperatorState(operatorId1, 3, 3);
-                       taskState.putState(0, new 
OperatorSubtaskState(serializedState, null, null, null, null));
-                       taskState.putState(1, new 
OperatorSubtaskState(serializedState, null, null, null, null));
-                       taskState.putState(2, new 
OperatorSubtaskState(serializedState, null, null, null, null));
+                       taskState.putState(0, new 
OperatorSubtaskState(serializedState));
+                       taskState.putState(1, new 
OperatorSubtaskState(serializedState));
+                       taskState.putState(2, new 
OperatorSubtaskState(serializedState));
 
                        checkpointTaskStates.put(operatorId1, taskState);
                }
@@ -279,7 +282,7 @@ public class CheckpointStateRestoreTest {
                // There is no task for this
                {
                        OperatorState taskState = new 
OperatorState(newOperatorID, 1, 1);
-                       taskState.putState(0, new 
OperatorSubtaskState(serializedState, null, null, null, null));
+                       taskState.putState(0, new 
OperatorSubtaskState(serializedState));
 
                        checkpointTaskStates.put(newOperatorID, taskState);
                }

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
index 1fe4e65..320dc2d 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
@@ -331,7 +331,7 @@ public abstract class CompletedCheckpointStoreTest extends 
TestLogger {
                boolean discarded;
 
                public TestOperatorSubtaskState() {
-                       super(null, null, null, null, null);
+                       super();
                        this.registered = false;
                        this.discarded = false;
                }

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java
index 7d103d0..7ebb49a 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java
@@ -324,7 +324,7 @@ public class PendingCheckpointTest {
        @Test
        public void testNonNullSubtaskStateLeadsToStatefulTask() throws 
Exception {
                PendingCheckpoint pending = 
createPendingCheckpoint(CheckpointProperties.forStandardCheckpoint(), null);
-               pending.acknowledgeTask(ATTEMPT_ID, mock(SubtaskState.class), 
mock(CheckpointMetrics.class));
+               pending.acknowledgeTask(ATTEMPT_ID, 
mock(TaskStateSnapshot.class), mock(CheckpointMetrics.class));
                Assert.assertFalse(pending.getOperatorStates().isEmpty());
        }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorTest.java
index 36c9cad..9ed4851 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorTest.java
@@ -23,6 +23,7 @@ import org.apache.flink.api.common.JobID;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.core.testutils.CommonTestUtils;
 import org.apache.flink.runtime.blob.BlobKey;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.clusterframework.types.AllocationID;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.executiongraph.JobInformation;
@@ -30,7 +31,6 @@ import 
org.apache.flink.runtime.executiongraph.TaskInformation;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
 import org.apache.flink.runtime.operators.BatchTask;
-import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.util.SerializedValue;
 
 import org.junit.Test;
@@ -73,7 +73,7 @@ public class TaskDeploymentDescriptorTest {
                        final SerializedValue<TaskInformation> 
serializedJobVertexInformation = new SerializedValue<>(new TaskInformation(
                                vertexID, taskName, currentNumberOfSubtasks, 
numberOfKeyGroups, invokableClass.getName(), taskConfiguration));
                        final int targetSlotNumber = 47;
-                       final TaskStateHandles taskStateHandles = new 
TaskStateHandles();
+                       final TaskStateSnapshot taskStateHandles = new 
TaskStateSnapshot();
 
                        final TaskDeploymentDescriptor orig = new 
TaskDeploymentDescriptor(
                                serializedJobInformation,

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionVertexLocalityTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionVertexLocalityTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionVertexLocalityTest.java
index 0eed90d..c9b7a40 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionVertexLocalityTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionVertexLocalityTest.java
@@ -23,6 +23,7 @@ import org.apache.flink.api.common.time.Time;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.metrics.groups.UnregisteredMetricsGroup;
 import org.apache.flink.runtime.checkpoint.StandaloneCheckpointRecoveryFactory;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.clusterframework.types.AllocationID;
 import org.apache.flink.runtime.clusterframework.types.ResourceID;
 import org.apache.flink.runtime.clusterframework.types.ResourceProfile;
@@ -38,7 +39,6 @@ import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobmanager.slots.AllocatedSlot;
 import org.apache.flink.runtime.jobmanager.slots.SlotOwner;
 import org.apache.flink.runtime.jobmanager.slots.TaskManagerGateway;
-import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.runtime.taskmanager.TaskManagerLocation;
 import org.apache.flink.runtime.testingUtils.TestingUtils;
 import org.apache.flink.runtime.testtasks.NoOpInvokable;
@@ -51,8 +51,10 @@ import java.net.InetAddress;
 import java.util.Iterator;
 import java.util.concurrent.TimeUnit;
 
-import static org.mockito.Mockito.*;
-import static org.junit.Assert.*;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.mock;
 
 /**
  * Tests that the execution vertex handles locality preferences well.
@@ -169,7 +171,7 @@ public class ExecutionVertexLocalityTest extends TestLogger 
{
 
                        // target state
                        ExecutionVertex target = 
graph.getAllVertices().get(targetVertexId).getTaskVertices()[i];
-                       
target.getCurrentExecutionAttempt().setInitialState(mock(TaskStateHandles.class));
+                       
target.getCurrentExecutionAttempt().setInitialState(mock(TaskStateSnapshot.class));
                }
 
                // validate that the target vertices have the state's location 
as the location preference

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
index a63b02d..23f0a38 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
@@ -18,16 +18,6 @@
 
 package org.apache.flink.runtime.jobmanager;
 
-import akka.actor.ActorRef;
-import akka.actor.ActorSystem;
-import akka.actor.Identify;
-import akka.actor.PoisonPill;
-import akka.actor.Props;
-import akka.japi.pf.FI;
-import akka.japi.pf.ReceiveBuilder;
-import akka.pattern.Patterns;
-import akka.testkit.CallingThreadDispatcher;
-import akka.testkit.JavaTestKit;
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.configuration.ConfigConstants;
 import org.apache.flink.configuration.Configuration;
@@ -44,8 +34,9 @@ import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.checkpoint.CheckpointRecoveryFactory;
 import org.apache.flink.runtime.checkpoint.CompletedCheckpointStore;
+import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
 import org.apache.flink.runtime.checkpoint.StandaloneCheckpointIDCounter;
-import org.apache.flink.runtime.checkpoint.SubtaskState;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.clusterframework.types.ResourceID;
 import org.apache.flink.runtime.execution.librarycache.BlobLibraryCacheManager;
 import 
org.apache.flink.runtime.executiongraph.restart.FixedDelayRestartStrategy;
@@ -59,6 +50,7 @@ import org.apache.flink.runtime.jobgraph.JobGraph;
 import org.apache.flink.runtime.jobgraph.JobStatus;
 import org.apache.flink.runtime.jobgraph.JobVertex;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
 import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings;
 import org.apache.flink.runtime.jobgraph.tasks.JobCheckpointingSettings;
@@ -69,9 +61,6 @@ import 
org.apache.flink.runtime.leaderelection.TestingLeaderElectionService;
 import org.apache.flink.runtime.leaderelection.TestingLeaderRetrievalService;
 import org.apache.flink.runtime.messages.JobManagerMessages;
 import org.apache.flink.runtime.metrics.MetricRegistry;
-import org.apache.flink.runtime.state.ChainedStateHandle;
-import org.apache.flink.runtime.state.StreamStateHandle;
-import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
 import org.apache.flink.runtime.taskmanager.TaskManager;
 import org.apache.flink.runtime.testingUtils.TestingJobManager;
@@ -83,23 +72,24 @@ import org.apache.flink.runtime.testingUtils.TestingUtils;
 import org.apache.flink.runtime.testutils.RecoverableCompletedCheckpointStore;
 import org.apache.flink.runtime.util.TestByteStreamStateHandleDeepCompare;
 import org.apache.flink.util.InstantiationUtil;
-
 import org.apache.flink.util.TestLogger;
+
+import akka.actor.ActorRef;
+import akka.actor.ActorSystem;
+import akka.actor.Identify;
+import akka.actor.PoisonPill;
+import akka.actor.Props;
+import akka.japi.pf.FI;
+import akka.japi.pf.ReceiveBuilder;
+import akka.pattern.Patterns;
+import akka.testkit.CallingThreadDispatcher;
+import akka.testkit.JavaTestKit;
 import org.junit.AfterClass;
 import org.junit.BeforeClass;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.TemporaryFolder;
 
-import scala.Int;
-import scala.Option;
-import scala.PartialFunction;
-import scala.concurrent.Await;
-import scala.concurrent.Future;
-import scala.concurrent.duration.Deadline;
-import scala.concurrent.duration.FiniteDuration;
-import scala.runtime.BoxedUnit;
-
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
@@ -113,6 +103,15 @@ import java.util.concurrent.Executor;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.TimeUnit;
 
+import scala.Int;
+import scala.Option;
+import scala.PartialFunction;
+import scala.concurrent.Await;
+import scala.concurrent.Future;
+import scala.concurrent.duration.Deadline;
+import scala.concurrent.duration.FiniteDuration;
+import scala.runtime.BoxedUnit;
+
 import static org.hamcrest.Matchers.containsInAnyOrder;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertThat;
@@ -552,10 +551,10 @@ public class JobManagerHARecoveryTest extends TestLogger {
 
                @Override
                public void setInitialState(
-                               TaskStateHandles taskStateHandles) throws 
Exception {
+                       TaskStateSnapshot taskStateHandles) throws Exception {
                        int subtaskIndex = getIndexInSubtaskGroup();
                        if (subtaskIndex < recoveredStates.length) {
-                               try (FSDataInputStream in = 
taskStateHandles.getLegacyOperatorState().get(0).openInputStream()) {
+                               try (FSDataInputStream in = 
taskStateHandles.getSubtaskStateMappings().iterator().next().getValue().getLegacyOperatorState().openInputStream())
 {
                                        recoveredStates[subtaskIndex] = 
InstantiationUtil.deserializeObject(in, getUserCodeClassLoader());
                                }
                        }
@@ -567,10 +566,11 @@ public class JobManagerHARecoveryTest extends TestLogger {
                                        String.valueOf(UUID.randomUUID()),
                                        
InstantiationUtil.serializeObject(checkpointMetaData.getCheckpointId()));
 
-                       ChainedStateHandle<StreamStateHandle> 
chainedStateHandle =
-                                       new 
ChainedStateHandle<StreamStateHandle>(Collections.singletonList(byteStreamStateHandle));
-                       SubtaskState checkpointStateHandles =
-                                       new SubtaskState(chainedStateHandle, 
null, null, null, null);
+                       TaskStateSnapshot checkpointStateHandles = new 
TaskStateSnapshot();
+                       checkpointStateHandles.putSubtaskStateByOperatorID(
+                               
OperatorID.fromJobVertexID(getEnvironment().getJobVertexId()),
+                               new OperatorSubtaskState(byteStreamStateHandle)
+                       );
 
                        getEnvironment().acknowledgeCheckpoint(
                                        checkpointMetaData.getCheckpointId(),

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
index bc420cc..d022cdc 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
@@ -24,14 +24,17 @@ import org.apache.flink.core.testutils.CommonTestUtils;
 import org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTest;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
-import org.apache.flink.runtime.checkpoint.SubtaskState;
+import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
 import org.apache.flink.runtime.messages.checkpoint.NotifyCheckpointComplete;
 import org.apache.flink.runtime.messages.checkpoint.TriggerCheckpoint;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.StreamStateHandle;
+
 import org.junit.Test;
 
 import java.io.IOException;
@@ -68,13 +71,17 @@ public class CheckpointMessagesTest {
 
                        KeyGroupRange keyGroupRange = KeyGroupRange.of(42,42);
 
-                       SubtaskState checkpointStateHandles =
-                                       new SubtaskState(
-                                                       
CheckpointCoordinatorTest.generateChainedStateHandle(new MyHandle()),
-                                                       
CheckpointCoordinatorTest.generateChainedPartitionableStateHandle(new 
JobVertexID(), 0, 2, 8, false),
-                                                       null,
-                                                       
CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, 
Collections.singletonList(new MyHandle())),
-                                                       null);
+                       TaskStateSnapshot checkpointStateHandles = new 
TaskStateSnapshot();
+                       checkpointStateHandles.putSubtaskStateByOperatorID(
+                               new OperatorID(),
+                               new OperatorSubtaskState(
+                                       
CheckpointCoordinatorTest.generateStreamStateHandle(new MyHandle()),
+                                       
CheckpointCoordinatorTest.generatePartitionableStateHandle(new JobVertexID(), 
0, 2, 8, false),
+                                       null,
+                                       
CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, 
Collections.singletonList(new MyHandle())),
+                                       null
+                               )
+                       );
 
                        AcknowledgeCheckpoint withState = new 
AcknowledgeCheckpoint(
                                        new JobID(),

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java
index 851fa96..8ed06b2 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java
@@ -26,7 +26,7 @@ import org.apache.flink.core.fs.Path;
 import org.apache.flink.runtime.accumulators.AccumulatorRegistry;
 import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
-import org.apache.flink.runtime.checkpoint.SubtaskState;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.io.disk.iomanager.IOManager;
@@ -156,7 +156,7 @@ public class DummyEnvironment implements Environment {
        }
 
        @Override
-       public void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics 
checkpointMetrics, SubtaskState subtaskState) {
+       public void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics 
checkpointMetrics, TaskStateSnapshot subtaskState) {
        }
 
        @Override

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java
index 4f0242e..7514cc4 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java
@@ -27,7 +27,7 @@ import org.apache.flink.core.memory.MemorySegmentFactory;
 import org.apache.flink.runtime.accumulators.AccumulatorRegistry;
 import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
-import org.apache.flink.runtime.checkpoint.SubtaskState;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.io.disk.iomanager.IOManager;
@@ -50,8 +50,8 @@ import 
org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo;
 import org.apache.flink.runtime.util.TestingTaskManagerRuntimeInfo;
 import org.apache.flink.types.Record;
 import org.apache.flink.util.MutableObjectIterator;
-
 import org.apache.flink.util.Preconditions;
+
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
 
@@ -354,7 +354,7 @@ public class MockEnvironment implements Environment {
        }
 
        @Override
-       public void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics 
checkpointMetrics, SubtaskState subtaskState) {
+       public void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics 
checkpointMetrics, TaskStateSnapshot subtaskState) {
                throw new UnsupportedOperationException();
        }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
index c6d2fec..085a386 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
@@ -27,6 +27,7 @@ import 
org.apache.flink.runtime.broadcast.BroadcastVariableManager;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.clusterframework.types.AllocationID;
 import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
 import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
@@ -49,7 +50,6 @@ import org.apache.flink.runtime.memory.MemoryManager;
 import org.apache.flink.runtime.metrics.groups.TaskIOMetricGroup;
 import org.apache.flink.runtime.metrics.groups.TaskMetricGroup;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
-import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.runtime.util.TestingTaskManagerRuntimeInfo;
 import org.apache.flink.util.SerializedValue;
 
@@ -187,7 +187,7 @@ public class TaskAsyncCallTest {
                        
Collections.<ResultPartitionDeploymentDescriptor>emptyList(),
                        Collections.<InputGateDeploymentDescriptor>emptyList(),
                        0,
-                       new TaskStateHandles(),
+                       new TaskStateSnapshot(),
                        mock(MemoryManager.class),
                        mock(IOManager.class),
                        networkEnvironment,
@@ -228,7 +228,7 @@ public class TaskAsyncCallTest {
                }
 
                @Override
-               public void setInitialState(TaskStateHandles taskStateHandles) 
throws Exception {}
+               public void setInitialState(TaskStateSnapshot taskStateHandles) 
throws Exception {}
 
                @Override
                public boolean triggerCheckpoint(CheckpointMetaData 
checkpointMetaData, CheckpointOptions checkpointOptions) {

Reply via email to