Repository: beam Updated Branches: refs/heads/master 2c71354d0 -> 3d0fe8539
[BEAM-1456] Make UnboundedSourceWrapper snapshot to rescalable operator state in Flink Runner Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/5e3fb8c3 Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/5e3fb8c3 Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/5e3fb8c3 Branch: refs/heads/master Commit: 5e3fb8c353d40b17cd4ed7332a9ecb7425b64c86 Parents: 2c71354 Author: JingsongLi <lzljs3620...@aliyun.com> Authored: Sat Feb 18 13:04:48 2017 +0800 Committer: Aljoscha Krettek <aljoscha.kret...@gmail.com> Committed: Fri Feb 24 12:04:47 2017 +0100 ---------------------------------------------------------------------- .../streaming/io/UnboundedSourceWrapper.java | 136 ++++++++++--------- .../streaming/UnboundedSourceWrapperTest.java | 116 +++++++++++++++- 2 files changed, 185 insertions(+), 67 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/5e3fb8c3/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapper.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapper.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapper.java index 237e5a3..2849464 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapper.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapper.java @@ -18,17 +18,14 @@ package org.apache.beam.runners.flink.translation.wrappers.streaming.io; import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Function; -import com.google.common.collect.Lists; -import java.io.ByteArrayInputStream; import java.util.ArrayList; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.List; +import org.apache.beam.runners.flink.translation.types.CoderTypeInformation; import org.apache.beam.runners.flink.translation.utils.SerializedPipelineOptions; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.KvCoder; -import org.apache.beam.sdk.coders.ListCoder; import org.apache.beam.sdk.coders.SerializableCoder; import org.apache.beam.sdk.io.UnboundedSource; import org.apache.beam.sdk.options.PipelineOptions; @@ -37,11 +34,17 @@ import org.apache.beam.sdk.transforms.windowing.PaneInfo; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.TypeDescriptor; -import org.apache.commons.io.output.ByteArrayOutputStream; +import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.functions.StoppableFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.state.OperatorStateStore; import org.apache.flink.configuration.Configuration; import org.apache.flink.runtime.state.CheckpointListener; -import org.apache.flink.streaming.api.checkpoint.Checkpointed; +import org.apache.flink.runtime.state.DefaultOperatorStateBackend; +import org.apache.flink.runtime.state.FunctionInitializationContext; +import org.apache.flink.runtime.state.FunctionSnapshotContext; +import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction; import org.apache.flink.streaming.api.operators.StreamingRuntimeContext; import org.apache.flink.streaming.api.watermark.Watermark; @@ -56,7 +59,8 @@ import org.slf4j.LoggerFactory; public class UnboundedSourceWrapper< OutputT, CheckpointMarkT extends UnboundedSource.CheckpointMark> extends RichParallelSourceFunction<WindowedValue<OutputT>> - implements ProcessingTimeCallback, StoppableFunction, Checkpointed<byte[]>, CheckpointListener { + implements ProcessingTimeCallback, StoppableFunction, + CheckpointListener, CheckpointedFunction { private static final Logger LOG = LoggerFactory.getLogger(UnboundedSourceWrapper.class); @@ -68,8 +72,8 @@ public class UnboundedSourceWrapper< /** * For snapshot and restore. */ - private final ListCoder< - KV<? extends UnboundedSource<OutputT, CheckpointMarkT>, CheckpointMarkT>> checkpointCoder; + private final KvCoder< + ? extends UnboundedSource<OutputT, CheckpointMarkT>, CheckpointMarkT> checkpointCoder; /** * The split sources. We split them in the constructor to ensure that all parallel @@ -117,12 +121,13 @@ public class UnboundedSourceWrapper< */ private static final int MAX_NUMBER_PENDING_CHECKPOINTS = 32; + private transient ListState<KV<? extends + UnboundedSource<OutputT, CheckpointMarkT>, CheckpointMarkT>> stateForCheckpoint; + /** - * When restoring from a snapshot we put the restored sources/checkpoint marks here - * and open in {@link #open(Configuration)}. + * false if checkpointCoder is null or no restore state by starting first. */ - private transient List< - KV<? extends UnboundedSource<OutputT, CheckpointMarkT>, CheckpointMarkT>> restoredState; + private transient boolean isRestored = false; @SuppressWarnings("unchecked") public UnboundedSourceWrapper( @@ -145,7 +150,7 @@ public class UnboundedSourceWrapper< (Coder) SerializableCoder.of(new TypeDescriptor<UnboundedSource>() { }); - checkpointCoder = (ListCoder) ListCoder.of(KvCoder.of(sourceCoder, checkpointMarkCoder)); + checkpointCoder = KvCoder.of(sourceCoder, checkpointMarkCoder); } // get the splits early. we assume that the generated splits are stable, @@ -171,30 +176,14 @@ public class UnboundedSourceWrapper< pendingCheckpoints = new LinkedHashMap<>(); - if (restoredState != null) { - + if (isRestored) { // restore the splitSources from the checkpoint to ensure consistent ordering - // do it using a transform because otherwise we would have to do - // unchecked casts - localSplitSources = Lists.transform( - restoredState, - new Function< - KV<? extends UnboundedSource<OutputT, CheckpointMarkT>, CheckpointMarkT>, - UnboundedSource<OutputT, CheckpointMarkT>>() { - @Override - public UnboundedSource<OutputT, CheckpointMarkT> apply( - KV<? extends UnboundedSource<OutputT, CheckpointMarkT>, CheckpointMarkT> input) { - return input.getKey(); - } - }); - for (KV<? extends UnboundedSource<OutputT, CheckpointMarkT>, CheckpointMarkT> restored: - restoredState) { - localReaders.add( - restored.getKey().createReader( - serializedOptions.getPipelineOptions(), restored.getValue())); + stateForCheckpoint.get()) { + localSplitSources.add(restored.getKey()); + localReaders.add(restored.getKey().createReader( + serializedOptions.getPipelineOptions(), restored.getValue())); } - restoredState = null; } else { // initialize localReaders and localSources from scratch for (int i = 0; i < splitSources.size(); i++) { @@ -342,37 +331,42 @@ public class UnboundedSourceWrapper< isRunning = false; } + // ------------------------------------------------------------------------ + // Checkpoint and restore + // ------------------------------------------------------------------------ + @Override - public byte[] snapshotState(long checkpointId, long checkpointTimestamp) throws Exception { + public void snapshotState(FunctionSnapshotContext functionSnapshotContext) throws Exception { + if (!isRunning) { + LOG.debug("snapshotState() called on closed source"); + } else { - if (checkpointCoder == null) { - // no checkpoint coder available in this source - return null; - } + if (checkpointCoder == null) { + // no checkpoint coder available in this source + return; + } - // we checkpoint the sources along with the CheckpointMarkT to ensure - // than we have a correct mapping of checkpoints to sources when - // restoring - List<KV<? extends UnboundedSource<OutputT, CheckpointMarkT>, CheckpointMarkT>> checkpoints = - new ArrayList<>(localSplitSources.size()); - List<CheckpointMarkT> checkpointMarks = new ArrayList<>(localSplitSources.size()); - - for (int i = 0; i < localSplitSources.size(); i++) { - UnboundedSource<OutputT, CheckpointMarkT> source = localSplitSources.get(i); - UnboundedSource.UnboundedReader<OutputT> reader = localReaders.get(i); - - @SuppressWarnings("unchecked") - CheckpointMarkT mark = (CheckpointMarkT) reader.getCheckpointMark(); - checkpointMarks.add(mark); - KV<UnboundedSource<OutputT, CheckpointMarkT>, CheckpointMarkT> kv = - KV.of(source, mark); - checkpoints.add(kv); - } + stateForCheckpoint.clear(); + + long checkpointId = functionSnapshotContext.getCheckpointId(); + + // we checkpoint the sources along with the CheckpointMarkT to ensure + // than we have a correct mapping of checkpoints to sources when + // restoring + List<CheckpointMarkT> checkpointMarks = new ArrayList<>(localSplitSources.size()); + + for (int i = 0; i < localSplitSources.size(); i++) { + UnboundedSource<OutputT, CheckpointMarkT> source = localSplitSources.get(i); + UnboundedSource.UnboundedReader<OutputT> reader = localReaders.get(i); + + @SuppressWarnings("unchecked") + CheckpointMarkT mark = (CheckpointMarkT) reader.getCheckpointMark(); + checkpointMarks.add(mark); + KV<UnboundedSource<OutputT, CheckpointMarkT>, CheckpointMarkT> kv = + KV.of(source, mark); + stateForCheckpoint.add(kv); + } - try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) { - checkpointCoder.encode(checkpoints, baos, Coder.Context.OUTER); - return baos.toByteArray(); - } finally { // cleanup old pending checkpoints and add new checkpoint int diff = pendingCheckpoints.size() - MAX_NUMBER_PENDING_CHECKPOINTS; if (diff >= 0) { @@ -384,18 +378,30 @@ public class UnboundedSourceWrapper< } } pendingCheckpoints.put(checkpointId, checkpointMarks); + } } @Override - public void restoreState(byte[] bytes) throws Exception { + public void initializeState(FunctionInitializationContext context) throws Exception { if (checkpointCoder == null) { // no checkpoint coder available in this source return; } - try (ByteArrayInputStream bais = new ByteArrayInputStream(bytes)) { - restoredState = checkpointCoder.decode(bais, Coder.Context.OUTER); + OperatorStateStore stateStore = context.getOperatorStateStore(); + CoderTypeInformation< + KV<? extends UnboundedSource<OutputT, CheckpointMarkT>, CheckpointMarkT>> + typeInformation = (CoderTypeInformation) new CoderTypeInformation<>(checkpointCoder); + stateForCheckpoint = stateStore.getOperatorState( + new ListStateDescriptor<>(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME, + typeInformation.createSerializer(new ExecutionConfig()))); + + if (context.isRestored()) { + isRestored = true; + LOG.info("Having restore state in the UnbounedSourceWrapper."); + } else { + LOG.info("No restore state for UnbounedSourceWrapper."); } } http://git-wip-us.apache.org/repos/asf/beam/blob/5e3fb8c3/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/UnboundedSourceWrapperTest.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/UnboundedSourceWrapperTest.java b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/UnboundedSourceWrapperTest.java index 5b3d088..90f95d6 100644 --- a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/UnboundedSourceWrapperTest.java +++ b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/UnboundedSourceWrapperTest.java @@ -28,17 +28,25 @@ import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashSet; +import java.util.List; import java.util.Set; import org.apache.beam.runners.flink.translation.wrappers.streaming.io.UnboundedSourceWrapper; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.io.UnboundedSource; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.KV; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.accumulators.Accumulator; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.state.OperatorStateStore; import org.apache.flink.configuration.Configuration; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.operators.testutils.DummyEnvironment; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContextSynchronousImpl; import org.apache.flink.streaming.api.TimeCharacteristic; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.operators.Output; @@ -53,6 +61,7 @@ import org.junit.Test; import org.junit.experimental.runners.Enclosed; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; +import org.mockito.Matchers; /** * Tests for {@link UnboundedSourceWrapper}. @@ -180,6 +189,22 @@ public class UnboundedSourceWrapperTest { KV<Integer, Integer>, TestCountingSource.CounterMark>> sourceOperator = new StreamSource<>(flinkWrapper); + + OperatorStateStore backend = mock(OperatorStateStore.class); + + TestingListState<KV<UnboundedSource, TestCountingSource.CounterMark>> + listState = new TestingListState<>(); + + when(backend.getOperatorState(Matchers.any(ListStateDescriptor.class))) + .thenReturn(listState); + + StateInitializationContext initializationContext = mock(StateInitializationContext.class); + + when(initializationContext.getOperatorStateStore()).thenReturn(backend); + when(initializationContext.isRestored()).thenReturn(false, true); + + flinkWrapper.initializeState(initializationContext); + setupSourceOperator(sourceOperator, numTasks); final Set<KV<Integer, Integer>> emittedElements = new HashSet<>(); @@ -224,7 +249,16 @@ public class UnboundedSourceWrapperTest { assertTrue("Did not successfully read first batch of elements.", readFirstBatchOfElements); // draw a snapshot - byte[] snapshot = flinkWrapper.snapshotState(0, 0); + flinkWrapper.snapshotState(new StateSnapshotContextSynchronousImpl(0, 0)); + + // test snapshot offsets + assertEquals(flinkWrapper.getLocalSplitSources().size(), + listState.getList().size()); + int totalEmit = 0; + for (KV<UnboundedSource, TestCountingSource.CounterMark> kv : listState.get()) { + totalEmit += kv.getValue().current + 1; + } + assertEquals(numElements / 2, totalEmit); // test that finalizeCheckpoint on CheckpointMark is called final ArrayList<Integer> finalizeList = new ArrayList<>(); @@ -250,7 +284,7 @@ public class UnboundedSourceWrapperTest { setupSourceOperator(restoredSourceOperator, numTasks); // restore snapshot - restoredFlinkWrapper.restoreState(snapshot); + restoredFlinkWrapper.initializeState(initializationContext); boolean readSecondBatchOfElements = false; @@ -297,6 +331,58 @@ public class UnboundedSourceWrapperTest { assertTrue(emittedElements.size() == numElements); } + @Test + public void testNullCheckpoint() throws Exception { + final int numElements = 20; + PipelineOptions options = PipelineOptionsFactory.create(); + + TestCountingSource source = new TestCountingSource(numElements) { + @Override + public Coder<CounterMark> getCheckpointMarkCoder() { + return null; + } + }; + UnboundedSourceWrapper<KV<Integer, Integer>, TestCountingSource.CounterMark> flinkWrapper = + new UnboundedSourceWrapper<>(options, source, numSplits); + + OperatorStateStore backend = mock(OperatorStateStore.class); + + TestingListState<KV<UnboundedSource, TestCountingSource.CounterMark>> + listState = new TestingListState<>(); + + when(backend.getOperatorState(Matchers.any(ListStateDescriptor.class))) + .thenReturn(listState); + + StateInitializationContext initializationContext = mock(StateInitializationContext.class); + + when(initializationContext.getOperatorStateStore()).thenReturn(backend); + when(initializationContext.isRestored()).thenReturn(false, true); + + flinkWrapper.initializeState(initializationContext); + + StreamSource sourceOperator = new StreamSource<>(flinkWrapper); + setupSourceOperator(sourceOperator, numTasks); + sourceOperator.open(); + + flinkWrapper.snapshotState(new StateSnapshotContextSynchronousImpl(0, 0)); + + assertEquals(0, listState.getList().size()); + + UnboundedSourceWrapper< + KV<Integer, Integer>, TestCountingSource.CounterMark> restoredFlinkWrapper = + new UnboundedSourceWrapper<>(options, new TestCountingSource(numElements), + numSplits); + + StreamSource restoredSourceOperator = new StreamSource<>(flinkWrapper); + setupSourceOperator(restoredSourceOperator, numTasks); + sourceOperator.open(); + + restoredFlinkWrapper.initializeState(initializationContext); + + assertEquals(Math.max(1, numSplits / numTasks), flinkWrapper.getLocalSplitSources().size()); + + } + @SuppressWarnings("unchecked") private static <T> void setupSourceOperator(StreamSource<T, ?> operator, int numSubTasks) { ExecutionConfig executionConfig = new ExecutionConfig(); @@ -349,4 +435,30 @@ public class UnboundedSourceWrapperTest { } } + + private static final class TestingListState<T> implements ListState<T> { + + private final List<T> list = new ArrayList<>(); + + @Override + public void clear() { + list.clear(); + } + + @Override + public Iterable<T> get() throws Exception { + return list; + } + + @Override + public void add(T value) throws Exception { + list.add(value); + } + + public List<T> getList() { + return list; + } + + } + }