http://git-wip-us.apache.org/repos/asf/flink/blob/ab014ef9/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTable.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTable.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTable.java new file mode 100644 index 0000000..d63b6d3 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTable.java @@ -0,0 +1,1066 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state.heap; + +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.runtime.state.RegisteredBackendStateMetaInfo; +import org.apache.flink.runtime.state.StateTransformationFunction; +import org.apache.flink.util.MathUtils; +import org.apache.flink.util.Preconditions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Arrays; +import java.util.ConcurrentModificationException; +import java.util.Iterator; +import java.util.NoSuchElementException; +import java.util.Objects; +import java.util.TreeSet; + +/** + * Implementation of Flink's in-memory state tables with copy-on-write support. This map does not support null values + * for key or namespace. + * <p> + * {@link CopyOnWriteStateTable} sacrifices some peak performance and memory efficiency for features like incremental + * rehashing and asynchronous snapshots through copy-on-write. Copy-on-write tries to minimize the amount of copying by + * maintaining version meta data for both, the map structure and the state objects. However, we must often proactively + * copy state objects when we hand them to the user. + * <p> + * As for any state backend, user should not keep references on state objects that they obtained from state backends + * outside the scope of the user function calls. + * <p> + * Some brief maintenance notes: + * <p> + * 1) Flattening the underlying data structure from nested maps (namespace) -> (key) -> (state) to one flat map + * (key, namespace) -> (state) brings certain performance trade-offs. In theory, the flat map has one less level of + * indirection compared to the nested map. However, the nested map naturally de-duplicates namespace objects for which + * #equals() is true. This leads to potentially a lot of redundant namespace objects for the flattened version. Those, + * in turn, can again introduce more cache misses because we need to follow the namespace object on all operations to + * ensure entry identities. Obviously, copy-on-write can also add memory overhead. So does the meta data to track + * copy-on-write requirement (state and entry versions on {@link StateTableEntry}). + * <p> + * 2) A flat map structure is a lot easier when it comes to tracking copy-on-write of the map structure. + * <p> + * 3) Nested structure had the (never used) advantage that we can easily drop and iterate whole namespaces. This could + * give locality advantages for certain access pattern, e.g. iterating a namespace. + * <p> + * 4) Serialization format is changed from namespace-prefix compressed (as naturally provided from the old nested + * structure) to making all entries self contained as (key, namespace, state). + * <p> + * 5) We got rid of having multiple nested tables, one for each key-group. Instead, we partition state into key-groups + * on-the-fly, during the asynchronous part of a snapshot. + * <p> + * 6) Currently, a state table can only grow, but never shrinks on low load. We could easily add this if required. + * <p> + * 7) Heap based state backends like this can easily cause a lot of GC activity. Besides using G1 as garbage collector, + * we should provide an additional state backend that operates on off-heap memory. This would sacrifice peak performance + * (due to de/serialization of objects) for a lower, but more constant throughput and potentially huge simplifications + * w.r.t. copy-on-write. + * <p> + * 8) We could try a hybrid of a serialized and object based backends, where key and namespace of the entries are both + * serialized in one byte-array. + * <p> + * 9) We could consider smaller types (e.g. short) for the version counting and think about some reset strategy before + * overflows, when there is no snapshot running. However, this would have to touch all entries in the map. + * <p> + * This class was initially based on the {@link java.util.HashMap} implementation of the Android JDK, but is now heavily + * customized towards the use case of table for state entries. + * + * IMPORTANT: the contracts for this class rely on the user not holding any references to objects returned by this map + * beyond the life cycle of per-element operations. Or phrased differently, all get-update-put operations on a mapping + * should be within one call of processElement. Otherwise, the user must take care of taking deep copies, e.g. for + * caching purposes. + * + * @param <K> type of key. + * @param <N> type of namespace. + * @param <S> type of value. + */ +public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implements Iterable<StateEntry<K, N, S>> { + + /** + * The logger. + */ + private static final Logger LOG = LoggerFactory.getLogger(HeapKeyedStateBackend.class); + + /** + * Min capacity (other than zero) for a {@link CopyOnWriteStateTable}. Must be a power of two + * greater than 1 (and less than 1 << 30). + */ + private static final int MINIMUM_CAPACITY = 4; + + /** + * Max capacity for a {@link CopyOnWriteStateTable}. Must be a power of two >= MINIMUM_CAPACITY. + */ + private static final int MAXIMUM_CAPACITY = 1 << 30; + + /** + * Minimum number of entries that one step of incremental rehashing migrates from the old to the new sub-table. + */ + private static final int MIN_TRANSFERRED_PER_INCREMENTAL_REHASH = 4; + + /** + * An empty table shared by all zero-capacity maps (typically from default + * constructor). It is never written to, and replaced on first put. Its size + * is set to half the minimum, so that the first resize will create a + * minimum-sized table. + */ + private static final StateTableEntry<?, ?, ?>[] EMPTY_TABLE = new StateTableEntry[MINIMUM_CAPACITY >>> 1]; + + /** + * Empty entry that we use to bootstrap our {@link CopyOnWriteStateTable.StateEntryIterator}. + */ + private static final StateTableEntry<?, ?, ?> ITERATOR_BOOTSTRAP_ENTRY = new StateTableEntry<>(); + + /** + * Maintains an ordered set of version ids that are still in use by unreleased snapshots. + */ + private final TreeSet<Integer> snapshotVersions; + + /** + * This is the primary entry array (hash directory) of the state table. If no incremental rehash is ongoing, this + * is the only used table. + **/ + private StateTableEntry<K, N, S>[] primaryTable; + + /** + * We maintain a secondary entry array while performing an incremental rehash. The purpose is to slowly migrate + * entries from the primary table to this resized table array. When all entries are migrated, this becomes the new + * primary table. + */ + private StateTableEntry<K, N, S>[] incrementalRehashTable; + + /** + * The current number of mappings in the primary table. + */ + private int primaryTableSize; + + /** + * The current number of mappings in the rehash table. + */ + private int incrementalRehashTableSize; + + /** + * The next index for a step of incremental rehashing in the primary table. + */ + private int rehashIndex; + + /** + * The current version of this map. Used for copy-on-write mechanics. + */ + private int stateTableVersion; + + /** + * The highest version of this map that is still required by any unreleased snapshot. + */ + private int highestRequiredSnapshotVersion; + + /** + * The last namespace that was actually inserted. This is a small optimization to reduce duplicate namespace objects. + */ + private N lastNamespace; + + /** + * The {@link CopyOnWriteStateTable} is rehashed when its size exceeds this threshold. + * The value of this field is generally .75 * capacity, except when + * the capacity is zero, as described in the EMPTY_TABLE declaration + * above. + */ + private int threshold; + + /** + * Incremented by "structural modifications" to allow (best effort) + * detection of concurrent modification. + */ + private int modCount; + + /** + * Constructs a new {@code StateTable} with default capacity of 1024. + * + * @param keyContext the key context. + * @param metaInfo the meta information, including the type serializer for state copy-on-write. + */ + CopyOnWriteStateTable(InternalKeyContext<K> keyContext, RegisteredBackendStateMetaInfo<N, S> metaInfo) { + this(keyContext, metaInfo, 1024); + } + + /** + * Constructs a new {@code StateTable} instance with the specified capacity. + * + * @param keyContext the key context. + * @param metaInfo the meta information, including the type serializer for state copy-on-write. + * @param capacity the initial capacity of this hash map. + * @throws IllegalArgumentException when the capacity is less than zero. + */ + @SuppressWarnings("unchecked") + private CopyOnWriteStateTable(InternalKeyContext<K> keyContext, RegisteredBackendStateMetaInfo<N, S> metaInfo, int capacity) { + super(keyContext, metaInfo); + + // initialized tables to EMPTY_TABLE. + this.primaryTable = (StateTableEntry<K, N, S>[]) EMPTY_TABLE; + this.incrementalRehashTable = (StateTableEntry<K, N, S>[]) EMPTY_TABLE; + + // initialize sizes to 0. + this.primaryTableSize = 0; + this.incrementalRehashTableSize = 0; + + this.rehashIndex = 0; + this.stateTableVersion = 0; + this.highestRequiredSnapshotVersion = 0; + this.snapshotVersions = new TreeSet<>(); + + if (capacity < 0) { + throw new IllegalArgumentException("Capacity: " + capacity); + } + + if (capacity == 0) { + threshold = -1; + return; + } + + if (capacity < MINIMUM_CAPACITY) { + capacity = MINIMUM_CAPACITY; + } else if (capacity > MAXIMUM_CAPACITY) { + capacity = MAXIMUM_CAPACITY; + } else { + capacity = MathUtils.roundUpToPowerOfTwo(capacity); + } + primaryTable = makeTable(capacity); + } + + // Public API from AbstractStateTable ------------------------------------------------------------------------------ + + /** + * Returns the total number of entries in this {@link CopyOnWriteStateTable}. This is the sum of both sub-tables. + * + * @return the number of entries in this {@link CopyOnWriteStateTable}. + */ + @Override + public int size() { + return primaryTableSize + incrementalRehashTableSize; + } + + @Override + public S get(K key, N namespace) { + + final int hash = computeHashForOperationAndDoIncrementalRehash(key, namespace); + final int requiredVersion = highestRequiredSnapshotVersion; + final StateTableEntry<K, N, S>[] tab = selectActiveTable(hash); + int index = hash & (tab.length - 1); + + for (StateTableEntry<K, N, S> e = tab[index]; e != null; e = e.next) { + final K eKey = e.key; + final N eNamespace = e.namespace; + if ((e.hash == hash && key.equals(eKey) && namespace.equals(eNamespace))) { + + // copy-on-write check for state + if (e.stateVersion < requiredVersion) { + // copy-on-write check for entry + if (e.entryVersion < requiredVersion) { + e = handleChainedEntryCopyOnWrite(tab, hash & (tab.length - 1), e); + } + e.stateVersion = stateTableVersion; + e.state = getStateSerializer().copy(e.state); + } + + return e.state; + } + } + + return null; + } + + @Override + public void put(K key, int keyGroup, N namespace, S state) { + put(key, namespace, state); + } + + @Override + public S get(N namespace) { + return get(keyContext.getCurrentKey(), namespace); + } + + @Override + public boolean containsKey(N namespace) { + return containsKey(keyContext.getCurrentKey(), namespace); + } + + @Override + public void put(N namespace, S state) { + put(keyContext.getCurrentKey(), namespace, state); + } + + @Override + public S putAndGetOld(N namespace, S state) { + return putAndGetOld(keyContext.getCurrentKey(), namespace, state); + } + + @Override + public void remove(N namespace) { + remove(keyContext.getCurrentKey(), namespace); + } + + @Override + public S removeAndGetOld(N namespace) { + return removeAndGetOld(keyContext.getCurrentKey(), namespace); + } + + @Override + public <T> void transform(N namespace, T value, StateTransformationFunction<S, T> transformation) throws Exception { + transform(keyContext.getCurrentKey(), namespace, value, transformation); + } + + // Private implementation details of the API methods --------------------------------------------------------------- + + /** + * Returns whether this table contains the specified key/namespace composite key. + * + * @param key the key in the composite key to search for. Not null. + * @param namespace the namespace in the composite key to search for. Not null. + * @return {@code true} if this map contains the specified key/namespace composite key, + * {@code false} otherwise. + */ + boolean containsKey(K key, N namespace) { + + final int hash = computeHashForOperationAndDoIncrementalRehash(key, namespace); + final StateTableEntry<K, N, S>[] tab = selectActiveTable(hash); + int index = hash & (tab.length - 1); + + for (StateTableEntry<K, N, S> e = tab[index]; e != null; e = e.next) { + final K eKey = e.key; + final N eNamespace = e.namespace; + + if ((e.hash == hash && key.equals(eKey) && namespace.equals(eNamespace))) { + return true; + } + } + return false; + } + + /** + * Maps the specified key/namespace composite key to the specified value. This method should be preferred + * over {@link #putAndGetOld(Object, Object, Object)} (Object, Object)} when the caller is not interested + * in the old value, because this can potentially reduce copy-on-write activity. + * + * @param key the key. Not null. + * @param namespace the namespace. Not null. + * @param value the value. Can be null. + */ + void put(K key, N namespace, S value) { + final StateTableEntry<K, N, S> e = putEntry(key, namespace); + + e.state = value; + e.stateVersion = stateTableVersion; + } + + /** + * Maps the specified key/namespace composite key to the specified value. Returns the previous state that was + * registered under the composite key. + * + * @param key the key. Not null. + * @param namespace the namespace. Not null. + * @param value the value. Can be null. + * @return the value of any previous mapping with the specified key or + * {@code null} if there was no such mapping. + */ + S putAndGetOld(K key, N namespace, S value) { + + final StateTableEntry<K, N, S> e = putEntry(key, namespace); + + // copy-on-write check for state + S oldState = (e.stateVersion < highestRequiredSnapshotVersion) ? + getStateSerializer().copy(e.state) : + e.state; + + e.state = value; + e.stateVersion = stateTableVersion; + + return oldState; + } + + /** + * Removes the mapping with the specified key/namespace composite key from this map. This method should be preferred + * over {@link #removeAndGetOld(Object, Object)} when the caller is not interested in the old value, because this + * can potentially reduce copy-on-write activity. + * + * @param key the key of the mapping to remove. Not null. + * @param namespace the namespace of the mapping to remove. Not null. + */ + void remove(K key, N namespace) { + removeEntry(key, namespace); + } + + /** + * Removes the mapping with the specified key/namespace composite key from this map, returning the state that was + * found under the entry. + * + * @param key the key of the mapping to remove. Not null. + * @param namespace the namespace of the mapping to remove. Not null. + * @return the value of the removed mapping or {@code null} if no mapping + * for the specified key was found. + */ + S removeAndGetOld(K key, N namespace) { + + final StateTableEntry<K, N, S> e = removeEntry(key, namespace); + + return e != null ? + // copy-on-write check for state + (e.stateVersion < highestRequiredSnapshotVersion ? + getStateSerializer().copy(e.state) : + e.state) : + null; + } + + /** + * @param key the key of the mapping to remove. Not null. + * @param namespace the namespace of the mapping to remove. Not null. + * @param value the value that is the second input for the transformation. + * @param transformation the transformation function to apply on the old state and the given value. + * @param <T> type of the value that is the second input to the {@link StateTransformationFunction}. + * @throws Exception exception that happen on applying the function. + * @see #transform(Object, Object, StateTransformationFunction). + */ + <T> void transform( + K key, + N namespace, + T value, + StateTransformationFunction<S, T> transformation) throws Exception { + + final StateTableEntry<K, N, S> entry = putEntry(key, namespace); + + // copy-on-write check for state + entry.state = transformation.apply( + (entry.stateVersion < highestRequiredSnapshotVersion) ? + getStateSerializer().copy(entry.state) : + entry.state, + value); + entry.stateVersion = stateTableVersion; + } + + /** + * Helper method that is the basis for operations that add mappings. + */ + private StateTableEntry<K, N, S> putEntry(K key, N namespace) { + + final int hash = computeHashForOperationAndDoIncrementalRehash(key, namespace); + final StateTableEntry<K, N, S>[] tab = selectActiveTable(hash); + int index = hash & (tab.length - 1); + + for (StateTableEntry<K, N, S> e = tab[index]; e != null; e = e.next) { + if (e.hash == hash && key.equals(e.key) && namespace.equals(e.namespace)) { + + // copy-on-write check for entry + if (e.entryVersion < highestRequiredSnapshotVersion) { + e = handleChainedEntryCopyOnWrite(tab, index, e); + } + + return e; + } + } + + ++modCount; + if (size() > threshold) { + doubleCapacity(); + } + + return addNewStateTableEntry(tab, key, namespace, hash); + } + + /** + * Helper method that is the basis for operations that remove mappings. + */ + private StateTableEntry<K, N, S> removeEntry(K key, N namespace) { + + final int hash = computeHashForOperationAndDoIncrementalRehash(key, namespace); + final StateTableEntry<K, N, S>[] tab = selectActiveTable(hash); + int index = hash & (tab.length - 1); + + for (StateTableEntry<K, N, S> e = tab[index], prev = null; e != null; prev = e, e = e.next) { + if (e.hash == hash && key.equals(e.key) && namespace.equals(e.namespace)) { + if (prev == null) { + tab[index] = e.next; + } else { + // copy-on-write check for entry + if (prev.entryVersion < highestRequiredSnapshotVersion) { + prev = handleChainedEntryCopyOnWrite(tab, index, prev); + } + prev.next = e.next; + } + ++modCount; + if (tab == primaryTable) { + --primaryTableSize; + } else { + --incrementalRehashTableSize; + } + return e; + } + } + return null; + } + + private void checkKeyNamespacePreconditions(K key, N namespace) { + Preconditions.checkNotNull(key, "No key set. This method should not be called outside of a keyed context."); + Preconditions.checkNotNull(namespace, "Provided namespace is null."); + } + + // Meta data setter / getter and toString -------------------------------------------------------------------------- + + @Override + public TypeSerializer<S> getStateSerializer() { + return metaInfo.getStateSerializer(); + } + + @Override + public TypeSerializer<N> getNamespaceSerializer() { + return metaInfo.getNamespaceSerializer(); + } + + @Override + public RegisteredBackendStateMetaInfo<N, S> getMetaInfo() { + return metaInfo; + } + + @Override + public void setMetaInfo(RegisteredBackendStateMetaInfo<N, S> metaInfo) { + this.metaInfo = metaInfo; + } + + // Iteration ------------------------------------------------------------------------------------------------------ + + @Override + public Iterator<StateEntry<K, N, S>> iterator() { + return new StateEntryIterator(); + } + + // Private utility functions for StateTable management ------------------------------------------------------------- + + /** + * @see #releaseSnapshot(CopyOnWriteStateTableSnapshot) + */ + @VisibleForTesting + void releaseSnapshot(int snapshotVersion) { + // we guard against concurrent modifications of highestRequiredSnapshotVersion between snapshot and release. + // Only stale reads of from the result of #releaseSnapshot calls are ok. + synchronized (snapshotVersions) { + Preconditions.checkState(snapshotVersions.remove(snapshotVersion), "Attempt to release unknown snapshot version"); + highestRequiredSnapshotVersion = snapshotVersions.isEmpty() ? 0 : snapshotVersions.last(); + } + } + + /** + * Creates (combined) copy of the table arrays for a snapshot. This method must be called by the same Thread that + * does modifications to the {@link CopyOnWriteStateTable}. + */ + @VisibleForTesting + @SuppressWarnings("unchecked") + StateTableEntry<K, N, S>[] snapshotTableArrays() { + + // we guard against concurrent modifications of highestRequiredSnapshotVersion between snapshot and release. + // Only stale reads of from the result of #releaseSnapshot calls are ok. This is why we must call this method + // from the same thread that does all the modifications to the table. + synchronized (snapshotVersions) { + + // increase the table version for copy-on-write and register the snapshot + if (++stateTableVersion < 0) { + // this is just a safety net against overflows, but should never happen in practice (i.e., only after 2^31 snapshots) + throw new IllegalStateException("Version count overflow in CopyOnWriteStateTable. Enforcing restart."); + } + + highestRequiredSnapshotVersion = stateTableVersion; + snapshotVersions.add(highestRequiredSnapshotVersion); + } + + StateTableEntry<K, N, S>[] table = primaryTable; + if (isRehashing()) { + // consider both tables for the snapshot, the rehash index tells us which part of the two tables we need + final int localRehashIndex = rehashIndex; + final int localCopyLength = table.length - localRehashIndex; + StateTableEntry<K, N, S>[] copy = new StateTableEntry[localRehashIndex + table.length]; + // for the primary table, take every index >= rhIdx. + System.arraycopy(table, localRehashIndex, copy, 0, localCopyLength); + + // for the new table, we are sure that two regions contain all the entries: + // [0, rhIdx[ AND [table.length / 2, table.length / 2 + rhIdx[ + table = incrementalRehashTable; + System.arraycopy(table, 0, copy, localCopyLength, localRehashIndex); + System.arraycopy(table, table.length >>> 1, copy, localCopyLength + localRehashIndex, localRehashIndex); + + return copy; + } else { + // we only need to copy the primary table + return Arrays.copyOf(table, table.length); + } + } + + /** + * Allocate a table of the given capacity and set the threshold accordingly. + * + * @param newCapacity must be a power of two + */ + private StateTableEntry<K, N, S>[] makeTable(int newCapacity) { + + if (MAXIMUM_CAPACITY == newCapacity) { + LOG.warn("Maximum capacity of 2^30 in StateTable reached. Cannot increase hash table size. This can lead " + + "to more collisions and lower performance. Please consider scaling-out your job or using a " + + "different keyed state backend implementation!"); + } + + threshold = (newCapacity >> 1) + (newCapacity >> 2); // 3/4 capacity + @SuppressWarnings("unchecked") StateTableEntry<K, N, S>[] newTable + = (StateTableEntry<K, N, S>[]) new StateTableEntry[newCapacity]; + return newTable; + } + + /** + * Creates and inserts a new {@link StateTableEntry}. + */ + private StateTableEntry<K, N, S> addNewStateTableEntry( + StateTableEntry<K, N, S>[] table, + K key, + N namespace, + int hash) { + + // small optimization that aims to avoid holding references on duplicate namespace objects + if (namespace.equals(lastNamespace)) { + namespace = lastNamespace; + } else { + lastNamespace = namespace; + } + + int index = hash & (table.length - 1); + StateTableEntry<K, N, S> newEntry = new StateTableEntry<>( + key, + namespace, + null, + hash, + table[index], + stateTableVersion, + stateTableVersion); + table[index] = newEntry; + + if (table == primaryTable) { + ++primaryTableSize; + } else { + ++incrementalRehashTableSize; + } + return newEntry; + } + + /** + * Select the sub-table which is responsible for entries with the given hash code. + * + * @param hashCode the hash code which we use to decide about the table that is responsible. + * @return the index of the sub-table that is responsible for the entry with the given hash code. + */ + private StateTableEntry<K, N, S>[] selectActiveTable(int hashCode) { + return (hashCode & (primaryTable.length - 1)) >= rehashIndex ? primaryTable : incrementalRehashTable; + } + + /** + * Doubles the capacity of the hash table. Existing entries are placed in + * the correct bucket on the enlarged table. If the current capacity is, + * MAXIMUM_CAPACITY, this method is a no-op. Returns the table, which + * will be new unless we were already at MAXIMUM_CAPACITY. + */ + private void doubleCapacity() { + + // There can only be one rehash in flight. From the amount of incremental rehash steps we take, this should always hold. + Preconditions.checkState(!isRehashing(), "There is already a rehash in progress."); + + StateTableEntry<K, N, S>[] oldTable = primaryTable; + + int oldCapacity = oldTable.length; + + if (oldCapacity == MAXIMUM_CAPACITY) { + return; + } + + incrementalRehashTable = makeTable(oldCapacity * 2); + } + + /** + * Returns true, if an incremental rehash is in progress. + */ + @VisibleForTesting + boolean isRehashing() { + // if we rehash, the secondary table is not empty + return EMPTY_TABLE != incrementalRehashTable; + } + + /** + * Computes the hash for the composite of key and namespace and performs some steps of incremental rehash if + * incremental rehashing is in progress. + */ + private int computeHashForOperationAndDoIncrementalRehash(K key, N namespace) { + + checkKeyNamespacePreconditions(key, namespace); + + if (isRehashing()) { + incrementalRehash(); + } + + return compositeHash(key, namespace); + } + + /** + * Runs a number of steps for incremental rehashing. + */ + @SuppressWarnings("unchecked") + private void incrementalRehash() { + + StateTableEntry<K, N, S>[] oldTable = primaryTable; + StateTableEntry<K, N, S>[] newTable = incrementalRehashTable; + + int oldCapacity = oldTable.length; + int newMask = newTable.length - 1; + int requiredVersion = highestRequiredSnapshotVersion; + int rhIdx = rehashIndex; + int transferred = 0; + + // we migrate a certain minimum amount of entries from the old to the new table + while (transferred < MIN_TRANSFERRED_PER_INCREMENTAL_REHASH) { + + StateTableEntry<K, N, S> e = oldTable[rhIdx]; + + while (e != null) { + // copy-on-write check for entry + if (e.entryVersion < requiredVersion) { + e = new StateTableEntry<>(e, stateTableVersion); + } + StateTableEntry<K, N, S> n = e.next; + int pos = e.hash & newMask; + e.next = newTable[pos]; + newTable[pos] = e; + e = n; + ++transferred; + } + + oldTable[rhIdx] = null; + if (++rhIdx == oldCapacity) { + //here, the rehash is complete and we release resources and reset fields + primaryTable = newTable; + incrementalRehashTable = (StateTableEntry<K, N, S>[]) EMPTY_TABLE; + primaryTableSize += incrementalRehashTableSize; + incrementalRehashTableSize = 0; + rehashIndex = 0; + return; + } + } + + // sync our local bookkeeping the with official bookkeeping fields + primaryTableSize -= transferred; + incrementalRehashTableSize += transferred; + rehashIndex = rhIdx; + } + + /** + * Perform copy-on-write for entry chains. We iterate the (hopefully and probably) still cached chain, replace + * all links up to the 'untilEntry', which we actually wanted to modify. + */ + private StateTableEntry<K, N, S> handleChainedEntryCopyOnWrite( + StateTableEntry<K, N, S>[] tab, + int tableIdx, + StateTableEntry<K, N, S> untilEntry) { + + final int required = highestRequiredSnapshotVersion; + + StateTableEntry<K, N, S> current = tab[tableIdx]; + StateTableEntry<K, N, S> copy; + + if (current.entryVersion < required) { + copy = new StateTableEntry<>(current, stateTableVersion); + tab[tableIdx] = copy; + } else { + // nothing to do, just advance copy to current + copy = current; + } + + // we iterate the chain up to 'until entry' + while (current != untilEntry) { + + //advance current + current = current.next; + + if (current.entryVersion < required) { + // copy and advance the current's copy + copy.next = new StateTableEntry<>(current, stateTableVersion); + copy = copy.next; + } else { + // nothing to do, just advance copy to current + copy = current; + } + } + + return copy; + } + + @SuppressWarnings("unchecked") + private static <K, N, S> StateTableEntry<K, N, S> getBootstrapEntry() { + return (StateTableEntry<K, N, S>) ITERATOR_BOOTSTRAP_ENTRY; + } + + /** + * Helper function that creates and scrambles a composite hash for key and namespace. + */ + private static int compositeHash(Object key, Object namespace) { + // create composite key through XOR, then apply some bit-mixing for better distribution of skewed keys. + return MathUtils.bitMix(key.hashCode() ^ namespace.hashCode()); + } + + // Snapshotting ---------------------------------------------------------------------------------------------------- + + int getStateTableVersion() { + return stateTableVersion; + } + + /** + * Creates a snapshot of this {@link CopyOnWriteStateTable}, to be written in checkpointing. The snapshot integrity + * is protected through copy-on-write from the {@link CopyOnWriteStateTable}. Users should call + * {@link #releaseSnapshot(CopyOnWriteStateTableSnapshot)} after using the returned object. + * + * @return a snapshot from this {@link CopyOnWriteStateTable}, for checkpointing. + */ + @Override + public CopyOnWriteStateTableSnapshot<K, N, S> createSnapshot() { + return new CopyOnWriteStateTableSnapshot<>(this); + } + + /** + * Releases a snapshot for this {@link CopyOnWriteStateTable}. This method should be called once a snapshot is no more needed, + * so that the {@link CopyOnWriteStateTable} can stop considering this snapshot for copy-on-write, thus avoiding unnecessary + * object creation. + * + * @param snapshotToRelease the snapshot to release, which was previously created by this state table. + */ + void releaseSnapshot(CopyOnWriteStateTableSnapshot<K, N, S> snapshotToRelease) { + + Preconditions.checkArgument(snapshotToRelease.isOwner(this), + "Cannot release snapshot which is owned by a different state table."); + + releaseSnapshot(snapshotToRelease.getSnapshotVersion()); + } + + // StateTableEntry ------------------------------------------------------------------------------------------------- + + /** + * One entry in the {@link CopyOnWriteStateTable}. This is a triplet of key, namespace, and state. Thereby, key and + * namespace together serve as a composite key for the state. This class also contains some management meta data for + * copy-on-write, a pointer to link other {@link StateTableEntry}s to a list, and cached hash code. + * + * @param <K> type of key. + * @param <N> type of namespace. + * @param <S> type of state. + */ + static class StateTableEntry<K, N, S> implements StateEntry<K, N, S> { + + /** + * The key. Assumed to be immutable and not null. + */ + final K key; + + /** + * The namespace. Assumed to be immutable and not null. + */ + final N namespace; + + /** + * The state. This is not final to allow exchanging the object for copy-on-write. Can be null. + */ + S state; + + /** + * Link to another {@link StateTableEntry}. This is used to resolve collisions in the + * {@link CopyOnWriteStateTable} through chaining. + */ + StateTableEntry<K, N, S> next; + + /** + * The version of this {@link StateTableEntry}. This is meta data for copy-on-write of the table structure. + */ + int entryVersion; + + /** + * The version of the state object in this entry. This is meta data for copy-on-write of the state object itself. + */ + int stateVersion; + + /** + * The computed secondary hash for the composite of key and namespace. + */ + final int hash; + + StateTableEntry() { + this(null, null, null, 0, null, 0, 0); + } + + StateTableEntry(StateTableEntry<K, N, S> other, int entryVersion) { + this(other.key, other.namespace, other.state, other.hash, other.next, entryVersion, other.stateVersion); + } + + StateTableEntry( + K key, + N namespace, + S state, + int hash, + StateTableEntry<K, N, S> next, + int entryVersion, + int stateVersion) { + this.key = key; + this.namespace = namespace; + this.hash = hash; + this.next = next; + this.entryVersion = entryVersion; + this.state = state; + this.stateVersion = stateVersion; + } + + public final void setState(S value, int mapVersion) { + // naturally, we can update the state version every time we replace the old state with a different object + if (value != state) { + this.state = value; + this.stateVersion = mapVersion; + } + } + + @Override + public K getKey() { + return key; + } + + @Override + public N getNamespace() { + return namespace; + } + + @Override + public S getState() { + return state; + } + + @Override + public final boolean equals(Object o) { + if (!(o instanceof CopyOnWriteStateTable.StateTableEntry)) { + return false; + } + + StateEntry<?, ?, ?> e = (StateEntry<?, ?, ?>) o; + return e.getKey().equals(key) + && e.getNamespace().equals(namespace) + && Objects.equals(e.getState(), state); + } + + @Override + public final int hashCode() { + return (key.hashCode() ^ namespace.hashCode()) ^ Objects.hashCode(state); + } + + @Override + public final String toString() { + return "(" + key + "|" + namespace + ")=" + state; + } + } + + // For testing ---------------------------------------------------------------------------------------------------- + + @Override + public int sizeOfNamespace(Object namespace) { + int count = 0; + for (StateEntry<K, N, S> entry : this) { + if (null != entry && namespace.equals(entry.getNamespace())) { + ++count; + } + } + return count; + } + + + // StateEntryIterator --------------------------------------------------------------------------------------------- + + /** + * Iterator over the entries in a {@link CopyOnWriteStateTable}. + */ + class StateEntryIterator implements Iterator<StateEntry<K, N, S>> { + private StateTableEntry<K, N, S>[] activeTable; + private int nextTablePosition; + private StateTableEntry<K, N, S> nextEntry; + private int expectedModCount = modCount; + + StateEntryIterator() { + this.activeTable = primaryTable; + this.nextTablePosition = 0; + this.expectedModCount = modCount; + this.nextEntry = getBootstrapEntry(); + advanceIterator(); + } + + private StateTableEntry<K, N, S> advanceIterator() { + + StateTableEntry<K, N, S> entryToReturn = nextEntry; + StateTableEntry<K, N, S> next = entryToReturn.next; + + // consider both sub-tables tables to cover the case of rehash + while (next == null) { + + StateTableEntry<K, N, S>[] tab = activeTable; + + while (nextTablePosition < tab.length) { + next = tab[nextTablePosition++]; + + if (next != null) { + nextEntry = next; + return entryToReturn; + } + } + + if (activeTable == incrementalRehashTable) { + break; + } + + activeTable = incrementalRehashTable; + nextTablePosition = 0; + } + + nextEntry = next; + return entryToReturn; + } + + @Override + public boolean hasNext() { + return nextEntry != null; + } + + @Override + public StateTableEntry<K, N, S> next() { + if (modCount != expectedModCount) { + throw new ConcurrentModificationException(); + } + + if (nextEntry == null) { + throw new NoSuchElementException(); + } + + return advanceIterator(); + } + + @Override + public void remove() { + throw new UnsupportedOperationException("Read-only iterator"); + } + } +} \ No newline at end of file
http://git-wip-us.apache.org/repos/asf/flink/blob/ab014ef9/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableSnapshot.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableSnapshot.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableSnapshot.java new file mode 100644 index 0000000..c83fce0 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableSnapshot.java @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state.heap; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.runtime.state.KeyGroupRange; +import org.apache.flink.runtime.state.KeyGroupRangeAssignment; + +import java.io.IOException; + +/** + * This class represents the snapshot of a {@link CopyOnWriteStateTable} and has a role in operator state checkpointing. Besides + * holding the {@link CopyOnWriteStateTable}s internal entries at the time of the snapshot, this class is also responsible for + * preparing and writing the state in the process of checkpointing. + * <p> + * IMPORTANT: Please notice that snapshot integrity of entries in this class rely on proper copy-on-write semantics + * through the {@link CopyOnWriteStateTable} that created the snapshot object, but all objects in this snapshot must be considered + * as READ-ONLY!. The reason is that the objects held by this class may or may not be deep copies of original objects + * that may still used in the {@link CopyOnWriteStateTable}. This depends for each entry on whether or not it was subject to + * copy-on-write operations by the {@link CopyOnWriteStateTable}. Phrased differently: the {@link CopyOnWriteStateTable} provides + * copy-on-write isolation for this snapshot, but this snapshot does not isolate modifications from the + * {@link CopyOnWriteStateTable}! + * + * @param <K> type of key + * @param <N> type of namespace + * @param <S> type of state + */ +@Internal +public class CopyOnWriteStateTableSnapshot<K, N, S> + extends AbstractStateTableSnapshot<K, N, S, CopyOnWriteStateTable<K, N, S>> { + + /** + * Version of the {@link CopyOnWriteStateTable} when this snapshot was created. This can be used to release the snapshot. + */ + private final int snapshotVersion; + + /** + * The number of entries in the {@link CopyOnWriteStateTable} at the time of creating this snapshot. + */ + private final int stateTableSize; + + /** + * The state table entries, as by the time this snapshot was created. Objects in this array may or may not be deep + * copies of the current entries in the {@link CopyOnWriteStateTable} that created this snapshot. This depends for each entry + * on whether or not it was subject to copy-on-write operations by the {@link CopyOnWriteStateTable}. + */ + private final CopyOnWriteStateTable.StateTableEntry<K, N, S>[] snapshotData; + + /** + * Offsets for the individual key-groups. This is lazily created when the snapshot is grouped by key-group during + * the process of writing this snapshot to an output as part of checkpointing. + */ + private int[] keyGroupOffsets; + + /** + * Creates a new {@link CopyOnWriteStateTableSnapshot}. + * + * @param owningStateTable the {@link CopyOnWriteStateTable} for which this object represents a snapshot. + */ + CopyOnWriteStateTableSnapshot(CopyOnWriteStateTable<K, N, S> owningStateTable) { + + super(owningStateTable); + this.snapshotData = owningStateTable.snapshotTableArrays(); + this.snapshotVersion = owningStateTable.getStateTableVersion(); + this.stateTableSize = owningStateTable.size(); + this.keyGroupOffsets = null; + } + + /** + * Returns the internal version of the {@link CopyOnWriteStateTable} when this snapshot was created. This value must be used to + * tell the {@link CopyOnWriteStateTable} when to release this snapshot. + */ + int getSnapshotVersion() { + return snapshotVersion; + } + + /** + * Partitions the snapshot data by key-group. The algorithm first builds a histogram for the distribution of keys + * into key-groups. Then, the histogram is accumulated to obtain the boundaries of each key-group in an array. + * Last, we use the accumulated counts as write position pointers for the key-group's bins when reordering the + * entries by key-group. This operation is lazily performed before the first writing of a key-group. + * <p> + * As a possible future optimization, we could perform the repartitioning in-place, using a scheme similar to the + * cuckoo cycles in cuckoo hashing. This can trade some performance for a smaller memory footprint. + */ + @SuppressWarnings("unchecked") + private void partitionEntriesByKeyGroup() { + + // We only have to perform this step once before the first key-group is written + if (null != keyGroupOffsets) { + return; + } + + final KeyGroupRange keyGroupRange = owningStateTable.keyContext.getKeyGroupRange(); + final int totalKeyGroups = owningStateTable.keyContext.getNumberOfKeyGroups(); + final int baseKgIdx = keyGroupRange.getStartKeyGroup(); + final int[] histogram = new int[keyGroupRange.getNumberOfKeyGroups() + 1]; + + CopyOnWriteStateTable.StateTableEntry<K, N, S>[] unfold = new CopyOnWriteStateTable.StateTableEntry[stateTableSize]; + + // 1) In this step we i) 'unfold' the linked list of entries to a flat array and ii) build a histogram for key-groups + int unfoldIndex = 0; + for (CopyOnWriteStateTable.StateTableEntry<K, N, S> entry : snapshotData) { + while (null != entry) { + int effectiveKgIdx = + KeyGroupRangeAssignment.computeKeyGroupForKeyHash(entry.key.hashCode(), totalKeyGroups) - baseKgIdx + 1; + ++histogram[effectiveKgIdx]; + unfold[unfoldIndex++] = entry; + entry = entry.next; + } + } + + // 2) We accumulate the histogram bins to obtain key-group ranges in the final array + for (int i = 1; i < histogram.length; ++i) { + histogram[i] += histogram[i - 1]; + } + + // 3) We repartition the entries by key-group, using the histogram values as write indexes + for (CopyOnWriteStateTable.StateTableEntry<K, N, S> t : unfold) { + int effectiveKgIdx = + KeyGroupRangeAssignment.computeKeyGroupForKeyHash(t.key.hashCode(), totalKeyGroups) - baseKgIdx; + snapshotData[histogram[effectiveKgIdx]++] = t; + } + + // 4) As byproduct, we also created the key-group offsets + this.keyGroupOffsets = histogram; + } + + @Override + public void release() { + owningStateTable.releaseSnapshot(this); + } + + @Override + public void writeMappingsInKeyGroup(DataOutputView dov, int keyGroupId) throws IOException { + + if (null == keyGroupOffsets) { + partitionEntriesByKeyGroup(); + } + + final CopyOnWriteStateTable.StateTableEntry<K, N, S>[] groupedOut = snapshotData; + KeyGroupRange keyGroupRange = owningStateTable.keyContext.getKeyGroupRange(); + int keyGroupOffsetIdx = keyGroupId - keyGroupRange.getStartKeyGroup() - 1; + int startOffset = keyGroupOffsetIdx < 0 ? 0 : keyGroupOffsets[keyGroupOffsetIdx]; + int endOffset = keyGroupOffsets[keyGroupOffsetIdx + 1]; + + TypeSerializer<K> keySerializer = owningStateTable.keyContext.getKeySerializer(); + TypeSerializer<N> namespaceSerializer = owningStateTable.metaInfo.getNamespaceSerializer(); + TypeSerializer<S> stateSerializer = owningStateTable.metaInfo.getStateSerializer(); + + // write number of mappings in key-group + dov.writeInt(endOffset - startOffset); + + // write mappings + for (int i = startOffset; i < endOffset; ++i) { + CopyOnWriteStateTable.StateTableEntry<K, N, S> toWrite = groupedOut[i]; + groupedOut[i] = null; // free asap for GC + namespaceSerializer.serialize(toWrite.namespace, dov); + keySerializer.serialize(toWrite.key, dov); + stateSerializer.serialize(toWrite.state, dov); + } + } + + /** + * Returns true iff the given state table is the owner of this snapshot object. + */ + boolean isOwner(CopyOnWriteStateTable<K, N, S> stateTable) { + return stateTable == owningStateTable; + } +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/flink/blob/ab014ef9/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapAggregatingState.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapAggregatingState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapAggregatingState.java index 624b83e..64fc1db 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapAggregatingState.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapAggregatingState.java @@ -23,18 +23,16 @@ import org.apache.flink.api.common.state.AggregatingState; import org.apache.flink.api.common.state.AggregatingStateDescriptor; import org.apache.flink.api.common.state.ReducingState; import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.runtime.state.KeyedStateBackend; +import org.apache.flink.runtime.state.StateTransformationFunction; import org.apache.flink.runtime.state.internal.InternalAggregatingState; +import org.apache.flink.util.Preconditions; import java.io.IOException; -import java.util.Map; - -import static org.apache.flink.util.Preconditions.checkState; /** * Heap-backed partitioned {@link ReducingState} that is * snapshotted into files. - * + * * @param <K> The type of the key. * @param <N> The type of the namespace. * @param <IN> The type of the value added to the state. @@ -45,13 +43,11 @@ public class HeapAggregatingState<K, N, IN, ACC, OUT> extends AbstractHeapMergingState<K, N, IN, OUT, ACC, AggregatingState<IN, OUT>, AggregatingStateDescriptor<IN, ACC, OUT>> implements InternalAggregatingState<N, IN, OUT> { - private final AggregateFunction<IN, ACC, OUT> aggFunction; + private final AggregateTransformation<IN, ACC, OUT> aggregateTransformation; /** * Creates a new key/value state for the given hash map of key/value pairs. * - * @param backend - * The state backend backing that created this state. * @param stateDesc * The state identifier for the state. This contains name and can create a default state value. * @param stateTable @@ -60,14 +56,13 @@ public class HeapAggregatingState<K, N, IN, ACC, OUT> * The serializer for the type that indicates the namespace */ public HeapAggregatingState( - KeyedStateBackend<K> backend, AggregatingStateDescriptor<IN, ACC, OUT> stateDesc, StateTable<K, N, ACC> stateTable, TypeSerializer<K> keySerializer, TypeSerializer<N> namespaceSerializer) { - super(backend, stateDesc, stateTable, keySerializer, namespaceSerializer); - this.aggFunction = stateDesc.getAggregateFunction(); + super(stateDesc, stateTable, keySerializer, namespaceSerializer); + this.aggregateTransformation = new AggregateTransformation<>(stateDesc.getAggregateFunction()); } // ------------------------------------------------------------------------ @@ -76,64 +71,25 @@ public class HeapAggregatingState<K, N, IN, ACC, OUT> @Override public OUT get() { - final K key = backend.getCurrentKey(); - - checkState(currentNamespace != null, "No namespace set."); - checkState(key != null, "No key set."); - - Map<N, Map<K, ACC>> namespaceMap = - stateTable.get(backend.getCurrentKeyGroupIndex()); - - if (namespaceMap == null) { - return null; - } - Map<K, ACC> keyedMap = namespaceMap.get(currentNamespace); - - if (keyedMap == null) { - return null; - } - - ACC accumulator = keyedMap.get(key); - return aggFunction.getResult(accumulator); + ACC accumulator = stateTable.get(currentNamespace); + return accumulator != null ? aggregateTransformation.aggFunction.getResult(accumulator) : null; } @Override public void add(IN value) throws IOException { - final K key = backend.getCurrentKey(); - - checkState(currentNamespace != null, "No namespace set."); - checkState(key != null, "No key set."); + final N namespace = currentNamespace; if (value == null) { clear(); return; } - Map<N, Map<K, ACC>> namespaceMap = - stateTable.get(backend.getCurrentKeyGroupIndex()); - - if (namespaceMap == null) { - namespaceMap = createNewMap(); - stateTable.set(backend.getCurrentKeyGroupIndex(), namespaceMap); - } - - Map<K, ACC> keyedMap = namespaceMap.get(currentNamespace); - - if (keyedMap == null) { - keyedMap = createNewMap(); - namespaceMap.put(currentNamespace, keyedMap); - } - - // if this is the first value for the key, create a new accumulator - ACC accumulator = keyedMap.get(key); - if (accumulator == null) { - accumulator = aggFunction.createAccumulator(); - keyedMap.put(key, accumulator); + try { + stateTable.transform(namespace, value, aggregateTransformation); + } catch (Exception e) { + throw new IOException("Exception while applying AggregateFunction in aggregating state", e); } - - // - aggFunction.add(value, accumulator); } // ------------------------------------------------------------------------ @@ -142,6 +98,24 @@ public class HeapAggregatingState<K, N, IN, ACC, OUT> @Override protected ACC mergeState(ACC a, ACC b) throws Exception { - return aggFunction.merge(a, b); + return aggregateTransformation.aggFunction.merge(a, b); + } + + static final class AggregateTransformation<IN, ACC, OUT> implements StateTransformationFunction<ACC, IN> { + + private final AggregateFunction<IN, ACC, OUT> aggFunction; + + AggregateTransformation(AggregateFunction<IN, ACC, OUT> aggFunction) { + this.aggFunction = Preconditions.checkNotNull(aggFunction); + } + + @Override + public ACC apply(ACC accumulator, IN value) throws Exception { + if (accumulator == null) { + accumulator = aggFunction.createAccumulator(); + } + aggFunction.add(value, accumulator); + return accumulator; + } } -} +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/flink/blob/ab014ef9/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapFoldingState.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapFoldingState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapFoldingState.java index 6df3f5d..dad6d0d 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapFoldingState.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapFoldingState.java @@ -22,12 +22,11 @@ import org.apache.flink.api.common.functions.FoldFunction; import org.apache.flink.api.common.state.FoldingState; import org.apache.flink.api.common.state.FoldingStateDescriptor; import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.runtime.state.KeyedStateBackend; +import org.apache.flink.runtime.state.StateTransformationFunction; import org.apache.flink.runtime.state.internal.InternalFoldingState; import org.apache.flink.util.Preconditions; import java.io.IOException; -import java.util.Map; /** * Heap-backed partitioned {@link FoldingState} that is @@ -43,24 +42,22 @@ public class HeapFoldingState<K, N, T, ACC> implements InternalFoldingState<N, T, ACC> { /** The function used to fold the state */ - private final FoldFunction<T, ACC> foldFunction; + private final FoldTransformation<T, ACC> foldTransformation; /** * Creates a new key/value state for the given hash map of key/value pairs. * - * @param backend The state backend backing that created this state. * @param stateDesc The state identifier for the state. This contains name * and can create a default state value. * @param stateTable The state tab;e to use in this kev/value state. May contain initial state. */ public HeapFoldingState( - KeyedStateBackend<K> backend, FoldingStateDescriptor<T, ACC> stateDesc, StateTable<K, N, ACC> stateTable, TypeSerializer<K> keySerializer, TypeSerializer<N> namespaceSerializer) { - super(backend, stateDesc, stateTable, keySerializer, namespaceSerializer); - this.foldFunction = stateDesc.getFoldFunction(); + super(stateDesc, stateTable, keySerializer, namespaceSerializer); + this.foldTransformation = new FoldTransformation<>(stateDesc); } // ------------------------------------------------------------------------ @@ -69,62 +66,37 @@ public class HeapFoldingState<K, N, T, ACC> @Override public ACC get() { - Preconditions.checkState(currentNamespace != null, "No namespace set."); - Preconditions.checkState(backend.getCurrentKey() != null, "No key set."); - - Map<N, Map<K, ACC>> namespaceMap = - stateTable.get(backend.getCurrentKeyGroupIndex()); - - if (namespaceMap == null) { - return null; - } - - Map<K, ACC> keyedMap = namespaceMap.get(currentNamespace); - - if (keyedMap == null) { - return null; - } - - return keyedMap.get(backend.<K>getCurrentKey()); + return stateTable.get(currentNamespace); } @Override public void add(T value) throws IOException { - Preconditions.checkState(currentNamespace != null, "No namespace set."); - Preconditions.checkState(backend.getCurrentKey() != null, "No key set."); if (value == null) { clear(); return; } - Map<N, Map<K, ACC>> namespaceMap = - stateTable.get(backend.getCurrentKeyGroupIndex()); - - if (namespaceMap == null) { - namespaceMap = createNewMap(); - stateTable.set(backend.getCurrentKeyGroupIndex(), namespaceMap); + try { + stateTable.transform(currentNamespace, value, foldTransformation); + } catch (Exception e) { + throw new IOException("Could not add value to folding state.", e); } + } - Map<K, ACC> keyedMap = namespaceMap.get(currentNamespace); - - if (keyedMap == null) { - keyedMap = createNewMap(); - namespaceMap.put(currentNamespace, keyedMap); - } + static final class FoldTransformation<T, ACC> implements StateTransformationFunction<ACC, T> { - ACC currentValue = keyedMap.get(backend.<K>getCurrentKey()); + private final FoldingStateDescriptor<T, ACC> stateDescriptor; + private final FoldFunction<T, ACC> foldFunction; - try { + FoldTransformation(FoldingStateDescriptor<T, ACC> stateDesc) { + this.stateDescriptor = Preconditions.checkNotNull(stateDesc); + this.foldFunction = Preconditions.checkNotNull(stateDesc.getFoldFunction()); + } - if (currentValue == null) { - keyedMap.put(backend.<K>getCurrentKey(), - foldFunction.fold(stateDesc.getDefaultValue(), value)); - } else { - keyedMap.put(backend.<K>getCurrentKey(), foldFunction.fold(currentValue, value)); - } - } catch (Exception e) { - throw new RuntimeException("Could not add value to folding state.", e); + @Override + public ACC apply(ACC previousState, T value) throws Exception { + return foldFunction.fold((previousState != null) ? previousState : stateDescriptor.getDefaultValue(), value); } } -} +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/flink/blob/ab014ef9/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java index a4a08c1..0335933 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java @@ -7,7 +7,7 @@ * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -18,6 +18,7 @@ package org.apache.flink.runtime.state.heap; +import org.apache.commons.collections.map.HashedMap; import org.apache.commons.io.IOUtils; import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.api.common.ExecutionConfig; @@ -29,18 +30,15 @@ import org.apache.flink.api.common.state.ReducingStateDescriptor; import org.apache.flink.api.common.state.StateDescriptor; import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.api.common.typeutils.base.VoidSerializer; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.core.fs.FSDataInputStream; -import org.apache.flink.core.fs.FileSystem; -import org.apache.flink.core.memory.DataInputView; import org.apache.flink.core.memory.DataInputViewStreamWrapper; -import org.apache.flink.core.memory.DataOutputView; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; import org.apache.flink.migration.MigrationUtil; import org.apache.flink.migration.runtime.state.KvStateSnapshot; -import org.apache.flink.migration.runtime.state.filesystem.AbstractFsStateSnapshot; -import org.apache.flink.migration.runtime.state.memory.AbstractMemStateSnapshot; +import org.apache.flink.migration.runtime.state.memory.MigrationRestoreSnapshot; +import org.apache.flink.runtime.io.async.AbstractAsyncIOCallable; +import org.apache.flink.runtime.io.async.AsyncStoppableTaskWithCallback; import org.apache.flink.runtime.checkpoint.CheckpointOptions; import org.apache.flink.runtime.query.TaskKvStateRegistry; import org.apache.flink.runtime.state.AbstractKeyedStateBackend; @@ -54,8 +52,6 @@ import org.apache.flink.runtime.state.KeyGroupsStateHandle; import org.apache.flink.runtime.state.KeyedBackendSerializationProxy; import org.apache.flink.runtime.state.RegisteredBackendStateMetaInfo; import org.apache.flink.runtime.state.StreamStateHandle; -import org.apache.flink.runtime.state.VoidNamespace; -import org.apache.flink.runtime.state.VoidNamespaceSerializer; import org.apache.flink.runtime.state.internal.InternalAggregatingState; import org.apache.flink.runtime.state.internal.InternalFoldingState; import org.apache.flink.runtime.state.internal.InternalListState; @@ -74,6 +70,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.RunnableFuture; +import java.util.concurrent.atomic.AtomicBoolean; /** * A {@link AbstractKeyedStateBackend} that keeps state on the Java Heap and will serialize state to @@ -94,7 +91,13 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> { * but we can't put them here because different key/value states with different types and * namespace types share this central list of tables. */ - private final Map<String, StateTable<K, ?, ?>> stateTables = new HashMap<>(); + private final HashMap<String, StateTable<K, ?, ?>> stateTables = new HashMap<>(); + + /** + * Determines whether or not we run snapshots asynchronously. This impacts the choice of the underlying + * {@link StateTable} implementation. + */ + private final boolean asynchronousSnapshots; public HeapKeyedStateBackend( TaskKvStateRegistry kvStateRegistry, @@ -102,10 +105,11 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> { ClassLoader userCodeClassLoader, int numberOfKeyGroups, KeyGroupRange keyGroupRange, + boolean asynchronousSnapshots, ExecutionConfig executionConfig) { super(kvStateRegistry, keySerializer, userCodeClassLoader, numberOfKeyGroups, keyGroupRange, executionConfig); - + this.asynchronousSnapshots = asynchronousSnapshots; LOG.info("Initializing heap keyed state backend with stream factory."); } @@ -124,7 +128,7 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> { private <N, V> StateTable<K, N, V> tryRegisterStateTable( String stateName, StateDescriptor.Type stateType, - TypeSerializer<N> namespaceSerializer, + TypeSerializer<N> namespaceSerializer, TypeSerializer<V> valueSerializer) { final RegisteredBackendStateMetaInfo<N, V> newMetaInfo = @@ -134,7 +138,7 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> { StateTable<K, N, V> stateTable = (StateTable<K, N, V>) stateTables.get(stateName); if (stateTable == null) { - stateTable = new StateTable<>(newMetaInfo, keyGroupRange); + stateTable = newStateTable(newMetaInfo); stateTables.put(stateName, stateTable); } else { if (!newMetaInfo.isCompatibleWith(stateTable.getMetaInfo())) { @@ -152,7 +156,7 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> { ValueStateDescriptor<V> stateDesc) throws Exception { StateTable<K, N, V> stateTable = tryRegisterStateTable(namespaceSerializer, stateDesc); - return new HeapValueState<>(this, stateDesc, stateTable, keySerializer, namespaceSerializer); + return new HeapValueState<>(stateDesc, stateTable, keySerializer, namespaceSerializer); } @Override @@ -170,7 +174,7 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> { namespaceSerializer, new ArrayListSerializer<T>(stateDesc.getElementSerializer())); - return new HeapListState<>(this, stateDesc, stateTable, keySerializer, namespaceSerializer); + return new HeapListState<>(stateDesc, stateTable, keySerializer, namespaceSerializer); } @Override @@ -179,7 +183,7 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> { ReducingStateDescriptor<T> stateDesc) throws Exception { StateTable<K, N, T> stateTable = tryRegisterStateTable(namespaceSerializer, stateDesc); - return new HeapReducingState<>(this, stateDesc, stateTable, keySerializer, namespaceSerializer); + return new HeapReducingState<>(stateDesc, stateTable, keySerializer, namespaceSerializer); } @Override @@ -188,7 +192,7 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> { AggregatingStateDescriptor<T, ACC, R> stateDesc) throws Exception { StateTable<K, N, ACC> stateTable = tryRegisterStateTable(namespaceSerializer, stateDesc); - return new HeapAggregatingState<>(this, stateDesc, stateTable, keySerializer, namespaceSerializer); + return new HeapAggregatingState<>(stateDesc, stateTable, keySerializer, namespaceSerializer); } @Override @@ -197,83 +201,151 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> { FoldingStateDescriptor<T, ACC> stateDesc) throws Exception { StateTable<K, N, ACC> stateTable = tryRegisterStateTable(namespaceSerializer, stateDesc); - return new HeapFoldingState<>(this, stateDesc, stateTable, keySerializer, namespaceSerializer); + return new HeapFoldingState<>(stateDesc, stateTable, keySerializer, namespaceSerializer); } @Override public <N, UK, UV> InternalMapState<N, UK, UV> createMapState(TypeSerializer<N> namespaceSerializer, MapStateDescriptor<UK, UV> stateDesc) throws Exception { - + StateTable<K, N, HashMap<UK, UV>> stateTable = tryRegisterStateTable( stateDesc.getName(), stateDesc.getType(), namespaceSerializer, new HashMapSerializer<>(stateDesc.getKeySerializer(), stateDesc.getValueSerializer())); - - return new HeapMapState<>(this, stateDesc, stateTable, keySerializer, namespaceSerializer); + + return new HeapMapState<>(stateDesc, stateTable, keySerializer, namespaceSerializer); } @Override @SuppressWarnings("unchecked") - public RunnableFuture<KeyGroupsStateHandle> snapshot( - long checkpointId, - long timestamp, - CheckpointStreamFactory streamFactory, + public RunnableFuture<KeyGroupsStateHandle> snapshot( + final long checkpointId, + final long timestamp, + final CheckpointStreamFactory streamFactory, CheckpointOptions checkpointOptions) throws Exception { if (stateTables.isEmpty()) { return new DoneFuture<>(null); } - try (CheckpointStreamFactory.CheckpointStateOutputStream stream = streamFactory. - createCheckpointStateOutputStream(checkpointId, timestamp)) { - - DataOutputViewStreamWrapper outView = new DataOutputViewStreamWrapper(stream); + long syncStartTime = System.currentTimeMillis(); - Preconditions.checkState(stateTables.size() <= Short.MAX_VALUE, - "Too many KV-States: " + stateTables.size() + - ". Currently at most " + Short.MAX_VALUE + " states are supported"); + Preconditions.checkState(stateTables.size() <= Short.MAX_VALUE, + "Too many KV-States: " + stateTables.size() + + ". Currently at most " + Short.MAX_VALUE + " states are supported"); - List<KeyedBackendSerializationProxy.StateMetaInfo<?, ?>> metaInfoProxyList = new ArrayList<>(stateTables.size()); + List<KeyedBackendSerializationProxy.StateMetaInfo<?, ?>> metaInfoProxyList = new ArrayList<>(stateTables.size()); - Map<String, Integer> kVStateToId = new HashMap<>(stateTables.size()); + final Map<String, Integer> kVStateToId = new HashMap<>(stateTables.size()); - for (Map.Entry<String, StateTable<K, ?, ?>> kvState : stateTables.entrySet()) { + final Map<StateTable<K, ?, ?>, StateTableSnapshot> cowStateStableSnapshots = new HashedMap(stateTables.size()); - RegisteredBackendStateMetaInfo<?, ?> metaInfo = kvState.getValue().getMetaInfo(); - KeyedBackendSerializationProxy.StateMetaInfo<?, ?> metaInfoProxy = new KeyedBackendSerializationProxy.StateMetaInfo( - metaInfo.getStateType(), - metaInfo.getName(), - metaInfo.getNamespaceSerializer(), - metaInfo.getStateSerializer()); + for (Map.Entry<String, StateTable<K, ?, ?>> kvState : stateTables.entrySet()) { + RegisteredBackendStateMetaInfo<?, ?> metaInfo = kvState.getValue().getMetaInfo(); + KeyedBackendSerializationProxy.StateMetaInfo<?, ?> metaInfoProxy = new KeyedBackendSerializationProxy.StateMetaInfo( + metaInfo.getStateType(), + metaInfo.getName(), + metaInfo.getNamespaceSerializer(), + metaInfo.getStateSerializer()); - metaInfoProxyList.add(metaInfoProxy); - kVStateToId.put(kvState.getKey(), kVStateToId.size()); + metaInfoProxyList.add(metaInfoProxy); + kVStateToId.put(kvState.getKey(), kVStateToId.size()); + StateTable<K, ?, ?> stateTable = kvState.getValue(); + if (null != stateTable) { + cowStateStableSnapshots.put(stateTable, stateTable.createSnapshot()); } + } - KeyedBackendSerializationProxy serializationProxy = - new KeyedBackendSerializationProxy(keySerializer, metaInfoProxyList); + final KeyedBackendSerializationProxy serializationProxy = + new KeyedBackendSerializationProxy(keySerializer, metaInfoProxyList); + + //--------------------------------------------------- this becomes the end of sync part + + // implementation of the async IO operation, based on FutureTask + final AbstractAsyncIOCallable<KeyGroupsStateHandle, CheckpointStreamFactory.CheckpointStateOutputStream> ioCallable = + new AbstractAsyncIOCallable<KeyGroupsStateHandle, CheckpointStreamFactory.CheckpointStateOutputStream>() { + + AtomicBoolean open = new AtomicBoolean(false); + + @Override + public CheckpointStreamFactory.CheckpointStateOutputStream openIOHandle() throws Exception { + if (open.compareAndSet(false, true)) { + CheckpointStreamFactory.CheckpointStateOutputStream stream = + streamFactory.createCheckpointStateOutputStream(checkpointId, timestamp); + try { + cancelStreamRegistry.registerClosable(stream); + return stream; + } catch (Exception ex) { + open.set(false); + throw ex; + } + } else { + throw new IOException("Operation already opened."); + } + } - serializationProxy.write(outView); + @Override + public KeyGroupsStateHandle performOperation() throws Exception { + long asyncStartTime = System.currentTimeMillis(); + CheckpointStreamFactory.CheckpointStateOutputStream stream = getIoHandle(); + DataOutputViewStreamWrapper outView = new DataOutputViewStreamWrapper(stream); + serializationProxy.write(outView); - int offsetCounter = 0; - long[] keyGroupRangeOffsets = new long[keyGroupRange.getNumberOfKeyGroups()]; + long[] keyGroupRangeOffsets = new long[keyGroupRange.getNumberOfKeyGroups()]; - for (int keyGroupIndex = keyGroupRange.getStartKeyGroup(); keyGroupIndex <= keyGroupRange.getEndKeyGroup(); keyGroupIndex++) { - keyGroupRangeOffsets[offsetCounter++] = stream.getPos(); - outView.writeInt(keyGroupIndex); - for (Map.Entry<String, StateTable<K, ?, ?>> kvState : stateTables.entrySet()) { - outView.writeShort(kVStateToId.get(kvState.getKey())); - writeStateTableForKeyGroup(outView, kvState.getValue(), keyGroupIndex); - } - } + for (int keyGroupPos = 0; keyGroupPos < keyGroupRange.getNumberOfKeyGroups(); ++keyGroupPos) { + int keyGroupId = keyGroupRange.getKeyGroupId(keyGroupPos); + keyGroupRangeOffsets[keyGroupPos] = stream.getPos(); + outView.writeInt(keyGroupId); + + for (Map.Entry<String, StateTable<K, ?, ?>> kvState : stateTables.entrySet()) { + outView.writeShort(kVStateToId.get(kvState.getKey())); + cowStateStableSnapshots.get(kvState.getValue()).writeMappingsInKeyGroup(outView, keyGroupId); + } + } + + if (open.compareAndSet(true, false)) { + StreamStateHandle streamStateHandle = stream.closeAndGetHandle(); + KeyGroupRangeOffsets offsets = new KeyGroupRangeOffsets(keyGroupRange, keyGroupRangeOffsets); + final KeyGroupsStateHandle keyGroupsStateHandle = new KeyGroupsStateHandle(offsets, streamStateHandle); + + if (asynchronousSnapshots) { + LOG.info("Heap backend snapshot ({}, asynchronous part) in thread {} took {} ms.", + streamFactory, Thread.currentThread(), (System.currentTimeMillis() - asyncStartTime)); + } + + return keyGroupsStateHandle; + } else { + throw new IOException("Checkpoint stream already closed."); + } + } + + @Override + public void done(boolean canceled) { + if (open.compareAndSet(true, false)) { + CheckpointStreamFactory.CheckpointStateOutputStream stream = getIoHandle(); + if (null != stream) { + cancelStreamRegistry.unregisterClosable(stream); + IOUtils.closeQuietly(stream); + } + } + for (StateTableSnapshot snapshot : cowStateStableSnapshots.values()) { + snapshot.release(); + } + } + }; - StreamStateHandle streamStateHandle = stream.closeAndGetHandle(); + AsyncStoppableTaskWithCallback<KeyGroupsStateHandle> task = AsyncStoppableTaskWithCallback.from(ioCallable); - KeyGroupRangeOffsets offsets = new KeyGroupRangeOffsets(keyGroupRange, keyGroupRangeOffsets); - final KeyGroupsStateHandle keyGroupsStateHandle = new KeyGroupsStateHandle(offsets, streamStateHandle); - return new DoneFuture<>(keyGroupsStateHandle); + if (!asynchronousSnapshots) { + task.run(); } + + LOG.info("Heap backend snapshot (" + streamFactory + ", synchronous part) in thread " + + Thread.currentThread() + " took " + (System.currentTimeMillis() - syncStartTime) + " ms."); + + return task; } @SuppressWarnings("deprecation") @@ -292,42 +364,12 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> { } } - private <N, S> void writeStateTableForKeyGroup( - DataOutputView outView, - StateTable<K, N, S> stateTable, - int keyGroupIndex) throws IOException { - - TypeSerializer<N> namespaceSerializer = stateTable.getNamespaceSerializer(); - TypeSerializer<S> stateSerializer = stateTable.getStateSerializer(); - - Map<N, Map<K, S>> namespaceMap = stateTable.get(keyGroupIndex); - if (namespaceMap == null) { - outView.writeByte(0); - } else { - outView.writeByte(1); - - // number of namespaces - outView.writeInt(namespaceMap.size()); - for (Map.Entry<N, Map<K, S>> namespace : namespaceMap.entrySet()) { - namespaceSerializer.serialize(namespace.getKey(), outView); - - Map<K, S> entryMap = namespace.getValue(); - - // number of entries - outView.writeInt(entryMap.size()); - for (Map.Entry<K, S> entry : entryMap.entrySet()) { - keySerializer.serialize(entry.getKey(), outView); - stateSerializer.serialize(entry.getValue(), outView); - } - } - } - } - @SuppressWarnings({"unchecked"}) private void restorePartitionedState(Collection<KeyGroupsStateHandle> state) throws Exception { + final Map<Integer, String> kvStatesById = new HashMap<>(); int numRegisteredKvStates = 0; - Map<Integer, String> kvStatesById = new HashMap<>(); + stateTables.clear(); for (KeyGroupsStateHandle keyGroupsHandle : state) { @@ -359,7 +401,7 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> { RegisteredBackendStateMetaInfo<?, ?> registeredBackendStateMetaInfo = new RegisteredBackendStateMetaInfo<>(metaInfoSerializationProxy); - stateTable = new StateTable<>(registeredBackendStateMetaInfo, keyGroupRange); + stateTable = newStateTable(registeredBackendStateMetaInfo); stateTables.put(metaInfoSerializationProxy.getStateName(), stateTable); kvStatesById.put(numRegisteredKvStates, metaInfoSerializationProxy.getStateName()); ++numRegisteredKvStates; @@ -372,20 +414,20 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> { fsDataInputStream.seek(offset); int writtenKeyGroupIndex = inView.readInt(); - assert writtenKeyGroupIndex == keyGroupIndex; + + Preconditions.checkState(writtenKeyGroupIndex == keyGroupIndex, + "Unexpected key-group in restore."); for (int i = 0; i < metaInfoList.size(); i++) { int kvStateId = inView.readShort(); - - byte isPresent = inView.readByte(); - if (isPresent == 0) { - continue; - } - StateTable<K, ?, ?> stateTable = stateTables.get(kvStatesById.get(kvStateId)); - Preconditions.checkNotNull(stateTable); - readStateTableForKeyGroup(inView, stateTable, keyGroupIndex); + StateTableByKeyGroupReader keyGroupReader = + StateTableByKeyGroupReaders.readerForVersion( + stateTable, + serializationProxy.getRestoredVersion()); + + keyGroupReader.readMappingsInKeyGroup(inView, keyGroupIndex); } } } finally { @@ -395,38 +437,12 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> { } } - private <N, S> void readStateTableForKeyGroup( - DataInputView inView, - StateTable<K, N, S> stateTable, - int keyGroupIndex) throws IOException { - - TypeSerializer<N> namespaceSerializer = stateTable.getNamespaceSerializer(); - TypeSerializer<S> stateSerializer = stateTable.getStateSerializer(); - - Map<N, Map<K, S>> namespaceMap = new HashMap<>(); - stateTable.set(keyGroupIndex, namespaceMap); - - int numNamespaces = inView.readInt(); - for (int k = 0; k < numNamespaces; k++) { - N namespace = namespaceSerializer.deserialize(inView); - Map<K, S> entryMap = new HashMap<>(); - namespaceMap.put(namespace, entryMap); - - int numEntries = inView.readInt(); - for (int l = 0; l < numEntries; l++) { - K key = keySerializer.deserialize(inView); - S state = stateSerializer.deserialize(inView); - entryMap.put(key, state); - } - } - } - @Override public String toString() { return "HeapKeyedStateBackend"; } - @SuppressWarnings({"unchecked", "rawtypes", "deprecation"}) + @SuppressWarnings({"unchecked", "rawtypes", "DeprecatedIsStillUsed"}) @Deprecated private void restoreOldSavepointKeyedState( Collection<KeyGroupsStateHandle> stateHandles) throws IOException, ClassNotFoundException { @@ -444,118 +460,18 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> { for (Map.Entry<String, KvStateSnapshot<K, ?, ?, ?>> nameToState : namedStates.entrySet()) { - KvStateSnapshot<K, ?, ?, ?> genericSnapshot = nameToState.getValue(); - - final RestoredState restoredState; - - if (genericSnapshot instanceof AbstractMemStateSnapshot) { - - AbstractMemStateSnapshot<K, ?, ?, ?, ?> stateSnapshot = - (AbstractMemStateSnapshot<K, ?, ?, ?, ?>) nameToState.getValue(); + final String stateName = nameToState.getKey(); + final KvStateSnapshot<K, ?, ?, ?> genericSnapshot = nameToState.getValue(); - restoredState = restoreHeapState(stateSnapshot); - - } else if (genericSnapshot instanceof AbstractFsStateSnapshot) { - - AbstractFsStateSnapshot<K, ?, ?, ?, ?> stateSnapshot = - (AbstractFsStateSnapshot<K, ?, ?, ?, ?>) nameToState.getValue(); - restoredState = restoreFsState(stateSnapshot); + if (genericSnapshot instanceof MigrationRestoreSnapshot) { + MigrationRestoreSnapshot<K, ?, ?> stateSnapshot = (MigrationRestoreSnapshot<K, ?, ?>) genericSnapshot; + final StateTable rawResultMap = + stateSnapshot.deserialize(stateName, this); + // add named state to the backend + stateTables.put(stateName, rawResultMap); } else { throw new IllegalStateException("Unknown state: " + genericSnapshot); } - - Map rawResultMap = restoredState.getRawResultMap(); - TypeSerializer<?> namespaceSerializer = restoredState.getNamespaceSerializer(); - TypeSerializer<?> stateSerializer = restoredState.getStateSerializer(); - - if (namespaceSerializer instanceof VoidSerializer) { - namespaceSerializer = VoidNamespaceSerializer.INSTANCE; - } - - Map nullNameSpaceFix = (Map) rawResultMap.remove(null); - - if (null != nullNameSpaceFix) { - rawResultMap.put(VoidNamespace.INSTANCE, nullNameSpaceFix); - } - - RegisteredBackendStateMetaInfo<?, ?> registeredBackendStateMetaInfo = - new RegisteredBackendStateMetaInfo<>( - StateDescriptor.Type.UNKNOWN, - nameToState.getKey(), - namespaceSerializer, - stateSerializer); - - StateTable<K, ?, ?> stateTable = new StateTable<>(registeredBackendStateMetaInfo, keyGroupRange); - stateTable.getState()[0] = rawResultMap; - - // add named state to the backend - stateTables.put(registeredBackendStateMetaInfo.getName(), stateTable); - } - } - - @SuppressWarnings("deprecation") - private RestoredState restoreHeapState(AbstractMemStateSnapshot<K, ?, ?, ?, ?> stateSnapshot) throws IOException { - return new RestoredState( - stateSnapshot.deserialize(), - stateSnapshot.getNamespaceSerializer(), - stateSnapshot.getStateSerializer()); - } - - @SuppressWarnings({"rawtypes", "unchecked", "deprecation"}) - private RestoredState restoreFsState(AbstractFsStateSnapshot<K, ?, ?, ?, ?> stateSnapshot) throws IOException { - FileSystem fs = stateSnapshot.getFilePath().getFileSystem(); - //TODO register closeable to support fast cancelation? - try (FSDataInputStream inStream = fs.open(stateSnapshot.getFilePath())) { - - DataInputViewStreamWrapper inView = new DataInputViewStreamWrapper(inStream); - - final int numNamespaces = inView.readInt(); - HashMap rawResultMap = new HashMap<>(numNamespaces); - - TypeSerializer<K> keySerializer = stateSnapshot.getKeySerializer(); - TypeSerializer<?> namespaceSerializer = stateSnapshot.getNamespaceSerializer(); - TypeSerializer<?> stateSerializer = stateSnapshot.getStateSerializer(); - - for (int i = 0; i < numNamespaces; i++) { - Object namespace = namespaceSerializer.deserialize(inView); - final int numKV = inView.readInt(); - Map<K, Object> namespaceMap = new HashMap<>(numKV); - rawResultMap.put(namespace, namespaceMap); - for (int j = 0; j < numKV; j++) { - K key = keySerializer.deserialize(inView); - Object value = stateSerializer.deserialize(inView); - namespaceMap.put(key, value); - } - } - return new RestoredState(rawResultMap, namespaceSerializer, stateSerializer); - } catch (Exception e) { - throw new IOException("Failed to restore state from file system", e); - } - } - - @SuppressWarnings("rawtypes") - static final class RestoredState { - - private final Map rawResultMap; - private final TypeSerializer<?> namespaceSerializer; - private final TypeSerializer<?> stateSerializer ; - - public RestoredState(Map rawResultMap, TypeSerializer<?> namespaceSerializer, TypeSerializer<?> stateSerializer) { - this.rawResultMap = rawResultMap; - this.namespaceSerializer = namespaceSerializer; - this.stateSerializer = stateSerializer; - } - - public Map getRawResultMap() { - return rawResultMap; - } - - public TypeSerializer<?> getNamespaceSerializer() { - return namespaceSerializer; - } - - public TypeSerializer<?> getStateSerializer() { - return stateSerializer; } } @@ -567,15 +483,7 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> { public int numStateEntries() { int sum = 0; for (StateTable<K, ?, ?> stateTable : stateTables.values()) { - for (Map namespaceMap : stateTable.getState()) { - if (namespaceMap == null) { - continue; - } - Map<?, Map> typedMap = (Map<?, Map>) namespaceMap; - for (Map entriesMap : typedMap.values()) { - sum += entriesMap.size(); - } - } + sum += stateTable.size(); } return sum; } @@ -584,22 +492,22 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> { * Returns the total number of state entries across all keys for the given namespace. */ @VisibleForTesting - @SuppressWarnings("unchecked") - public <N> int numStateEntries(N namespace) { + public int numStateEntries(Object namespace) { int sum = 0; for (StateTable<K, ?, ?> stateTable : stateTables.values()) { - for (Map namespaceMap : stateTable.getState()) { - if (namespaceMap == null) { - continue; - } - Map<?, Map> typedMap = (Map<?, Map>) namespaceMap; - Map singleNamespace = typedMap.get(namespace); - if (singleNamespace != null) { - sum += singleNamespace.size(); - } - } + sum += stateTable.sizeOfNamespace(namespace); } return sum; } -} + public <N, V> StateTable<K, N, V> newStateTable(RegisteredBackendStateMetaInfo<N, V> newMetaInfo) { + return asynchronousSnapshots ? + new CopyOnWriteStateTable<>(this, newMetaInfo) : + new NestedMapsStateTable<>(this, newMetaInfo); + } + + @Override + public boolean supportsAsynchronousSnapshots() { + return asynchronousSnapshots; + } +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/flink/blob/ab014ef9/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapListState.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapListState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapListState.java index 02c3067..d3f67f0 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapListState.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapListState.java @@ -22,14 +22,11 @@ import org.apache.flink.api.common.state.ListState; import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; -import org.apache.flink.runtime.state.KeyGroupRangeAssignment; -import org.apache.flink.runtime.state.KeyedStateBackend; import org.apache.flink.runtime.state.internal.InternalListState; import org.apache.flink.util.Preconditions; import java.io.ByteArrayOutputStream; import java.util.ArrayList; -import java.util.Map; /** * Heap-backed partitioned {@link org.apache.flink.api.common.state.ListState} that is snapshotted @@ -46,18 +43,16 @@ public class HeapListState<K, N, V> /** * Creates a new key/value state for the given hash map of key/value pairs. * - * @param backend The state backend backing that created this state. * @param stateDesc The state identifier for the state. This contains name * and can create a default state value. * @param stateTable The state tab;e to use in this kev/value state. May contain initial state. */ public HeapListState( - KeyedStateBackend<K> backend, ListStateDescriptor<V> stateDesc, StateTable<K, N, ArrayList<V>> stateTable, TypeSerializer<K> keySerializer, TypeSerializer<N> namespaceSerializer) { - super(backend, stateDesc, stateTable, keySerializer, namespaceSerializer); + super(stateDesc, stateTable, keySerializer, namespaceSerializer); } // ------------------------------------------------------------------------ @@ -66,55 +61,24 @@ public class HeapListState<K, N, V> @Override public Iterable<V> get() { - Preconditions.checkState(currentNamespace != null, "No namespace set."); - Preconditions.checkState(backend.getCurrentKey() != null, "No key set."); - - Map<N, Map<K, ArrayList<V>>> namespaceMap = - stateTable.get(backend.getCurrentKeyGroupIndex()); - - if (namespaceMap == null) { - return null; - } - - Map<K, ArrayList<V>> keyedMap = namespaceMap.get(currentNamespace); - - if (keyedMap == null) { - return null; - } - - return keyedMap.get(backend.<K>getCurrentKey()); + return stateTable.get(currentNamespace); } @Override public void add(V value) { - Preconditions.checkState(currentNamespace != null, "No namespace set."); - Preconditions.checkState(backend.getCurrentKey() != null, "No key set."); + final N namespace = currentNamespace; if (value == null) { clear(); return; } - Map<N, Map<K, ArrayList<V>>> namespaceMap = - stateTable.get(backend.getCurrentKeyGroupIndex()); - - if (namespaceMap == null) { - namespaceMap = createNewMap(); - stateTable.set(backend.getCurrentKeyGroupIndex(), namespaceMap); - } - - Map<K, ArrayList<V>> keyedMap = namespaceMap.get(currentNamespace); - - if (keyedMap == null) { - keyedMap = createNewMap(); - namespaceMap.put(currentNamespace, keyedMap); - } - - ArrayList<V> list = keyedMap.get(backend.<K>getCurrentKey()); + final StateTable<K, N, ArrayList<V>> map = stateTable; + ArrayList<V> list = map.get(namespace); if (list == null) { list = new ArrayList<>(); - keyedMap.put(backend.<K>getCurrentKey(), list); + map.put(namespace, list); } list.add(value); } @@ -124,20 +88,7 @@ public class HeapListState<K, N, V> Preconditions.checkState(namespace != null, "No namespace given."); Preconditions.checkState(key != null, "No key given."); - Map<N, Map<K, ArrayList<V>>> namespaceMap = - stateTable.get(KeyGroupRangeAssignment.assignToKeyGroup(key, backend.getNumberOfKeyGroups())); - - if (namespaceMap == null) { - return null; - } - - Map<K, ArrayList<V>> keyedMap = namespaceMap.get(currentNamespace); - - if (keyedMap == null) { - return null; - } - - ArrayList<V> result = keyedMap.get(key); + ArrayList<V> result = stateTable.get(key, namespace); if (result == null) { return null;