This is an automated email from the ASF dual-hosted git repository.

srichter pushed a commit to branch release-1.17
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/release-1.17 by this push:
     new 8d8a486aaa8 [FLINK-31963][state] Fix rescaling bug in recovery from 
unaligned checkpoints. (#22584) (#22594)
8d8a486aaa8 is described below

commit 8d8a486aaa8360d6beabacc6980280c96bf900ea
Author: Stefan Richter <srich...@apache.org>
AuthorDate: Wed May 17 12:05:04 2023 +0200

    [FLINK-31963][state] Fix rescaling bug in recovery from unaligned 
checkpoints. (#22584) (#22594)
    
    This commit fixes problems in StateAssignmentOperation for unaligned 
checkpoints with stateless operators that have upstream operators with output 
partition state or downstream operators with input channel state.
    
    (cherry picked from commit 354c0f455b92c083299d8028f161f0dd113ab614)
---
 .../checkpoint/StateAssignmentOperation.java       |  28 ++--
 .../runtime/checkpoint/TaskStateAssignment.java    |  19 ++-
 .../checkpoint/StateAssignmentOperationTest.java   | 178 ++++++++++++++++++++-
 .../checkpointing/UnalignedCheckpointITCase.java   |  18 ++-
 .../UnalignedCheckpointRescaleITCase.java          | 137 ++++++++++------
 .../checkpointing/UnalignedCheckpointTestBase.java |  32 +++-
 6 files changed, 335 insertions(+), 77 deletions(-)

diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
index 681e0b18df1..e476c6b65ec 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
@@ -136,19 +136,24 @@ public class StateAssignmentOperation {
 
         // repartition state
         for (TaskStateAssignment stateAssignment : vertexAssignments.values()) 
{
-            if (stateAssignment.hasNonFinishedState) {
+            if (stateAssignment.hasNonFinishedState
+                    // FLINK-31963: We need to run repartitioning for 
stateless operators that have
+                    // upstream output or downstream input states.
+                    || stateAssignment.hasUpstreamOutputStates()
+                    || stateAssignment.hasDownstreamInputStates()) {
                 assignAttemptState(stateAssignment);
             }
         }
 
         // actually assign the state
         for (TaskStateAssignment stateAssignment : vertexAssignments.values()) 
{
-            // If upstream has output states, even the empty task state should 
be assigned for the
-            // current task in order to notify this task that the old states 
will send to it which
-            // likely should be filtered.
+            // If upstream has output states or downstream has input states, 
even the empty task
+            // state should be assigned for the current task in order to 
notify this task that the
+            // old states will send to it which likely should be filtered.
             if (stateAssignment.hasNonFinishedState
                     || stateAssignment.isFullyFinished
-                    || stateAssignment.hasUpstreamOutputStates()) {
+                    || stateAssignment.hasUpstreamOutputStates()
+                    || stateAssignment.hasDownstreamInputStates()) {
                 assignTaskStateToExecutionJobVertices(stateAssignment);
             }
         }
@@ -345,9 +350,10 @@ public class StateAssignmentOperation {
                                         newParallelism)));
     }
 
-    public <I, T extends AbstractChannelStateHandle<I>> void 
reDistributeResultSubpartitionStates(
-            TaskStateAssignment assignment) {
-        if (!assignment.hasOutputState) {
+    public void reDistributeResultSubpartitionStates(TaskStateAssignment 
assignment) {
+        // FLINK-31963: We can skip this phase if there is no output state AND 
downstream has no
+        // input states
+        if (!assignment.hasOutputState && 
!assignment.hasDownstreamInputStates()) {
             return;
         }
 
@@ -394,7 +400,9 @@ public class StateAssignmentOperation {
     }
 
     public void reDistributeInputChannelStates(TaskStateAssignment 
stateAssignment) {
-        if (!stateAssignment.hasInputState) {
+        // FLINK-31963: We can skip this phase only if there is no input state 
AND upstream has no
+        // output states
+        if (!stateAssignment.hasInputState && 
!stateAssignment.hasUpstreamOutputStates()) {
             return;
         }
 
@@ -435,7 +443,7 @@ public class StateAssignmentOperation {
                             : getPartitionState(
                                     inputOperatorState, 
InputChannelInfo::getGateIdx, gateIndex);
             final MappingBasedRepartitioner<InputChannelStateHandle> 
repartitioner =
-                    new MappingBasedRepartitioner(mapping);
+                    new MappingBasedRepartitioner<>(mapping);
             final Map<OperatorInstanceID, List<InputChannelStateHandle>> 
repartitioned =
                     applyRepartitioner(
                             stateAssignment.inputOperatorID,
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateAssignment.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateAssignment.java
index 75ffc71d058..e9f9d11421e 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateAssignment.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateAssignment.java
@@ -84,6 +84,8 @@ class TaskStateAssignment {
 
     @Nullable private TaskStateAssignment[] downstreamAssignments;
     @Nullable private TaskStateAssignment[] upstreamAssignments;
+    @Nullable private Boolean hasUpstreamOutputStates;
+    @Nullable private Boolean hasDownstreamInputStates;
 
     private final Map<IntermediateDataSetID, TaskStateAssignment> 
consumerAssignment;
     private final Map<ExecutionJobVertex, TaskStateAssignment> 
vertexAssignments;
@@ -202,8 +204,21 @@ class TaskStateAssignment {
     }
 
     public boolean hasUpstreamOutputStates() {
-        return Arrays.stream(getUpstreamAssignments())
-                .anyMatch(assignment -> assignment.hasOutputState);
+        if (hasUpstreamOutputStates == null) {
+            hasUpstreamOutputStates =
+                    Arrays.stream(getUpstreamAssignments())
+                            .anyMatch(assignment -> assignment.hasOutputState);
+        }
+        return hasUpstreamOutputStates;
+    }
+
+    public boolean hasDownstreamInputStates() {
+        if (hasDownstreamInputStates == null) {
+            hasDownstreamInputStates =
+                    Arrays.stream(getDownstreamAssignments())
+                            .anyMatch(assignment -> assignment.hasInputState);
+        }
+        return hasDownstreamInputStates;
     }
 
     private InflightDataGateOrPartitionRescalingDescriptor log(
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperationTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperationTest.java
index f9cb551c5ad..bffdd8686ae 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperationTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperationTest.java
@@ -22,6 +22,7 @@ import org.apache.flink.runtime.JobException;
 import org.apache.flink.runtime.OperatorIDPair;
 import 
org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor;
 import org.apache.flink.runtime.client.JobExecutionException;
+import org.apache.flink.runtime.executiongraph.Execution;
 import org.apache.flink.runtime.executiongraph.ExecutionGraph;
 import org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils;
 import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
@@ -51,6 +52,9 @@ import org.junit.Assert;
 import org.junit.ClassRule;
 import org.junit.Test;
 
+import javax.annotation.Nullable;
+
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.EnumMap;
@@ -82,6 +86,7 @@ import static 
org.apache.flink.runtime.checkpoint.StateHandleDummyUtil.createNew
 import static 
org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper.ARBITRARY;
 import static 
org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper.RANGE;
 import static 
org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper.ROUND_ROBIN;
+import static org.apache.flink.util.Preconditions.checkArgument;
 import static org.hamcrest.CoreMatchers.is;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.containsInAnyOrder;
@@ -785,6 +790,129 @@ public class StateAssignmentOperationTest extends 
TestLogger {
         }
     }
 
+    /** FLINK-31963: Tests rescaling for stateless operators and upstream 
result partition state. */
+    @Test
+    public void testOnlyUpstreamChannelRescaleStateAssignment()
+            throws JobException, JobExecutionException {
+        Random random = new Random();
+        OperatorSubtaskState upstreamOpState =
+                OperatorSubtaskState.builder()
+                        .setResultSubpartitionState(
+                                new StateObjectCollection<>(
+                                        asList(
+                                                
createNewResultSubpartitionStateHandle(10, random),
+                                                
createNewResultSubpartitionStateHandle(
+                                                        10, random))))
+                        .build();
+        testOnlyUpstreamOrDownstreamRescalingInternal(upstreamOpState, null, 
5, 7);
+    }
+
+    /** FLINK-31963: Tests rescaling for stateless operators and downstream 
input channel state. */
+    @Test
+    public void testOnlyDownstreamChannelRescaleStateAssignment()
+            throws JobException, JobExecutionException {
+        Random random = new Random();
+        OperatorSubtaskState downstreamOpState =
+                OperatorSubtaskState.builder()
+                        .setInputChannelState(
+                                new StateObjectCollection<>(
+                                        asList(
+                                                
createNewInputChannelStateHandle(10, random),
+                                                
createNewInputChannelStateHandle(10, random))))
+                        .build();
+        testOnlyUpstreamOrDownstreamRescalingInternal(null, downstreamOpState, 
5, 5);
+    }
+
+    private void testOnlyUpstreamOrDownstreamRescalingInternal(
+            @Nullable OperatorSubtaskState upstreamOpState,
+            @Nullable OperatorSubtaskState downstreamOpState,
+            int expectedUpstreamCount,
+            int expectedDownstreamCount)
+            throws JobException, JobExecutionException {
+
+        checkArgument(
+                upstreamOpState != downstreamOpState
+                        && (upstreamOpState == null || downstreamOpState == 
null),
+                "Either upstream or downstream state must exist, but not 
both");
+
+        // Start from parallelism 5 for both operators
+        int upstreamParallelism = 5;
+        int downstreamParallelism = 5;
+
+        // Build states
+        List<OperatorID> operatorIds = buildOperatorIds(2);
+        Map<OperatorID, OperatorState> states = new HashMap<>();
+        OperatorState upstreamState =
+                new OperatorState(operatorIds.get(0), upstreamParallelism, 
MAX_P);
+        OperatorState downstreamState =
+                new OperatorState(operatorIds.get(1), downstreamParallelism, 
MAX_P);
+
+        states.put(operatorIds.get(0), upstreamState);
+        states.put(operatorIds.get(1), downstreamState);
+
+        if (upstreamOpState != null) {
+            upstreamState.putState(0, upstreamOpState);
+            // rescale downstream 5 -> 3
+            downstreamParallelism = 3;
+        }
+
+        if (downstreamOpState != null) {
+            downstreamState.putState(0, downstreamOpState);
+            // rescale upstream 5 -> 3
+            upstreamParallelism = 3;
+        }
+
+        List<OperatorIdWithParallelism> opIdWithParallelism = new 
ArrayList<>(2);
+        opIdWithParallelism.add(
+                new OperatorIdWithParallelism(operatorIds.get(0), 
upstreamParallelism));
+        opIdWithParallelism.add(
+                new OperatorIdWithParallelism(operatorIds.get(1), 
downstreamParallelism));
+
+        Map<OperatorID, ExecutionJobVertex> vertices =
+                buildVertices(opIdWithParallelism, RANGE, ROUND_ROBIN);
+
+        // Run state assignment
+        new StateAssignmentOperation(0, new HashSet<>(vertices.values()), 
states, false)
+                .assignStates();
+
+        // Check results
+        ExecutionJobVertex upstreamExecutionJobVertex = 
vertices.get(operatorIds.get(0));
+        ExecutionJobVertex downstreamExecutionJobVertex = 
vertices.get(operatorIds.get(1));
+
+        List<TaskStateSnapshot> upstreamTaskStateSnapshots =
+                getTaskStateSnapshotFromVertex(upstreamExecutionJobVertex);
+        List<TaskStateSnapshot> downstreamTaskStateSnapshots =
+                getTaskStateSnapshotFromVertex(downstreamExecutionJobVertex);
+
+        checkMappings(
+                upstreamTaskStateSnapshots,
+                TaskStateSnapshot::getOutputRescalingDescriptor,
+                expectedUpstreamCount);
+
+        checkMappings(
+                downstreamTaskStateSnapshots,
+                TaskStateSnapshot::getInputRescalingDescriptor,
+                expectedDownstreamCount);
+    }
+
+    private void checkMappings(
+            List<TaskStateSnapshot> taskStateSnapshots,
+            Function<TaskStateSnapshot, InflightDataRescalingDescriptor> 
extractFun,
+            int expectedCount) {
+        Assert.assertEquals(
+                expectedCount,
+                taskStateSnapshots.stream()
+                        .map(extractFun)
+                        .mapToInt(
+                                x -> {
+                                    int len = x.getOldSubtaskIndexes(0).length;
+                                    // Assert that there is a mapping.
+                                    Assert.assertTrue(len > 0);
+                                    return len;
+                                })
+                        .sum());
+    }
+
     @Test
     public void testStateWithFullyFinishedOperators() throws JobException, 
JobExecutionException {
         List<OperatorID> operatorIds = buildOperatorIds(2);
@@ -949,15 +1077,50 @@ public class StateAssignmentOperationTest extends 
TestLogger {
                                 }));
     }
 
+    private static class OperatorIdWithParallelism {
+        private final OperatorID operatorID;
+        private final int parallelism;
+
+        public OperatorID getOperatorID() {
+            return operatorID;
+        }
+
+        public int getParallelism() {
+            return parallelism;
+        }
+
+        public OperatorIdWithParallelism(OperatorID operatorID, int 
parallelism) {
+            this.operatorID = operatorID;
+            this.parallelism = parallelism;
+        }
+    }
+
     private Map<OperatorID, ExecutionJobVertex> buildVertices(
             List<OperatorID> operatorIds,
-            int parallelism,
+            int parallelisms,
             SubtaskStateMapper downstreamRescaler,
             SubtaskStateMapper upstreamRescaler)
             throws JobException, JobExecutionException {
-        final JobVertex[] jobVertices =
+        List<OperatorIdWithParallelism> opIdsWithParallelism =
                 operatorIds.stream()
-                        .map(id -> createJobVertex(id, id, parallelism))
+                        .map(operatorID -> new 
OperatorIdWithParallelism(operatorID, parallelisms))
+                        .collect(Collectors.toList());
+        return buildVertices(opIdsWithParallelism, downstreamRescaler, 
upstreamRescaler);
+    }
+
+    private Map<OperatorID, ExecutionJobVertex> buildVertices(
+            List<OperatorIdWithParallelism> operatorIdsAndParallelism,
+            SubtaskStateMapper downstreamRescaler,
+            SubtaskStateMapper upstreamRescaler)
+            throws JobException, JobExecutionException {
+        final JobVertex[] jobVertices =
+                operatorIdsAndParallelism.stream()
+                        .map(
+                                idWithParallelism ->
+                                        createJobVertex(
+                                                
idWithParallelism.getOperatorID(),
+                                                
idWithParallelism.getOperatorID(),
+                                                
idWithParallelism.getParallelism()))
                         .toArray(JobVertex[]::new);
         for (int index = 1; index < jobVertices.length; index++) {
             connectVertices(
@@ -1029,6 +1192,15 @@ public class StateAssignmentOperationTest extends 
TestLogger {
         return jobVertex;
     }
 
+    private List<TaskStateSnapshot> getTaskStateSnapshotFromVertex(
+            ExecutionJobVertex executionJobVertex) {
+        return Arrays.stream(executionJobVertex.getTaskVertices())
+                .map(ExecutionVertex::getCurrentExecutionAttempt)
+                .map(Execution::getTaskRestore)
+                .map(JobManagerTaskRestore::getTaskStateSnapshot)
+                .collect(Collectors.toList());
+    }
+
     private OperatorSubtaskState getAssignedState(
             ExecutionJobVertex executionJobVertex, OperatorID operatorId, int 
subtaskIdx) {
         return executionJobVertex
diff --git 
a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointITCase.java
 
b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointITCase.java
index 5a6efd174e1..dc1b21f07b4 100644
--- 
a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointITCase.java
+++ 
b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointITCase.java
@@ -113,7 +113,8 @@ public class UnalignedCheckpointITCase extends 
UnalignedCheckpointTestBase {
                     StreamExecutionEnvironment env,
                     int minCheckpoints,
                     boolean slotSharing,
-                    int expectedRestarts) {
+                    int expectedRestarts,
+                    long sourceSleepMs) {
                 final int parallelism = env.getParallelism();
                 final SingleOutputStreamOperator<Long> stream =
                         env.fromSource(
@@ -121,7 +122,8 @@ public class UnalignedCheckpointITCase extends 
UnalignedCheckpointTestBase {
                                                 minCheckpoints,
                                                 parallelism,
                                                 expectedRestarts,
-                                                env.getCheckpointInterval()),
+                                                env.getCheckpointInterval(),
+                                                sourceSleepMs),
                                         noWatermarks(),
                                         "source")
                                 .slotSharingGroup(slotSharing ? "default" : 
"source")
@@ -144,7 +146,8 @@ public class UnalignedCheckpointITCase extends 
UnalignedCheckpointTestBase {
                     StreamExecutionEnvironment env,
                     int minCheckpoints,
                     boolean slotSharing,
-                    int expectedRestarts) {
+                    int expectedRestarts,
+                    long sourceSleepMs) {
                 final int parallelism = env.getParallelism();
                 DataStream<Long> combinedSource = null;
                 for (int inputIndex = 0; inputIndex < NUM_SOURCES; 
inputIndex++) {
@@ -154,7 +157,8 @@ public class UnalignedCheckpointITCase extends 
UnalignedCheckpointTestBase {
                                                     minCheckpoints,
                                                     parallelism,
                                                     expectedRestarts,
-                                                    
env.getCheckpointInterval()),
+                                                    
env.getCheckpointInterval(),
+                                                    sourceSleepMs),
                                             noWatermarks(),
                                             "source" + inputIndex)
                                     .slotSharingGroup(
@@ -182,7 +186,8 @@ public class UnalignedCheckpointITCase extends 
UnalignedCheckpointTestBase {
                     StreamExecutionEnvironment env,
                     int minCheckpoints,
                     boolean slotSharing,
-                    int expectedRestarts) {
+                    int expectedRestarts,
+                    long sourceSleepMs) {
                 final int parallelism = env.getParallelism();
                 DataStream<Tuple2<Integer, Long>> combinedSource = null;
                 for (int inputIndex = 0; inputIndex < NUM_SOURCES; 
inputIndex++) {
@@ -193,7 +198,8 @@ public class UnalignedCheckpointITCase extends 
UnalignedCheckpointTestBase {
                                                     minCheckpoints,
                                                     parallelism,
                                                     expectedRestarts,
-                                                    
env.getCheckpointInterval()),
+                                                    
env.getCheckpointInterval(),
+                                                    sourceSleepMs),
                                             noWatermarks(),
                                             "source" + inputIndex)
                                     .slotSharingGroup(
diff --git 
a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointRescaleITCase.java
 
b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointRescaleITCase.java
index 4216cc5469b..1f5d4c73afc 100644
--- 
a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointRescaleITCase.java
+++ 
b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointRescaleITCase.java
@@ -68,6 +68,7 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
     private final int oldParallelism;
     private final int newParallelism;
     private final int buffersPerChannel;
+    private final long sourceSleepMs;
 
     enum Topology implements DagCreator {
         PIPELINE {
@@ -76,7 +77,8 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
                     StreamExecutionEnvironment env,
                     int minCheckpoints,
                     boolean slotSharing,
-                    int expectedRestarts) {
+                    int expectedRestarts,
+                    long sourceSleepMillis) {
                 final int parallelism = env.getParallelism();
                 final DataStream<Long> source =
                         createSourcePipeline(
@@ -86,6 +88,7 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
                                 expectedRestarts,
                                 parallelism,
                                 0,
+                                sourceSleepMillis,
                                 val -> true);
                 addFailingSink(source, minCheckpoints, slotSharing);
             }
@@ -97,7 +100,8 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
                     StreamExecutionEnvironment env,
                     int minCheckpoints,
                     boolean slotSharing,
-                    int expectedRestarts) {
+                    int expectedRestarts,
+                    long sourceSleepMs) {
 
                 final int parallelism = env.getParallelism();
                 DataStream<Long> combinedSource = null;
@@ -111,6 +115,7 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
                                     expectedRestarts,
                                     parallelism,
                                     inputIndex,
+                                    sourceSleepMs,
                                     val -> withoutHeader(val) % NUM_SOURCES == 
finalInputIndex);
                     combinedSource =
                             combinedSource == null
@@ -134,10 +139,10 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
                     StreamExecutionEnvironment env,
                     int minCheckpoints,
                     boolean slotSharing,
-                    int expectedRestarts) {
+                    int expectedRestarts,
+                    long sourceSleepMs) {
 
                 final int parallelism = env.getParallelism();
-                checkState(parallelism >= 4);
                 final DataStream<Long> source1 =
                         createSourcePipeline(
                                 env,
@@ -146,6 +151,7 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
                                 expectedRestarts,
                                 parallelism / 2,
                                 0,
+                                sourceSleepMs,
                                 val -> withoutHeader(val) % 2 == 0);
                 final DataStream<Long> source2 =
                         createSourcePipeline(
@@ -155,6 +161,7 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
                                 expectedRestarts,
                                 parallelism / 3,
                                 1,
+                                sourceSleepMs,
                                 val -> withoutHeader(val) % 2 == 1);
 
                 KeySelector<Long, Long> keySelector = i -> withoutHeader(i) % 
NUM_GROUPS;
@@ -174,7 +181,8 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
                     StreamExecutionEnvironment env,
                     int minCheckpoints,
                     boolean slotSharing,
-                    int expectedRestarts) {
+                    int expectedRestarts,
+                    long sourceSleepMs) {
 
                 final int parallelism = env.getParallelism();
                 DataStream<Long> combinedSource = null;
@@ -188,6 +196,7 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
                                     expectedRestarts,
                                     parallelism,
                                     inputIndex,
+                                    sourceSleepMs,
                                     val -> withoutHeader(val) % NUM_SOURCES == 
finalInputIndex);
                     combinedSource = combinedSource == null ? source : 
combinedSource.union(source);
                 }
@@ -202,7 +211,8 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
                     StreamExecutionEnvironment env,
                     int minCheckpoints,
                     boolean slotSharing,
-                    int expectedRestarts) {
+                    int expectedRestarts,
+                    long sourceSleepMs) {
 
                 final int parallelism = env.getParallelism();
                 final DataStream<Long> broadcastSide =
@@ -211,7 +221,8 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
                                         minCheckpoints,
                                         parallelism,
                                         expectedRestarts,
-                                        env.getCheckpointInterval()),
+                                        env.getCheckpointInterval(),
+                                        sourceSleepMs),
                                 noWatermarks(),
                                 "source");
                 final DataStream<Long> source =
@@ -222,6 +233,7 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
                                         expectedRestarts,
                                         parallelism,
                                         0,
+                                        sourceSleepMs,
                                         val -> true)
                                 .map(i -> checkHeader(i))
                                 .name("map")
@@ -249,7 +261,8 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
                     StreamExecutionEnvironment env,
                     int minCheckpoints,
                     boolean slotSharing,
-                    int expectedRestarts) {
+                    int expectedRestarts,
+                    long sourceSleepMs) {
 
                 final int parallelism = env.getParallelism();
                 final DataStream<Long> broadcastSide1 =
@@ -258,7 +271,8 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
                                                 minCheckpoints,
                                                 1,
                                                 expectedRestarts,
-                                                env.getCheckpointInterval()),
+                                                env.getCheckpointInterval(),
+                                                sourceSleepMs),
                                         noWatermarks(),
                                         "source-1")
                                 .setParallelism(1);
@@ -268,7 +282,8 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
                                                 minCheckpoints,
                                                 1,
                                                 expectedRestarts,
-                                                env.getCheckpointInterval()),
+                                                env.getCheckpointInterval(),
+                                                sourceSleepMs),
                                         noWatermarks(),
                                         "source-2")
                                 .setParallelism(1);
@@ -278,7 +293,8 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
                                                 minCheckpoints,
                                                 1,
                                                 expectedRestarts,
-                                                env.getCheckpointInterval()),
+                                                env.getCheckpointInterval(),
+                                                sourceSleepMs),
                                         noWatermarks(),
                                         "source-3")
                                 .setParallelism(1);
@@ -290,6 +306,7 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
                                         expectedRestarts,
                                         parallelism,
                                         0,
+                                        sourceSleepMs,
                                         val -> true)
                                 .map(i -> checkHeader(i))
                                 .name("map")
@@ -349,13 +366,15 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
                 int expectedRestarts,
                 int parallelism,
                 int inputIndex,
+                long sourceSleepMs,
                 FilterFunction<Long> sourceFilter) {
             return env.fromSource(
                             new LongSource(
                                     minCheckpoints,
                                     parallelism,
                                     expectedRestarts,
-                                    env.getCheckpointInterval()),
+                                    env.getCheckpointInterval(),
+                                    sourceSleepMs),
                             noWatermarks(),
                             "source" + inputIndex)
                     .uid("source" + inputIndex)
@@ -459,46 +478,61 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
         }
     }
 
-    @Parameterized.Parameters(name = "{0} {1} from {2} to {3}, 
buffersPerChannel = {4}")
+    @Parameterized.Parameters(
+            name = "{0} {1} from {2} to {3}, sourceSleepMs = {4}, 
buffersPerChannel = {5}")
     public static Object[][] getScaleFactors() {
+        // We use `sourceSleepMs` > 0 to test rescaling without backpressure 
and only very few
+        // captured in-flight records, see FLINK-31963.
         Object[][] parameters =
                 new Object[][] {
-                    new Object[] {"downscale", 
Topology.KEYED_DIFFERENT_PARALLELISM, 12, 7},
-                    new Object[] {"upscale", 
Topology.KEYED_DIFFERENT_PARALLELISM, 7, 12},
-                    new Object[] {"downscale", Topology.KEYED_BROADCAST, 7, 2},
-                    new Object[] {"upscale", Topology.KEYED_BROADCAST, 2, 7},
-                    new Object[] {"downscale", Topology.BROADCAST, 5, 2},
-                    new Object[] {"upscale", Topology.BROADCAST, 2, 5},
-                    new Object[] {"upscale", Topology.PIPELINE, 1, 2},
-                    new Object[] {"upscale", Topology.PIPELINE, 2, 3},
-                    new Object[] {"upscale", Topology.PIPELINE, 3, 7},
-                    new Object[] {"upscale", Topology.PIPELINE, 4, 8},
-                    new Object[] {"upscale", Topology.PIPELINE, 20, 21},
-                    new Object[] {"downscale", Topology.PIPELINE, 2, 1},
-                    new Object[] {"downscale", Topology.PIPELINE, 3, 2},
-                    new Object[] {"downscale", Topology.PIPELINE, 7, 3},
-                    new Object[] {"downscale", Topology.PIPELINE, 8, 4},
-                    new Object[] {"downscale", Topology.PIPELINE, 21, 20},
-                    new Object[] {"no scale", Topology.PIPELINE, 1, 1},
-                    new Object[] {"no scale", Topology.PIPELINE, 3, 3},
-                    new Object[] {"no scale", Topology.PIPELINE, 7, 7},
-                    new Object[] {"no scale", Topology.PIPELINE, 20, 20},
-                    new Object[] {"upscale", Topology.UNION, 1, 2},
-                    new Object[] {"upscale", Topology.UNION, 2, 3},
-                    new Object[] {"upscale", Topology.UNION, 3, 7},
-                    new Object[] {"downscale", Topology.UNION, 2, 1},
-                    new Object[] {"downscale", Topology.UNION, 3, 2},
-                    new Object[] {"downscale", Topology.UNION, 7, 3},
-                    new Object[] {"no scale", Topology.UNION, 1, 1},
-                    new Object[] {"no scale", Topology.UNION, 7, 7},
-                    new Object[] {"upscale", Topology.MULTI_INPUT, 1, 2},
-                    new Object[] {"upscale", Topology.MULTI_INPUT, 2, 3},
-                    new Object[] {"upscale", Topology.MULTI_INPUT, 3, 7},
-                    new Object[] {"downscale", Topology.MULTI_INPUT, 2, 1},
-                    new Object[] {"downscale", Topology.MULTI_INPUT, 3, 2},
-                    new Object[] {"downscale", Topology.MULTI_INPUT, 7, 3},
-                    new Object[] {"no scale", Topology.MULTI_INPUT, 1, 1},
-                    new Object[] {"no scale", Topology.MULTI_INPUT, 7, 7},
+                    new Object[] {"downscale", 
Topology.KEYED_DIFFERENT_PARALLELISM, 12, 7, 0L},
+                    new Object[] {"upscale", 
Topology.KEYED_DIFFERENT_PARALLELISM, 7, 12, 0L},
+                    new Object[] {"downscale", 
Topology.KEYED_DIFFERENT_PARALLELISM, 5, 3, 5L},
+                    new Object[] {"upscale", 
Topology.KEYED_DIFFERENT_PARALLELISM, 3, 5, 5L},
+                    new Object[] {"downscale", Topology.KEYED_BROADCAST, 7, 2, 
0L},
+                    new Object[] {"upscale", Topology.KEYED_BROADCAST, 2, 7, 
0L},
+                    new Object[] {"downscale", Topology.KEYED_BROADCAST, 5, 3, 
5L},
+                    new Object[] {"upscale", Topology.KEYED_BROADCAST, 3, 5, 
5L},
+                    new Object[] {"downscale", Topology.BROADCAST, 5, 2, 0L},
+                    new Object[] {"upscale", Topology.BROADCAST, 2, 5, 0L},
+                    new Object[] {"downscale", Topology.BROADCAST, 5, 3, 5L},
+                    new Object[] {"upscale", Topology.BROADCAST, 3, 5, 5L},
+                    new Object[] {"upscale", Topology.PIPELINE, 1, 2, 0L},
+                    new Object[] {"upscale", Topology.PIPELINE, 2, 3, 0L},
+                    new Object[] {"upscale", Topology.PIPELINE, 3, 7, 0L},
+                    new Object[] {"upscale", Topology.PIPELINE, 4, 8, 0L},
+                    new Object[] {"upscale", Topology.PIPELINE, 20, 21, 0L},
+                    new Object[] {"upscale", Topology.PIPELINE, 3, 5, 5L},
+                    new Object[] {"downscale", Topology.PIPELINE, 2, 1, 0L},
+                    new Object[] {"downscale", Topology.PIPELINE, 3, 2, 0L},
+                    new Object[] {"downscale", Topology.PIPELINE, 7, 3, 0L},
+                    new Object[] {"downscale", Topology.PIPELINE, 8, 4, 0L},
+                    new Object[] {"downscale", Topology.PIPELINE, 21, 20, 0L},
+                    new Object[] {"downscale", Topology.PIPELINE, 5, 3, 5L},
+                    new Object[] {"no scale", Topology.PIPELINE, 1, 1, 0L},
+                    new Object[] {"no scale", Topology.PIPELINE, 3, 3, 0L},
+                    new Object[] {"no scale", Topology.PIPELINE, 7, 7, 0L},
+                    new Object[] {"no scale", Topology.PIPELINE, 20, 20, 0L},
+                    new Object[] {"upscale", Topology.UNION, 1, 2, 0L},
+                    new Object[] {"upscale", Topology.UNION, 2, 3, 0L},
+                    new Object[] {"upscale", Topology.UNION, 3, 7, 0L},
+                    new Object[] {"upscale", Topology.UNION, 3, 5, 5L},
+                    new Object[] {"downscale", Topology.UNION, 2, 1, 0L},
+                    new Object[] {"downscale", Topology.UNION, 3, 2, 0L},
+                    new Object[] {"downscale", Topology.UNION, 7, 3, 0L},
+                    new Object[] {"downscale", Topology.UNION, 5, 3, 5L},
+                    new Object[] {"no scale", Topology.UNION, 1, 1, 0L},
+                    new Object[] {"no scale", Topology.UNION, 7, 7, 0L},
+                    new Object[] {"upscale", Topology.MULTI_INPUT, 1, 2, 0L},
+                    new Object[] {"upscale", Topology.MULTI_INPUT, 2, 3, 0L},
+                    new Object[] {"upscale", Topology.MULTI_INPUT, 3, 7, 0L},
+                    new Object[] {"upscale", Topology.MULTI_INPUT, 3, 5, 5L},
+                    new Object[] {"downscale", Topology.MULTI_INPUT, 2, 1, 0L},
+                    new Object[] {"downscale", Topology.MULTI_INPUT, 3, 2, 0L},
+                    new Object[] {"downscale", Topology.MULTI_INPUT, 7, 3, 0L},
+                    new Object[] {"downscale", Topology.MULTI_INPUT, 5, 3, 5L},
+                    new Object[] {"no scale", Topology.MULTI_INPUT, 1, 1, 0L},
+                    new Object[] {"no scale", Topology.MULTI_INPUT, 7, 7, 0L},
                 };
         return Arrays.stream(parameters)
                 .map(
@@ -516,10 +550,12 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
             Topology topology,
             int oldParallelism,
             int newParallelism,
+            long sourceSleepMs,
             int buffersPerChannel) {
         this.topology = topology;
         this.oldParallelism = oldParallelism;
         this.newParallelism = newParallelism;
+        this.sourceSleepMs = sourceSleepMs;
         this.buffersPerChannel = buffersPerChannel;
     }
 
@@ -529,7 +565,8 @@ public class UnalignedCheckpointRescaleITCase extends 
UnalignedCheckpointTestBas
                 new UnalignedSettings(topology)
                         .setParallelism(oldParallelism)
                         .setExpectedFailures(1)
-                        .setBuffersPerChannel(buffersPerChannel);
+                        .setBuffersPerChannel(buffersPerChannel)
+                        .setSourceSleepMs(sourceSleepMs);
         prescaleSettings.setGenerateCheckpoint(true);
         final File checkpointDir = super.execute(prescaleSettings);
 
diff --git 
a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointTestBase.java
 
b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointTestBase.java
index 3420cf6f090..7ae3ad63146 100644
--- 
a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointTestBase.java
+++ 
b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointTestBase.java
@@ -207,7 +207,8 @@ public abstract class UnalignedCheckpointTestBase extends 
TestLogger {
                 setupEnv,
                 settings.minCheckpoints,
                 settings.channelType.slotSharing,
-                settings.expectedFailures - 
settings.failuresAfterSourceFinishes);
+                settings.expectedFailures - 
settings.failuresAfterSourceFinishes,
+                settings.sourceSleepMs);
 
         return setupEnv.getStreamGraph();
     }
@@ -221,16 +222,19 @@ public abstract class UnalignedCheckpointTestBase extends 
TestLogger {
         private final int numSplits;
         private final int expectedRestarts;
         private final long checkpointingInterval;
+        private final long sourceSleepMs;
 
         protected LongSource(
                 int minCheckpoints,
                 int numSplits,
                 int expectedRestarts,
-                long checkpointingInterval) {
+                long checkpointingInterval,
+                long sourceSleepMs) {
             this.minCheckpoints = minCheckpoints;
             this.numSplits = numSplits;
             this.expectedRestarts = expectedRestarts;
             this.checkpointingInterval = checkpointingInterval;
+            this.sourceSleepMs = sourceSleepMs;
         }
 
         @Override
@@ -244,7 +248,8 @@ public abstract class UnalignedCheckpointTestBase extends 
TestLogger {
                     readerContext.getIndexOfSubtask(),
                     minCheckpoints,
                     expectedRestarts,
-                    checkpointingInterval);
+                    checkpointingInterval,
+                    sourceSleepMs);
         }
 
         @Override
@@ -285,17 +290,20 @@ public abstract class UnalignedCheckpointTestBase extends 
TestLogger {
             private int numCompletedCheckpoints;
             private boolean finishing;
             private boolean recovered;
+            private final long sourceSleepMs;
             @Nullable private Deadline pumpingUntil = null;
 
             public LongSourceReader(
                     int subtaskIndex,
                     int minCheckpoints,
                     int expectedRestarts,
-                    long checkpointingInterval) {
+                    long checkpointingInterval,
+                    long sourceSleepMs) {
                 this.subtaskIndex = subtaskIndex;
                 this.minCheckpoints = minCheckpoints;
                 this.expectedRestarts = expectedRestarts;
-                pumpInterval = Duration.ofMillis(checkpointingInterval);
+                this.pumpInterval = Duration.ofMillis(checkpointingInterval);
+                this.sourceSleepMs = sourceSleepMs;
             }
 
             @Override
@@ -304,6 +312,9 @@ public abstract class UnalignedCheckpointTestBase extends 
TestLogger {
             @Override
             public InputStatus pollNext(ReaderOutput<Long> output) throws 
InterruptedException {
                 for (LongSplit split : splits) {
+                    if (sourceSleepMs > 0L) {
+                        Thread.sleep(sourceSleepMs);
+                    }
                     output.collect(withHeader(split.nextNumber), 
split.nextNumber);
                     split.nextNumber += split.increment;
                 }
@@ -627,7 +638,8 @@ public abstract class UnalignedCheckpointTestBase extends 
TestLogger {
                 StreamExecutionEnvironment environment,
                 int minCheckpoints,
                 boolean slotSharing,
-                int expectedFailuresUntilSourceFinishes);
+                int expectedFailuresUntilSourceFinishes,
+                long sourceSleepMs);
     }
 
     /** Which channels are used to connect the tasks. */
@@ -664,6 +676,7 @@ public abstract class UnalignedCheckpointTestBase extends 
TestLogger {
         private int failuresAfterSourceFinishes = 0;
         private ChannelType channelType = ChannelType.MIXED;
         private int buffersPerChannel = 1;
+        private long sourceSleepMs = 0;
 
         public UnalignedSettings(DagCreator dagCreator) {
             this.dagCreator = dagCreator;
@@ -719,6 +732,11 @@ public abstract class UnalignedCheckpointTestBase extends 
TestLogger {
             return this;
         }
 
+        public UnalignedSettings setSourceSleepMs(long sourceSleepMs) {
+            this.sourceSleepMs = sourceSleepMs;
+            return this;
+        }
+
         public void configure(StreamExecutionEnvironment env) {
             env.enableCheckpointing(Math.max(100L, parallelism * 50L));
             
env.getCheckpointConfig().setAlignmentTimeout(Duration.ofMillis(alignmentTimeout));
@@ -791,6 +809,8 @@ public abstract class UnalignedCheckpointTestBase extends 
TestLogger {
                     + failuresAfterSourceFinishes
                     + ", channelType="
                     + channelType
+                    + ", sourceSleepMs="
+                    + sourceSleepMs
                     + '}';
         }
     }


Reply via email to