This is an automated email from the ASF dual-hosted git repository.

srichter pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit a4ad86fb083f90503938a9c3d816cdda9dc22427
Author: Stefan Richter <srich...@confluent.io>
AuthorDate: Fri Oct 13 16:27:39 2023 +0200

    [FLINK-33341][state] Add support for rescaling from local keyed state to 
PrioritizedOperatorSubtaskState.
---
 .../PrioritizedOperatorSubtaskState.java           | 132 +++++++++++++--
 .../PrioritizedOperatorSubtaskStateTest.java       | 184 ++++++++++++++++++---
 .../runtime/checkpoint/StateHandleDummyUtil.java   |  29 +++-
 3 files changed, 305 insertions(+), 40 deletions(-)

diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PrioritizedOperatorSubtaskState.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PrioritizedOperatorSubtaskState.java
index ef9bcd0440b..e41bcfe7338 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PrioritizedOperatorSubtaskState.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PrioritizedOperatorSubtaskState.java
@@ -31,10 +31,14 @@ import javax.annotation.Nonnull;
 import javax.annotation.Nullable;
 
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
+import java.util.Map;
+import java.util.Optional;
 import java.util.function.BiFunction;
 import java.util.function.Function;
+import java.util.stream.Collectors;
 
 /**
  * This class is a wrapper over multiple alternative {@link 
OperatorSubtaskState} that are (partial)
@@ -286,14 +290,14 @@ public class PrioritizedOperatorSubtaskState {
             }
 
             return new PrioritizedOperatorSubtaskState(
-                    resolvePrioritizedAlternatives(
+                    computePrioritizedAlternatives(
                             jobManagerState.getManagedKeyedState(),
                             managedKeyedAlternatives,
-                            
eqStateApprover(KeyedStateHandle::getKeyGroupRange)),
-                    resolvePrioritizedAlternatives(
+                            KeyedStateHandle::getKeyGroupRange),
+                    computePrioritizedAlternatives(
                             jobManagerState.getRawKeyedState(),
                             rawKeyedAlternatives,
-                            
eqStateApprover(KeyedStateHandle::getKeyGroupRange)),
+                            KeyedStateHandle::getKeyGroupRange),
                     resolvePrioritizedAlternatives(
                             jobManagerState.getManagedOperatorState(),
                             managedOperatorAlternatives,
@@ -313,22 +317,121 @@ public class PrioritizedOperatorSubtaskState {
                     restoredCheckpointId);
         }
 
+        /**
+         * This method creates an alternative recovery option by replacing as 
much job manager state
+         * with higher prioritized (=local) alternatives as possible.
+         *
+         * @param jobManagerState the state that the task got assigned from 
the job manager (this
+         *     state lives in remote storage).
+         * @param alternativesByPriority local alternatives to the job manager 
state, ordered by
+         *     priority.
+         * @param identityExtractor function to extract an identifier from a 
state object.
+         * @return prioritized state alternatives.
+         * @param <STATE_OBJ_TYPE> the type of the state objects we process.
+         * @param <ID_TYPE> the type of object that represents the id the 
state object type.
+         */
+        <STATE_OBJ_TYPE extends StateObject, ID_TYPE>
+                List<StateObjectCollection<STATE_OBJ_TYPE>> 
computePrioritizedAlternatives(
+                        StateObjectCollection<STATE_OBJ_TYPE> jobManagerState,
+                        List<StateObjectCollection<STATE_OBJ_TYPE>> 
alternativesByPriority,
+                        Function<STATE_OBJ_TYPE, ID_TYPE> identityExtractor) {
+
+            if (alternativesByPriority != null
+                    && !alternativesByPriority.isEmpty()
+                    && jobManagerState.hasState()) {
+
+                Optional<StateObjectCollection<STATE_OBJ_TYPE>> 
mergedAlternative =
+                        tryComputeMixedLocalAndRemoteAlternative(
+                                jobManagerState, alternativesByPriority, 
identityExtractor);
+
+                // Return the mix of local/remote state as first and pure 
remote state as second
+                // alternative (in case that we fail to recover from the local 
state, e.g. because
+                // of corruption).
+                if (mergedAlternative.isPresent()) {
+                    return Arrays.asList(mergedAlternative.get(), 
jobManagerState);
+                }
+            }
+
+            return Collections.singletonList(jobManagerState);
+        }
+
+        /**
+         * This method creates an alternative recovery option by replacing as 
much job manager state
+         * with higher prioritized (=local) alternatives as possible. Returns 
empty Optional if the
+         * JM state is empty or nothing could be replaced.
+         *
+         * @param jobManagerState the state that the task got assigned from 
the job manager (this
+         *     state lives in remote storage).
+         * @param alternativesByPriority local alternatives to the job manager 
state, ordered by
+         *     priority.
+         * @param identityExtractor function to extract an identifier from a 
state object.
+         * @return A state collection where all JM state handles for which we 
could find local *
+         *     alternatives are replaced by the alternative with the highest 
priority. Empty
+         *     optional if no state could be replaced.
+         * @param <STATE_OBJ_TYPE> the type of the state objects we process.
+         * @param <ID_TYPE> the type of object that represents the id the 
state object type.
+         */
+        static <STATE_OBJ_TYPE extends StateObject, ID_TYPE>
+                Optional<StateObjectCollection<STATE_OBJ_TYPE>>
+                        tryComputeMixedLocalAndRemoteAlternative(
+                                StateObjectCollection<STATE_OBJ_TYPE> 
jobManagerState,
+                                List<StateObjectCollection<STATE_OBJ_TYPE>> 
alternativesByPriority,
+                                Function<STATE_OBJ_TYPE, ID_TYPE> 
identityExtractor) {
+
+            List<STATE_OBJ_TYPE> result = Collections.emptyList();
+
+            // Build hash index over ids of the JM state
+            Map<ID_TYPE, STATE_OBJ_TYPE> indexById =
+                    jobManagerState.stream()
+                            .collect(Collectors.toMap(identityExtractor, 
Function.identity()));
+
+            // Move through all alternative in order from high to low priority
+            for (StateObjectCollection<STATE_OBJ_TYPE> alternative : 
alternativesByPriority) {
+                // Check all the state objects in the alternative if they can 
replace JM state
+                for (STATE_OBJ_TYPE stateHandle : alternative) {
+                    // Remove the current state object's id from the index to 
check for a match
+                    if (indexById.remove(identityExtractor.apply(stateHandle)) 
!= null) {
+                        if (result.isEmpty()) {
+                            // Lazy init result collection
+                            result = new ArrayList<>(jobManagerState.size());
+                        }
+                        // If the id was still in the index, replace with 
higher prio alternative
+                        result.add(stateHandle);
+
+                        // If the index is empty we are already done, all JM 
state was replaces with
+                        // the best alternative.
+                        if (indexById.isEmpty()) {
+                            return Optional.of(new 
StateObjectCollection<>(result));
+                        }
+                    }
+                }
+            }
+
+            // Nothing useful to return
+            if (result.isEmpty()) {
+                return Optional.empty();
+            }
+
+            // Add all remaining JM state objects that we could not replace 
from the index to the
+            // final result
+            result.addAll(indexById.values());
+            return Optional.of(new StateObjectCollection<>(result));
+        }
+
         /**
          * This helper method resolves the dependencies between the ground 
truth of the operator
          * state obtained from the job manager and potential alternatives for 
recovery, e.g. from a
          * task-local source.
          */
-        protected <T extends StateObject>
-                List<StateObjectCollection<T>> resolvePrioritizedAlternatives(
-                        StateObjectCollection<T> jobManagerState,
-                        List<StateObjectCollection<T>> alternativesByPriority,
-                        BiFunction<T, T, Boolean> approveFun) {
+        <T extends StateObject> List<StateObjectCollection<T>> 
resolvePrioritizedAlternatives(
+                StateObjectCollection<T> jobManagerState,
+                List<StateObjectCollection<T>> alternativesByPriority,
+                BiFunction<T, T, Boolean> approveFun) {
 
             // Nothing to resolve if there are no alternatives, or the ground 
truth has already no
-            // state, or if we can
-            // assume that a rescaling happened because we find more than one 
handle in the JM state
-            // (this is more a sanity
-            // check).
+            // state, or if we can assume that a rescaling happened because we 
find more than one
+            // handle in the JM state
+            // (this is more a sanity check).
             if (alternativesByPriority == null
                     || alternativesByPriority.isEmpty()
                     || !jobManagerState.hasState()
@@ -347,8 +450,7 @@ public class PrioritizedOperatorSubtaskState {
             for (StateObjectCollection<T> alternative : 
alternativesByPriority) {
 
                 // We found an alternative to the JM state if it has state, we 
have a 1:1
-                // relationship, and the
-                // approve-function signaled true.
+                // relationship, and the approve-function signaled true.
                 if (alternative != null
                         && alternative.hasState()
                         && alternative.size() == 1
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PrioritizedOperatorSubtaskStateTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PrioritizedOperatorSubtaskStateTest.java
index 11f41bc8baa..fa892cea9da 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PrioritizedOperatorSubtaskStateTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PrioritizedOperatorSubtaskStateTest.java
@@ -27,15 +27,19 @@ import 
org.apache.flink.runtime.state.ResultSubpartitionStateHandle;
 import org.apache.flink.runtime.state.StateObject;
 import org.apache.flink.util.Preconditions;
 
+import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.Iterator;
 import java.util.List;
+import java.util.Map;
 import java.util.Random;
 import java.util.function.Function;
+import java.util.function.IntFunction;
 import java.util.stream.Collectors;
 
 import static 
org.apache.flink.runtime.checkpoint.StateHandleDummyUtil.createNewInputChannelStateHandle;
@@ -51,6 +55,107 @@ class PrioritizedOperatorSubtaskStateTest {
 
     private static final Random RANDOM = new Random(0x42);
 
+    @Test
+    void testTryCreateMixedLocalAndRemoteAlternative() {
+        testTryCreateMixedLocalAndRemoteAlternative(
+                StateHandleDummyUtil::createKeyedStateHandleFromSeed,
+                KeyedStateHandle::getKeyGroupRange);
+    }
+
+    <SH extends StateObject, ID> void 
testTryCreateMixedLocalAndRemoteAlternative(
+            IntFunction<SH> stateHandleFactory, Function<SH, ID> idExtractor) {
+
+        SH remote0 = stateHandleFactory.apply(0);
+        SH remote1 = stateHandleFactory.apply(1);
+        SH remote2 = stateHandleFactory.apply(2);
+        SH remote3 = stateHandleFactory.apply(3);
+
+        List<SH> jmState = Arrays.asList(remote0, remote1, remote2, remote3);
+
+        SH local0 = stateHandleFactory.apply(0);
+        SH local3a = stateHandleFactory.apply(3);
+
+        List<SH> alternativeA = Arrays.asList(local0, local3a);
+
+        SH local1 = stateHandleFactory.apply(1);
+        SH local3b = stateHandleFactory.apply(3);
+        SH local5 = stateHandleFactory.apply(5);
+
+        List<SH> alternativeB = Arrays.asList(local1, local3b, local5);
+
+        List<StateObjectCollection<SH>> alternatives =
+                Arrays.asList(
+                        new StateObjectCollection<>(alternativeA),
+                        new StateObjectCollection<>(Collections.emptyList()),
+                        new StateObjectCollection<>(alternativeB));
+
+        StateObjectCollection<SH> result =
+                
PrioritizedOperatorSubtaskState.Builder.tryComputeMixedLocalAndRemoteAlternative(
+                                new StateObjectCollection<>(jmState), 
alternatives, idExtractor)
+                        .get();
+
+        assertThat(result).hasSameElementsAs(Arrays.asList(local0, local1, 
remote2, local3a));
+    }
+
+    @Test
+    void testTryCreateMixedLocalAndRemoteAlternativeEmptyAlternative() {
+        testTryCreateMixedLocalAndRemoteAlternativeEmptyAlternative(
+                StateHandleDummyUtil::createKeyedStateHandleFromSeed,
+                KeyedStateHandle::getKeyGroupRange);
+    }
+
+    <SH extends StateObject, ID> void 
testTryCreateMixedLocalAndRemoteAlternativeEmptyAlternative(
+            IntFunction<SH> stateHandleFactory, Function<SH, ID> idExtractor) {
+        List<SH> jmState =
+                Arrays.asList(
+                        stateHandleFactory.apply(0),
+                        stateHandleFactory.apply(1),
+                        stateHandleFactory.apply(2),
+                        stateHandleFactory.apply(3));
+
+        Assertions.assertFalse(
+                
PrioritizedOperatorSubtaskState.Builder.tryComputeMixedLocalAndRemoteAlternative(
+                                new StateObjectCollection<>(jmState),
+                                Collections.emptyList(),
+                                idExtractor)
+                        .isPresent());
+
+        Assertions.assertFalse(
+                
PrioritizedOperatorSubtaskState.Builder.tryComputeMixedLocalAndRemoteAlternative(
+                                new StateObjectCollection<>(jmState),
+                                Collections.singletonList(new 
StateObjectCollection<>()),
+                                idExtractor)
+                        .isPresent());
+    }
+
+    @Test
+    void testTryCreateMixedLocalAndRemoteAlternativeEmptyJMState() {
+        testTryCreateMixedLocalAndRemoteAlternativeEmptyJMState(
+                StateHandleDummyUtil::createKeyedStateHandleFromSeed,
+                KeyedStateHandle::getKeyGroupRange);
+    }
+
+    <SH extends StateObject, ID> void 
testTryCreateMixedLocalAndRemoteAlternativeEmptyJMState(
+            IntFunction<SH> stateHandleFactory, Function<SH, ID> idExtractor) {
+        List<SH> alternativeA =
+                Arrays.asList(stateHandleFactory.apply(0), 
stateHandleFactory.apply(3));
+
+        Assertions.assertFalse(
+                
PrioritizedOperatorSubtaskState.Builder.tryComputeMixedLocalAndRemoteAlternative(
+                                new 
StateObjectCollection<>(Collections.emptyList()),
+                                Collections.singletonList(
+                                        new 
StateObjectCollection<>(alternativeA)),
+                                idExtractor)
+                        .isPresent());
+
+        Assertions.assertFalse(
+                
PrioritizedOperatorSubtaskState.Builder.tryComputeMixedLocalAndRemoteAlternative(
+                                new 
StateObjectCollection<>(Collections.emptyList()),
+                                Collections.emptyList(),
+                                KeyedStateHandle::getKeyGroupRange)
+                        .isPresent());
+    }
+
     /**
      * This tests attempts to test (almost) the full space of significantly 
different options for
      * verifying and prioritizing {@link OperatorSubtaskState} options for 
local recovery over
@@ -106,16 +211,17 @@ class PrioritizedOperatorSubtaskStateTest {
                                                 : onlyPrimary))
                         .isTrue();
 
-                assertThat(
-                                checkResultAsExpected(
-                                        
OperatorSubtaskState::getManagedKeyedState,
-                                        PrioritizedOperatorSubtaskState
-                                                
::getPrioritizedManagedKeyedState,
-                                        prioritizedOperatorSubtaskState,
-                                        
primaryAndFallback.getManagedKeyedState().size() == 1
-                                                ? validAlternatives
-                                                : onlyPrimary))
-                        .isTrue();
+                StateObjectCollection<KeyedStateHandle> expManagedKeyed =
+                        computeExpectedMixedState(
+                                orderedAlternativesList,
+                                primaryAndFallback,
+                                OperatorSubtaskState::getManagedKeyedState,
+                                KeyedStateHandle::getKeyGroupRange);
+
+                assertResultAsExpected(
+                        expManagedKeyed,
+                        primaryAndFallback.getManagedKeyedState(),
+                        
prioritizedOperatorSubtaskState.getPrioritizedManagedKeyedState());
 
                 assertThat(
                                 checkResultAsExpected(
@@ -128,16 +234,17 @@ class PrioritizedOperatorSubtaskStateTest {
                                                 : onlyPrimary))
                         .isTrue();
 
-                assertThat(
-                                checkResultAsExpected(
-                                        OperatorSubtaskState::getRawKeyedState,
-                                        PrioritizedOperatorSubtaskState
-                                                ::getPrioritizedRawKeyedState,
-                                        prioritizedOperatorSubtaskState,
-                                        
primaryAndFallback.getRawKeyedState().size() == 1
-                                                ? validAlternatives
-                                                : onlyPrimary))
-                        .isTrue();
+                StateObjectCollection<KeyedStateHandle> expRawKeyed =
+                        computeExpectedMixedState(
+                                orderedAlternativesList,
+                                primaryAndFallback,
+                                OperatorSubtaskState::getRawKeyedState,
+                                KeyedStateHandle::getKeyGroupRange);
+
+                assertResultAsExpected(
+                        expRawKeyed,
+                        primaryAndFallback.getRawKeyedState(),
+                        
prioritizedOperatorSubtaskState.getPrioritizedRawKeyedState());
             }
         }
     }
@@ -390,4 +497,41 @@ class PrioritizedOperatorSubtaskStateTest {
             throw new IllegalStateException();
         }
     }
+
+    private <T extends StateObject, ID> StateObjectCollection<T> 
computeExpectedMixedState(
+            List<OperatorSubtaskState> orderedAlternativesList,
+            OperatorSubtaskState primaryAndFallback,
+            Function<OperatorSubtaskState, StateObjectCollection<T>> 
stateExtractor,
+            Function<T, ID> idExtractor) {
+
+        List<OperatorSubtaskState> reverseAlternatives = new 
ArrayList<>(orderedAlternativesList);
+        Collections.reverse(reverseAlternatives);
+
+        Map<ID, T> map =
+                stateExtractor.apply(primaryAndFallback).stream()
+                        .collect(Collectors.toMap(idExtractor, 
Function.identity()));
+
+        reverseAlternatives.stream()
+                .flatMap(x -> stateExtractor.apply(x).stream())
+                .forEach(x -> map.replace(idExtractor.apply(x), x));
+
+        return new StateObjectCollection<>(map.values());
+    }
+
+    static <SH extends StateObject> void assertResultAsExpected(
+            StateObjectCollection<SH> expected,
+            StateObjectCollection<SH> primary,
+            List<StateObjectCollection<SH>> actual) {
+        Assertions.assertTrue(!actual.isEmpty() && actual.size() <= 2);
+        Assertions.assertTrue(isSameContentUnordered(expected, actual.get(0)));
+        if (actual.size() == 1) {
+            Assertions.assertTrue(isSameContentUnordered(primary, 
actual.get(0)));
+        } else {
+            Assertions.assertTrue(isSameContentUnordered(primary, 
actual.get(1)));
+        }
+    }
+
+    static <T> boolean isSameContentUnordered(Collection<T> a, Collection<T> 
b) {
+        return a.size() == b.size() && a.containsAll(b);
+    }
 }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateHandleDummyUtil.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateHandleDummyUtil.java
index f7e4a7ef2dc..52a8bf032b6 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateHandleDummyUtil.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateHandleDummyUtil.java
@@ -55,7 +55,8 @@ public class StateHandleDummyUtil {
             OperatorStateHandle.StateMetaInfo metaInfo =
                     new OperatorStateHandle.StateMetaInfo(
                             offsets, 
OperatorStateHandle.Mode.SPLIT_DISTRIBUTE);
-            operatorStateMetaData.put(String.valueOf(UUID.randomUUID()), 
metaInfo);
+            operatorStateMetaData.put(
+                    String.valueOf(new UUID(random.nextLong(), 
random.nextLong())), metaInfo);
         }
         return new OperatorStreamStateHandle(
                 operatorStateMetaData, createStreamStateHandle(numNamedStates, 
random));
@@ -65,7 +66,8 @@ public class StateHandleDummyUtil {
             int numNamedStates, Random random) {
         byte[] streamData = new byte[numNamedStates * 4];
         random.nextBytes(streamData);
-        return new ByteStreamStateHandle(String.valueOf(UUID.randomUUID()), 
streamData);
+        return new ByteStreamStateHandle(
+                String.valueOf(new UUID(random.nextLong(), 
random.nextLong())), streamData);
     }
 
     /** Creates a new test {@link KeyedStateHandle} for the given key-group. */
@@ -149,11 +151,11 @@ public class StateHandleDummyUtil {
     }
 
     public static ResultSubpartitionStateHandle 
createNewResultSubpartitionStateHandle(
-            int i, Random random) {
+            int numNamedStates, Random random) {
         return new ResultSubpartitionStateHandle(
                 new ResultSubpartitionInfo(random.nextInt(), random.nextInt()),
-                createStreamStateHandle(i, random),
-                genOffsets(i, random));
+                createStreamStateHandle(numNamedStates, random),
+                genOffsets(numNamedStates, random));
     }
 
     private static ArrayList<Long> genOffsets(int size, Random random) {
@@ -164,6 +166,23 @@ public class StateHandleDummyUtil {
         return offsets;
     }
 
+    public static KeyedStateHandle createKeyedStateHandleFromSeed(int seed) {
+        return createNewKeyedStateHandle(KeyGroupRange.of(seed * 4, seed * 4 + 
3));
+    }
+
+    public static OperatorStateHandle createOperatorStateHandleFromSeed(int 
seed) {
+        return createNewOperatorStateHandle(1 + (seed % 3), new Random(seed));
+    }
+
+    public static InputChannelStateHandle 
createInputChannelStateHandleFromSeed(int seed) {
+        return createNewInputChannelStateHandle(1 + (seed % 3), new 
Random(seed));
+    }
+
+    public static ResultSubpartitionStateHandle 
createResultSubpartitionStateHandleFromSeed(
+            int seed) {
+        return createNewResultSubpartitionStateHandle(1 + (seed % 3), new 
Random(seed));
+    }
+
     /** KeyedStateHandle that only holds a key-group information. */
     private static class DummyKeyedStateHandle implements KeyedStateHandle {
 

Reply via email to