This is an automated email from the ASF dual-hosted git repository. srichter pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push: new ee60846 [FLINK-12296][StateBackend] Fix local state directory collision with state loss for chained keyed operators ee60846 is described below commit ee60846dc588b1a832a497ff9522d7a3a282c350 Author: klion26 <qcx978132...@gmail.com> AuthorDate: Wed Apr 24 10:52:03 2019 +0800 [FLINK-12296][StateBackend] Fix local state directory collision with state loss for chained keyed operators - Change Will change the local data path from `.../local_state_root/allocatio_id/job_id/jobvertext_id_subtask_id/chk_id/rocksdb` to `.../local_state_root/allocatio_id/job_id/jobvertext_id_subtask_id/chk_id/operator_id` When preparing the local directory Flink deletes the local directory for each subtask if it already exists, If more than one stateful operators chained in a single task, they'll share the same local directory path, then the local directory will be deleted unexpectedly, and the we'll get data loss. This closes #8263. --- .../CheckpointStreamWithResultProviderTest.java | 3 + .../state/StateSnapshotCompressionTest.java | 2 +- .../ttl/mock/MockKeyedStateBackendBuilder.java | 1 + .../runtime/state/ttl/mock/MockStateBackend.java | 2 +- .../state/RocksDBKeyedStateBackendBuilder.java | 1 + .../snapshot/RocksIncrementalSnapshotStrategy.java | 17 +- .../tasks/OneInputStreamTaskTestHarness.java | 50 +++- .../runtime/tasks/StreamConfigChainer.java | 23 +- .../runtime/tasks/StreamMockEnvironment.java | 8 +- .../runtime/tasks/StreamTaskTestHarness.java | 21 +- .../state/StatefulOperatorChainedTaskTest.java | 260 +++++++++++++++++++++ 11 files changed, 369 insertions(+), 19 deletions(-) diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/CheckpointStreamWithResultProviderTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/CheckpointStreamWithResultProviderTest.java index 2af25d9..57653e2 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/CheckpointStreamWithResultProviderTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/CheckpointStreamWithResultProviderTest.java @@ -35,6 +35,9 @@ import java.io.Closeable; import java.io.File; import java.io.IOException; +/** + * Test for CheckpointStreamWithResultProvider. + */ public class CheckpointStreamWithResultProviderTest extends TestLogger { private static TemporaryFolder temporaryFolder; diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateSnapshotCompressionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateSnapshotCompressionTest.java index a10be26..de687ff 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateSnapshotCompressionTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateSnapshotCompressionTest.java @@ -18,7 +18,6 @@ package org.apache.flink.runtime.state; -import org.apache.commons.io.IOUtils; import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.typeutils.base.StringSerializer; @@ -34,6 +33,7 @@ import org.apache.flink.runtime.state.memory.MemCheckpointStreamFactory; import org.apache.flink.runtime.state.ttl.TtlTimeProvider; import org.apache.flink.util.TestLogger; +import org.apache.commons.io.IOUtils; import org.junit.Assert; import org.junit.Test; diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackendBuilder.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackendBuilder.java index 3ffe183..8ec9b4d 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackendBuilder.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackendBuilder.java @@ -31,6 +31,7 @@ import org.apache.flink.runtime.state.heap.InternalKeyContextImpl; import org.apache.flink.runtime.state.ttl.TtlTimeProvider; import javax.annotation.Nonnull; + import java.util.Collection; import java.util.HashMap; import java.util.Map; diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockStateBackend.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockStateBackend.java index bdf07bf..f50f1b6 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockStateBackend.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockStateBackend.java @@ -35,8 +35,8 @@ import org.apache.flink.runtime.state.CheckpointStorageLocationReference; import org.apache.flink.runtime.state.CheckpointStreamFactory; import org.apache.flink.runtime.state.CheckpointedStateScope; import org.apache.flink.runtime.state.CompletedCheckpointStorageLocation; -import org.apache.flink.runtime.state.KeyedStateHandle; import org.apache.flink.runtime.state.KeyGroupRange; +import org.apache.flink.runtime.state.KeyedStateHandle; import org.apache.flink.runtime.state.OperatorStateBackend; import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.ttl.TtlTimeProvider; diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackendBuilder.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackendBuilder.java index b515c94..ddd55c8 100644 --- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackendBuilder.java +++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackendBuilder.java @@ -233,6 +233,7 @@ public class RocksDBKeyedStateBackendBuilder<K> extends AbstractKeyedStateBacken } } + @Override public RocksDBKeyedStateBackend<K> build() throws BackendBuildingException { RocksDBWriteBatchWrapper writeBatchWrapper = null; ColumnFamilyHandle defaultColumnFamilyHandle = null; diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksIncrementalSnapshotStrategy.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksIncrementalSnapshotStrategy.java index 889b18d..38d5e7a 100644 --- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksIncrementalSnapshotStrategy.java +++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksIncrementalSnapshotStrategy.java @@ -106,6 +106,9 @@ public class RocksIncrementalSnapshotStrategy<K> extends RocksDBSnapshotStrategy /** The help class used to upload state files. */ private final RocksDBStateUploader stateUploader; + /** The local directory name of the current snapshot strategy. */ + private final String localDirectoryName; + public RocksIncrementalSnapshotStrategy( @Nonnull RocksDB db, @Nonnull ResourceGuard rocksDBResourceGuard, @@ -137,6 +140,7 @@ public class RocksIncrementalSnapshotStrategy<K> extends RocksDBSnapshotStrategy this.materializedSstFiles = materializedSstFiles; this.lastCompletedCheckpointId = lastCompletedCheckpointId; this.stateUploader = new RocksDBStateUploader(numberOfTransferingThreads); + this.localDirectoryName = backendUID.toString().replaceAll("[\\-]", ""); } @Nonnull @@ -184,17 +188,18 @@ public class RocksIncrementalSnapshotStrategy<K> extends RocksDBSnapshotStrategy LocalRecoveryDirectoryProvider directoryProvider = localRecoveryConfig.getLocalStateDirectoryProvider(); File directory = directoryProvider.subtaskSpecificCheckpointDirectory(checkpointId); - if (directory.exists()) { - FileUtils.deleteDirectory(directory); - } - - if (!directory.mkdirs()) { + if (!directory.exists() && !directory.mkdirs()) { throw new IOException("Local state base directory for checkpoint " + checkpointId + " already exists: " + directory); } // introduces an extra directory because RocksDB wants a non-existing directory for native checkpoints. - File rdbSnapshotDir = new File(directory, "rocks_db"); + // append localDirectoryName here to solve directory collision problem when two stateful operators chained in one task. + File rdbSnapshotDir = new File(directory, localDirectoryName); + if (rdbSnapshotDir.exists()) { + FileUtils.deleteDirectory(rdbSnapshotDir); + } + Path path = new Path(rdbSnapshotDir.toURI()); // create a "permanent" snapshot directory because local recovery is active. try { diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTestHarness.java index 89a4f81..7ac0cf3 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTestHarness.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTestHarness.java @@ -25,7 +25,10 @@ import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.runtime.event.AbstractEvent; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.io.network.partition.consumer.StreamTestSingleInputGate; +import org.apache.flink.runtime.state.LocalRecoveryConfig; +import org.apache.flink.runtime.state.TestLocalRecoveryConfig; +import java.io.File; import java.io.IOException; import java.util.function.Function; @@ -56,16 +59,48 @@ public class OneInputStreamTaskTestHarness<IN, OUT> extends StreamTaskTestHarnes /** * Creates a test harness with the specified number of input gates and specified number - * of channels per input gate. + * of channels per input gate and local recovery disabled. + */ + public OneInputStreamTaskTestHarness( + Function<Environment, ? extends StreamTask<OUT, ?>> taskFactory, + int numInputGates, + int numInputChannelsPerGate, + TypeInformation<IN> inputType, + TypeInformation<OUT> outputType) { + this(taskFactory, numInputGates, numInputChannelsPerGate, inputType, outputType, TestLocalRecoveryConfig.disabled()); + } + + public OneInputStreamTaskTestHarness( + Function<Environment, ? extends StreamTask<OUT, ?>> taskFactory, + int numInputGates, + int numInputChannelsPerGate, + TypeInformation<IN> inputType, + TypeInformation<OUT> outputType, + File localRootDir) { + super(taskFactory, outputType, localRootDir); + + this.inputType = inputType; + inputSerializer = inputType.createSerializer(executionConfig); + + this.numInputGates = numInputGates; + this.numInputChannelsPerGate = numInputChannelsPerGate; + + streamConfig.setStateKeySerializer(inputSerializer); + } + + /** + * Creates a test harness with the specified number of input gates and specified number + * of channels per input gate and specified localRecoveryConfig. */ public OneInputStreamTaskTestHarness( Function<Environment, ? extends StreamTask<OUT, ?>> taskFactory, int numInputGates, int numInputChannelsPerGate, TypeInformation<IN> inputType, - TypeInformation<OUT> outputType) { + TypeInformation<OUT> outputType, + LocalRecoveryConfig localRecoveryConfig) { - super(taskFactory, outputType); + super(taskFactory, outputType, localRecoveryConfig); this.inputType = inputType; inputSerializer = inputType.createSerializer(executionConfig); @@ -78,11 +113,10 @@ public class OneInputStreamTaskTestHarness<IN, OUT> extends StreamTaskTestHarnes * Creates a test harness with one input gate that has one input channel. */ public OneInputStreamTaskTestHarness( - Function<Environment, ? extends StreamTask<OUT, ?>> taskFactory, - TypeInformation<IN> inputType, - TypeInformation<OUT> outputType) { - - this(taskFactory, 1, 1, inputType, outputType); + Function<Environment, ? extends StreamTask<OUT, ?>> taskFactory, + TypeInformation<IN> inputType, + TypeInformation<OUT> outputType) { + this(taskFactory, 1, 1, inputType, outputType, TestLocalRecoveryConfig.disabled()); } @Override diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamConfigChainer.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamConfigChainer.java index 10e50ce..747468e 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamConfigChainer.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamConfigChainer.java @@ -61,6 +61,14 @@ public class StreamConfigChainer { } public <T> StreamConfigChainer chain( + OperatorID operatorID, + OneInputStreamOperator<T, T> operator, + TypeSerializer<T> typeSerializer, + boolean createKeyedStateBackend) { + return chain(operatorID, operator, typeSerializer, typeSerializer, createKeyedStateBackend); + } + + public <T> StreamConfigChainer chain( OperatorID operatorID, OneInputStreamOperator<T, T> operator, TypeSerializer<T> typeSerializer) { @@ -68,10 +76,19 @@ public class StreamConfigChainer { } public <IN, OUT> StreamConfigChainer chain( + OperatorID operatorID, + OneInputStreamOperator<IN, OUT> operator, + TypeSerializer<IN> inputSerializer, + TypeSerializer<OUT> outputSerializer) { + return chain(operatorID, operator, inputSerializer, outputSerializer, false); + } + + public <IN, OUT> StreamConfigChainer chain( OperatorID operatorID, OneInputStreamOperator<IN, OUT> operator, TypeSerializer<IN> inputSerializer, - TypeSerializer<OUT> outputSerializer) { + TypeSerializer<OUT> outputSerializer, + boolean createKeyedStateBackend) { chainIndex++; tailConfig.setChainedOutputs(Collections.singletonList( @@ -87,6 +104,10 @@ public class StreamConfigChainer { tailConfig.setOperatorID(checkNotNull(operatorID)); tailConfig.setTypeSerializerIn1(inputSerializer); tailConfig.setTypeSerializerOut(outputSerializer); + if (createKeyedStateBackend) { + // used to test multiple stateful operators chained in a single task. + tailConfig.setStateKeySerializer(inputSerializer); + } tailConfig.setChainIndex(chainIndex); chainedConfigs.put(chainIndex, tailConfig); diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java index 6cd7617..134218a 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java @@ -106,6 +106,8 @@ public class StreamMockEnvironment implements Environment { private TaskEventDispatcher taskEventDispatcher = mock(TaskEventDispatcher.class); + private TaskManagerRuntimeInfo taskManagerRuntimeInfo = new TestingTaskManagerRuntimeInfo(); + public StreamMockEnvironment( Configuration jobConfig, Configuration taskConfig, @@ -332,7 +334,11 @@ public class StreamMockEnvironment implements Environment { @Override public TaskManagerRuntimeInfo getTaskManagerInfo() { - return new TestingTaskManagerRuntimeInfo(); + return this.taskManagerRuntimeInfo; + } + + public void setTaskManagerInfo(TaskManagerRuntimeInfo taskManagerRuntimeInfo) { + this.taskManagerRuntimeInfo = taskManagerRuntimeInfo; } @Override diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java index 36f0fb7..be31923 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java @@ -19,6 +19,7 @@ package org.apache.flink.streaming.runtime.tasks; import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.JobID; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.configuration.Configuration; @@ -26,10 +27,14 @@ import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.event.AbstractEvent; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.io.network.partition.consumer.StreamTestSingleInputGate; +import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; import org.apache.flink.runtime.memory.MemoryManager; import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider; +import org.apache.flink.runtime.state.LocalRecoveryConfig; +import org.apache.flink.runtime.state.LocalRecoveryDirectoryProviderImpl; +import org.apache.flink.runtime.state.TestLocalRecoveryConfig; import org.apache.flink.runtime.state.TestTaskStateManager; import org.apache.flink.streaming.api.TimeCharacteristic; import org.apache.flink.streaming.api.collector.selector.OutputSelector; @@ -47,6 +52,7 @@ import org.apache.flink.util.Preconditions; import org.junit.Assert; +import java.io.File; import java.io.IOException; import java.util.Collections; import java.util.LinkedList; @@ -109,7 +115,20 @@ public class StreamTaskTestHarness<OUT> { public StreamTaskTestHarness( Function<Environment, ? extends StreamTask<OUT, ?>> taskFactory, TypeInformation<OUT> outputType) { + this(taskFactory, outputType, TestLocalRecoveryConfig.disabled()); + } + public StreamTaskTestHarness( + Function<Environment, ? extends StreamTask<OUT, ?>> taskFactory, + TypeInformation<OUT> outputType, + File localRootDir) { + this(taskFactory, outputType, new LocalRecoveryConfig(true, new LocalRecoveryDirectoryProviderImpl(localRootDir, new JobID(), new JobVertexID(), 0))); + } + + public StreamTaskTestHarness( + Function<Environment, ? extends StreamTask<OUT, ?>> taskFactory, + TypeInformation<OUT> outputType, + LocalRecoveryConfig localRecoveryConfig) { this.taskFactory = checkNotNull(taskFactory); this.memorySize = DEFAULT_MEMORY_MANAGER_SIZE; this.bufferSize = DEFAULT_NETWORK_BUFFER_SIZE; @@ -123,7 +142,7 @@ public class StreamTaskTestHarness<OUT> { outputSerializer = outputType.createSerializer(executionConfig); outputStreamRecordSerializer = new StreamElementSerializer<OUT>(outputSerializer); - this.taskStateManager = new TestTaskStateManager(); + this.taskStateManager = new TestTaskStateManager(localRecoveryConfig); } public ProcessingTimeService getProcessingTimeService() { diff --git a/flink-tests/src/test/java/org/apache/flink/test/state/StatefulOperatorChainedTaskTest.java b/flink-tests/src/test/java/org/apache/flink/test/state/StatefulOperatorChainedTaskTest.java new file mode 100644 index 0000000..5651929 --- /dev/null +++ b/flink-tests/src/test/java/org/apache/flink/test/state/StatefulOperatorChainedTaskTest.java @@ -0,0 +1,260 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.test.state; + +import org.apache.flink.api.common.state.ValueState; +import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeutils.base.LongSerializer; +import org.apache.flink.api.common.typeutils.base.StringSerializer; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.core.testutils.OneShotLatch; +import org.apache.flink.runtime.checkpoint.CheckpointMetaData; +import org.apache.flink.runtime.checkpoint.CheckpointOptions; +import org.apache.flink.runtime.checkpoint.JobManagerTaskRestore; +import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; +import org.apache.flink.runtime.jobgraph.OperatorID; +import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.runtime.state.TestTaskStateManager; +import org.apache.flink.runtime.util.TestingTaskManagerRuntimeInfo; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.runtime.tasks.OneInputStreamTask; +import org.apache.flink.streaming.runtime.tasks.OneInputStreamTaskTestHarness; +import org.apache.flink.streaming.runtime.tasks.StreamMockEnvironment; +import org.apache.flink.streaming.util.TestHarnessUtil; + +import org.junit.Before; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.io.File; +import java.io.IOException; +import java.util.Arrays; +import java.util.HashSet; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; + +import static org.apache.flink.configuration.CheckpointingOptions.CHECKPOINTS_DIRECTORY; +import static org.apache.flink.configuration.CheckpointingOptions.INCREMENTAL_CHECKPOINTS; +import static org.apache.flink.configuration.CheckpointingOptions.STATE_BACKEND; +import static org.junit.Assert.assertEquals; + +/** + * Test for StatefulOperatorChainedTaskTest. + */ +public class StatefulOperatorChainedTaskTest { + + private static final Set<OperatorID> RESTORED_OPERATORS = ConcurrentHashMap.newKeySet(); + private TemporaryFolder temporaryFolder; + + @Before + public void setup() throws IOException { + RESTORED_OPERATORS.clear(); + temporaryFolder = new TemporaryFolder(); + temporaryFolder.create(); + } + + @Test + public void testMultipleStatefulOperatorChainedSnapshotAndRestore() throws Exception { + + OperatorID headOperatorID = new OperatorID(42L, 42L); + OperatorID tailOperatorID = new OperatorID(44L, 44L); + + JobManagerTaskRestore restore = createRunAndCheckpointOperatorChain( + headOperatorID, + new CounterOperator("head"), + tailOperatorID, + new CounterOperator("tail"), + Optional.empty()); + + TaskStateSnapshot stateHandles = restore.getTaskStateSnapshot(); + + assertEquals(2, stateHandles.getSubtaskStateMappings().size()); + + createRunAndCheckpointOperatorChain( + headOperatorID, + new CounterOperator("head"), + tailOperatorID, + new CounterOperator("tail"), + Optional.of(restore)); + + assertEquals(new HashSet<>(Arrays.asList(headOperatorID, tailOperatorID)), RESTORED_OPERATORS); + } + + private JobManagerTaskRestore createRunAndCheckpointOperatorChain( + OperatorID headId, + OneInputStreamOperator<String, String> headOperator, + OperatorID tailId, + OneInputStreamOperator<String, String> tailOperator, + Optional<JobManagerTaskRestore> restore) throws Exception { + + File localRootDir = temporaryFolder.newFolder(); + final OneInputStreamTaskTestHarness<String, String> testHarness = + new OneInputStreamTaskTestHarness<>( + OneInputStreamTask::new, + 1, 1, + BasicTypeInfo.STRING_TYPE_INFO, + BasicTypeInfo.STRING_TYPE_INFO, + localRootDir); + + testHarness.setupOperatorChain(headId, headOperator) + .chain(tailId, tailOperator, StringSerializer.INSTANCE, true) + .finish(); + + if (restore.isPresent()) { + JobManagerTaskRestore taskRestore = restore.get(); + testHarness.setTaskStateSnapshot( + taskRestore.getRestoreCheckpointId(), + taskRestore.getTaskStateSnapshot()); + } + + StreamMockEnvironment environment = new StreamMockEnvironment( + testHarness.jobConfig, + testHarness.taskConfig, + testHarness.getExecutionConfig(), + testHarness.memorySize, + new MockInputSplitProvider(), + testHarness.bufferSize, + testHarness.getTaskStateManager()); + + Configuration configuration = new Configuration(); + configuration.setString(STATE_BACKEND.key(), "rocksdb"); + File file = temporaryFolder.newFolder(); + configuration.setString(CHECKPOINTS_DIRECTORY.key(), file.toURI().toString()); + configuration.setString(INCREMENTAL_CHECKPOINTS.key(), "true"); + environment.setTaskManagerInfo( + new TestingTaskManagerRuntimeInfo( + configuration, + System.getProperty("java.io.tmpdir").split(",|" + File.pathSeparator))); + testHarness.invoke(environment); + testHarness.waitForTaskRunning(); + + OneInputStreamTask<String, String> streamTask = testHarness.getTask(); + + processRecords(testHarness); + triggerCheckpoint(testHarness, streamTask); + + TestTaskStateManager taskStateManager = testHarness.getTaskStateManager(); + + JobManagerTaskRestore jobManagerTaskRestore = new JobManagerTaskRestore( + taskStateManager.getReportedCheckpointId(), + taskStateManager.getLastJobManagerTaskStateSnapshot()); + + testHarness.endInput(); + testHarness.waitForTaskCompletion(); + return jobManagerTaskRestore; + } + + private void triggerCheckpoint( + OneInputStreamTaskTestHarness<String, String> testHarness, + OneInputStreamTask<String, String> streamTask) throws Exception { + + long checkpointId = 1L; + CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, 1L); + + testHarness.getTaskStateManager().setWaitForReportLatch(new OneShotLatch()); + + while (!streamTask.triggerCheckpoint(checkpointMetaData, CheckpointOptions.forCheckpointWithDefaultLocation(), false)) {} + + testHarness.getTaskStateManager().getWaitForReportLatch().await(); + long reportedCheckpointId = testHarness.getTaskStateManager().getReportedCheckpointId(); + + assertEquals(checkpointId, reportedCheckpointId); + } + + private void processRecords(OneInputStreamTaskTestHarness<String, String> testHarness) throws Exception { + ConcurrentLinkedQueue<Object> expectedOutput = new ConcurrentLinkedQueue<>(); + + testHarness.processElement(new StreamRecord<>("10"), 0, 0); + testHarness.processElement(new StreamRecord<>("20"), 0, 0); + testHarness.processElement(new StreamRecord<>("30"), 0, 0); + + testHarness.waitForInputProcessing(); + + expectedOutput.add(new StreamRecord<>("10")); + expectedOutput.add(new StreamRecord<>("20")); + expectedOutput.add(new StreamRecord<>("30")); + TestHarnessUtil.assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput()); + } + + private abstract static class RestoreWatchOperator<IN, OUT> + extends AbstractStreamOperator<OUT> + implements OneInputStreamOperator<IN, OUT> { + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + if (context.isRestored()) { + RESTORED_OPERATORS.add(getOperatorID()); + } + } + } + + /** + * Operator that counts processed messages and keeps result on state. + */ + private static class CounterOperator extends RestoreWatchOperator<String, String> { + private static final long serialVersionUID = 2048954179291813243L; + + private static long snapshotOutData = 0L; + private ValueState<Long> counterState; + private long counter = 0; + private String prefix; + + CounterOperator(String prefix) { + this.prefix = prefix; + } + + @Override + public void processElement(StreamRecord<String> element) throws Exception { + counter++; + output.collect(element); + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + + counterState = context + .getKeyedStateStore() + .getState(new ValueStateDescriptor<>(prefix + "counter-state", LongSerializer.INSTANCE)); + + // set key manually to make RocksDBListState get the serialized key. + setCurrentKey("10"); + + if (context.isRestored()) { + counter = counterState.value(); + assertEquals(snapshotOutData, counter); + counterState.clear(); + } + } + + @Override + public void snapshotState(StateSnapshotContext context) throws Exception { + counterState.update(counter); + snapshotOutData = counter; + } + } +} +