Github user tzulitai commented on a diff in the pull request: https://github.com/apache/flink/pull/3001#discussion_r112123126 --- Diff: flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java --- @@ -559,48 +699,298 @@ public void testFetcherShouldNotBeRestoringFromFailureIfNotRestoringFromCheckpoi @Test @SuppressWarnings("unchecked") + public void testFetcherShouldBeCorrectlySeededIfRestoringFromLegacyCheckpoint() throws Exception { + HashMap<KinesisStreamShard, SequenceNumber> fakeRestoredState = getFakeRestoredStore("all"); + + KinesisDataFetcher mockedFetcher = Mockito.mock(KinesisDataFetcher.class); + List<KinesisStreamShard> shards = new ArrayList<>(); + shards.addAll(fakeRestoredState.keySet()); + when(mockedFetcher.discoverNewShardsToSubscribe()).thenReturn(shards); + PowerMockito.whenNew(KinesisDataFetcher.class).withAnyArguments().thenReturn(mockedFetcher); + + // assume the given config is correct + PowerMockito.mockStatic(KinesisConfigUtil.class); + PowerMockito.doNothing().when(KinesisConfigUtil.class); + + TestableFlinkKinesisConsumer consumer = new TestableFlinkKinesisConsumer( + "fakeStream", new Properties(), 10, 2); + consumer.restoreState(fakeRestoredState); + consumer.open(new Configuration()); + consumer.run(Mockito.mock(SourceFunction.SourceContext.class)); + + Mockito.verify(mockedFetcher).setIsRestoringFromFailure(true); + for (Map.Entry<KinesisStreamShard, SequenceNumber> restoredShard : fakeRestoredState.entrySet()) { + Mockito.verify(mockedFetcher).advanceLastDiscoveredShardOfStream( + restoredShard.getKey().getStreamName(), restoredShard.getKey().getShard().getShardId()); + Mockito.verify(mockedFetcher).registerNewSubscribedShardState( + new KinesisStreamShardState(restoredShard.getKey(), restoredShard.getValue())); + } + } + + @Test + @SuppressWarnings("unchecked") public void testFetcherShouldBeCorrectlySeededIfRestoringFromCheckpoint() throws Exception { + // ---------------------------------------------------------------------- + // setting initial state + // ---------------------------------------------------------------------- + HashMap<KinesisStreamShard, SequenceNumber> fakeRestoredState = getFakeRestoredStore("all"); + + // ---------------------------------------------------------------------- + // mock operator state backend and initial state for initializeState() + // ---------------------------------------------------------------------- + TestingListState<Serializable> listState = new TestingListState<>(); + for (Map.Entry<KinesisStreamShard, SequenceNumber> state: fakeRestoredState.entrySet()) { + listState.add(Tuple2.of(state.getKey(), state.getValue())); + } + + OperatorStateStore operatorStateStore = mock(OperatorStateStore.class); + when(operatorStateStore.getUnionListState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState); + + StateInitializationContext initializationContext = mock(StateInitializationContext.class); + when(initializationContext.getOperatorStateStore()).thenReturn(operatorStateStore); + when(initializationContext.isRestored()).thenReturn(true); + + // ---------------------------------------------------------------------- + // mock fetcher + // ---------------------------------------------------------------------- KinesisDataFetcher mockedFetcher = Mockito.mock(KinesisDataFetcher.class); + List<KinesisStreamShard> shards = new ArrayList<>(); + shards.addAll(fakeRestoredState.keySet()); + when(mockedFetcher.discoverNewShardsToSubscribe()).thenReturn(shards); PowerMockito.whenNew(KinesisDataFetcher.class).withAnyArguments().thenReturn(mockedFetcher); // assume the given config is correct PowerMockito.mockStatic(KinesisConfigUtil.class); PowerMockito.doNothing().when(KinesisConfigUtil.class); - HashMap<KinesisStreamShard, SequenceNumber> fakeRestoredState = new HashMap<>(); - fakeRestoredState.put( - new KinesisStreamShard("fakeStream1", - new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))), - new SequenceNumber(UUID.randomUUID().toString())); - fakeRestoredState.put( - new KinesisStreamShard("fakeStream1", - new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(1))), - new SequenceNumber(UUID.randomUUID().toString())); - fakeRestoredState.put( - new KinesisStreamShard("fakeStream1", - new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(2))), - new SequenceNumber(UUID.randomUUID().toString())); - fakeRestoredState.put( - new KinesisStreamShard("fakeStream2", - new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))), - new SequenceNumber(UUID.randomUUID().toString())); - fakeRestoredState.put( - new KinesisStreamShard("fakeStream2", - new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(1))), - new SequenceNumber(UUID.randomUUID().toString())); + // ---------------------------------------------------------------------- + // start to test seed initial state to fetcher + // ---------------------------------------------------------------------- + TestableFlinkKinesisConsumer consumer = new TestableFlinkKinesisConsumer( + "fakeStream", new Properties(), 10, 2); + consumer.initializeState(initializationContext); + consumer.open(new Configuration()); + consumer.run(Mockito.mock(SourceFunction.SourceContext.class)); + + Mockito.verify(mockedFetcher).setIsRestoringFromFailure(true); + for (Map.Entry<KinesisStreamShard, SequenceNumber> restoredShard : fakeRestoredState.entrySet()) { + Mockito.verify(mockedFetcher).advanceLastDiscoveredShardOfStream( + restoredShard.getKey().getStreamName(), restoredShard.getKey().getShard().getShardId()); + Mockito.verify(mockedFetcher).registerNewSubscribedShardState( + new KinesisStreamShardState(restoredShard.getKey(), restoredShard.getValue())); + } + } + + @Test + @SuppressWarnings("unchecked") + public void testFetcherShouldBeCorrectlySeededOnlyItsOwnStates() throws Exception { + // ---------------------------------------------------------------------- + // setting initial state + // ---------------------------------------------------------------------- + HashMap<KinesisStreamShard, SequenceNumber> fakeRestoredState = getFakeRestoredStore("fakeStream1"); + + HashMap<KinesisStreamShard, SequenceNumber> fakeRestoredStateForOthers = getFakeRestoredStore("fakeStream2"); + + // ---------------------------------------------------------------------- + // mock operator state backend and initial state for initializeState() + // ---------------------------------------------------------------------- + TestingListState<Serializable> listState = new TestingListState<>(); + for (Map.Entry<KinesisStreamShard, SequenceNumber> state: fakeRestoredState.entrySet()) { + listState.add(Tuple2.of(state.getKey(), state.getValue())); + } + for (Map.Entry<KinesisStreamShard, SequenceNumber> state: fakeRestoredStateForOthers.entrySet()) { + listState.add(Tuple2.of(state.getKey(), state.getValue())); + } + OperatorStateStore operatorStateStore = mock(OperatorStateStore.class); + when(operatorStateStore.getUnionListState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState); + + StateInitializationContext initializationContext = mock(StateInitializationContext.class); + when(initializationContext.getOperatorStateStore()).thenReturn(operatorStateStore); + when(initializationContext.isRestored()).thenReturn(true); + + // ---------------------------------------------------------------------- + // mock fetcher + // ---------------------------------------------------------------------- + KinesisDataFetcher mockedFetcher = Mockito.mock(KinesisDataFetcher.class); + List<KinesisStreamShard> shards = new ArrayList<>(); + shards.addAll(fakeRestoredState.keySet()); + when(mockedFetcher.discoverNewShardsToSubscribe()).thenReturn(shards); + PowerMockito.whenNew(KinesisDataFetcher.class).withAnyArguments().thenReturn(mockedFetcher); + + // assume the given config is correct + PowerMockito.mockStatic(KinesisConfigUtil.class); + PowerMockito.doNothing().when(KinesisConfigUtil.class); + + // ---------------------------------------------------------------------- + // start to test seed initial state to fetcher + // ---------------------------------------------------------------------- TestableFlinkKinesisConsumer consumer = new TestableFlinkKinesisConsumer( "fakeStream", new Properties(), 10, 2); - consumer.restoreState(fakeRestoredState); + consumer.initializeState(initializationContext); consumer.open(new Configuration()); consumer.run(Mockito.mock(SourceFunction.SourceContext.class)); Mockito.verify(mockedFetcher).setIsRestoringFromFailure(true); + for (Map.Entry<KinesisStreamShard, SequenceNumber> restoredShard : fakeRestoredStateForOthers.entrySet()) { + // should never get restored state not belonging to itself + Mockito.verify(mockedFetcher, never()).advanceLastDiscoveredShardOfStream( + restoredShard.getKey().getStreamName(), restoredShard.getKey().getShard().getShardId()); + Mockito.verify(mockedFetcher, never()).registerNewSubscribedShardState( + new KinesisStreamShardState(restoredShard.getKey(), restoredShard.getValue())); + } for (Map.Entry<KinesisStreamShard, SequenceNumber> restoredShard : fakeRestoredState.entrySet()) { + // should get restored state belonging to itself Mockito.verify(mockedFetcher).advanceLastDiscoveredShardOfStream( restoredShard.getKey().getStreamName(), restoredShard.getKey().getShard().getShardId()); Mockito.verify(mockedFetcher).registerNewSubscribedShardState( new KinesisStreamShardState(restoredShard.getKey(), restoredShard.getValue())); } } + + @Test + @SuppressWarnings("unchecked") --- End diff -- Should place these annotations after the comment block. I think that's the usual convention.
--- If your project is set up for it, you can reply to this email and have your reply appear on GitHub as well. If your project does not have this feature enabled and wishes so, or if the feature is enabled but not working, please contact infrastructure at infrastruct...@apache.org or file a JIRA ticket with INFRA. ---