This is an automated email from the ASF dual-hosted git repository.

fanrui pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 686c00f8e3b727f15b7b922397207372913ec101
Author: Rui Fan <[email protected]>
AuthorDate: Mon Nov 3 22:03:34 2025 +0100

    [FLINK-38542][checkpoint] Recover output buffers of upstream task on 
downstream task side directly
---
 .../generated/checkpointing_configuration.html     |   6 +
 .../flink/configuration/CheckpointingOptions.java  |  10 +
 .../runtime/checkpoint/CheckpointCoordinator.java  |   9 +-
 .../runtime/checkpoint/OperatorSubtaskState.java   |  26 +-
 .../checkpoint/StateAssignmentOperation.java       |  36 +-
 .../runtime/checkpoint/TaskStateAssignment.java    | 145 +++++-
 .../channel/SequentialChannelStateReaderImpl.java  |   5 +
 .../tasks/CheckpointCoordinatorConfiguration.java  |  29 +-
 .../flink/streaming/api/graph/StreamGraph.java     |   5 +
 .../checkpoint/StateAssignmentOperationTest.java   | 580 +++++++++++++++++----
 .../runtime/checkpoint/StateHandleDummyUtil.java   |   8 +-
 .../test/state/ChangelogRecoveryCachingITCase.java |   4 +
 12 files changed, 742 insertions(+), 121 deletions(-)

diff --git a/docs/layouts/shortcodes/generated/checkpointing_configuration.html 
b/docs/layouts/shortcodes/generated/checkpointing_configuration.html
index 1766e2d7589..cb38355a92a 100644
--- a/docs/layouts/shortcodes/generated/checkpointing_configuration.html
+++ b/docs/layouts/shortcodes/generated/checkpointing_configuration.html
@@ -182,6 +182,12 @@
             <td>Integer</td>
             <td>Defines the maximum number of subtasks that share the same 
channel state file. It can reduce the number of small files when enable 
unaligned checkpoint. Each subtask will create a new channel state file when 
this is configured to 1.</td>
         </tr>
+        <tr>
+            
<td><h5>execution.checkpointing.unaligned.recover-output-on-downstream.enabled</h5></td>
+            <td style="word-wrap: break-word;">false</td>
+            <td>Boolean</td>
+            <td>Whether recovering output buffers of upstream task on 
downstream task directly when job restores from the unaligned checkpoint.</td>
+        </tr>
         <tr>
             <td><h5>execution.checkpointing.write-buffer-size</h5></td>
             <td style="word-wrap: break-word;">4096</td>
diff --git 
a/flink-core/src/main/java/org/apache/flink/configuration/CheckpointingOptions.java
 
b/flink-core/src/main/java/org/apache/flink/configuration/CheckpointingOptions.java
index 8e4ae89184e..9921ffc8f5f 100644
--- 
a/flink-core/src/main/java/org/apache/flink/configuration/CheckpointingOptions.java
+++ 
b/flink-core/src/main/java/org/apache/flink/configuration/CheckpointingOptions.java
@@ -646,6 +646,16 @@ public class CheckpointingOptions {
                                     + "It can reduce the number of small files 
when enable unaligned checkpoint. "
                                     + "Each subtask will create a new channel 
state file when this is configured to 1.");
 
+    @Experimental
+    public static final ConfigOption<Boolean> 
UNALIGNED_RECOVER_OUTPUT_ON_DOWNSTREAM =
+            ConfigOptions.key(
+                            
"execution.checkpointing.unaligned.recover-output-on-downstream.enabled")
+                    .booleanType()
+                    .defaultValue(false)
+                    .withDescription(
+                            "Whether recovering output buffers of upstream 
task on downstream task directly "
+                                    + "when job restores from the unaligned 
checkpoint.");
+
     /**
      * Determines whether checkpointing is enabled based on the configuration.
      *
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
index 85622387daa..bc65b09a7ef 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
@@ -241,6 +241,8 @@ public class CheckpointCoordinator {
 
     private final boolean isExactlyOnceMode;
 
+    private final boolean recoverOutputOnDownstreamTask;
+
     /** Flag represents there is an in-flight trigger request. */
     private boolean isTriggering = false;
 
@@ -344,6 +346,7 @@ public class CheckpointCoordinator {
         this.clock = checkNotNull(clock);
         this.isExactlyOnceMode = chkConfig.isExactlyOnce();
         this.unalignedCheckpointsEnabled = 
chkConfig.isUnalignedCheckpointsEnabled();
+        this.recoverOutputOnDownstreamTask = 
chkConfig.isRecoverOutputOnDownstreamTask();
         this.alignedCheckpointTimeout = 
chkConfig.getAlignedCheckpointTimeout();
         this.checkpointIdOfIgnoredInFlightData = 
chkConfig.getCheckpointIdOfIgnoredInFlightData();
 
@@ -1816,7 +1819,11 @@ public class CheckpointCoordinator {
 
             StateAssignmentOperation stateAssignmentOperation =
                     new StateAssignmentOperation(
-                            latest.getCheckpointID(), tasks, operatorStates, 
allowNonRestoredState);
+                            latest.getCheckpointID(),
+                            tasks,
+                            operatorStates,
+                            allowNonRestoredState,
+                            recoverOutputOnDownstreamTask);
 
             stateAssignmentOperation.assignStates();
 
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java
index a70b85175d6..007bb5334d2 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/OperatorSubtaskState.java
@@ -21,6 +21,7 @@ package org.apache.flink.runtime.checkpoint;
 import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.runtime.state.ChannelState;
 import org.apache.flink.runtime.state.CompositeStateHandle;
+import org.apache.flink.runtime.state.InputChannelStateHandle;
 import org.apache.flink.runtime.state.InputStateHandle;
 import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
@@ -89,6 +90,8 @@ public class OperatorSubtaskState implements 
CompositeStateHandle {
 
     private final StateObjectCollection<InputStateHandle> inputChannelState;
 
+    private final StateObjectCollection<InputChannelStateHandle> 
upstreamOutputBufferState;
+
     private final StateObjectCollection<OutputStateHandle> 
resultSubpartitionState;
 
     /**
@@ -123,6 +126,7 @@ public class OperatorSubtaskState implements 
CompositeStateHandle {
             StateObjectCollection<KeyedStateHandle> managedKeyedState,
             StateObjectCollection<KeyedStateHandle> rawKeyedState,
             StateObjectCollection<InputStateHandle> inputChannelState,
+            StateObjectCollection<InputChannelStateHandle> 
upstreamOutputBufferState,
             StateObjectCollection<OutputStateHandle> resultSubpartitionState,
             InflightDataRescalingDescriptor inputRescalingDescriptor,
             InflightDataRescalingDescriptor outputRescalingDescriptor) {
@@ -132,6 +136,7 @@ public class OperatorSubtaskState implements 
CompositeStateHandle {
         this.managedKeyedState = checkNotNull(managedKeyedState);
         this.rawKeyedState = checkNotNull(rawKeyedState);
         this.inputChannelState = checkNotNull(inputChannelState);
+        this.upstreamOutputBufferState = 
checkNotNull(upstreamOutputBufferState);
         this.resultSubpartitionState = checkNotNull(resultSubpartitionState);
         this.inputRescalingDescriptor = checkNotNull(inputRescalingDescriptor);
         this.outputRescalingDescriptor = 
checkNotNull(outputRescalingDescriptor);
@@ -152,7 +157,8 @@ public class OperatorSubtaskState implements 
CompositeStateHandle {
     }
 
     private Stream<StateObjectCollection<? extends ChannelState>> 
streamChannelStates() {
-        return Stream.of(inputChannelState, 
resultSubpartitionState).filter(Objects::nonNull);
+        return Stream.of(inputChannelState, upstreamOutputBufferState, 
resultSubpartitionState)
+                .filter(Objects::nonNull);
     }
 
     @VisibleForTesting
@@ -164,6 +170,7 @@ public class OperatorSubtaskState implements 
CompositeStateHandle {
                 StateObjectCollection.empty(),
                 StateObjectCollection.empty(),
                 StateObjectCollection.empty(),
+                StateObjectCollection.empty(),
                 InflightDataRescalingDescriptor.NO_RESCALE,
                 InflightDataRescalingDescriptor.NO_RESCALE);
     }
@@ -190,6 +197,10 @@ public class OperatorSubtaskState implements 
CompositeStateHandle {
         return inputChannelState;
     }
 
+    public StateObjectCollection<InputChannelStateHandle> 
getUpstreamOutputBufferState() {
+        return upstreamOutputBufferState;
+    }
+
     public StateObjectCollection<OutputStateHandle> 
getResultSubpartitionState() {
         return resultSubpartitionState;
     }
@@ -343,6 +354,8 @@ public class OperatorSubtaskState implements 
CompositeStateHandle {
                 + rawKeyedState
                 + ", inputChannelState="
                 + inputChannelState
+                + ", upstreamOutputBufferState="
+                + upstreamOutputBufferState
                 + ", resultSubpartitionState="
                 + resultSubpartitionState
                 + ", stateSize="
@@ -358,6 +371,7 @@ public class OperatorSubtaskState implements 
CompositeStateHandle {
                 || managedKeyedState.hasState()
                 || rawKeyedState.hasState()
                 || inputChannelState.hasState()
+                || upstreamOutputBufferState.hasState()
                 || resultSubpartitionState.hasState();
     }
 
@@ -368,6 +382,7 @@ public class OperatorSubtaskState implements 
CompositeStateHandle {
                 .setRawOperatorState(rawOperatorState)
                 .setRawKeyedState(rawKeyedState)
                 .setInputChannelState(inputChannelState)
+                .setUpstreamOutputBufferState(upstreamOutputBufferState)
                 .setResultSubpartitionState(resultSubpartitionState)
                 .setInputRescalingDescriptor(inputRescalingDescriptor)
                 .setOutputRescalingDescriptor(outputRescalingDescriptor);
@@ -392,6 +407,8 @@ public class OperatorSubtaskState implements 
CompositeStateHandle {
                 StateObjectCollection.empty();
         private StateObjectCollection<InputStateHandle> inputChannelState =
                 StateObjectCollection.empty();
+        private StateObjectCollection<InputChannelStateHandle> 
upstreamOutputBufferState =
+                StateObjectCollection.empty();
         private StateObjectCollection<OutputStateHandle> 
resultSubpartitionState =
                 StateObjectCollection.empty();
         private InflightDataRescalingDescriptor inputRescalingDescriptor =
@@ -449,6 +466,12 @@ public class OperatorSubtaskState implements 
CompositeStateHandle {
             return this;
         }
 
+        public Builder setUpstreamOutputBufferState(
+                StateObjectCollection<InputChannelStateHandle> 
upstreamOutputBufferState) {
+            this.upstreamOutputBufferState = 
checkNotNull(upstreamOutputBufferState);
+            return this;
+        }
+
         public Builder setResultSubpartitionState(
                 StateObjectCollection<OutputStateHandle> 
resultSubpartitionState) {
             this.resultSubpartitionState = 
checkNotNull(resultSubpartitionState);
@@ -474,6 +497,7 @@ public class OperatorSubtaskState implements 
CompositeStateHandle {
                     managedKeyedState,
                     rawKeyedState,
                     inputChannelState,
+                    upstreamOutputBufferState,
                     resultSubpartitionState,
                     inputRescalingDescriptor,
                     outputRescalingDescriptor);
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
index 70611f7b40e..6bc30d488c4 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
@@ -75,6 +75,7 @@ public class StateAssignmentOperation {
 
     private final long restoreCheckpointId;
     private final boolean allowNonRestoredState;
+    private final boolean recoverOutputOnDownstreamTask;
 
     /** The state assignments for each ExecutionJobVertex that will be filled 
in multiple passes. */
     private final Map<ExecutionJobVertex, TaskStateAssignment> 
vertexAssignments;
@@ -90,18 +91,29 @@ public class StateAssignmentOperation {
             long restoreCheckpointId,
             Set<ExecutionJobVertex> tasks,
             Map<OperatorID, OperatorState> operatorStates,
-            boolean allowNonRestoredState) {
+            boolean allowNonRestoredState,
+            boolean recoverOutputOnDownstreamTask) {
 
         this.restoreCheckpointId = restoreCheckpointId;
         this.tasks = Preconditions.checkNotNull(tasks);
         this.operatorStates = Preconditions.checkNotNull(operatorStates);
         this.allowNonRestoredState = allowNonRestoredState;
+        this.recoverOutputOnDownstreamTask = recoverOutputOnDownstreamTask;
         this.vertexAssignments = 
CollectionUtil.newHashMapWithExpectedSize(tasks.size());
     }
 
     public void assignStates() {
         checkStateMappingCompleteness(allowNonRestoredState, operatorStates, 
tasks);
 
+        buildStateAssignments();
+
+        repartitionState();
+
+        // actually assign the state
+        applyStateAssignments();
+    }
+
+    private void buildStateAssignments() {
         Map<OperatorID, OperatorState> localOperators = new 
HashMap<>(operatorStates);
 
         // find the states of all operators belonging to this task and compute 
additional
@@ -135,13 +147,16 @@ public class StateAssignmentOperation {
                             executionJobVertex,
                             operatorStates,
                             consumerAssignment,
-                            vertexAssignments);
+                            vertexAssignments,
+                            recoverOutputOnDownstreamTask);
             vertexAssignments.put(executionJobVertex, stateAssignment);
             for (final IntermediateResult producedDataSet : 
executionJobVertex.getInputs()) {
                 consumerAssignment.put(producedDataSet.getId(), 
stateAssignment);
             }
         }
+    }
 
+    private void repartitionState() {
         // repartition state
         for (TaskStateAssignment stateAssignment : vertexAssignments.values()) 
{
             if (stateAssignment.hasNonFinishedState
@@ -153,7 +168,22 @@ public class StateAssignmentOperation {
             }
         }
 
-        // actually assign the state
+        // distribute output channel states to downstream tasks if needed
+        // Note: it has to be called after assignAttemptState for all tasks 
since the
+        // redistributing of result subpartition states depend on the 
inputSubtaskMappings
+        // of downstream tasks.
+        if (recoverOutputOnDownstreamTask) {
+            for (TaskStateAssignment stateAssignment : 
vertexAssignments.values()) {
+                // If recoverOutputOnDownstreamTask is enabled, all upstream 
output buffers have to
+                // be distributed to downstream since the upstream task side 
doesn’t deserialize
+                // records generally. It is easy to filter records and 
re-upload records if
+                // recovering output buffers on downstream task side directly.
+                stateAssignment.distributeOutputBuffersToDownstream();
+            }
+        }
+    }
+
+    private void applyStateAssignments() {
         for (TaskStateAssignment stateAssignment : vertexAssignments.values()) 
{
             // If upstream has output states or downstream has input states, 
even the empty task
             // state should be assigned for the current task in order to 
notify this task that the
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateAssignment.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateAssignment.java
index 4601b4716a1..a6db5e4837f 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateAssignment.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskStateAssignment.java
@@ -21,6 +21,7 @@ import org.apache.flink.runtime.OperatorIDPair;
 import 
org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor;
 import 
org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor.InflightDataGateOrPartitionRescalingDescriptor.MappingType;
 import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
+import 
org.apache.flink.runtime.checkpoint.channel.ResultSubpartitionDistributor;
 import org.apache.flink.runtime.checkpoint.channel.ResultSubpartitionInfo;
 import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
 import org.apache.flink.runtime.executiongraph.IntermediateResult;
@@ -43,6 +44,7 @@ import org.slf4j.LoggerFactory;
 import javax.annotation.Nonnull;
 import javax.annotation.Nullable;
 
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
 import java.util.HashMap;
@@ -91,12 +93,26 @@ class TaskStateAssignment {
     final Map<OperatorInstanceID, List<InputChannelStateHandle>> 
inputChannelStates;
     final Map<OperatorInstanceID, List<ResultSubpartitionStateHandle>> 
resultSubpartitionStates;
 
+    /**
+     * Stores input channel states that come from upstream task's output 
buffers. It takes effect
+     * when {@link
+     * 
org.apache.flink.configuration.CheckpointingOptions#UNALIGNED_RECOVER_OUTPUT_ON_DOWNSTREAM}
+     * is enabled.
+     */
+    private final Map<OperatorInstanceID, List<InputChannelStateHandle>> 
upstreamOutputBufferStates;
+
     /** The subtask mapping when the output operator was rescaled. */
     private final Map<Integer, SubtasksRescaleMapping> outputSubtaskMappings = 
new HashMap<>();
 
     /** The subtask mapping when the input operator was rescaled. */
     private final Map<Integer, SubtasksRescaleMapping> inputSubtaskMappings = 
new HashMap<>();
 
+    /** InflightDataRescalingDescriptor for each subtask. */
+    private final Map<OperatorInstanceID, InflightDataRescalingDescriptor>
+            outputRescalingDescriptors = new HashMap<>();
+
+    private final boolean recoverOutputOnDownstreamTask;
+
     @Nullable private TaskStateAssignment[] downstreamAssignments;
     @Nullable private TaskStateAssignment[] upstreamAssignments;
     @Nullable private Boolean hasUpstreamOutputStates;
@@ -109,7 +125,8 @@ class TaskStateAssignment {
             ExecutionJobVertex executionJobVertex,
             Map<OperatorID, OperatorState> oldState,
             Map<IntermediateDataSetID, TaskStateAssignment> consumerAssignment,
-            Map<ExecutionJobVertex, TaskStateAssignment> vertexAssignments) {
+            Map<ExecutionJobVertex, TaskStateAssignment> vertexAssignments,
+            boolean recoverOutputOnDownstreamTask) {
 
         this.executionJobVertex = executionJobVertex;
         this.oldState = oldState;
@@ -126,6 +143,7 @@ class TaskStateAssignment {
         newParallelism = executionJobVertex.getParallelism();
         this.consumerAssignment = checkNotNull(consumerAssignment);
         this.vertexAssignments = checkNotNull(vertexAssignments);
+        this.recoverOutputOnDownstreamTask = recoverOutputOnDownstreamTask;
         final int expectedNumberOfSubtasks = newParallelism * oldState.size();
 
         subManagedOperatorState =
@@ -134,6 +152,8 @@ class TaskStateAssignment {
         inputChannelStates = 
CollectionUtil.newHashMapWithExpectedSize(expectedNumberOfSubtasks);
         resultSubpartitionStates =
                 
CollectionUtil.newHashMapWithExpectedSize(expectedNumberOfSubtasks);
+        upstreamOutputBufferStates =
+                
CollectionUtil.newHashMapWithExpectedSize(expectedNumberOfSubtasks);
         subManagedKeyedState = 
CollectionUtil.newHashMapWithExpectedSize(expectedNumberOfSubtasks);
         subRawKeyedState = 
CollectionUtil.newHashMapWithExpectedSize(expectedNumberOfSubtasks);
 
@@ -238,8 +258,15 @@ class TaskStateAssignment {
                 .setRawKeyedState(getState(instanceID, subRawKeyedState))
                 .setInputChannelState(
                         
castToInputStateCollection(inputChannelStates.get(instanceID)))
+                .setUpstreamOutputBufferState(
+                        new 
StateObjectCollection<>(upstreamOutputBufferStates.get(instanceID)))
                 .setResultSubpartitionState(
-                        
castToOutputStateCollection(resultSubpartitionStates.get(instanceID)))
+                        // If recoverOutputOnDownstreamTask is enabled, clear 
own output buffers as
+                        // they are migrated to downstream
+                        recoverOutputOnDownstreamTask
+                                ? castToOutputStateCollection(null)
+                                : castToOutputStateCollection(
+                                        
resultSubpartitionStates.get(instanceID)))
                 .setInputRescalingDescriptor(
                         createRescalingDescriptor(
                                 instanceID,
@@ -254,20 +281,7 @@ class TaskStateAssignment {
                                 inputSubtaskMappings,
                                 this::getInputMapping,
                                 true))
-                .setOutputRescalingDescriptor(
-                        createRescalingDescriptor(
-                                instanceID,
-                                outputOperatorID,
-                                getDownstreamAssignments(),
-                                (assignment, recompute) -> {
-                                    int assignmentIndex =
-                                            getAssignmentIndex(
-                                                    
assignment.getUpstreamAssignments(), this);
-                                    return 
assignment.getInputMapping(assignmentIndex, recompute);
-                                },
-                                outputSubtaskMappings,
-                                this::getOutputMapping,
-                                false))
+                
.setOutputRescalingDescriptor(getOutputRescalingDescriptor(instanceID))
                 .build();
     }
 
@@ -289,6 +303,33 @@ class TaskStateAssignment {
         return hasDownstreamInputStates;
     }
 
+    /**
+     * Gets the output rescaling descriptor for a specific instance with 
caching. Each descriptor is
+     * computed only once and cached for subsequent access.
+     */
+    public InflightDataRescalingDescriptor getOutputRescalingDescriptor(
+            OperatorInstanceID instanceID) {
+        return outputRescalingDescriptors.computeIfAbsent(
+                instanceID, this::computeOutputRescalingDescriptor);
+    }
+
+    /** Computes the output rescaling descriptor for a single subtask. */
+    private InflightDataRescalingDescriptor computeOutputRescalingDescriptor(
+            OperatorInstanceID instanceID) {
+        return createRescalingDescriptor(
+                instanceID,
+                outputOperatorID,
+                getDownstreamAssignments(),
+                (downstreamAssignment, recompute) -> {
+                    int assignmentIndex =
+                            
getAssignmentIndex(downstreamAssignment.getUpstreamAssignments(), this);
+                    return 
downstreamAssignment.getInputMapping(assignmentIndex, recompute);
+                },
+                outputSubtaskMappings,
+                this::getOutputMapping,
+                false);
+    }
+
     private InflightDataGateOrPartitionRescalingDescriptor log(
             InflightDataGateOrPartitionRescalingDescriptor descriptor, int 
subtask, int partition) {
         LOG.debug(
@@ -540,6 +581,78 @@ class TaskStateAssignment {
         return false;
     }
 
+    void distributeOutputBuffersToDownstream() {
+        for (Map.Entry<OperatorInstanceID, 
List<ResultSubpartitionStateHandle>> entry :
+                resultSubpartitionStates.entrySet()) {
+            OperatorInstanceID operatorInstanceID = entry.getKey();
+            List<ResultSubpartitionStateHandle> stateHandles = 
entry.getValue();
+
+            ResultSubpartitionDistributor distributor =
+                    new ResultSubpartitionDistributor(
+                            getOutputRescalingDescriptor(operatorInstanceID));
+
+            for (final ResultSubpartitionStateHandle stateHandle : 
stateHandles) {
+                distributeOutputBufferToDownstream(stateHandle, distributor);
+            }
+        }
+    }
+
+    private void distributeOutputBufferToDownstream(
+            ResultSubpartitionStateHandle stateHandle, 
ResultSubpartitionDistributor distributor) {
+        // From the perspective of the downstream task, the 
oldUpstreamSubtaskIndex will be
+        // treated as the inputChannelIdx, and the info.getSubPartitionIdx() 
will be treated
+        // as the oldDownstreamSubtaskIndex.
+        int oldUpstreamSubtaskIndex = stateHandle.getSubtaskIndex();
+        ResultSubpartitionInfo info = stateHandle.getInfo();
+        int partitionIdx = info.getPartitionIdx();
+        int oldDownstreamSubtaskIndex = info.getSubPartitionIdx();
+
+        int gateIdxResultPartition = 
findInputGateIdxForResultPartition(partitionIdx);
+        TaskStateAssignment downstreamAssignment = 
getDownstreamAssignments()[partitionIdx];
+
+        List<ResultSubpartitionInfo> mappedSubpartitions = 
distributor.getMappedSubpartitions(info);
+        for (final ResultSubpartitionInfo mappedSubpartition : 
mappedSubpartitions) {
+            int targetDownstreamSubtaskId = 
mappedSubpartition.getSubPartitionIdx();
+
+            OperatorInstanceID downstreamOperatorInstance =
+                    new OperatorInstanceID(
+                            targetDownstreamSubtaskId, 
downstreamAssignment.inputOperatorID);
+
+            InputChannelInfo inputChannelInfo =
+                    new InputChannelInfo(gateIdxResultPartition, 
oldUpstreamSubtaskIndex);
+
+            InputChannelStateHandle upstreamOutputBufferHandle =
+                    new InputChannelStateHandle(
+                            oldDownstreamSubtaskIndex,
+                            inputChannelInfo,
+                            stateHandle.getDelegate(),
+                            stateHandle.getOffsets(),
+                            stateHandle.getStateSize());
+
+            List<InputChannelStateHandle> upstreamOutputBufferHandles =
+                    
downstreamAssignment.upstreamOutputBufferStates.computeIfAbsent(
+                            downstreamOperatorInstance, k -> new 
ArrayList<>());
+            upstreamOutputBufferHandles.add(upstreamOutputBufferHandle);
+        }
+    }
+
+    private int findInputGateIdxForResultPartition(int partitionIndex) {
+        // Check downstream input state for this partition
+        TaskStateAssignment downstreamAssignment = 
getDownstreamAssignments()[partitionIndex];
+
+        IntermediateResult producedResult =
+                executionJobVertex.getProducedDataSets()[partitionIndex];
+        IntermediateDataSetID resultId = producedResult.getId();
+        List<IntermediateResult> inputs = 
downstreamAssignment.executionJobVertex.getInputs();
+        for (int i = 0; i < inputs.size(); i++) {
+            if (inputs.get(i).getId().equals(resultId)) {
+                return i;
+            }
+        }
+        throw new IllegalArgumentException(
+                "No channel rescaler found during rescaling of channel state");
+    }
+
     @Override
     public String toString() {
         return "TaskStateAssignment for " + executionJobVertex.getName();
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java
index 6fe1f1a6e99..3daa4b4947a 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java
@@ -67,6 +67,11 @@ public class SequentialChannelStateReaderImpl implements 
SequentialChannelStateR
                     groupByDelegate(
                             streamSubtaskStates(),
                             ChannelStateHelper::extractUnmergedInputHandles));
+            read(
+                    stateHandler,
+                    groupByDelegate(
+                            streamSubtaskStates(),
+                            
OperatorSubtaskState::getUpstreamOutputBufferState));
         }
     }
 
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/CheckpointCoordinatorConfiguration.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/CheckpointCoordinatorConfiguration.java
index 2d100aabefa..cc755fc17a4 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/CheckpointCoordinatorConfiguration.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/tasks/CheckpointCoordinatorConfiguration.java
@@ -71,6 +71,8 @@ public class CheckpointCoordinatorConfiguration implements 
Serializable {
 
     private final boolean enableCheckpointsAfterTasksFinish;
 
+    private final boolean recoverOutputOnDownstreamTask;
+
     /**
      * @deprecated use {@link #builder()}.
      */
@@ -98,6 +100,7 @@ public class CheckpointCoordinatorConfiguration implements 
Serializable {
                 isUnalignedCheckpoint,
                 0,
                 checkpointIdOfIgnoredInFlightData,
+                false,
                 false);
     }
 
@@ -113,7 +116,8 @@ public class CheckpointCoordinatorConfiguration implements 
Serializable {
             boolean isUnalignedCheckpointsEnabled,
             long alignedCheckpointTimeout,
             long checkpointIdOfIgnoredInFlightData,
-            boolean enableCheckpointsAfterTasksFinish) {
+            boolean enableCheckpointsAfterTasksFinish,
+            boolean recoverOutputOnDownstreamTask) {
 
         if (checkpointIntervalDuringBacklog < MINIMAL_CHECKPOINT_TIME) {
             // interval of max value means disable periodic checkpoint
@@ -144,6 +148,7 @@ public class CheckpointCoordinatorConfiguration implements 
Serializable {
         this.alignedCheckpointTimeout = alignedCheckpointTimeout;
         this.checkpointIdOfIgnoredInFlightData = 
checkpointIdOfIgnoredInFlightData;
         this.enableCheckpointsAfterTasksFinish = 
enableCheckpointsAfterTasksFinish;
+        this.recoverOutputOnDownstreamTask = recoverOutputOnDownstreamTask;
     }
 
     public long getCheckpointInterval() {
@@ -198,6 +203,10 @@ public class CheckpointCoordinatorConfiguration implements 
Serializable {
         return enableCheckpointsAfterTasksFinish;
     }
 
+    public boolean isRecoverOutputOnDownstreamTask() {
+        return recoverOutputOnDownstreamTask;
+    }
+
     @Override
     public boolean equals(Object o) {
         if (this == o) {
@@ -217,7 +226,8 @@ public class CheckpointCoordinatorConfiguration implements 
Serializable {
                 && checkpointRetentionPolicy == that.checkpointRetentionPolicy
                 && tolerableCheckpointFailureNumber == 
that.tolerableCheckpointFailureNumber
                 && checkpointIdOfIgnoredInFlightData == 
that.checkpointIdOfIgnoredInFlightData
-                && enableCheckpointsAfterTasksFinish == 
that.enableCheckpointsAfterTasksFinish;
+                && enableCheckpointsAfterTasksFinish == 
that.enableCheckpointsAfterTasksFinish
+                && recoverOutputOnDownstreamTask == 
that.recoverOutputOnDownstreamTask;
     }
 
     @Override
@@ -233,7 +243,8 @@ public class CheckpointCoordinatorConfiguration implements 
Serializable {
                 alignedCheckpointTimeout,
                 tolerableCheckpointFailureNumber,
                 checkpointIdOfIgnoredInFlightData,
-                enableCheckpointsAfterTasksFinish);
+                enableCheckpointsAfterTasksFinish,
+                recoverOutputOnDownstreamTask);
     }
 
     @Override
@@ -261,6 +272,8 @@ public class CheckpointCoordinatorConfiguration implements 
Serializable {
                 + checkpointIdOfIgnoredInFlightData
                 + ", enableCheckpointsAfterTasksFinish="
                 + enableCheckpointsAfterTasksFinish
+                + ", recoverOutputOnDownstreamTask="
+                + recoverOutputOnDownstreamTask
                 + '}';
     }
 
@@ -283,6 +296,7 @@ public class CheckpointCoordinatorConfiguration implements 
Serializable {
         private long alignedCheckpointTimeout = 0;
         private long checkpointIdOfIgnoredInFlightData;
         private boolean enableCheckpointsAfterTasksFinish;
+        private boolean recoverOutputOnDownstreamTask;
 
         public CheckpointCoordinatorConfiguration build() {
             return new CheckpointCoordinatorConfiguration(
@@ -297,7 +311,8 @@ public class CheckpointCoordinatorConfiguration implements 
Serializable {
                     isUnalignedCheckpointsEnabled,
                     alignedCheckpointTimeout,
                     checkpointIdOfIgnoredInFlightData,
-                    enableCheckpointsAfterTasksFinish);
+                    enableCheckpointsAfterTasksFinish,
+                    recoverOutputOnDownstreamTask);
         }
 
         public CheckpointCoordinatorConfigurationBuilder setCheckpointInterval(
@@ -370,5 +385,11 @@ public class CheckpointCoordinatorConfiguration implements 
Serializable {
             this.enableCheckpointsAfterTasksFinish = 
enableCheckpointsAfterTasksFinish;
             return this;
         }
+
+        public CheckpointCoordinatorConfigurationBuilder 
setRecoverOutputOnDownstreamTask(
+                boolean recoverOutputOnDownstreamTask) {
+            this.recoverOutputOnDownstreamTask = recoverOutputOnDownstreamTask;
+            return this;
+        }
     }
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamGraph.java
 
b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamGraph.java
index 49822d1c666..887c22a395e 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamGraph.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/streaming/api/graph/StreamGraph.java
@@ -37,6 +37,7 @@ import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.api.java.tuple.Tuple3;
 import org.apache.flink.api.java.typeutils.MissingTypeInfo;
+import org.apache.flink.configuration.CheckpointingOptions;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.configuration.ExecutionOptions;
 import org.apache.flink.configuration.ExternalizedCheckpointRetention;
@@ -389,6 +390,10 @@ public class StreamGraph implements Pipeline, 
ExecutionPlan {
                                 cfg.getCheckpointIdOfIgnoredInFlightData())
                         
.setAlignedCheckpointTimeout(cfg.getAlignedCheckpointTimeout().toMillis())
                         
.setEnableCheckpointsAfterTasksFinish(isEnableCheckpointsAfterTasksFinish())
+                        .setRecoverOutputOnDownstreamTask(
+                                jobConfiguration.get(
+                                        CheckpointingOptions
+                                                
.UNALIGNED_RECOVER_OUTPUT_ON_DOWNSTREAM))
                         .build(),
                 serializedStateBackend,
                 getJobConfiguration()
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperationTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperationTest.java
index c13bcd0e321..711c3ef5cf3 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperationTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperationTest.java
@@ -38,18 +38,25 @@ import org.apache.flink.runtime.jobgraph.JobVertex;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.jobgraph.OperatorInstanceID;
+import org.apache.flink.runtime.state.InputChannelStateHandle;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.OperatorStreamStateHandle;
+import org.apache.flink.runtime.state.OutputStateHandle;
+import org.apache.flink.runtime.state.ResultSubpartitionStateHandle;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
 import org.apache.flink.runtime.testtasks.NoOpInvokable;
 import org.apache.flink.testutils.TestingUtils;
 import org.apache.flink.testutils.executor.TestExecutorExtension;
 
+import org.apache.flink.shaded.guava33.com.google.common.collect.Iterables;
+
 import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.extension.RegisterExtension;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.ValueSource;
 
 import javax.annotation.Nullable;
 
@@ -506,7 +513,7 @@ class StateAssignmentOperationTest {
                 buildVertices(operatorIds, numSubTasks, RANGE, ROUND_ROBIN);
         Map<OperatorID, OperatorState> states = 
buildOperatorStates(operatorIds, numSubTasks);
 
-        new StateAssignmentOperation(0, new HashSet<>(vertices.values()), 
states, false)
+        new StateAssignmentOperation(0, new HashSet<>(vertices.values()), 
states, false, false)
                 .assignStates();
 
         for (OperatorID operatorId : operatorIds) {
@@ -535,7 +542,7 @@ class StateAssignmentOperationTest {
         Map<OperatorID, ExecutionJobVertex> vertices =
                 toExecutionVertices(upstream1, upstream2, downstream);
 
-        new StateAssignmentOperation(0, new HashSet<>(vertices.values()), 
states, false)
+        new StateAssignmentOperation(0, new HashSet<>(vertices.values()), 
states, false, false)
                 .assignStates();
 
         assertThat(
@@ -599,7 +606,7 @@ class StateAssignmentOperationTest {
         Map<OperatorID, ExecutionJobVertex> vertices =
                 toExecutionVertices(upstream1, upstream2, downstream);
 
-        new StateAssignmentOperation(0, new HashSet<>(vertices.values()), 
states, false)
+        new StateAssignmentOperation(0, new HashSet<>(vertices.values()), 
states, false, false)
                 .assignStates();
 
         assertThat(
@@ -619,17 +626,32 @@ class StateAssignmentOperationTest {
                 .isEqualTo(6);
     }
 
-    @Test
-    void testChannelStateAssignmentDownscaling() throws JobException, 
JobExecutionException {
+    @ParameterizedTest
+    @ValueSource(booleans = {true, false})
+    void testChannelStateAssignmentDownscaling(boolean 
recoverOutputOnDownstreamTask)
+            throws JobException, JobExecutionException {
+        int oldParallelism = 3;
+        int newParallelism = 2;
+
         List<OperatorID> operatorIds = buildOperatorIds(2);
-        Map<OperatorID, OperatorState> states = 
buildOperatorStates(operatorIds, 3);
+        OperatorID firstOperator = operatorIds.get(0);
+        JobResultSubpartitionHandlers jobResultSubpartitionHandlers =
+                new JobResultSubpartitionHandlers(operatorIds, oldParallelism);
+        Map<OperatorID, OperatorState> states =
+                buildOperatorStates(operatorIds, oldParallelism, 
jobResultSubpartitionHandlers);
 
         Map<OperatorID, ExecutionJobVertex> vertices =
-                buildVertices(operatorIds, 2, RANGE, ROUND_ROBIN);
-
-        new StateAssignmentOperation(0, new HashSet<>(vertices.values()), 
states, false)
+                buildVertices(operatorIds, newParallelism, RANGE, ROUND_ROBIN);
+
+        new StateAssignmentOperation(
+                        0,
+                        new HashSet<>(vertices.values()),
+                        states,
+                        false,
+                        recoverOutputOnDownstreamTask)
                 .assignStates();
 
+        OperatorID upstreamOperator = null;
         for (OperatorID operatorId : operatorIds) {
             // input is range partitioned, so there is an overlap
             assertState(
@@ -648,22 +670,68 @@ class StateAssignmentOperationTest {
                     OperatorSubtaskState::getInputChannelState,
                     1,
                     2);
-            // output is round robin redistributed
-            assertState(
-                    vertices,
-                    operatorId,
-                    states,
-                    0,
-                    OperatorSubtaskState::getResultSubpartitionState,
-                    0,
-                    2);
-            assertState(
-                    vertices,
-                    operatorId,
-                    states,
-                    1,
-                    OperatorSubtaskState::getResultSubpartitionState,
-                    1);
+
+            if (recoverOutputOnDownstreamTask) {
+                // output buffer states are moved to downstream task when
+                // recoverOutputOnDownstreamTask is enabled
+                assertStateEmptyForAllSubtasks(
+                        vertices,
+                        operatorId,
+                        newParallelism,
+                        OperatorSubtaskState::getResultSubpartitionState);
+
+                if (firstOperator == operatorId) {
+                    // The first operator does not have any upstream.
+                    assertStateEmptyForAllSubtasks(
+                            vertices,
+                            operatorId,
+                            newParallelism,
+                            
OperatorSubtaskState::getUpstreamOutputBufferState);
+                } else {
+                    ExecutionJobVertex executionJobVertex = 
vertices.get(operatorId);
+                    int[] indexes0 = {0, 0, 0, 1, 1, 0, 1, 1, 2, 0, 2, 1};
+                    assertUpstreamOutputBufferState(
+                            jobResultSubpartitionHandlers,
+                            executionJobVertex,
+                            operatorId,
+                            0,
+                            upstreamOperator,
+                            indexes0);
+                    int[] indexes1 = {0, 1, 0, 2, 1, 1, 1, 2, 2, 1, 2, 2};
+                    assertUpstreamOutputBufferState(
+                            jobResultSubpartitionHandlers,
+                            executionJobVertex,
+                            operatorId,
+                            1,
+                            upstreamOperator,
+                            indexes1);
+                }
+            } else {
+                // output is round robin redistributed
+                assertState(
+                        vertices,
+                        operatorId,
+                        states,
+                        0,
+                        OperatorSubtaskState::getResultSubpartitionState,
+                        0,
+                        2);
+                assertState(
+                        vertices,
+                        operatorId,
+                        states,
+                        1,
+                        OperatorSubtaskState::getResultSubpartitionState,
+                        1);
+
+                // The upstream output buffer state is expected to be empty.
+                assertStateEmptyForAllSubtasks(
+                        vertices,
+                        operatorId,
+                        newParallelism,
+                        OperatorSubtaskState::getUpstreamOutputBufferState);
+            }
+            upstreamOperator = operatorId;
         }
 
         assertThat(
@@ -686,38 +754,98 @@ class StateAssignmentOperationTest {
                 .isEqualTo(rescalingDescriptor(to(1, 2), array(mappings(to(0, 
2), to(1))), set(1)));
     }
 
-    @Test
-    void testChannelStateAssignmentNoRescale() throws JobException, 
JobExecutionException {
+    @ParameterizedTest
+    @ValueSource(booleans = {true, false})
+    void testChannelStateAssignmentNoRescale(boolean 
recoverOutputOnDownstreamTask)
+            throws JobException, JobExecutionException {
         List<OperatorID> operatorIds = buildOperatorIds(2);
-        Map<OperatorID, OperatorState> states = 
buildOperatorStates(operatorIds, 2);
+        OperatorID firstOperator = operatorIds.get(0);
+        int parallelism = 2;
+        JobResultSubpartitionHandlers jobResultSubpartitionHandlers =
+                new JobResultSubpartitionHandlers(operatorIds, parallelism);
+
+        Map<OperatorID, OperatorState> states =
+                buildOperatorStates(operatorIds, parallelism, 
jobResultSubpartitionHandlers);
 
         Map<OperatorID, ExecutionJobVertex> vertices =
                 buildVertices(operatorIds, 2, RANGE, ROUND_ROBIN);
 
-        new StateAssignmentOperation(0, new HashSet<>(vertices.values()), 
states, false)
+        new StateAssignmentOperation(
+                        0,
+                        new HashSet<>(vertices.values()),
+                        states,
+                        false,
+                        recoverOutputOnDownstreamTask)
                 .assignStates();
 
+        OperatorID upstreamOperator = null;
         for (OperatorID operatorId : operatorIds) {
-            // input is range partitioned, so there is an overlap
+            // input is range partitioned
             assertState(
                     vertices, operatorId, states, 0, 
OperatorSubtaskState::getInputChannelState, 0);
             assertState(
                     vertices, operatorId, states, 1, 
OperatorSubtaskState::getInputChannelState, 1);
-            // output is round robin redistributed
-            assertState(
-                    vertices,
-                    operatorId,
-                    states,
-                    0,
-                    OperatorSubtaskState::getResultSubpartitionState,
-                    0);
-            assertState(
-                    vertices,
-                    operatorId,
-                    states,
-                    1,
-                    OperatorSubtaskState::getResultSubpartitionState,
-                    1);
+
+            if (recoverOutputOnDownstreamTask) {
+                // output buffer states are moved to downstream task when
+                // recoverOutputOnDownstreamTask is enabled
+                assertStateEmptyForAllSubtasks(
+                        vertices,
+                        operatorId,
+                        parallelism,
+                        OperatorSubtaskState::getResultSubpartitionState);
+
+                if (firstOperator == operatorId) {
+                    // The first operator does not have any upstream.
+                    assertStateEmptyForAllSubtasks(
+                            vertices,
+                            operatorId,
+                            parallelism,
+                            
OperatorSubtaskState::getUpstreamOutputBufferState);
+                } else {
+                    ExecutionJobVertex executionJobVertex = 
vertices.get(operatorId);
+                    int[] indexes0 = {0, 0, 1, 0};
+                    assertUpstreamOutputBufferState(
+                            jobResultSubpartitionHandlers,
+                            executionJobVertex,
+                            operatorId,
+                            0,
+                            upstreamOperator,
+                            indexes0);
+                    int[] indexes1 = {0, 1, 1, 1};
+                    assertUpstreamOutputBufferState(
+                            jobResultSubpartitionHandlers,
+                            executionJobVertex,
+                            operatorId,
+                            1,
+                            upstreamOperator,
+                            indexes1);
+                }
+            } else {
+                // output is round robin redistributed
+                assertState(
+                        vertices,
+                        operatorId,
+                        states,
+                        0,
+                        OperatorSubtaskState::getResultSubpartitionState,
+                        0);
+                assertState(
+                        vertices,
+                        operatorId,
+                        states,
+                        1,
+                        OperatorSubtaskState::getResultSubpartitionState,
+                        1);
+
+                // The upstream output buffer state is expected to be empty.
+                assertStateEmptyForAllSubtasks(
+                        vertices,
+                        operatorId,
+                        parallelism,
+                        OperatorSubtaskState::getUpstreamOutputBufferState);
+            }
+            upstreamOperator = operatorId;
         }
 
         assertThat(
@@ -739,17 +867,32 @@ class StateAssignmentOperationTest {
                 .isEqualTo(InflightDataRescalingDescriptor.NO_RESCALE);
     }
 
-    @Test
-    void testChannelStateAssignmentUpscaling() throws JobException, 
JobExecutionException {
+    @ParameterizedTest
+    @ValueSource(booleans = {true, false})
+    void testChannelStateAssignmentUpscaling(boolean 
recoverOutputOnDownstreamTask)
+            throws JobException, JobExecutionException {
+        int oldParallelism = 2;
+        int newParallelism = 3;
+
         List<OperatorID> operatorIds = buildOperatorIds(2);
-        Map<OperatorID, OperatorState> states = 
buildOperatorStates(operatorIds, 2);
+        OperatorID firstOperator = operatorIds.get(0);
+        JobResultSubpartitionHandlers jobResultSubpartitionHandlers =
+                new JobResultSubpartitionHandlers(operatorIds, oldParallelism);
+        Map<OperatorID, OperatorState> states =
+                buildOperatorStates(operatorIds, oldParallelism, 
jobResultSubpartitionHandlers);
 
         Map<OperatorID, ExecutionJobVertex> vertices =
-                buildVertices(operatorIds, 3, RANGE, ROUND_ROBIN);
-
-        new StateAssignmentOperation(0, new HashSet<>(vertices.values()), 
states, false)
+                buildVertices(operatorIds, newParallelism, RANGE, ROUND_ROBIN);
+
+        new StateAssignmentOperation(
+                        0,
+                        new HashSet<>(vertices.values()),
+                        states,
+                        false,
+                        recoverOutputOnDownstreamTask)
                 .assignStates();
 
+        OperatorID upstreamOperator = null;
         for (OperatorID operatorId : operatorIds) {
             // input is range partitioned, so there is an overlap
             assertState(
@@ -764,27 +907,81 @@ class StateAssignmentOperationTest {
                     1);
             assertState(
                     vertices, operatorId, states, 2, 
OperatorSubtaskState::getInputChannelState, 1);
-            // output is round robin redistributed
-            assertState(
-                    vertices,
-                    operatorId,
-                    states,
-                    0,
-                    OperatorSubtaskState::getResultSubpartitionState,
-                    0);
-            assertState(
-                    vertices,
-                    operatorId,
-                    states,
-                    1,
-                    OperatorSubtaskState::getResultSubpartitionState,
-                    1);
-            assertState(
-                    vertices,
-                    operatorId,
-                    states,
-                    2,
-                    OperatorSubtaskState::getResultSubpartitionState);
+
+            if (recoverOutputOnDownstreamTask) {
+                // output buffer states are moved to downstream task when
+                // recoverOutputOnDownstreamTask is enabled
+                assertStateEmptyForAllSubtasks(
+                        vertices,
+                        operatorId,
+                        newParallelism,
+                        OperatorSubtaskState::getResultSubpartitionState);
+
+                if (firstOperator == operatorId) {
+                    // The first operator does not have any upstream.
+                    assertStateEmptyForAllSubtasks(
+                            vertices,
+                            operatorId,
+                            newParallelism,
+                            
OperatorSubtaskState::getUpstreamOutputBufferState);
+                } else {
+                    ExecutionJobVertex executionJobVertex = 
vertices.get(operatorId);
+                    int[] indexes0 = {0, 0, 1, 0};
+                    assertUpstreamOutputBufferState(
+                            jobResultSubpartitionHandlers,
+                            executionJobVertex,
+                            operatorId,
+                            0,
+                            upstreamOperator,
+                            indexes0);
+                    int[] indexes1 = {0, 0, 0, 1, 1, 0, 1, 1};
+                    assertUpstreamOutputBufferState(
+                            jobResultSubpartitionHandlers,
+                            executionJobVertex,
+                            operatorId,
+                            1,
+                            upstreamOperator,
+                            indexes1);
+                    int[] indexes2 = {0, 1, 1, 1};
+                    assertUpstreamOutputBufferState(
+                            jobResultSubpartitionHandlers,
+                            executionJobVertex,
+                            operatorId,
+                            2,
+                            upstreamOperator,
+                            indexes2);
+                }
+            } else {
+                // output is round robin redistributed
+                assertState(
+                        vertices,
+                        operatorId,
+                        states,
+                        0,
+                        OperatorSubtaskState::getResultSubpartitionState,
+                        0);
+                assertState(
+                        vertices,
+                        operatorId,
+                        states,
+                        1,
+                        OperatorSubtaskState::getResultSubpartitionState,
+                        1);
+                assertState(
+                        vertices,
+                        operatorId,
+                        states,
+                        2,
+                        OperatorSubtaskState::getResultSubpartitionState);
+
+                // The upstream output buffer state is expected to be empty.
+                assertStateEmptyForAllSubtasks(
+                        vertices,
+                        operatorId,
+                        newParallelism,
+                        OperatorSubtaskState::getUpstreamOutputBufferState);
+            }
+            upstreamOperator = operatorId;
         }
 
         assertThat(
@@ -845,7 +1042,7 @@ class StateAssignmentOperationTest {
                 buildVertices(operatorIds, 3, RANGE, ROUND_ROBIN);
 
         // when: States are assigned.
-        new StateAssignmentOperation(0, new HashSet<>(vertices.values()), 
states, false)
+        new StateAssignmentOperation(0, new HashSet<>(vertices.values()), 
states, false, false)
                 .assignStates();
 
         // then: All subtask have not null TaskRestore information(even if it 
is empty).
@@ -942,7 +1139,7 @@ class StateAssignmentOperationTest {
                 buildVertices(opIdWithParallelism, RANGE, ROUND_ROBIN);
 
         // Run state assignment
-        new StateAssignmentOperation(0, new HashSet<>(vertices.values()), 
states, false)
+        new StateAssignmentOperation(0, new HashSet<>(vertices.values()), 
states, false, false)
                 .assignStates();
 
         // Check results
@@ -983,8 +1180,10 @@ class StateAssignmentOperationTest {
                 .isEqualTo(expectedCount);
     }
 
-    @Test
-    void testStateWithFullyFinishedOperators() throws JobException, 
JobExecutionException {
+    @ParameterizedTest
+    @ValueSource(booleans = {true, false})
+    void testStateWithFullyFinishedOperators(boolean 
recoverOutputOnDownstreamTask)
+            throws JobException, JobExecutionException {
         List<OperatorID> operatorIds = buildOperatorIds(2);
         Map<OperatorID, OperatorState> states =
                 
buildOperatorStates(Collections.singletonList(operatorIds.get(1)), 3);
@@ -996,7 +1195,12 @@ class StateAssignmentOperationTest {
 
         Map<OperatorID, ExecutionJobVertex> vertices =
                 buildVertices(operatorIds, 2, RANGE, ROUND_ROBIN);
-        new StateAssignmentOperation(0, new HashSet<>(vertices.values()), 
states, false)
+        new StateAssignmentOperation(
+                        0,
+                        new HashSet<>(vertices.values()),
+                        states,
+                        false,
+                        recoverOutputOnDownstreamTask)
                 .assignStates();
 
         // Check the job vertex with only finished operator.
@@ -1043,8 +1247,33 @@ class StateAssignmentOperationTest {
                 .isTrue();
     }
 
-    @Test
-    void assigningStatesShouldWorkWithUserDefinedOperatorIdsAsWell() {
+    private void assertStateEmptyForAllSubtasks(
+            Map<OperatorID, ExecutionJobVertex> vertices,
+            OperatorID operatorId,
+            int numSubtasks,
+            Function<OperatorSubtaskState, StateObjectCollection<?>> 
extractor) {
+        for (int i = 0; i < numSubtasks; i++) {
+            assertStateEmpty(vertices, operatorId, i, extractor);
+        }
+    }
+
+    private void assertStateEmpty(
+            Map<OperatorID, ExecutionJobVertex> vertices,
+            OperatorID operatorId,
+            int newSubtaskIndex,
+            Function<OperatorSubtaskState, StateObjectCollection<?>> 
extractor) {
+        final OperatorSubtaskState subState =
+                getAssignedState(vertices.get(operatorId), operatorId, 
newSubtaskIndex);
+
+        assertThat(extractor.apply(subState).hasState())
+                .as("State should be empty for operator %s subtask %d", 
operatorId, newSubtaskIndex)
+                .isFalse();
+    }
+
+    @ParameterizedTest
+    @ValueSource(booleans = {true, false})
+    void assigningStatesShouldWorkWithUserDefinedOperatorIdsAsWell(
+            boolean recoverOutputOnDownstreamTask) {
         int numSubTasks = 1;
         OperatorID operatorId = new OperatorID();
         OperatorID userDefinedOperatorId = new OperatorID();
@@ -1054,7 +1283,12 @@ class StateAssignmentOperationTest {
                 buildExecutionJobVertex(operatorId, userDefinedOperatorId, 1);
         Map<OperatorID, OperatorState> states = 
buildOperatorStates(operatorIds, numSubTasks);
 
-        new StateAssignmentOperation(0, 
Collections.singleton(executionJobVertex), states, false)
+        new StateAssignmentOperation(
+                        0,
+                        Collections.singleton(executionJobVertex),
+                        states,
+                        false,
+                        recoverOutputOnDownstreamTask)
                 .assignStates();
 
         assertThat(getAssignedState(executionJobVertex, operatorId, 0))
@@ -1081,10 +1315,101 @@ class StateAssignmentOperationTest {
                 .collect(Collectors.toList());
     }
 
+    /**
+     * Asserts the upstream output buffer state for a specific subtask by 
verifying the expected
+     * upstream subtask and subpartition mappings.
+     *
+     * <p>The indexes parameter contains pairs of (subtaskIndex, 
subpartitionIndex) that define the
+     * expected upstream sources for the current subtask's input channels.
+     *
+     * @param jobResultSubpartitionHandlers the job result subpartition 
handlers
+     * @param executionJobVertex the execution job vertex
+     * @param operatorId the operator ID to check
+     * @param subtaskIndex the subtask index to verify
+     * @param upstreamOperator the upstream operator ID
+     * @param indexes pairs of (subtaskIndex, subpartitionIndex) defining 
expected upstream sources
+     */
+    private void assertUpstreamOutputBufferState(
+            JobResultSubpartitionHandlers jobResultSubpartitionHandlers,
+            ExecutionJobVertex executionJobVertex,
+            OperatorID operatorId,
+            int subtaskIndex,
+            OperatorID upstreamOperator,
+            int[] indexes) {
+        checkArgument(
+                indexes.length % 2 == 0,
+                "indexes must contain pairs of (subtaskIndex, 
subpartitionIndex)");
+
+        ArrayList<ResultSubpartitionStateHandle> originals = new 
ArrayList<>(indexes.length / 2);
+        for (int i = 0; i < indexes.length; i += 2) {
+            int upstreamSubtaskIndex = indexes[i];
+            int upstreamSubpartitionId = indexes[i + 1];
+            originals.add(
+                    jobResultSubpartitionHandlers.getHandler(
+                            upstreamOperator, upstreamSubtaskIndex, 
upstreamSubpartitionId));
+        }
+
+        List<InputChannelStateHandle> transformed =
+                getAssignedState(executionJobVertex, operatorId, subtaskIndex)
+                        .getUpstreamOutputBufferState()
+                        .asList();
+        assertStrictPerfectOneToOneMatch(originals, transformed);
+    }
+
+    /**
+     * Core logic: Assert that two collections have strictly equal sizes and 
each element in the
+     * Result collection can find a unique, valid match in the Input 
collection (perfect 1:1 match).
+     */
+    private void assertStrictPerfectOneToOneMatch(
+            List<ResultSubpartitionStateHandle> originals,
+            List<InputChannelStateHandle> transformed) {
+
+        assertThat(originals)
+                .as(
+                        "Verify that the Result collection has the same number 
of elements as the Input collection")
+                .hasSameSizeAs(transformed);
+
+        if (originals.isEmpty()) {
+            return;
+        }
+
+        // Used to store already matched Input elements to ensure 1:1 
constraint
+        Set<InputChannelStateHandle> usedInputs = new HashSet<>();
+
+        originals.forEach(
+                r -> {
+                    List<InputChannelStateHandle> matchedInputHandles =
+                            transformed.stream()
+                                    .filter(t -> 
verifyTransformationConsistency(r, t))
+                                    .collect(Collectors.toList());
+                    assertThat(matchedInputHandles).hasSize(1);
+
+                    InputChannelStateHandle matchedInputHandle =
+                            Iterables.getOnlyElement(matchedInputHandles);
+                    assertThat(usedInputs).doesNotContain(matchedInputHandle);
+                    usedInputs.add(matchedInputHandle);
+                });
+    }
+
+    private boolean verifyTransformationConsistency(
+            ResultSubpartitionStateHandle original, InputChannelStateHandle 
transformed) {
+        return original.getDelegate() == transformed.getDelegate()
+                && original.getOffsets().equals(transformed.getOffsets())
+                && original.getStateSize() == transformed.getStateSize()
+                && original.getSubtaskIndex() == 
transformed.getInfo().getInputChannelIdx();
+    }
+
     private Map<OperatorID, OperatorState> buildOperatorStates(
             List<OperatorID> operatorIDs, int numSubTasks) {
+        return buildOperatorStates(
+                operatorIDs,
+                numSubTasks,
+                new JobResultSubpartitionHandlers(operatorIDs, numSubTasks));
+    }
+
+    private Map<OperatorID, OperatorState> buildOperatorStates(
+            List<OperatorID> operatorIDs, int numSubTasks, 
JobResultSubpartitionHandlers handlers) {
         Random random = new Random();
-        final OperatorID lastId = operatorIDs.get(operatorIDs.size() - 1);
         return operatorIDs.stream()
                 .collect(
                         Collectors.toMap(
@@ -1136,17 +1461,8 @@ class StateAssignmentOperationTest {
                                                                                
                 10,
                                                                                
                 random))))
                                                         
.setResultSubpartitionState(
-                                                                operatorID == 
lastId
-                                                                        ? 
StateObjectCollection
-                                                                               
 .empty()
-                                                                        : new 
StateObjectCollection<>(
-                                                                               
 asList(
-                                                                               
         createNewResultSubpartitionStateHandle(
-                                                                               
                 10,
-                                                                               
                 random),
-                                                                               
         createNewResultSubpartitionStateHandle(
-                                                                               
                 10,
-                                                                               
                 random))))
+                                                                
handlers.getStateObjectCollection(
+                                                                        
operatorID, i))
                                                         .build());
                                     }
                                     return state;
@@ -1441,7 +1757,7 @@ class StateAssignmentOperationTest {
         Map<OperatorID, ExecutionJobVertex> vertices = 
toExecutionVertices(source, map1, map2);
 
         // This should not throw UnsupportedOperationException
-        new StateAssignmentOperation(0, new HashSet<>(vertices.values()), 
states, false)
+        new StateAssignmentOperation(0, new HashSet<>(vertices.values()), 
states, false, false)
                 .assignStates();
 
         // Verify state assignment succeeded
@@ -1521,7 +1837,7 @@ class StateAssignmentOperationTest {
                 toExecutionVertices(upstream1, upstream2, upstream3, 
downstream);
 
         // This should not throw UnsupportedOperationException
-        new StateAssignmentOperation(0, new HashSet<>(vertices.values()), 
states, false)
+        new StateAssignmentOperation(0, new HashSet<>(vertices.values()), 
states, false, false)
                 .assignStates();
 
         // Verify downstream received state
@@ -1582,7 +1898,7 @@ class StateAssignmentOperationTest {
         Map<OperatorID, ExecutionJobVertex> vertices = 
toExecutionVertices(source, sink);
 
         // This should succeed even with RESCALE partitioner when parallelism 
changes
-        new StateAssignmentOperation(0, new HashSet<>(vertices.values()), 
states, false)
+        new StateAssignmentOperation(0, new HashSet<>(vertices.values()), 
states, false, false)
                 .assignStates();
 
         // Verify state was assigned
@@ -1590,4 +1906,78 @@ class StateAssignmentOperationTest {
                 getAssignedState(vertices.get(operatorIds.get(1)), 
operatorIds.get(1), 0);
         assertThat(sinkAssignedState).isNotNull();
     }
+
+    /**
+     * Test utility class that manages ResultSubpartitionStateHandles for all 
operators in a job.
+     */
+    private static class JobResultSubpartitionHandlers {
+
+        private final Map<OperatorID, Map<Integer, 
SubtaskResultSubpartitionHandlers>> handlers;
+
+        public JobResultSubpartitionHandlers(List<OperatorID> operatorIDs, int 
numSubTasks) {
+            this.handlers = new HashMap<>(operatorIDs.size() - 1);
+
+            Random random = new Random();
+            final OperatorID lastId = operatorIDs.get(operatorIDs.size() - 1);
+            for (OperatorID operatorID : operatorIDs) {
+                if (operatorID == lastId) {
+                    // The last operator does not contain output buffers.
+                    return;
+                }
+                Map<Integer, SubtaskResultSubpartitionHandlers> 
operatorHandlers = new HashMap<>();
+                for (int subtaskIndex = 0; subtaskIndex < numSubTasks; 
subtaskIndex++) {
+                    operatorHandlers.put(
+                            subtaskIndex,
+                            new SubtaskResultSubpartitionHandlers(random, 
numSubTasks));
+                }
+                handlers.put(operatorID, operatorHandlers);
+            }
+        }
+
+        /**
+         * Returns the StateObjectCollection of OutputStateHandles for the 
specified operator and
+         * subtask.
+         */
+        private StateObjectCollection<OutputStateHandle> 
getStateObjectCollection(
+                OperatorID operatorID, int subtaskIndex) {
+            Map<Integer, SubtaskResultSubpartitionHandlers> operatorHandlers =
+                    handlers.get(operatorID);
+            if (operatorHandlers == null) {
+                return StateObjectCollection.empty();
+            }
+
+            SubtaskResultSubpartitionHandlers 
subtaskResultSubpartitionHandlers =
+                    operatorHandlers.get(subtaskIndex);
+            checkArgument(
+                    subtaskResultSubpartitionHandlers != null,
+                    "Subtask result subpartition handler not found for subtask 
%s.",
+                    subtaskIndex);
+            // Ensure there is output buffer in each subpartition.
+            return new StateObjectCollection<>(
+                    new 
ArrayList<>(subtaskResultSubpartitionHandlers.handlers.values()));
+        }
+
+        private ResultSubpartitionStateHandle getHandler(
+                OperatorID operatorID, int subtaskIndex, int 
subpartitionIndex) {
+            return 
handlers.get(operatorID).get(subtaskIndex).getHandler(subpartitionIndex);
+        }
+    }
+
+    /** Test utility class that manages ResultSubpartitionStateHandles for a 
single subtask. */
+    private static class SubtaskResultSubpartitionHandlers {
+
+        // The key is subpartition index
+        private final Map<Integer, ResultSubpartitionStateHandle> handlers;
+
+        public SubtaskResultSubpartitionHandlers(Random random, int 
numSubpartitions) {
+            this.handlers = new HashMap<>(numSubpartitions);
+            for (int i = 0; i < numSubpartitions; i++) {
+                handlers.put(i, createNewResultSubpartitionStateHandle(10, 0, 
i, random));
+            }
+        }
+
+        private ResultSubpartitionStateHandle getHandler(int 
subpartitionIndex) {
+            return handlers.get(subpartitionIndex);
+        }
+    }
 }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateHandleDummyUtil.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateHandleDummyUtil.java
index 7007821a6fe..088dc6b7af7 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateHandleDummyUtil.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateHandleDummyUtil.java
@@ -162,8 +162,14 @@ public class StateHandleDummyUtil {
 
     public static ResultSubpartitionStateHandle 
createNewResultSubpartitionStateHandle(
             int numNamedStates, int partitionIndex, Random random) {
+        return createNewResultSubpartitionStateHandle(
+                numNamedStates, partitionIndex, random.nextInt(), random);
+    }
+
+    public static ResultSubpartitionStateHandle 
createNewResultSubpartitionStateHandle(
+            int numNamedStates, int partitionIndex, int subPartitionIdx, 
Random random) {
         return new ResultSubpartitionStateHandle(
-                new ResultSubpartitionInfo(partitionIndex, random.nextInt()),
+                new ResultSubpartitionInfo(partitionIndex, subPartitionIdx),
                 createStreamStateHandle(numNamedStates, random),
                 genOffsets(numNamedStates, random));
     }
diff --git 
a/flink-tests/src/test/java/org/apache/flink/test/state/ChangelogRecoveryCachingITCase.java
 
b/flink-tests/src/test/java/org/apache/flink/test/state/ChangelogRecoveryCachingITCase.java
index c306a81a0c2..b6511bba78a 100644
--- 
a/flink-tests/src/test/java/org/apache/flink/test/state/ChangelogRecoveryCachingITCase.java
+++ 
b/flink-tests/src/test/java/org/apache/flink/test/state/ChangelogRecoveryCachingITCase.java
@@ -64,6 +64,7 @@ import static 
org.apache.flink.changelog.fs.FsStateChangelogOptions.PREEMPTIVE_P
 import static 
org.apache.flink.configuration.CheckpointingOptions.CHECKPOINTS_DIRECTORY;
 import static 
org.apache.flink.configuration.CheckpointingOptions.CHECKPOINT_STORAGE;
 import static 
org.apache.flink.configuration.CheckpointingOptions.FILE_MERGING_ENABLED;
+import static 
org.apache.flink.configuration.CheckpointingOptions.UNALIGNED_RECOVER_OUTPUT_ON_DOWNSTREAM;
 import static org.apache.flink.configuration.CoreOptions.DEFAULT_PARALLELISM;
 import static 
org.apache.flink.configuration.ExternalizedCheckpointRetention.RETAIN_ON_CANCELLATION;
 import static 
org.apache.flink.configuration.RestartStrategyOptions.RESTART_STRATEGY;
@@ -179,6 +180,9 @@ public class ChangelogRecoveryCachingITCase extends 
TestLogger {
         conf.set(PERIODIC_MATERIALIZATION_ENABLED, false);
 
         conf.set(CheckpointingOptions.ENABLE_UNALIGNED, true); // speedup
+        // Disable UNALIGNED_ALLOW_ON_RECOVERY to prevent randomization since 
the output buffer
+        // states file may be opened from multiple downstream subtasks
+        conf.set(UNALIGNED_RECOVER_OUTPUT_ON_DOWNSTREAM, false);
         conf.set(
                 CheckpointingOptions.ALIGNED_CHECKPOINT_TIMEOUT,
                 Duration.ZERO); // prevent randomization

Reply via email to