This is an automated email from the ASF dual-hosted git repository. srichter pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit a4ad86fb083f90503938a9c3d816cdda9dc22427 Author: Stefan Richter <srich...@confluent.io> AuthorDate: Fri Oct 13 16:27:39 2023 +0200 [FLINK-33341][state] Add support for rescaling from local keyed state to PrioritizedOperatorSubtaskState. --- .../PrioritizedOperatorSubtaskState.java | 132 +++++++++++++-- .../PrioritizedOperatorSubtaskStateTest.java | 184 ++++++++++++++++++--- .../runtime/checkpoint/StateHandleDummyUtil.java | 29 +++- 3 files changed, 305 insertions(+), 40 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PrioritizedOperatorSubtaskState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PrioritizedOperatorSubtaskState.java index ef9bcd0440b..e41bcfe7338 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PrioritizedOperatorSubtaskState.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PrioritizedOperatorSubtaskState.java @@ -31,10 +31,14 @@ import javax.annotation.Nonnull; import javax.annotation.Nullable; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Map; +import java.util.Optional; import java.util.function.BiFunction; import java.util.function.Function; +import java.util.stream.Collectors; /** * This class is a wrapper over multiple alternative {@link OperatorSubtaskState} that are (partial) @@ -286,14 +290,14 @@ public class PrioritizedOperatorSubtaskState { } return new PrioritizedOperatorSubtaskState( - resolvePrioritizedAlternatives( + computePrioritizedAlternatives( jobManagerState.getManagedKeyedState(), managedKeyedAlternatives, - eqStateApprover(KeyedStateHandle::getKeyGroupRange)), - resolvePrioritizedAlternatives( + KeyedStateHandle::getKeyGroupRange), + computePrioritizedAlternatives( jobManagerState.getRawKeyedState(), rawKeyedAlternatives, - eqStateApprover(KeyedStateHandle::getKeyGroupRange)), + KeyedStateHandle::getKeyGroupRange), resolvePrioritizedAlternatives( jobManagerState.getManagedOperatorState(), managedOperatorAlternatives, @@ -313,22 +317,121 @@ public class PrioritizedOperatorSubtaskState { restoredCheckpointId); } + /** + * This method creates an alternative recovery option by replacing as much job manager state + * with higher prioritized (=local) alternatives as possible. + * + * @param jobManagerState the state that the task got assigned from the job manager (this + * state lives in remote storage). + * @param alternativesByPriority local alternatives to the job manager state, ordered by + * priority. + * @param identityExtractor function to extract an identifier from a state object. + * @return prioritized state alternatives. + * @param <STATE_OBJ_TYPE> the type of the state objects we process. + * @param <ID_TYPE> the type of object that represents the id the state object type. + */ + <STATE_OBJ_TYPE extends StateObject, ID_TYPE> + List<StateObjectCollection<STATE_OBJ_TYPE>> computePrioritizedAlternatives( + StateObjectCollection<STATE_OBJ_TYPE> jobManagerState, + List<StateObjectCollection<STATE_OBJ_TYPE>> alternativesByPriority, + Function<STATE_OBJ_TYPE, ID_TYPE> identityExtractor) { + + if (alternativesByPriority != null + && !alternativesByPriority.isEmpty() + && jobManagerState.hasState()) { + + Optional<StateObjectCollection<STATE_OBJ_TYPE>> mergedAlternative = + tryComputeMixedLocalAndRemoteAlternative( + jobManagerState, alternativesByPriority, identityExtractor); + + // Return the mix of local/remote state as first and pure remote state as second + // alternative (in case that we fail to recover from the local state, e.g. because + // of corruption). + if (mergedAlternative.isPresent()) { + return Arrays.asList(mergedAlternative.get(), jobManagerState); + } + } + + return Collections.singletonList(jobManagerState); + } + + /** + * This method creates an alternative recovery option by replacing as much job manager state + * with higher prioritized (=local) alternatives as possible. Returns empty Optional if the + * JM state is empty or nothing could be replaced. + * + * @param jobManagerState the state that the task got assigned from the job manager (this + * state lives in remote storage). + * @param alternativesByPriority local alternatives to the job manager state, ordered by + * priority. + * @param identityExtractor function to extract an identifier from a state object. + * @return A state collection where all JM state handles for which we could find local * + * alternatives are replaced by the alternative with the highest priority. Empty + * optional if no state could be replaced. + * @param <STATE_OBJ_TYPE> the type of the state objects we process. + * @param <ID_TYPE> the type of object that represents the id the state object type. + */ + static <STATE_OBJ_TYPE extends StateObject, ID_TYPE> + Optional<StateObjectCollection<STATE_OBJ_TYPE>> + tryComputeMixedLocalAndRemoteAlternative( + StateObjectCollection<STATE_OBJ_TYPE> jobManagerState, + List<StateObjectCollection<STATE_OBJ_TYPE>> alternativesByPriority, + Function<STATE_OBJ_TYPE, ID_TYPE> identityExtractor) { + + List<STATE_OBJ_TYPE> result = Collections.emptyList(); + + // Build hash index over ids of the JM state + Map<ID_TYPE, STATE_OBJ_TYPE> indexById = + jobManagerState.stream() + .collect(Collectors.toMap(identityExtractor, Function.identity())); + + // Move through all alternative in order from high to low priority + for (StateObjectCollection<STATE_OBJ_TYPE> alternative : alternativesByPriority) { + // Check all the state objects in the alternative if they can replace JM state + for (STATE_OBJ_TYPE stateHandle : alternative) { + // Remove the current state object's id from the index to check for a match + if (indexById.remove(identityExtractor.apply(stateHandle)) != null) { + if (result.isEmpty()) { + // Lazy init result collection + result = new ArrayList<>(jobManagerState.size()); + } + // If the id was still in the index, replace with higher prio alternative + result.add(stateHandle); + + // If the index is empty we are already done, all JM state was replaces with + // the best alternative. + if (indexById.isEmpty()) { + return Optional.of(new StateObjectCollection<>(result)); + } + } + } + } + + // Nothing useful to return + if (result.isEmpty()) { + return Optional.empty(); + } + + // Add all remaining JM state objects that we could not replace from the index to the + // final result + result.addAll(indexById.values()); + return Optional.of(new StateObjectCollection<>(result)); + } + /** * This helper method resolves the dependencies between the ground truth of the operator * state obtained from the job manager and potential alternatives for recovery, e.g. from a * task-local source. */ - protected <T extends StateObject> - List<StateObjectCollection<T>> resolvePrioritizedAlternatives( - StateObjectCollection<T> jobManagerState, - List<StateObjectCollection<T>> alternativesByPriority, - BiFunction<T, T, Boolean> approveFun) { + <T extends StateObject> List<StateObjectCollection<T>> resolvePrioritizedAlternatives( + StateObjectCollection<T> jobManagerState, + List<StateObjectCollection<T>> alternativesByPriority, + BiFunction<T, T, Boolean> approveFun) { // Nothing to resolve if there are no alternatives, or the ground truth has already no - // state, or if we can - // assume that a rescaling happened because we find more than one handle in the JM state - // (this is more a sanity - // check). + // state, or if we can assume that a rescaling happened because we find more than one + // handle in the JM state + // (this is more a sanity check). if (alternativesByPriority == null || alternativesByPriority.isEmpty() || !jobManagerState.hasState() @@ -347,8 +450,7 @@ public class PrioritizedOperatorSubtaskState { for (StateObjectCollection<T> alternative : alternativesByPriority) { // We found an alternative to the JM state if it has state, we have a 1:1 - // relationship, and the - // approve-function signaled true. + // relationship, and the approve-function signaled true. if (alternative != null && alternative.hasState() && alternative.size() == 1 diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PrioritizedOperatorSubtaskStateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PrioritizedOperatorSubtaskStateTest.java index 11f41bc8baa..fa892cea9da 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PrioritizedOperatorSubtaskStateTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PrioritizedOperatorSubtaskStateTest.java @@ -27,15 +27,19 @@ import org.apache.flink.runtime.state.ResultSubpartitionStateHandle; import org.apache.flink.runtime.state.StateObject; import org.apache.flink.util.Preconditions; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.Collections; import java.util.Iterator; import java.util.List; +import java.util.Map; import java.util.Random; import java.util.function.Function; +import java.util.function.IntFunction; import java.util.stream.Collectors; import static org.apache.flink.runtime.checkpoint.StateHandleDummyUtil.createNewInputChannelStateHandle; @@ -51,6 +55,107 @@ class PrioritizedOperatorSubtaskStateTest { private static final Random RANDOM = new Random(0x42); + @Test + void testTryCreateMixedLocalAndRemoteAlternative() { + testTryCreateMixedLocalAndRemoteAlternative( + StateHandleDummyUtil::createKeyedStateHandleFromSeed, + KeyedStateHandle::getKeyGroupRange); + } + + <SH extends StateObject, ID> void testTryCreateMixedLocalAndRemoteAlternative( + IntFunction<SH> stateHandleFactory, Function<SH, ID> idExtractor) { + + SH remote0 = stateHandleFactory.apply(0); + SH remote1 = stateHandleFactory.apply(1); + SH remote2 = stateHandleFactory.apply(2); + SH remote3 = stateHandleFactory.apply(3); + + List<SH> jmState = Arrays.asList(remote0, remote1, remote2, remote3); + + SH local0 = stateHandleFactory.apply(0); + SH local3a = stateHandleFactory.apply(3); + + List<SH> alternativeA = Arrays.asList(local0, local3a); + + SH local1 = stateHandleFactory.apply(1); + SH local3b = stateHandleFactory.apply(3); + SH local5 = stateHandleFactory.apply(5); + + List<SH> alternativeB = Arrays.asList(local1, local3b, local5); + + List<StateObjectCollection<SH>> alternatives = + Arrays.asList( + new StateObjectCollection<>(alternativeA), + new StateObjectCollection<>(Collections.emptyList()), + new StateObjectCollection<>(alternativeB)); + + StateObjectCollection<SH> result = + PrioritizedOperatorSubtaskState.Builder.tryComputeMixedLocalAndRemoteAlternative( + new StateObjectCollection<>(jmState), alternatives, idExtractor) + .get(); + + assertThat(result).hasSameElementsAs(Arrays.asList(local0, local1, remote2, local3a)); + } + + @Test + void testTryCreateMixedLocalAndRemoteAlternativeEmptyAlternative() { + testTryCreateMixedLocalAndRemoteAlternativeEmptyAlternative( + StateHandleDummyUtil::createKeyedStateHandleFromSeed, + KeyedStateHandle::getKeyGroupRange); + } + + <SH extends StateObject, ID> void testTryCreateMixedLocalAndRemoteAlternativeEmptyAlternative( + IntFunction<SH> stateHandleFactory, Function<SH, ID> idExtractor) { + List<SH> jmState = + Arrays.asList( + stateHandleFactory.apply(0), + stateHandleFactory.apply(1), + stateHandleFactory.apply(2), + stateHandleFactory.apply(3)); + + Assertions.assertFalse( + PrioritizedOperatorSubtaskState.Builder.tryComputeMixedLocalAndRemoteAlternative( + new StateObjectCollection<>(jmState), + Collections.emptyList(), + idExtractor) + .isPresent()); + + Assertions.assertFalse( + PrioritizedOperatorSubtaskState.Builder.tryComputeMixedLocalAndRemoteAlternative( + new StateObjectCollection<>(jmState), + Collections.singletonList(new StateObjectCollection<>()), + idExtractor) + .isPresent()); + } + + @Test + void testTryCreateMixedLocalAndRemoteAlternativeEmptyJMState() { + testTryCreateMixedLocalAndRemoteAlternativeEmptyJMState( + StateHandleDummyUtil::createKeyedStateHandleFromSeed, + KeyedStateHandle::getKeyGroupRange); + } + + <SH extends StateObject, ID> void testTryCreateMixedLocalAndRemoteAlternativeEmptyJMState( + IntFunction<SH> stateHandleFactory, Function<SH, ID> idExtractor) { + List<SH> alternativeA = + Arrays.asList(stateHandleFactory.apply(0), stateHandleFactory.apply(3)); + + Assertions.assertFalse( + PrioritizedOperatorSubtaskState.Builder.tryComputeMixedLocalAndRemoteAlternative( + new StateObjectCollection<>(Collections.emptyList()), + Collections.singletonList( + new StateObjectCollection<>(alternativeA)), + idExtractor) + .isPresent()); + + Assertions.assertFalse( + PrioritizedOperatorSubtaskState.Builder.tryComputeMixedLocalAndRemoteAlternative( + new StateObjectCollection<>(Collections.emptyList()), + Collections.emptyList(), + KeyedStateHandle::getKeyGroupRange) + .isPresent()); + } + /** * This tests attempts to test (almost) the full space of significantly different options for * verifying and prioritizing {@link OperatorSubtaskState} options for local recovery over @@ -106,16 +211,17 @@ class PrioritizedOperatorSubtaskStateTest { : onlyPrimary)) .isTrue(); - assertThat( - checkResultAsExpected( - OperatorSubtaskState::getManagedKeyedState, - PrioritizedOperatorSubtaskState - ::getPrioritizedManagedKeyedState, - prioritizedOperatorSubtaskState, - primaryAndFallback.getManagedKeyedState().size() == 1 - ? validAlternatives - : onlyPrimary)) - .isTrue(); + StateObjectCollection<KeyedStateHandle> expManagedKeyed = + computeExpectedMixedState( + orderedAlternativesList, + primaryAndFallback, + OperatorSubtaskState::getManagedKeyedState, + KeyedStateHandle::getKeyGroupRange); + + assertResultAsExpected( + expManagedKeyed, + primaryAndFallback.getManagedKeyedState(), + prioritizedOperatorSubtaskState.getPrioritizedManagedKeyedState()); assertThat( checkResultAsExpected( @@ -128,16 +234,17 @@ class PrioritizedOperatorSubtaskStateTest { : onlyPrimary)) .isTrue(); - assertThat( - checkResultAsExpected( - OperatorSubtaskState::getRawKeyedState, - PrioritizedOperatorSubtaskState - ::getPrioritizedRawKeyedState, - prioritizedOperatorSubtaskState, - primaryAndFallback.getRawKeyedState().size() == 1 - ? validAlternatives - : onlyPrimary)) - .isTrue(); + StateObjectCollection<KeyedStateHandle> expRawKeyed = + computeExpectedMixedState( + orderedAlternativesList, + primaryAndFallback, + OperatorSubtaskState::getRawKeyedState, + KeyedStateHandle::getKeyGroupRange); + + assertResultAsExpected( + expRawKeyed, + primaryAndFallback.getRawKeyedState(), + prioritizedOperatorSubtaskState.getPrioritizedRawKeyedState()); } } } @@ -390,4 +497,41 @@ class PrioritizedOperatorSubtaskStateTest { throw new IllegalStateException(); } } + + private <T extends StateObject, ID> StateObjectCollection<T> computeExpectedMixedState( + List<OperatorSubtaskState> orderedAlternativesList, + OperatorSubtaskState primaryAndFallback, + Function<OperatorSubtaskState, StateObjectCollection<T>> stateExtractor, + Function<T, ID> idExtractor) { + + List<OperatorSubtaskState> reverseAlternatives = new ArrayList<>(orderedAlternativesList); + Collections.reverse(reverseAlternatives); + + Map<ID, T> map = + stateExtractor.apply(primaryAndFallback).stream() + .collect(Collectors.toMap(idExtractor, Function.identity())); + + reverseAlternatives.stream() + .flatMap(x -> stateExtractor.apply(x).stream()) + .forEach(x -> map.replace(idExtractor.apply(x), x)); + + return new StateObjectCollection<>(map.values()); + } + + static <SH extends StateObject> void assertResultAsExpected( + StateObjectCollection<SH> expected, + StateObjectCollection<SH> primary, + List<StateObjectCollection<SH>> actual) { + Assertions.assertTrue(!actual.isEmpty() && actual.size() <= 2); + Assertions.assertTrue(isSameContentUnordered(expected, actual.get(0))); + if (actual.size() == 1) { + Assertions.assertTrue(isSameContentUnordered(primary, actual.get(0))); + } else { + Assertions.assertTrue(isSameContentUnordered(primary, actual.get(1))); + } + } + + static <T> boolean isSameContentUnordered(Collection<T> a, Collection<T> b) { + return a.size() == b.size() && a.containsAll(b); + } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateHandleDummyUtil.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateHandleDummyUtil.java index f7e4a7ef2dc..52a8bf032b6 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateHandleDummyUtil.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateHandleDummyUtil.java @@ -55,7 +55,8 @@ public class StateHandleDummyUtil { OperatorStateHandle.StateMetaInfo metaInfo = new OperatorStateHandle.StateMetaInfo( offsets, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE); - operatorStateMetaData.put(String.valueOf(UUID.randomUUID()), metaInfo); + operatorStateMetaData.put( + String.valueOf(new UUID(random.nextLong(), random.nextLong())), metaInfo); } return new OperatorStreamStateHandle( operatorStateMetaData, createStreamStateHandle(numNamedStates, random)); @@ -65,7 +66,8 @@ public class StateHandleDummyUtil { int numNamedStates, Random random) { byte[] streamData = new byte[numNamedStates * 4]; random.nextBytes(streamData); - return new ByteStreamStateHandle(String.valueOf(UUID.randomUUID()), streamData); + return new ByteStreamStateHandle( + String.valueOf(new UUID(random.nextLong(), random.nextLong())), streamData); } /** Creates a new test {@link KeyedStateHandle} for the given key-group. */ @@ -149,11 +151,11 @@ public class StateHandleDummyUtil { } public static ResultSubpartitionStateHandle createNewResultSubpartitionStateHandle( - int i, Random random) { + int numNamedStates, Random random) { return new ResultSubpartitionStateHandle( new ResultSubpartitionInfo(random.nextInt(), random.nextInt()), - createStreamStateHandle(i, random), - genOffsets(i, random)); + createStreamStateHandle(numNamedStates, random), + genOffsets(numNamedStates, random)); } private static ArrayList<Long> genOffsets(int size, Random random) { @@ -164,6 +166,23 @@ public class StateHandleDummyUtil { return offsets; } + public static KeyedStateHandle createKeyedStateHandleFromSeed(int seed) { + return createNewKeyedStateHandle(KeyGroupRange.of(seed * 4, seed * 4 + 3)); + } + + public static OperatorStateHandle createOperatorStateHandleFromSeed(int seed) { + return createNewOperatorStateHandle(1 + (seed % 3), new Random(seed)); + } + + public static InputChannelStateHandle createInputChannelStateHandleFromSeed(int seed) { + return createNewInputChannelStateHandle(1 + (seed % 3), new Random(seed)); + } + + public static ResultSubpartitionStateHandle createResultSubpartitionStateHandleFromSeed( + int seed) { + return createNewResultSubpartitionStateHandle(1 + (seed % 3), new Random(seed)); + } + /** KeyedStateHandle that only holds a key-group information. */ private static class DummyKeyedStateHandle implements KeyedStateHandle {