This is an automated email from the ASF dual-hosted git repository. dwysakowicz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit 5010d2ae8f77f20304411b34b7893518da9cdf66 Author: Dawid Wysakowicz <dwysakow...@apache.org> AuthorDate: Mon Jan 25 09:05:49 2021 +0100 [FLINK-21206] Implement HeapKeyValueStateIterator --- .../flink/runtime/state/KeyValueStateIterator.java | 4 +- .../state/heap/HeapKeyValueStateIterator.java | 421 +++++++++++++++++++++ .../state/SavepointStateBackendSwitchTestBase.java | 65 ++-- 3 files changed, 467 insertions(+), 23 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyValueStateIterator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyValueStateIterator.java index 28a3266..cf9cf16 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyValueStateIterator.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyValueStateIterator.java @@ -18,6 +18,8 @@ package org.apache.flink.runtime.state; +import java.io.IOException; + /** * Iterator that over all key-value state entries in a {@link KeyedStateBackend}. For use during * snapshotting. @@ -31,7 +33,7 @@ public interface KeyValueStateIterator extends AutoCloseable { * Advances the iterator. Should only be called if {@link #isValid()} returned true. Valid flag * can only change after calling {@link #next()}. */ - void next(); + void next() throws IOException; /** Returns the key-group for the current key. */ int keyGroup(); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyValueStateIterator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyValueStateIterator.java new file mode 100644 index 0000000..3f37640 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyValueStateIterator.java @@ -0,0 +1,421 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state.heap; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.base.ListSerializer; +import org.apache.flink.api.common.typeutils.base.MapSerializer; +import org.apache.flink.core.memory.DataOutputSerializer; +import org.apache.flink.runtime.state.CompositeKeySerializationUtils; +import org.apache.flink.runtime.state.IterableStateSnapshot; +import org.apache.flink.runtime.state.KeyGroupRange; +import org.apache.flink.runtime.state.KeyValueStateIterator; +import org.apache.flink.runtime.state.ListDelimitedSerializer; +import org.apache.flink.runtime.state.RegisteredKeyValueStateBackendMetaInfo; +import org.apache.flink.runtime.state.RegisteredPriorityQueueStateBackendMetaInfo; +import org.apache.flink.runtime.state.SerializedCompositeKeyBuilder; +import org.apache.flink.runtime.state.StateEntry; +import org.apache.flink.runtime.state.StateSnapshot; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; +import javax.annotation.concurrent.NotThreadSafe; + +import java.io.IOException; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * A {@link org.apache.flink.runtime.state.KeyValueStateIterator} over Heap backend snapshot + * resources. + */ +@Internal +@NotThreadSafe +public final class HeapKeyValueStateIterator implements KeyValueStateIterator { + + private static final byte[] EMPTY_BYTE_ARRAY = new byte[0]; + + private final Map<StateUID, Integer> stateNamesToId; + private final Map<StateUID, StateSnapshot> stateStableSnapshots; + private final int keyGroupPrefixBytes; + + private boolean isValid; + private boolean newKeyGroup; + private boolean newKVState; + private byte[] currentKey; + private byte[] currentValue; + + /** Iterator over the key groups of the corresponding key group range. */ + private final Iterator<Integer> keyGroupIterator; + /** The current value of the keyGroupIterator. */ + private int currentKeyGroup; + + /** Iterator over all states present in the snapshots. */ + private Iterator<StateUID> statesIterator; + /** The current value of the statesIterator. */ + private StateUID currentState; + /** + * An iterator over the values of the current state. It can be one of three: + * + * <ul> + * <li>{@link QueueIterator} for iterating over entries in a priority queue + * <li>{@link StateTableIterator} for iterating over entries in a StateTable + * <li>{@link MapStateIterator} for iterating over entries in a user map, this one falls back + * to the upper one automatically if exhausted + * </ul> + */ + private SingleStateIterator currentStateIterator; + /** Helpers for serializing state into the unified format. */ + private final DataOutputSerializer valueOut = new DataOutputSerializer(64); + + private final ListDelimitedSerializer listDelimitedSerializer = new ListDelimitedSerializer(); + private final SerializedCompositeKeyBuilder<Object> compositeKeyBuilder; + + public HeapKeyValueStateIterator( + @Nonnull final KeyGroupRange keyGroupRange, + @Nonnull final TypeSerializer<?> keySerializer, + @Nonnegative final int totalKeyGroups, + @Nonnull final Map<StateUID, Integer> stateNamesToId, + @Nonnull final Map<StateUID, StateSnapshot> stateSnapshots) + throws IOException { + checkNotNull(keyGroupRange); + checkNotNull(keySerializer); + this.stateNamesToId = checkNotNull(stateNamesToId); + this.stateStableSnapshots = checkNotNull(stateSnapshots); + + this.statesIterator = stateSnapshots.keySet().iterator(); + this.keyGroupIterator = keyGroupRange.iterator(); + + this.keyGroupPrefixBytes = + CompositeKeySerializationUtils.computeRequiredBytesInKeyGroupPrefix(totalKeyGroups); + this.compositeKeyBuilder = + new SerializedCompositeKeyBuilder<>( + castToType(keySerializer), keyGroupPrefixBytes, 32); + + if (!keyGroupIterator.hasNext() || !statesIterator.hasNext()) { + // stop early, no key groups or states + isValid = false; + } else { + currentKeyGroup = keyGroupIterator.next(); + next(); + this.newKeyGroup = true; + } + } + + @Override + public boolean isValid() { + return isValid; + } + + @Override + public boolean isNewKeyValueState() { + return this.newKVState; + } + + @Override + public boolean isNewKeyGroup() { + return this.newKeyGroup; + } + + @Override + public int keyGroup() { + return currentKeyGroup; + } + + @Override + public int kvStateId() { + return stateNamesToId.get(currentState); + } + + @Override + public void next() throws IOException { + this.newKVState = false; + this.newKeyGroup = false; + + boolean nextElementSet = false; + do { + if (currentState == null) { + boolean hasNextState = moveToNextState(); + if (!hasNextState) { + isValid = false; + return; + } + } + + boolean hasStateEntry = currentStateIterator != null && currentStateIterator.hasNext(); + if (!hasStateEntry) { + this.currentState = null; + } + + if (hasStateEntry) { + nextElementSet = currentStateIterator.writeOutNext(); + } + } while (!nextElementSet); + isValid = true; + } + + private boolean moveToNextState() throws IOException { + if (statesIterator.hasNext()) { + this.currentState = statesIterator.next(); + this.newKVState = true; + } else if (keyGroupIterator.hasNext()) { + this.currentKeyGroup = keyGroupIterator.next(); + resetStates(); + this.newKeyGroup = true; + this.newKVState = true; + } else { + return false; + } + + StateSnapshot stateSnapshot = this.stateStableSnapshots.get(currentState); + setCurrentStateIterator(stateSnapshot); + + // set to a valid entry + return true; + } + + private void resetStates() { + this.statesIterator = stateStableSnapshots.keySet().iterator(); + this.currentState = statesIterator.next(); + } + + @SuppressWarnings("unchecked") + private void setCurrentStateIterator(StateSnapshot stateSnapshot) throws IOException { + if (stateSnapshot instanceof IterableStateSnapshot) { + RegisteredKeyValueStateBackendMetaInfo<Object, Object> metaInfo = + new RegisteredKeyValueStateBackendMetaInfo<>( + stateSnapshot.getMetaInfoSnapshot()); + Iterator<? extends StateEntry<?, ?, ?>> snapshotIterator = + ((IterableStateSnapshot<?, ?, ?>) stateSnapshot).getIterator(currentKeyGroup); + this.currentStateIterator = new StateTableIterator(snapshotIterator, metaInfo); + } else if (stateSnapshot instanceof HeapPriorityQueueStateSnapshot) { + Iterator<Object> snapshotIterator = + ((HeapPriorityQueueStateSnapshot<Object>) stateSnapshot) + .getIteratorForKeyGroup(currentKeyGroup); + RegisteredPriorityQueueStateBackendMetaInfo<Object> metaInfo = + new RegisteredPriorityQueueStateBackendMetaInfo<>( + stateSnapshot.getMetaInfoSnapshot()); + this.currentStateIterator = new QueueIterator<>(snapshotIterator, metaInfo); + } else { + throw new IllegalStateException("Unknown snapshot type: " + stateSnapshot); + } + } + + /** A common interface for writing out a single entry in a state. */ + private interface SingleStateIterator { + + boolean hasNext(); + + /** + * Sets the {@link #currentKey} and {@link #currentValue} to the value of the next entry in + * the state. + * + * @return false if an entry was empty. It can be the case if we try to serialize an empty + * Map or List. In that case we should skip to a next entry. + */ + boolean writeOutNext() throws IOException; + } + + private final class StateTableIterator implements SingleStateIterator { + + private final Iterator<? extends StateEntry<?, ?, ?>> entriesIterator; + private final RegisteredKeyValueStateBackendMetaInfo<?, ?> stateSnapshot; + + private StateTableIterator( + Iterator<? extends StateEntry<?, ?, ?>> entriesIterator, + RegisteredKeyValueStateBackendMetaInfo<?, ?> stateSnapshot) { + this.entriesIterator = entriesIterator; + this.stateSnapshot = stateSnapshot; + } + + @Override + public boolean hasNext() { + return entriesIterator.hasNext(); + } + + @Override + public boolean writeOutNext() throws IOException { + StateEntry<?, ?, ?> currentEntry = entriesIterator.next(); + valueOut.clear(); + compositeKeyBuilder.setKeyAndKeyGroup(currentEntry.getKey(), keyGroup()); + compositeKeyBuilder.setNamespace( + currentEntry.getNamespace(), + castToType(stateSnapshot.getNamespaceSerializer())); + TypeSerializer<?> stateSerializer = stateSnapshot.getStateSerializer(); + switch (stateSnapshot.getStateType()) { + case AGGREGATING: + case REDUCING: + case FOLDING: + case VALUE: + return writeOutValue(currentEntry, stateSerializer); + case LIST: + return writeOutList(currentEntry, stateSerializer); + case MAP: + return writeOutMap(currentEntry, stateSerializer); + default: + throw new IllegalStateException(""); + } + } + + private boolean writeOutValue( + StateEntry<?, ?, ?> currentEntry, TypeSerializer<?> stateSerializer) + throws IOException { + currentKey = compositeKeyBuilder.build(); + castToType(stateSerializer).serialize(currentEntry.getState(), valueOut); + currentValue = valueOut.getCopyOfBuffer(); + return true; + } + + @SuppressWarnings("unchecked") + private boolean writeOutList( + StateEntry<?, ?, ?> currentEntry, TypeSerializer<?> stateSerializer) + throws IOException { + List<Object> state = (List<Object>) currentEntry.getState(); + if (state.isEmpty()) { + return false; + } + ListSerializer<Object> listSerializer = (ListSerializer<Object>) stateSerializer; + currentKey = compositeKeyBuilder.build(); + currentValue = + listDelimitedSerializer.serializeList( + state, listSerializer.getElementSerializer()); + return true; + } + + @SuppressWarnings("unchecked") + private boolean writeOutMap( + StateEntry<?, ?, ?> currentEntry, TypeSerializer<?> stateSerializer) + throws IOException { + Map<Object, Object> state = (Map<Object, Object>) currentEntry.getState(); + if (state.isEmpty()) { + return false; + } + MapSerializer<Object, Object> mapSerializer = + (MapSerializer<Object, Object>) stateSerializer; + currentStateIterator = + new MapStateIterator( + state, + mapSerializer.getKeySerializer(), + mapSerializer.getValueSerializer(), + this); + return currentStateIterator.writeOutNext(); + } + } + + private final class MapStateIterator implements SingleStateIterator { + + private final Iterator<Map.Entry<Object, Object>> mapEntries; + private final TypeSerializer<Object> userKeySerializer; + private final TypeSerializer<Object> userValueSerializer; + private final StateTableIterator parentIterator; + + private MapStateIterator( + Map<Object, Object> mapEntries, + TypeSerializer<Object> userKeySerializer, + TypeSerializer<Object> userValueSerializer, + StateTableIterator parentIterator) { + assert !mapEntries.isEmpty(); + this.mapEntries = mapEntries.entrySet().iterator(); + this.userKeySerializer = userKeySerializer; + this.userValueSerializer = userValueSerializer; + this.parentIterator = parentIterator; + } + + @Override + public boolean hasNext() { + // we should never end up here with an exhausted map iterator + // if an iterator is exhausted in the writeOutNext we switch back to + // the originating StateTableIterator + assert mapEntries.hasNext(); + return true; + } + + @Override + public boolean writeOutNext() throws IOException { + Map.Entry<Object, Object> entry = mapEntries.next(); + valueOut.clear(); + currentKey = + compositeKeyBuilder.buildCompositeKeyUserKey(entry.getKey(), userKeySerializer); + Object userValue = entry.getValue(); + valueOut.writeBoolean(userValue == null); + userValueSerializer.serialize(userValue, valueOut); + currentValue = valueOut.getCopyOfBuffer(); + + if (!mapEntries.hasNext()) { + currentStateIterator = parentIterator; + } + return true; + } + } + + private final class QueueIterator<T> implements SingleStateIterator { + private final Iterator<T> elementsForKeyGroup; + private final RegisteredPriorityQueueStateBackendMetaInfo<T> metaInfo; + private final DataOutputSerializer keyOut = new DataOutputSerializer(128); + private final int afterKeyMark; + + public QueueIterator( + Iterator<T> elementsForKeyGroup, + RegisteredPriorityQueueStateBackendMetaInfo<T> metaInfo) + throws IOException { + this.elementsForKeyGroup = elementsForKeyGroup; + this.metaInfo = metaInfo; + CompositeKeySerializationUtils.writeKeyGroup(keyGroup(), keyGroupPrefixBytes, keyOut); + afterKeyMark = keyOut.length(); + } + + @Override + public boolean hasNext() { + return elementsForKeyGroup.hasNext(); + } + + @Override + public boolean writeOutNext() throws IOException { + currentValue = EMPTY_BYTE_ARRAY; + keyOut.setPosition(afterKeyMark); + T next = elementsForKeyGroup.next(); + metaInfo.getElementSerializer().serialize(next, keyOut); + currentKey = keyOut.getCopyOfBuffer(); + return true; + } + } + + @SuppressWarnings("unchecked") + @Nonnull + private static <T> TypeSerializer<T> castToType(@Nonnull TypeSerializer<?> serializer) { + return (TypeSerializer<T>) serializer; + } + + @Override + public byte[] key() { + return currentKey; + } + + @Override + public byte[] value() { + return currentValue; + } + + @Override + public void close() {} +} diff --git a/flink-tests/src/test/java/org/apache/flink/test/state/SavepointStateBackendSwitchTestBase.java b/flink-tests/src/test/java/org/apache/flink/test/state/SavepointStateBackendSwitchTestBase.java index 933b8df..d54b667 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/state/SavepointStateBackendSwitchTestBase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/state/SavepointStateBackendSwitchTestBase.java @@ -56,6 +56,7 @@ import java.util.Map; import java.util.concurrent.RunnableFuture; import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.contains; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThat; @@ -101,6 +102,7 @@ public abstract class SavepointStateBackendSwitchTestBase { final Integer namespace1 = 1; final Integer namespace2 = 2; final Integer namespace3 = 3; + final Integer namespace4 = 4; try (final CheckpointableKeyedStateBackend<String> keyedBackend = fromBackend.createBackend( @@ -113,7 +115,8 @@ public abstract class SavepointStateBackendSwitchTestBase { listStateDescriptor, namespace1, namespace2, - namespace3); + namespace3, + namespace4); } final SnapshotResult<KeyedStateHandle> stateHandles; @@ -136,6 +139,7 @@ public abstract class SavepointStateBackendSwitchTestBase { namespace1, namespace2, namespace3, + namespace4, keyedBackend); } } @@ -159,7 +163,8 @@ public abstract class SavepointStateBackendSwitchTestBase { ListStateDescriptor<Long> listStateDescriptor, Integer namespace1, Integer namespace2, - Integer namespace3) + Integer namespace3, + Integer namespace4) throws Exception { InternalMapState<String, Integer, Long, Long> mapState = @@ -179,6 +184,7 @@ public abstract class SavepointStateBackendSwitchTestBase { mapState.setCurrentNamespace(namespace2); mapState.put(22L, 22L); mapState.put(11L, 11L); + listState.setCurrentNamespace(namespace2); listState.add(4L); listState.add(5L); @@ -194,6 +200,7 @@ public abstract class SavepointStateBackendSwitchTestBase { mapState.put(33L, 33L); mapState.put(44L, 44L); mapState.put(55L, 55L); + valueState.setCurrentNamespace(namespace3); valueState.update(1239L); listState.setCurrentNamespace(namespace3); @@ -201,6 +208,16 @@ public abstract class SavepointStateBackendSwitchTestBase { listState.add(2L); listState.add(3L); + mapState.setCurrentNamespace(namespace4); + mapState.put(1L, 1L); + // HEAP state backend will keep an empty map as an entry in the underlying State Table + // we should skip such entries when serializing + Iterator<Map.Entry<Long, Long>> iterator = mapState.iterator(); + while (iterator.hasNext()) { + iterator.next(); + iterator.remove(); + } + KeyGroupedInternalPriorityQueue<TimerHeapInternalTimer<String, Integer>> priorityQueue = keyedBackend.create( "event-time", @@ -234,9 +251,10 @@ public abstract class SavepointStateBackendSwitchTestBase { Integer namespace1, Integer namespace2, Integer namespace3, + Integer namespace4, CheckpointableKeyedStateBackend<String> keyedBackend) throws Exception { - InternalMapState<String, Integer, Long, Long> state = + InternalMapState<String, Integer, Long, Long> mapState = keyedBackend.createInternalState(IntSerializer.INSTANCE, mapStateDescriptor); InternalValueState<String, Integer, Long> valueState = @@ -246,35 +264,38 @@ public abstract class SavepointStateBackendSwitchTestBase { keyedBackend.createInternalState(IntSerializer.INSTANCE, listStateDescriptor); keyedBackend.setCurrentKey("abc"); - state.setCurrentNamespace(namespace1); - assertEquals(33L, (long) state.get(33L)); - assertEquals(55L, (long) state.get(55L)); - assertEquals(2, getStateSize(state)); - - state.setCurrentNamespace(namespace2); - assertEquals(22L, (long) state.get(22L)); - assertEquals(11L, (long) state.get(11L)); - assertEquals(2, getStateSize(state)); + mapState.setCurrentNamespace(namespace1); + assertEquals(33L, (long) mapState.get(33L)); + assertEquals(55L, (long) mapState.get(55L)); + assertEquals(2, getStateSize(mapState)); + + mapState.setCurrentNamespace(namespace2); + assertEquals(22L, (long) mapState.get(22L)); + assertEquals(11L, (long) mapState.get(11L)); + assertEquals(2, getStateSize(mapState)); listState.setCurrentNamespace(namespace2); assertThat(listState.get(), contains(4L, 5L, 6L)); - state.setCurrentNamespace(namespace3); - assertEquals(44L, (long) state.get(44L)); - assertEquals(1, getStateSize(state)); + mapState.setCurrentNamespace(namespace3); + assertEquals(44L, (long) mapState.get(44L)); + assertEquals(1, getStateSize(mapState)); keyedBackend.setCurrentKey("mno"); - state.setCurrentNamespace(namespace3); - assertEquals(11L, (long) state.get(11L)); - assertEquals(22L, (long) state.get(22L)); - assertEquals(33L, (long) state.get(33L)); - assertEquals(44L, (long) state.get(44L)); - assertEquals(55L, (long) state.get(55L)); - assertEquals(5, getStateSize(state)); + mapState.setCurrentNamespace(namespace3); + assertEquals(11L, (long) mapState.get(11L)); + assertEquals(22L, (long) mapState.get(22L)); + assertEquals(33L, (long) mapState.get(33L)); + assertEquals(44L, (long) mapState.get(44L)); + assertEquals(55L, (long) mapState.get(55L)); + assertEquals(5, getStateSize(mapState)); valueState.setCurrentNamespace(namespace3); assertEquals(1239L, (long) valueState.value()); listState.setCurrentNamespace(namespace3); assertThat(listState.get(), contains(1L, 2L, 3L)); + mapState.setCurrentNamespace(namespace4); + assertThat(mapState.isEmpty(), is(true)); + KeyGroupedInternalPriorityQueue<TimerHeapInternalTimer<String, Integer>> priorityQueue = keyedBackend.create( "event-time",