StefanRRichter commented on code in PR #22584: URL: https://github.com/apache/flink/pull/22584#discussion_r1193845132
########## flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperationTest.java: ########## @@ -785,6 +789,130 @@ public void testOnlyUpstreamChannelStateAssignment() } } + /** FLINK-31963: Tests rescaling for stateless operators and upstream result partition state. */ + @Test + public void testOnlyUpstreamChannelRescaleStateAssignment() + throws JobException, JobExecutionException { + Random random = new Random(); + OperatorSubtaskState upstreamOpState = + OperatorSubtaskState.builder() + .setResultSubpartitionState( + new StateObjectCollection<>( + asList( + createNewResultSubpartitionStateHandle(10, random), + createNewResultSubpartitionStateHandle( + 10, random)))) + .build(); + testOnlyUpstreamOrDownstreamRescalingInternal(upstreamOpState, null, 5, 7); + } + + /** FLINK-31963: Tests rescaling for stateless operators and downstream input channel state. */ + @Test + public void testOnlyDownstreamChannelRescaleStateAssignment() + throws JobException, JobExecutionException { + Random random = new Random(); + OperatorSubtaskState downstreamOpState = + OperatorSubtaskState.builder() + .setInputChannelState( + new StateObjectCollection<>( + asList( + createNewInputChannelStateHandle(10, random), + createNewInputChannelStateHandle(10, random)))) + .build(); + testOnlyUpstreamOrDownstreamRescalingInternal(null, downstreamOpState, 5, 5); + } + + private void testOnlyUpstreamOrDownstreamRescalingInternal( + @Nullable OperatorSubtaskState upstreamOpState, + @Nullable OperatorSubtaskState downstreamOpState, + int expectedUpstreamCount, + int expectedDownstreamCount) + throws JobException, JobExecutionException { + + if ((upstreamOpState == null && downstreamOpState == null) + || (upstreamOpState != null && downstreamOpState != null)) { + // Either upstream or downstream state must exist, but not both. + return; + } + + // Start from parallelism 5 for both operators + int upstreamParallelism = 5; + int downstreamParallelism = 5; + + // Build states + List<OperatorID> operatorIds = buildOperatorIds(2); + Map<OperatorID, OperatorState> states = new HashMap<>(); + OperatorState upstreamState = + new OperatorState(operatorIds.get(0), upstreamParallelism, MAX_P); + OperatorState downstreamState = + new OperatorState(operatorIds.get(1), downstreamParallelism, MAX_P); + + states.put(operatorIds.get(0), upstreamState); + states.put(operatorIds.get(1), downstreamState); + + if (upstreamOpState != null) { + upstreamState.putState(0, upstreamOpState); + // rescale downstream 5 -> 3 + downstreamParallelism = 3; + } + + if (downstreamOpState != null) { + downstreamState.putState(0, downstreamOpState); + // rescale upstream 5 -> 3 + upstreamParallelism = 3; + } + + List<OperatorIdWithParallelism> opIdWithParallelism = new ArrayList<>(2); + opIdWithParallelism.add( + new OperatorIdWithParallelism(operatorIds.get(0), upstreamParallelism)); + opIdWithParallelism.add( + new OperatorIdWithParallelism(operatorIds.get(1), downstreamParallelism)); + + Map<OperatorID, ExecutionJobVertex> vertices = + buildVertices(opIdWithParallelism, RANGE, ROUND_ROBIN); + + // Run state assignment + new StateAssignmentOperation(0, new HashSet<>(vertices.values()), states, false) + .assignStates(); + + // Check results + ExecutionJobVertex upstreamExecutionJobVertex = vertices.get(operatorIds.get(0)); + ExecutionJobVertex downstreamExecutionJobVertex = vertices.get(operatorIds.get(1)); + + List<TaskStateSnapshot> upstreamRescalingDescriptors = + getRescalingDescriptorsFromVertex(upstreamExecutionJobVertex); + List<TaskStateSnapshot> downstreamRescalingDescriptors = + getRescalingDescriptorsFromVertex(downstreamExecutionJobVertex); + + checkMappings( + upstreamRescalingDescriptors, + TaskStateSnapshot::getOutputRescalingDescriptor, + expectedUpstreamCount); + + checkMappings( + downstreamRescalingDescriptors, + TaskStateSnapshot::getInputRescalingDescriptor, + expectedDownstreamCount); + } + + private void checkMappings( + List<TaskStateSnapshot> taskStateSnapshots, + Function<TaskStateSnapshot, InflightDataRescalingDescriptor> extractFun, + int expectedCount) { + Assert.assertEquals( + expectedCount, + taskStateSnapshots.stream() + .map(extractFun) + .mapToInt( + x -> { + int len = x.getOldSubtaskIndexes(0).length; + // Assert that there is a mapping. + Assert.assertTrue(len > 0); + return len; + }) + .sum()); Review Comment: I was thinking about this and decided to keep the test targeted at just checking that a remapping has happened. I'd hope there are already tests that check the correctness of such remappings thoroughly. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org