This is an automated email from the ASF dual-hosted git repository. fanrui pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit 686c00f8e3b727f15b7b922397207372913ec101 Author: Rui Fan <[email protected]> AuthorDate: Mon Nov 3 22:03:34 2025 +0100 [FLINK-38542][checkpoint] Recover output buffers of upstream task on downstream task side directly --- .../generated/checkpointing_configuration.html | 6 + .../flink/configuration/CheckpointingOptions.java | 10 + .../runtime/checkpoint/CheckpointCoordinator.java | 9 +- .../runtime/checkpoint/OperatorSubtaskState.java | 26 +- .../checkpoint/StateAssignmentOperation.java | 36 +- .../runtime/checkpoint/TaskStateAssignment.java | 145 +++++- .../channel/SequentialChannelStateReaderImpl.java | 5 + .../tasks/CheckpointCoordinatorConfiguration.java | 29 +- .../flink/streaming/api/graph/StreamGraph.java | 5 + .../checkpoint/StateAssignmentOperationTest.java | 580 +++++++++++++++++---- .../runtime/checkpoint/StateHandleDummyUtil.java | 8 +- .../test/state/ChangelogRecoveryCachingITCase.java | 4 + 12 files changed, 742 insertions(+), 121 deletions(-) diff --git a/docs/layouts/shortcodes/generated/checkpointing_configuration.html b/docs/layouts/shortcodes/generated/checkpointing_configuration.html index 1766e2d7589..cb38355a92a 100644 --- a/docs/layouts/shortcodes/generated/checkpointing_configuration.html +++ b/docs/layouts/shortcodes/generated/checkpointing_configuration.html @@ -182,6 +182,12 @@ <td>Integer</td> <td>Defines the maximum number of subtasks that share the same channel state file. It can reduce the number of small files when enable unaligned checkpoint. Each subtask will create a new channel state file when this is configured to 1.</td> </tr> + <tr> + <td><h5>execution.checkpointing.unaligned.recover-output-on-downstream.enabled</h5></td> + <td style="word-wrap: break-word;">false</td> + <td>Boolean</td> + <td>Whether recovering output buffers of upstream task on downstream task directly when job restores from the unaligned checkpoint.</td> + </tr> <tr> <td><h5>execution.checkpointing.write-buffer-size</h5></td> <td style="word-wrap: break-word;">4096</td> diff --git a/flink-core/src/main/java/org/apache/flink/configuration/CheckpointingOptions.java b/flink-core/src/main/java/org/apache/flink/configuration/CheckpointingOptions.java index 8e4ae89184e..9921ffc8f5f 100644 --- a/flink-core/src/main/java/org/apache/flink/configuration/CheckpointingOptions.java +++ b/flink-core/src/main/java/org/apache/flink/configuration/CheckpointingOptions.java @@ -646,6 +646,16 @@ public class CheckpointingOptions { + "It can reduce the number of small files when enable unaligned checkpoint. " + "Each subtask will create a new channel state file when this is configured to 1."); + @Experimental + public static final ConfigOption<Boolean> UNALIGNED_RECOVER_OUTPUT_ON_DOWNSTREAM = + ConfigOptions.key( + "execution.checkpointing.unaligned.recover-output-on-downstream.enabled") + .booleanType() + .defaultValue(false) + .withDescription( + "Whether recovering output buffers of upstream task on downstream task directly " + + "when job restores from the unaligned checkpoint."); + /** * Determines whether checkpointing is enabled based on the configuration. * diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java index 85622387daa..bc65b09a7ef 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java @@ -241,6 +241,8 @@ public class CheckpointCoordinator { private final boolean isExactlyOnceMode; + private final boolean recoverOutputOnDownstreamTask; + /** Flag represents there is an in-flight trigger request. */ private boolean isTriggering = false; @@ -344,6 +346,7 @@ public class CheckpointCoordinator { this.clock = checkNotNull(clock); this.isExactlyOnceMode = chkConfig.isExactlyOnce(); this.unalignedCheckpointsEnabled = chkConfig.isUnalignedCheckpointsEnabled(); + this.recoverOutputOnDownstreamTask = chkConfig.isRecoverOutputOnDownstreamTask(); this.alignedCheckpointTimeout = chkConfig.getAlignedCheckpointTimeout(); this.checkpointIdOfIgnoredInFlightData = chkConfig.getCheckpointIdOfIgnoredInFlightData(); @@ -1816,7 +1819,11 @@ public class CheckpointCoordinator { StateAssignmentOperation stateAssignmentOperation = new StateAssignmentOperation( - latest.getCheckpointID(), tasks, operatorStates, allowNonRestoredState); + latest.getCheckpointID(), + tasks, + operatorStates, + allowNonRestoredState, + recoverOutputOnDownstreamTask); stateAssignmentOperation.assignStates(); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java index a70b85175d6..007bb5334d2 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java @@ -21,6 +21,7 @@ package org.apache.flink.runtime.checkpoint; import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.runtime.state.ChannelState; import org.apache.flink.runtime.state.CompositeStateHandle; +import org.apache.flink.runtime.state.InputChannelStateHandle; import org.apache.flink.runtime.state.InputStateHandle; import org.apache.flink.runtime.state.KeyedStateHandle; import org.apache.flink.runtime.state.OperatorStateHandle; @@ -89,6 +90,8 @@ public class OperatorSubtaskState implements CompositeStateHandle { private final StateObjectCollection<InputStateHandle> inputChannelState; + private final StateObjectCollection<InputChannelStateHandle> upstreamOutputBufferState; + private final StateObjectCollection<OutputStateHandle> resultSubpartitionState; /** @@ -123,6 +126,7 @@ public class OperatorSubtaskState implements CompositeStateHandle { StateObjectCollection<KeyedStateHandle> managedKeyedState, StateObjectCollection<KeyedStateHandle> rawKeyedState, StateObjectCollection<InputStateHandle> inputChannelState, + StateObjectCollection<InputChannelStateHandle> upstreamOutputBufferState, StateObjectCollection<OutputStateHandle> resultSubpartitionState, InflightDataRescalingDescriptor inputRescalingDescriptor, InflightDataRescalingDescriptor outputRescalingDescriptor) { @@ -132,6 +136,7 @@ public class OperatorSubtaskState implements CompositeStateHandle { this.managedKeyedState = checkNotNull(managedKeyedState); this.rawKeyedState = checkNotNull(rawKeyedState); this.inputChannelState = checkNotNull(inputChannelState); + this.upstreamOutputBufferState = checkNotNull(upstreamOutputBufferState); this.resultSubpartitionState = checkNotNull(resultSubpartitionState); this.inputRescalingDescriptor = checkNotNull(inputRescalingDescriptor); this.outputRescalingDescriptor = checkNotNull(outputRescalingDescriptor); @@ -152,7 +157,8 @@ public class OperatorSubtaskState implements CompositeStateHandle { } private Stream<StateObjectCollection<? extends ChannelState>> streamChannelStates() { - return Stream.of(inputChannelState, resultSubpartitionState).filter(Objects::nonNull); + return Stream.of(inputChannelState, upstreamOutputBufferState, resultSubpartitionState) + .filter(Objects::nonNull); } @VisibleForTesting @@ -164,6 +170,7 @@ public class OperatorSubtaskState implements CompositeStateHandle { StateObjectCollection.empty(), StateObjectCollection.empty(), StateObjectCollection.empty(), + StateObjectCollection.empty(), InflightDataRescalingDescriptor.NO_RESCALE, InflightDataRescalingDescriptor.NO_RESCALE); } @@ -190,6 +197,10 @@ public class OperatorSubtaskState implements CompositeStateHandle { return inputChannelState; } + public StateObjectCollection<InputChannelStateHandle> getUpstreamOutputBufferState() { + return upstreamOutputBufferState; + } + public StateObjectCollection<OutputStateHandle> getResultSubpartitionState() { return resultSubpartitionState; } @@ -343,6 +354,8 @@ public class OperatorSubtaskState implements CompositeStateHandle { + rawKeyedState + ", inputChannelState=" + inputChannelState + + ", upstreamOutputBufferState=" + + upstreamOutputBufferState + ", resultSubpartitionState=" + resultSubpartitionState + ", stateSize=" @@ -358,6 +371,7 @@ public class OperatorSubtaskState implements CompositeStateHandle { || managedKeyedState.hasState() || rawKeyedState.hasState() || inputChannelState.hasState() + || upstreamOutputBufferState.hasState() || resultSubpartitionState.hasState(); } @@ -368,6 +382,7 @@ public class OperatorSubtaskState implements CompositeStateHandle { .setRawOperatorState(rawOperatorState) .setRawKeyedState(rawKeyedState) .setInputChannelState(inputChannelState) + .setUpstreamOutputBufferState(upstreamOutputBufferState) .setResultSubpartitionState(resultSubpartitionState) .setInputRescalingDescriptor(inputRescalingDescriptor) .setOutputRescalingDescriptor(outputRescalingDescriptor); @@ -392,6 +407,8 @@ public class OperatorSubtaskState implements CompositeStateHandle { StateObjectCollection.empty(); private StateObjectCollection<InputStateHandle> inputChannelState = StateObjectCollection.empty(); + private StateObjectCollection<InputChannelStateHandle> upstreamOutputBufferState = + StateObjectCollection.empty(); private StateObjectCollection<OutputStateHandle> resultSubpartitionState = StateObjectCollection.empty(); private InflightDataRescalingDescriptor inputRescalingDescriptor = @@ -449,6 +466,12 @@ public class OperatorSubtaskState implements CompositeStateHandle { return this; } + public Builder setUpstreamOutputBufferState( + StateObjectCollection<InputChannelStateHandle> upstreamOutputBufferState) { + this.upstreamOutputBufferState = checkNotNull(upstreamOutputBufferState); + return this; + } + public Builder setResultSubpartitionState( StateObjectCollection<OutputStateHandle> resultSubpartitionState) { this.resultSubpartitionState = checkNotNull(resultSubpartitionState); @@ -474,6 +497,7 @@ public class OperatorSubtaskState implements CompositeStateHandle { managedKeyedState, rawKeyedState, inputChannelState, + upstreamOutputBufferState, resultSubpartitionState, inputRescalingDescriptor, outputRescalingDescriptor); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java index 70611f7b40e..6bc30d488c4 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java @@ -75,6 +75,7 @@ public class StateAssignmentOperation { private final long restoreCheckpointId; private final boolean allowNonRestoredState; + private final boolean recoverOutputOnDownstreamTask; /** The state assignments for each ExecutionJobVertex that will be filled in multiple passes. */ private final Map<ExecutionJobVertex, TaskStateAssignment> vertexAssignments; @@ -90,18 +91,29 @@ public class StateAssignmentOperation { long restoreCheckpointId, Set<ExecutionJobVertex> tasks, Map<OperatorID, OperatorState> operatorStates, - boolean allowNonRestoredState) { + boolean allowNonRestoredState, + boolean recoverOutputOnDownstreamTask) { this.restoreCheckpointId = restoreCheckpointId; this.tasks = Preconditions.checkNotNull(tasks); this.operatorStates = Preconditions.checkNotNull(operatorStates); this.allowNonRestoredState = allowNonRestoredState; + this.recoverOutputOnDownstreamTask = recoverOutputOnDownstreamTask; this.vertexAssignments = CollectionUtil.newHashMapWithExpectedSize(tasks.size()); } public void assignStates() { checkStateMappingCompleteness(allowNonRestoredState, operatorStates, tasks); + buildStateAssignments(); + + repartitionState(); + + // actually assign the state + applyStateAssignments(); + } + + private void buildStateAssignments() { Map<OperatorID, OperatorState> localOperators = new HashMap<>(operatorStates); // find the states of all operators belonging to this task and compute additional @@ -135,13 +147,16 @@ public class StateAssignmentOperation { executionJobVertex, operatorStates, consumerAssignment, - vertexAssignments); + vertexAssignments, + recoverOutputOnDownstreamTask); vertexAssignments.put(executionJobVertex, stateAssignment); for (final IntermediateResult producedDataSet : executionJobVertex.getInputs()) { consumerAssignment.put(producedDataSet.getId(), stateAssignment); } } + } + private void repartitionState() { // repartition state for (TaskStateAssignment stateAssignment : vertexAssignments.values()) { if (stateAssignment.hasNonFinishedState @@ -153,7 +168,22 @@ public class StateAssignmentOperation { } } - // actually assign the state + // distribute output channel states to downstream tasks if needed + // Note: it has to be called after assignAttemptState for all tasks since the + // redistributing of result subpartition states depend on the inputSubtaskMappings + // of downstream tasks. + if (recoverOutputOnDownstreamTask) { + for (TaskStateAssignment stateAssignment : vertexAssignments.values()) { + // If recoverOutputOnDownstreamTask is enabled, all upstream output buffers have to + // be distributed to downstream since the upstream task side doesn’t deserialize + // records generally. It is easy to filter records and re-upload records if + // recovering output buffers on downstream task side directly. + stateAssignment.distributeOutputBuffersToDownstream(); + } + } + } + + private void applyStateAssignments() { for (TaskStateAssignment stateAssignment : vertexAssignments.values()) { // If upstream has output states or downstream has input states, even the empty task // state should be assigned for the current task in order to notify this task that the diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateAssignment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateAssignment.java index 4601b4716a1..a6db5e4837f 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateAssignment.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateAssignment.java @@ -21,6 +21,7 @@ import org.apache.flink.runtime.OperatorIDPair; import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor; import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor.MappingType; import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo; +import org.apache.flink.runtime.checkpoint.channel.ResultSubpartitionDistributor; import org.apache.flink.runtime.checkpoint.channel.ResultSubpartitionInfo; import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; import org.apache.flink.runtime.executiongraph.IntermediateResult; @@ -43,6 +44,7 @@ import org.slf4j.LoggerFactory; import javax.annotation.Nonnull; import javax.annotation.Nullable; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.HashMap; @@ -91,12 +93,26 @@ class TaskStateAssignment { final Map<OperatorInstanceID, List<InputChannelStateHandle>> inputChannelStates; final Map<OperatorInstanceID, List<ResultSubpartitionStateHandle>> resultSubpartitionStates; + /** + * Stores input channel states that come from upstream task's output buffers. It takes effect + * when {@link + * org.apache.flink.configuration.CheckpointingOptions#UNALIGNED_RECOVER_OUTPUT_ON_DOWNSTREAM} + * is enabled. + */ + private final Map<OperatorInstanceID, List<InputChannelStateHandle>> upstreamOutputBufferStates; + /** The subtask mapping when the output operator was rescaled. */ private final Map<Integer, SubtasksRescaleMapping> outputSubtaskMappings = new HashMap<>(); /** The subtask mapping when the input operator was rescaled. */ private final Map<Integer, SubtasksRescaleMapping> inputSubtaskMappings = new HashMap<>(); + /** InflightDataRescalingDescriptor for each subtask. */ + private final Map<OperatorInstanceID, InflightDataRescalingDescriptor> + outputRescalingDescriptors = new HashMap<>(); + + private final boolean recoverOutputOnDownstreamTask; + @Nullable private TaskStateAssignment[] downstreamAssignments; @Nullable private TaskStateAssignment[] upstreamAssignments; @Nullable private Boolean hasUpstreamOutputStates; @@ -109,7 +125,8 @@ class TaskStateAssignment { ExecutionJobVertex executionJobVertex, Map<OperatorID, OperatorState> oldState, Map<IntermediateDataSetID, TaskStateAssignment> consumerAssignment, - Map<ExecutionJobVertex, TaskStateAssignment> vertexAssignments) { + Map<ExecutionJobVertex, TaskStateAssignment> vertexAssignments, + boolean recoverOutputOnDownstreamTask) { this.executionJobVertex = executionJobVertex; this.oldState = oldState; @@ -126,6 +143,7 @@ class TaskStateAssignment { newParallelism = executionJobVertex.getParallelism(); this.consumerAssignment = checkNotNull(consumerAssignment); this.vertexAssignments = checkNotNull(vertexAssignments); + this.recoverOutputOnDownstreamTask = recoverOutputOnDownstreamTask; final int expectedNumberOfSubtasks = newParallelism * oldState.size(); subManagedOperatorState = @@ -134,6 +152,8 @@ class TaskStateAssignment { inputChannelStates = CollectionUtil.newHashMapWithExpectedSize(expectedNumberOfSubtasks); resultSubpartitionStates = CollectionUtil.newHashMapWithExpectedSize(expectedNumberOfSubtasks); + upstreamOutputBufferStates = + CollectionUtil.newHashMapWithExpectedSize(expectedNumberOfSubtasks); subManagedKeyedState = CollectionUtil.newHashMapWithExpectedSize(expectedNumberOfSubtasks); subRawKeyedState = CollectionUtil.newHashMapWithExpectedSize(expectedNumberOfSubtasks); @@ -238,8 +258,15 @@ class TaskStateAssignment { .setRawKeyedState(getState(instanceID, subRawKeyedState)) .setInputChannelState( castToInputStateCollection(inputChannelStates.get(instanceID))) + .setUpstreamOutputBufferState( + new StateObjectCollection<>(upstreamOutputBufferStates.get(instanceID))) .setResultSubpartitionState( - castToOutputStateCollection(resultSubpartitionStates.get(instanceID))) + // If recoverOutputOnDownstreamTask is enabled, clear own output buffers as + // they are migrated to downstream + recoverOutputOnDownstreamTask + ? castToOutputStateCollection(null) + : castToOutputStateCollection( + resultSubpartitionStates.get(instanceID))) .setInputRescalingDescriptor( createRescalingDescriptor( instanceID, @@ -254,20 +281,7 @@ class TaskStateAssignment { inputSubtaskMappings, this::getInputMapping, true)) - .setOutputRescalingDescriptor( - createRescalingDescriptor( - instanceID, - outputOperatorID, - getDownstreamAssignments(), - (assignment, recompute) -> { - int assignmentIndex = - getAssignmentIndex( - assignment.getUpstreamAssignments(), this); - return assignment.getInputMapping(assignmentIndex, recompute); - }, - outputSubtaskMappings, - this::getOutputMapping, - false)) + .setOutputRescalingDescriptor(getOutputRescalingDescriptor(instanceID)) .build(); } @@ -289,6 +303,33 @@ class TaskStateAssignment { return hasDownstreamInputStates; } + /** + * Gets the output rescaling descriptor for a specific instance with caching. Each descriptor is + * computed only once and cached for subsequent access. + */ + public InflightDataRescalingDescriptor getOutputRescalingDescriptor( + OperatorInstanceID instanceID) { + return outputRescalingDescriptors.computeIfAbsent( + instanceID, this::computeOutputRescalingDescriptor); + } + + /** Computes the output rescaling descriptor for a single subtask. */ + private InflightDataRescalingDescriptor computeOutputRescalingDescriptor( + OperatorInstanceID instanceID) { + return createRescalingDescriptor( + instanceID, + outputOperatorID, + getDownstreamAssignments(), + (downstreamAssignment, recompute) -> { + int assignmentIndex = + getAssignmentIndex(downstreamAssignment.getUpstreamAssignments(), this); + return downstreamAssignment.getInputMapping(assignmentIndex, recompute); + }, + outputSubtaskMappings, + this::getOutputMapping, + false); + } + private InflightDataGateOrPartitionRescalingDescriptor log( InflightDataGateOrPartitionRescalingDescriptor descriptor, int subtask, int partition) { LOG.debug( @@ -540,6 +581,78 @@ class TaskStateAssignment { return false; } + void distributeOutputBuffersToDownstream() { + for (Map.Entry<OperatorInstanceID, List<ResultSubpartitionStateHandle>> entry : + resultSubpartitionStates.entrySet()) { + OperatorInstanceID operatorInstanceID = entry.getKey(); + List<ResultSubpartitionStateHandle> stateHandles = entry.getValue(); + + ResultSubpartitionDistributor distributor = + new ResultSubpartitionDistributor( + getOutputRescalingDescriptor(operatorInstanceID)); + + for (final ResultSubpartitionStateHandle stateHandle : stateHandles) { + distributeOutputBufferToDownstream(stateHandle, distributor); + } + } + } + + private void distributeOutputBufferToDownstream( + ResultSubpartitionStateHandle stateHandle, ResultSubpartitionDistributor distributor) { + // From the perspective of the downstream task, the oldUpstreamSubtaskIndex will be + // treated as the inputChannelIdx, and the info.getSubPartitionIdx() will be treated + // as the oldDownstreamSubtaskIndex. + int oldUpstreamSubtaskIndex = stateHandle.getSubtaskIndex(); + ResultSubpartitionInfo info = stateHandle.getInfo(); + int partitionIdx = info.getPartitionIdx(); + int oldDownstreamSubtaskIndex = info.getSubPartitionIdx(); + + int gateIdxResultPartition = findInputGateIdxForResultPartition(partitionIdx); + TaskStateAssignment downstreamAssignment = getDownstreamAssignments()[partitionIdx]; + + List<ResultSubpartitionInfo> mappedSubpartitions = distributor.getMappedSubpartitions(info); + for (final ResultSubpartitionInfo mappedSubpartition : mappedSubpartitions) { + int targetDownstreamSubtaskId = mappedSubpartition.getSubPartitionIdx(); + + OperatorInstanceID downstreamOperatorInstance = + new OperatorInstanceID( + targetDownstreamSubtaskId, downstreamAssignment.inputOperatorID); + + InputChannelInfo inputChannelInfo = + new InputChannelInfo(gateIdxResultPartition, oldUpstreamSubtaskIndex); + + InputChannelStateHandle upstreamOutputBufferHandle = + new InputChannelStateHandle( + oldDownstreamSubtaskIndex, + inputChannelInfo, + stateHandle.getDelegate(), + stateHandle.getOffsets(), + stateHandle.getStateSize()); + + List<InputChannelStateHandle> upstreamOutputBufferHandles = + downstreamAssignment.upstreamOutputBufferStates.computeIfAbsent( + downstreamOperatorInstance, k -> new ArrayList<>()); + upstreamOutputBufferHandles.add(upstreamOutputBufferHandle); + } + } + + private int findInputGateIdxForResultPartition(int partitionIndex) { + // Check downstream input state for this partition + TaskStateAssignment downstreamAssignment = getDownstreamAssignments()[partitionIndex]; + + IntermediateResult producedResult = + executionJobVertex.getProducedDataSets()[partitionIndex]; + IntermediateDataSetID resultId = producedResult.getId(); + List<IntermediateResult> inputs = downstreamAssignment.executionJobVertex.getInputs(); + for (int i = 0; i < inputs.size(); i++) { + if (inputs.get(i).getId().equals(resultId)) { + return i; + } + } + throw new IllegalArgumentException( + "No channel rescaler found during rescaling of channel state"); + } + @Override public String toString() { return "TaskStateAssignment for " + executionJobVertex.getName(); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java index 6fe1f1a6e99..3daa4b4947a 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java @@ -67,6 +67,11 @@ public class SequentialChannelStateReaderImpl implements SequentialChannelStateR groupByDelegate( streamSubtaskStates(), ChannelStateHelper::extractUnmergedInputHandles)); + read( + stateHandler, + groupByDelegate( + streamSubtaskStates(), + OperatorSubtaskState::getUpstreamOutputBufferState)); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/CheckpointCoordinatorConfiguration.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/CheckpointCoordinatorConfiguration.java index 2d100aabefa..cc755fc17a4 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/CheckpointCoordinatorConfiguration.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/CheckpointCoordinatorConfiguration.java @@ -71,6 +71,8 @@ public class CheckpointCoordinatorConfiguration implements Serializable { private final boolean enableCheckpointsAfterTasksFinish; + private final boolean recoverOutputOnDownstreamTask; + /** * @deprecated use {@link #builder()}. */ @@ -98,6 +100,7 @@ public class CheckpointCoordinatorConfiguration implements Serializable { isUnalignedCheckpoint, 0, checkpointIdOfIgnoredInFlightData, + false, false); } @@ -113,7 +116,8 @@ public class CheckpointCoordinatorConfiguration implements Serializable { boolean isUnalignedCheckpointsEnabled, long alignedCheckpointTimeout, long checkpointIdOfIgnoredInFlightData, - boolean enableCheckpointsAfterTasksFinish) { + boolean enableCheckpointsAfterTasksFinish, + boolean recoverOutputOnDownstreamTask) { if (checkpointIntervalDuringBacklog < MINIMAL_CHECKPOINT_TIME) { // interval of max value means disable periodic checkpoint @@ -144,6 +148,7 @@ public class CheckpointCoordinatorConfiguration implements Serializable { this.alignedCheckpointTimeout = alignedCheckpointTimeout; this.checkpointIdOfIgnoredInFlightData = checkpointIdOfIgnoredInFlightData; this.enableCheckpointsAfterTasksFinish = enableCheckpointsAfterTasksFinish; + this.recoverOutputOnDownstreamTask = recoverOutputOnDownstreamTask; } public long getCheckpointInterval() { @@ -198,6 +203,10 @@ public class CheckpointCoordinatorConfiguration implements Serializable { return enableCheckpointsAfterTasksFinish; } + public boolean isRecoverOutputOnDownstreamTask() { + return recoverOutputOnDownstreamTask; + } + @Override public boolean equals(Object o) { if (this == o) { @@ -217,7 +226,8 @@ public class CheckpointCoordinatorConfiguration implements Serializable { && checkpointRetentionPolicy == that.checkpointRetentionPolicy && tolerableCheckpointFailureNumber == that.tolerableCheckpointFailureNumber && checkpointIdOfIgnoredInFlightData == that.checkpointIdOfIgnoredInFlightData - && enableCheckpointsAfterTasksFinish == that.enableCheckpointsAfterTasksFinish; + && enableCheckpointsAfterTasksFinish == that.enableCheckpointsAfterTasksFinish + && recoverOutputOnDownstreamTask == that.recoverOutputOnDownstreamTask; } @Override @@ -233,7 +243,8 @@ public class CheckpointCoordinatorConfiguration implements Serializable { alignedCheckpointTimeout, tolerableCheckpointFailureNumber, checkpointIdOfIgnoredInFlightData, - enableCheckpointsAfterTasksFinish); + enableCheckpointsAfterTasksFinish, + recoverOutputOnDownstreamTask); } @Override @@ -261,6 +272,8 @@ public class CheckpointCoordinatorConfiguration implements Serializable { + checkpointIdOfIgnoredInFlightData + ", enableCheckpointsAfterTasksFinish=" + enableCheckpointsAfterTasksFinish + + ", recoverOutputOnDownstreamTask=" + + recoverOutputOnDownstreamTask + '}'; } @@ -283,6 +296,7 @@ public class CheckpointCoordinatorConfiguration implements Serializable { private long alignedCheckpointTimeout = 0; private long checkpointIdOfIgnoredInFlightData; private boolean enableCheckpointsAfterTasksFinish; + private boolean recoverOutputOnDownstreamTask; public CheckpointCoordinatorConfiguration build() { return new CheckpointCoordinatorConfiguration( @@ -297,7 +311,8 @@ public class CheckpointCoordinatorConfiguration implements Serializable { isUnalignedCheckpointsEnabled, alignedCheckpointTimeout, checkpointIdOfIgnoredInFlightData, - enableCheckpointsAfterTasksFinish); + enableCheckpointsAfterTasksFinish, + recoverOutputOnDownstreamTask); } public CheckpointCoordinatorConfigurationBuilder setCheckpointInterval( @@ -370,5 +385,11 @@ public class CheckpointCoordinatorConfiguration implements Serializable { this.enableCheckpointsAfterTasksFinish = enableCheckpointsAfterTasksFinish; return this; } + + public CheckpointCoordinatorConfigurationBuilder setRecoverOutputOnDownstreamTask( + boolean recoverOutputOnDownstreamTask) { + this.recoverOutputOnDownstreamTask = recoverOutputOnDownstreamTask; + return this; + } } } diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamGraph.java b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamGraph.java index 49822d1c666..887c22a395e 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamGraph.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamGraph.java @@ -37,6 +37,7 @@ import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.api.java.typeutils.MissingTypeInfo; +import org.apache.flink.configuration.CheckpointingOptions; import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.ExecutionOptions; import org.apache.flink.configuration.ExternalizedCheckpointRetention; @@ -389,6 +390,10 @@ public class StreamGraph implements Pipeline, ExecutionPlan { cfg.getCheckpointIdOfIgnoredInFlightData()) .setAlignedCheckpointTimeout(cfg.getAlignedCheckpointTimeout().toMillis()) .setEnableCheckpointsAfterTasksFinish(isEnableCheckpointsAfterTasksFinish()) + .setRecoverOutputOnDownstreamTask( + jobConfiguration.get( + CheckpointingOptions + .UNALIGNED_RECOVER_OUTPUT_ON_DOWNSTREAM)) .build(), serializedStateBackend, getJobConfiguration() diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperationTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperationTest.java index c13bcd0e321..711c3ef5cf3 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperationTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperationTest.java @@ -38,18 +38,25 @@ import org.apache.flink.runtime.jobgraph.JobVertex; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.jobgraph.OperatorInstanceID; +import org.apache.flink.runtime.state.InputChannelStateHandle; import org.apache.flink.runtime.state.KeyGroupRange; import org.apache.flink.runtime.state.KeyedStateHandle; import org.apache.flink.runtime.state.OperatorStateHandle; import org.apache.flink.runtime.state.OperatorStreamStateHandle; +import org.apache.flink.runtime.state.OutputStateHandle; +import org.apache.flink.runtime.state.ResultSubpartitionStateHandle; import org.apache.flink.runtime.state.memory.ByteStreamStateHandle; import org.apache.flink.runtime.testtasks.NoOpInvokable; import org.apache.flink.testutils.TestingUtils; import org.apache.flink.testutils.executor.TestExecutorExtension; +import org.apache.flink.shaded.guava33.com.google.common.collect.Iterables; + import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import javax.annotation.Nullable; @@ -506,7 +513,7 @@ class StateAssignmentOperationTest { buildVertices(operatorIds, numSubTasks, RANGE, ROUND_ROBIN); Map<OperatorID, OperatorState> states = buildOperatorStates(operatorIds, numSubTasks); - new StateAssignmentOperation(0, new HashSet<>(vertices.values()), states, false) + new StateAssignmentOperation(0, new HashSet<>(vertices.values()), states, false, false) .assignStates(); for (OperatorID operatorId : operatorIds) { @@ -535,7 +542,7 @@ class StateAssignmentOperationTest { Map<OperatorID, ExecutionJobVertex> vertices = toExecutionVertices(upstream1, upstream2, downstream); - new StateAssignmentOperation(0, new HashSet<>(vertices.values()), states, false) + new StateAssignmentOperation(0, new HashSet<>(vertices.values()), states, false, false) .assignStates(); assertThat( @@ -599,7 +606,7 @@ class StateAssignmentOperationTest { Map<OperatorID, ExecutionJobVertex> vertices = toExecutionVertices(upstream1, upstream2, downstream); - new StateAssignmentOperation(0, new HashSet<>(vertices.values()), states, false) + new StateAssignmentOperation(0, new HashSet<>(vertices.values()), states, false, false) .assignStates(); assertThat( @@ -619,17 +626,32 @@ class StateAssignmentOperationTest { .isEqualTo(6); } - @Test - void testChannelStateAssignmentDownscaling() throws JobException, JobExecutionException { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testChannelStateAssignmentDownscaling(boolean recoverOutputOnDownstreamTask) + throws JobException, JobExecutionException { + int oldParallelism = 3; + int newParallelism = 2; + List<OperatorID> operatorIds = buildOperatorIds(2); - Map<OperatorID, OperatorState> states = buildOperatorStates(operatorIds, 3); + OperatorID firstOperator = operatorIds.get(0); + JobResultSubpartitionHandlers jobResultSubpartitionHandlers = + new JobResultSubpartitionHandlers(operatorIds, oldParallelism); + Map<OperatorID, OperatorState> states = + buildOperatorStates(operatorIds, oldParallelism, jobResultSubpartitionHandlers); Map<OperatorID, ExecutionJobVertex> vertices = - buildVertices(operatorIds, 2, RANGE, ROUND_ROBIN); - - new StateAssignmentOperation(0, new HashSet<>(vertices.values()), states, false) + buildVertices(operatorIds, newParallelism, RANGE, ROUND_ROBIN); + + new StateAssignmentOperation( + 0, + new HashSet<>(vertices.values()), + states, + false, + recoverOutputOnDownstreamTask) .assignStates(); + OperatorID upstreamOperator = null; for (OperatorID operatorId : operatorIds) { // input is range partitioned, so there is an overlap assertState( @@ -648,22 +670,68 @@ class StateAssignmentOperationTest { OperatorSubtaskState::getInputChannelState, 1, 2); - // output is round robin redistributed - assertState( - vertices, - operatorId, - states, - 0, - OperatorSubtaskState::getResultSubpartitionState, - 0, - 2); - assertState( - vertices, - operatorId, - states, - 1, - OperatorSubtaskState::getResultSubpartitionState, - 1); + + if (recoverOutputOnDownstreamTask) { + // output buffer states are moved to downstream task when + // recoverOutputOnDownstreamTask is enabled + assertStateEmptyForAllSubtasks( + vertices, + operatorId, + newParallelism, + OperatorSubtaskState::getResultSubpartitionState); + + if (firstOperator == operatorId) { + // The first operator does not have any upstream. + assertStateEmptyForAllSubtasks( + vertices, + operatorId, + newParallelism, + OperatorSubtaskState::getUpstreamOutputBufferState); + } else { + ExecutionJobVertex executionJobVertex = vertices.get(operatorId); + int[] indexes0 = {0, 0, 0, 1, 1, 0, 1, 1, 2, 0, 2, 1}; + assertUpstreamOutputBufferState( + jobResultSubpartitionHandlers, + executionJobVertex, + operatorId, + 0, + upstreamOperator, + indexes0); + int[] indexes1 = {0, 1, 0, 2, 1, 1, 1, 2, 2, 1, 2, 2}; + assertUpstreamOutputBufferState( + jobResultSubpartitionHandlers, + executionJobVertex, + operatorId, + 1, + upstreamOperator, + indexes1); + } + } else { + // output is round robin redistributed + assertState( + vertices, + operatorId, + states, + 0, + OperatorSubtaskState::getResultSubpartitionState, + 0, + 2); + assertState( + vertices, + operatorId, + states, + 1, + OperatorSubtaskState::getResultSubpartitionState, + 1); + + // The upstream output buffer state is expected to be empty. + assertStateEmptyForAllSubtasks( + vertices, + operatorId, + newParallelism, + OperatorSubtaskState::getUpstreamOutputBufferState); + } + upstreamOperator = operatorId; } assertThat( @@ -686,38 +754,98 @@ class StateAssignmentOperationTest { .isEqualTo(rescalingDescriptor(to(1, 2), array(mappings(to(0, 2), to(1))), set(1))); } - @Test - void testChannelStateAssignmentNoRescale() throws JobException, JobExecutionException { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testChannelStateAssignmentNoRescale(boolean recoverOutputOnDownstreamTask) + throws JobException, JobExecutionException { List<OperatorID> operatorIds = buildOperatorIds(2); - Map<OperatorID, OperatorState> states = buildOperatorStates(operatorIds, 2); + OperatorID firstOperator = operatorIds.get(0); + int parallelism = 2; + JobResultSubpartitionHandlers jobResultSubpartitionHandlers = + new JobResultSubpartitionHandlers(operatorIds, parallelism); + + Map<OperatorID, OperatorState> states = + buildOperatorStates(operatorIds, parallelism, jobResultSubpartitionHandlers); Map<OperatorID, ExecutionJobVertex> vertices = buildVertices(operatorIds, 2, RANGE, ROUND_ROBIN); - new StateAssignmentOperation(0, new HashSet<>(vertices.values()), states, false) + new StateAssignmentOperation( + 0, + new HashSet<>(vertices.values()), + states, + false, + recoverOutputOnDownstreamTask) .assignStates(); + OperatorID upstreamOperator = null; for (OperatorID operatorId : operatorIds) { - // input is range partitioned, so there is an overlap + // input is range partitioned assertState( vertices, operatorId, states, 0, OperatorSubtaskState::getInputChannelState, 0); assertState( vertices, operatorId, states, 1, OperatorSubtaskState::getInputChannelState, 1); - // output is round robin redistributed - assertState( - vertices, - operatorId, - states, - 0, - OperatorSubtaskState::getResultSubpartitionState, - 0); - assertState( - vertices, - operatorId, - states, - 1, - OperatorSubtaskState::getResultSubpartitionState, - 1); + + if (recoverOutputOnDownstreamTask) { + // output buffer states are moved to downstream task when + // recoverOutputOnDownstreamTask is enabled + assertStateEmptyForAllSubtasks( + vertices, + operatorId, + parallelism, + OperatorSubtaskState::getResultSubpartitionState); + + if (firstOperator == operatorId) { + // The first operator does not have any upstream. + assertStateEmptyForAllSubtasks( + vertices, + operatorId, + parallelism, + OperatorSubtaskState::getUpstreamOutputBufferState); + } else { + ExecutionJobVertex executionJobVertex = vertices.get(operatorId); + int[] indexes0 = {0, 0, 1, 0}; + assertUpstreamOutputBufferState( + jobResultSubpartitionHandlers, + executionJobVertex, + operatorId, + 0, + upstreamOperator, + indexes0); + int[] indexes1 = {0, 1, 1, 1}; + assertUpstreamOutputBufferState( + jobResultSubpartitionHandlers, + executionJobVertex, + operatorId, + 1, + upstreamOperator, + indexes1); + } + } else { + // output is round robin redistributed + assertState( + vertices, + operatorId, + states, + 0, + OperatorSubtaskState::getResultSubpartitionState, + 0); + assertState( + vertices, + operatorId, + states, + 1, + OperatorSubtaskState::getResultSubpartitionState, + 1); + + // The upstream output buffer state is expected to be empty. + assertStateEmptyForAllSubtasks( + vertices, + operatorId, + parallelism, + OperatorSubtaskState::getUpstreamOutputBufferState); + } + upstreamOperator = operatorId; } assertThat( @@ -739,17 +867,32 @@ class StateAssignmentOperationTest { .isEqualTo(InflightDataRescalingDescriptor.NO_RESCALE); } - @Test - void testChannelStateAssignmentUpscaling() throws JobException, JobExecutionException { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testChannelStateAssignmentUpscaling(boolean recoverOutputOnDownstreamTask) + throws JobException, JobExecutionException { + int oldParallelism = 2; + int newParallelism = 3; + List<OperatorID> operatorIds = buildOperatorIds(2); - Map<OperatorID, OperatorState> states = buildOperatorStates(operatorIds, 2); + OperatorID firstOperator = operatorIds.get(0); + JobResultSubpartitionHandlers jobResultSubpartitionHandlers = + new JobResultSubpartitionHandlers(operatorIds, oldParallelism); + Map<OperatorID, OperatorState> states = + buildOperatorStates(operatorIds, oldParallelism, jobResultSubpartitionHandlers); Map<OperatorID, ExecutionJobVertex> vertices = - buildVertices(operatorIds, 3, RANGE, ROUND_ROBIN); - - new StateAssignmentOperation(0, new HashSet<>(vertices.values()), states, false) + buildVertices(operatorIds, newParallelism, RANGE, ROUND_ROBIN); + + new StateAssignmentOperation( + 0, + new HashSet<>(vertices.values()), + states, + false, + recoverOutputOnDownstreamTask) .assignStates(); + OperatorID upstreamOperator = null; for (OperatorID operatorId : operatorIds) { // input is range partitioned, so there is an overlap assertState( @@ -764,27 +907,81 @@ class StateAssignmentOperationTest { 1); assertState( vertices, operatorId, states, 2, OperatorSubtaskState::getInputChannelState, 1); - // output is round robin redistributed - assertState( - vertices, - operatorId, - states, - 0, - OperatorSubtaskState::getResultSubpartitionState, - 0); - assertState( - vertices, - operatorId, - states, - 1, - OperatorSubtaskState::getResultSubpartitionState, - 1); - assertState( - vertices, - operatorId, - states, - 2, - OperatorSubtaskState::getResultSubpartitionState); + + if (recoverOutputOnDownstreamTask) { + // output buffer states are moved to downstream task when + // recoverOutputOnDownstreamTask is enabled + assertStateEmptyForAllSubtasks( + vertices, + operatorId, + newParallelism, + OperatorSubtaskState::getResultSubpartitionState); + + if (firstOperator == operatorId) { + // The first operator does not have any upstream. + assertStateEmptyForAllSubtasks( + vertices, + operatorId, + newParallelism, + OperatorSubtaskState::getUpstreamOutputBufferState); + } else { + ExecutionJobVertex executionJobVertex = vertices.get(operatorId); + int[] indexes0 = {0, 0, 1, 0}; + assertUpstreamOutputBufferState( + jobResultSubpartitionHandlers, + executionJobVertex, + operatorId, + 0, + upstreamOperator, + indexes0); + int[] indexes1 = {0, 0, 0, 1, 1, 0, 1, 1}; + assertUpstreamOutputBufferState( + jobResultSubpartitionHandlers, + executionJobVertex, + operatorId, + 1, + upstreamOperator, + indexes1); + int[] indexes2 = {0, 1, 1, 1}; + assertUpstreamOutputBufferState( + jobResultSubpartitionHandlers, + executionJobVertex, + operatorId, + 2, + upstreamOperator, + indexes2); + } + } else { + // output is round robin redistributed + assertState( + vertices, + operatorId, + states, + 0, + OperatorSubtaskState::getResultSubpartitionState, + 0); + assertState( + vertices, + operatorId, + states, + 1, + OperatorSubtaskState::getResultSubpartitionState, + 1); + assertState( + vertices, + operatorId, + states, + 2, + OperatorSubtaskState::getResultSubpartitionState); + + // The upstream output buffer state is expected to be empty. + assertStateEmptyForAllSubtasks( + vertices, + operatorId, + newParallelism, + OperatorSubtaskState::getUpstreamOutputBufferState); + } + upstreamOperator = operatorId; } assertThat( @@ -845,7 +1042,7 @@ class StateAssignmentOperationTest { buildVertices(operatorIds, 3, RANGE, ROUND_ROBIN); // when: States are assigned. - new StateAssignmentOperation(0, new HashSet<>(vertices.values()), states, false) + new StateAssignmentOperation(0, new HashSet<>(vertices.values()), states, false, false) .assignStates(); // then: All subtask have not null TaskRestore information(even if it is empty). @@ -942,7 +1139,7 @@ class StateAssignmentOperationTest { buildVertices(opIdWithParallelism, RANGE, ROUND_ROBIN); // Run state assignment - new StateAssignmentOperation(0, new HashSet<>(vertices.values()), states, false) + new StateAssignmentOperation(0, new HashSet<>(vertices.values()), states, false, false) .assignStates(); // Check results @@ -983,8 +1180,10 @@ class StateAssignmentOperationTest { .isEqualTo(expectedCount); } - @Test - void testStateWithFullyFinishedOperators() throws JobException, JobExecutionException { + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void testStateWithFullyFinishedOperators(boolean recoverOutputOnDownstreamTask) + throws JobException, JobExecutionException { List<OperatorID> operatorIds = buildOperatorIds(2); Map<OperatorID, OperatorState> states = buildOperatorStates(Collections.singletonList(operatorIds.get(1)), 3); @@ -996,7 +1195,12 @@ class StateAssignmentOperationTest { Map<OperatorID, ExecutionJobVertex> vertices = buildVertices(operatorIds, 2, RANGE, ROUND_ROBIN); - new StateAssignmentOperation(0, new HashSet<>(vertices.values()), states, false) + new StateAssignmentOperation( + 0, + new HashSet<>(vertices.values()), + states, + false, + recoverOutputOnDownstreamTask) .assignStates(); // Check the job vertex with only finished operator. @@ -1043,8 +1247,33 @@ class StateAssignmentOperationTest { .isTrue(); } - @Test - void assigningStatesShouldWorkWithUserDefinedOperatorIdsAsWell() { + private void assertStateEmptyForAllSubtasks( + Map<OperatorID, ExecutionJobVertex> vertices, + OperatorID operatorId, + int numSubtasks, + Function<OperatorSubtaskState, StateObjectCollection<?>> extractor) { + for (int i = 0; i < numSubtasks; i++) { + assertStateEmpty(vertices, operatorId, i, extractor); + } + } + + private void assertStateEmpty( + Map<OperatorID, ExecutionJobVertex> vertices, + OperatorID operatorId, + int newSubtaskIndex, + Function<OperatorSubtaskState, StateObjectCollection<?>> extractor) { + final OperatorSubtaskState subState = + getAssignedState(vertices.get(operatorId), operatorId, newSubtaskIndex); + + assertThat(extractor.apply(subState).hasState()) + .as("State should be empty for operator %s subtask %d", operatorId, newSubtaskIndex) + .isFalse(); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void assigningStatesShouldWorkWithUserDefinedOperatorIdsAsWell( + boolean recoverOutputOnDownstreamTask) { int numSubTasks = 1; OperatorID operatorId = new OperatorID(); OperatorID userDefinedOperatorId = new OperatorID(); @@ -1054,7 +1283,12 @@ class StateAssignmentOperationTest { buildExecutionJobVertex(operatorId, userDefinedOperatorId, 1); Map<OperatorID, OperatorState> states = buildOperatorStates(operatorIds, numSubTasks); - new StateAssignmentOperation(0, Collections.singleton(executionJobVertex), states, false) + new StateAssignmentOperation( + 0, + Collections.singleton(executionJobVertex), + states, + false, + recoverOutputOnDownstreamTask) .assignStates(); assertThat(getAssignedState(executionJobVertex, operatorId, 0)) @@ -1081,10 +1315,101 @@ class StateAssignmentOperationTest { .collect(Collectors.toList()); } + /** + * Asserts the upstream output buffer state for a specific subtask by verifying the expected + * upstream subtask and subpartition mappings. + * + * <p>The indexes parameter contains pairs of (subtaskIndex, subpartitionIndex) that define the + * expected upstream sources for the current subtask's input channels. + * + * @param jobResultSubpartitionHandlers the job result subpartition handlers + * @param executionJobVertex the execution job vertex + * @param operatorId the operator ID to check + * @param subtaskIndex the subtask index to verify + * @param upstreamOperator the upstream operator ID + * @param indexes pairs of (subtaskIndex, subpartitionIndex) defining expected upstream sources + */ + private void assertUpstreamOutputBufferState( + JobResultSubpartitionHandlers jobResultSubpartitionHandlers, + ExecutionJobVertex executionJobVertex, + OperatorID operatorId, + int subtaskIndex, + OperatorID upstreamOperator, + int[] indexes) { + checkArgument( + indexes.length % 2 == 0, + "indexes must contain pairs of (subtaskIndex, subpartitionIndex)"); + + ArrayList<ResultSubpartitionStateHandle> originals = new ArrayList<>(indexes.length / 2); + for (int i = 0; i < indexes.length; i += 2) { + int upstreamSubtaskIndex = indexes[i]; + int upstreamSubpartitionId = indexes[i + 1]; + originals.add( + jobResultSubpartitionHandlers.getHandler( + upstreamOperator, upstreamSubtaskIndex, upstreamSubpartitionId)); + } + + List<InputChannelStateHandle> transformed = + getAssignedState(executionJobVertex, operatorId, subtaskIndex) + .getUpstreamOutputBufferState() + .asList(); + assertStrictPerfectOneToOneMatch(originals, transformed); + } + + /** + * Core logic: Assert that two collections have strictly equal sizes and each element in the + * Result collection can find a unique, valid match in the Input collection (perfect 1:1 match). + */ + private void assertStrictPerfectOneToOneMatch( + List<ResultSubpartitionStateHandle> originals, + List<InputChannelStateHandle> transformed) { + + assertThat(originals) + .as( + "Verify that the Result collection has the same number of elements as the Input collection") + .hasSameSizeAs(transformed); + + if (originals.isEmpty()) { + return; + } + + // Used to store already matched Input elements to ensure 1:1 constraint + Set<InputChannelStateHandle> usedInputs = new HashSet<>(); + + originals.forEach( + r -> { + List<InputChannelStateHandle> matchedInputHandles = + transformed.stream() + .filter(t -> verifyTransformationConsistency(r, t)) + .collect(Collectors.toList()); + assertThat(matchedInputHandles).hasSize(1); + + InputChannelStateHandle matchedInputHandle = + Iterables.getOnlyElement(matchedInputHandles); + assertThat(usedInputs).doesNotContain(matchedInputHandle); + usedInputs.add(matchedInputHandle); + }); + } + + private boolean verifyTransformationConsistency( + ResultSubpartitionStateHandle original, InputChannelStateHandle transformed) { + return original.getDelegate() == transformed.getDelegate() + && original.getOffsets().equals(transformed.getOffsets()) + && original.getStateSize() == transformed.getStateSize() + && original.getSubtaskIndex() == transformed.getInfo().getInputChannelIdx(); + } + private Map<OperatorID, OperatorState> buildOperatorStates( List<OperatorID> operatorIDs, int numSubTasks) { + return buildOperatorStates( + operatorIDs, + numSubTasks, + new JobResultSubpartitionHandlers(operatorIDs, numSubTasks)); + } + + private Map<OperatorID, OperatorState> buildOperatorStates( + List<OperatorID> operatorIDs, int numSubTasks, JobResultSubpartitionHandlers handlers) { Random random = new Random(); - final OperatorID lastId = operatorIDs.get(operatorIDs.size() - 1); return operatorIDs.stream() .collect( Collectors.toMap( @@ -1136,17 +1461,8 @@ class StateAssignmentOperationTest { 10, random)))) .setResultSubpartitionState( - operatorID == lastId - ? StateObjectCollection - .empty() - : new StateObjectCollection<>( - asList( - createNewResultSubpartitionStateHandle( - 10, - random), - createNewResultSubpartitionStateHandle( - 10, - random)))) + handlers.getStateObjectCollection( + operatorID, i)) .build()); } return state; @@ -1441,7 +1757,7 @@ class StateAssignmentOperationTest { Map<OperatorID, ExecutionJobVertex> vertices = toExecutionVertices(source, map1, map2); // This should not throw UnsupportedOperationException - new StateAssignmentOperation(0, new HashSet<>(vertices.values()), states, false) + new StateAssignmentOperation(0, new HashSet<>(vertices.values()), states, false, false) .assignStates(); // Verify state assignment succeeded @@ -1521,7 +1837,7 @@ class StateAssignmentOperationTest { toExecutionVertices(upstream1, upstream2, upstream3, downstream); // This should not throw UnsupportedOperationException - new StateAssignmentOperation(0, new HashSet<>(vertices.values()), states, false) + new StateAssignmentOperation(0, new HashSet<>(vertices.values()), states, false, false) .assignStates(); // Verify downstream received state @@ -1582,7 +1898,7 @@ class StateAssignmentOperationTest { Map<OperatorID, ExecutionJobVertex> vertices = toExecutionVertices(source, sink); // This should succeed even with RESCALE partitioner when parallelism changes - new StateAssignmentOperation(0, new HashSet<>(vertices.values()), states, false) + new StateAssignmentOperation(0, new HashSet<>(vertices.values()), states, false, false) .assignStates(); // Verify state was assigned @@ -1590,4 +1906,78 @@ class StateAssignmentOperationTest { getAssignedState(vertices.get(operatorIds.get(1)), operatorIds.get(1), 0); assertThat(sinkAssignedState).isNotNull(); } + + /** + * Test utility class that manages ResultSubpartitionStateHandles for all operators in a job. + */ + private static class JobResultSubpartitionHandlers { + + private final Map<OperatorID, Map<Integer, SubtaskResultSubpartitionHandlers>> handlers; + + public JobResultSubpartitionHandlers(List<OperatorID> operatorIDs, int numSubTasks) { + this.handlers = new HashMap<>(operatorIDs.size() - 1); + + Random random = new Random(); + final OperatorID lastId = operatorIDs.get(operatorIDs.size() - 1); + for (OperatorID operatorID : operatorIDs) { + if (operatorID == lastId) { + // The last operator does not contain output buffers. + return; + } + Map<Integer, SubtaskResultSubpartitionHandlers> operatorHandlers = new HashMap<>(); + for (int subtaskIndex = 0; subtaskIndex < numSubTasks; subtaskIndex++) { + operatorHandlers.put( + subtaskIndex, + new SubtaskResultSubpartitionHandlers(random, numSubTasks)); + } + handlers.put(operatorID, operatorHandlers); + } + } + + /** + * Returns the StateObjectCollection of OutputStateHandles for the specified operator and + * subtask. + */ + private StateObjectCollection<OutputStateHandle> getStateObjectCollection( + OperatorID operatorID, int subtaskIndex) { + Map<Integer, SubtaskResultSubpartitionHandlers> operatorHandlers = + handlers.get(operatorID); + if (operatorHandlers == null) { + return StateObjectCollection.empty(); + } + + SubtaskResultSubpartitionHandlers subtaskResultSubpartitionHandlers = + operatorHandlers.get(subtaskIndex); + checkArgument( + subtaskResultSubpartitionHandlers != null, + "Subtask result subpartition handler not found for subtask %s.", + subtaskIndex); + // Ensure there is output buffer in each subpartition. + return new StateObjectCollection<>( + new ArrayList<>(subtaskResultSubpartitionHandlers.handlers.values())); + } + + private ResultSubpartitionStateHandle getHandler( + OperatorID operatorID, int subtaskIndex, int subpartitionIndex) { + return handlers.get(operatorID).get(subtaskIndex).getHandler(subpartitionIndex); + } + } + + /** Test utility class that manages ResultSubpartitionStateHandles for a single subtask. */ + private static class SubtaskResultSubpartitionHandlers { + + // The key is subpartition index + private final Map<Integer, ResultSubpartitionStateHandle> handlers; + + public SubtaskResultSubpartitionHandlers(Random random, int numSubpartitions) { + this.handlers = new HashMap<>(numSubpartitions); + for (int i = 0; i < numSubpartitions; i++) { + handlers.put(i, createNewResultSubpartitionStateHandle(10, 0, i, random)); + } + } + + private ResultSubpartitionStateHandle getHandler(int subpartitionIndex) { + return handlers.get(subpartitionIndex); + } + } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateHandleDummyUtil.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateHandleDummyUtil.java index 7007821a6fe..088dc6b7af7 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateHandleDummyUtil.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateHandleDummyUtil.java @@ -162,8 +162,14 @@ public class StateHandleDummyUtil { public static ResultSubpartitionStateHandle createNewResultSubpartitionStateHandle( int numNamedStates, int partitionIndex, Random random) { + return createNewResultSubpartitionStateHandle( + numNamedStates, partitionIndex, random.nextInt(), random); + } + + public static ResultSubpartitionStateHandle createNewResultSubpartitionStateHandle( + int numNamedStates, int partitionIndex, int subPartitionIdx, Random random) { return new ResultSubpartitionStateHandle( - new ResultSubpartitionInfo(partitionIndex, random.nextInt()), + new ResultSubpartitionInfo(partitionIndex, subPartitionIdx), createStreamStateHandle(numNamedStates, random), genOffsets(numNamedStates, random)); } diff --git a/flink-tests/src/test/java/org/apache/flink/test/state/ChangelogRecoveryCachingITCase.java b/flink-tests/src/test/java/org/apache/flink/test/state/ChangelogRecoveryCachingITCase.java index c306a81a0c2..b6511bba78a 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/state/ChangelogRecoveryCachingITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/state/ChangelogRecoveryCachingITCase.java @@ -64,6 +64,7 @@ import static org.apache.flink.changelog.fs.FsStateChangelogOptions.PREEMPTIVE_P import static org.apache.flink.configuration.CheckpointingOptions.CHECKPOINTS_DIRECTORY; import static org.apache.flink.configuration.CheckpointingOptions.CHECKPOINT_STORAGE; import static org.apache.flink.configuration.CheckpointingOptions.FILE_MERGING_ENABLED; +import static org.apache.flink.configuration.CheckpointingOptions.UNALIGNED_RECOVER_OUTPUT_ON_DOWNSTREAM; import static org.apache.flink.configuration.CoreOptions.DEFAULT_PARALLELISM; import static org.apache.flink.configuration.ExternalizedCheckpointRetention.RETAIN_ON_CANCELLATION; import static org.apache.flink.configuration.RestartStrategyOptions.RESTART_STRATEGY; @@ -179,6 +180,9 @@ public class ChangelogRecoveryCachingITCase extends TestLogger { conf.set(PERIODIC_MATERIALIZATION_ENABLED, false); conf.set(CheckpointingOptions.ENABLE_UNALIGNED, true); // speedup + // Disable UNALIGNED_ALLOW_ON_RECOVERY to prevent randomization since the output buffer + // states file may be opened from multiple downstream subtasks + conf.set(UNALIGNED_RECOVER_OUTPUT_ON_DOWNSTREAM, false); conf.set( CheckpointingOptions.ALIGNED_CHECKPOINT_TIMEOUT, Duration.ZERO); // prevent randomization
