http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java index 24169f2..0d2e903 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java @@ -19,6 +19,7 @@ package org.apache.flink.runtime.checkpoint; import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; import org.apache.flink.api.common.JobID; import org.apache.flink.api.common.time.Time; import org.apache.flink.api.java.tuple.Tuple2; @@ -33,6 +34,7 @@ import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; import org.apache.flink.runtime.executiongraph.ExecutionVertex; import org.apache.flink.runtime.jobgraph.JobStatus; import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings; import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint; import org.apache.flink.runtime.messages.checkpoint.DeclineCheckpoint; @@ -168,7 +170,14 @@ public class CheckpointCoordinatorTest { final ExecutionAttemptID triggerAttemptID1 = new ExecutionAttemptID(); final ExecutionAttemptID triggerAttemptID2 = new ExecutionAttemptID(); ExecutionVertex triggerVertex1 = mockExecutionVertex(triggerAttemptID1); - ExecutionVertex triggerVertex2 = mockExecutionVertex(triggerAttemptID2, new JobVertexID(), 1, 1, ExecutionState.FINISHED); + JobVertexID jobVertexID2 = new JobVertexID(); + ExecutionVertex triggerVertex2 = mockExecutionVertex( + triggerAttemptID2, + jobVertexID2, + Lists.newArrayList(OperatorID.fromJobVertexID(jobVertexID2)), + 1, + 1, + ExecutionState.FINISHED); // create some mock Execution vertices that need to ack the checkpoint final ExecutionAttemptID ackAttemptID1 = new ExecutionAttemptID(); @@ -317,7 +326,7 @@ public class CheckpointCoordinatorTest { assertEquals(jid, checkpoint.getJobId()); assertEquals(2, checkpoint.getNumberOfNonAcknowledgedTasks()); assertEquals(0, checkpoint.getNumberOfAcknowledgedTasks()); - assertEquals(0, checkpoint.getTaskStates().size()); + assertEquals(0, checkpoint.getOperatorStates().size()); assertFalse(checkpoint.isDiscarded()); assertFalse(checkpoint.isFullyAcknowledged()); @@ -426,7 +435,7 @@ public class CheckpointCoordinatorTest { assertEquals(jid, checkpoint1.getJobId()); assertEquals(2, checkpoint1.getNumberOfNonAcknowledgedTasks()); assertEquals(0, checkpoint1.getNumberOfAcknowledgedTasks()); - assertEquals(0, checkpoint1.getTaskStates().size()); + assertEquals(0, checkpoint1.getOperatorStates().size()); assertFalse(checkpoint1.isDiscarded()); assertFalse(checkpoint1.isFullyAcknowledged()); @@ -436,7 +445,7 @@ public class CheckpointCoordinatorTest { assertEquals(jid, checkpoint2.getJobId()); assertEquals(2, checkpoint2.getNumberOfNonAcknowledgedTasks()); assertEquals(0, checkpoint2.getNumberOfAcknowledgedTasks()); - assertEquals(0, checkpoint2.getTaskStates().size()); + assertEquals(0, checkpoint2.getOperatorStates().size()); assertFalse(checkpoint2.isDiscarded()); assertFalse(checkpoint2.isFullyAcknowledged()); @@ -471,7 +480,7 @@ public class CheckpointCoordinatorTest { assertEquals(jid, checkpointNew.getJobId()); assertEquals(2, checkpointNew.getNumberOfNonAcknowledgedTasks()); assertEquals(0, checkpointNew.getNumberOfAcknowledgedTasks()); - assertEquals(0, checkpointNew.getTaskStates().size()); + assertEquals(0, checkpointNew.getOperatorStates().size()); assertFalse(checkpointNew.isDiscarded()); assertFalse(checkpointNew.isFullyAcknowledged()); assertNotEquals(checkpoint1.getCheckpointId(), checkpointNew.getCheckpointId()); @@ -539,10 +548,20 @@ public class CheckpointCoordinatorTest { assertEquals(jid, checkpoint.getJobId()); assertEquals(2, checkpoint.getNumberOfNonAcknowledgedTasks()); assertEquals(0, checkpoint.getNumberOfAcknowledgedTasks()); - assertEquals(0, checkpoint.getTaskStates().size()); + assertEquals(0, checkpoint.getOperatorStates().size()); assertFalse(checkpoint.isDiscarded()); assertFalse(checkpoint.isFullyAcknowledged()); + OperatorID opID1 = OperatorID.fromJobVertexID(vertex1.getJobvertexId()); + OperatorID opID2 = OperatorID.fromJobVertexID(vertex2.getJobvertexId()); + + Map<OperatorID, OperatorState> operatorStates = checkpoint.getOperatorStates(); + + operatorStates.put(opID1, new SpyInjectingOperatorState( + opID1, vertex1.getTotalNumberOfParallelSubtasks(), vertex1.getMaxParallelism())); + operatorStates.put(opID2, new SpyInjectingOperatorState( + opID2, vertex2.getTotalNumberOfParallelSubtasks(), vertex2.getMaxParallelism())); + // check that the vertices received the trigger checkpoint message { verify(vertex1.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId), eq(timestamp), any(CheckpointOptions.class)); @@ -550,9 +569,9 @@ public class CheckpointCoordinatorTest { } // acknowledge from one of the tasks - SubtaskState subtaskState2 = mock(SubtaskState.class); - AcknowledgeCheckpoint acknowledgeCheckpoint1 = new AcknowledgeCheckpoint(jid, attemptID2, checkpointId, new CheckpointMetrics(), subtaskState2); + AcknowledgeCheckpoint acknowledgeCheckpoint1 = new AcknowledgeCheckpoint(jid, attemptID2, checkpointId, new CheckpointMetrics(), mock(SubtaskState.class)); coord.receiveAcknowledgeMessage(acknowledgeCheckpoint1); + OperatorSubtaskState subtaskState2 = operatorStates.get(opID2).getState(vertex2.getParallelSubtaskIndex()); assertEquals(1, checkpoint.getNumberOfAcknowledgedTasks()); assertEquals(1, checkpoint.getNumberOfNonAcknowledgedTasks()); assertFalse(checkpoint.isDiscarded()); @@ -566,8 +585,8 @@ public class CheckpointCoordinatorTest { verify(subtaskState2, never()).registerSharedStates(any(SharedStateRegistry.class)); // acknowledge the other task. - SubtaskState subtaskState1 = mock(SubtaskState.class); - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointId, new CheckpointMetrics(), subtaskState1)); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointId, new CheckpointMetrics(), mock(SubtaskState.class))); + OperatorSubtaskState subtaskState1 = operatorStates.get(opID1).getState(vertex1.getParallelSubtaskIndex()); // the checkpoint is internally converted to a successful checkpoint and the // pending checkpoint object is disposed @@ -596,7 +615,7 @@ public class CheckpointCoordinatorTest { assertEquals(jid, success.getJobId()); assertEquals(timestamp, success.getTimestamp()); assertEquals(checkpoint.getCheckpointId(), success.getCheckpointID()); - assertEquals(2, success.getTaskStates().size()); + assertEquals(2, success.getOperatorStates().size()); // --------------- // trigger another checkpoint and see that this one replaces the other checkpoint @@ -606,7 +625,9 @@ public class CheckpointCoordinatorTest { long checkpointIdNew = coord.getPendingCheckpoints().entrySet().iterator().next().getKey(); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointIdNew)); + subtaskState1 = operatorStates.get(opID1).getState(vertex1.getParallelSubtaskIndex()); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID2, checkpointIdNew)); + subtaskState2 = operatorStates.get(opID2).getState(vertex2.getParallelSubtaskIndex()); assertEquals(0, coord.getNumberOfPendingCheckpoints()); assertEquals(1, coord.getNumberOfRetainedSuccessfulCheckpoints()); @@ -616,7 +637,7 @@ public class CheckpointCoordinatorTest { assertEquals(jid, successNew.getJobId()); assertEquals(timestampNew, successNew.getTimestamp()); assertEquals(checkpointIdNew, successNew.getCheckpointID()); - assertTrue(successNew.getTaskStates().isEmpty()); + assertTrue(successNew.getOperatorStates().isEmpty()); // validate that the subtask states in old savepoint have unregister their shared states { @@ -756,13 +777,13 @@ public class CheckpointCoordinatorTest { assertEquals(checkpointId1, sc1.getCheckpointID()); assertEquals(timestamp1, sc1.getTimestamp()); assertEquals(jid, sc1.getJobId()); - assertTrue(sc1.getTaskStates().isEmpty()); + assertTrue(sc1.getOperatorStates().isEmpty()); CompletedCheckpoint sc2 = scs.get(1); assertEquals(checkpointId2, sc2.getCheckpointID()); assertEquals(timestamp2, sc2.getTimestamp()); assertEquals(jid, sc2.getJobId()); - assertTrue(sc2.getTaskStates().isEmpty()); + assertTrue(sc2.getOperatorStates().isEmpty()); coord.shutdown(JobStatus.FINISHED); } @@ -830,10 +851,22 @@ public class CheckpointCoordinatorTest { verify(triggerVertex1.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId1), eq(timestamp1), any(CheckpointOptions.class)); verify(triggerVertex2.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId1), eq(timestamp1), any(CheckpointOptions.class)); - // acknowledge one of the three tasks - SubtaskState subtaskState1_2 = mock(SubtaskState.class); - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId1, new CheckpointMetrics(), subtaskState1_2)); + OperatorID opID1 = OperatorID.fromJobVertexID(ackVertex1.getJobvertexId()); + OperatorID opID2 = OperatorID.fromJobVertexID(ackVertex2.getJobvertexId()); + OperatorID opID3 = OperatorID.fromJobVertexID(ackVertex3.getJobvertexId()); + + Map<OperatorID, OperatorState> operatorStates1 = pending1.getOperatorStates(); + + operatorStates1.put(opID1, new SpyInjectingOperatorState( + opID1, ackVertex1.getTotalNumberOfParallelSubtasks(), ackVertex1.getMaxParallelism())); + operatorStates1.put(opID2, new SpyInjectingOperatorState( + opID2, ackVertex2.getTotalNumberOfParallelSubtasks(), ackVertex2.getMaxParallelism())); + operatorStates1.put(opID3, new SpyInjectingOperatorState( + opID3, ackVertex3.getTotalNumberOfParallelSubtasks(), ackVertex3.getMaxParallelism())); + // acknowledge one of the three tasks + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId1, new CheckpointMetrics(), mock(SubtaskState.class))); + OperatorSubtaskState subtaskState1_2 = operatorStates1.get(opID2).getState(ackVertex2.getParallelSubtaskIndex()); // start the second checkpoint // trigger the first checkpoint. this should succeed assertTrue(coord.triggerCheckpoint(timestamp2, false)); @@ -850,6 +883,15 @@ public class CheckpointCoordinatorTest { } long checkpointId2 = pending2.getCheckpointId(); + Map<OperatorID, OperatorState> operatorStates2 = pending2.getOperatorStates(); + + operatorStates2.put(opID1, new SpyInjectingOperatorState( + opID1, ackVertex1.getTotalNumberOfParallelSubtasks(), ackVertex1.getMaxParallelism())); + operatorStates2.put(opID2, new SpyInjectingOperatorState( + opID2, ackVertex2.getTotalNumberOfParallelSubtasks(), ackVertex2.getMaxParallelism())); + operatorStates2.put(opID3, new SpyInjectingOperatorState( + opID3, ackVertex3.getTotalNumberOfParallelSubtasks(), ackVertex3.getMaxParallelism())); + // trigger messages should have been sent verify(triggerVertex1.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId2), eq(timestamp2), any(CheckpointOptions.class)); verify(triggerVertex2.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId2), eq(timestamp2), any(CheckpointOptions.class)); @@ -857,17 +899,17 @@ public class CheckpointCoordinatorTest { // we acknowledge one more task from the first checkpoint and the second // checkpoint completely. The second checkpoint should then subsume the first checkpoint - SubtaskState subtaskState2_3 = mock(SubtaskState.class); - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId2, new CheckpointMetrics(), subtaskState2_3)); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId2, new CheckpointMetrics(), mock(SubtaskState.class))); + OperatorSubtaskState subtaskState2_3 = operatorStates2.get(opID3).getState(ackVertex3.getParallelSubtaskIndex()); - SubtaskState subtaskState2_1 = mock(SubtaskState.class); - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId2, new CheckpointMetrics(), subtaskState2_1)); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId2, new CheckpointMetrics(), mock(SubtaskState.class))); + OperatorSubtaskState subtaskState2_1 = operatorStates2.get(opID1).getState(ackVertex1.getParallelSubtaskIndex()); - SubtaskState subtaskState1_1 = mock(SubtaskState.class); - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId1, new CheckpointMetrics(), subtaskState1_1)); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId1, new CheckpointMetrics(), mock(SubtaskState.class))); + OperatorSubtaskState subtaskState1_1 = operatorStates1.get(opID1).getState(ackVertex1.getParallelSubtaskIndex()); - SubtaskState subtaskState2_2 = mock(SubtaskState.class); - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId2, new CheckpointMetrics(), subtaskState2_2)); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId2, new CheckpointMetrics(), mock(SubtaskState.class))); + OperatorSubtaskState subtaskState2_2 = operatorStates2.get(opID2).getState(ackVertex2.getParallelSubtaskIndex()); // now, the second checkpoint should be confirmed, and the first discarded // actually both pending checkpoints are discarded, and the second has been transformed @@ -896,7 +938,7 @@ public class CheckpointCoordinatorTest { assertEquals(checkpointId2, success.getCheckpointID()); assertEquals(timestamp2, success.getTimestamp()); assertEquals(jid, success.getJobId()); - assertEquals(3, success.getTaskStates().size()); + assertEquals(3, success.getOperatorStates().size()); // the first confirm message should be out verify(commitVertex.getCurrentExecutionAttempt(), times(1)).notifyCheckpointComplete(eq(checkpointId2), eq(timestamp2)); @@ -970,8 +1012,15 @@ public class CheckpointCoordinatorTest { PendingCheckpoint checkpoint = coord.getPendingCheckpoints().values().iterator().next(); assertFalse(checkpoint.isDiscarded()); - SubtaskState subtaskState = mock(SubtaskState.class); - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpoint.getCheckpointId(), new CheckpointMetrics(), subtaskState)); + OperatorID opID1 = OperatorID.fromJobVertexID(ackVertex1.getJobvertexId()); + + Map<OperatorID, OperatorState> operatorStates = checkpoint.getOperatorStates(); + + operatorStates.put(opID1, new SpyInjectingOperatorState( + opID1, ackVertex1.getTotalNumberOfParallelSubtasks(), ackVertex1.getMaxParallelism())); + + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpoint.getCheckpointId(), new CheckpointMetrics(), mock(SubtaskState.class))); + OperatorSubtaskState subtaskState = operatorStates.get(opID1).getState(ackVertex1.getParallelSubtaskIndex()); // wait until the checkpoint must have expired. // we check every 250 msecs conservatively for 5 seconds @@ -1106,13 +1155,25 @@ public class CheckpointCoordinatorTest { long checkpointId = pendingCheckpoint.getCheckpointId(); - SubtaskState triggerSubtaskState = mock(SubtaskState.class); + OperatorID opIDtrigger = OperatorID.fromJobVertexID(triggerVertex.getJobvertexId()); + OperatorID opID1 = OperatorID.fromJobVertexID(ackVertex1.getJobvertexId()); + OperatorID opID2 = OperatorID.fromJobVertexID(ackVertex2.getJobvertexId()); + + Map<OperatorID, OperatorState> operatorStates = pendingCheckpoint.getOperatorStates(); + + operatorStates.put(opIDtrigger, new SpyInjectingOperatorState( + opIDtrigger, triggerVertex.getTotalNumberOfParallelSubtasks(), triggerVertex.getMaxParallelism())); + operatorStates.put(opID1, new SpyInjectingOperatorState( + opID1, ackVertex1.getTotalNumberOfParallelSubtasks(), ackVertex1.getMaxParallelism())); + operatorStates.put(opID2, new SpyInjectingOperatorState( + opID2, ackVertex2.getTotalNumberOfParallelSubtasks(), ackVertex2.getMaxParallelism())); // acknowledge the first trigger vertex - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, triggerAttemptId, checkpointId, new CheckpointMetrics(), triggerSubtaskState)); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, triggerAttemptId, checkpointId, new CheckpointMetrics(), mock(SubtaskState.class))); + OperatorSubtaskState storedTriggerSubtaskState = operatorStates.get(opIDtrigger).getState(triggerVertex.getParallelSubtaskIndex()); - // verify that the subtask state has registered its shared states at the registry - verify(triggerSubtaskState, never()).discardState(); + // verify that the subtask state has not been discarded + verify(storedTriggerSubtaskState, never()).discardState(); SubtaskState unknownSubtaskState = mock(SubtaskState.class); @@ -1131,20 +1192,20 @@ public class CheckpointCoordinatorTest { verify(differentJobSubtaskState, never()).discardState(); // duplicate acknowledge message for the trigger vertex - reset(triggerSubtaskState); + SubtaskState triggerSubtaskState = mock(SubtaskState.class); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, triggerAttemptId, checkpointId, new CheckpointMetrics(), triggerSubtaskState)); // duplicate acknowledge messages for a known vertex should not trigger discarding the state verify(triggerSubtaskState, never()).discardState(); // let the checkpoint fail at the first ack vertex - reset(triggerSubtaskState); + reset(storedTriggerSubtaskState); coord.receiveDeclineMessage(new DeclineCheckpoint(jobId, ackAttemptId1, checkpointId)); assertTrue(pendingCheckpoint.isDiscarded()); // check that we've cleaned up the already acknowledged state - verify(triggerSubtaskState, times(1)).discardState(); + verify(storedTriggerSubtaskState, times(1)).discardState(); SubtaskState ackSubtaskState = mock(SubtaskState.class); @@ -1411,15 +1472,25 @@ public class CheckpointCoordinatorTest { assertEquals(jid, pending.getJobId()); assertEquals(2, pending.getNumberOfNonAcknowledgedTasks()); assertEquals(0, pending.getNumberOfAcknowledgedTasks()); - assertEquals(0, pending.getTaskStates().size()); + assertEquals(0, pending.getOperatorStates().size()); assertFalse(pending.isDiscarded()); assertFalse(pending.isFullyAcknowledged()); assertFalse(pending.canBeSubsumed()); + OperatorID opID1 = OperatorID.fromJobVertexID(vertex1.getJobvertexId()); + OperatorID opID2 = OperatorID.fromJobVertexID(vertex2.getJobvertexId()); + + Map<OperatorID, OperatorState> operatorStates = pending.getOperatorStates(); + + operatorStates.put(opID1, new SpyInjectingOperatorState( + opID1, vertex1.getTotalNumberOfParallelSubtasks(), vertex1.getMaxParallelism())); + operatorStates.put(opID2, new SpyInjectingOperatorState( + opID2, vertex2.getTotalNumberOfParallelSubtasks(), vertex1.getMaxParallelism())); + // acknowledge from one of the tasks - SubtaskState subtaskState2 = mock(SubtaskState.class); - AcknowledgeCheckpoint acknowledgeCheckpoint2 = new AcknowledgeCheckpoint(jid, attemptID2, checkpointId, new CheckpointMetrics(), subtaskState2); + AcknowledgeCheckpoint acknowledgeCheckpoint2 = new AcknowledgeCheckpoint(jid, attemptID2, checkpointId, new CheckpointMetrics(), mock(SubtaskState.class)); coord.receiveAcknowledgeMessage(acknowledgeCheckpoint2); + OperatorSubtaskState subtaskState2 = operatorStates.get(opID2).getState(vertex2.getParallelSubtaskIndex()); assertEquals(1, pending.getNumberOfAcknowledgedTasks()); assertEquals(1, pending.getNumberOfNonAcknowledgedTasks()); assertFalse(pending.isDiscarded()); @@ -1433,8 +1504,8 @@ public class CheckpointCoordinatorTest { assertFalse(savepointFuture.isDone()); // acknowledge the other task. - SubtaskState subtaskState1 = mock(SubtaskState.class); - coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointId, new CheckpointMetrics(), subtaskState1)); + coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointId, new CheckpointMetrics(), mock(SubtaskState.class))); + OperatorSubtaskState subtaskState1 = operatorStates.get(opID1).getState(vertex1.getParallelSubtaskIndex()); // the checkpoint is internally converted to a successful checkpoint and the // pending checkpoint object is disposed @@ -1461,7 +1532,7 @@ public class CheckpointCoordinatorTest { assertEquals(jid, success.getJobId()); assertEquals(timestamp, success.getTimestamp()); assertEquals(pending.getCheckpointId(), success.getCheckpointID()); - assertEquals(2, success.getTaskStates().size()); + assertEquals(2, success.getOperatorStates().size()); // --------------- // trigger another checkpoint and see that this one replaces the other checkpoint @@ -1474,6 +1545,9 @@ public class CheckpointCoordinatorTest { coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointIdNew)); coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID2, checkpointIdNew)); + subtaskState1 = operatorStates.get(opID1).getState(vertex1.getParallelSubtaskIndex()); + subtaskState2 = operatorStates.get(opID2).getState(vertex2.getParallelSubtaskIndex()); + assertEquals(0, coord.getNumberOfPendingCheckpoints()); assertEquals(1, coord.getNumberOfRetainedSuccessfulCheckpoints()); @@ -1481,7 +1555,7 @@ public class CheckpointCoordinatorTest { assertEquals(jid, successNew.getJobId()); assertEquals(timestampNew, successNew.getTimestamp()); assertEquals(checkpointIdNew, successNew.getCheckpointID()); - assertTrue(successNew.getTaskStates().isEmpty()); + assertTrue(successNew.getOperatorStates().isEmpty()); assertTrue(savepointFuture.isDone()); // validate that the first savepoint does not discard its private states. @@ -1969,6 +2043,18 @@ public class CheckpointCoordinatorTest { List<KeyGroupRange> keyGroupPartitions1 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism1, parallelism1); List<KeyGroupRange> keyGroupPartitions2 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, parallelism2); + PendingCheckpoint pending = coord.getPendingCheckpoints().get(checkpointId); + + OperatorID opID1 = OperatorID.fromJobVertexID(jobVertexID1); + OperatorID opID2 = OperatorID.fromJobVertexID(jobVertexID2); + + Map<OperatorID, OperatorState> operatorStates = pending.getOperatorStates(); + + operatorStates.put(opID1, new SpyInjectingOperatorState( + opID1, jobVertex1.getParallelism(), jobVertex1.getMaxParallelism())); + operatorStates.put(opID2, new SpyInjectingOperatorState( + opID2, jobVertex2.getParallelism(), jobVertex2.getMaxParallelism())); + for (int index = 0; index < jobVertex1.getParallelism(); index++) { SubtaskState subtaskState = mockSubtaskState(jobVertexID1, index, keyGroupPartitions1.get(index)); @@ -2004,8 +2090,8 @@ public class CheckpointCoordinatorTest { // All shared states should be unregistered once the store is shut down for (CompletedCheckpoint completedCheckpoint : completedCheckpoints) { - for (TaskState taskState : completedCheckpoint.getTaskStates().values()) { - for (SubtaskState subtaskState : taskState.getStates()) { + for (OperatorState taskState : completedCheckpoint.getOperatorStates().values()) { + for (OperatorSubtaskState subtaskState : taskState.getStates()) { verify(subtaskState, times(1)).unregisterSharedStates(any(SharedStateRegistry.class)); } } @@ -2021,8 +2107,8 @@ public class CheckpointCoordinatorTest { // validate that all shared states are registered again after the recovery. for (CompletedCheckpoint completedCheckpoint : completedCheckpoints) { - for (TaskState taskState : completedCheckpoint.getTaskStates().values()) { - for (SubtaskState subtaskState : taskState.getStates()) { + for (OperatorState taskState : completedCheckpoint.getOperatorStates().values()) { + for (OperatorSubtaskState subtaskState : taskState.getStates()) { verify(subtaskState, times(2)).registerSharedStates(any(SharedStateRegistry.class)); } } @@ -2432,7 +2518,11 @@ public class CheckpointCoordinatorTest { actualOpStatesBackend.add(opStateBackend); actualOpStatesRaw.add(opStateRaw); - assertNull(operatorState); + // the 'non partition state' is not null because it is recombined. + assertNotNull(operatorState); + for (int index = 0; index < operatorState.getLength(); index++) { + assertNull(operatorState.get(index)); + } compareKeyedState(Collections.singletonList(originalKeyedStateBackend), keyedStateBackend); compareKeyedState(Collections.singletonList(originalKeyedStateRaw), keyGroupStateRaw); } @@ -2682,7 +2772,21 @@ public class CheckpointCoordinatorTest { } static ExecutionJobVertex mockExecutionJobVertex( + JobVertexID jobVertexID, + int parallelism, + int maxParallelism) { + + return mockExecutionJobVertex( + jobVertexID, + Collections.singletonList(OperatorID.fromJobVertexID(jobVertexID)), + parallelism, + maxParallelism + ); + } + + static ExecutionJobVertex mockExecutionJobVertex( JobVertexID jobVertexID, + List<OperatorID> jobVertexIDs, int parallelism, int maxParallelism) { final ExecutionJobVertex executionJobVertex = mock(ExecutionJobVertex.class); @@ -2693,6 +2797,7 @@ public class CheckpointCoordinatorTest { executionVertices[i] = mockExecutionVertex( new ExecutionAttemptID(), jobVertexID, + jobVertexIDs, parallelism, maxParallelism, ExecutionState.RUNNING); @@ -2705,14 +2810,18 @@ public class CheckpointCoordinatorTest { when(executionJobVertex.getParallelism()).thenReturn(parallelism); when(executionJobVertex.getMaxParallelism()).thenReturn(maxParallelism); when(executionJobVertex.isMaxParallelismConfigured()).thenReturn(true); + when(executionJobVertex.getOperatorIDs()).thenReturn(jobVertexIDs); + when(executionJobVertex.getUserDefinedOperatorIDs()).thenReturn(Arrays.asList(new OperatorID[jobVertexIDs.size()])); return executionJobVertex; } static ExecutionVertex mockExecutionVertex(ExecutionAttemptID attemptID) { + JobVertexID jobVertexID = new JobVertexID(); return mockExecutionVertex( attemptID, - new JobVertexID(), + jobVertexID, + Arrays.asList(OperatorID.fromJobVertexID(jobVertexID)), 1, 1, ExecutionState.RUNNING); @@ -2721,6 +2830,7 @@ public class CheckpointCoordinatorTest { private static ExecutionVertex mockExecutionVertex( ExecutionAttemptID attemptID, JobVertexID jobVertexID, + List<OperatorID> jobVertexIDs, int parallelism, int maxParallelism, ExecutionState state, @@ -2743,6 +2853,11 @@ public class CheckpointCoordinatorTest { when(vertex.getTotalNumberOfParallelSubtasks()).thenReturn(parallelism); when(vertex.getMaxParallelism()).thenReturn(maxParallelism); + ExecutionJobVertex jobVertex = mock(ExecutionJobVertex.class); + when(jobVertex.getOperatorIDs()).thenReturn(jobVertexIDs); + + when(vertex.getJobVertex()).thenReturn(jobVertex); + return vertex; } @@ -3135,8 +3250,16 @@ public class CheckpointCoordinatorTest { null, Executors.directExecutor()); - store.addCheckpoint( - new CompletedCheckpoint(new JobID(), 0, 0, 0, Collections.<JobVertexID, TaskState>emptyMap())); + store.addCheckpoint(new CompletedCheckpoint( + new JobID(), + 0, + 0, + 0, + Collections.<OperatorID, OperatorState>emptyMap(), + Collections.<MasterState>emptyList(), + CheckpointProperties.forStandardCheckpoint(), + null, + null)); CheckpointStatsTracker tracker = mock(CheckpointStatsTracker.class); coord.setCheckpointStatsTracker(tracker); @@ -3146,4 +3269,17 @@ public class CheckpointCoordinatorTest { verify(tracker, times(1)) .reportRestoredCheckpoint(any(RestoredCheckpointStats.class)); } + + private static final class SpyInjectingOperatorState extends OperatorState { + + private static final long serialVersionUID = -4004437428483663815L; + + public SpyInjectingOperatorState(OperatorID taskID, int parallelism, int maxParallelism) { + super(taskID, parallelism, maxParallelism); + } + + public void putState(int subtaskIndex, OperatorSubtaskState subtaskState) { + super.putState(subtaskIndex, spy(subtaskState)); + } + } }
http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java index 2fc1de5..7d24568 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java @@ -26,6 +26,7 @@ import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; import org.apache.flink.runtime.executiongraph.ExecutionVertex; import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings; import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint; import org.apache.flink.runtime.state.ChainedStateHandle; @@ -209,6 +210,8 @@ public class CheckpointStateRestoreTest { JobVertexID jobVertexId1 = new JobVertexID(); JobVertexID jobVertexId2 = new JobVertexID(); + OperatorID operatorId1 = OperatorID.fromJobVertexID(jobVertexId1); + // 1st JobVertex ExecutionVertex vertex11 = mockExecutionVertex(mockExecution(), jobVertexId1, 0, 3); ExecutionVertex vertex12 = mockExecutionVertex(mockExecution(), jobVertexId1, 1, 3); @@ -239,20 +242,30 @@ public class CheckpointStateRestoreTest { null, Executors.directExecutor()); - ChainedStateHandle<StreamStateHandle> serializedState = CheckpointCoordinatorTest - .generateChainedStateHandle(new SerializableObject()); + StreamStateHandle serializedState = CheckpointCoordinatorTest + .generateChainedStateHandle(new SerializableObject()) + .get(0); // --- (2) Checkpoint misses state for a jobVertex (should work) --- - Map<JobVertexID, TaskState> checkpointTaskStates = new HashMap<>(); + Map<OperatorID, OperatorState> checkpointTaskStates = new HashMap<>(); { - TaskState taskState = new TaskState(jobVertexId1, 3, 3, 1); - taskState.putState(0, new SubtaskState(serializedState, null, null, null, null)); - taskState.putState(1, new SubtaskState(serializedState, null, null, null, null)); - taskState.putState(2, new SubtaskState(serializedState, null, null, null, null)); + OperatorState taskState = new OperatorState(operatorId1, 3, 3); + taskState.putState(0, new OperatorSubtaskState(serializedState, null, null, null, null)); + taskState.putState(1, new OperatorSubtaskState(serializedState, null, null, null, null)); + taskState.putState(2, new OperatorSubtaskState(serializedState, null, null, null, null)); - checkpointTaskStates.put(jobVertexId1, taskState); + checkpointTaskStates.put(operatorId1, taskState); } - CompletedCheckpoint checkpoint = new CompletedCheckpoint(new JobID(), 0, 1, 2, new HashMap<>(checkpointTaskStates)); + CompletedCheckpoint checkpoint = new CompletedCheckpoint( + new JobID(), + 0, + 1, + 2, + new HashMap<>(checkpointTaskStates), + Collections.<MasterState>emptyList(), + CheckpointProperties.forStandardCheckpoint(), + null, + null); coord.getCheckpointStore().addCheckpoint(checkpoint); @@ -261,16 +274,26 @@ public class CheckpointStateRestoreTest { // --- (3) JobVertex missing for task state that is part of the checkpoint --- JobVertexID newJobVertexID = new JobVertexID(); + OperatorID newOperatorID = OperatorID.fromJobVertexID(newJobVertexID); // There is no task for this { - TaskState taskState = new TaskState(jobVertexId1, 1, 1, 1); - taskState.putState(0, new SubtaskState(serializedState, null, null, null, null)); + OperatorState taskState = new OperatorState(newOperatorID, 1, 1); + taskState.putState(0, new OperatorSubtaskState(serializedState, null, null, null, null)); - checkpointTaskStates.put(newJobVertexID, taskState); + checkpointTaskStates.put(newOperatorID, taskState); } - checkpoint = new CompletedCheckpoint(new JobID(), 1, 2, 3, new HashMap<>(checkpointTaskStates)); + checkpoint = new CompletedCheckpoint( + new JobID(), + 1, + 2, + 3, + new HashMap<>(checkpointTaskStates), + Collections.<MasterState>emptyList(), + CheckpointProperties.forStandardCheckpoint(), + null, + null); coord.getCheckpointStore().addCheckpoint(checkpoint); @@ -314,6 +337,12 @@ public class CheckpointStateRestoreTest { when(vertex.getMaxParallelism()).thenReturn(vertices.length); when(vertex.getJobVertexId()).thenReturn(id); when(vertex.getTaskVertices()).thenReturn(vertices); + when(vertex.getOperatorIDs()).thenReturn(Collections.singletonList(OperatorID.fromJobVertexID(id))); + when(vertex.getUserDefinedOperatorIDs()).thenReturn(Collections.<OperatorID>singletonList(null)); + + for (ExecutionVertex v : vertices) { + when(v.getJobVertex()).thenReturn(vertex); + } return vertex; } } http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java index fc6e516..94bd12f 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java @@ -20,8 +20,7 @@ package org.apache.flink.runtime.checkpoint; import org.apache.flink.api.common.JobID; import org.apache.flink.runtime.jobgraph.JobStatus; -import org.apache.flink.runtime.jobgraph.JobVertexID; -import org.apache.flink.runtime.state.KeyGroupRange; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.state.SharedStateRegistry; import org.apache.flink.util.TestLogger; import org.junit.Test; @@ -39,6 +38,7 @@ import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.mockito.Matchers.any; import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -105,7 +105,7 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger { assertEquals(1, checkpoints.getNumberOfRetainedCheckpoints()); for (int i = 1; i < expected.length; i++) { - Collection<TaskState> taskStates = expected[i - 1].getTaskStates().values(); + Collection<OperatorState> taskStates = expected[i - 1].getOperatorStates().values(); checkpoints.addCheckpoint(expected[i]); @@ -114,8 +114,8 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger { assertTrue(expected[i - 1].isDiscarded()); assertEquals(1, checkpoints.getNumberOfRetainedCheckpoints()); - for (TaskState taskState : taskStates) { - for (SubtaskState subtaskState : taskState.getStates()) { + for (OperatorState operatorState : taskStates) { + for (OperatorSubtaskState subtaskState : operatorState.getStates()) { verify(subtaskState, times(1)).unregisterSharedStates(any(SharedStateRegistry.class)); } } @@ -198,51 +198,43 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger { // --------------------------------------------------------------------------------------------- protected TestCompletedCheckpoint createCheckpoint(int id) throws IOException { - return createCheckpoint(id, 4); - } - - protected TestCompletedCheckpoint createCheckpoint(int id, int numberOfStates) - throws IOException { - return createCheckpoint(id, numberOfStates, CheckpointProperties.forStandardCheckpoint()); - } + int numberOfStates = 4; + CheckpointProperties props = CheckpointProperties.forStandardCheckpoint(); - protected TestCompletedCheckpoint createCheckpoint(int id, int numberOfStates, CheckpointProperties props) - throws IOException { + OperatorID operatorID = new OperatorID(); - JobVertexID jvid = new JobVertexID(); - - Map<JobVertexID, TaskState> taskGroupStates = new HashMap<>(); - TaskState taskState = new TaskState(jvid, numberOfStates, numberOfStates, 1); - taskGroupStates.put(jvid, taskState); + Map<OperatorID, OperatorState> operatorGroupState = new HashMap<>(); + OperatorState operatorState = new OperatorState(operatorID, numberOfStates, numberOfStates); + operatorGroupState.put(operatorID, operatorState); for (int i = 0; i < numberOfStates; i++) { - SubtaskState subtaskState = CheckpointCoordinatorTest.mockSubtaskState(jvid, i, new KeyGroupRange(i, i)); + OperatorSubtaskState subtaskState = mock(OperatorSubtaskState.class); - taskState.putState(i, subtaskState); + operatorState.putState(i, subtaskState); } - return new TestCompletedCheckpoint(new JobID(), id, 0, taskGroupStates, props); + return new TestCompletedCheckpoint(new JobID(), id, 0, operatorGroupState, props); } - protected void resetCheckpoint(Collection<TaskState> taskStates) { - for (TaskState taskState : taskStates) { - for (SubtaskState subtaskState : taskState.getStates()) { + protected void resetCheckpoint(Collection<OperatorState> operatorStates) { + for (OperatorState operatorState : operatorStates) { + for (OperatorSubtaskState subtaskState : operatorState.getStates()) { Mockito.reset(subtaskState); } } } - protected void verifyCheckpointRegistered(Collection<TaskState> taskStates, SharedStateRegistry registry) { - for (TaskState taskState : taskStates) { - for (SubtaskState subtaskState : taskState.getStates()) { + protected void verifyCheckpointRegistered(Collection<OperatorState> operatorStates, SharedStateRegistry registry) { + for (OperatorState operatorState : operatorStates) { + for (OperatorSubtaskState subtaskState : operatorState.getStates()) { verify(subtaskState, times(1)).registerSharedStates(eq(registry)); } } } - protected void verifyCheckpointDiscarded(Collection<TaskState> taskStates) { - for (TaskState taskState : taskStates) { - for (SubtaskState subtaskState : taskState.getStates()) { + protected void verifyCheckpointDiscarded(Collection<OperatorState> operatorStates) { + for (OperatorState operatorState : operatorStates) { + for (OperatorSubtaskState subtaskState : operatorState.getStates()) { verify(subtaskState, times(1)).discardState(); } } @@ -270,10 +262,9 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger { JobID jobId, long checkpointId, long timestamp, - Map<JobVertexID, TaskState> taskGroupStates, + Map<OperatorID, OperatorState> operatorGroupState, CheckpointProperties props) { - - super(jobId, checkpointId, timestamp, Long.MAX_VALUE, taskGroupStates, null, props); + super(jobId, checkpointId, timestamp, Long.MAX_VALUE, operatorGroupState, null, props, null, null); } @Override http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java index 652cc76..589ff46 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java @@ -23,7 +23,7 @@ import org.apache.flink.core.fs.Path; import org.apache.flink.core.testutils.CommonTestUtils; import org.apache.flink.runtime.jobgraph.JobStatus; import org.apache.flink.runtime.jobgraph.JobVertexID; -import org.apache.flink.runtime.state.SharedStateHandle; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.state.SharedStateRegistry; import org.apache.flink.runtime.state.filesystem.FileStateHandle; import org.junit.Rule; @@ -37,7 +37,6 @@ import java.util.HashMap; import java.util.Map; import static org.junit.Assert.assertEquals; -import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -55,9 +54,9 @@ public class CompletedCheckpointTest { File file = tmpFolder.newFile(); assertEquals(true, file.exists()); - TaskState state = mock(TaskState.class); - Map<JobVertexID, TaskState> taskStates = new HashMap<>(); - taskStates.put(new JobVertexID(), state); + OperatorState state = mock(OperatorState.class); + Map<OperatorID, OperatorState> taskStates = new HashMap<>(); + taskStates.put(new OperatorID(), state); // Verify discard call is forwarded to state CompletedCheckpoint checkpoint = new CompletedCheckpoint( @@ -78,18 +77,20 @@ public class CompletedCheckpointTest { */ @Test public void testCleanUpOnSubsume() throws Exception { - TaskState state = mock(TaskState.class); - Map<JobVertexID, TaskState> taskStates = new HashMap<>(); - taskStates.put(new JobVertexID(), state); + OperatorState state = mock(OperatorState.class); + Map<OperatorID, OperatorState> operatorStates = new HashMap<>(); + operatorStates.put(new OperatorID(), state); boolean discardSubsumed = true; CheckpointProperties props = new CheckpointProperties(false, false, discardSubsumed, true, true, true, true); CompletedCheckpoint checkpoint = new CompletedCheckpoint( new JobID(), 0, 0, 1, - taskStates, + operatorStates, Collections.<MasterState>emptyList(), - props); + props, + null, + null); SharedStateRegistry sharedStateRegistry = new SharedStateRegistry(); checkpoint.registerSharedStates(sharedStateRegistry); @@ -114,9 +115,9 @@ public class CompletedCheckpointTest { JobStatus.FINISHED, JobStatus.CANCELED, JobStatus.FAILED, JobStatus.SUSPENDED }; - TaskState state = mock(TaskState.class); - Map<JobVertexID, TaskState> taskStates = new HashMap<>(); - taskStates.put(new JobVertexID(), state); + OperatorState state = mock(OperatorState.class); + Map<OperatorID, OperatorState> operatorStates = new HashMap<>(); + operatorStates.put(new OperatorID(), state); for (JobStatus status : terminalStates) { Mockito.reset(state); @@ -125,7 +126,7 @@ public class CompletedCheckpointTest { CheckpointProperties props = new CheckpointProperties(false, true, false, false, false, false, false); CompletedCheckpoint checkpoint = new CompletedCheckpoint( new JobID(), 0, 0, 1, - new HashMap<>(taskStates), + new HashMap<>(operatorStates), Collections.<MasterState>emptyList(), props, new FileStateHandle(new Path(file.toURI()), file.length()), @@ -143,9 +144,11 @@ public class CompletedCheckpointTest { props = new CheckpointProperties(false, false, true, true, true, true, true); checkpoint = new CompletedCheckpoint( new JobID(), 0, 0, 1, - new HashMap<>(taskStates), + new HashMap<>(operatorStates), Collections.<MasterState>emptyList(), - props); + props, + null, + null); checkpoint.discardOnShutdown(status, sharedStateRegistry); verify(state, times(1)).discardState(); @@ -158,18 +161,20 @@ public class CompletedCheckpointTest { */ @Test public void testCompletedCheckpointStatsCallbacks() throws Exception { - TaskState state = mock(TaskState.class); - Map<JobVertexID, TaskState> taskStates = new HashMap<>(); - taskStates.put(new JobVertexID(), state); + OperatorState state = mock(OperatorState.class); + Map<OperatorID, OperatorState> operatorStates = new HashMap<>(); + operatorStates.put(new OperatorID(), state); CompletedCheckpoint completed = new CompletedCheckpoint( new JobID(), 0, 0, 1, - new HashMap<>(taskStates), + new HashMap<>(operatorStates), Collections.<MasterState>emptyList(), - CheckpointProperties.forStandardCheckpoint()); + CheckpointProperties.forStandardCheckpoint(), + null, + null); CompletedCheckpointStats.DiscardCallback callback = mock(CompletedCheckpointStats.DiscardCallback.class); completed.setDiscardCallback(callback); http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java index 2dd1803..6df01a0 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java @@ -22,15 +22,15 @@ import org.apache.flink.api.common.JobID; import org.apache.flink.runtime.concurrent.Executors; import org.apache.flink.runtime.concurrent.Future; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; +import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; import org.apache.flink.runtime.executiongraph.ExecutionVertex; import org.apache.flink.runtime.jobgraph.JobVertexID; -import org.apache.flink.runtime.state.SharedStateHandle; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.state.SharedStateRegistry; import org.junit.Assert; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; -import org.mockito.Mock; import org.mockito.Mockito; import java.io.File; @@ -50,9 +50,7 @@ import static org.junit.Assert.fail; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyLong; import static org.mockito.Mockito.doNothing; -import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.powermock.api.mockito.PowerMockito.when; @@ -63,9 +61,13 @@ public class PendingCheckpointTest { private static final ExecutionAttemptID ATTEMPT_ID = new ExecutionAttemptID(); static { + ExecutionJobVertex jobVertex = mock(ExecutionJobVertex.class); + when(jobVertex.getOperatorIDs()).thenReturn(Collections.singletonList(new OperatorID())); + ExecutionVertex vertex = mock(ExecutionVertex.class); when(vertex.getMaxParallelism()).thenReturn(128); when(vertex.getTotalNumberOfParallelSubtasks()).thenReturn(1); + when(vertex.getJobVertex()).thenReturn(jobVertex); ACK_TASKS.put(ATTEMPT_ID, vertex); } @@ -193,7 +195,7 @@ public class PendingCheckpointTest { CheckpointProperties props = new CheckpointProperties(false, true, false, false, false, false, false); QueueExecutor executor = new QueueExecutor(); - TaskState state = mock(TaskState.class); + OperatorState state = mock(OperatorState.class); doNothing().when(state).registerSharedStates(any(SharedStateRegistry.class)); doNothing().when(state).unregisterSharedStates(any(SharedStateRegistry.class)); @@ -311,7 +313,7 @@ public class PendingCheckpointTest { public void testNullSubtaskStateLeadsToStatelessTask() throws Exception { PendingCheckpoint pending = createPendingCheckpoint(CheckpointProperties.forStandardCheckpoint(), null); pending.acknowledgeTask(ATTEMPT_ID, null, mock(CheckpointMetrics.class)); - Assert.assertTrue(pending.getTaskStates().isEmpty()); + Assert.assertTrue(pending.getOperatorStates().isEmpty()); } /** @@ -324,7 +326,7 @@ public class PendingCheckpointTest { public void testNonNullSubtaskStateLeadsToStatefulTask() throws Exception { PendingCheckpoint pending = createPendingCheckpoint(CheckpointProperties.forStandardCheckpoint(), null); pending.acknowledgeTask(ATTEMPT_ID, mock(SubtaskState.class), mock(CheckpointMetrics.class)); - Assert.assertFalse(pending.getTaskStates().isEmpty()); + Assert.assertFalse(pending.getOperatorStates().isEmpty()); } @Test @@ -367,12 +369,12 @@ public class PendingCheckpointTest { } @SuppressWarnings("unchecked") - static void setTaskState(PendingCheckpoint pending, TaskState state) throws NoSuchFieldException, IllegalAccessException { - Field field = PendingCheckpoint.class.getDeclaredField("taskStates"); + static void setTaskState(PendingCheckpoint pending, OperatorState state) throws NoSuchFieldException, IllegalAccessException { + Field field = PendingCheckpoint.class.getDeclaredField("operatorStates"); field.setAccessible(true); - Map<JobVertexID, TaskState> taskStates = (Map<JobVertexID, TaskState>) field.get(pending); + Map<OperatorID, OperatorState> taskStates = (Map<OperatorID, OperatorState>) field.get(pending); - taskStates.put(new JobVertexID(), state); + taskStates.put(new OperatorID(), state); } private static final class QueueExecutor implements Executor { http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStoreTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStoreTest.java index 64aeeba..be94762 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStoreTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStoreTest.java @@ -54,16 +54,16 @@ public class StandaloneCompletedCheckpointStoreTest extends CompletedCheckpointS public void testShutdownDiscardsCheckpoints() throws Exception { AbstractCompletedCheckpointStore store = createCompletedCheckpoints(1); TestCompletedCheckpoint checkpoint = createCheckpoint(0); - Collection<TaskState> taskStates = checkpoint.getTaskStates().values(); + Collection<OperatorState> operatorStates = checkpoint.getOperatorStates().values(); store.addCheckpoint(checkpoint); assertEquals(1, store.getNumberOfRetainedCheckpoints()); - verifyCheckpointRegistered(taskStates, store.sharedStateRegistry); + verifyCheckpointRegistered(operatorStates, store.sharedStateRegistry); store.shutdown(JobStatus.FINISHED); assertEquals(0, store.getNumberOfRetainedCheckpoints()); assertTrue(checkpoint.isDiscarded()); - verifyCheckpointDiscarded(taskStates); + verifyCheckpointDiscarded(operatorStates); } /** @@ -74,7 +74,7 @@ public class StandaloneCompletedCheckpointStoreTest extends CompletedCheckpointS public void testSuspendDiscardsCheckpoints() throws Exception { AbstractCompletedCheckpointStore store = createCompletedCheckpoints(1); TestCompletedCheckpoint checkpoint = createCheckpoint(0); - Collection<TaskState> taskStates = checkpoint.getTaskStates().values(); + Collection<OperatorState> taskStates = checkpoint.getOperatorStates().values(); store.addCheckpoint(checkpoint); assertEquals(1, store.getNumberOfRetainedCheckpoints()); @@ -99,7 +99,7 @@ public class StandaloneCompletedCheckpointStoreTest extends CompletedCheckpointS for (long i = 0; i <= numCheckpointsToRetain; ++i) { CompletedCheckpoint checkpointToAdd = mock(CompletedCheckpoint.class); doReturn(i).when(checkpointToAdd).getCheckpointID(); - doReturn(Collections.emptyMap()).when(checkpointToAdd).getTaskStates(); + doReturn(Collections.emptyMap()).when(checkpointToAdd).getOperatorStates(); doThrow(new IOException()).when(checkpointToAdd).discardOnSubsume(any(SharedStateRegistry.class)); try { http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java index 73fcf78..73e0ed9 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java @@ -90,17 +90,17 @@ public class ZooKeeperCompletedCheckpointStoreITCase extends CompletedCheckpoint checkpoints.addCheckpoint(expected[1]); checkpoints.addCheckpoint(expected[2]); - verifyCheckpointRegistered(expected[0].getTaskStates().values(), checkpoints.sharedStateRegistry); - verifyCheckpointRegistered(expected[1].getTaskStates().values(), checkpoints.sharedStateRegistry); - verifyCheckpointRegistered(expected[2].getTaskStates().values(), checkpoints.sharedStateRegistry); + verifyCheckpointRegistered(expected[0].getOperatorStates().values(), checkpoints.sharedStateRegistry); + verifyCheckpointRegistered(expected[1].getOperatorStates().values(), checkpoints.sharedStateRegistry); + verifyCheckpointRegistered(expected[2].getOperatorStates().values(), checkpoints.sharedStateRegistry); // All three should be in ZK assertEquals(3, ZooKeeper.getClient().getChildren().forPath(CheckpointsPath).size()); assertEquals(3, checkpoints.getNumberOfRetainedCheckpoints()); - resetCheckpoint(expected[0].getTaskStates().values()); - resetCheckpoint(expected[1].getTaskStates().values()); - resetCheckpoint(expected[2].getTaskStates().values()); + resetCheckpoint(expected[0].getOperatorStates().values()); + resetCheckpoint(expected[1].getOperatorStates().values()); + resetCheckpoint(expected[2].getOperatorStates().values()); // Recover TODO!!! clear registry! checkpoints.recover(); @@ -121,7 +121,7 @@ public class ZooKeeperCompletedCheckpointStoreITCase extends CompletedCheckpoint assertEquals(expectedCheckpoints, actualCheckpoints); for (CompletedCheckpoint actualCheckpoint : actualCheckpoints) { - verifyCheckpointRegistered(actualCheckpoint.getTaskStates().values(), checkpoints.sharedStateRegistry); + verifyCheckpointRegistered(actualCheckpoint.getOperatorStates().values(), checkpoints.sharedStateRegistry); } } http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreTest.java index 66ef232..8fc0f02 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreTest.java @@ -227,7 +227,7 @@ public class ZooKeeperCompletedCheckpointStoreTest extends TestLogger { for (long i = 0; i <= numCheckpointsToRetain; ++i) { CompletedCheckpoint checkpointToAdd = mock(CompletedCheckpoint.class); doReturn(i).when(checkpointToAdd).getCheckpointID(); - doReturn(Collections.emptyMap()).when(checkpointToAdd).getTaskStates(); + doReturn(Collections.emptyMap()).when(checkpointToAdd).getOperatorStates(); try { zooKeeperCompletedCheckpointStore.addCheckpoint(checkpointToAdd); http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/CheckpointTestUtils.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/CheckpointTestUtils.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/CheckpointTestUtils.java index 7d9874e..ba77dbc 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/CheckpointTestUtils.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/CheckpointTestUtils.java @@ -20,9 +20,12 @@ package org.apache.flink.runtime.checkpoint.savepoint; import org.apache.flink.configuration.ConfigConstants; import org.apache.flink.runtime.checkpoint.MasterState; +import org.apache.flink.runtime.checkpoint.OperatorState; +import org.apache.flink.runtime.checkpoint.OperatorSubtaskState; import org.apache.flink.runtime.checkpoint.SubtaskState; import org.apache.flink.runtime.checkpoint.TaskState; import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.state.ChainedStateHandle; import org.apache.flink.runtime.state.KeyGroupRangeOffsets; import org.apache.flink.runtime.state.KeyGroupsStateHandle; @@ -49,6 +52,94 @@ import static org.junit.Assert.assertEquals; public class CheckpointTestUtils { /** + * Creates a random collection of OperatorState objects containing various types of state handles. + */ + public static Collection<OperatorState> createOperatorStates(int numTaskStates, int numSubtasksPerTask) { + return createOperatorStates(new Random(), numTaskStates, numSubtasksPerTask); + } + + /** + * Creates a random collection of OperatorState objects containing various types of state handles. + */ + public static Collection<OperatorState> createOperatorStates( + Random random, + int numTaskStates, + int numSubtasksPerTask) { + + List<OperatorState> taskStates = new ArrayList<>(numTaskStates); + + for (int stateIdx = 0; stateIdx < numTaskStates; ++stateIdx) { + + OperatorState taskState = new OperatorState(new OperatorID(), numSubtasksPerTask, 128); + + boolean hasNonPartitionableState = random.nextBoolean(); + boolean hasOperatorStateBackend = random.nextBoolean(); + boolean hasOperatorStateStream = random.nextBoolean(); + + boolean hasKeyedBackend = random.nextInt(4) != 0; + boolean hasKeyedStream = random.nextInt(4) != 0; + + for (int subtaskIdx = 0; subtaskIdx < numSubtasksPerTask; subtaskIdx++) { + + StreamStateHandle nonPartitionableState = null; + StreamStateHandle operatorStateBackend = + new TestByteStreamStateHandleDeepCompare("b", ("Beautiful").getBytes(ConfigConstants.DEFAULT_CHARSET)); + StreamStateHandle operatorStateStream = + new TestByteStreamStateHandleDeepCompare("b", ("Beautiful").getBytes(ConfigConstants.DEFAULT_CHARSET)); + + OperatorStateHandle operatorStateHandleBackend = null; + OperatorStateHandle operatorStateHandleStream = null; + + Map<String, StateMetaInfo> offsetsMap = new HashMap<>(); + offsetsMap.put("A", new OperatorStateHandle.StateMetaInfo(new long[]{0, 10, 20}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE)); + offsetsMap.put("B", new OperatorStateHandle.StateMetaInfo(new long[]{30, 40, 50}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE)); + offsetsMap.put("C", new OperatorStateHandle.StateMetaInfo(new long[]{60, 70, 80}, OperatorStateHandle.Mode.BROADCAST)); + + if (hasNonPartitionableState) { + nonPartitionableState = + new TestByteStreamStateHandleDeepCompare("a", ("Hi").getBytes(ConfigConstants.DEFAULT_CHARSET)); + } + + if (hasOperatorStateBackend) { + operatorStateHandleBackend = new OperatorStateHandle(offsetsMap, operatorStateBackend); + } + + if (hasOperatorStateStream) { + operatorStateHandleStream = new OperatorStateHandle(offsetsMap, operatorStateStream); + } + + KeyGroupsStateHandle keyedStateBackend = null; + KeyGroupsStateHandle keyedStateStream = null; + + if (hasKeyedBackend) { + keyedStateBackend = new KeyGroupsStateHandle( + new KeyGroupRangeOffsets(1, 1, new long[]{42}), + new TestByteStreamStateHandleDeepCompare("c", "Hello" + .getBytes(ConfigConstants.DEFAULT_CHARSET))); + } + + if (hasKeyedStream) { + keyedStateStream = new KeyGroupsStateHandle( + new KeyGroupRangeOffsets(1, 1, new long[]{23}), + new TestByteStreamStateHandleDeepCompare("d", "World" + .getBytes(ConfigConstants.DEFAULT_CHARSET))); + } + + taskState.putState(subtaskIdx, new OperatorSubtaskState( + nonPartitionableState, + operatorStateHandleBackend, + operatorStateHandleStream, + keyedStateStream, + keyedStateBackend)); + } + + taskStates.add(taskState); + } + + return taskStates; + } + + /** * Creates a random collection of TaskState objects containing various types of state handles. */ public static Collection<TaskState> createTaskStates(int numTaskStates, int numSubtasksPerTask) { @@ -88,7 +179,7 @@ public class CheckpointTestUtils { StreamStateHandle nonPartitionableState = new TestByteStreamStateHandleDeepCompare("a-" + chainIdx, ("Hi-" + chainIdx).getBytes( - ConfigConstants.DEFAULT_CHARSET)); + ConfigConstants.DEFAULT_CHARSET)); StreamStateHandle operatorStateBackend = new TestByteStreamStateHandleDeepCompare("b-" + chainIdx, ("Beautiful-" + chainIdx).getBytes(ConfigConstants.DEFAULT_CHARSET)); StreamStateHandle operatorStateStream = @@ -122,14 +213,14 @@ public class CheckpointTestUtils { keyedStateBackend = new KeyGroupsStateHandle( new KeyGroupRangeOffsets(1, 1, new long[]{42}), new TestByteStreamStateHandleDeepCompare("c", "Hello" - .getBytes(ConfigConstants.DEFAULT_CHARSET))); + .getBytes(ConfigConstants.DEFAULT_CHARSET))); } if (hasKeyedStream) { keyedStateStream = new KeyGroupsStateHandle( new KeyGroupRangeOffsets(1, 1, new long[]{23}), new TestByteStreamStateHandleDeepCompare("d", "World" - .getBytes(ConfigConstants.DEFAULT_CHARSET))); + .getBytes(ConfigConstants.DEFAULT_CHARSET))); } taskState.putState(subtaskIdx, new SubtaskState( http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointLoaderTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointLoaderTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointLoaderTest.java index 20b1e57..331621d 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointLoaderTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointLoaderTest.java @@ -20,14 +20,17 @@ package org.apache.flink.runtime.checkpoint.savepoint; import org.apache.flink.api.common.JobID; import org.apache.flink.runtime.checkpoint.CompletedCheckpoint; -import org.apache.flink.runtime.checkpoint.TaskState; +import org.apache.flink.runtime.checkpoint.MasterState; +import org.apache.flink.runtime.checkpoint.OperatorState; import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; import java.io.File; +import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -53,29 +56,30 @@ public class SavepointLoaderTest { int parallelism = 128128; long checkpointId = Integer.MAX_VALUE + 123123L; - JobVertexID vertexId = new JobVertexID(); + JobVertexID jobVertexID = new JobVertexID(); + OperatorID operatorID = OperatorID.fromJobVertexID(jobVertexID); - TaskState state = mock(TaskState.class); + OperatorState state = mock(OperatorState.class); when(state.getParallelism()).thenReturn(parallelism); - when(state.getJobVertexID()).thenReturn(vertexId); + when(state.getOperatorID()).thenReturn(operatorID); when(state.getMaxParallelism()).thenReturn(parallelism); - when(state.getChainLength()).thenReturn(1); - Map<JobVertexID, TaskState> taskStates = new HashMap<>(); - taskStates.put(vertexId, state); + Map<OperatorID, OperatorState> taskStates = new HashMap<>(); + taskStates.put(operatorID, state); JobID jobId = new JobID(); // Store savepoint - SavepointV2 savepoint = new SavepointV2(checkpointId, taskStates.values()); + SavepointV2 savepoint = new SavepointV2(checkpointId, taskStates.values(), Collections.<MasterState>emptyList()); String path = SavepointStore.storeSavepoint(tmp.getAbsolutePath(), savepoint); ExecutionJobVertex vertex = mock(ExecutionJobVertex.class); when(vertex.getParallelism()).thenReturn(parallelism); when(vertex.getMaxParallelism()).thenReturn(parallelism); + when(vertex.getOperatorIDs()).thenReturn(Collections.singletonList(operatorID)); Map<JobVertexID, ExecutionJobVertex> tasks = new HashMap<>(); - tasks.put(vertexId, vertex); + tasks.put(jobVertexID, vertex); ClassLoader ucl = Thread.currentThread().getContextClassLoader(); @@ -97,7 +101,7 @@ public class SavepointLoaderTest { } // 3) Load and validate: missing vertex - assertNotNull(tasks.remove(vertexId)); + assertNotNull(tasks.remove(jobVertexID)); try { SavepointLoader.loadAndValidateSavepoint(jobId, tasks, path, ucl, false); http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointStoreTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointStoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointStoreTest.java index cf79282..391102c 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointStoreTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointStoreTest.java @@ -24,6 +24,7 @@ import org.apache.flink.core.fs.FSDataOutputStream; import org.apache.flink.core.fs.FileSystem; import org.apache.flink.core.fs.Path; import org.apache.flink.runtime.checkpoint.MasterState; +import org.apache.flink.runtime.checkpoint.OperatorState; import org.apache.flink.runtime.checkpoint.TaskState; import org.apache.flink.runtime.state.filesystem.FileStateHandle; import org.junit.Rule; @@ -69,7 +70,10 @@ public class SavepointStoreTest { // Store String savepointDirectory = SavepointStore.createSavepointDirectory(root, new JobID()); - SavepointV2 stored = new SavepointV2(1929292, CheckpointTestUtils.createTaskStates(4, 24)); + SavepointV2 stored = new SavepointV2( + 1929292, + CheckpointTestUtils.createOperatorStates(4, 24), + Collections.<MasterState>emptyList()); String path = SavepointStore.storeSavepoint(savepointDirectory, stored); list = rootFile.listFiles(); @@ -80,7 +84,7 @@ public class SavepointStoreTest { Savepoint loaded = SavepointStore.loadSavepoint(path, Thread.currentThread().getContextClassLoader()); assertEquals(stored.getCheckpointId(), loaded.getCheckpointId()); - assertEquals(stored.getTaskStates(), loaded.getTaskStates()); + assertEquals(stored.getOperatorStates(), loaded.getOperatorStates()); assertEquals(stored.getMasterStates(), loaded.getMasterStates()); loaded.dispose(); @@ -147,7 +151,10 @@ public class SavepointStoreTest { // Savepoint v0 String savepointDirectory2 = SavepointStore.createSavepointDirectory(root, new JobID()); - SavepointV2 savepoint = new SavepointV2(checkpointId, CheckpointTestUtils.createTaskStates(4, 32)); + SavepointV2 savepoint = new SavepointV2( + checkpointId, + CheckpointTestUtils.createOperatorStates(4, 32), + Collections.<MasterState>emptyList()); String pathSavepoint = SavepointStore.storeSavepoint(savepointDirectory2, savepoint); list = rootFile.listFiles(); @@ -205,7 +212,10 @@ public class SavepointStoreTest { FileSystem fs = FileSystem.get(new Path(root).toUri()); // Store - SavepointV2 savepoint = new SavepointV2(1929292, CheckpointTestUtils.createTaskStates(4, 24)); + SavepointV2 savepoint = new SavepointV2( + 1929292, + CheckpointTestUtils.createOperatorStates(4, 24), + Collections.<MasterState>emptyList()); FileStateHandle store1 = SavepointStore.storeExternalizedCheckpointToHandle(root, savepoint); fs.exists(store1.getFilePath()); @@ -266,6 +276,11 @@ public class SavepointStoreTest { } @Override + public Collection<OperatorState> getOperatorStates() { + return null; + } + + @Override public void dispose() { } http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2SerializerTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2SerializerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2SerializerTest.java index deb14dd..154d761 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2SerializerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2SerializerTest.java @@ -23,7 +23,7 @@ import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos; import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; import org.apache.flink.runtime.checkpoint.MasterState; -import org.apache.flink.runtime.checkpoint.TaskState; +import org.apache.flink.runtime.checkpoint.OperatorState; import org.junit.Test; @@ -48,7 +48,7 @@ public class SavepointV2SerializerTest { for (int i = 0; i < 100; ++i) { final long checkpointId = rnd.nextLong() & 0x7fffffffffffffffL; - final Collection<TaskState> taskStates = Collections.emptyList(); + final Collection<OperatorState> taskStates = Collections.emptyList(); final Collection<MasterState> masterStates = Collections.emptyList(); testCheckpointSerialization(checkpointId, taskStates, masterStates); @@ -63,13 +63,13 @@ public class SavepointV2SerializerTest { for (int i = 0; i < 100; ++i) { final long checkpointId = rnd.nextLong() & 0x7fffffffffffffffL; - final Collection<TaskState> taskStates = Collections.emptyList(); + final Collection<OperatorState> operatorStates = Collections.emptyList(); final int numMasterStates = rnd.nextInt(maxNumMasterStates) + 1; final Collection<MasterState> masterStates = CheckpointTestUtils.createRandomMasterStates(rnd, numMasterStates); - testCheckpointSerialization(checkpointId, taskStates, masterStates); + testCheckpointSerialization(checkpointId, operatorStates, masterStates); } } @@ -84,8 +84,8 @@ public class SavepointV2SerializerTest { final int numTasks = rnd.nextInt(maxTaskStates) + 1; final int numSubtasks = rnd.nextInt(maxNumSubtasks) + 1; - final Collection<TaskState> taskStates = - CheckpointTestUtils.createTaskStates(rnd, numTasks, numSubtasks); + final Collection<OperatorState> taskStates = + CheckpointTestUtils.createOperatorStates(rnd, numTasks, numSubtasks); final Collection<MasterState> masterStates = Collections.emptyList(); @@ -106,8 +106,8 @@ public class SavepointV2SerializerTest { final int numTasks = rnd.nextInt(maxTaskStates) + 1; final int numSubtasks = rnd.nextInt(maxNumSubtasks) + 1; - final Collection<TaskState> taskStates = - CheckpointTestUtils.createTaskStates(rnd, numTasks, numSubtasks); + final Collection<OperatorState> taskStates = + CheckpointTestUtils.createOperatorStates(rnd, numTasks, numSubtasks); final int numMasterStates = rnd.nextInt(maxNumMasterStates) + 1; final Collection<MasterState> masterStates = @@ -119,7 +119,7 @@ public class SavepointV2SerializerTest { private void testCheckpointSerialization( long checkpointId, - Collection<TaskState> taskStates, + Collection<OperatorState> operatorStates, Collection<MasterState> masterStates) throws IOException { SavepointV2Serializer serializer = SavepointV2Serializer.INSTANCE; @@ -127,7 +127,7 @@ public class SavepointV2SerializerTest { ByteArrayOutputStreamWithPos baos = new ByteArrayOutputStreamWithPos(); DataOutputStream out = new DataOutputViewStreamWrapper(baos); - serializer.serialize(new SavepointV2(checkpointId, taskStates, masterStates), out); + serializer.serialize(new SavepointV2(checkpointId, operatorStates, masterStates), out); out.close(); byte[] bytes = baos.toByteArray(); @@ -136,7 +136,7 @@ public class SavepointV2SerializerTest { SavepointV2 deserialized = serializer.deserialize(in, getClass().getClassLoader()); assertEquals(checkpointId, deserialized.getCheckpointId()); - assertEquals(taskStates, deserialized.getTaskStates()); + assertEquals(operatorStates, deserialized.getOperatorStates()); assertEquals(masterStates.size(), deserialized.getMasterStates().size()); for (Iterator<MasterState> a = masterStates.iterator(), b = deserialized.getMasterStates().iterator(); http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2Test.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2Test.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2Test.java index 428a62a..6b6a6d4 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2Test.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV2Test.java @@ -19,7 +19,7 @@ package org.apache.flink.runtime.checkpoint.savepoint; import org.apache.flink.runtime.checkpoint.MasterState; -import org.apache.flink.runtime.checkpoint.TaskState; +import org.apache.flink.runtime.checkpoint.OperatorState; import org.junit.Test; @@ -36,7 +36,7 @@ public class SavepointV2Test { * Simple test of savepoint methods. */ @Test - public void testSavepointV1() throws Exception { + public void testSavepointV2() throws Exception { final Random rnd = new Random(); final long checkpointId = rnd.nextInt(Integer.MAX_VALUE) + 1; @@ -44,8 +44,8 @@ public class SavepointV2Test { final int numSubtaskStates = 16; final int numMasterStates = 7; - Collection<TaskState> taskStates = - CheckpointTestUtils.createTaskStates(rnd, numTaskStates, numSubtaskStates); + Collection<OperatorState> taskStates = + CheckpointTestUtils.createOperatorStates(rnd, numTaskStates, numSubtaskStates); Collection<MasterState> masterStates = CheckpointTestUtils.createRandomMasterStates(rnd, numMasterStates); @@ -54,15 +54,15 @@ public class SavepointV2Test { assertEquals(2, checkpoint.getVersion()); assertEquals(checkpointId, checkpoint.getCheckpointId()); - assertEquals(taskStates, checkpoint.getTaskStates()); + assertEquals(taskStates, checkpoint.getOperatorStates()); assertEquals(masterStates, checkpoint.getMasterStates()); - assertFalse(checkpoint.getTaskStates().isEmpty()); + assertFalse(checkpoint.getOperatorStates().isEmpty()); assertFalse(checkpoint.getMasterStates().isEmpty()); checkpoint.dispose(); - assertTrue(checkpoint.getTaskStates().isEmpty()); + assertTrue(checkpoint.getOperatorStates().isEmpty()); assertTrue(checkpoint.getMasterStates().isEmpty()); } } http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/LegacyJobVertexIdTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/LegacyJobVertexIdTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/LegacyJobVertexIdTest.java index 89db3a1..b5a67fd 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/LegacyJobVertexIdTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/LegacyJobVertexIdTest.java @@ -25,11 +25,13 @@ import org.apache.flink.runtime.executiongraph.restart.RestartStrategy; import org.apache.flink.runtime.instance.SlotProvider; import org.apache.flink.runtime.jobgraph.JobVertex; import org.apache.flink.runtime.jobgraph.JobVertexID; +import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; import org.apache.flink.util.SerializedValue; import org.junit.Assert; import org.junit.Test; +import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.Map; @@ -46,7 +48,7 @@ public class LegacyJobVertexIdTest { JobVertexID legacyId1 = new JobVertexID(); JobVertexID legacyId2 = new JobVertexID(); - JobVertex jobVertex = new JobVertex("test", defaultId, Arrays.asList(legacyId1, legacyId2)); + JobVertex jobVertex = new JobVertex("test", defaultId, Arrays.asList(legacyId1, legacyId2), new ArrayList<OperatorID>(), new ArrayList<OperatorID>()); jobVertex.setInvokableClass(AbstractInvokable.class); ExecutionGraph executionGraph = new ExecutionGraph( http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java b/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java index a932c18..11a03cc 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java @@ -76,7 +76,7 @@ public class RecoverableCompletedCheckpointStore extends AbstractCompletedCheckp suspended.clear(); for (CompletedCheckpoint checkpoint : checkpoints) { - sharedStateRegistry.unregisterAll(checkpoint.getTaskStates().values()); + sharedStateRegistry.unregisterAll(checkpoint.getOperatorStates().values()); suspended.add(checkpoint); } http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/flink-streaming-java/src/main/java/org/apache/flink/migration/streaming/api/graph/StreamGraphHasherV1.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/migration/streaming/api/graph/StreamGraphHasherV1.java b/flink-streaming-java/src/main/java/org/apache/flink/migration/streaming/api/graph/StreamGraphHasherV1.java index 2fbfe1c..f468c93 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/migration/streaming/api/graph/StreamGraphHasherV1.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/migration/streaming/api/graph/StreamGraphHasherV1.java @@ -153,19 +153,6 @@ public class StreamGraphHasherV1 implements StreamGraphHasher { return true; } else { - // Check that this node is not part of a chain. This is currently - // not supported, because the runtime takes the snapshots by the - // operator ID of the first vertex in a chain. It's OK if the node - // has chained outputs. - for (StreamEdge inEdge : node.getInEdges()) { - if (isChainable(inEdge, isChainingEnabled)) { - throw new UnsupportedOperationException("Cannot assign user-specified hash " - + "to intermediate node in chain. This will be supported in future " - + "versions of Flink. As a work around start new chain at task " - + node.getOperatorName() + "."); - } - } - Hasher hasher = hashFunction.newHasher(); byte[] hash = generateUserSpecifiedHash(node, hasher); http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphHasherV2.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphHasherV2.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphHasherV2.java index 3772e58..7c2416e 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphHasherV2.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphHasherV2.java @@ -169,19 +169,6 @@ public class StreamGraphHasherV2 implements StreamGraphHasher { return true; } else { - // Check that this node is not part of a chain. This is currently - // not supported, because the runtime takes the snapshots by the - // operator ID of the first vertex in a chain. It's OK if the node - // has chained outputs. - for (StreamEdge inEdge : node.getInEdges()) { - if (isChainable(inEdge, isChainingEnabled)) { - throw new UnsupportedOperationException("Cannot assign user-specified hash " - + "to intermediate node in chain. This will be supported in future " - + "versions of Flink. As a work around start new chain at task " - + node.getOperatorName() + "."); - } - } - Hasher hasher = hashFunction.newHasher(); byte[] hash = generateUserSpecifiedHash(node, hasher); http://git-wip-us.apache.org/repos/asf/flink/blob/f7980a7e/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphUserHashHasher.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphUserHashHasher.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphUserHashHasher.java index c1750a1..8a8c8b0 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphUserHashHasher.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamGraphUserHashHasher.java @@ -38,15 +38,6 @@ public class StreamGraphUserHashHasher implements StreamGraphHasher { String userHash = streamNode.getUserHash(); if (null != userHash) { - for (StreamEdge inEdge : streamNode.getInEdges()) { - if (StreamingJobGraphGenerator.isChainable(inEdge, streamGraph)) { - throw new UnsupportedOperationException("Cannot assign user-specified hash " - + "to intermediate node in chain. This will be supported in future " - + "versions of Flink. As a work around start new chain at task " - + streamNode.getOperatorName() + "."); - } - } - hashResult.put(streamNode.getId(), StringUtils.hexStringToByte(userHash)); } }