Repository: beam
Updated Branches:
  refs/heads/master 2c71354d0 -> 3d0fe8539


[BEAM-1456] Make UnboundedSourceWrapper snapshot to rescalable operator state 
in Flink Runner


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/5e3fb8c3
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/5e3fb8c3
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/5e3fb8c3

Branch: refs/heads/master
Commit: 5e3fb8c353d40b17cd4ed7332a9ecb7425b64c86
Parents: 2c71354
Author: JingsongLi <lzljs3620...@aliyun.com>
Authored: Sat Feb 18 13:04:48 2017 +0800
Committer: Aljoscha Krettek <aljoscha.kret...@gmail.com>
Committed: Fri Feb 24 12:04:47 2017 +0100

----------------------------------------------------------------------
 .../streaming/io/UnboundedSourceWrapper.java    | 136 ++++++++++---------
 .../streaming/UnboundedSourceWrapperTest.java   | 116 +++++++++++++++-
 2 files changed, 185 insertions(+), 67 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/5e3fb8c3/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapper.java
----------------------------------------------------------------------
diff --git 
a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapper.java
 
b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapper.java
index 237e5a3..2849464 100644
--- 
a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapper.java
+++ 
b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/io/UnboundedSourceWrapper.java
@@ -18,17 +18,14 @@
 package org.apache.beam.runners.flink.translation.wrappers.streaming.io;
 
 import com.google.common.annotations.VisibleForTesting;
-import com.google.common.base.Function;
-import com.google.common.collect.Lists;
-import java.io.ByteArrayInputStream;
 import java.util.ArrayList;
 import java.util.Iterator;
 import java.util.LinkedHashMap;
 import java.util.List;
+import org.apache.beam.runners.flink.translation.types.CoderTypeInformation;
 import 
org.apache.beam.runners.flink.translation.utils.SerializedPipelineOptions;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.KvCoder;
-import org.apache.beam.sdk.coders.ListCoder;
 import org.apache.beam.sdk.coders.SerializableCoder;
 import org.apache.beam.sdk.io.UnboundedSource;
 import org.apache.beam.sdk.options.PipelineOptions;
@@ -37,11 +34,17 @@ import org.apache.beam.sdk.transforms.windowing.PaneInfo;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.TypeDescriptor;
-import org.apache.commons.io.output.ByteArrayOutputStream;
+import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.functions.StoppableFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.state.OperatorStateStore;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.runtime.state.CheckpointListener;
-import org.apache.flink.streaming.api.checkpoint.Checkpointed;
+import org.apache.flink.runtime.state.DefaultOperatorStateBackend;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
 import 
org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
 import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
 import org.apache.flink.streaming.api.watermark.Watermark;
@@ -56,7 +59,8 @@ import org.slf4j.LoggerFactory;
 public class UnboundedSourceWrapper<
     OutputT, CheckpointMarkT extends UnboundedSource.CheckpointMark>
     extends RichParallelSourceFunction<WindowedValue<OutputT>>
-    implements ProcessingTimeCallback, StoppableFunction, 
Checkpointed<byte[]>, CheckpointListener {
+    implements ProcessingTimeCallback, StoppableFunction,
+    CheckpointListener, CheckpointedFunction {
 
   private static final Logger LOG = 
LoggerFactory.getLogger(UnboundedSourceWrapper.class);
 
@@ -68,8 +72,8 @@ public class UnboundedSourceWrapper<
   /**
    * For snapshot and restore.
    */
-  private final ListCoder<
-      KV<? extends UnboundedSource<OutputT, CheckpointMarkT>, 
CheckpointMarkT>> checkpointCoder;
+  private final KvCoder<
+      ? extends UnboundedSource<OutputT, CheckpointMarkT>, CheckpointMarkT> 
checkpointCoder;
 
   /**
    * The split sources. We split them in the constructor to ensure that all 
parallel
@@ -117,12 +121,13 @@ public class UnboundedSourceWrapper<
    */
   private static final int MAX_NUMBER_PENDING_CHECKPOINTS = 32;
 
+  private transient ListState<KV<? extends
+      UnboundedSource<OutputT, CheckpointMarkT>, CheckpointMarkT>> 
stateForCheckpoint;
+
   /**
-   * When restoring from a snapshot we put the restored sources/checkpoint 
marks here
-   * and open in {@link #open(Configuration)}.
+   * false if checkpointCoder is null or no restore state by starting first.
    */
-  private transient List<
-      KV<? extends UnboundedSource<OutputT, CheckpointMarkT>, 
CheckpointMarkT>> restoredState;
+  private transient boolean isRestored = false;
 
   @SuppressWarnings("unchecked")
   public UnboundedSourceWrapper(
@@ -145,7 +150,7 @@ public class UnboundedSourceWrapper<
           (Coder) SerializableCoder.of(new TypeDescriptor<UnboundedSource>() {
           });
 
-      checkpointCoder = (ListCoder) ListCoder.of(KvCoder.of(sourceCoder, 
checkpointMarkCoder));
+      checkpointCoder = KvCoder.of(sourceCoder, checkpointMarkCoder);
     }
 
     // get the splits early. we assume that the generated splits are stable,
@@ -171,30 +176,14 @@ public class UnboundedSourceWrapper<
 
     pendingCheckpoints = new LinkedHashMap<>();
 
-    if (restoredState != null) {
-
+    if (isRestored) {
       // restore the splitSources from the checkpoint to ensure consistent 
ordering
-      // do it using a transform because otherwise we would have to do
-      // unchecked casts
-      localSplitSources = Lists.transform(
-          restoredState,
-          new Function<
-              KV<? extends UnboundedSource<OutputT, CheckpointMarkT>, 
CheckpointMarkT>,
-              UnboundedSource<OutputT, CheckpointMarkT>>() {
-            @Override
-            public UnboundedSource<OutputT, CheckpointMarkT> apply(
-                KV<? extends UnboundedSource<OutputT, CheckpointMarkT>, 
CheckpointMarkT> input) {
-              return input.getKey();
-            }
-          });
-
       for (KV<? extends UnboundedSource<OutputT, CheckpointMarkT>, 
CheckpointMarkT> restored:
-          restoredState) {
-        localReaders.add(
-            restored.getKey().createReader(
-                serializedOptions.getPipelineOptions(), restored.getValue()));
+          stateForCheckpoint.get()) {
+        localSplitSources.add(restored.getKey());
+        localReaders.add(restored.getKey().createReader(
+            serializedOptions.getPipelineOptions(), restored.getValue()));
       }
-      restoredState = null;
     } else {
       // initialize localReaders and localSources from scratch
       for (int i = 0; i < splitSources.size(); i++) {
@@ -342,37 +331,42 @@ public class UnboundedSourceWrapper<
     isRunning = false;
   }
 
+  // ------------------------------------------------------------------------
+  //  Checkpoint and restore
+  // ------------------------------------------------------------------------
+
   @Override
-  public byte[] snapshotState(long checkpointId, long checkpointTimestamp) 
throws Exception {
+  public void snapshotState(FunctionSnapshotContext functionSnapshotContext) 
throws Exception {
+    if (!isRunning) {
+      LOG.debug("snapshotState() called on closed source");
+    } else {
 
-    if (checkpointCoder == null) {
-      // no checkpoint coder available in this source
-      return null;
-    }
+      if (checkpointCoder == null) {
+        // no checkpoint coder available in this source
+        return;
+      }
 
-    // we checkpoint the sources along with the CheckpointMarkT to ensure
-    // than we have a correct mapping of checkpoints to sources when
-    // restoring
-    List<KV<? extends UnboundedSource<OutputT, CheckpointMarkT>, 
CheckpointMarkT>> checkpoints =
-        new ArrayList<>(localSplitSources.size());
-    List<CheckpointMarkT> checkpointMarks = new 
ArrayList<>(localSplitSources.size());
-
-    for (int i = 0; i < localSplitSources.size(); i++) {
-      UnboundedSource<OutputT, CheckpointMarkT> source = 
localSplitSources.get(i);
-      UnboundedSource.UnboundedReader<OutputT> reader = localReaders.get(i);
-
-      @SuppressWarnings("unchecked")
-      CheckpointMarkT mark = (CheckpointMarkT) reader.getCheckpointMark();
-      checkpointMarks.add(mark);
-      KV<UnboundedSource<OutputT, CheckpointMarkT>, CheckpointMarkT> kv =
-          KV.of(source, mark);
-      checkpoints.add(kv);
-    }
+      stateForCheckpoint.clear();
+
+      long checkpointId = functionSnapshotContext.getCheckpointId();
+
+      // we checkpoint the sources along with the CheckpointMarkT to ensure
+      // than we have a correct mapping of checkpoints to sources when
+      // restoring
+      List<CheckpointMarkT> checkpointMarks = new 
ArrayList<>(localSplitSources.size());
+
+      for (int i = 0; i < localSplitSources.size(); i++) {
+        UnboundedSource<OutputT, CheckpointMarkT> source = 
localSplitSources.get(i);
+        UnboundedSource.UnboundedReader<OutputT> reader = localReaders.get(i);
+
+        @SuppressWarnings("unchecked")
+        CheckpointMarkT mark = (CheckpointMarkT) reader.getCheckpointMark();
+        checkpointMarks.add(mark);
+        KV<UnboundedSource<OutputT, CheckpointMarkT>, CheckpointMarkT> kv =
+            KV.of(source, mark);
+        stateForCheckpoint.add(kv);
+      }
 
-    try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
-      checkpointCoder.encode(checkpoints, baos, Coder.Context.OUTER);
-      return baos.toByteArray();
-    } finally {
       // cleanup old pending checkpoints and add new checkpoint
       int diff = pendingCheckpoints.size() - MAX_NUMBER_PENDING_CHECKPOINTS;
       if (diff >= 0) {
@@ -384,18 +378,30 @@ public class UnboundedSourceWrapper<
         }
       }
       pendingCheckpoints.put(checkpointId, checkpointMarks);
+
     }
   }
 
   @Override
-  public void restoreState(byte[] bytes) throws Exception {
+  public void initializeState(FunctionInitializationContext context) throws 
Exception {
     if (checkpointCoder == null) {
       // no checkpoint coder available in this source
       return;
     }
 
-    try (ByteArrayInputStream bais = new ByteArrayInputStream(bytes)) {
-      restoredState = checkpointCoder.decode(bais, Coder.Context.OUTER);
+    OperatorStateStore stateStore = context.getOperatorStateStore();
+    CoderTypeInformation<
+        KV<? extends UnboundedSource<OutputT, CheckpointMarkT>, 
CheckpointMarkT>>
+        typeInformation = (CoderTypeInformation) new 
CoderTypeInformation<>(checkpointCoder);
+    stateForCheckpoint = stateStore.getOperatorState(
+        new 
ListStateDescriptor<>(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME,
+            typeInformation.createSerializer(new ExecutionConfig())));
+
+    if (context.isRestored()) {
+      isRestored = true;
+      LOG.info("Having restore state in the UnbounedSourceWrapper.");
+    } else {
+      LOG.info("No restore state for UnbounedSourceWrapper.");
     }
   }
 

http://git-wip-us.apache.org/repos/asf/beam/blob/5e3fb8c3/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/UnboundedSourceWrapperTest.java
----------------------------------------------------------------------
diff --git 
a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/UnboundedSourceWrapperTest.java
 
b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/UnboundedSourceWrapperTest.java
index 5b3d088..90f95d6 100644
--- 
a/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/UnboundedSourceWrapperTest.java
+++ 
b/runners/flink/runner/src/test/java/org/apache/beam/runners/flink/streaming/UnboundedSourceWrapperTest.java
@@ -28,17 +28,25 @@ import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashSet;
+import java.util.List;
 import java.util.Set;
 import 
org.apache.beam.runners.flink.translation.wrappers.streaming.io.UnboundedSourceWrapper;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.io.UnboundedSource;
 import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.options.PipelineOptionsFactory;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.KV;
 import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.accumulators.Accumulator;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.state.OperatorStateStore;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContextSynchronousImpl;
 import org.apache.flink.streaming.api.TimeCharacteristic;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.operators.Output;
@@ -53,6 +61,7 @@ import org.junit.Test;
 import org.junit.experimental.runners.Enclosed;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
+import org.mockito.Matchers;
 
 /**
  * Tests for {@link UnboundedSourceWrapper}.
@@ -180,6 +189,22 @@ public class UnboundedSourceWrapperTest {
               KV<Integer, Integer>,
               TestCountingSource.CounterMark>> sourceOperator = new 
StreamSource<>(flinkWrapper);
 
+
+      OperatorStateStore backend = mock(OperatorStateStore.class);
+
+      TestingListState<KV<UnboundedSource, TestCountingSource.CounterMark>>
+          listState = new TestingListState<>();
+
+      when(backend.getOperatorState(Matchers.any(ListStateDescriptor.class)))
+          .thenReturn(listState);
+
+      StateInitializationContext initializationContext = 
mock(StateInitializationContext.class);
+
+      when(initializationContext.getOperatorStateStore()).thenReturn(backend);
+      when(initializationContext.isRestored()).thenReturn(false, true);
+
+      flinkWrapper.initializeState(initializationContext);
+
       setupSourceOperator(sourceOperator, numTasks);
 
       final Set<KV<Integer, Integer>> emittedElements = new HashSet<>();
@@ -224,7 +249,16 @@ public class UnboundedSourceWrapperTest {
       assertTrue("Did not successfully read first batch of elements.", 
readFirstBatchOfElements);
 
       // draw a snapshot
-      byte[] snapshot = flinkWrapper.snapshotState(0, 0);
+      flinkWrapper.snapshotState(new StateSnapshotContextSynchronousImpl(0, 
0));
+
+      // test snapshot offsets
+      assertEquals(flinkWrapper.getLocalSplitSources().size(),
+          listState.getList().size());
+      int totalEmit = 0;
+      for (KV<UnboundedSource, TestCountingSource.CounterMark> kv : 
listState.get()) {
+        totalEmit += kv.getValue().current + 1;
+      }
+      assertEquals(numElements / 2, totalEmit);
 
       // test that finalizeCheckpoint on CheckpointMark is called
       final ArrayList<Integer> finalizeList = new ArrayList<>();
@@ -250,7 +284,7 @@ public class UnboundedSourceWrapperTest {
       setupSourceOperator(restoredSourceOperator, numTasks);
 
       // restore snapshot
-      restoredFlinkWrapper.restoreState(snapshot);
+      restoredFlinkWrapper.initializeState(initializationContext);
 
       boolean readSecondBatchOfElements = false;
 
@@ -297,6 +331,58 @@ public class UnboundedSourceWrapperTest {
       assertTrue(emittedElements.size() == numElements);
     }
 
+    @Test
+    public void testNullCheckpoint() throws Exception {
+      final int numElements = 20;
+      PipelineOptions options = PipelineOptionsFactory.create();
+
+      TestCountingSource source = new TestCountingSource(numElements) {
+        @Override
+        public Coder<CounterMark> getCheckpointMarkCoder() {
+          return null;
+        }
+      };
+      UnboundedSourceWrapper<KV<Integer, Integer>, 
TestCountingSource.CounterMark> flinkWrapper =
+          new UnboundedSourceWrapper<>(options, source, numSplits);
+
+      OperatorStateStore backend = mock(OperatorStateStore.class);
+
+      TestingListState<KV<UnboundedSource, TestCountingSource.CounterMark>>
+          listState = new TestingListState<>();
+
+      when(backend.getOperatorState(Matchers.any(ListStateDescriptor.class)))
+          .thenReturn(listState);
+
+      StateInitializationContext initializationContext = 
mock(StateInitializationContext.class);
+
+      when(initializationContext.getOperatorStateStore()).thenReturn(backend);
+      when(initializationContext.isRestored()).thenReturn(false, true);
+
+      flinkWrapper.initializeState(initializationContext);
+
+      StreamSource sourceOperator = new StreamSource<>(flinkWrapper);
+      setupSourceOperator(sourceOperator, numTasks);
+      sourceOperator.open();
+
+      flinkWrapper.snapshotState(new StateSnapshotContextSynchronousImpl(0, 
0));
+
+      assertEquals(0, listState.getList().size());
+
+      UnboundedSourceWrapper<
+          KV<Integer, Integer>, TestCountingSource.CounterMark> 
restoredFlinkWrapper =
+          new UnboundedSourceWrapper<>(options, new 
TestCountingSource(numElements),
+              numSplits);
+
+      StreamSource restoredSourceOperator = new StreamSource<>(flinkWrapper);
+      setupSourceOperator(restoredSourceOperator, numTasks);
+      sourceOperator.open();
+
+      restoredFlinkWrapper.initializeState(initializationContext);
+
+      assertEquals(Math.max(1, numSplits / numTasks), 
flinkWrapper.getLocalSplitSources().size());
+
+    }
+
     @SuppressWarnings("unchecked")
     private static <T> void setupSourceOperator(StreamSource<T, ?> operator, 
int numSubTasks) {
       ExecutionConfig executionConfig = new ExecutionConfig();
@@ -349,4 +435,30 @@ public class UnboundedSourceWrapperTest {
     }
 
   }
+
+  private static final class TestingListState<T> implements ListState<T> {
+
+    private final List<T> list = new ArrayList<>();
+
+    @Override
+    public void clear() {
+      list.clear();
+    }
+
+    @Override
+    public Iterable<T> get() throws Exception {
+      return list;
+    }
+
+    @Override
+    public void add(T value) throws Exception {
+      list.add(value);
+    }
+
+    public List<T> getList() {
+      return list;
+    }
+
+  }
+
 }

Reply via email to