http://git-wip-us.apache.org/repos/asf/flink/blob/ab014ef9/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapMapState.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapMapState.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapMapState.java
index 0360161..f393237 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapMapState.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapMapState.java
@@ -22,8 +22,6 @@ import org.apache.flink.api.common.state.MapState;
 import org.apache.flink.api.common.state.MapStateDescriptor;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer;
-import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
-import org.apache.flink.runtime.state.KeyedStateBackend;
 import org.apache.flink.runtime.state.internal.InternalMapState;
 import org.apache.flink.util.Preconditions;
 
@@ -47,35 +45,23 @@ public class HeapMapState<K, N, UK, UV>
        /**
         * Creates a new key/value state for the given hash map of key/value 
pairs.
         *
-        * @param backend    The state backend backing that created this state.
         * @param stateDesc  The state identifier for the state. This contains 
name
         *                   and can create a default state value.
         * @param stateTable The state tab;e to use in this kev/value state. 
May contain initial state.
         */
-       public HeapMapState(KeyedStateBackend<K> backend,
+       public HeapMapState(
                        MapStateDescriptor<UK, UV> stateDesc,
                        StateTable<K, N, HashMap<UK, UV>> stateTable,
                        TypeSerializer<K> keySerializer,
                        TypeSerializer<N> namespaceSerializer) {
-               super(backend, stateDesc, stateTable, keySerializer, 
namespaceSerializer);
+               super(stateDesc, stateTable, keySerializer, 
namespaceSerializer);
        }
 
        @Override
        public UV get(UK userKey) {
-               Preconditions.checkState(currentNamespace != null, "No 
namespace set.");
-               Preconditions.checkState(backend.getCurrentKey() != null, "No 
key set.");
 
-               Map<N, Map<K, HashMap<UK, UV>>> namespaceMap = 
stateTable.get(backend.getCurrentKeyGroupIndex());
-               if (namespaceMap == null) {
-                       return null;
-               }
-
-               Map<K, HashMap<UK, UV>> keyedMap = 
namespaceMap.get(currentNamespace);
-               if (keyedMap == null) {
-                       return null;
-               }
+               HashMap<UK, UV> userMap = stateTable.get(currentNamespace);
 
-               HashMap<UK, UV> userMap = 
keyedMap.get(backend.<K>getCurrentKey());
                if (userMap == null) {
                        return null;
                }
@@ -85,25 +71,11 @@ public class HeapMapState<K, N, UK, UV>
 
        @Override
        public void put(UK userKey, UV userValue) {
-               Preconditions.checkState(currentNamespace != null, "No 
namespace set.");
-               Preconditions.checkState(backend.getCurrentKey() != null, "No 
key set.");
-
-               Map<N, Map<K, HashMap<UK, UV>>> namespaceMap = 
stateTable.get(backend.getCurrentKeyGroupIndex());
-               if (namespaceMap == null) {
-                       namespaceMap = createNewMap();
-                       stateTable.set(backend.getCurrentKeyGroupIndex(), 
namespaceMap);
-               }
-
-               Map<K, HashMap<UK, UV>> keyedMap = 
namespaceMap.get(currentNamespace);
-               if (keyedMap == null) {
-                       keyedMap = createNewMap();
-                       namespaceMap.put(currentNamespace, keyedMap);
-               }
 
-               HashMap<UK, UV> userMap = keyedMap.get(backend.getCurrentKey());
+               HashMap<UK, UV> userMap = stateTable.get(currentNamespace);
                if (userMap == null) {
                        userMap = new HashMap<>();
-                       keyedMap.put(backend.getCurrentKey(), userMap);
+                       stateTable.put(currentNamespace, userMap);
                }
 
                userMap.put(userKey, userValue);
@@ -111,52 +83,27 @@ public class HeapMapState<K, N, UK, UV>
 
        @Override
        public void putAll(Map<UK, UV> value) {
-               Preconditions.checkState(currentNamespace != null, "No 
namespace set.");
-               Preconditions.checkState(backend.getCurrentKey() != null, "No 
key set.");
 
-               Map<N, Map<K, HashMap<UK, UV>>> namespaceMap = 
stateTable.get(backend.getCurrentKeyGroupIndex());
-               if (namespaceMap == null) {
-                       namespaceMap = createNewMap();
-                       stateTable.set(backend.getCurrentKeyGroupIndex(), 
namespaceMap);
-               }
-
-               Map<K, HashMap<UK, UV>> keyedMap = 
namespaceMap.get(currentNamespace);
-               if (keyedMap == null) {
-                       keyedMap = createNewMap();
-                       namespaceMap.put(currentNamespace, keyedMap);
-               }
+               HashMap<UK, UV> userMap = stateTable.get(currentNamespace);
 
-               HashMap<UK, UV> userMap = keyedMap.get(backend.getCurrentKey());
                if (userMap == null) {
                        userMap = new HashMap<>();
-                       keyedMap.put(backend.getCurrentKey(), userMap);
+                       stateTable.put(currentNamespace, userMap);
                }
 
                userMap.putAll(value);
        }
-       
+
        @Override
        public void remove(UK userKey) {
-               Preconditions.checkState(currentNamespace != null, "No 
namespace set.");
-               Preconditions.checkState(backend.getCurrentKey() != null, "No 
key set.");
-
-               Map<N, Map<K, HashMap<UK, UV>>> namespaceMap = 
stateTable.get(backend.getCurrentKeyGroupIndex());
-               if (namespaceMap == null) {
-                       return;
-               }
 
-               Map<K, HashMap<UK, UV>> keyedMap = 
namespaceMap.get(currentNamespace);
-               if (keyedMap == null) {
-                       return;
-               }
-
-               HashMap<UK, UV> userMap = keyedMap.get(backend.getCurrentKey());
+               HashMap<UK, UV> userMap = stateTable.get(currentNamespace);
                if (userMap == null) {
                        return;
                }
 
                userMap.remove(userKey);
-               
+
                if (userMap.isEmpty()) {
                        clear();
                }
@@ -164,101 +111,31 @@ public class HeapMapState<K, N, UK, UV>
 
        @Override
        public boolean contains(UK userKey) {
-               Preconditions.checkState(currentNamespace != null, "No 
namespace set.");
-               Preconditions.checkState(backend.getCurrentKey() != null, "No 
key set.");
-
-               Map<N, Map<K, HashMap<UK, UV>>> namespaceMap = 
stateTable.get(backend.getCurrentKeyGroupIndex());
-               if (namespaceMap == null) {
-                       return false;
-               }
-
-               Map<K, HashMap<UK, UV>> keyedMap = 
namespaceMap.get(currentNamespace);
-               if (keyedMap == null) {
-                       return false;
-               }
-
-               HashMap<UK, UV> userMap = 
keyedMap.get(backend.<K>getCurrentKey());
-               
+               HashMap<UK, UV> userMap = stateTable.get(currentNamespace);
                return userMap != null && userMap.containsKey(userKey);
        }
 
        @Override
        public Iterable<Map.Entry<UK, UV>> entries() {
-               Preconditions.checkState(currentNamespace != null, "No 
namespace set.");
-               Preconditions.checkState(backend.getCurrentKey() != null, "No 
key set.");
-
-               Map<N, Map<K, HashMap<UK, UV>>> namespaceMap = 
stateTable.get(backend.getCurrentKeyGroupIndex());
-               if (namespaceMap == null) {
-                       return null;
-               }
-
-               Map<K, HashMap<UK, UV>> keyedMap = 
namespaceMap.get(currentNamespace);
-               if (keyedMap == null) {
-                       return null;
-               }
-
-               HashMap<UK, UV> userMap = 
keyedMap.get(backend.<K>getCurrentKey());
-
+               HashMap<UK, UV> userMap = stateTable.get(currentNamespace);
                return userMap == null ? null : userMap.entrySet();
        }
        
        @Override
        public Iterable<UK> keys() {
-               Preconditions.checkState(currentNamespace != null, "No 
namespace set.");
-               Preconditions.checkState(backend.getCurrentKey() != null, "No 
key set.");
-
-               Map<N, Map<K, HashMap<UK, UV>>> namespaceMap = 
stateTable.get(backend.getCurrentKeyGroupIndex());
-               if (namespaceMap == null) {
-                       return null;
-               }
-
-               Map<K, HashMap<UK, UV>> keyedMap = 
namespaceMap.get(currentNamespace);
-               if (keyedMap == null) {
-                       return null;
-               }
-
-               HashMap<UK, UV> userMap = 
keyedMap.get(backend.<K>getCurrentKey());
-
+               HashMap<UK, UV> userMap = stateTable.get(currentNamespace);
                return userMap == null ? null : userMap.keySet();
        }
 
        @Override
        public Iterable<UV> values() {
-               Preconditions.checkState(currentNamespace != null, "No 
namespace set.");
-               Preconditions.checkState(backend.getCurrentKey() != null, "No 
key set.");
-
-               Map<N, Map<K, HashMap<UK, UV>>> namespaceMap = 
stateTable.get(backend.getCurrentKeyGroupIndex());
-               if (namespaceMap == null) {
-                       return null;
-               }
-
-               Map<K, HashMap<UK, UV>> keyedMap = 
namespaceMap.get(currentNamespace);
-               if (keyedMap == null) {
-                       return null;
-               }
-
-               HashMap<UK, UV> userMap = 
keyedMap.get(backend.<K>getCurrentKey());
-
+               HashMap<UK, UV> userMap = stateTable.get(currentNamespace);
                return userMap == null ? null : userMap.values();
        }
 
        @Override
        public Iterator<Map.Entry<UK, UV>> iterator() {
-               Preconditions.checkState(currentNamespace != null, "No 
namespace set.");
-               Preconditions.checkState(backend.getCurrentKey() != null, "No 
key set.");
-
-               Map<N, Map<K, HashMap<UK, UV>>> namespaceMap = 
stateTable.get(backend.getCurrentKeyGroupIndex());
-               if (namespaceMap == null) {
-                       return null;
-               }
-
-               Map<K, HashMap<UK, UV>> keyedMap = 
namespaceMap.get(currentNamespace);
-               if (keyedMap == null) {
-                       return null;
-               }
-
-               HashMap<UK, UV> userMap = 
keyedMap.get(backend.<K>getCurrentKey());
-
+               HashMap<UK, UV> userMap = stateTable.get(currentNamespace);
                return userMap == null ? null : userMap.entrySet().iterator();
        }
 
@@ -267,22 +144,12 @@ public class HeapMapState<K, N, UK, UV>
                Preconditions.checkState(namespace != null, "No namespace 
given.");
                Preconditions.checkState(key != null, "No key given.");
 
-               Map<N, Map<K, HashMap<UK, UV>>> namespaceMap = 
stateTable.get(KeyGroupRangeAssignment.assignToKeyGroup(key, 
backend.getNumberOfKeyGroups()));
-
-               if (namespaceMap == null) {
-                       return null;
-               }
+               HashMap<UK, UV> result = stateTable.get(key, namespace);
 
-               Map<K, HashMap<UK, UV>> keyedMap = namespaceMap.get(namespace);
-               if (keyedMap == null) {
+               if (null == result) {
                        return null;
                }
 
-               HashMap<UK, UV> result = keyedMap.get(key);
-               if (result == null) {
-                       return null;
-               }
-               
                TypeSerializer<UK> userKeySerializer = 
stateDesc.getKeySerializer();
                TypeSerializer<UV> userValueSerializer = 
stateDesc.getValueSerializer();
 

http://git-wip-us.apache.org/repos/asf/flink/blob/ab014ef9/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapReducingState.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapReducingState.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapReducingState.java
index 090a660..6e11327 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapReducingState.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapReducingState.java
@@ -22,17 +22,16 @@ import org.apache.flink.api.common.functions.ReduceFunction;
 import org.apache.flink.api.common.state.ReducingState;
 import org.apache.flink.api.common.state.ReducingStateDescriptor;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
-import org.apache.flink.runtime.state.KeyedStateBackend;
+import org.apache.flink.runtime.state.StateTransformationFunction;
 import org.apache.flink.runtime.state.internal.InternalReducingState;
 import org.apache.flink.util.Preconditions;
 
 import java.io.IOException;
-import java.util.Map;
 
 /**
  * Heap-backed partitioned {@link 
org.apache.flink.api.common.state.ReducingState} that is
  * snapshotted into files.
- * 
+ *
  * @param <K> The type of the key.
  * @param <N> The type of the namespace.
  * @param <V> The type of the value.
@@ -41,25 +40,23 @@ public class HeapReducingState<K, N, V>
                extends AbstractHeapMergingState<K, N, V, V, V, 
ReducingState<V>, ReducingStateDescriptor<V>>
                implements InternalReducingState<N, V> {
 
-       private final ReduceFunction<V> reduceFunction;
+       private final ReduceTransformation<V> reduceTransformation;
 
        /**
         * Creates a new key/value state for the given hash map of key/value 
pairs.
         *
-        * @param backend The state backend backing that created this state.
         * @param stateDesc The state identifier for the state. This contains 
name
         *                           and can create a default state value.
         * @param stateTable The state table to use in this kev/value state. 
May contain initial state.
         */
        public HeapReducingState(
-                       KeyedStateBackend<K> backend,
                        ReducingStateDescriptor<V> stateDesc,
                        StateTable<K, N, V> stateTable,
                        TypeSerializer<K> keySerializer,
                        TypeSerializer<N> namespaceSerializer) {
 
-               super(backend, stateDesc, stateTable, keySerializer, 
namespaceSerializer);
-               this.reduceFunction = stateDesc.getReduceFunction();
+               super(stateDesc, stateTable, keySerializer, 
namespaceSerializer);
+               this.reduceTransformation = new 
ReduceTransformation<>(stateDesc.getReduceFunction());
        }
 
        // 
------------------------------------------------------------------------
@@ -68,62 +65,21 @@ public class HeapReducingState<K, N, V>
 
        @Override
        public V get() {
-               Preconditions.checkState(currentNamespace != null, "No 
namespace set.");
-               Preconditions.checkState(backend.getCurrentKey() != null, "No 
key set.");
-
-               Map<N, Map<K, V>> namespaceMap =
-                               
stateTable.get(backend.getCurrentKeyGroupIndex());
-
-               if (namespaceMap == null) {
-                       return null;
-               }
-
-               Map<K, V> keyedMap = namespaceMap.get(currentNamespace);
-
-               if (keyedMap == null) {
-                       return null;
-               }
-
-               return keyedMap.get(backend.<K>getCurrentKey());
+               return stateTable.get(currentNamespace);
        }
 
        @Override
        public void add(V value) throws IOException {
-               Preconditions.checkState(currentNamespace != null, "No 
namespace set.");
-               Preconditions.checkState(backend.getCurrentKey() != null, "No 
key set.");
 
                if (value == null) {
                        clear();
                        return;
                }
 
-               Map<N, Map<K, V>> namespaceMap =
-                               
stateTable.get(backend.getCurrentKeyGroupIndex());
-
-               if (namespaceMap == null) {
-                       namespaceMap = createNewMap();
-                       stateTable.set(backend.getCurrentKeyGroupIndex(), 
namespaceMap);
-               }
-
-               Map<K, V> keyedMap = namespaceMap.get(currentNamespace);
-
-               if (keyedMap == null) {
-                       keyedMap = createNewMap();
-                       namespaceMap.put(currentNamespace, keyedMap);
-               }
-
-               V currentValue = keyedMap.put(backend.<K>getCurrentKey(), 
value);
-
-               if (currentValue == null) {
-                       // we're good, just added the new value
-               } else {
-                       V reducedValue;
-                       try {
-                               reducedValue = 
reduceFunction.reduce(currentValue, value);
-                       } catch (Exception e) {
-                               throw new IOException("Exception while applying 
ReduceFunction in reducing state", e);
-                       }
-                       keyedMap.put(backend.<K>getCurrentKey(), reducedValue);
+               try {
+                       stateTable.transform(currentNamespace, value, 
reduceTransformation);
+               } catch (Exception e) {
+                       throw new IOException("Exception while applying 
ReduceFunction in reducing state", e);
                }
        }
 
@@ -133,6 +89,20 @@ public class HeapReducingState<K, N, V>
 
        @Override
        protected V mergeState(V a, V b) throws Exception {
-               return reduceFunction.reduce(a, b);
+               return reduceTransformation.apply(a, b);
+       }
+
+       static final class ReduceTransformation<V> implements 
StateTransformationFunction<V, V> {
+
+               private final ReduceFunction<V> reduceFunction;
+
+               ReduceTransformation(ReduceFunction<V> reduceFunction) {
+                       this.reduceFunction = 
Preconditions.checkNotNull(reduceFunction);
+               }
+
+               @Override
+               public V apply(V previousState, V value) throws Exception {
+                       return previousState != null ? 
reduceFunction.reduce(previousState, value) : value;
+               }
        }
-}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/ab014ef9/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapValueState.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapValueState.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapValueState.java
index 9e042fe..6de62a8 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapValueState.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapValueState.java
@@ -21,16 +21,12 @@ package org.apache.flink.runtime.state.heap;
 import org.apache.flink.api.common.state.ValueState;
 import org.apache.flink.api.common.state.ValueStateDescriptor;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
-import org.apache.flink.runtime.state.KeyedStateBackend;
 import org.apache.flink.runtime.state.internal.InternalValueState;
-import org.apache.flink.util.Preconditions;
-
-import java.util.Map;
 
 /**
  * Heap-backed partitioned {@link 
org.apache.flink.api.common.state.ValueState} that is snapshotted
  * into files.
- * 
+ *
  * @param <K> The type of the key.
  * @param <N> The type of the namespace.
  * @param <V> The type of the value.
@@ -42,39 +38,21 @@ public class HeapValueState<K, N, V>
        /**
         * Creates a new key/value state for the given hash map of key/value 
pairs.
         *
-        * @param backend The state backend backing that created this state.
         * @param stateDesc The state identifier for the state. This contains 
name
         *                           and can create a default state value.
         * @param stateTable The state tab;e to use in this kev/value state. 
May contain initial state.
         */
        public HeapValueState(
-                       KeyedStateBackend<K> backend,
                        ValueStateDescriptor<V> stateDesc,
                        StateTable<K, N, V> stateTable,
                        TypeSerializer<K> keySerializer,
                        TypeSerializer<N> namespaceSerializer) {
-               super(backend, stateDesc, stateTable, keySerializer, 
namespaceSerializer);
+               super(stateDesc, stateTable, keySerializer, 
namespaceSerializer);
        }
 
        @Override
        public V value() {
-               Preconditions.checkState(currentNamespace != null, "No 
namespace set.");
-               Preconditions.checkState(backend.getCurrentKey() != null, "No 
key set.");
-
-               Map<N, Map<K, V>> namespaceMap =
-                               
stateTable.get(backend.getCurrentKeyGroupIndex());
-
-               if (namespaceMap == null) {
-                       return stateDesc.getDefaultValue();
-               }
-
-               Map<K, V> keyedMap = namespaceMap.get(currentNamespace);
-
-               if (keyedMap == null) {
-                       return stateDesc.getDefaultValue();
-               }
-
-               V result = keyedMap.get(backend.<K>getCurrentKey());
+               final V result = stateTable.get(currentNamespace);
 
                if (result == null) {
                        return stateDesc.getDefaultValue();
@@ -85,29 +63,12 @@ public class HeapValueState<K, N, V>
 
        @Override
        public void update(V value) {
-               Preconditions.checkState(currentNamespace != null, "No 
namespace set.");
-               Preconditions.checkState(backend.getCurrentKey() != null, "No 
key set.");
 
                if (value == null) {
                        clear();
                        return;
                }
 
-               Map<N, Map<K, V>> namespaceMap =
-                               
stateTable.get(backend.getCurrentKeyGroupIndex());
-
-               if (namespaceMap == null) {
-                       namespaceMap = createNewMap();
-                       stateTable.set(backend.getCurrentKeyGroupIndex(), 
namespaceMap);
-               }
-
-               Map<K, V> keyedMap = namespaceMap.get(currentNamespace);
-
-               if (keyedMap == null) {
-                       keyedMap = createNewMap();
-                       namespaceMap.put(currentNamespace, keyedMap);
-               }
-
-               keyedMap.put(backend.<K>getCurrentKey(), value);
+               stateTable.put(currentNamespace, value);
        }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/ab014ef9/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/InternalKeyContext.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/InternalKeyContext.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/InternalKeyContext.java
new file mode 100644
index 0000000..cb0582b
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/InternalKeyContext.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state.heap;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.runtime.state.KeyGroupRange;
+
+/**
+ * This interface is the current context of a keyed state. It provides 
information about the currently selected key in
+ * the context, the corresponding key-group, and other key and key-grouping 
related information.
+ * <p>
+ * The typical use case for this interface is providing a view on the 
current-key selection aspects of
+ * {@link org.apache.flink.runtime.state.KeyedStateBackend}.
+ */
+@Internal
+public interface InternalKeyContext<K> {
+
+       /**
+        * Used by states to access the current key.
+        */
+       K getCurrentKey();
+
+       /**
+        * Returns the key-group to which the current key belongs.
+        */
+       int getCurrentKeyGroupIndex();
+
+       /**
+        * Returns the number of key-groups aka max parallelism.
+        */
+       int getNumberOfKeyGroups();
+
+       /**
+        * Returns the key groups for this backend.
+        */
+       KeyGroupRange getKeyGroupRange();
+
+       /**
+        * {@link TypeSerializer} for the state backend key type.
+        */
+       TypeSerializer<K> getKeySerializer();
+
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/ab014ef9/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/NestedMapsStateTable.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/NestedMapsStateTable.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/NestedMapsStateTable.java
new file mode 100644
index 0000000..22f344d
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/NestedMapsStateTable.java
@@ -0,0 +1,363 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.runtime.state.heap;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
+import org.apache.flink.runtime.state.RegisteredBackendStateMetaInfo;
+import org.apache.flink.runtime.state.StateTransformationFunction;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * This implementation of {@link StateTable} uses nested {@link HashMap} 
objects. It is also maintaining a partitioning
+ * by key-group.
+ * <p>
+ * In contrast to {@link CopyOnWriteStateTable}, this implementation does not 
support asynchronous snapshots. However,
+ * it might have a better memory footprint for some use-cases, e.g. it is 
naturally de-duplicating namespace objects.
+ *
+ * @param <K> type of key.
+ * @param <N> type of namespace.
+ * @param <S> type of state.
+ */
+@Internal
+public class NestedMapsStateTable<K, N, S> extends StateTable<K, N, S> {
+
+       /**
+        * Map for holding the actual state objects. The outer array represents 
the key-groups. The nested maps provide
+        * an outer scope by namespace and an inner scope by key.
+        */
+       private final Map<N, Map<K, S>>[] state;
+
+       /**
+        * The offset to the contiguous key groups
+        */
+       private final int keyGroupOffset;
+
+       // 
------------------------------------------------------------------------
+
+       /**
+        * Creates a new {@link NestedMapsStateTable} for the given key context 
and meta info.
+        *
+        * @param keyContext the key context.
+        * @param metaInfo the meta information for this state table.
+        */
+       public NestedMapsStateTable(InternalKeyContext<K> keyContext, 
RegisteredBackendStateMetaInfo<N, S> metaInfo) {
+               super(keyContext, metaInfo);
+               this.keyGroupOffset = 
keyContext.getKeyGroupRange().getStartKeyGroup();
+
+               @SuppressWarnings("unchecked")
+               Map<N, Map<K, S>>[] state = (Map<N, Map<K, S>>[]) new 
Map[keyContext.getNumberOfKeyGroups()];
+               this.state = state;
+       }
+
+       // 
------------------------------------------------------------------------
+       //  access to maps
+       // 
------------------------------------------------------------------------
+
+       /**
+        * Returns the internal data structure.
+        */
+       @VisibleForTesting
+       public Map<N, Map<K, S>>[] getState() {
+               return state;
+       }
+
+       @VisibleForTesting
+       Map<N, Map<K, S>> getMapForKeyGroup(int keyGroupIndex) {
+               final int pos = indexToOffset(keyGroupIndex);
+               if (pos >= 0 && pos < state.length) {
+                       return state[pos];
+               } else {
+                       return null;
+               }
+       }
+
+       /**
+        * Sets the given map for the given key-group.
+        */
+       private void setMapForKeyGroup(int keyGroupId, Map<N, Map<K, S>> map) {
+               try {
+                       state[indexToOffset(keyGroupId)] = map;
+               } catch (ArrayIndexOutOfBoundsException e) {
+                       throw new IllegalArgumentException("Key group index out 
of range of key group range [" +
+                                       keyGroupOffset + ", " + (keyGroupOffset 
+ state.length) + ").");
+               }
+       }
+
+       /**
+        * Translates a key-group id to the internal array offset.
+        */
+       private int indexToOffset(int index) {
+               return index - keyGroupOffset;
+       }
+
+       // 
------------------------------------------------------------------------
+
+       @Override
+       public int size() {
+               int count = 0;
+               for (Map<N, Map<K, S>> namespaceMap : state) {
+                       if (null != namespaceMap) {
+                               for (Map<K, S> keyMap : namespaceMap.values()) {
+                                       if (null != keyMap) {
+                                               count += keyMap.size();
+                                       }
+                               }
+                       }
+               }
+               return count;
+       }
+
+       @Override
+       public S get(N namespace) {
+               return get(keyContext.getCurrentKey(), 
keyContext.getCurrentKeyGroupIndex(), namespace);
+       }
+
+       @Override
+       public boolean containsKey(N namespace) {
+               return containsKey(keyContext.getCurrentKey(), 
keyContext.getCurrentKeyGroupIndex(), namespace);
+       }
+
+       @Override
+       public void put(N namespace, S state) {
+               put(keyContext.getCurrentKey(), 
keyContext.getCurrentKeyGroupIndex(), namespace, state);
+       }
+
+       @Override
+       public S putAndGetOld(N namespace, S state) {
+               return putAndGetOld(keyContext.getCurrentKey(), 
keyContext.getCurrentKeyGroupIndex(), namespace, state);
+       }
+
+       @Override
+       public void remove(N namespace) {
+               remove(keyContext.getCurrentKey(), 
keyContext.getCurrentKeyGroupIndex(), namespace);
+       }
+
+       @Override
+       public S removeAndGetOld(N namespace) {
+               return removeAndGetOld(keyContext.getCurrentKey(), 
keyContext.getCurrentKeyGroupIndex(), namespace);
+       }
+
+       @Override
+       public S get(K key, N namespace) {
+               int keyGroup = KeyGroupRangeAssignment.assignToKeyGroup(key, 
keyContext.getNumberOfKeyGroups());
+               return get(key, keyGroup, namespace);
+       }
+
+       // 
------------------------------------------------------------------------
+
+       private boolean containsKey(K key, int keyGroupIndex, N namespace) {
+
+               checkKeyNamespacePreconditions(key, namespace);
+
+               Map<N, Map<K, S>> namespaceMap = 
getMapForKeyGroup(keyGroupIndex);
+
+               if (namespaceMap == null) {
+                       return false;
+               }
+
+               Map<K, S> keyedMap = namespaceMap.get(namespace);
+
+               return keyedMap != null && keyedMap.containsKey(key);
+       }
+
+       S get(K key, int keyGroupIndex, N namespace) {
+
+               checkKeyNamespacePreconditions(key, namespace);
+
+               Map<N, Map<K, S>> namespaceMap = 
getMapForKeyGroup(keyGroupIndex);
+
+               if (namespaceMap == null) {
+                       return null;
+               }
+
+               Map<K, S> keyedMap = namespaceMap.get(namespace);
+
+               if (keyedMap == null) {
+                       return null;
+               }
+
+               return keyedMap.get(key);
+       }
+
+       @Override
+       public void put(K key, int keyGroupIndex, N namespace, S value) {
+               putAndGetOld(key, keyGroupIndex, namespace, value);
+       }
+
+       private S putAndGetOld(K key, int keyGroupIndex, N namespace, S value) {
+
+               checkKeyNamespacePreconditions(key, namespace);
+
+               Map<N, Map<K, S>> namespaceMap = 
getMapForKeyGroup(keyGroupIndex);
+
+               if (namespaceMap == null) {
+                       namespaceMap = new HashMap<>();
+                       setMapForKeyGroup(keyGroupIndex, namespaceMap);
+               }
+
+               Map<K, S> keyedMap = namespaceMap.get(namespace);
+
+               if (keyedMap == null) {
+                       keyedMap = new HashMap<>();
+                       namespaceMap.put(namespace, keyedMap);
+               }
+
+               return keyedMap.put(key, value);
+       }
+
+       private void remove(K key, int keyGroupIndex, N namespace) {
+               removeAndGetOld(key, keyGroupIndex, namespace);
+       }
+
+       private S removeAndGetOld(K key, int keyGroupIndex, N namespace) {
+
+               checkKeyNamespacePreconditions(key, namespace);
+
+               Map<N, Map<K, S>> namespaceMap = 
getMapForKeyGroup(keyGroupIndex);
+
+               if (namespaceMap == null) {
+                       return null;
+               }
+
+               Map<K, S> keyedMap = namespaceMap.get(namespace);
+
+               if (keyedMap == null) {
+                       return null;
+               }
+
+               S removed = keyedMap.remove(key);
+
+               if (keyedMap.isEmpty()) {
+                       namespaceMap.remove(namespace);
+               }
+
+               return removed;
+       }
+
+       private void checkKeyNamespacePreconditions(K key, N namespace) {
+               Preconditions.checkNotNull(key, "No key set. This method should 
not be called outside of a keyed context.");
+               Preconditions.checkNotNull(namespace, "Provided namespace is 
null.");
+       }
+
+       @Override
+       public int sizeOfNamespace(Object namespace) {
+               int count = 0;
+               for (Map<N, Map<K, S>> namespaceMap : state) {
+                       if (null != namespaceMap) {
+                               Map<K, S> keyMap = namespaceMap.get(namespace);
+                               count += keyMap != null ? keyMap.size() : 0;
+                       }
+               }
+
+               return count;
+       }
+
+       @Override
+       public <T> void transform(N namespace, T value, 
StateTransformationFunction<S, T> transformation) throws Exception {
+               final K key = keyContext.getCurrentKey();
+               checkKeyNamespacePreconditions(key, namespace);
+               final int keyGroupIndex = keyContext.getCurrentKeyGroupIndex();
+
+               Map<N, Map<K, S>> namespaceMap = 
getMapForKeyGroup(keyGroupIndex);
+
+               if (namespaceMap == null) {
+                       namespaceMap = new HashMap<>();
+                       setMapForKeyGroup(keyGroupIndex, namespaceMap);
+               }
+
+               Map<K, S> keyedMap = namespaceMap.get(namespace);
+
+               if (keyedMap == null) {
+                       keyedMap = new HashMap<>();
+                       namespaceMap.put(namespace, keyedMap);
+               }
+
+               keyedMap.put(key, transformation.apply(keyedMap.get(key), 
value));
+       }
+
+       // snapshots 
---------------------------------------------------------------------------------------------------
+
+       private static <K, N, S> int countMappingsInKeyGroup(final Map<N, 
Map<K, S>> keyGroupMap) {
+               int count = 0;
+               for (Map<K, S> namespaceMap : keyGroupMap.values()) {
+                       count += namespaceMap.size();
+               }
+
+               return count;
+       }
+
+       @Override
+       public NestedMapsStateTableSnapshot<K, N, S> createSnapshot() {
+               return new NestedMapsStateTableSnapshot<>(this);
+       }
+
+       /**
+        * This class encapsulates the snapshot logic.
+        *
+        * @param <K> type of key.
+        * @param <N> type of namespace.
+        * @param <S> type of state.
+        */
+       static class NestedMapsStateTableSnapshot<K, N, S>
+                       extends AbstractStateTableSnapshot<K, N, S, 
NestedMapsStateTable<K, N, S>> {
+
+               NestedMapsStateTableSnapshot(NestedMapsStateTable<K, N, S> 
owningTable) {
+                       super(owningTable);
+               }
+
+               /**
+                * Implementation note: we currently chose the same format 
between {@link NestedMapsStateTable} and
+                * {@link CopyOnWriteStateTable}.
+                * <p>
+                * {@link NestedMapsStateTable} could naturally support a kind 
of
+                * prefix-compressed format (grouping by namespace, writing the 
namespace only once per group instead for each
+                * mapping). We might implement support for different formats 
later (tailored towards different state table
+                * implementations).
+                */
+               @Override
+               public void writeMappingsInKeyGroup(DataOutputView dov, int 
keyGroupId) throws IOException {
+                       final Map<N, Map<K, S>> keyGroupMap = 
owningStateTable.getMapForKeyGroup(keyGroupId);
+                       if (null != keyGroupMap) {
+                               TypeSerializer<K> keySerializer = 
owningStateTable.keyContext.getKeySerializer();
+                               TypeSerializer<N> namespaceSerializer = 
owningStateTable.metaInfo.getNamespaceSerializer();
+                               TypeSerializer<S> stateSerializer = 
owningStateTable.metaInfo.getStateSerializer();
+                               
dov.writeInt(countMappingsInKeyGroup(keyGroupMap));
+                               for (Map.Entry<N, Map<K, S>> namespaceEntry : 
keyGroupMap.entrySet()) {
+                                       final N namespace = 
namespaceEntry.getKey();
+                                       final Map<K, S> namespaceMap = 
namespaceEntry.getValue();
+
+                                       for (Map.Entry<K, S> keyEntry : 
namespaceMap.entrySet()) {
+                                               
namespaceSerializer.serialize(namespace, dov);
+                                               
keySerializer.serialize(keyEntry.getKey(), dov);
+                                               
stateSerializer.serialize(keyEntry.getValue(), dov);
+                                       }
+                               }
+                       } else {
+                               dov.writeInt(0);
+                       }
+               }
+       }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/ab014ef9/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateEntry.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateEntry.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateEntry.java
new file mode 100644
index 0000000..8e29cb2
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateEntry.java
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state.heap;
+
+/**
+ * Interface of entries in a state table. Entries are triple of key, 
namespace, and state.
+ *
+ * @param <K> type of key.
+ * @param <N> type of namespace.
+ * @param <S> type of state.
+ */
+public interface StateEntry<K, N, S> {
+
+       /**
+        * Returns the key of this entry.
+        */
+       K getKey();
+
+       /**
+        * Returns the namespace of this entry.
+        */
+       N getNamespace();
+
+       /**
+        * Returns the state of this entry.
+        */
+       S getState();
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/ab014ef9/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTable.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTable.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTable.java
index 21265f4..62fc869 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTable.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTable.java
@@ -7,7 +7,7 @@
  * "License"); you may not use this file except in compliance
  * with the License.  You may obtain a copy of the License at
  *
- *     http://www.apache.org/licenses/LICENSE-2.0
+ * http://www.apache.org/licenses/LICENSE-2.0
  *
  * Unless required by applicable law or agreed to in writing, software
  * distributed under the License is distributed on an "AS IS" BASIS,
@@ -15,72 +15,152 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
+
 package org.apache.flink.runtime.state.heap;
 
 import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.runtime.state.RegisteredBackendStateMetaInfo;
-import org.apache.flink.runtime.state.KeyGroupRange;
-
-import java.util.Map;
-
-public class StateTable<K, N, ST> {
-
-       /** Map for holding the actual state objects. */
-       private final Map<N, Map<K, ST>>[] state;
-
-       /** The offset to the contiguous key groups */
-       private final int keyGroupOffset;
-
-       /** Combined meta information such as name and serializers for this 
state */
-       private RegisteredBackendStateMetaInfo<N, ST> metaInfo;
-
-       // 
------------------------------------------------------------------------
-       public StateTable(RegisteredBackendStateMetaInfo<N, ST> metaInfo, 
KeyGroupRange keyGroupRange) {
-               this.metaInfo = metaInfo;
-               this.keyGroupOffset = keyGroupRange.getStartKeyGroup();
-
-               @SuppressWarnings("unchecked")
-               Map<N, Map<K, ST>>[] state = (Map<N, Map<K, ST>>[]) new 
Map[keyGroupRange.getNumberOfKeyGroups()];
-               this.state = state;
-       }
-
-       // 
------------------------------------------------------------------------
-       //  access to maps
-       // 
------------------------------------------------------------------------
+import org.apache.flink.runtime.state.StateTransformationFunction;
+import org.apache.flink.util.Preconditions;
 
-       public Map<N, Map<K, ST>>[] getState() {
-               return state;
-       }
-
-       public Map<N, Map<K, ST>> get(int index) {
-               final int pos = indexToOffset(index);
-               if (pos >= 0 && pos < state.length) {
-                       return state[pos];
-               } else {
-                       return null;
-               }
-       }
-
-       public void set(int index, Map<N, Map<K, ST>> map) {
-               try {
-                       state[indexToOffset(index)] = map;
-               }
-               catch (ArrayIndexOutOfBoundsException e) {
-                       throw new IllegalArgumentException("Key group index out 
of range of key group range [" +
-                                       keyGroupOffset + ", " + (keyGroupOffset 
+ state.length) + ").");
-               }
+/**
+ * Base class for state tables. Accesses to state are typically scoped by the 
currently active key, as provided
+ * through the {@link InternalKeyContext}.
+ *
+ * @param <K> type of key
+ * @param <N> type of namespace
+ * @param <S> type of state
+ */
+public abstract class StateTable<K, N, S> {
+
+       /**
+        * The key context view on the backend. This provides information, such 
as the currently active key.
+        */
+       protected final InternalKeyContext<K> keyContext;
+
+       /**
+        * Combined meta information such as name and serializers for this state
+        */
+       protected RegisteredBackendStateMetaInfo<N, S> metaInfo;
+
+       /**
+        *
+        * @param keyContext the key context provides the key scope for all 
put/get/delete operations.
+        * @param metaInfo the meta information, including the type serializer 
for state copy-on-write.
+        */
+       public StateTable(InternalKeyContext<K> keyContext, 
RegisteredBackendStateMetaInfo<N, S> metaInfo) {
+               this.keyContext = Preconditions.checkNotNull(keyContext);
+               this.metaInfo = Preconditions.checkNotNull(metaInfo);
        }
 
-       private int indexToOffset(int index) {
-               return index - keyGroupOffset;
+       // Main interface methods of StateTable 
-------------------------------------------------------
+
+       /**
+        * Returns whether this {@link NestedMapsStateTable} is empty.
+        *
+        * @return {@code true} if this {@link NestedMapsStateTable} has no 
elements, {@code false}
+        * otherwise.
+        * @see #size()
+        */
+       public boolean isEmpty() {
+               return size() == 0;
        }
 
-       // 
------------------------------------------------------------------------
-       //  metadata
-       // 
------------------------------------------------------------------------
-       
-       public TypeSerializer<ST> getStateSerializer() {
+       /**
+        * Returns the total number of entries in this {@link 
NestedMapsStateTable}. This is the sum of both sub-tables.
+        *
+        * @return the number of entries in this {@link NestedMapsStateTable}.
+        */
+       public abstract int size();
+
+       /**
+        * Returns the state of the mapping for the composite of active key and 
given namespace.
+        *
+        * @param namespace the namespace. Not null.
+        * @return the states of the mapping with the specified key/namespace 
composite key, or {@code null}
+        * if no mapping for the specified key is found.
+        */
+       public abstract S get(N namespace);
+
+       /**
+        * Returns whether this table contains a mapping for the composite of 
active key and given namespace.
+        *
+        * @param namespace the namespace in the composite key to search for. 
Not null.
+        * @return {@code true} if this map contains the specified 
key/namespace composite key,
+        * {@code false} otherwise.
+        */
+       public abstract boolean containsKey(N namespace);
+
+       /**
+        * Maps the composite of active key and given namespace to the 
specified state. This method should be preferred
+        * over {@link #putAndGetOld(N, S)} (Namespace, State)} when the caller 
is not interested in the old state.
+        *
+        * @param namespace the namespace. Not null.
+        * @param state     the state. Can be null.
+        */
+       public abstract void put(N namespace, S state);
+
+       /**
+        * Maps the composite of active key and given namespace to the 
specified state. Returns the previous state that
+        * was registered under the composite key.
+        *
+        * @param namespace the namespace. Not null.
+        * @param state     the state. Can be null.
+        * @return the state of any previous mapping with the specified key or
+        * {@code null} if there was no such mapping.
+        */
+       public abstract S putAndGetOld(N namespace, S state);
+
+       /**
+        * Removes the mapping for the composite of active key and given 
namespace. This method should be preferred
+        * over {@link #removeAndGetOld(N)} when the caller is not interested 
in the old state.
+        *
+        * @param namespace the namespace of the mapping to remove. Not null.
+        */
+       public abstract void remove(N namespace);
+
+       /**
+        * Removes the mapping for the composite of active key and given 
namespace, returning the state that was
+        * found under the entry.
+        *
+        * @param namespace the namespace of the mapping to remove. Not null.
+        * @return the state of the removed mapping or {@code null} if no 
mapping
+        * for the specified key was found.
+        */
+       public abstract S removeAndGetOld(N namespace);
+
+       /**
+        * Applies the given {@link StateTransformationFunction} to the state 
(1st input argument), using the given value as
+        * second input argument. The result of {@link 
StateTransformationFunction#apply(Object, Object)} is then stored as
+        * the new state. This function is basically an optimization for 
get-update-put pattern.
+        *
+        * @param namespace      the namespace. Not null.
+        * @param value          the value to use in transforming the state. 
Can be null.
+        * @param transformation the transformation function.
+        * @throws Exception if some exception happens in the transformation 
function.
+        */
+       public abstract <T> void transform(
+                       N namespace,
+                       T value,
+                       StateTransformationFunction<S, T> transformation) 
throws Exception;
+
+       // For queryable state 
------------------------------------------------------------------------
+
+       /**
+        * Returns the state for the composite of active key and given 
namespace. This is typically used by
+        * queryable state.
+        *
+        * @param key       the key. Not null.
+        * @param namespace the namespace. Not null.
+        * @return the state of the mapping with the specified key/namespace 
composite key, or {@code null}
+        * if no mapping for the specified key is found.
+        */
+       public abstract S get(K key, N namespace);
+
+       // Meta data setter / getter and toString 
-----------------------------------------------------
+
+       public TypeSerializer<S> getStateSerializer() {
                return metaInfo.getStateSerializer();
        }
 
@@ -88,28 +168,22 @@ public class StateTable<K, N, ST> {
                return metaInfo.getNamespaceSerializer();
        }
 
-       public RegisteredBackendStateMetaInfo<N, ST> getMetaInfo() {
+       public RegisteredBackendStateMetaInfo<N, S> getMetaInfo() {
                return metaInfo;
        }
 
-       public void setMetaInfo(RegisteredBackendStateMetaInfo<N, ST> metaInfo) 
{
+       public void setMetaInfo(RegisteredBackendStateMetaInfo<N, S> metaInfo) {
                this.metaInfo = metaInfo;
        }
 
-       // 
------------------------------------------------------------------------
-       //  for testing
-       // 
------------------------------------------------------------------------
+       // Snapshot / Restore 
-------------------------------------------------------------------------
+
+       abstract StateTableSnapshot createSnapshot();
+
+       public abstract void put(K key, int keyGroup, N namespace, S state);
+
+       // For testing 
--------------------------------------------------------------------------------
 
        @VisibleForTesting
-       boolean isEmpty() {
-               for (Map<N, Map<K, ST>> map : state) {
-                       if (map != null) {
-                               if (!map.isEmpty()) {
-                                       return false;
-                               }
-                       }
-               }
-
-               return true;
-       }
-}
+       public abstract int sizeOfNamespace(Object namespace);
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/ab014ef9/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTableByKeyGroupReader.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTableByKeyGroupReader.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTableByKeyGroupReader.java
new file mode 100644
index 0000000..659c174
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTableByKeyGroupReader.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state.heap;
+
+import org.apache.flink.core.memory.DataInputView;
+
+import java.io.IOException;
+
+/**
+ * Interface for state de-serialization into {@link StateTable}s by key-group.
+ */
+interface StateTableByKeyGroupReader {
+
+       /**
+        * Read the data for the specified key-group from the input.
+        *
+        * @param div        the input
+        * @param keyGroupId the key-group to write
+        * @throws IOException on write related problems
+        */
+       void readMappingsInKeyGroup(DataInputView div, int keyGroupId) throws 
IOException;
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/ab014ef9/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTableByKeyGroupReaders.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTableByKeyGroupReaders.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTableByKeyGroupReaders.java
new file mode 100644
index 0000000..53ec349
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTableByKeyGroupReaders.java
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state.heap;
+
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.memory.DataInputView;
+
+import java.io.IOException;
+
+/**
+ * This class provides a static factory method to create different 
implementations of {@link StateTableByKeyGroupReader}
+ * depending on the provided serialization format version.
+ * <p>
+ * The implementations are also located here as inner classes.
+ */
+class StateTableByKeyGroupReaders {
+
+       /**
+        * Creates a new StateTableByKeyGroupReader that inserts de-serialized 
mappings into the given table, using the
+        * de-serialization algorithm that matches the given version.
+        *
+        * @param table the {@link StateTable} into which de-serialized 
mappings are inserted.
+        * @param version version for the de-serialization algorithm.
+        * @param <K> type of key.
+        * @param <N> type of namespace.
+        * @param <S> type of state.
+        * @return the appropriate reader.
+        */
+       static <K, N, S> StateTableByKeyGroupReader 
readerForVersion(StateTable<K, N, S> table, int version) {
+               switch (version) {
+                       case 1:
+                               return new 
StateTableByKeyGroupReaderV1<>(table);
+                       case 2:
+                               return new 
StateTableByKeyGroupReaderV2<>(table);
+                       default:
+                               throw new IllegalArgumentException("Unknown 
version: " + version);
+               }
+       }
+
+       static abstract class AbstractStateTableByKeyGroupReader<K, N, S>
+                       implements StateTableByKeyGroupReader {
+
+               protected final StateTable<K, N, S> stateTable;
+
+               AbstractStateTableByKeyGroupReader(StateTable<K, N, S> 
stateTable) {
+                       this.stateTable = stateTable;
+               }
+
+               @Override
+               public abstract void readMappingsInKeyGroup(DataInputView div, 
int keyGroupId) throws IOException;
+
+               protected TypeSerializer<K> getKeySerializer() {
+                       return stateTable.keyContext.getKeySerializer();
+               }
+
+               protected TypeSerializer<N> getNamespaceSerializer() {
+                       return stateTable.getNamespaceSerializer();
+               }
+
+               protected TypeSerializer<S> getStateSerializer() {
+                       return stateTable.getStateSerializer();
+               }
+       }
+
+       static final class StateTableByKeyGroupReaderV1<K, N, S>
+                       extends AbstractStateTableByKeyGroupReader<K, N, S> {
+
+               StateTableByKeyGroupReaderV1(StateTable<K, N, S> stateTable) {
+                       super(stateTable);
+               }
+
+               @Override
+               public void readMappingsInKeyGroup(DataInputView inView, int 
keyGroupId) throws IOException {
+
+                       if (inView.readByte() == 0) {
+                               return;
+                       }
+
+                       final TypeSerializer<K> keySerializer = 
getKeySerializer();
+                       final TypeSerializer<N> namespaceSerializer = 
getNamespaceSerializer();
+                       final TypeSerializer<S> stateSerializer = 
getStateSerializer();
+
+                       // V1 uses kind of namespace compressing format
+                       int numNamespaces = inView.readInt();
+                       for (int k = 0; k < numNamespaces; k++) {
+                               N namespace = 
namespaceSerializer.deserialize(inView);
+                               int numEntries = inView.readInt();
+                               for (int l = 0; l < numEntries; l++) {
+                                       K key = 
keySerializer.deserialize(inView);
+                                       S state = 
stateSerializer.deserialize(inView);
+                                       stateTable.put(key, keyGroupId, 
namespace, state);
+                               }
+                       }
+               }
+       }
+
+       private static final class StateTableByKeyGroupReaderV2<K, N, S>
+                       extends AbstractStateTableByKeyGroupReader<K, N, S> {
+
+               StateTableByKeyGroupReaderV2(StateTable<K, N, S> stateTable) {
+                       super(stateTable);
+               }
+
+               @Override
+               public void readMappingsInKeyGroup(DataInputView inView, int 
keyGroupId) throws IOException {
+
+                       final TypeSerializer<K> keySerializer = 
getKeySerializer();
+                       final TypeSerializer<N> namespaceSerializer = 
getNamespaceSerializer();
+                       final TypeSerializer<S> stateSerializer = 
getStateSerializer();
+
+                       int numKeys = inView.readInt();
+                       for (int i = 0; i < numKeys; ++i) {
+                               N namespace = 
namespaceSerializer.deserialize(inView);
+                               K key = keySerializer.deserialize(inView);
+                               S state = stateSerializer.deserialize(inView);
+                               stateTable.put(key, keyGroupId, namespace, 
state);
+                       }
+               }
+       }
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/ab014ef9/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTableSnapshot.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTableSnapshot.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTableSnapshot.java
new file mode 100644
index 0000000..d4244d7
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTableSnapshot.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state.heap;
+
+import org.apache.flink.core.memory.DataOutputView;
+
+import java.io.IOException;
+
+/**
+ * Interface for the snapshots of a {@link StateTable}. Offers a way to 
serialize the snapshot (by key-group). All
+ * snapshots should be released after usage.
+ */
+interface StateTableSnapshot {
+
+       /**
+        * Writes the data for the specified key-group to the output.
+        *
+        * @param dov the output
+        * @param keyGroupId the key-group to write
+        * @throws IOException on write related problems
+        */
+       void writeMappingsInKeyGroup(DataOutputView dov, int keyGroupId) throws 
IOException;
+
+       /**
+        * Release the snapshot. All snapshots should be released when they are 
no longer used because some implementation
+        * can only release resources after a release.
+        */
+       void release();
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/ab014ef9/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java
index da01c09..f0bac1b 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java
@@ -45,6 +45,9 @@ public class MemoryStateBackend extends AbstractStateBackend {
        /** The maximal size that the snapshotted memory state may have */
        private final int maxStateSize;
 
+       /** Switch to chose between synchronous and asynchronous snapshots */
+       private final boolean asynchronousSnapshots;
+
        /**
         * Creates a new memory state backend that accepts states whose 
serialized forms are
         * up to the default state size (5 MB).
@@ -60,7 +63,29 @@ public class MemoryStateBackend extends AbstractStateBackend 
{
         * @param maxStateSize The maximal size of the serialized state
         */
        public MemoryStateBackend(int maxStateSize) {
+               this(maxStateSize, false);
+       }
+
+       /**
+        * Creates a new memory state backend that accepts states whose 
serialized forms are
+        * up to the default state size (5 MB).
+        *
+        * @param asynchronousSnapshots Switch to enable asynchronous snapshots.
+        */
+       public MemoryStateBackend(boolean asynchronousSnapshots) {
+               this(DEFAULT_MAX_STATE_SIZE, asynchronousSnapshots);
+       }
+
+       /**
+        * Creates a new memory state backend that accepts states whose 
serialized forms are
+        * up to the given number of bytes.
+        *
+        * @param maxStateSize The maximal size of the serialized state
+        * @param asynchronousSnapshots Switch to enable asynchronous snapshots.
+        */
+       public MemoryStateBackend(int maxStateSize, boolean 
asynchronousSnapshots) {
                this.maxStateSize = maxStateSize;
+               this.asynchronousSnapshots = asynchronousSnapshots;
        }
 
        @Override
@@ -98,6 +123,7 @@ public class MemoryStateBackend extends AbstractStateBackend 
{
                                env.getUserClassLoader(),
                                numberOfKeyGroups,
                                keyGroupRange,
+                               asynchronousSnapshots,
                                env.getExecutionConfig());
        }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/ab014ef9/flink-runtime/src/test/java/org/apache/flink/runtime/query/QueryableStateClientTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/query/QueryableStateClientTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/query/QueryableStateClientTest.java
index 2c385c1..b4d6eb7 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/query/QueryableStateClientTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/query/QueryableStateClientTest.java
@@ -33,12 +33,12 @@ import org.apache.flink.runtime.query.netty.KvStateServer;
 import org.apache.flink.runtime.query.netty.UnknownKvStateID;
 import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer;
 import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
-import org.apache.flink.runtime.state.RegisteredBackendStateMetaInfo;
 import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.RegisteredBackendStateMetaInfo;
 import org.apache.flink.runtime.state.VoidNamespace;
 import org.apache.flink.runtime.state.VoidNamespaceSerializer;
 import org.apache.flink.runtime.state.heap.HeapValueState;
-import org.apache.flink.runtime.state.heap.StateTable;
+import org.apache.flink.runtime.state.heap.NestedMapsStateTable;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
 import org.apache.flink.util.MathUtils;
 import org.junit.AfterClass;
@@ -278,9 +278,8 @@ public class QueryableStateClientTest {
 
                                // Register state
                                HeapValueState<Integer, VoidNamespace, Integer> 
kvState = new HeapValueState<>(
-                                               keyedStateBackend,
                                                descriptor,
-                                               new StateTable<Integer, 
VoidNamespace, Integer>(registeredBackendStateMetaInfo, new KeyGroupRange(0, 
1)),
+                                               new 
NestedMapsStateTable<Integer, VoidNamespace, Integer>(keyedStateBackend, 
registeredBackendStateMetaInfo),
                                                IntSerializer.INSTANCE,
                                                
VoidNamespaceSerializer.INSTANCE);
 

http://git-wip-us.apache.org/repos/asf/flink/blob/ab014ef9/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/message/KvStateRequestSerializerTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/message/KvStateRequestSerializerTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/message/KvStateRequestSerializerTest.java
index 93094a4..4ed63a2 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/message/KvStateRequestSerializerTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/query/netty/message/KvStateRequestSerializerTest.java
@@ -21,12 +21,10 @@ package org.apache.flink.runtime.query.netty.message;
 import io.netty.buffer.ByteBuf;
 import io.netty.buffer.ByteBufAllocator;
 import io.netty.buffer.UnpooledByteBufAllocator;
-
 import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.state.ListStateDescriptor;
 import org.apache.flink.api.common.state.MapStateDescriptor;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
-import org.apache.flink.api.common.typeutils.base.ByteSerializer;
 import org.apache.flink.api.common.typeutils.base.LongSerializer;
 import org.apache.flink.api.common.typeutils.base.StringSerializer;
 import org.apache.flink.api.java.tuple.Tuple2;
@@ -38,13 +36,16 @@ import 
org.apache.flink.runtime.state.VoidNamespaceSerializer;
 import org.apache.flink.runtime.state.heap.HeapKeyedStateBackend;
 import org.apache.flink.runtime.state.internal.InternalKvState;
 import org.apache.flink.runtime.state.internal.InternalListState;
-
 import org.apache.flink.runtime.state.internal.InternalMapState;
 import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
 
 import java.io.ByteArrayOutputStream;
 import java.io.IOException;
 import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -54,10 +55,19 @@ import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
 import static org.mockito.Mockito.mock;
 
+@RunWith(Parameterized.class)
 public class KvStateRequestSerializerTest {
 
        private final ByteBufAllocator alloc = UnpooledByteBufAllocator.DEFAULT;
 
+       @Parameterized.Parameters
+       public static Collection<Boolean> parameters() {
+               return Arrays.asList(false, true);
+       }
+
+       @Parameterized.Parameter
+       public boolean async;
+
        /**
         * Tests KvState request serialization.
         */
@@ -332,7 +342,9 @@ public class KvStateRequestSerializerTest {
                                mock(TaskKvStateRegistry.class),
                                LongSerializer.INSTANCE,
                                ClassLoader.getSystemClassLoader(),
-                               1, new KeyGroupRange(0, 0),
+                               1,
+                               new KeyGroupRange(0, 0),
+                               async,
                                new ExecutionConfig()
                        );
                longHeapKeyedStateBackend.setCurrentKey(key);
@@ -418,7 +430,7 @@ public class KvStateRequestSerializerTest {
                KvStateRequestSerializer.deserializeList(new byte[] {1, 1, 1, 
1, 1, 1, 1, 1, 2, 3},
                        LongSerializer.INSTANCE);
        }
-       
+
        /**
         * Tests map serialization utils.
         */
@@ -429,11 +441,13 @@ public class KvStateRequestSerializerTest {
                // objects for heap state list serialisation
                final HeapKeyedStateBackend<Long> longHeapKeyedStateBackend =
                        new HeapKeyedStateBackend<>(
-                               mock(TaskKvStateRegistry.class),
-                               LongSerializer.INSTANCE,
-                               ClassLoader.getSystemClassLoader(),
-                               1, new KeyGroupRange(0, 0),
-                               new ExecutionConfig()
+                                       mock(TaskKvStateRegistry.class),
+                                       LongSerializer.INSTANCE,
+                                       ClassLoader.getSystemClassLoader(),
+                                       1,
+                                       new KeyGroupRange(0, 0),
+                                       async,
+                                       new ExecutionConfig()
                        );
                longHeapKeyedStateBackend.setCurrentKey(key);
 
@@ -481,7 +495,7 @@ public class KvStateRequestSerializerTest {
                        KvStateRequestSerializer.serializeKeyAndNamespace(
                                key, LongSerializer.INSTANCE,
                                VoidNamespace.INSTANCE, 
VoidNamespaceSerializer.INSTANCE);
-               
+
                final byte[] serializedValues = 
mapState.getSerializedValue(serializedKey);
 
                Map<Long, String> actualValues = 
KvStateRequestSerializer.deserializeMap(serializedValues, userKeySerializer, 
userValueSerializer);
@@ -534,7 +548,7 @@ public class KvStateRequestSerializerTest {
                KvStateRequestSerializer.deserializeMap(new byte[]{1, 1, 1, 1, 
1, 1, 1, 1, 0},
                                LongSerializer.INSTANCE, 
LongSerializer.INSTANCE);
        }
-       
+
        /**
         * Tests map deserialization with too few bytes.
         */

http://git-wip-us.apache.org/repos/asf/flink/blob/ab014ef9/flink-runtime/src/test/java/org/apache/flink/runtime/state/AsyncFileStateBackendTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/AsyncFileStateBackendTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/AsyncFileStateBackendTest.java
new file mode 100644
index 0000000..dd73e42
--- /dev/null
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/AsyncFileStateBackendTest.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+public class AsyncFileStateBackendTest extends FileStateBackendTest {
+
+       @Override
+       protected boolean useAsyncMode() {
+               return true;
+       }
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/ab014ef9/flink-runtime/src/test/java/org/apache/flink/runtime/state/AsyncMemoryStateBackendTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/AsyncMemoryStateBackendTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/AsyncMemoryStateBackendTest.java
new file mode 100644
index 0000000..ba4a89d
--- /dev/null
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/AsyncMemoryStateBackendTest.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+public class AsyncMemoryStateBackendTest extends MemoryStateBackendTest {
+
+       @Override
+       protected boolean useAsyncMode() {
+               return true;
+       }
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/ab014ef9/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java
index 75014e7..6be2343 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/FileStateBackendTest.java
@@ -25,6 +25,7 @@ import org.apache.flink.core.testutils.CommonTestUtils;
 import org.apache.flink.runtime.state.filesystem.FileStateHandle;
 import org.apache.flink.runtime.state.filesystem.FsStateBackend;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
+import org.junit.Ignore;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.TemporaryFolder;
@@ -50,7 +51,11 @@ public class FileStateBackendTest extends 
StateBackendTestBase<FsStateBackend> {
        @Override
        protected FsStateBackend getStateBackend() throws Exception {
                File checkpointPath = tempFolder.newFolder();
-               return new FsStateBackend(localFileUri(checkpointPath));
+               return new FsStateBackend(localFileUri(checkpointPath), 
useAsyncMode());
+       }
+
+       protected boolean useAsyncMode() {
+               return false;
        }
 
        // disable these because the verification does not work for this state 
backend
@@ -208,6 +213,7 @@ public class FileStateBackendTest extends 
StateBackendTestBase<FsStateBackend> {
                }
        }
 
+       @Ignore
        @Test
        public void testConcurrentMapIfQueryable() throws Exception {
                super.testConcurrentMapIfQueryable();

http://git-wip-us.apache.org/repos/asf/flink/blob/ab014ef9/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java
index 362fcd6..48d56e2 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java
@@ -25,6 +25,7 @@ import org.apache.flink.api.common.state.ValueStateDescriptor;
 import org.apache.flink.api.common.typeutils.base.IntSerializer;
 import org.apache.flink.runtime.state.heap.HeapKeyedStateBackend;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
+import org.junit.Ignore;
 import org.junit.Test;
 
 import java.io.IOException;
@@ -44,7 +45,11 @@ public class MemoryStateBackendTest extends 
StateBackendTestBase<MemoryStateBack
 
        @Override
        protected MemoryStateBackend getStateBackend() throws Exception {
-               return new MemoryStateBackend();
+               return new MemoryStateBackend(useAsyncMode());
+       }
+
+       protected boolean useAsyncMode() {
+               return false;
        }
 
        // disable these because the verification does not work for this state 
backend
@@ -193,6 +198,7 @@ public class MemoryStateBackendTest extends 
StateBackendTestBase<MemoryStateBack
                }
        }
 
+       @Ignore
        @Test
        public void testConcurrentMapIfQueryable() throws Exception {
                super.testConcurrentMapIfQueryable();

http://git-wip-us.apache.org/repos/asf/flink/blob/ab014ef9/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
index 40ac72c..331c6bd 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
@@ -42,6 +42,7 @@ import 
org.apache.flink.api.common.typeutils.base.StringSerializer;
 import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
 import org.apache.flink.core.testutils.CheckedThread;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.core.testutils.OneShotLatch;
 import org.apache.flink.runtime.checkpoint.StateAssignmentOperation;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
@@ -50,10 +51,15 @@ import org.apache.flink.runtime.query.KvStateRegistry;
 import org.apache.flink.runtime.query.KvStateRegistryListener;
 import org.apache.flink.runtime.query.netty.message.KvStateRequestSerializer;
 import org.apache.flink.runtime.state.heap.AbstractHeapState;
+import org.apache.flink.runtime.state.heap.NestedMapsStateTable;
 import org.apache.flink.runtime.state.heap.StateTable;
 import org.apache.flink.runtime.state.internal.InternalKvState;
+import org.apache.flink.runtime.state.internal.InternalValueState;
+import org.apache.flink.runtime.util.BlockerCheckpointStreamFactory;
 import org.apache.flink.types.IntValue;
+import org.apache.flink.util.IOUtils;
 import org.apache.flink.util.TestLogger;
+import org.junit.Assert;
 import org.junit.Test;
 
 import java.io.IOException;
@@ -67,6 +73,7 @@ import java.util.Random;
 import java.util.Timer;
 import java.util.TimerTask;
 import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ExecutionException;
 import java.util.concurrent.RunnableFuture;
 
 import static org.hamcrest.Matchers.containsInAnyOrder;
@@ -791,7 +798,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                        fail(e.getMessage());
                }
        }
-       
+
        @Test
        @SuppressWarnings("unchecked,rawtypes")
        public void testMapState() {
@@ -823,7 +830,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                        backend.setCurrentKey(1);
                        assertTrue(state.contains(1));
                        assertEquals("1", state.get(1));
-                       assertEquals(new HashMap<Integer, String>() {{ put (1, 
"1"); }}, 
+                       assertEquals(new HashMap<Integer, String>() {{ put (1, 
"1"); }},
                                        getSerializedMap(kvState, 1, 
keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, userKeySerializer, 
userValueSerializer));
 
                        // draw a snapshot
@@ -848,12 +855,12 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                                        getSerializedMap(kvState, 1, 
keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, userKeySerializer, 
userValueSerializer));
                        backend.setCurrentKey(2);
                        assertEquals("102", state.get(102));
-                       assertEquals(new HashMap<Integer, String>() {{ put(2, 
"2"); put(102, "102"); }}, 
+                       assertEquals(new HashMap<Integer, String>() {{ put(2, 
"2"); put(102, "102"); }},
                                        getSerializedMap(kvState, 2, 
keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, userKeySerializer, 
userValueSerializer));
                        backend.setCurrentKey(3);
                        assertTrue(state.contains(103));
                        assertEquals("103", state.get(103));
-                       assertEquals(new HashMap<Integer, String>() {{ put(103, 
"103"); put(1031, "1031"); put(1032, "1032"); }}, 
+                       assertEquals(new HashMap<Integer, String>() {{ put(103, 
"103"); put(1031, "1031"); put(1032, "1032"); }},
                                        getSerializedMap(kvState, 3, 
keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, userKeySerializer, 
userValueSerializer));
 
                        List<Integer> keys = new ArrayList<>();
@@ -912,11 +919,11 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
 
                        backend.setCurrentKey(1);
                        assertEquals("1", restored1.get(1));
-                       assertEquals(new HashMap<Integer, String>() {{ put (1, 
"1"); }}, 
+                       assertEquals(new HashMap<Integer, String>() {{ put (1, 
"1"); }},
                                        getSerializedMap(restoredKvState1, 1, 
keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, userKeySerializer, 
userValueSerializer));
                        backend.setCurrentKey(2);
                        assertEquals("2", restored1.get(2));
-                       assertEquals(new HashMap<Integer, String>() {{ put (2, 
"2"); }}, 
+                       assertEquals(new HashMap<Integer, String>() {{ put (2, 
"2"); }},
                                        getSerializedMap(restoredKvState1, 2, 
keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, userKeySerializer, 
userValueSerializer));
 
                        backend.dispose();
@@ -931,15 +938,15 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
 
                        backend.setCurrentKey(1);
                        assertEquals("101", restored2.get(1));
-                       assertEquals(new HashMap<Integer, String>() {{ put (1, 
"101"); }}, 
+                       assertEquals(new HashMap<Integer, String>() {{ put (1, 
"101"); }},
                                        getSerializedMap(restoredKvState2, 1, 
keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, userKeySerializer, 
userValueSerializer));
                        backend.setCurrentKey(2);
                        assertEquals("102", restored2.get(102));
-                       assertEquals(new HashMap<Integer, String>() {{ put(2, 
"2"); put (102, "102"); }}, 
+                       assertEquals(new HashMap<Integer, String>() {{ put(2, 
"2"); put (102, "102"); }},
                                        getSerializedMap(restoredKvState2, 2, 
keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, userKeySerializer, 
userValueSerializer));
                        backend.setCurrentKey(3);
                        assertEquals("103", restored2.get(103));
-                       assertEquals(new HashMap<Integer, String>() {{ put(103, 
"103"); put(1031, "1031"); put(1032, "1032"); }}, 
+                       assertEquals(new HashMap<Integer, String>() {{ put(103, 
"103"); put(1031, "1031"); put(1032, "1032"); }},
                                        getSerializedMap(restoredKvState2, 3, 
keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, userKeySerializer, 
userValueSerializer));
 
                        backend.dispose();
@@ -1111,7 +1118,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
 
                backend.dispose();
        }
-       
+
        /**
         * This test verifies that state is correctly assigned to key groups 
and that restore
         * restores the relevant key groups in the backend.
@@ -1364,7 +1371,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                        fail(e.getMessage());
                }
        }
-       
+
        @Test
        @SuppressWarnings("unchecked")
        public void testMapStateRestoreWithWrongSerializers() {
@@ -1507,11 +1514,8 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                        backend.setCurrentKey(1);
                        state.update(121818273);
 
-                       int keyGroupIndex = 
KeyGroupRangeAssignment.assignToKeyGroup(1, numberOfKeyGroups);
-                       StateTable stateTable = ((AbstractHeapState) 
kvState).getStateTable();
-                       assertNotNull("State not set", 
stateTable.get(keyGroupIndex));
-                       assertTrue(stateTable.get(keyGroupIndex) instanceof 
ConcurrentHashMap);
-                       
assertTrue(stateTable.get(keyGroupIndex).get(VoidNamespace.INSTANCE) instanceof 
ConcurrentHashMap);
+                       StateTable<?, ?, ?> stateTable = ((AbstractHeapState<?, 
?,? ,?, ?>) kvState).getStateTable();
+                       checkConcurrentStateTable(stateTable, 
numberOfKeyGroups);
 
                }
 
@@ -1533,11 +1537,8 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                        backend.setCurrentKey(1);
                        state.add(121818273);
 
-                       int keyGroupIndex = 
KeyGroupRangeAssignment.assignToKeyGroup(1, numberOfKeyGroups);
-                       StateTable stateTable = ((AbstractHeapState) 
kvState).getStateTable();
-                       assertNotNull("State not set", 
stateTable.get(keyGroupIndex));
-                       assertTrue(stateTable.get(keyGroupIndex) instanceof 
ConcurrentHashMap);
-                       
assertTrue(stateTable.get(keyGroupIndex).get(VoidNamespace.INSTANCE) instanceof 
ConcurrentHashMap);
+                       StateTable<?, ?, ?> stateTable = ((AbstractHeapState<?, 
?,? ,?, ?>) kvState).getStateTable();
+                       checkConcurrentStateTable(stateTable, 
numberOfKeyGroups);
                }
 
                {
@@ -1564,11 +1565,8 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                        backend.setCurrentKey(1);
                        state.add(121818273);
 
-                       int keyGroupIndex = 
KeyGroupRangeAssignment.assignToKeyGroup(1, numberOfKeyGroups);
-                       StateTable stateTable = ((AbstractHeapState) 
kvState).getStateTable();
-                       assertNotNull("State not set", 
stateTable.get(keyGroupIndex));
-                       assertTrue(stateTable.get(keyGroupIndex) instanceof 
ConcurrentHashMap);
-                       
assertTrue(stateTable.get(keyGroupIndex).get(VoidNamespace.INSTANCE) instanceof 
ConcurrentHashMap);
+                       StateTable<?, ?, ?> stateTable = ((AbstractHeapState<?, 
?,? ,?, ?>) kvState).getStateTable();
+                       checkConcurrentStateTable(stateTable, 
numberOfKeyGroups);
                }
 
                {
@@ -1595,13 +1593,10 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                        backend.setCurrentKey(1);
                        state.add(121818273);
 
-                       int keyGroupIndex = 
KeyGroupRangeAssignment.assignToKeyGroup(1, numberOfKeyGroups);
-                       StateTable stateTable = ((AbstractHeapState) 
kvState).getStateTable();
-                       assertNotNull("State not set", 
stateTable.get(keyGroupIndex));
-                       assertTrue(stateTable.get(keyGroupIndex) instanceof 
ConcurrentHashMap);
-                       
assertTrue(stateTable.get(keyGroupIndex).get(VoidNamespace.INSTANCE) instanceof 
ConcurrentHashMap);
+                       StateTable<?, ?, ?> stateTable = ((AbstractHeapState<?, 
?,? ,?, ?>) kvState).getStateTable();
+                       checkConcurrentStateTable(stateTable, 
numberOfKeyGroups);
                }
-               
+
                {
                        // MapState
                        MapStateDescriptor<Integer, String> desc = new 
MapStateDescriptor<>("map-state", Integer.class, String.class);
@@ -1623,13 +1618,22 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                        int keyGroupIndex = 
KeyGroupRangeAssignment.assignToKeyGroup(1, numberOfKeyGroups);
                        StateTable stateTable = ((AbstractHeapState) 
kvState).getStateTable();
                        assertNotNull("State not set", 
stateTable.get(keyGroupIndex));
-                       assertTrue(stateTable.get(keyGroupIndex) instanceof 
ConcurrentHashMap);
-                       
assertTrue(stateTable.get(keyGroupIndex).get(VoidNamespace.INSTANCE) instanceof 
ConcurrentHashMap);
+                       checkConcurrentStateTable(stateTable, 
numberOfKeyGroups);
                }
 
                backend.dispose();
        }
 
+       private void checkConcurrentStateTable(StateTable<?, ?, ?> stateTable, 
int numberOfKeyGroups) {
+               assertNotNull("State not set", stateTable);
+               if (stateTable instanceof NestedMapsStateTable) {
+                       int keyGroupIndex = 
KeyGroupRangeAssignment.assignToKeyGroup(1, numberOfKeyGroups);
+                       NestedMapsStateTable<?, ?, ?> nestedMapsStateTable = 
(NestedMapsStateTable<?, ?, ?>) stateTable;
+                       
assertTrue(nestedMapsStateTable.getState()[keyGroupIndex] instanceof 
ConcurrentHashMap);
+                       
assertTrue(nestedMapsStateTable.getState()[keyGroupIndex].get(VoidNamespace.INSTANCE)
 instanceof ConcurrentHashMap);
+               }
+       }
+
        /**
         * Tests registration with the KvStateRegistry.
         */
@@ -1688,7 +1692,8 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                        ListStateDescriptor<String> kvId = new 
ListStateDescriptor<>("id", String.class);
 
                        // draw a snapshot
-                       KeyGroupsStateHandle snapshot = 
runSnapshot(backend.snapshot(682375462379L, 1, streamFactory, 
CheckpointOptions.forFullCheckpoint()));
+                       KeyGroupsStateHandle snapshot =
+                                       
runSnapshot(backend.snapshot(682375462379L, 1, streamFactory, 
CheckpointOptions.forFullCheckpoint()));
                        assertNull(snapshot);
                        backend.dispose();
 
@@ -1708,6 +1713,145 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                }
        }
 
+       @Test
+       public void testAsyncSnapshot() throws Exception {
+               OneShotLatch waiter = new OneShotLatch();
+               BlockerCheckpointStreamFactory streamFactory = new 
BlockerCheckpointStreamFactory(1024 * 1024);
+               streamFactory.setWaiterLatch(waiter);
+
+               AbstractKeyedStateBackend<Integer> backend = null;
+               KeyGroupsStateHandle stateHandle = null;
+
+               try {
+                       backend = createKeyedBackend(IntSerializer.INSTANCE);
+                       InternalValueState<VoidNamespace, Integer> valueState = 
backend.createValueState(
+                                       VoidNamespaceSerializer.INSTANCE,
+                                       new ValueStateDescriptor<>("test", 
IntSerializer.INSTANCE));
+
+                       valueState.setCurrentNamespace(VoidNamespace.INSTANCE);
+
+                       for (int i = 0; i < 10; ++i) {
+                               backend.setCurrentKey(i);
+                               valueState.update(i);
+                       }
+
+                       RunnableFuture<KeyGroupsStateHandle> snapshot =
+                                       backend.snapshot(0L, 0L, streamFactory, 
CheckpointOptions.forFullCheckpoint());
+                       Thread runner = new Thread(snapshot);
+                       runner.start();
+                       for (int i = 0; i < 20; ++i) {
+                               backend.setCurrentKey(i);
+                               valueState.update(i + 1);
+                               if (10 == i) {
+                                       waiter.await();
+                               }
+                       }
+
+                       runner.join();
+                       stateHandle = snapshot.get();
+
+                       // test isolation
+                       for (int i = 0; i < 20; ++i) {
+                               backend.setCurrentKey(i);
+                               Assert.assertEquals(i + 1, (int) 
valueState.value());
+                       }
+
+               } finally {
+                       if (null != backend) {
+                               IOUtils.closeQuietly(backend);
+                               backend.dispose();
+                       }
+               }
+
+               Assert.assertNotNull(stateHandle);
+
+               backend = createKeyedBackend(IntSerializer.INSTANCE);
+               try {
+                       backend.restore(Collections.singleton(stateHandle));
+                       InternalValueState<VoidNamespace, Integer> valueState = 
backend.createValueState(
+                                       VoidNamespaceSerializer.INSTANCE,
+                                       new ValueStateDescriptor<>("test", 
IntSerializer.INSTANCE));
+
+                       valueState.setCurrentNamespace(VoidNamespace.INSTANCE);
+
+                       for (int i = 0; i < 10; ++i) {
+                               backend.setCurrentKey(i);
+                               Assert.assertEquals(i, (int) 
valueState.value());
+                       }
+
+                       backend.setCurrentKey(11);
+                       Assert.assertEquals(null, valueState.value());
+               } finally {
+                       if (null != backend) {
+                               IOUtils.closeQuietly(backend);
+                               backend.dispose();
+                       }
+               }
+       }
+
+       @Test
+       public void testAsyncSnapshotCancellation() throws Exception {
+               OneShotLatch blocker = new OneShotLatch();
+               OneShotLatch waiter = new OneShotLatch();
+               BlockerCheckpointStreamFactory streamFactory = new 
BlockerCheckpointStreamFactory(1024 * 1024);
+               streamFactory.setWaiterLatch(waiter);
+               streamFactory.setBlockerLatch(blocker);
+               streamFactory.setAfterNumberInvocations(100);
+
+               AbstractKeyedStateBackend<Integer> backend = null;
+               try {
+                       backend = createKeyedBackend(IntSerializer.INSTANCE);
+
+                       if (!backend.supportsAsynchronousSnapshots()) {
+                               return;
+                       }
+
+                       InternalValueState<VoidNamespace, Integer> valueState = 
backend.createValueState(
+                                       VoidNamespaceSerializer.INSTANCE,
+                                       new ValueStateDescriptor<>("test", 
IntSerializer.INSTANCE));
+
+                       valueState.setCurrentNamespace(VoidNamespace.INSTANCE);
+
+                       for (int i = 0; i < 10; ++i) {
+                               backend.setCurrentKey(i);
+                               valueState.update(i);
+                       }
+
+                       RunnableFuture<KeyGroupsStateHandle> snapshot =
+                                       backend.snapshot(0L, 0L, streamFactory, 
CheckpointOptions.forFullCheckpoint());
+
+                       Thread runner = new Thread(snapshot);
+                       runner.start();
+
+                       // wait until the code reached some stream read
+                       waiter.await();
+
+                       // close the backend to see if the close is propagated 
to the stream
+                       backend.close();
+
+                       //unblock the stream so that it can run into the 
IOException
+                       blocker.trigger();
+
+                       //dispose the backend
+                       backend.dispose();
+
+                       runner.join();
+
+                       try {
+                               snapshot.get();
+                               fail("Close was not propagated.");
+                       } catch (ExecutionException ex) {
+                               //ignore
+                       }
+
+               } finally {
+                       if (null != backend) {
+                               IOUtils.closeQuietly(backend);
+                               backend.dispose();
+                       }
+               }
+       }
+
        private static class AppendingFold implements FoldFunction<Integer, 
String> {
                private static final long serialVersionUID = 1L;
 
@@ -1764,7 +1908,7 @@ public abstract class StateBackendTestBase<B extends 
AbstractStateBackend> exten
                        return 
KvStateRequestSerializer.deserializeList(serializedValue, valueSerializer);
                }
        }
-       
+
        /**
         * Returns the value by getting the serialized value and deserializing 
it
         * if it is not null.

Reply via email to