This is an automated email from the ASF dual-hosted git repository. dwysakowicz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push: new 351d6e9 [FLINK-21167] Make StateTable snapshots iterable 351d6e9 is described below commit 351d6e924db155518113d21ab99c566740a0f194 Author: Dawid Wysakowicz <dwysakow...@apache.org> AuthorDate: Wed Jan 27 21:04:38 2021 +0100 [FLINK-21167] Make StateTable snapshots iterable In order to implement an iterator required by a binary unified savepoint we need a way to iterate a snapshot. --- .../flink/runtime/state/IterableStateSnapshot.java | 32 ++++++ .../org/apache/flink/runtime/state/StateEntry.java | 9 ++ .../state/heap/AbstractStateTableSnapshot.java | 16 ++- .../state/heap/CopyOnWriteStateMapSnapshot.java | 25 +++-- .../runtime/state/heap/NestedStateMapSnapshot.java | 23 ++++ .../flink/runtime/state/heap/StateMapSnapshot.java | 8 ++ .../state/heap/CopyOnWriteStateMapTest.java | 95 ++++++++++++++++- .../state/heap/NestedMapsStateTableTest.java | 118 +++++++++++++++++++++ .../runtime/state/testutils/StateEntryMatcher.java | 58 ++++++++++ .../flink-statebackend-heap-spillable/pom.xml | 14 ++- .../heap/CopyOnWriteSkipListStateMapSnapshot.java | 64 +++++++++++ .../CopyOnWriteSkipListStateMapComplexOpTest.java | 56 ++++++++-- .../heap/CopyOnWriteSkipListStateMapTestUtils.java | 103 +++++++++++++++--- 13 files changed, 582 insertions(+), 39 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/IterableStateSnapshot.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/IterableStateSnapshot.java new file mode 100644 index 0000000..e68c74aa --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/IterableStateSnapshot.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state; + +import org.apache.flink.annotation.Internal; + +import java.util.Iterator; + +/** + * A {@link StateSnapshot} that can return an iterator over all contained {@link StateEntry + * StateEntries}. + */ +@Internal +public interface IterableStateSnapshot<K, N, S> extends StateSnapshot { + Iterator<StateEntry<K, N, S>> getIterator(int keyGroup); +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateEntry.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateEntry.java index b5b8fe8..af3d1ed 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateEntry.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateEntry.java @@ -36,6 +36,15 @@ public interface StateEntry<K, N, S> { /** Returns the state of this entry. */ S getState(); + default StateEntry<K, N, S> filterOrTransform(StateSnapshotTransformer<S> transformer) { + S newState = transformer.filterOrTransform(getState()); + if (newState != null) { + return new SimpleStateEntry<>(getKey(), getNamespace(), newState); + } else { + return null; + } + } + class SimpleStateEntry<K, N, S> implements StateEntry<K, N, S> { private final K key; private final N namespace; diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/AbstractStateTableSnapshot.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/AbstractStateTableSnapshot.java index 1ea074e..5955834 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/AbstractStateTableSnapshot.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/AbstractStateTableSnapshot.java @@ -21,6 +21,8 @@ package org.apache.flink.runtime.state.heap; import org.apache.flink.annotation.Internal; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.runtime.state.IterableStateSnapshot; +import org.apache.flink.runtime.state.StateEntry; import org.apache.flink.runtime.state.StateSnapshot; import org.apache.flink.runtime.state.StateSnapshotTransformer; import org.apache.flink.runtime.state.metainfo.StateMetaInfoSnapshot; @@ -30,6 +32,7 @@ import javax.annotation.Nonnull; import javax.annotation.Nullable; import java.io.IOException; +import java.util.Iterator; /** * Abstract base class for snapshots of a {@link StateTable}. Offers a way to serialize the snapshot @@ -37,7 +40,7 @@ import java.io.IOException; */ @Internal abstract class AbstractStateTableSnapshot<K, N, S> - implements StateSnapshot, StateSnapshot.StateKeyGroupWriter { + implements IterableStateSnapshot<K, N, S>, StateSnapshot.StateKeyGroupWriter { /** The {@link StateTable} from which this snapshot was created. */ protected final StateTable<K, N, S> owningStateTable; @@ -88,6 +91,17 @@ abstract class AbstractStateTableSnapshot<K, N, S> return this; } + @Override + public Iterator<StateEntry<K, N, S>> getIterator(int keyGroupId) { + StateMapSnapshot<K, N, S, ? extends StateMap<K, N, S>> stateMapSnapshot = + getStateMapSnapshotForKeyGroup(keyGroupId); + return stateMapSnapshot.getIterator( + localKeySerializer, + localNamespaceSerializer, + localStateSerializer, + stateSnapshotTransformer); + } + /** * Implementation note: we currently chose the same format between {@link NestedMapsStateTable} * and {@link CopyOnWriteStateTable}. diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateMapSnapshot.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateMapSnapshot.java index 567ad35..d6c06f0 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateMapSnapshot.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateMapSnapshot.java @@ -111,6 +111,19 @@ public class CopyOnWriteStateMapSnapshot<K, N, S> } @Override + public SnapshotIterator<K, N, S> getIterator( + @Nonnull TypeSerializer<K> keySerializer, + @Nonnull TypeSerializer<N> namespaceSerializer, + @Nonnull TypeSerializer<S> stateSerializer, + @Nullable final StateSnapshotTransformer<S> stateSnapshotTransformer) { + + return stateSnapshotTransformer == null + ? new NonTransformSnapshotIterator<>(numberOfEntriesInSnapshotData, snapshotData) + : new TransformedSnapshotIterator<>( + numberOfEntriesInSnapshotData, snapshotData, stateSnapshotTransformer); + } + + @Override public void writeState( TypeSerializer<K> keySerializer, TypeSerializer<N> namespaceSerializer, @@ -119,13 +132,11 @@ public class CopyOnWriteStateMapSnapshot<K, N, S> @Nullable StateSnapshotTransformer<S> stateSnapshotTransformer) throws IOException { SnapshotIterator<K, N, S> snapshotIterator = - stateSnapshotTransformer == null - ? new NonTransformSnapshotIterator<>( - numberOfEntriesInSnapshotData, snapshotData) - : new TransformedSnapshotIterator<>( - numberOfEntriesInSnapshotData, - snapshotData, - stateSnapshotTransformer); + getIterator( + keySerializer, + namespaceSerializer, + stateSerializer, + stateSnapshotTransformer); int size = snapshotIterator.size(); dov.writeInt(size); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/NestedStateMapSnapshot.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/NestedStateMapSnapshot.java index 4a3f7bd..e58cf91 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/NestedStateMapSnapshot.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/NestedStateMapSnapshot.java @@ -20,6 +20,7 @@ package org.apache.flink.runtime.state.heap; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.runtime.state.StateEntry; import org.apache.flink.runtime.state.StateSnapshotTransformer; import javax.annotation.Nonnull; @@ -27,7 +28,11 @@ import javax.annotation.Nullable; import java.io.IOException; import java.util.HashMap; +import java.util.Iterator; import java.util.Map; +import java.util.Objects; +import java.util.Spliterators; +import java.util.stream.StreamSupport; /** * This class represents the snapshot of a {@link NestedStateMap}. @@ -49,6 +54,24 @@ public class NestedStateMapSnapshot<K, N, S> } @Override + public Iterator<StateEntry<K, N, S>> getIterator( + @Nonnull TypeSerializer<K> keySerializer, + @Nonnull TypeSerializer<N> namespaceSerializer, + @Nonnull TypeSerializer<S> stateSerializer, + @Nullable StateSnapshotTransformer<S> stateSnapshotTransformer) { + if (stateSnapshotTransformer == null) { + return owningStateMap.iterator(); + } else { + return StreamSupport.stream( + Spliterators.spliteratorUnknownSize(owningStateMap.iterator(), 0), + false) + .map(entry -> entry.filterOrTransform(stateSnapshotTransformer)) + .filter(Objects::nonNull) + .iterator(); + } + } + + @Override public void writeState( TypeSerializer<K> keySerializer, TypeSerializer<N> namespaceSerializer, diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateMapSnapshot.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateMapSnapshot.java index f343b94..cafad96 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateMapSnapshot.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateMapSnapshot.java @@ -20,6 +20,7 @@ package org.apache.flink.runtime.state.heap; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.runtime.state.StateEntry; import org.apache.flink.runtime.state.StateSnapshotTransformer; import org.apache.flink.util.Preconditions; @@ -27,6 +28,7 @@ import javax.annotation.Nonnull; import javax.annotation.Nullable; import java.io.IOException; +import java.util.Iterator; /** * Base class for snapshots of a {@link StateMap}. @@ -52,6 +54,12 @@ public abstract class StateMapSnapshot<K, N, S, T extends StateMap<K, N, S>> { /** Release the snapshot. */ public void release() {} + public abstract Iterator<StateEntry<K, N, S>> getIterator( + @Nonnull TypeSerializer<K> keySerializer, + @Nonnull TypeSerializer<N> namespaceSerializer, + @Nonnull TypeSerializer<S> stateSerializer, + @Nullable final StateSnapshotTransformer<S> stateSnapshotTransformer); + /** * Writes the state in this snapshot to output. The state need to be transformed with the given * transformer if the transformer is non-null. diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateMapTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateMapTest.java index a9d1787..ff9615a 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateMapTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateMapTest.java @@ -19,10 +19,13 @@ package org.apache.flink.runtime.state.heap; import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.api.common.typeutils.base.ListSerializer; +import org.apache.flink.api.common.typeutils.base.LongSerializer; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.runtime.state.ArrayListSerializer; import org.apache.flink.runtime.state.StateEntry; +import org.apache.flink.runtime.state.StateSnapshotTransformer; import org.apache.flink.runtime.state.StateTransformationFunction; import org.apache.flink.runtime.state.internal.InternalKvState.StateIncrementalVisitor; import org.apache.flink.util.TestLogger; @@ -31,13 +34,22 @@ import org.hamcrest.Matchers; import org.junit.Assert; import org.junit.Test; +import javax.annotation.Nullable; + import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; import java.util.HashMap; +import java.util.Iterator; +import java.util.List; import java.util.Map; import java.util.Random; +import static org.apache.flink.runtime.state.testutils.StateEntryMatcher.entry; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.junit.Assert.assertThat; + /** Test for {@link CopyOnWriteStateMap}. */ public class CopyOnWriteStateMapTest extends TestLogger { @@ -398,6 +410,82 @@ public class CopyOnWriteStateMapTest extends TestLogger { Assert.assertSame(originalState5, stateMap.get(5, 1)); } + @Test + public void testIteratingOverSnapshot() { + ListSerializer<Integer> stateSerializer = new ListSerializer<>(IntSerializer.INSTANCE); + final CopyOnWriteStateMap<Integer, Integer, List<Integer>> stateMap = + new CopyOnWriteStateMap<>(stateSerializer); + + List<Integer> originalState1 = new ArrayList<>(1); + List<Integer> originalState2 = new ArrayList<>(1); + List<Integer> originalState3 = new ArrayList<>(1); + List<Integer> originalState4 = new ArrayList<>(1); + List<Integer> originalState5 = new ArrayList<>(1); + + originalState1.add(1); + originalState2.add(2); + originalState3.add(3); + originalState4.add(4); + originalState5.add(5); + + stateMap.put(1, 1, originalState1); + stateMap.put(2, 1, originalState2); + stateMap.put(3, 1, originalState3); + stateMap.put(4, 1, originalState4); + stateMap.put(5, 1, originalState5); + + CopyOnWriteStateMapSnapshot<Integer, Integer, List<Integer>> snapshot = + stateMap.stateSnapshot(); + + Iterator<StateEntry<Integer, Integer, List<Integer>>> iterator = + snapshot.getIterator( + IntSerializer.INSTANCE, IntSerializer.INSTANCE, stateSerializer, null); + assertThat( + () -> iterator, + containsInAnyOrder( + entry(1, 1, originalState1), + entry(2, 1, originalState2), + entry(3, 1, originalState3), + entry(4, 1, originalState4), + entry(5, 1, originalState5))); + } + + @Test + public void testIteratingOverSnapshotWithTransform() { + final CopyOnWriteStateMap<Integer, Integer, Long> stateMap = + new CopyOnWriteStateMap<>(LongSerializer.INSTANCE); + + stateMap.put(1, 1, 10L); + stateMap.put(2, 1, 11L); + stateMap.put(3, 1, 12L); + stateMap.put(4, 1, 13L); + stateMap.put(5, 1, 14L); + + StateMapSnapshot<Integer, Integer, Long, ? extends StateMap<Integer, Integer, Long>> + snapshot = stateMap.stateSnapshot(); + + Iterator<StateEntry<Integer, Integer, Long>> iterator = + snapshot.getIterator( + IntSerializer.INSTANCE, + IntSerializer.INSTANCE, + LongSerializer.INSTANCE, + new StateSnapshotTransformer<Long>() { + @Nullable + @Override + public Long filterOrTransform(@Nullable Long value) { + if (value == 12L) { + return null; + } else { + return value + 2L; + } + } + }); + assertThat( + () -> iterator, + containsInAnyOrder( + entry(1, 1, 12L), entry(2, 1, 13L), entry(4, 1, 15L), entry(5, 1, 16L))); + } + /** This tests that snapshot can be released correctly. */ @Test public void testSnapshotRelease() { @@ -410,16 +498,15 @@ public class CopyOnWriteStateMapTest extends TestLogger { CopyOnWriteStateMapSnapshot<Integer, Integer, Integer> snapshot = stateMap.stateSnapshot(); Assert.assertFalse(snapshot.isReleased()); - Assert.assertThat( - stateMap.getSnapshotVersions(), Matchers.contains(snapshot.getSnapshotVersion())); + assertThat(stateMap.getSnapshotVersions(), contains(snapshot.getSnapshotVersion())); snapshot.release(); Assert.assertTrue(snapshot.isReleased()); - Assert.assertThat(stateMap.getSnapshotVersions(), Matchers.empty()); + assertThat(stateMap.getSnapshotVersions(), Matchers.empty()); // verify that snapshot will release itself only once snapshot.release(); - Assert.assertThat(stateMap.getSnapshotVersions(), Matchers.empty()); + assertThat(stateMap.getSnapshotVersions(), Matchers.empty()); } @SuppressWarnings("unchecked") diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/NestedMapsStateTableTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/NestedMapsStateTableTest.java new file mode 100644 index 0000000..d8ff6c2 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/NestedMapsStateTableTest.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state.heap; + +import org.apache.flink.api.common.typeutils.base.IntSerializer; +import org.apache.flink.api.common.typeutils.base.ListSerializer; +import org.apache.flink.api.common.typeutils.base.LongSerializer; +import org.apache.flink.runtime.state.StateEntry; +import org.apache.flink.runtime.state.StateSnapshotTransformer; + +import org.junit.Test; + +import javax.annotation.Nullable; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +import static org.apache.flink.runtime.state.testutils.StateEntryMatcher.entry; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.junit.Assert.assertThat; + +/** Tests for {@link NestedMapsStateTable}. */ +public class NestedMapsStateTableTest { + @Test + public void testIteratingOverSnapshot() { + ListSerializer<Integer> stateSerializer = new ListSerializer<>(IntSerializer.INSTANCE); + final NestedStateMap<Integer, Integer, List<Integer>> stateMap = new NestedStateMap<>(); + + List<Integer> originalState1 = new ArrayList<>(1); + List<Integer> originalState2 = new ArrayList<>(1); + List<Integer> originalState3 = new ArrayList<>(1); + List<Integer> originalState4 = new ArrayList<>(1); + List<Integer> originalState5 = new ArrayList<>(1); + + originalState1.add(1); + originalState2.add(2); + originalState3.add(3); + originalState4.add(4); + originalState5.add(5); + + stateMap.put(1, 1, originalState1); + stateMap.put(2, 1, originalState2); + stateMap.put(3, 1, originalState3); + stateMap.put(4, 1, originalState4); + stateMap.put(5, 1, originalState5); + + StateMapSnapshot< + Integer, + Integer, + List<Integer>, + ? extends StateMap<Integer, Integer, List<Integer>>> + snapshot = stateMap.stateSnapshot(); + + Iterator<StateEntry<Integer, Integer, List<Integer>>> iterator = + snapshot.getIterator( + IntSerializer.INSTANCE, IntSerializer.INSTANCE, stateSerializer, null); + assertThat( + () -> iterator, + containsInAnyOrder( + entry(1, 1, originalState1), + entry(2, 1, originalState2), + entry(3, 1, originalState3), + entry(4, 1, originalState4), + entry(5, 1, originalState5))); + } + + @Test + public void testIteratingOverSnapshotWithTransform() { + final NestedStateMap<Integer, Integer, Long> stateMap = new NestedStateMap<>(); + + stateMap.put(1, 1, 10L); + stateMap.put(2, 1, 11L); + stateMap.put(3, 1, 12L); + stateMap.put(4, 1, 13L); + stateMap.put(5, 1, 14L); + + StateMapSnapshot<Integer, Integer, Long, ? extends StateMap<Integer, Integer, Long>> + snapshot = stateMap.stateSnapshot(); + + Iterator<StateEntry<Integer, Integer, Long>> iterator = + snapshot.getIterator( + IntSerializer.INSTANCE, + IntSerializer.INSTANCE, + LongSerializer.INSTANCE, + new StateSnapshotTransformer<Long>() { + @Nullable + @Override + public Long filterOrTransform(@Nullable Long value) { + if (value == 12L) { + return null; + } else { + return value + 2L; + } + } + }); + assertThat( + () -> iterator, + containsInAnyOrder( + entry(1, 1, 12L), entry(2, 1, 13L), entry(4, 1, 15L), entry(5, 1, 16L))); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/testutils/StateEntryMatcher.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/testutils/StateEntryMatcher.java new file mode 100644 index 0000000..cd1787a --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/testutils/StateEntryMatcher.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.state.testutils; + +import org.apache.flink.runtime.state.StateEntry; + +import org.hamcrest.Description; +import org.hamcrest.TypeSafeMatcher; + +import java.util.Objects; + +/** A matcher for {@link StateEntry StateEntries}. */ +public class StateEntryMatcher<K, N, S> extends TypeSafeMatcher<StateEntry<K, N, S>> { + private final K key; + private final N namespace; + private final S state; + + StateEntryMatcher(K key, N namespace, S state) { + this.key = key; + this.namespace = namespace; + this.state = state; + } + + public static <K, N, S> StateEntryMatcher<K, N, S> entry(K key, N namespace, S state) { + return new StateEntryMatcher<>(key, namespace, state); + } + + @Override + protected boolean matchesSafely(StateEntry<K, N, S> item) { + return Objects.equals(item.getKey(), key) + && Objects.equals(item.getNamespace(), namespace) + && Objects.equals(item.getState(), state); + } + + @Override + public void describeTo(Description description) { + description.appendText( + String.format( + "expected entry: key: %s, namespace: %s, state: %s", + key, namespace, state)); + } +} diff --git a/flink-state-backends/flink-statebackend-heap-spillable/pom.xml b/flink-state-backends/flink-statebackend-heap-spillable/pom.xml index 799051c..663dc88 100644 --- a/flink-state-backends/flink-statebackend-heap-spillable/pom.xml +++ b/flink-state-backends/flink-statebackend-heap-spillable/pom.xml @@ -19,8 +19,8 @@ under the License. --> <project xmlns="http://maven.apache.org/POM/4.0.0" - xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" - xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd"> + xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" + xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd"> <modelVersion>4.0.0</modelVersion> @@ -46,11 +46,19 @@ under the License. <scope>provided</scope> </dependency> - <!-- test dependencies --> + <!-- test dependencies --> <dependency> <groupId>org.apache.flink</groupId> <artifactId>flink-test-utils-junit</artifactId> </dependency> + + <dependency> + <groupId>org.apache.flink</groupId> + <artifactId>flink-runtime_${scala.binary.version}</artifactId> + <version>${project.version}</version> + <scope>test</scope> + <type>test-jar</type> + </dependency> </dependencies> </project> diff --git a/flink-state-backends/flink-statebackend-heap-spillable/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteSkipListStateMapSnapshot.java b/flink-state-backends/flink-statebackend-heap-spillable/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteSkipListStateMapSnapshot.java index ead4126..35fc52c 100644 --- a/flink-state-backends/flink-statebackend-heap-spillable/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteSkipListStateMapSnapshot.java +++ b/flink-state-backends/flink-statebackend-heap-spillable/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteSkipListStateMapSnapshot.java @@ -20,9 +20,12 @@ package org.apache.flink.runtime.state.heap; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.core.memory.DataInputDeserializer; import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.runtime.state.StateEntry; import org.apache.flink.runtime.state.StateSnapshotTransformer; import org.apache.flink.util.ResourceGuard; +import org.apache.flink.util.WrappingRuntimeException; import javax.annotation.Nonnegative; import javax.annotation.Nonnull; @@ -31,6 +34,9 @@ import javax.annotation.Nullable; import java.io.IOException; import java.util.Iterator; import java.util.NoSuchElementException; +import java.util.Objects; +import java.util.Spliterators; +import java.util.stream.StreamSupport; import static org.apache.flink.runtime.state.heap.SkipListUtils.HEAD_NODE; import static org.apache.flink.runtime.state.heap.SkipListUtils.NIL_NODE; @@ -86,6 +92,64 @@ public class CopyOnWriteSkipListStateMapSnapshot<K, N, S> } @Override + public Iterator<StateEntry<K, N, S>> getIterator( + @Nonnull TypeSerializer<K> keySerializer, + @Nonnull TypeSerializer<N> namespaceSerializer, + @Nonnull TypeSerializer<S> stateSerializer, + @Nullable StateSnapshotTransformer<S> stateSnapshotTransformer) { + SkipListValueSerializer<S> skipListValueSerializer = + new SkipListValueSerializer<>(stateSerializer); + + DataInputDeserializer inputDeserializer = new DataInputDeserializer(); + // 1. iterates nodes to get size after transform + Iterator<Tuple2<Long, Long>> transformNodeIterator = new SnapshotNodeIterator(true); + return StreamSupport.stream( + Spliterators.spliteratorUnknownSize(transformNodeIterator, 0), false) + .map( + tuple -> + transformEntry( + keySerializer, + namespaceSerializer, + stateSnapshotTransformer, + skipListValueSerializer, + inputDeserializer, + tuple)) + .filter(Objects::nonNull) + .iterator(); + } + + private StateEntry<K, N, S> transformEntry( + TypeSerializer<K> keySerializer, + TypeSerializer<N> namespaceSerializer, + StateSnapshotTransformer<S> stateSnapshotTransformer, + SkipListValueSerializer<S> skipListValueSerializer, + DataInputDeserializer inputDeserializer, + Tuple2<Long, Long> pointers) { + try { + final S oldState = owningStateMap.helpGetState(pointers.f1, skipListValueSerializer); + final S newState; + if (stateSnapshotTransformer != null) { + newState = stateSnapshotTransformer.filterOrTransform(oldState); + } else { + newState = oldState; + } + Tuple2<byte[], byte[]> keyAndNamespace = + owningStateMap.helpGetBytesForKeyAndNamespace(pointers.f0); + if (newState == null) { + return null; + } else { + inputDeserializer.setBuffer(keyAndNamespace.f0); + K key = keySerializer.deserialize(inputDeserializer); + inputDeserializer.setBuffer(keyAndNamespace.f1); + N namespace = namespaceSerializer.deserialize(inputDeserializer); + return new StateEntry.SimpleStateEntry<>(key, namespace, newState); + } + } catch (IOException e) { + throw new WrappingRuntimeException(e); + } + } + + @Override public void writeState( TypeSerializer<K> keySerializer, TypeSerializer<N> namespaceSerializer, diff --git a/flink-state-backends/flink-statebackend-heap-spillable/src/test/java/org/apache/flink/runtime/state/heap/CopyOnWriteSkipListStateMapComplexOpTest.java b/flink-state-backends/flink-statebackend-heap-spillable/src/test/java/org/apache/flink/runtime/state/heap/CopyOnWriteSkipListStateMapComplexOpTest.java index b56c0bf..76d909f 100644 --- a/flink-state-backends/flink-statebackend-heap-spillable/src/test/java/org/apache/flink/runtime/state/heap/CopyOnWriteSkipListStateMapComplexOpTest.java +++ b/flink-state-backends/flink-statebackend-heap-spillable/src/test/java/org/apache/flink/runtime/state/heap/CopyOnWriteSkipListStateMapComplexOpTest.java @@ -25,6 +25,7 @@ import org.apache.flink.api.common.typeutils.base.IntSerializer; import org.apache.flink.api.common.typeutils.base.LongSerializer; import org.apache.flink.api.common.typeutils.base.StringSerializer; import org.apache.flink.runtime.state.StateSnapshotTransformer; +import org.apache.flink.runtime.state.heap.CopyOnWriteSkipListStateMapTestUtils.SnapshotVerificationMode; import org.apache.flink.util.IOUtils; import org.apache.flink.util.TestLogger; import org.apache.flink.util.function.TriFunction; @@ -540,7 +541,18 @@ public class CopyOnWriteSkipListStateMapComplexOpTest extends TestLogger { /** Tests that remove states physically during sync part of snapshot. */ @Test + public void testPhysicallyRemoveDuringSyncPartOfSnapshotWithIterator() throws IOException { + testPhysicallyRemoveDuringSyncPartOfSnapshot(SnapshotVerificationMode.ITERATOR); + } + + /** Tests that remove states physically during sync part of snapshot. */ + @Test public void testPhysicallyRemoveDuringSyncPartOfSnapshot() throws IOException { + testPhysicallyRemoveDuringSyncPartOfSnapshot(SnapshotVerificationMode.SERIALIZED); + } + + private void testPhysicallyRemoveDuringSyncPartOfSnapshot( + SnapshotVerificationMode verificationMode) throws IOException { TestAllocator spaceAllocator = new TestAllocator(256); // set logicalRemovedKeysRatio to 0 so that all logically removed states will be deleted // when snapshot @@ -585,7 +597,12 @@ public class CopyOnWriteSkipListStateMapComplexOpTest extends TestLogger { verifyState(referenceStates, stateMap); verifySnapshotWithoutTransform( - expectedSnapshot1, snapshot1, keySerializer, namespaceSerializer, stateSerializer); + expectedSnapshot1, + snapshot1, + keySerializer, + namespaceSerializer, + stateSerializer, + verificationMode); snapshot1.release(); // no spaces should be free @@ -609,7 +626,12 @@ public class CopyOnWriteSkipListStateMapComplexOpTest extends TestLogger { assertEquals(0, spaceAllocator.getTotalSpaceNumber()); verifySnapshotWithoutTransform( - expectedSnapshot2, snapshot2, keySerializer, namespaceSerializer, stateSerializer); + expectedSnapshot2, + snapshot2, + keySerializer, + namespaceSerializer, + stateSerializer, + verificationMode); snapshot2.release(); assertEquals(0, stateMap.size()); @@ -760,6 +782,17 @@ public class CopyOnWriteSkipListStateMapComplexOpTest extends TestLogger { /** Tests concurrent snapshots. */ @Test public void testConcurrentSnapshots() throws IOException { + testConcurrentSnapshots(SnapshotVerificationMode.SERIALIZED); + } + + /** Tests concurrent snapshots. */ + @Test + public void testConcurrentSnapshotsWithIterator() throws IOException { + testConcurrentSnapshots(SnapshotVerificationMode.ITERATOR); + } + + private void testConcurrentSnapshots(SnapshotVerificationMode verificationMode) + throws IOException { TestAllocator spaceAllocator = new TestAllocator(256); // set logicalRemovedKeysRatio to 0 so that all logically removed states will be deleted // when snapshot @@ -817,7 +850,8 @@ public class CopyOnWriteSkipListStateMapComplexOpTest extends TestLogger { transformer, keySerializer, namespaceSerializer, - stateSerializer); + stateSerializer, + verificationMode); snapshot2.release(); // update states @@ -826,7 +860,12 @@ public class CopyOnWriteSkipListStateMapComplexOpTest extends TestLogger { // complete snapshot1 verifySnapshotWithoutTransform( - expectedSnapshot1, snapshot1, keySerializer, namespaceSerializer, stateSerializer); + expectedSnapshot1, + snapshot1, + keySerializer, + namespaceSerializer, + stateSerializer, + verificationMode); snapshot1.release(); // create snapshot4 @@ -846,7 +885,8 @@ public class CopyOnWriteSkipListStateMapComplexOpTest extends TestLogger { transformer, keySerializer, namespaceSerializer, - stateSerializer); + stateSerializer, + verificationMode); snapshot3.release(); // update states @@ -860,7 +900,8 @@ public class CopyOnWriteSkipListStateMapComplexOpTest extends TestLogger { transformer, keySerializer, namespaceSerializer, - stateSerializer); + stateSerializer, + verificationMode); snapshot4.release(); verifyState(referenceStates, stateMap); @@ -878,7 +919,8 @@ public class CopyOnWriteSkipListStateMapComplexOpTest extends TestLogger { transformer, keySerializer, namespaceSerializer, - stateSerializer); + stateSerializer, + verificationMode); snapshot5.release(); verifyState(referenceStates, stateMap); diff --git a/flink-state-backends/flink-statebackend-heap-spillable/src/test/java/org/apache/flink/runtime/state/heap/CopyOnWriteSkipListStateMapTestUtils.java b/flink-state-backends/flink-statebackend-heap-spillable/src/test/java/org/apache/flink/runtime/state/heap/CopyOnWriteSkipListStateMapTestUtils.java index 6bccba3..627184c 100644 --- a/flink-state-backends/flink-statebackend-heap-spillable/src/test/java/org/apache/flink/runtime/state/heap/CopyOnWriteSkipListStateMapTestUtils.java +++ b/flink-state-backends/flink-statebackend-heap-spillable/src/test/java/org/apache/flink/runtime/state/heap/CopyOnWriteSkipListStateMapTestUtils.java @@ -35,6 +35,8 @@ import org.apache.flink.runtime.state.StateSnapshotTransformer; import org.apache.flink.runtime.state.heap.space.Allocator; import org.apache.flink.runtime.state.internal.InternalKvState; +import org.hamcrest.Matcher; + import javax.annotation.Nonnull; import java.io.IOException; @@ -46,10 +48,13 @@ import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.stream.Collectors; import static org.apache.flink.runtime.state.heap.CopyOnWriteSkipListStateMap.DEFAULT_LOGICAL_REMOVED_KEYS_RATIO; import static org.apache.flink.runtime.state.heap.CopyOnWriteSkipListStateMap.DEFAULT_MAX_KEYS_TO_DELETE_ONE_TIME; import static org.apache.flink.runtime.state.heap.SkipListUtils.NIL_NODE; +import static org.apache.flink.runtime.state.testutils.StateEntryMatcher.entry; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.nullValue; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; @@ -173,6 +178,11 @@ class CopyOnWriteSkipListStateMapTestUtils { } } + enum SnapshotVerificationMode { + ITERATOR, + SERIALIZED + } + static <K, N, S> void verifySnapshotWithoutTransform( Map<N, Map<K, S>> referenceStates, @Nonnull CopyOnWriteSkipListStateMapSnapshot<K, N, S> snapshot, @@ -180,17 +190,51 @@ class CopyOnWriteSkipListStateMapTestUtils { TypeSerializer<N> namespaceSerializer, TypeSerializer<S> stateSerializer) throws IOException { + verifySnapshotWithoutTransform( + referenceStates, + snapshot, + keySerializer, + namespaceSerializer, + stateSerializer, + SnapshotVerificationMode.SERIALIZED); + } + + static <K, N, S> void verifySnapshotWithoutTransform( + Map<N, Map<K, S>> referenceStates, + @Nonnull CopyOnWriteSkipListStateMapSnapshot<K, N, S> snapshot, + TypeSerializer<K> keySerializer, + TypeSerializer<N> namespaceSerializer, + TypeSerializer<S> stateSerializer, + SnapshotVerificationMode verificationMode) + throws IOException { ByteArrayOutputStreamWithPos outputStream = new ByteArrayOutputStreamWithPos(); DataOutputView outputView = new DataOutputViewStreamWrapper(outputStream); - snapshot.writeState(keySerializer, namespaceSerializer, stateSerializer, outputView, null); + if (verificationMode == SnapshotVerificationMode.ITERATOR) { + Iterator<StateEntry<K, N, S>> iterator = + snapshot.getIterator(keySerializer, namespaceSerializer, stateSerializer, null); + assertThat(() -> iterator, containsInAnyOrder(toMatchers(referenceStates))); + } else { + snapshot.writeState( + keySerializer, namespaceSerializer, stateSerializer, outputView, null); + + Map<N, Map<K, S>> actualStates = + readStateFromSnapshot( + outputStream.toByteArray(), + keySerializer, + namespaceSerializer, + stateSerializer); + assertEquals(referenceStates, actualStates); + } + } - Map<N, Map<K, S>> actualStates = - readStateFromSnapshot( - outputStream.toByteArray(), - keySerializer, - namespaceSerializer, - stateSerializer); - assertEquals(referenceStates, actualStates); + private static <K, N, S> List<Matcher<? super StateEntry<K, N, S>>> toMatchers( + Map<N, Map<K, S>> referenceStates) { + return referenceStates.entrySet().stream() + .flatMap( + e -> + e.getValue().entrySet().stream() + .map(ks -> entry(ks.getKey(), e.getKey(), ks.getValue()))) + .collect(Collectors.toList()); } static <K, N, S> void verifySnapshotWithTransform( @@ -201,10 +245,27 @@ class CopyOnWriteSkipListStateMapTestUtils { TypeSerializer<N> namespaceSerializer, TypeSerializer<S> stateSerializer) throws IOException { + verifySnapshotWithTransform( + referenceStates, + snapshot, + transformer, + keySerializer, + namespaceSerializer, + stateSerializer, + SnapshotVerificationMode.SERIALIZED); + } + + static <K, N, S> void verifySnapshotWithTransform( + @Nonnull Map<N, Map<K, S>> referenceStates, + @Nonnull CopyOnWriteSkipListStateMapSnapshot<K, N, S> snapshot, + StateSnapshotTransformer<S> transformer, + TypeSerializer<K> keySerializer, + TypeSerializer<N> namespaceSerializer, + TypeSerializer<S> stateSerializer, + SnapshotVerificationMode verificationMode) + throws IOException { ByteArrayOutputStreamWithPos outputStream = new ByteArrayOutputStreamWithPos(); DataOutputView outputView = new DataOutputViewStreamWrapper(outputStream); - snapshot.writeState( - keySerializer, namespaceSerializer, stateSerializer, outputView, transformer); Map<N, Map<K, S>> transformedStates = new HashMap<>(); for (Map.Entry<N, Map<K, S>> namespaceEntry : referenceStates.entrySet()) { @@ -218,13 +279,21 @@ class CopyOnWriteSkipListStateMapTestUtils { } } - Map<N, Map<K, S>> actualStates = - readStateFromSnapshot( - outputStream.toByteArray(), - keySerializer, - namespaceSerializer, - stateSerializer); - assertEquals(transformedStates, actualStates); + if (verificationMode == SnapshotVerificationMode.SERIALIZED) { + snapshot.writeState( + keySerializer, namespaceSerializer, stateSerializer, outputView, transformer); + Map<N, Map<K, S>> actualStates = + readStateFromSnapshot( + outputStream.toByteArray(), + keySerializer, + namespaceSerializer, + stateSerializer); + assertEquals(transformedStates, actualStates); + } else { + Iterator<StateEntry<K, N, S>> iterator = + snapshot.getIterator(keySerializer, namespaceSerializer, stateSerializer, null); + assertThat(() -> iterator, containsInAnyOrder(toMatchers(referenceStates))); + } } private static <K, N, S> Map<N, Map<K, S>> readStateFromSnapshot(