This is an automated email from the ASF dual-hosted git repository.

fanrui pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit d1914c63c95715086884fee20cfe08fe8388547b
Author: Rui Fan <[email protected]>
AuthorDate: Wed Feb 18 21:26:32 2026 +0100

    [FLINK-39018][network] Buffer migration from RecoveredInputChannel to 
physical channels
---
 .../partition/consumer/LocalInputChannel.java      | 28 ++++++++++++-
 .../consumer/LocalRecoveredInputChannel.java       |  8 +++-
 .../partition/consumer/RecoveredInputChannel.java  | 21 +++++++++-
 .../partition/consumer/RemoteInputChannel.java     | 48 +++++++++++++++++++---
 .../consumer/RemoteRecoveredInputChannel.java      |  8 +++-
 .../partition/consumer/SingleInputGate.java        | 41 ++++++++++++++----
 .../partition/consumer/UnknownInputChannel.java    |  7 +++-
 ...editBasedPartitionRequestClientHandlerTest.java |  4 +-
 .../netty/PartitionRequestRegistrationTest.java    |  4 +-
 .../partition/consumer/InputChannelBuilder.java    |  7 +++-
 .../partition/consumer/LocalInputChannelTest.java  | 37 +++++++++++++++++
 .../consumer/RecoveredInputChannelTest.java        |  5 ++-
 .../partition/consumer/RemoteInputChannelTest.java | 46 +++++++++++++++++++++
 .../benchmark/SingleInputGateBenchmarkFactory.java |  7 +++-
 14 files changed, 242 insertions(+), 29 deletions(-)

diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java
index 2833adecb58..661e4b063c7 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java
@@ -98,7 +98,8 @@ public class LocalInputChannel extends InputChannel 
implements BufferAvailabilit
             int maxBackoff,
             Counter numBytesIn,
             Counter numBuffersIn,
-            ChannelStateWriter stateWriter) {
+            ChannelStateWriter stateWriter,
+            ArrayDeque<Buffer> initialRecoveredBuffers) {
 
         super(
                 inputGate,
@@ -113,6 +114,31 @@ public class LocalInputChannel extends InputChannel 
implements BufferAvailabilit
         this.partitionManager = checkNotNull(partitionManager);
         this.taskEventPublisher = checkNotNull(taskEventPublisher);
         this.channelStatePersister = new ChannelStatePersister(stateWriter, 
getChannelInfo());
+
+        // Migrate recovered buffers from RecoveredInputChannel if provided.
+        // These buffers have been filtered but not yet consumed by the Task.
+        if (!initialRecoveredBuffers.isEmpty()) {
+            final int expectedCount = initialRecoveredBuffers.size();
+            // Sequence number starts at Integer.MIN_VALUE, consistent with 
RecoveredInputChannel.
+            int seqNum = Integer.MIN_VALUE;
+            while (!initialRecoveredBuffers.isEmpty()) {
+                Buffer buffer = initialRecoveredBuffers.poll();
+                // Determine next data type based on the next buffer in the 
queue
+                Buffer.DataType nextDataType =
+                        initialRecoveredBuffers.isEmpty()
+                                ? Buffer.DataType.NONE
+                                : initialRecoveredBuffers.peek().getDataType();
+                // buffersInBacklog is set to 0 as these are recovered buffers
+                BufferAndBacklog bufferAndBacklog =
+                        new BufferAndBacklog(buffer, 0, nextDataType, 
seqNum++);
+                toBeConsumedBuffers.add(bufferAndBacklog);
+            }
+            checkState(
+                    toBeConsumedBuffers.size() == expectedCount,
+                    "Buffer migration failed: expected %s buffers but got %s",
+                    expectedCount,
+                    toBeConsumedBuffers.size());
+        }
     }
 
     // ------------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalRecoveredInputChannel.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalRecoveredInputChannel.java
index 784444d63e8..bdde2244f38 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalRecoveredInputChannel.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalRecoveredInputChannel.java
@@ -19,11 +19,14 @@
 package org.apache.flink.runtime.io.network.partition.consumer;
 
 import org.apache.flink.runtime.io.network.TaskEventPublisher;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.io.network.metrics.InputChannelMetrics;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionManager;
 import 
org.apache.flink.runtime.io.network.partition.ResultSubpartitionIndexSet;
 
+import java.util.ArrayDeque;
+
 import static org.apache.flink.util.Preconditions.checkNotNull;
 
 /**
@@ -61,7 +64,7 @@ public class LocalRecoveredInputChannel extends 
RecoveredInputChannel {
     }
 
     @Override
-    protected InputChannel toInputChannelInternal() {
+    protected InputChannel toInputChannelInternal(ArrayDeque<Buffer> 
remainingBuffers) {
         return new LocalInputChannel(
                 inputGate,
                 getChannelIndex(),
@@ -73,6 +76,7 @@ public class LocalRecoveredInputChannel extends 
RecoveredInputChannel {
                 maxBackoff,
                 numBytesIn,
                 numBuffersIn,
-                channelStateWriter);
+                channelStateWriter,
+                remainingBuffers);
     }
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannel.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannel.java
index e809e952a28..d2a7a07137d 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannel.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannel.java
@@ -111,7 +111,16 @@ public abstract class RecoveredInputChannel extends 
InputChannel implements Chan
     public final InputChannel toInputChannel() throws IOException {
         Preconditions.checkState(
                 stateConsumedFuture.isDone(), "recovered state is not fully 
consumed");
-        final InputChannel inputChannel = toInputChannelInternal();
+
+        // Extract remaining buffers before conversion.
+        // These buffers have been filtered but not yet consumed by the Task.
+        final ArrayDeque<Buffer> remainingBuffers;
+        synchronized (receivedBuffers) {
+            remainingBuffers = new ArrayDeque<>(receivedBuffers);
+            receivedBuffers.clear();
+        }
+
+        final InputChannel inputChannel = 
toInputChannelInternal(remainingBuffers);
         inputChannel.checkpointStopped(lastStoppedCheckpointId);
         return inputChannel;
     }
@@ -121,7 +130,15 @@ public abstract class RecoveredInputChannel extends 
InputChannel implements Chan
         this.lastStoppedCheckpointId = checkpointId;
     }
 
-    protected abstract InputChannel toInputChannelInternal() throws 
IOException;
+    /**
+     * Creates the physical InputChannel from this recovered channel.
+     *
+     * @param remainingBuffers buffers that have been filtered but not yet 
consumed by the Task.
+     *     These buffers will be migrated to the new physical channel.
+     * @return the physical InputChannel (LocalInputChannel or 
RemoteInputChannel)
+     */
+    protected abstract InputChannel toInputChannelInternal(ArrayDeque<Buffer> 
remainingBuffers)
+            throws IOException;
 
     CompletableFuture<?> getStateConsumedFuture() {
         return stateConsumedFuture;
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java
index 3430196775c..66a7d500140 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java
@@ -138,7 +138,8 @@ public class RemoteInputChannel extends InputChannel {
             int networkBuffersPerChannel,
             Counter numBytesIn,
             Counter numBuffersIn,
-            ChannelStateWriter stateWriter) {
+            ChannelStateWriter stateWriter,
+            ArrayDeque<Buffer> initialRecoveredBuffers) {
 
         super(
                 inputGate,
@@ -157,6 +158,29 @@ public class RemoteInputChannel extends InputChannel {
         this.connectionManager = checkNotNull(connectionManager);
         this.bufferManager = new 
BufferManager(inputGate.getMemorySegmentProvider(), this, 0);
         this.channelStatePersister = new ChannelStatePersister(stateWriter, 
getChannelInfo());
+
+        // Migrate recovered buffers from RecoveredInputChannel if provided.
+        // These buffers have been filtered but not yet consumed by the Task.
+        if (!initialRecoveredBuffers.isEmpty()) {
+            final int expectedCount = initialRecoveredBuffers.size();
+            // Sequence number starts at Integer.MIN_VALUE, consistent with 
RecoveredInputChannel.
+            int seqNum = Integer.MIN_VALUE;
+            for (Buffer buffer : initialRecoveredBuffers) {
+                // subpartitionId is set to 0 for recovered buffers. This is 
correct because:
+                // 1) For single-subpartition channels, the only valid 
subpartition is 0.
+                // 2) For multi-subpartition channels 
(consumedSubpartitionIndexSet.size() > 1),
+                //    RecoveryMetadata events embedded in the recovered buffer 
sequence track
+                //    the actual subpartition context for proper routing.
+                SequenceBuffer sequenceBuffer = new SequenceBuffer(buffer, 
seqNum++, 0);
+                receivedBuffers.add(sequenceBuffer);
+                totalQueueSizeInBytes += buffer.getSize();
+            }
+            checkState(
+                    receivedBuffers.size() == expectedCount,
+                    "Buffer migration failed: expected %s buffers but got %s",
+                    expectedCount,
+                    receivedBuffers.size());
+        }
     }
 
     @VisibleForTesting
@@ -239,9 +263,9 @@ public class RemoteInputChannel extends InputChannel {
 
     @Override
     protected int peekNextBufferSubpartitionIdInternal() throws IOException {
-        checkPartitionRequestQueueInitialized();
-
         synchronized (receivedBuffers) {
+            checkReadability();
+
             final SequenceBuffer next = receivedBuffers.peek();
 
             if (next != null) {
@@ -254,12 +278,12 @@ public class RemoteInputChannel extends InputChannel {
 
     @Override
     public Optional<BufferAndAvailability> getNextBuffer() throws IOException {
-        checkPartitionRequestQueueInitialized();
-
         final SequenceBuffer next;
         final DataType nextDataType;
 
         synchronized (receivedBuffers) {
+            checkReadability();
+
             next = receivedBuffers.poll();
 
             if (next != null) {
@@ -879,6 +903,20 @@ public class RemoteInputChannel extends InputChannel {
         setError(cause);
     }
 
+    /**
+     * When receivedBuffers contains migrated buffers from 
RecoveredInputChannel, they can be read
+     * before requestSubpartitions(). In that case only check for errors. Once 
migrated buffers are
+     * drained, require full client initialization check.
+     */
+    private void checkReadability() throws IOException {
+        assert Thread.holdsLock(receivedBuffers);
+        if (receivedBuffers.isEmpty()) {
+            checkPartitionRequestQueueInitialized();
+        } else {
+            checkError();
+        }
+    }
+
     private void checkPartitionRequestQueueInitialized() throws IOException {
         checkError();
         checkState(
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteRecoveredInputChannel.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteRecoveredInputChannel.java
index cbaddbcbfa0..2cfff6f5e79 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteRecoveredInputChannel.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteRecoveredInputChannel.java
@@ -20,11 +20,13 @@ package 
org.apache.flink.runtime.io.network.partition.consumer;
 
 import org.apache.flink.runtime.io.network.ConnectionID;
 import org.apache.flink.runtime.io.network.ConnectionManager;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.io.network.metrics.InputChannelMetrics;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
 import 
org.apache.flink.runtime.io.network.partition.ResultSubpartitionIndexSet;
 
 import java.io.IOException;
+import java.util.ArrayDeque;
 
 import static org.apache.flink.util.Preconditions.checkNotNull;
 
@@ -66,7 +68,8 @@ public class RemoteRecoveredInputChannel extends 
RecoveredInputChannel {
     }
 
     @Override
-    protected InputChannel toInputChannelInternal() throws IOException {
+    protected InputChannel toInputChannelInternal(ArrayDeque<Buffer> 
remainingBuffers)
+            throws IOException {
         RemoteInputChannel remoteInputChannel =
                 new RemoteInputChannel(
                         inputGate,
@@ -81,7 +84,8 @@ public class RemoteRecoveredInputChannel extends 
RecoveredInputChannel {
                         networkBuffersPerChannel,
                         numBytesIn,
                         numBuffersIn,
-                        channelStateWriter);
+                        channelStateWriter,
+                        remainingBuffers);
         remoteInputChannel.setup();
         return remoteInputChannel;
     }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
index d845eef6294..2847e36fcc2 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
@@ -375,6 +375,10 @@ public class SingleInputGate extends IndexedInputGate {
         }
     }
 
+    /**
+     * Converts all {@link RecoveredInputChannel}s to their real channel types 
({@link
+     * LocalInputChannel} or {@link RemoteInputChannel}).
+     */
     @VisibleForTesting
     public void convertRecoveredInputChannels() {
         LOG.debug("Converting recovered input channels ({} channels)", 
getNumberOfInputChannels());
@@ -384,19 +388,40 @@ public class SingleInputGate extends IndexedInputGate {
                     new HashSet<>(inputChannelsForCurrentPartition.keySet());
             for (InputChannelInfo inputChannelInfo : oldInputChannelInfos) {
                 InputChannel inputChannel = 
inputChannelsForCurrentPartition.get(inputChannelInfo);
-                if (inputChannel instanceof RecoveredInputChannel) {
-                    try {
-                        InputChannel realInputChannel =
-                                ((RecoveredInputChannel) 
inputChannel).toInputChannel();
-                        inputChannel.releaseAllResources();
+                if (!(inputChannel instanceof RecoveredInputChannel)) {
+                    continue;
+                }
+                try {
+                    // Phase 1: Convert channel and release resources outside 
the lock.
+                    // These calls may acquire the receivedBuffers lock 
internally, so they
+                    // run outside inputChannelsWithData lock to maintain a 
consistent lock
+                    // order with onRecoveredStateBuffer() which acquires 
receivedBuffers
+                    // first and then inputChannelsWithData.
+                    InputChannel realInputChannel =
+                            ((RecoveredInputChannel) 
inputChannel).toInputChannel();
+                    inputChannel.releaseAllResources();
+                    int buffersInUseCount = 
realInputChannel.getBuffersInUseCount();
+
+                    // Phase 2: Atomically update data structures under the 
lock.
+                    synchronized (inputChannelsWithData) {
+                        if (inputChannelsWithData.contains(inputChannel)) {
+                            inputChannelsWithData.getAndRemove(ch -> ch == 
inputChannel);
+                        }
+                        
enqueuedInputChannelsWithData.clear(inputChannel.getChannelIndex());
+
                         
inputChannelsForCurrentPartition.remove(inputChannelInfo);
                         inputChannelsForCurrentPartition.put(
                                 realInputChannel.getChannelInfo(), 
realInputChannel);
                         channels[inputChannel.getChannelIndex()] = 
realInputChannel;
-                    } catch (Throwable t) {
-                        inputChannel.setError(t);
-                        return;
+
+                        if (buffersInUseCount > 0) {
+                            inputChannelsWithData.add(realInputChannel);
+                            
enqueuedInputChannelsWithData.set(realInputChannel.getChannelIndex());
+                        }
                     }
+                } catch (Throwable t) {
+                    inputChannel.setError(t);
+                    return;
                 }
             }
         }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnknownInputChannel.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnknownInputChannel.java
index 2ff8aa73bcd..15182cedadb 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnknownInputChannel.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnknownInputChannel.java
@@ -35,6 +35,7 @@ import org.apache.flink.util.Preconditions;
 import javax.annotation.Nullable;
 
 import java.io.IOException;
+import java.util.ArrayDeque;
 import java.util.Optional;
 
 import static 
org.apache.flink.runtime.checkpoint.CheckpointFailureReason.CHECKPOINT_DECLINED_TASK_NOT_READY;
@@ -183,7 +184,8 @@ class UnknownInputChannel extends InputChannel implements 
ChannelStateHolder {
                 networkBuffersPerChannel,
                 metrics.getNumBytesInRemoteCounter(),
                 metrics.getNumBuffersInRemoteCounter(),
-                channelStateWriter == null ? ChannelStateWriter.NO_OP : 
channelStateWriter);
+                channelStateWriter == null ? ChannelStateWriter.NO_OP : 
channelStateWriter,
+                new ArrayDeque<>());
     }
 
     public LocalInputChannel toLocalInputChannel(ResultPartitionID 
resultPartitionID) {
@@ -198,7 +200,8 @@ class UnknownInputChannel extends InputChannel implements 
ChannelStateHolder {
                 maxBackoff,
                 metrics.getNumBytesInLocalCounter(),
                 metrics.getNumBuffersInLocalCounter(),
-                channelStateWriter == null ? ChannelStateWriter.NO_OP : 
channelStateWriter);
+                channelStateWriter == null ? ChannelStateWriter.NO_OP : 
channelStateWriter,
+                new ArrayDeque<>());
     }
 
     @Override
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CreditBasedPartitionRequestClientHandlerTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CreditBasedPartitionRequestClientHandlerTest.java
index 48b29aeaeaa..d96ed78b6a0 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CreditBasedPartitionRequestClientHandlerTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/CreditBasedPartitionRequestClientHandlerTest.java
@@ -68,6 +68,7 @@ import org.junit.jupiter.params.provider.MethodSource;
 
 import java.io.IOException;
 import java.net.InetSocketAddress;
+import java.util.ArrayDeque;
 import java.util.stream.Stream;
 
 import static 
org.apache.flink.runtime.io.network.netty.PartitionRequestQueueTest.blockChannel;
@@ -951,7 +952,8 @@ class CreditBasedPartitionRequestClientHandlerTest {
                     2,
                     new SimpleCounter(),
                     new SimpleCounter(),
-                    ChannelStateWriter.NO_OP);
+                    ChannelStateWriter.NO_OP,
+                    new ArrayDeque<>());
             this.expectedMessage = expectedMessage;
         }
 
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestRegistrationTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestRegistrationTest.java
index 3b590d3a256..e3cfb55e340 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestRegistrationTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestRegistrationTest.java
@@ -44,6 +44,7 @@ import 
org.apache.flink.shaded.netty4.io.netty.channel.Channel;
 
 import org.junit.jupiter.api.Test;
 
+import java.util.ArrayDeque;
 import java.util.Optional;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
@@ -248,7 +249,8 @@ class PartitionRequestRegistrationTest {
                     2,
                     new SimpleCounter(),
                     new SimpleCounter(),
-                    ChannelStateWriter.NO_OP);
+                    ChannelStateWriter.NO_OP,
+                    new ArrayDeque<>());
             this.latch = latch;
         }
 
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/InputChannelBuilder.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/InputChannelBuilder.java
index f5e810bb605..08f65d9fe72 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/InputChannelBuilder.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/InputChannelBuilder.java
@@ -34,6 +34,7 @@ import 
org.apache.flink.runtime.io.network.partition.ResultPartitionManager;
 import 
org.apache.flink.runtime.io.network.partition.ResultSubpartitionIndexSet;
 
 import java.net.InetSocketAddress;
+import java.util.ArrayDeque;
 
 import static 
org.apache.flink.runtime.io.network.partition.consumer.SingleInputGateTest.TestingResultPartitionManager;
 
@@ -164,7 +165,8 @@ public class InputChannelBuilder {
                 maxBackoff,
                 metrics.getNumBytesInLocalCounter(),
                 metrics.getNumBuffersInLocalCounter(),
-                stateWriter);
+                stateWriter,
+                new ArrayDeque<>());
     }
 
     public RemoteInputChannel buildRemoteChannel(SingleInputGate inputGate) {
@@ -181,7 +183,8 @@ public class InputChannelBuilder {
                 networkBuffersPerChannel,
                 metrics.getNumBytesInRemoteCounter(),
                 metrics.getNumBuffersInRemoteCounter(),
-                stateWriter);
+                stateWriter,
+                new ArrayDeque<>());
     }
 
     public LocalRecoveredInputChannel 
buildLocalRecoveredChannel(SingleInputGate inputGate) {
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java
index aeb765f79b9..86bda9866d2 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java
@@ -713,6 +713,43 @@ class LocalInputChannelTest {
         
assertThat(localChannel.unsynchronizedGetNumberOfQueuedBuffers()).isEqualTo(5);
     }
 
+    @Test
+    void testGetNextBufferWithMigratedRecoveredBuffers() throws Exception {
+        // given: LocalInputChannel with recovered buffers migrated from 
RecoveredInputChannel
+        SingleInputGate inputGate = createSingleInputGate(1);
+
+        ArrayDeque<Buffer> recoveredBuffers = new ArrayDeque<>();
+        recoveredBuffers.add(TestBufferFactory.createBuffer(10));
+        recoveredBuffers.add(TestBufferFactory.createBuffer(20));
+
+        LocalInputChannel channel =
+                new LocalInputChannel(
+                        inputGate,
+                        0,
+                        new ResultPartitionID(),
+                        new ResultSubpartitionIndexSet(0),
+                        new ResultPartitionManager(),
+                        new TaskEventDispatcher(),
+                        0,
+                        0,
+                        new SimpleCounter(),
+                        new SimpleCounter(),
+                        ChannelStateWriter.NO_OP,
+                        recoveredBuffers);
+
+        inputGate.setInputChannels(channel);
+
+        // then: Can read recovered buffers even before requestSubpartitions()
+        Optional<InputChannel.BufferAndAvailability> first = 
channel.getNextBuffer();
+        assertThat(first).isPresent();
+        assertThat(first.get().buffer().getSize()).isEqualTo(10);
+        assertThat(first.get().moreAvailable()).isTrue();
+
+        Optional<InputChannel.BufferAndAvailability> second = 
channel.getNextBuffer();
+        assertThat(second).isPresent();
+        assertThat(second.get().buffer().getSize()).isEqualTo(20);
+    }
+
     @Test
     void testCheckpointStartedPersistsRecoveredBuffers() throws Exception {
         // given: Local input channel with recovered buffers
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannelTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannelTest.java
index ab7dd142c1e..5985a81e8ca 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannelTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RecoveredInputChannelTest.java
@@ -22,11 +22,14 @@ import org.apache.flink.metrics.SimpleCounter;
 import org.apache.flink.runtime.checkpoint.CheckpointException;
 import org.apache.flink.runtime.checkpoint.CheckpointType;
 import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
 import 
org.apache.flink.runtime.io.network.partition.ResultSubpartitionIndexSet;
 
 import org.junit.jupiter.api.Test;
 
+import java.util.ArrayDeque;
+
 import static org.apache.flink.runtime.checkpoint.CheckpointOptions.unaligned;
 import static 
org.apache.flink.runtime.state.CheckpointStorageLocationReference.getDefault;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
@@ -74,7 +77,7 @@ class RecoveredInputChannelTest {
                     new SimpleCounter(),
                     10) {
                 @Override
-                protected InputChannel toInputChannelInternal() {
+                protected InputChannel 
toInputChannelInternal(ArrayDeque<Buffer> remainingBuffers) {
                     throw new AssertionError("channel conversion succeeded");
                 }
             };
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java
index 4f5abdd4271..e47de93c9e8 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java
@@ -21,6 +21,7 @@ package 
org.apache.flink.runtime.io.network.partition.consumer;
 import org.apache.flink.core.memory.MemorySegment;
 import org.apache.flink.core.memory.MemorySegmentFactory;
 import org.apache.flink.core.testutils.OneShotLatch;
+import org.apache.flink.metrics.SimpleCounter;
 import org.apache.flink.runtime.checkpoint.CheckpointException;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.checkpoint.CheckpointType;
@@ -2073,6 +2074,51 @@ class RemoteInputChannelTest {
         }
     }
 
+    @Test
+    void testGetNextBufferWithMigratedRecoveredBuffers() throws Exception {
+        // given: RemoteInputChannel with recovered buffers migrated from 
RecoveredInputChannel
+        SingleInputGate inputGate = createSingleInputGate(1);
+
+        ArrayDeque<Buffer> recoveredBuffers = new ArrayDeque<>();
+        recoveredBuffers.add(TestBufferFactory.createBuffer(10));
+        recoveredBuffers.add(TestBufferFactory.createBuffer(20));
+
+        ConnectionID connectionId =
+                new ConnectionID(
+                        
org.apache.flink.runtime.clusterframework.types.ResourceID.generate(),
+                        new java.net.InetSocketAddress("localhost", 0),
+                        0);
+        RemoteInputChannel channel =
+                new RemoteInputChannel(
+                        inputGate,
+                        0,
+                        new ResultPartitionID(),
+                        new ResultSubpartitionIndexSet(0),
+                        connectionId,
+                        
InputChannelTestUtils.mockConnectionManagerWithPartitionRequestClient(
+                                mock(PartitionRequestClient.class)),
+                        0,
+                        0,
+                        0,
+                        2,
+                        new SimpleCounter(),
+                        new SimpleCounter(),
+                        ChannelStateWriter.NO_OP,
+                        recoveredBuffers);
+
+        inputGate.setInputChannels(channel);
+
+        // then: Can read recovered buffers even before requestSubpartitions()
+        Optional<BufferAndAvailability> first = channel.getNextBuffer();
+        assertThat(first).isPresent();
+        assertThat(first.get().buffer().getSize()).isEqualTo(10);
+        assertThat(first.get().moreAvailable()).isTrue();
+
+        Optional<BufferAndAvailability> second = channel.getNextBuffer();
+        assertThat(second).isPresent();
+        assertThat(second.get().buffer().getSize()).isEqualTo(20);
+    }
+
     private static final class TestBufferPool extends NoOpBufferPool {
 
         @Override
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/benchmark/SingleInputGateBenchmarkFactory.java
 
b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/benchmark/SingleInputGateBenchmarkFactory.java
index 49ecb6c6645..b850a7cc553 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/benchmark/SingleInputGateBenchmarkFactory.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/benchmark/SingleInputGateBenchmarkFactory.java
@@ -37,6 +37,7 @@ import 
org.apache.flink.runtime.shuffle.NettyShuffleDescriptor;
 import 
org.apache.flink.runtime.taskmanager.NettyShuffleEnvironmentConfiguration;
 
 import java.io.IOException;
+import java.util.ArrayDeque;
 
 /**
  * A benchmark-specific input gate factory which overrides the respective 
methods of creating {@link
@@ -128,7 +129,8 @@ public class SingleInputGateBenchmarkFactory extends 
SingleInputGateFactory {
                     maxBackoff,
                     metrics.getNumBytesInLocalCounter(),
                     metrics.getNumBuffersInLocalCounter(),
-                    ChannelStateWriter.NO_OP);
+                    ChannelStateWriter.NO_OP,
+                    new ArrayDeque<>());
         }
 
         @Override
@@ -183,7 +185,8 @@ public class SingleInputGateBenchmarkFactory extends 
SingleInputGateFactory {
                     networkBuffersPerChannel,
                     metrics.getNumBytesInRemoteCounter(),
                     metrics.getNumBuffersInRemoteCounter(),
-                    ChannelStateWriter.NO_OP);
+                    ChannelStateWriter.NO_OP,
+                    new ArrayDeque<>());
         }
 
         @Override

Reply via email to