pnowojski commented on code in PR #22584:
URL: https://github.com/apache/flink/pull/22584#discussion_r1193812795


##########
flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperationTest.java:
##########
@@ -785,6 +789,130 @@ public void testOnlyUpstreamChannelStateAssignment()
         }
     }
 
+    /** FLINK-31963: Tests rescaling for stateless operators and upstream 
result partition state. */
+    @Test
+    public void testOnlyUpstreamChannelRescaleStateAssignment()
+            throws JobException, JobExecutionException {
+        Random random = new Random();
+        OperatorSubtaskState upstreamOpState =
+                OperatorSubtaskState.builder()
+                        .setResultSubpartitionState(
+                                new StateObjectCollection<>(
+                                        asList(
+                                                
createNewResultSubpartitionStateHandle(10, random),
+                                                
createNewResultSubpartitionStateHandle(
+                                                        10, random))))
+                        .build();
+        testOnlyUpstreamOrDownstreamRescalingInternal(upstreamOpState, null, 
5, 7);
+    }
+
+    /** FLINK-31963: Tests rescaling for stateless operators and downstream 
input channel state. */
+    @Test
+    public void testOnlyDownstreamChannelRescaleStateAssignment()
+            throws JobException, JobExecutionException {
+        Random random = new Random();
+        OperatorSubtaskState downstreamOpState =
+                OperatorSubtaskState.builder()
+                        .setInputChannelState(
+                                new StateObjectCollection<>(
+                                        asList(
+                                                
createNewInputChannelStateHandle(10, random),
+                                                
createNewInputChannelStateHandle(10, random))))
+                        .build();
+        testOnlyUpstreamOrDownstreamRescalingInternal(null, downstreamOpState, 
5, 5);
+    }
+
+    private void testOnlyUpstreamOrDownstreamRescalingInternal(
+            @Nullable OperatorSubtaskState upstreamOpState,
+            @Nullable OperatorSubtaskState downstreamOpState,
+            int expectedUpstreamCount,
+            int expectedDownstreamCount)
+            throws JobException, JobExecutionException {
+
+        if ((upstreamOpState == null && downstreamOpState == null)
+                || (upstreamOpState != null && downstreamOpState != null)) {
+            // Either upstream or downstream state must exist, but not both.
+            return;
+        }
+
+        // Start from parallelism 5 for both operators
+        int upstreamParallelism = 5;
+        int downstreamParallelism = 5;
+
+        // Build states
+        List<OperatorID> operatorIds = buildOperatorIds(2);
+        Map<OperatorID, OperatorState> states = new HashMap<>();
+        OperatorState upstreamState =
+                new OperatorState(operatorIds.get(0), upstreamParallelism, 
MAX_P);
+        OperatorState downstreamState =
+                new OperatorState(operatorIds.get(1), downstreamParallelism, 
MAX_P);
+
+        states.put(operatorIds.get(0), upstreamState);
+        states.put(operatorIds.get(1), downstreamState);
+
+        if (upstreamOpState != null) {
+            upstreamState.putState(0, upstreamOpState);
+            // rescale downstream 5 -> 3
+            downstreamParallelism = 3;
+        }
+
+        if (downstreamOpState != null) {
+            downstreamState.putState(0, downstreamOpState);
+            // rescale upstream 5 -> 3
+            upstreamParallelism = 3;
+        }
+
+        List<OperatorIdWithParallelism> opIdWithParallelism = new 
ArrayList<>(2);
+        opIdWithParallelism.add(
+                new OperatorIdWithParallelism(operatorIds.get(0), 
upstreamParallelism));
+        opIdWithParallelism.add(
+                new OperatorIdWithParallelism(operatorIds.get(1), 
downstreamParallelism));
+
+        Map<OperatorID, ExecutionJobVertex> vertices =
+                buildVertices(opIdWithParallelism, RANGE, ROUND_ROBIN);
+
+        // Run state assignment
+        new StateAssignmentOperation(0, new HashSet<>(vertices.values()), 
states, false)
+                .assignStates();
+
+        // Check results
+        ExecutionJobVertex upstreamExecutionJobVertex = 
vertices.get(operatorIds.get(0));
+        ExecutionJobVertex downstreamExecutionJobVertex = 
vertices.get(operatorIds.get(1));
+
+        List<TaskStateSnapshot> upstreamRescalingDescriptors =
+                getRescalingDescriptorsFromVertex(upstreamExecutionJobVertex);
+        List<TaskStateSnapshot> downstreamRescalingDescriptors =
+                
getRescalingDescriptorsFromVertex(downstreamExecutionJobVertex);
+
+        checkMappings(
+                upstreamRescalingDescriptors,
+                TaskStateSnapshot::getOutputRescalingDescriptor,
+                expectedUpstreamCount);
+
+        checkMappings(
+                downstreamRescalingDescriptors,
+                TaskStateSnapshot::getInputRescalingDescriptor,
+                expectedDownstreamCount);
+    }
+
+    private void checkMappings(
+            List<TaskStateSnapshot> taskStateSnapshots,
+            Function<TaskStateSnapshot, InflightDataRescalingDescriptor> 
extractFun,
+            int expectedCount) {
+        Assert.assertEquals(
+                expectedCount,
+                taskStateSnapshots.stream()
+                        .map(extractFun)
+                        .mapToInt(
+                                x -> {
+                                    int len = x.getOldSubtaskIndexes(0).length;
+                                    // Assert that there is a mapping.
+                                    Assert.assertTrue(len > 0);
+                                    return len;
+                                })
+                        .sum());

Review Comment:
   Instead of asserting length of the mappings, should we assert the actual 
mappings? 🤔 



##########
flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperationTest.java:
##########
@@ -785,6 +789,130 @@ public void testOnlyUpstreamChannelStateAssignment()
         }
     }
 
+    /** FLINK-31963: Tests rescaling for stateless operators and upstream 
result partition state. */
+    @Test
+    public void testOnlyUpstreamChannelRescaleStateAssignment()
+            throws JobException, JobExecutionException {
+        Random random = new Random();
+        OperatorSubtaskState upstreamOpState =
+                OperatorSubtaskState.builder()
+                        .setResultSubpartitionState(
+                                new StateObjectCollection<>(
+                                        asList(
+                                                
createNewResultSubpartitionStateHandle(10, random),
+                                                
createNewResultSubpartitionStateHandle(
+                                                        10, random))))
+                        .build();
+        testOnlyUpstreamOrDownstreamRescalingInternal(upstreamOpState, null, 
5, 7);
+    }
+
+    /** FLINK-31963: Tests rescaling for stateless operators and downstream 
input channel state. */
+    @Test
+    public void testOnlyDownstreamChannelRescaleStateAssignment()
+            throws JobException, JobExecutionException {
+        Random random = new Random();
+        OperatorSubtaskState downstreamOpState =
+                OperatorSubtaskState.builder()
+                        .setInputChannelState(
+                                new StateObjectCollection<>(
+                                        asList(
+                                                
createNewInputChannelStateHandle(10, random),
+                                                
createNewInputChannelStateHandle(10, random))))
+                        .build();
+        testOnlyUpstreamOrDownstreamRescalingInternal(null, downstreamOpState, 
5, 5);
+    }
+
+    private void testOnlyUpstreamOrDownstreamRescalingInternal(
+            @Nullable OperatorSubtaskState upstreamOpState,
+            @Nullable OperatorSubtaskState downstreamOpState,
+            int expectedUpstreamCount,
+            int expectedDownstreamCount)
+            throws JobException, JobExecutionException {
+
+        if ((upstreamOpState == null && downstreamOpState == null)
+                || (upstreamOpState != null && downstreamOpState != null)) {
+            // Either upstream or downstream state must exist, but not both.
+            return;
+        }
+
+        // Start from parallelism 5 for both operators
+        int upstreamParallelism = 5;
+        int downstreamParallelism = 5;
+
+        // Build states
+        List<OperatorID> operatorIds = buildOperatorIds(2);
+        Map<OperatorID, OperatorState> states = new HashMap<>();
+        OperatorState upstreamState =
+                new OperatorState(operatorIds.get(0), upstreamParallelism, 
MAX_P);
+        OperatorState downstreamState =
+                new OperatorState(operatorIds.get(1), downstreamParallelism, 
MAX_P);
+
+        states.put(operatorIds.get(0), upstreamState);
+        states.put(operatorIds.get(1), downstreamState);
+
+        if (upstreamOpState != null) {
+            upstreamState.putState(0, upstreamOpState);
+            // rescale downstream 5 -> 3
+            downstreamParallelism = 3;
+        }
+
+        if (downstreamOpState != null) {
+            downstreamState.putState(0, downstreamOpState);
+            // rescale upstream 5 -> 3
+            upstreamParallelism = 3;
+        }
+
+        List<OperatorIdWithParallelism> opIdWithParallelism = new 
ArrayList<>(2);
+        opIdWithParallelism.add(
+                new OperatorIdWithParallelism(operatorIds.get(0), 
upstreamParallelism));
+        opIdWithParallelism.add(
+                new OperatorIdWithParallelism(operatorIds.get(1), 
downstreamParallelism));
+
+        Map<OperatorID, ExecutionJobVertex> vertices =
+                buildVertices(opIdWithParallelism, RANGE, ROUND_ROBIN);
+
+        // Run state assignment
+        new StateAssignmentOperation(0, new HashSet<>(vertices.values()), 
states, false)
+                .assignStates();
+
+        // Check results
+        ExecutionJobVertex upstreamExecutionJobVertex = 
vertices.get(operatorIds.get(0));
+        ExecutionJobVertex downstreamExecutionJobVertex = 
vertices.get(operatorIds.get(1));
+
+        List<TaskStateSnapshot> upstreamRescalingDescriptors =
+                getRescalingDescriptorsFromVertex(upstreamExecutionJobVertex);
+        List<TaskStateSnapshot> downstreamRescalingDescriptors =
+                
getRescalingDescriptorsFromVertex(downstreamExecutionJobVertex);

Review Comment:
   Rename `RescalingDescriptors` -> `TaskStateSnapshots`? You are obtaining 
descriptors from the `TaskStateSnapshot` in the next step within 
`checkMappings`.



##########
flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperationTest.java:
##########
@@ -785,6 +789,130 @@ public void testOnlyUpstreamChannelStateAssignment()
         }
     }
 
+    /** FLINK-31963: Tests rescaling for stateless operators and upstream 
result partition state. */
+    @Test
+    public void testOnlyUpstreamChannelRescaleStateAssignment()
+            throws JobException, JobExecutionException {
+        Random random = new Random();
+        OperatorSubtaskState upstreamOpState =
+                OperatorSubtaskState.builder()
+                        .setResultSubpartitionState(
+                                new StateObjectCollection<>(
+                                        asList(
+                                                
createNewResultSubpartitionStateHandle(10, random),
+                                                
createNewResultSubpartitionStateHandle(
+                                                        10, random))))
+                        .build();
+        testOnlyUpstreamOrDownstreamRescalingInternal(upstreamOpState, null, 
5, 7);
+    }
+
+    /** FLINK-31963: Tests rescaling for stateless operators and downstream 
input channel state. */
+    @Test
+    public void testOnlyDownstreamChannelRescaleStateAssignment()
+            throws JobException, JobExecutionException {
+        Random random = new Random();
+        OperatorSubtaskState downstreamOpState =
+                OperatorSubtaskState.builder()
+                        .setInputChannelState(
+                                new StateObjectCollection<>(
+                                        asList(
+                                                
createNewInputChannelStateHandle(10, random),
+                                                
createNewInputChannelStateHandle(10, random))))
+                        .build();
+        testOnlyUpstreamOrDownstreamRescalingInternal(null, downstreamOpState, 
5, 5);
+    }
+
+    private void testOnlyUpstreamOrDownstreamRescalingInternal(
+            @Nullable OperatorSubtaskState upstreamOpState,
+            @Nullable OperatorSubtaskState downstreamOpState,
+            int expectedUpstreamCount,
+            int expectedDownstreamCount)
+            throws JobException, JobExecutionException {
+
+        if ((upstreamOpState == null && downstreamOpState == null)
+                || (upstreamOpState != null && downstreamOpState != null)) {
+            // Either upstream or downstream state must exist, but not both.
+            return;
+        }

Review Comment:
   `checkArgument(..., "Either upstream or downstream state must exist, but not 
both")`?



##########
flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperationTest.java:
##########
@@ -785,6 +789,130 @@ public void testOnlyUpstreamChannelStateAssignment()
         }
     }
 
+    /** FLINK-31963: Tests rescaling for stateless operators and upstream 
result partition state. */
+    @Test
+    public void testOnlyUpstreamChannelRescaleStateAssignment()
+            throws JobException, JobExecutionException {
+        Random random = new Random();
+        OperatorSubtaskState upstreamOpState =
+                OperatorSubtaskState.builder()
+                        .setResultSubpartitionState(
+                                new StateObjectCollection<>(
+                                        asList(
+                                                
createNewResultSubpartitionStateHandle(10, random),
+                                                
createNewResultSubpartitionStateHandle(
+                                                        10, random))))
+                        .build();
+        testOnlyUpstreamOrDownstreamRescalingInternal(upstreamOpState, null, 
5, 7);
+    }
+
+    /** FLINK-31963: Tests rescaling for stateless operators and downstream 
input channel state. */
+    @Test
+    public void testOnlyDownstreamChannelRescaleStateAssignment()
+            throws JobException, JobExecutionException {
+        Random random = new Random();
+        OperatorSubtaskState downstreamOpState =
+                OperatorSubtaskState.builder()
+                        .setInputChannelState(
+                                new StateObjectCollection<>(
+                                        asList(
+                                                
createNewInputChannelStateHandle(10, random),
+                                                
createNewInputChannelStateHandle(10, random))))
+                        .build();
+        testOnlyUpstreamOrDownstreamRescalingInternal(null, downstreamOpState, 
5, 5);
+    }
+
+    private void testOnlyUpstreamOrDownstreamRescalingInternal(
+            @Nullable OperatorSubtaskState upstreamOpState,
+            @Nullable OperatorSubtaskState downstreamOpState,
+            int expectedUpstreamCount,
+            int expectedDownstreamCount)
+            throws JobException, JobExecutionException {
+
+        if ((upstreamOpState == null && downstreamOpState == null)
+                || (upstreamOpState != null && downstreamOpState != null)) {
+            // Either upstream or downstream state must exist, but not both.
+            return;
+        }
+
+        // Start from parallelism 5 for both operators
+        int upstreamParallelism = 5;
+        int downstreamParallelism = 5;
+
+        // Build states
+        List<OperatorID> operatorIds = buildOperatorIds(2);
+        Map<OperatorID, OperatorState> states = new HashMap<>();
+        OperatorState upstreamState =
+                new OperatorState(operatorIds.get(0), upstreamParallelism, 
MAX_P);
+        OperatorState downstreamState =
+                new OperatorState(operatorIds.get(1), downstreamParallelism, 
MAX_P);
+
+        states.put(operatorIds.get(0), upstreamState);
+        states.put(operatorIds.get(1), downstreamState);
+
+        if (upstreamOpState != null) {
+            upstreamState.putState(0, upstreamOpState);
+            // rescale downstream 5 -> 3
+            downstreamParallelism = 3;
+        }
+
+        if (downstreamOpState != null) {
+            downstreamState.putState(0, downstreamOpState);
+            // rescale upstream 5 -> 3
+            upstreamParallelism = 3;
+        }
+
+        List<OperatorIdWithParallelism> opIdWithParallelism = new 
ArrayList<>(2);
+        opIdWithParallelism.add(
+                new OperatorIdWithParallelism(operatorIds.get(0), 
upstreamParallelism));
+        opIdWithParallelism.add(
+                new OperatorIdWithParallelism(operatorIds.get(1), 
downstreamParallelism));
+
+        Map<OperatorID, ExecutionJobVertex> vertices =
+                buildVertices(opIdWithParallelism, RANGE, ROUND_ROBIN);
+
+        // Run state assignment
+        new StateAssignmentOperation(0, new HashSet<>(vertices.values()), 
states, false)
+                .assignStates();
+
+        // Check results
+        ExecutionJobVertex upstreamExecutionJobVertex = 
vertices.get(operatorIds.get(0));
+        ExecutionJobVertex downstreamExecutionJobVertex = 
vertices.get(operatorIds.get(1));
+
+        List<TaskStateSnapshot> upstreamRescalingDescriptors =
+                getRescalingDescriptorsFromVertex(upstreamExecutionJobVertex);
+        List<TaskStateSnapshot> downstreamRescalingDescriptors =
+                
getRescalingDescriptorsFromVertex(downstreamExecutionJobVertex);
+
+        checkMappings(
+                upstreamRescalingDescriptors,
+                TaskStateSnapshot::getOutputRescalingDescriptor,
+                expectedUpstreamCount);
+
+        checkMappings(
+                downstreamRescalingDescriptors,
+                TaskStateSnapshot::getInputRescalingDescriptor,
+                expectedDownstreamCount);

Review Comment:
   nit: instead of lambda functions I would accept a little bit of code 
deduplication and replace those calls with:
   ```
           checkMappings(
                   
upstreamTaskStateSnapshots.stream().map(TaskStateSnapshot::getOutputRescalingDescriptor),
                   expectedUpstreamCount);
   
           checkMappings(
                   
downstreamTaskStateSnapshots.stream().map(TaskStateSnapshot::getInputRescalingDescriptor),
                   expectedDownstreamCount);
   ```



##########
flink-tests/src/test/java/org/apache/flink/test/checkpointing/UnalignedCheckpointRescaleITCase.java:
##########
@@ -459,46 +478,59 @@ public void processBroadcastElement(Long value, Context 
ctx, Collector<Long> out
         }
     }
 
-    @Parameterized.Parameters(name = "{0} {1} from {2} to {3}, 
buffersPerChannel = {4}")
+    @Parameterized.Parameters(
+            name = "{0} {1} from {2} to {3}, sourceSleepMs = {4}, 
buffersPerChannel = {5}")

Review Comment:
   I would add a comment above this line explaining why do we want to have non 
zero `sourceSleepMs` sometimes. That we want to test the rescaling without 
backpressure with only occasional a couple of captured in-flight records .



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to