This is an automated email from the ASF dual-hosted git repository.

srichter pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 237d07c76b51c171f0f41f9f82d777df26da1dd4
Author: Yu Li <l...@apache.org>
AuthorDate: Fri Mar 1 03:02:35 2019 +0800

    [FLINK-11730] [State Backends] Make HeapKeyedStateBackend follow the 
builder pattern
    
    This closes #7866.
---
 .../KVStateRequestSerializerRocksDBTest.java       |   5 +-
 .../network/KvStateRequestSerializerTest.java      |  57 +--
 .../flink/runtime/state/AbstractStateBackend.java  |   9 +
 .../flink/runtime/state/RestoreOperation.java      |  10 +-
 .../runtime/state/filesystem/FsStateBackend.java   |  34 +-
 ...AsyncSnapshotStrategySynchronicityBehavior.java |  26 +-
 .../runtime/state/heap/HeapKeyedStateBackend.java  | 546 ++-------------------
 .../state/heap/HeapKeyedStateBackendBuilder.java   | 152 ++++++
 .../runtime/state/heap/HeapRestoreOperation.java   | 293 +++++++++++
 .../runtime/state/heap/HeapSnapshotStrategy.java   | 271 ++++++++++
 .../SnapshotStrategySynchronicityBehavior.java     |  24 +-
 .../apache/flink/runtime/state/heap/StateUID.java  |  73 +++
 .../SyncSnapshotStrategySynchronicityBehavior.java |  32 +-
 .../runtime/state/memory/MemoryStateBackend.java   |  34 +-
 .../state/StateBackendMigrationTestBase.java       |   2 -
 .../flink/runtime/state/StateBackendTestBase.java  |  14 +-
 .../state/StateSnapshotCompressionTest.java        |  89 ++--
 ...HeapKeyedStateBackendSnapshotMigrationTest.java |  25 +-
 .../state/heap/HeapStateBackendTestBase.java       |  23 +-
 .../runtime/state/ttl/StateBackendTestContext.java |  13 -
 .../flink/runtime/state/ttl/TtlStateTestBase.java  |   3 -
 .../state/ttl/mock/MockKeyedStateBackend.java      |  24 +-
 .../ttl/mock/MockKeyedStateBackendBuilder.java     |  85 ++++
 .../state/ttl/mock/MockRestoreOperation.java       |  53 ++
 .../runtime/state/ttl/mock/MockStateBackend.java   |   8 +-
 .../streaming/state/RocksDBStateBackend.java       |  10 -
 .../state/restore/RocksDBRestoreOperation.java     |   4 +-
 .../streaming/state/RocksDBStateBackendTest.java   |   5 +-
 28 files changed, 1188 insertions(+), 736 deletions(-)

diff --git 
a/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KVStateRequestSerializerRocksDBTest.java
 
b/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KVStateRequestSerializerRocksDBTest.java
index 3431199..a5df958 100644
--- 
a/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KVStateRequestSerializerRocksDBTest.java
+++ 
b/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KVStateRequestSerializerRocksDBTest.java
@@ -32,6 +32,7 @@ import 
org.apache.flink.metrics.groups.UnregisteredMetricsGroup;
 import org.apache.flink.queryablestate.client.VoidNamespace;
 import org.apache.flink.queryablestate.client.VoidNamespaceSerializer;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
+import org.apache.flink.runtime.state.AbstractStateBackend;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.TestLocalRecoveryConfig;
 import org.apache.flink.runtime.state.internal.InternalListState;
@@ -88,7 +89,7 @@ public final class KVStateRequestSerializerRocksDBTest {
                                TtlTimeProvider.DEFAULT,
                                new UnregisteredMetricsGroup(),
                                Collections.emptyList(),
-                               
RocksDBStateBackend.getCompressionDecorator(executionConfig),
+                               
AbstractStateBackend.getCompressionDecorator(executionConfig),
                                new CloseableRegistry()
                        ).build();
                longHeapKeyedStateBackend.setCurrentKey(key);
@@ -132,7 +133,7 @@ public final class KVStateRequestSerializerRocksDBTest {
                                TtlTimeProvider.DEFAULT,
                                new UnregisteredMetricsGroup(),
                                Collections.emptyList(),
-                               
RocksDBStateBackend.getCompressionDecorator(executionConfig),
+                               
AbstractStateBackend.getCompressionDecorator(executionConfig),
                                new CloseableRegistry()
                        ).build();
                longHeapKeyedStateBackend.setCurrentKey(key);
diff --git 
a/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KvStateRequestSerializerTest.java
 
b/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KvStateRequestSerializerTest.java
index aac3394..2ad202f 100644
--- 
a/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KvStateRequestSerializerTest.java
+++ 
b/flink-queryable-state/flink-queryable-state-runtime/src/test/java/org/apache/flink/queryablestate/network/KvStateRequestSerializerTest.java
@@ -30,9 +30,12 @@ import org.apache.flink.queryablestate.client.VoidNamespace;
 import org.apache.flink.queryablestate.client.VoidNamespaceSerializer;
 import 
org.apache.flink.queryablestate.client.state.serialization.KvStateSerializer;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
+import org.apache.flink.runtime.state.AbstractStateBackend;
+import org.apache.flink.runtime.state.BackendBuildingException;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.TestLocalRecoveryConfig;
 import org.apache.flink.runtime.state.heap.HeapKeyedStateBackend;
+import org.apache.flink.runtime.state.heap.HeapKeyedStateBackendBuilder;
 import org.apache.flink.runtime.state.heap.HeapPriorityQueueSetFactory;
 import org.apache.flink.runtime.state.internal.InternalKvState;
 import org.apache.flink.runtime.state.internal.InternalListState;
@@ -48,6 +51,7 @@ import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -188,22 +192,7 @@ public class KvStateRequestSerializerTest {
        @Test
        public void testListSerialization() throws Exception {
                final long key = 0L;
-               final KeyGroupRange keyGroupRange = new KeyGroupRange(0, 0);
-               // objects for heap state list serialisation
-               final HeapKeyedStateBackend<Long> longHeapKeyedStateBackend =
-                       new HeapKeyedStateBackend<>(
-                               mock(TaskKvStateRegistry.class),
-                               LongSerializer.INSTANCE,
-                               ClassLoader.getSystemClassLoader(),
-                               keyGroupRange.getNumberOfKeyGroups(),
-                               keyGroupRange,
-                               async,
-                               new ExecutionConfig(),
-                               TestLocalRecoveryConfig.disabled(),
-                               new HeapPriorityQueueSetFactory(keyGroupRange, 
keyGroupRange.getNumberOfKeyGroups(), 128),
-                               TtlTimeProvider.DEFAULT,
-                               new CloseableRegistry());
-               longHeapKeyedStateBackend.setCurrentKey(key);
+               final HeapKeyedStateBackend<Long> longHeapKeyedStateBackend = 
getLongHeapKeyedStateBackend(key);
 
                final InternalListState<Long, VoidNamespace, Long> listState = 
longHeapKeyedStateBackend.createInternalState(
                        VoidNamespaceSerializer.INSTANCE,
@@ -297,31 +286,39 @@ public class KvStateRequestSerializerTest {
        @Test
        public void testMapSerialization() throws Exception {
                final long key = 0L;
+               final HeapKeyedStateBackend<Long> longHeapKeyedStateBackend = 
getLongHeapKeyedStateBackend(key);
+
+               final InternalMapState<Long, VoidNamespace, Long, String> 
mapState =
+                               (InternalMapState<Long, VoidNamespace, Long, 
String>)
+                                               
longHeapKeyedStateBackend.getPartitionedState(
+                                                               
VoidNamespace.INSTANCE,
+                                                               
VoidNamespaceSerializer.INSTANCE,
+                                                               new 
MapStateDescriptor<>("test", LongSerializer.INSTANCE, 
StringSerializer.INSTANCE));
+
+               testMapSerialization(key, mapState);
+       }
+
+       private HeapKeyedStateBackend<Long> getLongHeapKeyedStateBackend(final 
long key) throws BackendBuildingException {
                final KeyGroupRange keyGroupRange = new KeyGroupRange(0, 0);
+               ExecutionConfig executionConfig = new ExecutionConfig();
                // objects for heap state list serialisation
                final HeapKeyedStateBackend<Long> longHeapKeyedStateBackend =
-                       new HeapKeyedStateBackend<>(
+                       new HeapKeyedStateBackendBuilder<>(
                                mock(TaskKvStateRegistry.class),
                                LongSerializer.INSTANCE,
                                ClassLoader.getSystemClassLoader(),
                                keyGroupRange.getNumberOfKeyGroups(),
                                keyGroupRange,
-                               async,
-                               new ExecutionConfig(),
+                               executionConfig,
+                               TtlTimeProvider.DEFAULT,
+                               Collections.emptyList(),
+                               
AbstractStateBackend.getCompressionDecorator(executionConfig),
                                TestLocalRecoveryConfig.disabled(),
                                new HeapPriorityQueueSetFactory(keyGroupRange, 
keyGroupRange.getNumberOfKeyGroups(), 128),
-                               TtlTimeProvider.DEFAULT,
-                               new CloseableRegistry());
+                               async,
+                               new CloseableRegistry()).build();
                longHeapKeyedStateBackend.setCurrentKey(key);
-
-               final InternalMapState<Long, VoidNamespace, Long, String> 
mapState =
-                               (InternalMapState<Long, VoidNamespace, Long, 
String>)
-                                               
longHeapKeyedStateBackend.getPartitionedState(
-                                                               
VoidNamespace.INSTANCE,
-                                                               
VoidNamespaceSerializer.INSTANCE,
-                                                               new 
MapStateDescriptor<>("test", LongSerializer.INSTANCE, 
StringSerializer.INSTANCE));
-
-               testMapSerialization(key, mapState);
+               return longHeapKeyedStateBackend;
        }
 
        /**
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractStateBackend.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractStateBackend.java
index 3ebf09b..2343d83 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractStateBackend.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractStateBackend.java
@@ -19,6 +19,7 @@
 package org.apache.flink.runtime.state;
 
 import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.core.fs.CloseableRegistry;
@@ -42,6 +43,14 @@ public abstract class AbstractStateBackend implements 
StateBackend, java.io.Seri
 
        private static final long serialVersionUID = 4620415814639230247L;
 
+       public static StreamCompressionDecorator 
getCompressionDecorator(ExecutionConfig executionConfig) {
+               if (executionConfig != null && 
executionConfig.isUseSnapshotCompression()) {
+                       return SnappyStreamCompressionDecorator.INSTANCE;
+               } else {
+                       return UncompressedStreamCompressionDecorator.INSTANCE;
+               }
+       }
+
        // 
------------------------------------------------------------------------
        //  State Backend - State-Holding Backends
        // 
------------------------------------------------------------------------
diff --git 
a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBRestoreOperation.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/RestoreOperation.java
similarity index 82%
copy from 
flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBRestoreOperation.java
copy to 
flink-runtime/src/main/java/org/apache/flink/runtime/state/RestoreOperation.java
index ff70199..4de4208 100644
--- 
a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBRestoreOperation.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/RestoreOperation.java
@@ -16,14 +16,16 @@
  * limitations under the License.
  */
 
-package org.apache.flink.contrib.streaming.state.restore;
+package org.apache.flink.runtime.state;
 
 /**
- * Interface for RocksDB restore.
+ * Interface for restore operation.
+ *
+ * @param <R> Generic type of the restore result.
  */
-public interface RocksDBRestoreOperation {
+public interface RestoreOperation<R> {
        /**
         * Restores state that was previously snapshot-ed from the provided 
state handles.
         */
-       RocksDBRestoreResult restore() throws Exception;
+       R restore() throws Exception;
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java
index 594e526..0511dd3 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java
@@ -30,15 +30,17 @@ import org.apache.flink.metrics.MetricGroup;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
 import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
+import org.apache.flink.runtime.state.AbstractStateBackend;
+import org.apache.flink.runtime.state.BackendBuildingException;
 import org.apache.flink.runtime.state.CheckpointStorage;
 import org.apache.flink.runtime.state.ConfigurableStateBackend;
 import org.apache.flink.runtime.state.DefaultOperatorStateBackend;
-import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.LocalRecoveryConfig;
 import org.apache.flink.runtime.state.OperatorStateBackend;
 import org.apache.flink.runtime.state.TaskStateManager;
-import org.apache.flink.runtime.state.heap.HeapKeyedStateBackend;
+import org.apache.flink.runtime.state.heap.HeapKeyedStateBackendBuilder;
 import org.apache.flink.runtime.state.heap.HeapPriorityQueueSetFactory;
 import org.apache.flink.runtime.state.ttl.TtlTimeProvider;
 import org.apache.flink.util.TernaryBoolean;
@@ -463,25 +465,27 @@ public class FsStateBackend extends 
AbstractFileStateBackend implements Configur
                TtlTimeProvider ttlTimeProvider,
                MetricGroup metricGroup,
                @Nonnull Collection<KeyedStateHandle> stateHandles,
-               CloseableRegistry cancelStreamRegistry) {
+               CloseableRegistry cancelStreamRegistry) throws 
BackendBuildingException {
 
                TaskStateManager taskStateManager = env.getTaskStateManager();
                LocalRecoveryConfig localRecoveryConfig = 
taskStateManager.createLocalRecoveryConfig();
                HeapPriorityQueueSetFactory priorityQueueSetFactory =
                        new HeapPriorityQueueSetFactory(keyGroupRange, 
numberOfKeyGroups, 128);
 
-               return new HeapKeyedStateBackend<>(
-                               kvStateRegistry,
-                               keySerializer,
-                               env.getUserClassLoader(),
-                               numberOfKeyGroups,
-                               keyGroupRange,
-                               isUsingAsynchronousSnapshots(),
-                               env.getExecutionConfig(),
-                               localRecoveryConfig,
-                               priorityQueueSetFactory,
-                               ttlTimeProvider,
-                       cancelStreamRegistry);
+               return new HeapKeyedStateBackendBuilder<>(
+                       kvStateRegistry,
+                       keySerializer,
+                       env.getUserClassLoader(),
+                       numberOfKeyGroups,
+                       keyGroupRange,
+                       env.getExecutionConfig(),
+                       ttlTimeProvider,
+                       stateHandles,
+                       
AbstractStateBackend.getCompressionDecorator(env.getExecutionConfig()),
+                       localRecoveryConfig,
+                       priorityQueueSetFactory,
+                       isUsingAsynchronousSnapshots(),
+                       cancelStreamRegistry).build();
        }
 
        @Override
diff --git 
a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBRestoreOperation.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/AsyncSnapshotStrategySynchronicityBehavior.java
similarity index 56%
copy from 
flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBRestoreOperation.java
copy to 
flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/AsyncSnapshotStrategySynchronicityBehavior.java
index ff70199..05d6475 100644
--- 
a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBRestoreOperation.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/AsyncSnapshotStrategySynchronicityBehavior.java
@@ -16,14 +16,26 @@
  * limitations under the License.
  */
 
-package org.apache.flink.contrib.streaming.state.restore;
+package org.apache.flink.runtime.state.heap;
+
+import org.apache.flink.runtime.state.RegisteredKeyValueStateBackendMetaInfo;
 
 /**
- * Interface for RocksDB restore.
+ * Asynchronous behavior for heap snapshot strategy.
+ *
+ * @param <K> The data type that the serializer serializes.
  */
-public interface RocksDBRestoreOperation {
-       /**
-        * Restores state that was previously snapshot-ed from the provided 
state handles.
-        */
-       RocksDBRestoreResult restore() throws Exception;
+class AsyncSnapshotStrategySynchronicityBehavior<K> implements 
SnapshotStrategySynchronicityBehavior<K> {
+
+       @Override
+       public boolean isAsynchronous() {
+               return true;
+       }
+
+       @Override
+       public <N, V> StateTable<K, N, V> newStateTable(
+               InternalKeyContext<K> keyContext,
+               RegisteredKeyValueStateBackendMetaInfo<N, V> newMetaInfo) {
+               return new CopyOnWriteStateTable<>(keyContext, newMetaInfo);
+       }
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
index 7614a55..a3a0c29 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
@@ -32,64 +32,39 @@ import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.common.typeutils.TypeSerializerSchemaCompatibility;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.core.fs.CloseableRegistry;
-import org.apache.flink.core.fs.FSDataInputStream;
-import org.apache.flink.core.memory.DataInputViewStreamWrapper;
-import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
 import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
-import org.apache.flink.runtime.state.AbstractSnapshotStrategy;
-import org.apache.flink.runtime.state.AsyncSnapshotCallable;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
-import org.apache.flink.runtime.state.CheckpointStreamWithResultProvider;
-import org.apache.flink.runtime.state.CheckpointedStateScope;
-import org.apache.flink.runtime.state.DoneFuture;
 import org.apache.flink.runtime.state.KeyExtractorFunction;
 import org.apache.flink.runtime.state.KeyGroupRange;
-import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
 import org.apache.flink.runtime.state.KeyGroupedInternalPriorityQueue;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.Keyed;
-import org.apache.flink.runtime.state.KeyedBackendSerializationProxy;
 import org.apache.flink.runtime.state.KeyedStateFunction;
 import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.LocalRecoveryConfig;
 import org.apache.flink.runtime.state.PriorityComparable;
 import org.apache.flink.runtime.state.RegisteredKeyValueStateBackendMetaInfo;
 import 
org.apache.flink.runtime.state.RegisteredPriorityQueueStateBackendMetaInfo;
-import org.apache.flink.runtime.state.SnappyStreamCompressionDecorator;
 import org.apache.flink.runtime.state.SnapshotResult;
-import org.apache.flink.runtime.state.StateSnapshot;
-import org.apache.flink.runtime.state.StateSnapshotKeyGroupReader;
+import org.apache.flink.runtime.state.StateSerializerProvider;
 import org.apache.flink.runtime.state.StateSnapshotRestore;
 import 
org.apache.flink.runtime.state.StateSnapshotTransformer.StateSnapshotTransformFactory;
 import org.apache.flink.runtime.state.StateSnapshotTransformers;
 import org.apache.flink.runtime.state.StreamCompressionDecorator;
-import org.apache.flink.runtime.state.StreamStateHandle;
-import org.apache.flink.runtime.state.UncompressedStreamCompressionDecorator;
-import org.apache.flink.runtime.state.metainfo.StateMetaInfoSnapshot;
 import org.apache.flink.runtime.state.ttl.TtlTimeProvider;
 import org.apache.flink.util.FlinkRuntimeException;
-import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.StateMigrationException;
-import org.apache.flink.util.function.SupplierWithException;
 
-import org.apache.commons.io.IOUtils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import javax.annotation.Nonnull;
 
 import java.io.IOException;
-import java.io.InputStream;
-import java.io.OutputStream;
-import java.util.ArrayList;
 import java.util.Collection;
-import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
-import java.util.Objects;
-import java.util.concurrent.FutureTask;
 import java.util.concurrent.RunnableFuture;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
@@ -132,7 +107,7 @@ public class HeapKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
        /**
         * The snapshot strategy for this backend. This determines, e.g., if 
snapshots are synchronous or asynchronous.
         */
-       private final HeapSnapshotStrategy snapshotStrategy;
+       private final HeapSnapshotStrategy<K> snapshotStrategy;
 
        /**
         * Factory for state that is organized as priority queue.
@@ -141,31 +116,28 @@ public class HeapKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
 
        public HeapKeyedStateBackend(
                TaskKvStateRegistry kvStateRegistry,
-               TypeSerializer<K> keySerializer,
+               StateSerializerProvider<K> keySerializerProvider,
                ClassLoader userCodeClassLoader,
                int numberOfKeyGroups,
                KeyGroupRange keyGroupRange,
-               boolean asynchronousSnapshots,
                ExecutionConfig executionConfig,
+               TtlTimeProvider ttlTimeProvider,
+               CloseableRegistry cancelStreamRegistry,
+               StreamCompressionDecorator keyGroupCompressionDecorator,
+               Map<String, StateTable<K, ?, ?>> registeredKVStates,
+               Map<String, HeapPriorityQueueSnapshotRestoreWrapper> 
registeredPQStates,
                LocalRecoveryConfig localRecoveryConfig,
                HeapPriorityQueueSetFactory priorityQueueSetFactory,
-               TtlTimeProvider ttlTimeProvider,
-               CloseableRegistry cancelStreamRegistry) {
-
-               super(kvStateRegistry, keySerializer, userCodeClassLoader,
-                       numberOfKeyGroups, keyGroupRange, executionConfig, 
ttlTimeProvider, new CloseableRegistry());
-
-               this.registeredKVStates = new HashMap<>();
-               this.registeredPQStates = new HashMap<>();
-               this.localRecoveryConfig = 
Preconditions.checkNotNull(localRecoveryConfig);
-
-               SnapshotStrategySynchronicityBehavior<K> synchronicityTrait = 
asynchronousSnapshots ?
-                       new AsyncSnapshotStrategySynchronicityBehavior() :
-                       new SyncSnapshotStrategySynchronicityBehavior();
-
-               this.snapshotStrategy = new 
HeapSnapshotStrategy(synchronicityTrait);
+               HeapSnapshotStrategy<K> snapshotStrategy
+       ) {
+               super(kvStateRegistry, keySerializerProvider, 
userCodeClassLoader, numberOfKeyGroups,
+                       keyGroupRange, executionConfig, ttlTimeProvider, 
cancelStreamRegistry, keyGroupCompressionDecorator);
+               this.registeredKVStates = registeredKVStates;
+               this.registeredPQStates = registeredPQStates;
+               this.localRecoveryConfig = localRecoveryConfig;
                LOG.info("Initializing heap keyed state backend with stream 
factory.");
                this.priorityQueueSetFactory = priorityQueueSetFactory;
+               this.snapshotStrategy = snapshotStrategy;
        }
 
        // 
------------------------------------------------------------------------
@@ -227,9 +199,9 @@ public class HeapKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
        }
 
        private <N, V> StateTable<K, N, V> tryRegisterStateTable(
-                       TypeSerializer<N> namespaceSerializer,
-                       StateDescriptor<?, V> stateDesc,
-                       @Nonnull StateSnapshotTransformFactory<V> 
snapshotTransformFactory) throws StateMigrationException {
+               TypeSerializer<N> namespaceSerializer,
+               StateDescriptor<?, V> stateDesc,
+               @Nonnull StateSnapshotTransformFactory<V> 
snapshotTransformFactory) throws StateMigrationException {
 
                @SuppressWarnings("unchecked")
                StateTable<K, N, V> stateTable = (StateTable<K, N, V>) 
registeredKVStates.get(stateDesc.getName());
@@ -265,7 +237,7 @@ public class HeapKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
                                newStateSerializer,
                                snapshotTransformFactory);
 
-                       stateTable = 
snapshotStrategy.newStateTable(newMetaInfo);
+                       stateTable = snapshotStrategy.newStateTable(this, 
newMetaInfo);
                        registeredKVStates.put(stateDesc.getName(), stateTable);
                }
 
@@ -284,10 +256,6 @@ public class HeapKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
                return table.getKeys(namespace);
        }
 
-       private boolean hasRegisteredState() {
-               return !(registeredKVStates.isEmpty() && 
registeredPQStates.isEmpty());
-       }
-
        @Override
        @Nonnull
        public <N, SV, SEV, S extends State, IS extends S> IS 
createInternalState(
@@ -337,189 +305,8 @@ public class HeapKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
        }
 
        @SuppressWarnings("deprecation")
-       public void restore(Collection<KeyedStateHandle> restoredState) throws 
Exception {
-               if (restoredState == null || restoredState.isEmpty()) {
-                       return;
-               }
-
-               LOG.info("Initializing heap keyed state backend from 
snapshot.");
-
-               if (LOG.isDebugEnabled()) {
-                       LOG.debug("Restoring snapshot from state handles: {}.", 
restoredState);
-               }
-
-               restorePartitionedState(restoredState);
-       }
-
-       @SuppressWarnings({"unchecked"})
-       private void restorePartitionedState(Collection<KeyedStateHandle> 
state) throws Exception {
-
-               final Map<Integer, StateMetaInfoSnapshot> kvStatesById = new 
HashMap<>();
-               registeredKVStates.clear();
-               registeredPQStates.clear();
-
-               boolean keySerializerRestored = false;
-
-               for (KeyedStateHandle keyedStateHandle : state) {
-
-                       if (keyedStateHandle == null) {
-                               continue;
-                       }
-
-                       if (!(keyedStateHandle instanceof 
KeyGroupsStateHandle)) {
-                               throw new IllegalStateException("Unexpected 
state handle type, " +
-                                               "expected: " + 
KeyGroupsStateHandle.class +
-                                               ", but found: " + 
keyedStateHandle.getClass());
-                       }
-
-                       KeyGroupsStateHandle keyGroupsStateHandle = 
(KeyGroupsStateHandle) keyedStateHandle;
-                       FSDataInputStream fsDataInputStream = 
keyGroupsStateHandle.openInputStream();
-                       
cancelStreamRegistry.registerCloseable(fsDataInputStream);
-
-                       try {
-                               DataInputViewStreamWrapper inView = new 
DataInputViewStreamWrapper(fsDataInputStream);
-
-                               KeyedBackendSerializationProxy<K> 
serializationProxy =
-                                               new 
KeyedBackendSerializationProxy<>(userCodeClassLoader);
-
-                               serializationProxy.read(inView);
-
-                               if (!keySerializerRestored) {
-                                       // check for key serializer 
compatibility; this also reconfigures the
-                                       // key serializer to be compatible, if 
it is required and is possible
-                                       TypeSerializerSchemaCompatibility<K> 
keySerializerSchemaCompat =
-                                               
checkKeySerializerSchemaCompatibility(serializationProxy.getKeySerializerSnapshot());
-                                       if 
(keySerializerSchemaCompat.isCompatibleAfterMigration() || 
keySerializerSchemaCompat.isIncompatible()) {
-                                               throw new 
StateMigrationException("The new key serializer must be compatible.");
-                                       }
-
-                                       keySerializerRestored = true;
-                               }
-
-                               List<StateMetaInfoSnapshot> restoredMetaInfos =
-                                       
serializationProxy.getStateMetaInfoSnapshots();
-
-                               
createOrCheckStateForMetaInfo(restoredMetaInfos, kvStatesById);
-
-                               readStateHandleStateData(
-                                       fsDataInputStream,
-                                       inView,
-                                       
keyGroupsStateHandle.getGroupRangeOffsets(),
-                                       kvStatesById, restoredMetaInfos.size(),
-                                       serializationProxy.getReadVersion(),
-                                       
serializationProxy.isUsingKeyGroupCompression());
-                       } finally {
-                               if 
(cancelStreamRegistry.unregisterCloseable(fsDataInputStream)) {
-                                       IOUtils.closeQuietly(fsDataInputStream);
-                               }
-                       }
-               }
-       }
-
-       private void readStateHandleStateData(
-               FSDataInputStream fsDataInputStream,
-               DataInputViewStreamWrapper inView,
-               KeyGroupRangeOffsets keyGroupOffsets,
-               Map<Integer, StateMetaInfoSnapshot> kvStatesById,
-               int numStates,
-               int readVersion,
-               boolean isCompressed) throws IOException {
-
-               final StreamCompressionDecorator streamCompressionDecorator = 
isCompressed ?
-                       SnappyStreamCompressionDecorator.INSTANCE : 
UncompressedStreamCompressionDecorator.INSTANCE;
-
-               for (Tuple2<Integer, Long> groupOffset : keyGroupOffsets) {
-                       int keyGroupIndex = groupOffset.f0;
-                       long offset = groupOffset.f1;
-
-                       // Check that restored key groups all belong to the 
backend.
-                       
Preconditions.checkState(keyGroupRange.contains(keyGroupIndex), "The key group 
must belong to the backend.");
-
-                       fsDataInputStream.seek(offset);
-
-                       int writtenKeyGroupIndex = inView.readInt();
-                       Preconditions.checkState(writtenKeyGroupIndex == 
keyGroupIndex,
-                               "Unexpected key-group in restore.");
-
-                       try (InputStream kgCompressionInStream =
-                                        
streamCompressionDecorator.decorateWithCompression(fsDataInputStream)) {
-
-                               readKeyGroupStateData(
-                                       kgCompressionInStream,
-                                       kvStatesById,
-                                       keyGroupIndex,
-                                       numStates,
-                                       readVersion);
-                       }
-               }
-       }
-
-       private void readKeyGroupStateData(
-               InputStream inputStream,
-               Map<Integer, StateMetaInfoSnapshot> kvStatesById,
-               int keyGroupIndex,
-               int numStates,
-               int readVersion) throws IOException {
-
-               DataInputViewStreamWrapper inView =
-                       new DataInputViewStreamWrapper(inputStream);
-
-               for (int i = 0; i < numStates; i++) {
-
-                       final int kvStateId = inView.readShort();
-                       final StateMetaInfoSnapshot stateMetaInfoSnapshot = 
kvStatesById.get(kvStateId);
-                       final StateSnapshotRestore registeredState;
-
-                       switch (stateMetaInfoSnapshot.getBackendStateType()) {
-                               case KEY_VALUE:
-                                       registeredState = 
registeredKVStates.get(stateMetaInfoSnapshot.getName());
-                                       break;
-                               case PRIORITY_QUEUE:
-                                       registeredState = 
registeredPQStates.get(stateMetaInfoSnapshot.getName());
-                                       break;
-                               default:
-                                       throw new 
IllegalStateException("Unexpected state type: " +
-                                               
stateMetaInfoSnapshot.getBackendStateType() + ".");
-                       }
-
-                       StateSnapshotKeyGroupReader keyGroupReader = 
registeredState.keyGroupReader(readVersion);
-                       keyGroupReader.readMappingsInKeyGroup(inView, 
keyGroupIndex);
-               }
-       }
-
-       private void createOrCheckStateForMetaInfo(
-               List<StateMetaInfoSnapshot> restoredMetaInfo,
-               Map<Integer, StateMetaInfoSnapshot> kvStatesById) {
-
-               for (StateMetaInfoSnapshot metaInfoSnapshot : restoredMetaInfo) 
{
-                       final StateSnapshotRestore registeredState;
-
-                       switch (metaInfoSnapshot.getBackendStateType()) {
-                               case KEY_VALUE:
-                                       registeredState = 
registeredKVStates.get(metaInfoSnapshot.getName());
-                                       if (registeredState == null) {
-                                               
RegisteredKeyValueStateBackendMetaInfo<?, ?> 
registeredKeyedBackendStateMetaInfo =
-                                                       new 
RegisteredKeyValueStateBackendMetaInfo<>(metaInfoSnapshot);
-                                               registeredKVStates.put(
-                                                       
metaInfoSnapshot.getName(),
-                                                       
snapshotStrategy.newStateTable(registeredKeyedBackendStateMetaInfo));
-                                       }
-                                       break;
-                               case PRIORITY_QUEUE:
-                                       registeredState = 
registeredPQStates.get(metaInfoSnapshot.getName());
-                                       if (registeredState == null) {
-                                               createInternal(new 
RegisteredPriorityQueueStateBackendMetaInfo<>(metaInfoSnapshot));
-                                       }
-                                       break;
-                               default:
-                                       throw new 
IllegalStateException("Unexpected state type: " +
-                                               
metaInfoSnapshot.getBackendStateType() + ".");
-                       }
-
-                       if (registeredState == null) {
-                               kvStatesById.put(kvStatesById.size(), 
metaInfoSnapshot);
-                       }
-               }
+       public void restore(Collection<KeyedStateHandle> restoredState) {
+               // all restore work done in builder and nothing to do here
        }
 
        @Override
@@ -529,10 +316,10 @@ public class HeapKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
 
        @Override
        public <N, S extends State, T> void applyToAllKeys(
-                       final N namespace,
-                       final TypeSerializer<N> namespaceSerializer,
-                       final StateDescriptor<S, T> stateDescriptor,
-                       final KeyedStateFunction<K, S> function) throws 
Exception {
+               final N namespace,
+               final TypeSerializer<N> namespaceSerializer,
+               final StateDescriptor<S, T> stateDescriptor,
+               final KeyedStateFunction<K, S> function) throws Exception {
 
                try (Stream<K> keyStream = getKeys(stateDescriptor.getName(), 
namespace)) {
 
@@ -541,9 +328,9 @@ public class HeapKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
                        final List<K> keys = 
keyStream.collect(Collectors.toList());
 
                        final S state = getPartitionedState(
-                                       namespace,
-                                       namespaceSerializer,
-                                       stateDescriptor);
+                               namespace,
+                               namespaceSerializer,
+                               stateDescriptor);
 
                        for (K key : keys) {
                                setCurrentKey(key);
@@ -593,233 +380,6 @@ public class HeapKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
                return localRecoveryConfig;
        }
 
-       private interface SnapshotStrategySynchronicityBehavior<K> {
-
-               default void finalizeSnapshotBeforeReturnHook(Runnable 
runnable) {
-
-               }
-
-
-               boolean isAsynchronous();
-
-               <N, V> StateTable<K, N, V> 
newStateTable(RegisteredKeyValueStateBackendMetaInfo<N, V> newMetaInfo);
-       }
-
-       private class AsyncSnapshotStrategySynchronicityBehavior implements 
SnapshotStrategySynchronicityBehavior<K> {
-
-               @Override
-               public boolean isAsynchronous() {
-                       return true;
-               }
-
-               @Override
-               public <N, V> StateTable<K, N, V> 
newStateTable(RegisteredKeyValueStateBackendMetaInfo<N, V> newMetaInfo) {
-                       return new 
CopyOnWriteStateTable<>(HeapKeyedStateBackend.this, newMetaInfo);
-               }
-       }
-
-       private class SyncSnapshotStrategySynchronicityBehavior implements 
SnapshotStrategySynchronicityBehavior<K> {
-
-               @Override
-               public void finalizeSnapshotBeforeReturnHook(Runnable runnable) 
{
-                       // this triggers a synchronous execution from the main 
checkpointing thread.
-                       runnable.run();
-               }
-
-               @Override
-               public boolean isAsynchronous() {
-                       return false;
-               }
-
-               @Override
-               public <N, V> StateTable<K, N, V> 
newStateTable(RegisteredKeyValueStateBackendMetaInfo<N, V> newMetaInfo) {
-                       return new 
NestedMapsStateTable<>(HeapKeyedStateBackend.this, newMetaInfo);
-               }
-       }
-
-       /**
-        * Base class for the snapshots of the heap backend that outlines the 
algorithm and offers some hooks to realize
-        * the concrete strategies. Subclasses must be threadsafe.
-        */
-       private class HeapSnapshotStrategy
-               extends AbstractSnapshotStrategy<KeyedStateHandle> implements 
SnapshotStrategySynchronicityBehavior<K> {
-
-               private final SnapshotStrategySynchronicityBehavior<K> 
snapshotStrategySynchronicityTrait;
-
-               HeapSnapshotStrategy(
-                       SnapshotStrategySynchronicityBehavior<K> 
snapshotStrategySynchronicityTrait) {
-                       super("Heap backend snapshot");
-                       this.snapshotStrategySynchronicityTrait = 
snapshotStrategySynchronicityTrait;
-               }
-
-               @Nonnull
-               @Override
-               public RunnableFuture<SnapshotResult<KeyedStateHandle>> 
snapshot(
-                       long checkpointId,
-                       long timestamp,
-                       @Nonnull CheckpointStreamFactory primaryStreamFactory,
-                       @Nonnull CheckpointOptions checkpointOptions) throws 
IOException {
-
-                       if (!hasRegisteredState()) {
-                               return DoneFuture.of(SnapshotResult.empty());
-                       }
-
-                       int numStates = registeredKVStates.size() + 
registeredPQStates.size();
-
-                       Preconditions.checkState(numStates <= Short.MAX_VALUE,
-                               "Too many states: " + numStates +
-                                       ". Currently at most " + 
Short.MAX_VALUE + " states are supported");
-
-                       final List<StateMetaInfoSnapshot> metaInfoSnapshots = 
new ArrayList<>(numStates);
-                       final Map<StateUID, Integer> stateNamesToId =
-                               new HashMap<>(numStates);
-                       final Map<StateUID, StateSnapshot> 
cowStateStableSnapshots =
-                               new HashMap<>(numStates);
-
-                       processSnapshotMetaInfoForAllStates(
-                               metaInfoSnapshots,
-                               cowStateStableSnapshots,
-                               stateNamesToId,
-                               registeredKVStates,
-                               
StateMetaInfoSnapshot.BackendStateType.KEY_VALUE);
-
-                       processSnapshotMetaInfoForAllStates(
-                               metaInfoSnapshots,
-                               cowStateStableSnapshots,
-                               stateNamesToId,
-                               registeredPQStates,
-                               
StateMetaInfoSnapshot.BackendStateType.PRIORITY_QUEUE);
-
-                       final KeyedBackendSerializationProxy<K> 
serializationProxy =
-                               new KeyedBackendSerializationProxy<>(
-                                       // TODO: this code assumes that writing 
a serializer is threadsafe, we should support to
-                                       // get a serialized form already at 
state registration time in the future
-                                       getKeySerializer(),
-                                       metaInfoSnapshots,
-                                       
!Objects.equals(UncompressedStreamCompressionDecorator.INSTANCE, 
keyGroupCompressionDecorator));
-
-                       final 
SupplierWithException<CheckpointStreamWithResultProvider, Exception> 
checkpointStreamSupplier =
-
-                               localRecoveryConfig.isLocalRecoveryEnabled() ?
-
-                                       () -> 
CheckpointStreamWithResultProvider.createDuplicatingStream(
-                                               checkpointId,
-                                               
CheckpointedStateScope.EXCLUSIVE,
-                                               primaryStreamFactory,
-                                               
localRecoveryConfig.getLocalStateDirectoryProvider()) :
-
-                                       () -> 
CheckpointStreamWithResultProvider.createSimpleStream(
-                                               
CheckpointedStateScope.EXCLUSIVE,
-                                               primaryStreamFactory);
-
-                       //--------------------------------------------------- 
this becomes the end of sync part
-
-                       final 
AsyncSnapshotCallable<SnapshotResult<KeyedStateHandle>> asyncSnapshotCallable =
-                               new 
AsyncSnapshotCallable<SnapshotResult<KeyedStateHandle>>() {
-                                       @Override
-                                       protected 
SnapshotResult<KeyedStateHandle> callInternal() throws Exception {
-
-                                               final 
CheckpointStreamWithResultProvider streamWithResultProvider =
-                                                       
checkpointStreamSupplier.get();
-
-                                               
snapshotCloseableRegistry.registerCloseable(streamWithResultProvider);
-
-                                               final 
CheckpointStreamFactory.CheckpointStateOutputStream localStream =
-                                                       
streamWithResultProvider.getCheckpointOutputStream();
-
-                                               final 
DataOutputViewStreamWrapper outView = new 
DataOutputViewStreamWrapper(localStream);
-                                               
serializationProxy.write(outView);
-
-                                               final long[] 
keyGroupRangeOffsets = new long[keyGroupRange.getNumberOfKeyGroups()];
-
-                                               for (int keyGroupPos = 0; 
keyGroupPos < keyGroupRange.getNumberOfKeyGroups(); ++keyGroupPos) {
-                                                       int keyGroupId = 
keyGroupRange.getKeyGroupId(keyGroupPos);
-                                                       
keyGroupRangeOffsets[keyGroupPos] = localStream.getPos();
-                                                       
outView.writeInt(keyGroupId);
-
-                                                       for 
(Map.Entry<StateUID, StateSnapshot> stateSnapshot :
-                                                               
cowStateStableSnapshots.entrySet()) {
-                                                               
StateSnapshot.StateKeyGroupWriter partitionedSnapshot =
-
-                                                                       
stateSnapshot.getValue().getKeyGroupWriter();
-                                                               try (
-                                                                       
OutputStream kgCompressionOut =
-                                                                               
keyGroupCompressionDecorator.decorateWithCompression(localStream)) {
-                                                                       
DataOutputViewStreamWrapper kgCompressionView =
-                                                                               
new DataOutputViewStreamWrapper(kgCompressionOut);
-                                                                       
kgCompressionView.writeShort(stateNamesToId.get(stateSnapshot.getKey()));
-                                                                       
partitionedSnapshot.writeStateInKeyGroup(kgCompressionView, keyGroupId);
-                                                               } // this will 
just close the outer compression stream
-                                                       }
-                                               }
-
-                                               if 
(snapshotCloseableRegistry.unregisterCloseable(streamWithResultProvider)) {
-                                                       KeyGroupRangeOffsets 
kgOffs = new KeyGroupRangeOffsets(keyGroupRange, keyGroupRangeOffsets);
-                                                       
SnapshotResult<StreamStateHandle> result =
-                                                               
streamWithResultProvider.closeAndFinalizeCheckpointStreamResult();
-                                                       return 
CheckpointStreamWithResultProvider.toKeyedStateHandleSnapshotResult(result, 
kgOffs);
-                                               } else {
-                                                       throw new 
IOException("Stream already unregistered.");
-                                               }
-                                       }
-
-                                       @Override
-                                       protected void 
cleanupProvidedResources() {
-                                               for (StateSnapshot 
tableSnapshot : cowStateStableSnapshots.values()) {
-                                                       tableSnapshot.release();
-                                               }
-                                       }
-
-                                       @Override
-                                       protected void 
logAsyncSnapshotComplete(long startTime) {
-                                               if 
(snapshotStrategySynchronicityTrait.isAsynchronous()) {
-                                                       
logAsyncCompleted(primaryStreamFactory, startTime);
-                                               }
-                                       }
-                               };
-
-                       final FutureTask<SnapshotResult<KeyedStateHandle>> task 
=
-                               
asyncSnapshotCallable.toAsyncSnapshotFutureTask(cancelStreamRegistry);
-                       finalizeSnapshotBeforeReturnHook(task);
-
-                       return task;
-               }
-
-               @Override
-               public void finalizeSnapshotBeforeReturnHook(Runnable runnable) 
{
-                       
snapshotStrategySynchronicityTrait.finalizeSnapshotBeforeReturnHook(runnable);
-               }
-
-               @Override
-               public boolean isAsynchronous() {
-                       return 
snapshotStrategySynchronicityTrait.isAsynchronous();
-               }
-
-               @Override
-               public <N, V> StateTable<K, N, V> 
newStateTable(RegisteredKeyValueStateBackendMetaInfo<N, V> newMetaInfo) {
-                       return 
snapshotStrategySynchronicityTrait.newStateTable(newMetaInfo);
-               }
-
-               private void processSnapshotMetaInfoForAllStates(
-                       List<StateMetaInfoSnapshot> metaInfoSnapshots,
-                       Map<StateUID, StateSnapshot> cowStateStableSnapshots,
-                       Map<StateUID, Integer> stateNamesToId,
-                       Map<String, ? extends StateSnapshotRestore> 
registeredStates,
-                       StateMetaInfoSnapshot.BackendStateType stateType) {
-
-                       for (Map.Entry<String, ? extends StateSnapshotRestore> 
kvState : registeredStates.entrySet()) {
-                               final StateUID stateUid = 
StateUID.of(kvState.getKey(), stateType);
-                               stateNamesToId.put(stateUid, 
stateNamesToId.size());
-                               StateSnapshotRestore state = kvState.getValue();
-                               if (null != state) {
-                                       final StateSnapshot stateSnapshot = 
state.stateSnapshot();
-                                       
metaInfoSnapshots.add(stateSnapshot.getMetaInfoSnapshot());
-                                       cowStateStableSnapshots.put(stateUid, 
stateSnapshot);
-                               }
-                       }
-               }
-       }
-
        private interface StateFactory {
                <K, N, SV, S extends State, IS extends S> IS createState(
                        StateDescriptor<S, SV> stateDesc,
@@ -827,52 +387,4 @@ public class HeapKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
                        TypeSerializer<K> keySerializer) throws Exception;
        }
 
-       /**
-        * Unique identifier for registered state in this backend.
-        */
-       private static final class StateUID {
-
-               @Nonnull
-               private final String stateName;
-
-               @Nonnull
-               private final StateMetaInfoSnapshot.BackendStateType stateType;
-
-               StateUID(@Nonnull String stateName, @Nonnull 
StateMetaInfoSnapshot.BackendStateType stateType) {
-                       this.stateName = stateName;
-                       this.stateType = stateType;
-               }
-
-               @Nonnull
-               public String getStateName() {
-                       return stateName;
-               }
-
-               @Nonnull
-               public StateMetaInfoSnapshot.BackendStateType getStateType() {
-                       return stateType;
-               }
-
-               @Override
-               public boolean equals(Object o) {
-                       if (this == o) {
-                               return true;
-                       }
-                       if (o == null || getClass() != o.getClass()) {
-                               return false;
-                       }
-                       StateUID uid = (StateUID) o;
-                       return Objects.equals(getStateName(), 
uid.getStateName()) &&
-                               getStateType() == uid.getStateType();
-               }
-
-               @Override
-               public int hashCode() {
-                       return Objects.hash(getStateName(), getStateType());
-               }
-
-               public static StateUID of(@Nonnull String stateName, @Nonnull 
StateMetaInfoSnapshot.BackendStateType stateType) {
-                       return new StateUID(stateName, stateType);
-               }
-       }
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackendBuilder.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackendBuilder.java
new file mode 100644
index 0000000..3047fd0
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackendBuilder.java
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state.heap;
+
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.CloseableRegistry;
+import org.apache.flink.runtime.query.TaskKvStateRegistry;
+import org.apache.flink.runtime.state.AbstractKeyedStateBackendBuilder;
+import org.apache.flink.runtime.state.BackendBuildingException;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.KeyedStateHandle;
+import org.apache.flink.runtime.state.LocalRecoveryConfig;
+import org.apache.flink.runtime.state.StreamCompressionDecorator;
+import org.apache.flink.runtime.state.ttl.TtlTimeProvider;
+
+import javax.annotation.Nonnull;
+
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * Builder class for {@link HeapKeyedStateBackend} which handles all necessary 
initializations and clean ups.
+ *
+ * @param <K> The data type that the key serializer serializes.
+ */
+public class HeapKeyedStateBackendBuilder<K> extends 
AbstractKeyedStateBackendBuilder<K> {
+       /**
+        * The configuration of local recovery.
+        */
+       private final LocalRecoveryConfig localRecoveryConfig;
+       /**
+        * Factory for state that is organized as priority queue.
+        */
+       private final HeapPriorityQueueSetFactory priorityQueueSetFactory;
+       /**
+        * Whether asynchronous snapshot is enabled.
+        */
+       private final boolean asynchronousSnapshots;
+
+
+       public HeapKeyedStateBackendBuilder(
+               TaskKvStateRegistry kvStateRegistry,
+               TypeSerializer<K> keySerializer,
+               ClassLoader userCodeClassLoader,
+               int numberOfKeyGroups,
+               KeyGroupRange keyGroupRange,
+               ExecutionConfig executionConfig,
+               TtlTimeProvider ttlTimeProvider,
+               @Nonnull Collection<KeyedStateHandle> stateHandles,
+               StreamCompressionDecorator keyGroupCompressionDecorator,
+               LocalRecoveryConfig localRecoveryConfig,
+               HeapPriorityQueueSetFactory priorityQueueSetFactory,
+               boolean asynchronousSnapshots,
+               CloseableRegistry cancelStreamRegistry) {
+               super(
+                       kvStateRegistry,
+                       keySerializer,
+                       userCodeClassLoader,
+                       numberOfKeyGroups,
+                       keyGroupRange,
+                       executionConfig,
+                       ttlTimeProvider,
+                       stateHandles,
+                       keyGroupCompressionDecorator,
+                       cancelStreamRegistry);
+               this.localRecoveryConfig = localRecoveryConfig;
+               this.priorityQueueSetFactory = priorityQueueSetFactory;
+               this.asynchronousSnapshots = asynchronousSnapshots;
+       }
+
+       @Override
+       public HeapKeyedStateBackend<K> build() throws BackendBuildingException 
{
+               // Map of registered Key/Value states
+               Map<String, StateTable<K, ?, ?>> registeredKVStates = new 
HashMap<>();
+               // Map of registered priority queue set states
+               Map<String, HeapPriorityQueueSnapshotRestoreWrapper> 
registeredPQStates = new HashMap<>();
+               CloseableRegistry cancelStreamRegistryForBackend = new 
CloseableRegistry();
+               HeapSnapshotStrategy<K> snapshotStrategy = initSnapshotStrategy(
+                       asynchronousSnapshots, registeredKVStates, 
registeredPQStates, cancelStreamRegistryForBackend);
+               HeapKeyedStateBackend<K> backend = new HeapKeyedStateBackend<>(
+                       kvStateRegistry,
+                       keySerializerProvider,
+                       userCodeClassLoader,
+                       numberOfKeyGroups,
+                       keyGroupRange,
+                       executionConfig,
+                       ttlTimeProvider,
+                       cancelStreamRegistryForBackend,
+                       keyGroupCompressionDecorator,
+                       registeredKVStates,
+                       registeredPQStates,
+                       localRecoveryConfig,
+                       priorityQueueSetFactory,
+                       snapshotStrategy
+               );
+               HeapRestoreOperation<K> restoreOperation = new 
HeapRestoreOperation<>(
+                       restoreStateHandles,
+                       keySerializerProvider,
+                       userCodeClassLoader,
+                       registeredKVStates,
+                       registeredPQStates,
+                       cancelStreamRegistry,
+                       priorityQueueSetFactory,
+                       keyGroupRange,
+                       numberOfKeyGroups,
+                       snapshotStrategy,
+                       backend);
+               try {
+                       restoreOperation.restore();
+               } catch (Exception e) {
+                       throw new BackendBuildingException("Failed when trying 
to restore heap backend", e);
+               }
+               return backend;
+       }
+
+       private HeapSnapshotStrategy<K> initSnapshotStrategy(
+               boolean asynchronousSnapshots,
+               Map<String, StateTable<K, ?, ?>> registeredKVStates,
+               Map<String, HeapPriorityQueueSnapshotRestoreWrapper> 
registeredPQStates,
+               CloseableRegistry cancelStreamRegistry) {
+               SnapshotStrategySynchronicityBehavior<K> synchronicityTrait = 
asynchronousSnapshots ?
+                       new AsyncSnapshotStrategySynchronicityBehavior<>() :
+                       new SyncSnapshotStrategySynchronicityBehavior<>();
+               return new HeapSnapshotStrategy<>(
+                       synchronicityTrait,
+                       registeredKVStates,
+                       registeredPQStates,
+                       keyGroupCompressionDecorator,
+                       localRecoveryConfig,
+                       keyGroupRange,
+                       cancelStreamRegistry,
+                       keySerializerProvider);
+       }
+}
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapRestoreOperation.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapRestoreOperation.java
new file mode 100644
index 0000000..35966ba
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapRestoreOperation.java
@@ -0,0 +1,293 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state.heap;
+
+import org.apache.commons.io.IOUtils;
+import org.apache.flink.api.common.typeutils.TypeSerializerSchemaCompatibility;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.core.fs.CloseableRegistry;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.runtime.state.KeyExtractorFunction;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
+import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.Keyed;
+import org.apache.flink.runtime.state.KeyedBackendSerializationProxy;
+import org.apache.flink.runtime.state.KeyedStateHandle;
+import org.apache.flink.runtime.state.PriorityComparable;
+import org.apache.flink.runtime.state.RegisteredKeyValueStateBackendMetaInfo;
+import 
org.apache.flink.runtime.state.RegisteredPriorityQueueStateBackendMetaInfo;
+import org.apache.flink.runtime.state.RestoreOperation;
+import org.apache.flink.runtime.state.SnappyStreamCompressionDecorator;
+import org.apache.flink.runtime.state.StateSerializerProvider;
+import org.apache.flink.runtime.state.StateSnapshotKeyGroupReader;
+import org.apache.flink.runtime.state.StateSnapshotRestore;
+import org.apache.flink.runtime.state.StreamCompressionDecorator;
+import org.apache.flink.runtime.state.UncompressedStreamCompressionDecorator;
+import org.apache.flink.runtime.state.metainfo.StateMetaInfoSnapshot;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.StateMigrationException;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * Implementation of heap restore operation.
+ *
+ * @param <K> The data type that the serializer serializes.
+ */
+public class HeapRestoreOperation<K> implements RestoreOperation<Void> {
+       private final Collection<KeyedStateHandle> restoreStateHandles;
+       private final StateSerializerProvider<K> keySerializerProvider;
+       private final ClassLoader userCodeClassLoader;
+       private final Map<String, StateTable<K, ?, ?>> registeredKVStates;
+       private final Map<String, HeapPriorityQueueSnapshotRestoreWrapper> 
registeredPQStates;
+       private final CloseableRegistry cancelStreamRegistry;
+       private final HeapPriorityQueueSetFactory priorityQueueSetFactory;
+       @Nonnull
+       private final KeyGroupRange keyGroupRange;
+       @Nonnegative
+       private final int numberOfKeyGroups;
+       private final HeapSnapshotStrategy<K> snapshotStrategy;
+       private final HeapKeyedStateBackend<K> backend;
+
+       HeapRestoreOperation(
+               @Nonnull Collection<KeyedStateHandle> restoreStateHandles,
+               StateSerializerProvider<K> keySerializerProvider,
+               ClassLoader userCodeClassLoader,
+               Map<String, StateTable<K, ?, ?>> registeredKVStates,
+               Map<String, HeapPriorityQueueSnapshotRestoreWrapper> 
registeredPQStates,
+               CloseableRegistry cancelStreamRegistry,
+               HeapPriorityQueueSetFactory priorityQueueSetFactory,
+               @Nonnull KeyGroupRange keyGroupRange,
+               int numberOfKeyGroups,
+               HeapSnapshotStrategy<K> snapshotStrategy,
+               HeapKeyedStateBackend<K> backend) {
+               this.restoreStateHandles = restoreStateHandles;
+               this.keySerializerProvider = keySerializerProvider;
+               this.userCodeClassLoader = userCodeClassLoader;
+               this.registeredKVStates = registeredKVStates;
+               this.registeredPQStates = registeredPQStates;
+               this.cancelStreamRegistry = cancelStreamRegistry;
+               this.priorityQueueSetFactory = priorityQueueSetFactory;
+               this.keyGroupRange = keyGroupRange;
+               this.numberOfKeyGroups = numberOfKeyGroups;
+               this.snapshotStrategy = snapshotStrategy;
+               this.backend = backend;
+       }
+
+       @Override
+       public Void restore() throws Exception {
+
+               final Map<Integer, StateMetaInfoSnapshot> kvStatesById = new 
HashMap<>();
+               registeredKVStates.clear();
+               registeredPQStates.clear();
+
+               boolean keySerializerRestored = false;
+
+               for (KeyedStateHandle keyedStateHandle : restoreStateHandles) {
+
+                       if (keyedStateHandle == null) {
+                               continue;
+                       }
+
+                       if (!(keyedStateHandle instanceof 
KeyGroupsStateHandle)) {
+                               throw new IllegalStateException("Unexpected 
state handle type, " +
+                                       "expected: " + 
KeyGroupsStateHandle.class +
+                                       ", but found: " + 
keyedStateHandle.getClass());
+                       }
+
+                       KeyGroupsStateHandle keyGroupsStateHandle = 
(KeyGroupsStateHandle) keyedStateHandle;
+                       FSDataInputStream fsDataInputStream = 
keyGroupsStateHandle.openInputStream();
+                       
cancelStreamRegistry.registerCloseable(fsDataInputStream);
+
+                       try {
+                               DataInputViewStreamWrapper inView = new 
DataInputViewStreamWrapper(fsDataInputStream);
+
+                               KeyedBackendSerializationProxy<K> 
serializationProxy =
+                                       new 
KeyedBackendSerializationProxy<>(userCodeClassLoader);
+
+                               serializationProxy.read(inView);
+
+                               if (!keySerializerRestored) {
+                                       // check for key serializer 
compatibility; this also reconfigures the
+                                       // key serializer to be compatible, if 
it is required and is possible
+                                       TypeSerializerSchemaCompatibility<K> 
keySerializerSchemaCompat =
+                                               
keySerializerProvider.setPreviousSerializerSnapshotForRestoredState(serializationProxy.getKeySerializerSnapshot());
+                                       if 
(keySerializerSchemaCompat.isCompatibleAfterMigration() || 
keySerializerSchemaCompat.isIncompatible()) {
+                                               throw new 
StateMigrationException("The new key serializer must be compatible.");
+                                       }
+
+                                       keySerializerRestored = true;
+                               }
+
+                               List<StateMetaInfoSnapshot> restoredMetaInfos =
+                                       
serializationProxy.getStateMetaInfoSnapshots();
+
+                               
createOrCheckStateForMetaInfo(restoredMetaInfos, kvStatesById);
+
+                               readStateHandleStateData(
+                                       fsDataInputStream,
+                                       inView,
+                                       
keyGroupsStateHandle.getGroupRangeOffsets(),
+                                       kvStatesById, restoredMetaInfos.size(),
+                                       serializationProxy.getReadVersion(),
+                                       
serializationProxy.isUsingKeyGroupCompression());
+                       } finally {
+                               if 
(cancelStreamRegistry.unregisterCloseable(fsDataInputStream)) {
+                                       IOUtils.closeQuietly(fsDataInputStream);
+                               }
+                       }
+               }
+               return null;
+       }
+
+       private void createOrCheckStateForMetaInfo(
+               List<StateMetaInfoSnapshot> restoredMetaInfo,
+               Map<Integer, StateMetaInfoSnapshot> kvStatesById) {
+
+               for (StateMetaInfoSnapshot metaInfoSnapshot : restoredMetaInfo) 
{
+                       final StateSnapshotRestore registeredState;
+
+                       switch (metaInfoSnapshot.getBackendStateType()) {
+                               case KEY_VALUE:
+                                       registeredState = 
registeredKVStates.get(metaInfoSnapshot.getName());
+                                       if (registeredState == null) {
+                                               
RegisteredKeyValueStateBackendMetaInfo<?, ?> 
registeredKeyedBackendStateMetaInfo =
+                                                       new 
RegisteredKeyValueStateBackendMetaInfo<>(metaInfoSnapshot);
+                                               registeredKVStates.put(
+                                                       
metaInfoSnapshot.getName(),
+                                                       
snapshotStrategy.newStateTable(backend, registeredKeyedBackendStateMetaInfo));
+                                       }
+                                       break;
+                               case PRIORITY_QUEUE:
+                                       registeredState = 
registeredPQStates.get(metaInfoSnapshot.getName());
+                                       if (registeredState == null) {
+                                               createInternal(new 
RegisteredPriorityQueueStateBackendMetaInfo<>(metaInfoSnapshot));
+                                       }
+                                       break;
+                               default:
+                                       throw new 
IllegalStateException("Unexpected state type: " +
+                                               
metaInfoSnapshot.getBackendStateType() + ".");
+                       }
+
+                       if (registeredState == null) {
+                               kvStatesById.put(kvStatesById.size(), 
metaInfoSnapshot);
+                       }
+               }
+       }
+
+       private <T extends HeapPriorityQueueElement & PriorityComparable & 
Keyed> void createInternal(
+               RegisteredPriorityQueueStateBackendMetaInfo<T> metaInfo) {
+
+               final String stateName = metaInfo.getName();
+               final HeapPriorityQueueSet<T> priorityQueue = 
priorityQueueSetFactory.create(
+                       stateName,
+                       metaInfo.getElementSerializer());
+
+               HeapPriorityQueueSnapshotRestoreWrapper<T> wrapper =
+                       new HeapPriorityQueueSnapshotRestoreWrapper<>(
+                               priorityQueue,
+                               metaInfo,
+                               KeyExtractorFunction.forKeyedObjects(),
+                               keyGroupRange,
+                               numberOfKeyGroups);
+
+               registeredPQStates.put(stateName, wrapper);
+       }
+
+       private void readStateHandleStateData(
+               FSDataInputStream fsDataInputStream,
+               DataInputViewStreamWrapper inView,
+               KeyGroupRangeOffsets keyGroupOffsets,
+               Map<Integer, StateMetaInfoSnapshot> kvStatesById,
+               int numStates,
+               int readVersion,
+               boolean isCompressed) throws IOException {
+
+               final StreamCompressionDecorator streamCompressionDecorator = 
isCompressed ?
+                       SnappyStreamCompressionDecorator.INSTANCE : 
UncompressedStreamCompressionDecorator.INSTANCE;
+
+               for (Tuple2<Integer, Long> groupOffset : keyGroupOffsets) {
+                       int keyGroupIndex = groupOffset.f0;
+                       long offset = groupOffset.f1;
+
+                       // Check that restored key groups all belong to the 
backend.
+                       
Preconditions.checkState(keyGroupRange.contains(keyGroupIndex), "The key group 
must belong to the backend.");
+
+                       fsDataInputStream.seek(offset);
+
+                       int writtenKeyGroupIndex = inView.readInt();
+                       Preconditions.checkState(writtenKeyGroupIndex == 
keyGroupIndex,
+                               "Unexpected key-group in restore.");
+
+                       try (InputStream kgCompressionInStream =
+                                        
streamCompressionDecorator.decorateWithCompression(fsDataInputStream)) {
+
+                               readKeyGroupStateData(
+                                       kgCompressionInStream,
+                                       kvStatesById,
+                                       keyGroupIndex,
+                                       numStates,
+                                       readVersion);
+                       }
+               }
+       }
+
+       private void readKeyGroupStateData(
+               InputStream inputStream,
+               Map<Integer, StateMetaInfoSnapshot> kvStatesById,
+               int keyGroupIndex,
+               int numStates,
+               int readVersion) throws IOException {
+
+               DataInputViewStreamWrapper inView =
+                       new DataInputViewStreamWrapper(inputStream);
+
+               for (int i = 0; i < numStates; i++) {
+
+                       final int kvStateId = inView.readShort();
+                       final StateMetaInfoSnapshot stateMetaInfoSnapshot = 
kvStatesById.get(kvStateId);
+                       final StateSnapshotRestore registeredState;
+
+                       switch (stateMetaInfoSnapshot.getBackendStateType()) {
+                               case KEY_VALUE:
+                                       registeredState = 
registeredKVStates.get(stateMetaInfoSnapshot.getName());
+                                       break;
+                               case PRIORITY_QUEUE:
+                                       registeredState = 
registeredPQStates.get(stateMetaInfoSnapshot.getName());
+                                       break;
+                               default:
+                                       throw new 
IllegalStateException("Unexpected state type: " +
+                                               
stateMetaInfoSnapshot.getBackendStateType() + ".");
+                       }
+
+                       StateSnapshotKeyGroupReader keyGroupReader = 
registeredState.keyGroupReader(readVersion);
+                       keyGroupReader.readMappingsInKeyGroup(inView, 
keyGroupIndex);
+               }
+       }
+}
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapSnapshotStrategy.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapSnapshotStrategy.java
new file mode 100644
index 0000000..8a7d280
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapSnapshotStrategy.java
@@ -0,0 +1,271 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state.heap;
+
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.CloseableRegistry;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.state.AbstractSnapshotStrategy;
+import org.apache.flink.runtime.state.AsyncSnapshotCallable;
+import org.apache.flink.runtime.state.CheckpointStreamFactory;
+import org.apache.flink.runtime.state.CheckpointStreamWithResultProvider;
+import org.apache.flink.runtime.state.CheckpointedStateScope;
+import org.apache.flink.runtime.state.DoneFuture;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
+import org.apache.flink.runtime.state.KeyedBackendSerializationProxy;
+import org.apache.flink.runtime.state.KeyedStateHandle;
+import org.apache.flink.runtime.state.LocalRecoveryConfig;
+import org.apache.flink.runtime.state.RegisteredKeyValueStateBackendMetaInfo;
+import org.apache.flink.runtime.state.SnapshotResult;
+import org.apache.flink.runtime.state.StateSerializerProvider;
+import org.apache.flink.runtime.state.StateSnapshot;
+import org.apache.flink.runtime.state.StateSnapshotRestore;
+import org.apache.flink.runtime.state.StreamCompressionDecorator;
+import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.UncompressedStreamCompressionDecorator;
+import org.apache.flink.runtime.state.metainfo.StateMetaInfoSnapshot;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.function.SupplierWithException;
+
+import javax.annotation.Nonnull;
+import java.io.IOException;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.concurrent.FutureTask;
+import java.util.concurrent.RunnableFuture;
+
+/**
+ * Base class for the snapshots of the heap backend that outlines the 
algorithm and offers some hooks to realize
+ * the concrete strategies. Subclasses must be threadsafe.
+ */
+class HeapSnapshotStrategy<K>
+       extends AbstractSnapshotStrategy<KeyedStateHandle> implements 
SnapshotStrategySynchronicityBehavior<K> {
+
+       private final SnapshotStrategySynchronicityBehavior<K> 
snapshotStrategySynchronicityTrait;
+       private final Map<String, StateTable<K, ?, ?>> registeredKVStates;
+       private final Map<String, HeapPriorityQueueSnapshotRestoreWrapper> 
registeredPQStates;
+       private final StreamCompressionDecorator keyGroupCompressionDecorator;
+       private final LocalRecoveryConfig localRecoveryConfig;
+       private final KeyGroupRange keyGroupRange;
+       private final CloseableRegistry cancelStreamRegistry;
+       private final StateSerializerProvider<K> keySerializerProvider;
+
+       HeapSnapshotStrategy(
+               SnapshotStrategySynchronicityBehavior<K> 
snapshotStrategySynchronicityTrait,
+               Map<String, StateTable<K, ?, ?>> registeredKVStates,
+               Map<String, HeapPriorityQueueSnapshotRestoreWrapper> 
registeredPQStates,
+               StreamCompressionDecorator keyGroupCompressionDecorator,
+               LocalRecoveryConfig localRecoveryConfig,
+               KeyGroupRange keyGroupRange,
+               CloseableRegistry cancelStreamRegistry,
+               StateSerializerProvider<K> keySerializerProvider) {
+               super("Heap backend snapshot");
+               this.snapshotStrategySynchronicityTrait = 
snapshotStrategySynchronicityTrait;
+               this.registeredKVStates = registeredKVStates;
+               this.registeredPQStates = registeredPQStates;
+               this.keyGroupCompressionDecorator = 
keyGroupCompressionDecorator;
+               this.localRecoveryConfig = localRecoveryConfig;
+               this.keyGroupRange = keyGroupRange;
+               this.cancelStreamRegistry = cancelStreamRegistry;
+               this.keySerializerProvider = keySerializerProvider;
+       }
+
+       @Nonnull
+       @Override
+       public RunnableFuture<SnapshotResult<KeyedStateHandle>> snapshot(
+               long checkpointId,
+               long timestamp,
+               @Nonnull CheckpointStreamFactory primaryStreamFactory,
+               @Nonnull CheckpointOptions checkpointOptions) throws 
IOException {
+
+               if (!hasRegisteredState()) {
+                       return DoneFuture.of(SnapshotResult.empty());
+               }
+
+               int numStates = registeredKVStates.size() + 
registeredPQStates.size();
+
+               Preconditions.checkState(numStates <= Short.MAX_VALUE,
+                       "Too many states: " + numStates +
+                               ". Currently at most " + Short.MAX_VALUE + " 
states are supported");
+
+               final List<StateMetaInfoSnapshot> metaInfoSnapshots = new 
ArrayList<>(numStates);
+               final Map<StateUID, Integer> stateNamesToId =
+                       new HashMap<>(numStates);
+               final Map<StateUID, StateSnapshot> cowStateStableSnapshots =
+                       new HashMap<>(numStates);
+
+               processSnapshotMetaInfoForAllStates(
+                       metaInfoSnapshots,
+                       cowStateStableSnapshots,
+                       stateNamesToId,
+                       registeredKVStates,
+                       StateMetaInfoSnapshot.BackendStateType.KEY_VALUE);
+
+               processSnapshotMetaInfoForAllStates(
+                       metaInfoSnapshots,
+                       cowStateStableSnapshots,
+                       stateNamesToId,
+                       registeredPQStates,
+                       StateMetaInfoSnapshot.BackendStateType.PRIORITY_QUEUE);
+
+               final KeyedBackendSerializationProxy<K> serializationProxy =
+                       new KeyedBackendSerializationProxy<>(
+                               // TODO: this code assumes that writing a 
serializer is threadsafe, we should support to
+                               // get a serialized form already at state 
registration time in the future
+                               getKeySerializer(),
+                               metaInfoSnapshots,
+                               
!Objects.equals(UncompressedStreamCompressionDecorator.INSTANCE, 
keyGroupCompressionDecorator));
+
+               final SupplierWithException<CheckpointStreamWithResultProvider, 
Exception> checkpointStreamSupplier =
+
+                       localRecoveryConfig.isLocalRecoveryEnabled() ?
+
+                               () -> 
CheckpointStreamWithResultProvider.createDuplicatingStream(
+                                       checkpointId,
+                                       CheckpointedStateScope.EXCLUSIVE,
+                                       primaryStreamFactory,
+                                       
localRecoveryConfig.getLocalStateDirectoryProvider()) :
+
+                               () -> 
CheckpointStreamWithResultProvider.createSimpleStream(
+                                       CheckpointedStateScope.EXCLUSIVE,
+                                       primaryStreamFactory);
+
+               //--------------------------------------------------- this 
becomes the end of sync part
+
+               final AsyncSnapshotCallable<SnapshotResult<KeyedStateHandle>> 
asyncSnapshotCallable =
+                       new 
AsyncSnapshotCallable<SnapshotResult<KeyedStateHandle>>() {
+                               @Override
+                               protected SnapshotResult<KeyedStateHandle> 
callInternal() throws Exception {
+
+                                       final 
CheckpointStreamWithResultProvider streamWithResultProvider =
+                                               checkpointStreamSupplier.get();
+
+                                       
snapshotCloseableRegistry.registerCloseable(streamWithResultProvider);
+
+                                       final 
CheckpointStreamFactory.CheckpointStateOutputStream localStream =
+                                               
streamWithResultProvider.getCheckpointOutputStream();
+
+                                       final DataOutputViewStreamWrapper 
outView = new DataOutputViewStreamWrapper(localStream);
+                                       serializationProxy.write(outView);
+
+                                       final long[] keyGroupRangeOffsets = new 
long[keyGroupRange.getNumberOfKeyGroups()];
+
+                                       for (int keyGroupPos = 0; keyGroupPos < 
keyGroupRange.getNumberOfKeyGroups(); ++keyGroupPos) {
+                                               int keyGroupId = 
keyGroupRange.getKeyGroupId(keyGroupPos);
+                                               
keyGroupRangeOffsets[keyGroupPos] = localStream.getPos();
+                                               outView.writeInt(keyGroupId);
+
+                                               for (Map.Entry<StateUID, 
StateSnapshot> stateSnapshot :
+                                                       
cowStateStableSnapshots.entrySet()) {
+                                                       
StateSnapshot.StateKeyGroupWriter partitionedSnapshot =
+
+                                                               
stateSnapshot.getValue().getKeyGroupWriter();
+                                                       try (
+                                                               OutputStream 
kgCompressionOut =
+                                                                       
keyGroupCompressionDecorator.decorateWithCompression(localStream)) {
+                                                               
DataOutputViewStreamWrapper kgCompressionView =
+                                                                       new 
DataOutputViewStreamWrapper(kgCompressionOut);
+                                                               
kgCompressionView.writeShort(stateNamesToId.get(stateSnapshot.getKey()));
+                                                               
partitionedSnapshot.writeStateInKeyGroup(kgCompressionView, keyGroupId);
+                                                       } // this will just 
close the outer compression stream
+                                               }
+                                       }
+
+                                       if 
(snapshotCloseableRegistry.unregisterCloseable(streamWithResultProvider)) {
+                                               KeyGroupRangeOffsets kgOffs = 
new KeyGroupRangeOffsets(keyGroupRange, keyGroupRangeOffsets);
+                                               
SnapshotResult<StreamStateHandle> result =
+                                                       
streamWithResultProvider.closeAndFinalizeCheckpointStreamResult();
+                                               return 
CheckpointStreamWithResultProvider.toKeyedStateHandleSnapshotResult(result, 
kgOffs);
+                                       } else {
+                                               throw new IOException("Stream 
already unregistered.");
+                                       }
+                               }
+
+                               @Override
+                               protected void cleanupProvidedResources() {
+                                       for (StateSnapshot tableSnapshot : 
cowStateStableSnapshots.values()) {
+                                               tableSnapshot.release();
+                                       }
+                               }
+
+                               @Override
+                               protected void logAsyncSnapshotComplete(long 
startTime) {
+                                       if 
(snapshotStrategySynchronicityTrait.isAsynchronous()) {
+                                               
logAsyncCompleted(primaryStreamFactory, startTime);
+                                       }
+                               }
+                       };
+
+               final FutureTask<SnapshotResult<KeyedStateHandle>> task =
+                       
asyncSnapshotCallable.toAsyncSnapshotFutureTask(cancelStreamRegistry);
+               finalizeSnapshotBeforeReturnHook(task);
+
+               return task;
+       }
+
+       @Override
+       public void finalizeSnapshotBeforeReturnHook(Runnable runnable) {
+               
snapshotStrategySynchronicityTrait.finalizeSnapshotBeforeReturnHook(runnable);
+       }
+
+       @Override
+       public boolean isAsynchronous() {
+               return snapshotStrategySynchronicityTrait.isAsynchronous();
+       }
+
+       @Override
+       public <N, V> StateTable<K, N, V> newStateTable(
+               InternalKeyContext<K> keyContext,
+               RegisteredKeyValueStateBackendMetaInfo<N, V> newMetaInfo) {
+               return 
snapshotStrategySynchronicityTrait.newStateTable(keyContext, newMetaInfo);
+       }
+
+       private void processSnapshotMetaInfoForAllStates(
+               List<StateMetaInfoSnapshot> metaInfoSnapshots,
+               Map<StateUID, StateSnapshot> cowStateStableSnapshots,
+               Map<StateUID, Integer> stateNamesToId,
+               Map<String, ? extends StateSnapshotRestore> registeredStates,
+               StateMetaInfoSnapshot.BackendStateType stateType) {
+
+               for (Map.Entry<String, ? extends StateSnapshotRestore> kvState 
: registeredStates.entrySet()) {
+                       final StateUID stateUid = StateUID.of(kvState.getKey(), 
stateType);
+                       stateNamesToId.put(stateUid, stateNamesToId.size());
+                       StateSnapshotRestore state = kvState.getValue();
+                       if (null != state) {
+                               final StateSnapshot stateSnapshot = 
state.stateSnapshot();
+                               
metaInfoSnapshots.add(stateSnapshot.getMetaInfoSnapshot());
+                               cowStateStableSnapshots.put(stateUid, 
stateSnapshot);
+                       }
+               }
+       }
+
+       private boolean hasRegisteredState() {
+               return !(registeredKVStates.isEmpty() && 
registeredPQStates.isEmpty());
+       }
+
+       public TypeSerializer<K> getKeySerializer() {
+               return keySerializerProvider.currentSchemaSerializer();
+       }
+}
diff --git 
a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBRestoreOperation.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/SnapshotStrategySynchronicityBehavior.java
similarity index 59%
copy from 
flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBRestoreOperation.java
copy to 
flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/SnapshotStrategySynchronicityBehavior.java
index ff70199..3e963ed 100644
--- 
a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBRestoreOperation.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/SnapshotStrategySynchronicityBehavior.java
@@ -16,14 +16,24 @@
  * limitations under the License.
  */
 
-package org.apache.flink.contrib.streaming.state.restore;
+package org.apache.flink.runtime.state.heap;
+
+import org.apache.flink.runtime.state.RegisteredKeyValueStateBackendMetaInfo;
 
 /**
- * Interface for RocksDB restore.
+ * Interface for synchronicity behavior of heap snapshot strategy.
+ *
+ * @param <K> The data type that the serializer serializes.
  */
-public interface RocksDBRestoreOperation {
-       /**
-        * Restores state that was previously snapshot-ed from the provided 
state handles.
-        */
-       RocksDBRestoreResult restore() throws Exception;
+interface SnapshotStrategySynchronicityBehavior<K> {
+
+       default void finalizeSnapshotBeforeReturnHook(Runnable runnable) {
+
+       }
+
+       boolean isAsynchronous();
+
+       <N, V> StateTable<K, N, V> newStateTable(
+               InternalKeyContext<K> keyContext,
+               RegisteredKeyValueStateBackendMetaInfo<N, V> newMetaInfo);
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateUID.java 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateUID.java
new file mode 100644
index 0000000..4575c7c
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateUID.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state.heap;
+
+import org.apache.flink.runtime.state.metainfo.StateMetaInfoSnapshot;
+
+import javax.annotation.Nonnull;
+import java.util.Objects;
+
+/**
+ * Unique identifier for registered state in this backend.
+ */
+final class StateUID {
+
+       @Nonnull
+       private final String stateName;
+
+       @Nonnull
+       private final StateMetaInfoSnapshot.BackendStateType stateType;
+
+       StateUID(@Nonnull String stateName, @Nonnull 
StateMetaInfoSnapshot.BackendStateType stateType) {
+               this.stateName = stateName;
+               this.stateType = stateType;
+       }
+
+       @Nonnull
+       public String getStateName() {
+               return stateName;
+       }
+
+       @Nonnull
+       public StateMetaInfoSnapshot.BackendStateType getStateType() {
+               return stateType;
+       }
+
+       @Override
+       public boolean equals(Object o) {
+               if (this == o) {
+                       return true;
+               }
+               if (o == null || getClass() != o.getClass()) {
+                       return false;
+               }
+               StateUID uid = (StateUID) o;
+               return Objects.equals(getStateName(), uid.getStateName()) &&
+                       getStateType() == uid.getStateType();
+       }
+
+       @Override
+       public int hashCode() {
+               return Objects.hash(getStateName(), getStateType());
+       }
+
+       public static StateUID of(@Nonnull String stateName, @Nonnull 
StateMetaInfoSnapshot.BackendStateType stateType) {
+               return new StateUID(stateName, stateType);
+       }
+}
diff --git 
a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBRestoreOperation.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/SyncSnapshotStrategySynchronicityBehavior.java
similarity index 50%
copy from 
flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBRestoreOperation.java
copy to 
flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/SyncSnapshotStrategySynchronicityBehavior.java
index ff70199..2d553ab 100644
--- 
a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBRestoreOperation.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/SyncSnapshotStrategySynchronicityBehavior.java
@@ -16,14 +16,32 @@
  * limitations under the License.
  */
 
-package org.apache.flink.contrib.streaming.state.restore;
+package org.apache.flink.runtime.state.heap;
+
+import org.apache.flink.runtime.state.RegisteredKeyValueStateBackendMetaInfo;
 
 /**
- * Interface for RocksDB restore.
+ * Synchronous behavior for heap snapshot strategy.
+ *
+ * @param <K> The data type that the serializer serializes.
  */
-public interface RocksDBRestoreOperation {
-       /**
-        * Restores state that was previously snapshot-ed from the provided 
state handles.
-        */
-       RocksDBRestoreResult restore() throws Exception;
+class SyncSnapshotStrategySynchronicityBehavior<K> implements 
SnapshotStrategySynchronicityBehavior<K> {
+
+       @Override
+       public void finalizeSnapshotBeforeReturnHook(Runnable runnable) {
+               // this triggers a synchronous execution from the main 
checkpointing thread.
+               runnable.run();
+       }
+
+       @Override
+       public boolean isAsynchronous() {
+               return false;
+       }
+
+       @Override
+       public <N, V> StateTable<K, N, V> newStateTable(
+               InternalKeyContext<K> keyContext,
+               RegisteredKeyValueStateBackendMetaInfo<N, V> newMetaInfo) {
+               return new NestedMapsStateTable<>(keyContext, newMetaInfo);
+       }
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java
index 6338c53..03e11ae 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java
@@ -29,15 +29,17 @@ import org.apache.flink.metrics.MetricGroup;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
 import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
+import org.apache.flink.runtime.state.AbstractStateBackend;
+import org.apache.flink.runtime.state.BackendBuildingException;
 import org.apache.flink.runtime.state.CheckpointStorage;
 import org.apache.flink.runtime.state.ConfigurableStateBackend;
 import org.apache.flink.runtime.state.DefaultOperatorStateBackend;
-import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateBackend;
 import org.apache.flink.runtime.state.TaskStateManager;
 import org.apache.flink.runtime.state.filesystem.AbstractFileStateBackend;
-import org.apache.flink.runtime.state.heap.HeapKeyedStateBackend;
+import org.apache.flink.runtime.state.heap.HeapKeyedStateBackendBuilder;
 import org.apache.flink.runtime.state.heap.HeapPriorityQueueSetFactory;
 import org.apache.flink.runtime.state.ttl.TtlTimeProvider;
 import org.apache.flink.util.TernaryBoolean;
@@ -319,23 +321,25 @@ public class MemoryStateBackend extends 
AbstractFileStateBackend implements Conf
                TtlTimeProvider ttlTimeProvider,
                MetricGroup metricGroup,
                @Nonnull Collection<KeyedStateHandle> stateHandles,
-               CloseableRegistry cancelStreamRegistry) {
+               CloseableRegistry cancelStreamRegistry) throws 
BackendBuildingException {
 
                TaskStateManager taskStateManager = env.getTaskStateManager();
                HeapPriorityQueueSetFactory priorityQueueSetFactory =
                        new HeapPriorityQueueSetFactory(keyGroupRange, 
numberOfKeyGroups, 128);
-               return new HeapKeyedStateBackend<>(
-                               kvStateRegistry,
-                               keySerializer,
-                               env.getUserClassLoader(),
-                               numberOfKeyGroups,
-                               keyGroupRange,
-                               isUsingAsynchronousSnapshots(),
-                               env.getExecutionConfig(),
-                               taskStateManager.createLocalRecoveryConfig(),
-                               priorityQueueSetFactory,
-                               ttlTimeProvider,
-                       cancelStreamRegistry);
+               return new HeapKeyedStateBackendBuilder<>(
+                       kvStateRegistry,
+                       keySerializer,
+                       env.getUserClassLoader(),
+                       numberOfKeyGroups,
+                       keyGroupRange,
+                       env.getExecutionConfig(),
+                       ttlTimeProvider,
+                       stateHandles,
+                       
AbstractStateBackend.getCompressionDecorator(env.getExecutionConfig()),
+                       taskStateManager.createLocalRecoveryConfig(),
+                       priorityQueueSetFactory,
+                       isUsingAsynchronousSnapshots(),
+                       cancelStreamRegistry).build();
        }
 
        // 
------------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendMigrationTestBase.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendMigrationTestBase.java
index 13009a7..6a01583 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendMigrationTestBase.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendMigrationTestBase.java
@@ -1000,7 +1000,6 @@ public abstract class StateBackendMigrationTestBase<B 
extends AbstractStateBacke
                        numberOfKeyGroups,
                        keyGroupRange,
                        env.getTaskKvStateRegistry());
-               backend.restore(Collections.emptyList());
                return backend;
        }
 
@@ -1036,7 +1035,6 @@ public abstract class StateBackendMigrationTestBase<B 
extends AbstractStateBacke
                        env.getTaskKvStateRegistry()
                        , TtlTimeProvider.DEFAULT,
                        state);
-               backend.restore(new StateObjectCollection<>(state));
                return backend;
        }
 
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
index b7d8cfc..0d0c7a4 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
@@ -60,7 +60,6 @@ import org.apache.flink.queryablestate.KvStateID;
 import 
org.apache.flink.queryablestate.client.state.serialization.KvStateSerializer;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.checkpoint.StateAssignmentOperation;
-import org.apache.flink.runtime.checkpoint.StateObjectCollection;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.highavailability.HighAvailabilityServices;
 import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
@@ -81,6 +80,7 @@ import org.apache.flink.types.IntValue;
 import org.apache.flink.util.IOUtils;
 import org.apache.flink.util.StateMigrationException;
 import org.apache.flink.util.TestLogger;
+import org.hamcrest.Matchers;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.ExpectedException;
@@ -109,6 +109,8 @@ import java.util.concurrent.TimeUnit;
 import java.util.stream.Stream;
 
 import static java.util.Arrays.asList;
+import static org.hamcrest.CoreMatchers.anyOf;
+import static org.hamcrest.CoreMatchers.isA;
 import static org.hamcrest.Matchers.containsInAnyOrder;
 import static org.hamcrest.core.Is.is;
 import static org.junit.Assert.assertEquals;
@@ -179,8 +181,6 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                                env.getTaskKvStateRegistry(),
                                TtlTimeProvider.DEFAULT);
 
-               backend.restore(Collections.emptyList());
-
                return backend;
        }
 
@@ -218,8 +218,6 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                        TtlTimeProvider.DEFAULT,
                        state);
 
-               backend.restore(new StateObjectCollection<>(state));
-
                return backend;
        }
 
@@ -666,7 +664,8 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
 
                        // on the second restore, since the custom serializer 
will be used for
                        // deserialization, we expect the deliberate failure to 
be thrown
-                       
expectedException.expect(ExpectedKryoTestException.class);
+                       
expectedException.expect(anyOf(isA(ExpectedKryoTestException.class),
+                               Matchers.<Throwable>hasProperty("cause", 
isA(ExpectedKryoTestException.class))));
 
                        // state backends that eagerly deserializes (such as 
the memory state backend) will fail here
                        backend = restoreKeyedBackend(IntSerializer.INSTANCE, 
snapshot2, env);
@@ -768,7 +767,8 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
 
                        // on the second restore, since the custom serializer 
will be used for
                        // deserialization, we expect the deliberate failure to 
be thrown
-                       
expectedException.expect(ExpectedKryoTestException.class);
+                       
expectedException.expect(anyOf(isA(ExpectedKryoTestException.class),
+                               Matchers.<Throwable>hasProperty("cause", 
isA(ExpectedKryoTestException.class))));
 
                        // state backends that eagerly deserializes (such as 
the memory state backend) will fail here
                        backend = restoreKeyedBackend(IntSerializer.INSTANCE, 
snapshot2, env);
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateSnapshotCompressionTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateSnapshotCompressionTest.java
index 64814c7..a10be26 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateSnapshotCompressionTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateSnapshotCompressionTest.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.runtime.state;
 
+import org.apache.commons.io.IOUtils;
 import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.state.ValueStateDescriptor;
 import org.apache.flink.api.common.typeutils.base.StringSerializer;
@@ -26,16 +27,18 @@ import 
org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.checkpoint.StateObjectCollection;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
 import org.apache.flink.runtime.state.heap.HeapKeyedStateBackend;
+import org.apache.flink.runtime.state.heap.HeapKeyedStateBackendBuilder;
 import org.apache.flink.runtime.state.heap.HeapPriorityQueueSetFactory;
 import org.apache.flink.runtime.state.internal.InternalValueState;
 import org.apache.flink.runtime.state.memory.MemCheckpointStreamFactory;
 import org.apache.flink.runtime.state.ttl.TtlTimeProvider;
 import org.apache.flink.util.TestLogger;
 
-import org.apache.commons.io.IOUtils;
 import org.junit.Assert;
 import org.junit.Test;
 
+import java.util.Collection;
+import java.util.Collections;
 import java.util.concurrent.RunnableFuture;
 
 import static org.mockito.Mockito.mock;
@@ -43,23 +46,12 @@ import static org.mockito.Mockito.mock;
 public class StateSnapshotCompressionTest extends TestLogger {
 
        @Test
-       public void testCompressionConfiguration() {
+       public void testCompressionConfiguration() throws 
BackendBuildingException {
 
                ExecutionConfig executionConfig = new ExecutionConfig();
                executionConfig.setUseSnapshotCompression(true);
 
-               AbstractKeyedStateBackend<String> stateBackend = new 
HeapKeyedStateBackend<>(
-                       mock(TaskKvStateRegistry.class),
-                       StringSerializer.INSTANCE,
-                       StateSnapshotCompressionTest.class.getClassLoader(),
-                       16,
-                       new KeyGroupRange(0, 15),
-                       true,
-                       executionConfig,
-                       TestLocalRecoveryConfig.disabled(),
-                       mock(HeapPriorityQueueSetFactory.class),
-                       TtlTimeProvider.DEFAULT,
-                       new CloseableRegistry());
+               AbstractKeyedStateBackend<String> stateBackend = 
getStringHeapKeyedStateBackend(executionConfig);
 
                try {
                        Assert.assertTrue(
@@ -73,18 +65,7 @@ public class StateSnapshotCompressionTest extends TestLogger 
{
                executionConfig = new ExecutionConfig();
                executionConfig.setUseSnapshotCompression(false);
 
-               stateBackend = new HeapKeyedStateBackend<>(
-                       mock(TaskKvStateRegistry.class),
-                       StringSerializer.INSTANCE,
-                       StateSnapshotCompressionTest.class.getClassLoader(),
-                       16,
-                       new KeyGroupRange(0, 15),
-                       true,
-                       executionConfig,
-                       TestLocalRecoveryConfig.disabled(),
-                       mock(HeapPriorityQueueSetFactory.class),
-                       TtlTimeProvider.DEFAULT,
-                       new CloseableRegistry());
+               stateBackend = getStringHeapKeyedStateBackend(executionConfig);
 
                try {
                        Assert.assertTrue(
@@ -106,28 +87,42 @@ public class StateSnapshotCompressionTest extends 
TestLogger {
                snapshotRestoreRoundtrip(false);
        }
 
-       private void snapshotRestoreRoundtrip(boolean useCompression) throws 
Exception {
-
-               ExecutionConfig executionConfig = new ExecutionConfig();
-               executionConfig.setUseSnapshotCompression(useCompression);
-
-               KeyedStateHandle stateHandle = null;
-
-               ValueStateDescriptor<String> stateDescriptor = new 
ValueStateDescriptor<>("test", String.class);
-               stateDescriptor.initializeSerializerUnlessSet(executionConfig);
+       private HeapKeyedStateBackend<String> 
getStringHeapKeyedStateBackend(ExecutionConfig executionConfig)
+               throws BackendBuildingException {
+               return getStringHeapKeyedStateBackend(executionConfig, 
Collections.emptyList());
+       }
 
-               AbstractKeyedStateBackend<String> stateBackend = new 
HeapKeyedStateBackend<>(
+       private HeapKeyedStateBackend<String> getStringHeapKeyedStateBackend(
+               ExecutionConfig executionConfig,
+               Collection<KeyedStateHandle> stateHandles)
+               throws BackendBuildingException {
+               return new HeapKeyedStateBackendBuilder<>(
                        mock(TaskKvStateRegistry.class),
                        StringSerializer.INSTANCE,
                        StateSnapshotCompressionTest.class.getClassLoader(),
                        16,
                        new KeyGroupRange(0, 15),
-                       true,
                        executionConfig,
+                       TtlTimeProvider.DEFAULT,
+                       stateHandles,
+                       
AbstractStateBackend.getCompressionDecorator(executionConfig),
                        TestLocalRecoveryConfig.disabled(),
                        mock(HeapPriorityQueueSetFactory.class),
-                       TtlTimeProvider.DEFAULT,
-                       new CloseableRegistry());
+                       true,
+                       new CloseableRegistry()).build();
+       }
+
+       private void snapshotRestoreRoundtrip(boolean useCompression) throws 
Exception {
+
+               ExecutionConfig executionConfig = new ExecutionConfig();
+               executionConfig.setUseSnapshotCompression(useCompression);
+
+               KeyedStateHandle stateHandle;
+
+               ValueStateDescriptor<String> stateDescriptor = new 
ValueStateDescriptor<>("test", String.class);
+               stateDescriptor.initializeSerializerUnlessSet(executionConfig);
+
+               AbstractKeyedStateBackend<String> stateBackend = 
getStringHeapKeyedStateBackend(executionConfig);
 
                try {
 
@@ -160,22 +155,8 @@ public class StateSnapshotCompressionTest extends 
TestLogger {
 
                executionConfig = new ExecutionConfig();
 
-               stateBackend = new HeapKeyedStateBackend<>(
-                       mock(TaskKvStateRegistry.class),
-                       StringSerializer.INSTANCE,
-                       StateSnapshotCompressionTest.class.getClassLoader(),
-                       16,
-                       new KeyGroupRange(0, 15),
-                       true,
-                       executionConfig,
-                       TestLocalRecoveryConfig.disabled(),
-                       mock(HeapPriorityQueueSetFactory.class),
-                       TtlTimeProvider.DEFAULT,
-                       new CloseableRegistry());
+               stateBackend = getStringHeapKeyedStateBackend(executionConfig, 
StateObjectCollection.singleton(stateHandle));
                try {
-
-                       
stateBackend.restore(StateObjectCollection.singleton(stateHandle));
-
                        InternalValueState<String, VoidNamespace, String> state 
= stateBackend.createInternalState(
                                new VoidNamespaceSerializer(),
                                stateDescriptor);
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackendSnapshotMigrationTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackendSnapshotMigrationTest.java
index 1ca3e80..23e74cd 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackendSnapshotMigrationTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackendSnapshotMigrationTest.java
@@ -58,21 +58,19 @@ public class HeapKeyedStateBackendSnapshotMigrationTest 
extends HeapStateBackend
 
                Preconditions.checkNotNull(resource, "Binary snapshot resource 
not found!");
 
-               try (final HeapKeyedStateBackend<String> keyedBackend = 
createKeyedBackend()) {
+               final SnapshotResult<KeyedStateHandle> stateHandles;
+               try (BufferedInputStream bis = new BufferedInputStream((new 
FileInputStream(resource.getFile())))) {
+                       stateHandles = InstantiationUtil.deserializeObject(bis, 
Thread.currentThread().getContextClassLoader());
+               }
+               final KeyedStateHandle stateHandle = 
stateHandles.getJobManagerOwnedSnapshot();
+               try (final HeapKeyedStateBackend<String> keyedBackend = 
createKeyedBackend(StateObjectCollection.singleton(stateHandle))) {
                        final Integer namespace1 = 1;
                        final Integer namespace2 = 2;
                        final Integer namespace3 = 3;
 
-                       final SnapshotResult<KeyedStateHandle> stateHandles;
-                       try (BufferedInputStream bis = new 
BufferedInputStream((new FileInputStream(resource.getFile())))) {
-                               stateHandles = 
InstantiationUtil.deserializeObject(bis, 
Thread.currentThread().getContextClassLoader());
-                       }
-
                        final MapStateDescriptor<Long, Long> stateDescr = new 
MapStateDescriptor<>("my-map-state", Long.class, Long.class);
                        stateDescr.initializeSerializerUnlessSet(new 
ExecutionConfig());
 
-                       
keyedBackend.restore(StateObjectCollection.singleton(stateHandles.getJobManagerOwnedSnapshot()));
-
                        InternalMapState<String, Integer, Long, Long> state = 
keyedBackend.createInternalState(IntSerializer.INSTANCE, stateDescr);
 
                        keyedBackend.setCurrentKey("abc");
@@ -224,12 +222,11 @@ public class HeapKeyedStateBackendSnapshotMigrationTest 
extends HeapStateBackend
                final Integer namespace2 = 2;
                final Integer namespace3 = 3;
 
-               try (final HeapKeyedStateBackend<String> keyedBackend = 
createKeyedBackend()) {
-                       final KeyGroupsStateHandle stateHandle;
-                       try (BufferedInputStream bis = new 
BufferedInputStream((new FileInputStream(resource.getFile())))) {
-                               stateHandle = 
InstantiationUtil.deserializeObject(bis, 
Thread.currentThread().getContextClassLoader());
-                       }
-                       
keyedBackend.restore(StateObjectCollection.singleton(stateHandle));
+               final KeyGroupsStateHandle stateHandle;
+               try (BufferedInputStream bis = new BufferedInputStream((new 
FileInputStream(resource.getFile())))) {
+                       stateHandle = InstantiationUtil.deserializeObject(bis, 
Thread.currentThread().getContextClassLoader());
+               }
+               try (final HeapKeyedStateBackend<String> keyedBackend = 
createKeyedBackend(StateObjectCollection.singleton(stateHandle))) {
                        final ListStateDescriptor<Long> stateDescr = new 
ListStateDescriptor<>("my-state", Long.class);
                        stateDescr.initializeSerializerUnlessSet(new 
ExecutionConfig());
 
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapStateBackendTestBase.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapStateBackendTestBase.java
index efcb727..4b1b641 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapStateBackendTestBase.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapStateBackendTestBase.java
@@ -23,7 +23,9 @@ import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.common.typeutils.base.StringSerializer;
 import org.apache.flink.core.fs.CloseableRegistry;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
+import org.apache.flink.runtime.state.AbstractStateBackend;
 import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.TestLocalRecoveryConfig;
 import org.apache.flink.runtime.state.ttl.TtlTimeProvider;
 
@@ -46,25 +48,30 @@ public abstract class HeapStateBackendTestBase {
        @Parameterized.Parameter
        public boolean async;
 
-       public HeapKeyedStateBackend<String> createKeyedBackend() throws 
Exception {
-               return createKeyedBackend(StringSerializer.INSTANCE);
+       public HeapKeyedStateBackend<String> 
createKeyedBackend(Collection<KeyedStateHandle> stateHandles) throws Exception {
+               return createKeyedBackend(StringSerializer.INSTANCE, 
stateHandles);
        }
 
-       public <K> HeapKeyedStateBackend<K> 
createKeyedBackend(TypeSerializer<K> keySerializer) throws Exception {
+       public <K> HeapKeyedStateBackend<K> createKeyedBackend(
+               TypeSerializer<K> keySerializer,
+               Collection<KeyedStateHandle> stateHandles) throws Exception {
                final KeyGroupRange keyGroupRange = new KeyGroupRange(0, 15);
                final int numKeyGroups = keyGroupRange.getNumberOfKeyGroups();
+               ExecutionConfig executionConfig = new ExecutionConfig();
 
-               return new HeapKeyedStateBackend<>(
+               return new HeapKeyedStateBackendBuilder<>(
                        mock(TaskKvStateRegistry.class),
                        keySerializer,
                        HeapStateBackendTestBase.class.getClassLoader(),
                        numKeyGroups,
                        keyGroupRange,
-                       async,
-                       new ExecutionConfig(),
+                       executionConfig,
+                       TtlTimeProvider.DEFAULT,
+                       stateHandles,
+                       
AbstractStateBackend.getCompressionDecorator(executionConfig),
                        TestLocalRecoveryConfig.disabled(),
                        new HeapPriorityQueueSetFactory(keyGroupRange, 
numKeyGroups, 128),
-                       TtlTimeProvider.DEFAULT,
-                       new CloseableRegistry());
+                       async,
+                       new CloseableRegistry()).build();
        }
 }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/StateBackendTestContext.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/StateBackendTestContext.java
index 1c4294d..3fdad0a 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/StateBackendTestContext.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/StateBackendTestContext.java
@@ -23,7 +23,6 @@ import org.apache.flink.api.common.state.State;
 import org.apache.flink.api.common.state.StateDescriptor;
 import org.apache.flink.api.common.typeutils.base.StringSerializer;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
-import org.apache.flink.runtime.checkpoint.StateObjectCollection;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
 import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
@@ -37,7 +36,6 @@ import 
org.apache.flink.runtime.state.internal.InternalKvState;
 import org.apache.flink.util.Preconditions;
 
 import javax.annotation.Nonnull;
-import javax.annotation.Nullable;
 
 import java.io.IOException;
 import java.util.ArrayList;
@@ -137,17 +135,6 @@ public abstract class StateBackendTestContext {
                return snapshotRunnableFuture;
        }
 
-       void restoreSnapshot(@Nullable KeyedStateHandle snapshot) throws 
Exception {
-               Collection<KeyedStateHandle> snapshots = new ArrayList<>();
-               snapshots.add(snapshot);
-               Collection<KeyedStateHandle> restoreState =
-                       snapshot == null ? null : new 
StateObjectCollection<>(snapshots);
-               keyedStateBackend.restore(restoreState);
-               if (snapshot != null) {
-                       snapshots.add(snapshot);
-               }
-       }
-
        public void setCurrentKey(String key) {
                //noinspection resource
                Preconditions.checkNotNull(keyedStateBackend, "keyed backend is 
not initialised");
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlStateTestBase.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlStateTestBase.java
index 443d2a6..e6d5ba3 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlStateTestBase.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlStateTestBase.java
@@ -129,7 +129,6 @@ public abstract class TtlStateTestBase {
        protected <S extends State> StateDescriptor<S, Object> 
initTest(StateTtlConfig ttlConfig) throws Exception {
                this.ttlConfig = ttlConfig;
                sbetc.createAndRestoreKeyedStateBackend(null);
-               sbetc.restoreSnapshot(null);
                sbetc.setCurrentKey("defaultKey");
                StateDescriptor<S, Object> stateDesc = createState();
                ctx().initTestValues();
@@ -155,7 +154,6 @@ public abstract class TtlStateTestBase {
 
        private void restoreSnapshot(KeyedStateHandle snapshot, int 
numberOfKeyGroups) throws Exception {
                sbetc.createAndRestoreKeyedStateBackend(numberOfKeyGroups, 
snapshot);
-               sbetc.restoreSnapshot(snapshot);
                sbetc.setCurrentKey("defaultKey");
                createState();
        }
@@ -435,7 +433,6 @@ public abstract class TtlStateTestBase {
                KeyedStateHandle snapshot = sbetc.takeSnapshot();
                sbetc.createAndRestoreKeyedStateBackend(snapshot);
 
-               sbetc.restoreSnapshot(snapshot);
                sbetc.setCurrentKey("defaultKey");
                sbetc.createState(ctx().createStateDescriptor(), "");
        }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java
index bc0303b..5803c8e 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java
@@ -30,7 +30,6 @@ import org.apache.flink.api.common.state.ValueStateDescriptor;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.core.fs.CloseableRegistry;
-import org.apache.flink.metrics.MetricGroup;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
 import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
@@ -57,7 +56,6 @@ import javax.annotation.Nonnull;
 
 import java.util.ArrayList;
 import java.util.Collection;
-import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -87,9 +85,9 @@ public class MockKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
                        Tuple2.of(FoldingStateDescriptor.class, (StateFactory) 
MockInternalFoldingState::createState)
                ).collect(Collectors.toMap(t -> t.f0, t -> t.f1));
 
-       private final Map<String, Map<K, Map<Object, Object>>> stateValues = 
new HashMap<>();
+       private final Map<String, Map<K, Map<Object, Object>>> stateValues;
 
-       private final Map<String, StateSnapshotTransformer<Object>> 
stateSnapshotFilters = new HashMap<>();
+       private final Map<String, StateSnapshotTransformer<Object>> 
stateSnapshotFilters;
 
        MockKeyedStateBackend(
                TaskKvStateRegistry kvStateRegistry,
@@ -99,10 +97,13 @@ public class MockKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
                KeyGroupRange keyGroupRange,
                ExecutionConfig executionConfig,
                TtlTimeProvider ttlTimeProvider,
-               MetricGroup operatorMetricGroup,
+               Map<String, Map<K, Map<Object, Object>>> stateValues,
+               Map<String, StateSnapshotTransformer<Object>> 
stateSnapshotFilters,
                CloseableRegistry cancelStreamRegistry) {
                super(kvStateRegistry, keySerializer, userCodeClassLoader,
                        numberOfKeyGroups, keyGroupRange, executionConfig, 
ttlTimeProvider, cancelStreamRegistry);
+               this.stateValues = stateValues;
+               this.stateSnapshotFilters = stateSnapshotFilters;
        }
 
        @Override
@@ -187,17 +188,10 @@ public class MockKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
        @SuppressWarnings("unchecked")
        @Override
        public void restore(Collection<KeyedStateHandle> state) {
-               stateValues.clear();
-               state = state == null ? Collections.emptyList() : state;
-               state.forEach(ksh -> 
stateValues.putAll(copy(((MockKeyedStateHandle<K>) ksh).snapshotStates)));
-       }
-
-       private static <K> Map<String, Map<K, Map<Object, Object>>> copy(
-               Map<String, Map<K, Map<Object, Object>>> stateValues) {
-               return copy(stateValues, Collections.emptyMap());
+               // all restore work done in builder and nothing to do here
        }
 
-       private static <K> Map<String, Map<K, Map<Object, Object>>> copy(
+       static <K> Map<String, Map<K, Map<Object, Object>>> copy(
                Map<String, Map<K, Map<Object, Object>>> stateValues, 
Map<String, StateSnapshotTransformer<Object>> stateSnapshotFilters) {
                Map<String, Map<K, Map<Object, Object>>> snapshotStates = new 
HashMap<>();
                for (String stateName : stateValues.keySet()) {
@@ -244,7 +238,7 @@ public class MockKeyedStateBackend<K> extends 
AbstractKeyedStateBackend<K> {
                        0);
        }
 
-       private static class MockKeyedStateHandle<K> implements 
KeyedStateHandle {
+       static class MockKeyedStateHandle<K> implements KeyedStateHandle {
                private static final long serialVersionUID = 1L;
 
                final Map<String, Map<K, Map<Object, Object>>> snapshotStates;
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackendBuilder.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackendBuilder.java
new file mode 100644
index 0000000..5706fb0
--- /dev/null
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackendBuilder.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state.ttl.mock;
+
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.CloseableRegistry;
+import org.apache.flink.runtime.query.TaskKvStateRegistry;
+import org.apache.flink.runtime.state.AbstractKeyedStateBackendBuilder;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.KeyedStateHandle;
+import org.apache.flink.runtime.state.StateSnapshotTransformer;
+import org.apache.flink.runtime.state.StreamCompressionDecorator;
+import org.apache.flink.runtime.state.ttl.TtlTimeProvider;
+
+import javax.annotation.Nonnull;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * Builder class for {@link MockKeyedStateBackend}.
+ *
+ * @param <K> The data type that the key serializer serializes.
+ */
+public class MockKeyedStateBackendBuilder<K> extends 
AbstractKeyedStateBackendBuilder<K> {
+       public MockKeyedStateBackendBuilder(
+               TaskKvStateRegistry kvStateRegistry,
+               TypeSerializer<K> keySerializer,
+               ClassLoader userCodeClassLoader,
+               int numberOfKeyGroups,
+               KeyGroupRange keyGroupRange,
+               ExecutionConfig executionConfig,
+               TtlTimeProvider ttlTimeProvider,
+               @Nonnull Collection<KeyedStateHandle> stateHandles,
+               StreamCompressionDecorator keyGroupCompressionDecorator,
+               CloseableRegistry cancelStreamRegistry) {
+               super(
+                       kvStateRegistry,
+                       keySerializer,
+                       userCodeClassLoader,
+                       numberOfKeyGroups,
+                       keyGroupRange,
+                       executionConfig,
+                       ttlTimeProvider,
+                       stateHandles,
+                       keyGroupCompressionDecorator,
+                       cancelStreamRegistry);
+       }
+
+       @Override
+       public MockKeyedStateBackend<K> build() {
+               Map<String, Map<K, Map<Object, Object>>> stateValues = new 
HashMap<>();
+               Map<String, StateSnapshotTransformer<Object>> 
stateSnapshotFilters = new HashMap<>();
+               MockRestoreOperation<K> restoreOperation = new 
MockRestoreOperation<>(restoreStateHandles, stateValues);
+               restoreOperation.restore();
+               return new MockKeyedStateBackend<>(
+                       kvStateRegistry,
+                       keySerializer,
+                       userCodeClassLoader,
+                       numberOfKeyGroups,
+                       keyGroupRange,
+                       executionConfig,
+                       ttlTimeProvider,
+                       stateValues,
+                       stateSnapshotFilters,
+                       cancelStreamRegistry);
+       }
+}
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockRestoreOperation.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockRestoreOperation.java
new file mode 100644
index 0000000..ffc090a
--- /dev/null
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockRestoreOperation.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state.ttl.mock;
+
+import org.apache.flink.runtime.state.KeyedStateHandle;
+import org.apache.flink.runtime.state.RestoreOperation;
+
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Map;
+
+import static 
org.apache.flink.runtime.state.ttl.mock.MockKeyedStateBackend.copy;
+
+/**
+ * Implementation of mock restore operation.
+ *
+ * @param <K> The data type that the serializer serializes.
+ */
+public class MockRestoreOperation<K> implements RestoreOperation<Void> {
+       private final Collection<KeyedStateHandle> state;
+       private final Map<String, Map<K, Map<Object, Object>>> stateValues;
+
+       public MockRestoreOperation(
+               Collection<KeyedStateHandle> state,
+               Map<String, Map<K, Map<Object, Object>>> stateValues) {
+               this.state = state;
+               this.stateValues = stateValues;
+       }
+
+       @Override
+       public Void restore() {
+               state.forEach(ksh -> stateValues.putAll(
+                       copy(((MockKeyedStateBackend.MockKeyedStateHandle<K>) 
ksh).snapshotStates,
+                               Collections.emptyMap())));
+               return null;
+       }
+}
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockStateBackend.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockStateBackend.java
index 3a0bb1b..117db95 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockStateBackend.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockStateBackend.java
@@ -42,6 +42,7 @@ import org.apache.flink.runtime.state.ttl.TtlTimeProvider;
 
 import javax.annotation.Nonnull;
 import javax.annotation.Nullable;
+
 import java.util.Collection;
 
 /** mack state backend. */
@@ -127,7 +128,7 @@ public class MockStateBackend extends AbstractStateBackend {
                MetricGroup metricGroup,
                @Nonnull Collection<KeyedStateHandle> stateHandles,
                CloseableRegistry cancelStreamRegistry) {
-               return new MockKeyedStateBackend<>(
+               return new MockKeyedStateBackendBuilder<>(
                        new KvStateRegistry().createTaskRegistry(jobID, new 
JobVertexID()),
                        keySerializer,
                        env.getUserClassLoader(),
@@ -135,8 +136,9 @@ public class MockStateBackend extends AbstractStateBackend {
                        keyGroupRange,
                        env.getExecutionConfig(),
                        ttlTimeProvider,
-                       metricGroup,
-                       cancelStreamRegistry);
+                       stateHandles,
+                       
AbstractStateBackend.getCompressionDecorator(env.getExecutionConfig()),
+                       cancelStreamRegistry).build();
        }
 
        @Override
diff --git 
a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
 
b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
index a4251d9..62e30ac 100644
--- 
a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
+++ 
b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
@@ -39,10 +39,8 @@ import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.LocalRecoveryConfig;
 import org.apache.flink.runtime.state.OperatorStateBackend;
-import org.apache.flink.runtime.state.SnappyStreamCompressionDecorator;
 import org.apache.flink.runtime.state.StateBackend;
 import org.apache.flink.runtime.state.StreamCompressionDecorator;
-import org.apache.flink.runtime.state.UncompressedStreamCompressionDecorator;
 import org.apache.flink.runtime.state.filesystem.FsStateBackend;
 import org.apache.flink.runtime.state.ttl.TtlTimeProvider;
 import org.apache.flink.util.AbstractID;
@@ -521,14 +519,6 @@ public class RocksDBStateBackend extends 
AbstractStateBackend implements Configu
                return builder.build();
        }
 
-       public static StreamCompressionDecorator 
getCompressionDecorator(ExecutionConfig executionConfig) {
-               if (executionConfig != null && 
executionConfig.isUseSnapshotCompression()) {
-                       return SnappyStreamCompressionDecorator.INSTANCE;
-               } else {
-                       return UncompressedStreamCompressionDecorator.INSTANCE;
-               }
-       }
-
        @Override
        public OperatorStateBackend createOperatorStateBackend(
                        Environment env,
diff --git 
a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBRestoreOperation.java
 
b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBRestoreOperation.java
index ff70199..b90c4db 100644
--- 
a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBRestoreOperation.java
+++ 
b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBRestoreOperation.java
@@ -18,10 +18,12 @@
 
 package org.apache.flink.contrib.streaming.state.restore;
 
+import org.apache.flink.runtime.state.RestoreOperation;
+
 /**
  * Interface for RocksDB restore.
  */
-public interface RocksDBRestoreOperation {
+public interface RocksDBRestoreOperation extends 
RestoreOperation<RocksDBRestoreResult> {
        /**
         * Restores state that was previously snapshot-ed from the provided 
state handles.
         */
diff --git 
a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java
 
b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java
index ed3f6f6..819c864 100644
--- 
a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java
+++ 
b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java
@@ -32,6 +32,7 @@ import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
 import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
+import org.apache.flink.runtime.state.AbstractStateBackend;
 import org.apache.flink.runtime.state.IncrementalRemoteKeyedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyedStateHandle;
@@ -213,7 +214,7 @@ public class RocksDBStateBackendTest extends 
StateBackendTestBase<RocksDBStateBa
                        TtlTimeProvider.DEFAULT,
                        new UnregisteredMetricsGroup(),
                        Collections.emptyList(),
-                       
RocksDBStateBackend.getCompressionDecorator(env.getExecutionConfig()),
+                       
AbstractStateBackend.getCompressionDecorator(env.getExecutionConfig()),
                        spy(db),
                        defaultCFHandle,
                        new CloseableRegistry()).build();
@@ -290,7 +291,7 @@ public class RocksDBStateBackendTest extends 
StateBackendTestBase<RocksDBStateBa
                                TtlTimeProvider.DEFAULT,
                                new UnregisteredMetricsGroup(),
                                Collections.emptyList(),
-                               
RocksDBStateBackend.getCompressionDecorator(executionConfig),
+                               
AbstractStateBackend.getCompressionDecorator(executionConfig),
                                db,
                                defaultCFHandle,
                                new CloseableRegistry()).build();

Reply via email to