This is an automated email from the ASF dual-hosted git repository.

1996fanrui pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 38a29c14fe3668f7ba20d4e8bb52ebeb135d8e18
Author: Rui Fan <[email protected]>
AuthorDate: Wed Apr 22 21:57:58 2026 +0200

    [FLINK-39519][checkpoint] Allocate pre-filter source buffers from a 
reusable heap segment
    
    In filtering-mode channel state recovery, InputChannelRecoveredStateHandler
    now allocates the pre-filter source buffer from heap instead of the Network
    Buffer Pool, eliminating the deadlock where pre-filter and post-filter
    buffers compete for the same pool in the single-threaded
    channel-state-unspilling recovery loop.
    
    Memory bound:
    - One MemorySegment per task, sized to memorySegmentSize, lazily allocated
      on the first getBuffer() call, reused across every subsequent call, and
      freed in close(). Worst-case footprint per task is therefore one
      memorySegmentSize (default 32 KB).
    - The one-buffer-at-a-time invariant is guaranteed structurally by Flink's
      serial recovery loop and the deserializer's isBufferConsumed contract,
      so no semaphore or per-gate counter is needed.
    - A runtime check asserts !preFilterBufferInUse before each allocation.
      The custom BufferRecycler flips the flag back on recycle without freeing
      the segment. Any future regression of the invariant fails loudly with
      IllegalStateException instead of silently corrupting memory.
    
    Non-filtering mode is unchanged: getBuffer() still delegates to the
    channel's Network Buffer Pool via requestBufferBlocking().
    
    Wiring:
    - RecordFilterContext gains a required memorySegmentSize parameter (with
      checkArgument > 0) and a getMemorySegmentSize() accessor. disabled()
      factory uses MemoryManager.DEFAULT_PAGE_SIZE.
    - StreamTask.createRecordFilterContext() passes
      ConfigurationParserUtils.getPageSize(jobConfiguration).
    - SequentialChannelStateReaderImpl passes
      filterContext.getMemorySegmentSize() to the handler constructor.
    
    Tests:
    - testPreFilterBufferIsolationFromNetworkBufferPool: filtering-mode
      getBuffer() returns a heap-backed buffer; the Network Buffer Pool is
      untouched.
    - testNonFilteringModeUsesNetworkBufferPool: non-filtering path preserved.
    - testPreFilterSegmentReusedAcrossCalls: successive getBuffer()/recycle
      cycles return the same MemorySegment instance.
    - testGetBufferThrowsWhenPriorBufferNotRecycled: runtime invariant check.
    - testPreFilterSegmentFreedOnClose: segment freed on handler close.
    - testMemorySegmentSizeExposedAndValidated: context validation and getter.
---
 .../channel/RecoveredChannelStateHandler.java      |  85 +++++++++++-
 .../channel/SequentialChannelStateReaderImpl.java  |   3 +-
 .../runtime/io/recovery/RecordFilterContext.java   |  27 +++-
 .../flink/streaming/runtime/tasks/StreamTask.java  |   3 +-
 .../InputChannelRecoveredStateHandlerTest.java     | 151 ++++++++++++++++++++-
 .../io/recovery/RecordFilterContextTest.java       |  61 ++++++++-
 .../VirtualChannelRecordFilterFactoryTest.java     |   4 +-
 7 files changed, 317 insertions(+), 17 deletions(-)

diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandler.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandler.java
index 31db728bc48..ca01ff37bd3 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandler.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandler.java
@@ -17,6 +17,9 @@
 
 package org.apache.flink.runtime.checkpoint.channel;
 
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.core.memory.MemorySegmentFactory;
 import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor;
 import org.apache.flink.runtime.checkpoint.RescaleMappings;
 import org.apache.flink.runtime.io.network.api.SubtaskConnectionDescriptor;
@@ -25,6 +28,8 @@ import 
org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.io.network.buffer.BufferBuilder;
 import org.apache.flink.runtime.io.network.buffer.BufferConsumer;
+import org.apache.flink.runtime.io.network.buffer.BufferRecycler;
+import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
 import 
org.apache.flink.runtime.io.network.partition.CheckpointedResultPartition;
 import org.apache.flink.runtime.io.network.partition.consumer.InputChannel;
 import org.apache.flink.runtime.io.network.partition.consumer.InputGate;
@@ -39,6 +44,7 @@ import java.util.List;
 import java.util.Map;
 
 import static 
org.apache.flink.runtime.checkpoint.channel.ChannelStateByteBuffer.wrap;
+import static org.apache.flink.util.Preconditions.checkArgument;
 import static org.apache.flink.util.Preconditions.checkState;
 
 interface RecoveredChannelStateHandler<Info, Context> extends AutoCloseable {
@@ -82,24 +88,94 @@ class InputChannelRecoveredStateHandler
      */
     @Nullable private final ChannelStateFilteringHandler filteringHandler;
 
+    /** Network buffer memory segment size in bytes. Used to size the reusable 
pre-filter buffer. */
+    private final int memorySegmentSize;
+
+    /**
+     * Reusable heap memory segment backing the pre-filter buffer in filtering 
mode. Lazily
+     * allocated on the first {@link #getPreFilterBuffer} call, reused for 
every subsequent call,
+     * and freed in {@link #close()}.
+     *
+     * <p>Reuse is safe because at most one pre-filter buffer is in flight per 
task at any moment.
+     * This invariant is enforced at runtime by {@link #preFilterBufferInUse}.
+     */
+    @Nullable private MemorySegment preFilterSegment;
+
+    /**
+     * Tracks whether {@link #preFilterSegment} is currently wrapped by a live 
{@link Buffer} that
+     * has not yet been recycled. Flipped to {@code true} when a new buffer is 
issued, and flipped
+     * back to {@code false} by the custom {@link BufferRecycler} when the 
buffer is recycled.
+     */
+    private boolean preFilterBufferInUse;
+
     InputChannelRecoveredStateHandler(
             InputGate[] inputGates,
             InflightDataRescalingDescriptor channelMapping,
-            @Nullable ChannelStateFilteringHandler filteringHandler) {
+            @Nullable ChannelStateFilteringHandler filteringHandler,
+            int memorySegmentSize) {
         this.inputGates = inputGates;
         this.channelMapping = channelMapping;
         this.filteringHandler = filteringHandler;
+        checkArgument(
+                memorySegmentSize > 0, "memorySegmentSize must be positive: 
%s", memorySegmentSize);
+        this.memorySegmentSize = memorySegmentSize;
     }
 
     @Override
     public BufferWithContext<Buffer> getBuffer(InputChannelInfo channelInfo)
             throws IOException, InterruptedException {
-        // request the buffer from any mapped channel as they all will receive 
the same buffer
+        if (filteringHandler != null) {
+            return getPreFilterBuffer();
+        }
+        // Non-filtering mode: use existing network buffer pool allocation.
         RecoveredInputChannel channel = getMappedChannels(channelInfo);
         Buffer buffer = channel.requestBufferBlocking();
         return new BufferWithContext<>(wrap(buffer), buffer);
     }
 
+    /**
+     * Allocates a pre-filter buffer from a reusable heap segment (isolated 
from the Network Buffer
+     * Pool) in filtering mode.
+     *
+     * <p>Memory management: a single {@link MemorySegment} per task is lazily 
allocated on first
+     * invocation and reused across every subsequent call. The custom {@link 
BufferRecycler} does
+     * not free the segment — it only flips {@link #preFilterBufferInUse} back 
to {@code false} so
+     * the next call can reuse it. The segment itself is freed in {@link 
#close()}.
+     *
+     * <p>Runtime invariant check: the one-at-a-time invariant on pre-filter 
buffers is guaranteed
+     * by Flink's serial recovery loop and the deserializer's ownership 
contract. This method
+     * asserts the invariant before issuing a buffer: if a previously issued 
buffer has not yet been
+     * recycled, it throws {@link IllegalStateException} so any future 
regression fails loudly
+     * instead of silently corrupting memory.
+     */
+    private BufferWithContext<Buffer> getPreFilterBuffer() {
+        checkState(
+                !preFilterBufferInUse,
+                "Previous pre-filter buffer has not been recycled. This 
violates the "
+                        + "one-buffer-at-a-time invariant of pre-filter 
buffers.");
+
+        if (preFilterSegment == null) {
+            preFilterSegment = 
MemorySegmentFactory.allocateUnpooledSegment(memorySegmentSize);
+        }
+        preFilterBufferInUse = true;
+
+        // The recycler keeps the segment alive for reuse; only flips the 
in-use flag.
+        BufferRecycler recycler = segment -> preFilterBufferInUse = false;
+        Buffer buffer = new NetworkBuffer(preFilterSegment, recycler);
+        return new BufferWithContext<>(wrap(buffer), buffer);
+    }
+
+    @VisibleForTesting
+    boolean isPreFilterBufferInUse() {
+        return preFilterBufferInUse;
+    }
+
+    @VisibleForTesting
+    @Nullable
+    MemorySegment getPreFilterSegmentForTesting() {
+        return preFilterSegment;
+    }
+
     @Override
     public void recover(
             InputChannelInfo channelInfo,
@@ -162,6 +238,11 @@ class InputChannelRecoveredStateHandler
         for (final InputGate inputGate : inputGates) {
             inputGate.finishReadRecoveredState();
         }
+        if (preFilterSegment != null) {
+            preFilterSegment.free();
+            preFilterSegment = null;
+            preFilterBufferInUse = false;
+        }
     }
 
     private RecoveredInputChannel getChannel(int gateIndex, int 
subPartitionIndex) {
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java
index 8aa8db2679f..c52572e52fa 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java
@@ -74,7 +74,8 @@ public class SequentialChannelStateReaderImpl implements 
SequentialChannelStateR
                         new InputChannelRecoveredStateHandler(
                                 inputGates,
                                 
taskStateSnapshot.getInputRescalingDescriptor(),
-                                filteringHandler)) {
+                                filteringHandler,
+                                filterContext.getMemorySegmentSize())) {
             read(
                     stateHandler,
                     groupByDelegate(
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContext.java
 
b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContext.java
index f2568fe5854..e207eb6213e 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContext.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContext.java
@@ -20,6 +20,7 @@ package org.apache.flink.streaming.runtime.io.recovery;
 import org.apache.flink.annotation.Internal;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor;
+import org.apache.flink.runtime.memory.MemoryManager;
 import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
 
 import static org.apache.flink.util.Preconditions.checkArgument;
@@ -96,6 +97,13 @@ public class RecordFilterContext {
     /** Whether unaligned checkpointing during recovery is enabled. */
     private final boolean checkpointingDuringRecoveryEnabled;
 
+    /**
+     * Network buffer memory segment size in bytes (from 
taskmanager.memory.segment-size). Used to
+     * size the reusable heap source buffer in {@code 
InputChannelRecoveredStateHandler} so it
+     * matches the network buffer size.
+     */
+    private final int memorySegmentSize;
+
     /**
      * Creates a new RecordFilterContext.
      *
@@ -108,6 +116,7 @@ public class RecordFilterContext {
      *     (converted to empty array).
      * @param checkpointingDuringRecoveryEnabled Whether unaligned 
checkpointing during recovery is
      *     enabled.
+     * @param memorySegmentSize Network buffer memory segment size in bytes. 
Must be positive.
      */
     public RecordFilterContext(
             InputFilterConfig[] inputConfigs,
@@ -115,13 +124,17 @@ public class RecordFilterContext {
             int subtaskIndex,
             int maxParallelism,
             String[] tmpDirectories,
-            boolean checkpointingDuringRecoveryEnabled) {
+            boolean checkpointingDuringRecoveryEnabled,
+            int memorySegmentSize) {
         this.inputConfigs = checkNotNull(inputConfigs).clone();
         this.rescalingDescriptor = checkNotNull(rescalingDescriptor);
         this.subtaskIndex = subtaskIndex;
         this.maxParallelism = maxParallelism;
         this.tmpDirectories = tmpDirectories != null ? tmpDirectories : new 
String[0];
         this.checkpointingDuringRecoveryEnabled = 
checkpointingDuringRecoveryEnabled;
+        checkArgument(
+                memorySegmentSize > 0, "memorySegmentSize must be positive: 
%s", memorySegmentSize);
+        this.memorySegmentSize = memorySegmentSize;
     }
 
     /**
@@ -195,6 +208,15 @@ public class RecordFilterContext {
         return tmpDirectories;
     }
 
+    /**
+     * Gets the network buffer memory segment size in bytes.
+     *
+     * @return The memory segment size. Always positive.
+     */
+    public int getMemorySegmentSize() {
+        return memorySegmentSize;
+    }
+
     /**
      * Checks if a specific gate and subtask combination is ambiguous 
(requires filtering).
      *
@@ -222,6 +244,7 @@ public class RecordFilterContext {
                 0,
                 0,
                 new String[0],
-                false);
+                false,
+                MemoryManager.DEFAULT_PAGE_SIZE);
     }
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
 
b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
index 7938a1ef278..e06164125fd 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
@@ -2057,7 +2057,8 @@ public abstract class StreamTask<OUT, OP extends 
StreamOperator<OUT>>
                 getEnvironment().getTaskInfo().getIndexOfThisSubtask(),
                 
getEnvironment().getTaskInfo().getMaxNumberOfParallelSubtasks(),
                 getEnvironment().getIOManager().getSpillingDirectoriesPaths(),
-                true);
+                true,
+                
ConfigurationParserUtils.getPageSize(getEnvironment().getJobConfiguration()));
     }
 
     /** Check whether records can be emitted in batch. */
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/InputChannelRecoveredStateHandlerTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/InputChannelRecoveredStateHandlerTest.java
index 39ce6c7d4bf..9c4aab0bc7a 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/InputChannelRecoveredStateHandlerTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/InputChannelRecoveredStateHandlerTest.java
@@ -18,14 +18,17 @@
 
 package org.apache.flink.runtime.checkpoint.channel;
 
+import org.apache.flink.core.memory.MemorySegment;
 import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor;
 import org.apache.flink.runtime.checkpoint.RescaleMappings;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
 import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
 import 
org.apache.flink.runtime.io.network.partition.consumer.InputChannelBuilder;
 import org.apache.flink.runtime.io.network.partition.consumer.InputGate;
 import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate;
 import 
org.apache.flink.runtime.io.network.partition.consumer.SingleInputGateBuilder;
+import org.apache.flink.runtime.memory.MemoryManager;
 
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
@@ -48,7 +51,8 @@ class InputChannelRecoveredStateHandlerTest extends 
RecoveredChannelStateHandler
     @BeforeEach
     void setUp() {
         // given: Segment provider with defined number of allocated segments.
-        networkBufferPool = new NetworkBufferPool(preAllocatedSegments, 1024);
+        networkBufferPool =
+                new NetworkBufferPool(preAllocatedSegments, 
MemoryManager.DEFAULT_PAGE_SIZE);
 
         // and: Configured input gate with recovered channel.
         inputGate =
@@ -78,7 +82,8 @@ class InputChannelRecoveredStateHandlerTest extends 
RecoveredChannelStateHandler
                                             
.InflightDataGateOrPartitionRescalingDescriptor
                                             .MappingType.IDENTITY)
                         }),
-                null);
+                null,
+                MemoryManager.DEFAULT_PAGE_SIZE);
     }
 
     private InputChannelRecoveredStateHandler buildMultiChannelHandler() {
@@ -105,7 +110,33 @@ class InputChannelRecoveredStateHandlerTest extends 
RecoveredChannelStateHandler
                                             
.InflightDataGateOrPartitionRescalingDescriptor
                                             .MappingType.RESCALING)
                         }),
-                null);
+                null,
+                MemoryManager.DEFAULT_PAGE_SIZE);
+    }
+
+    /** Builds a handler in filtering mode (non-null filtering handler, no-op 
stub). */
+    private InputChannelRecoveredStateHandler 
buildFilteringInputChannelStateHandler() {
+        // Empty GateFilterHandler array: filtering is "enabled" structurally, 
but no gate-level
+        // filter logic runs. Suitable for exercising getBuffer() routing only.
+        ChannelStateFilteringHandler stubFilteringHandler =
+                new ChannelStateFilteringHandler(
+                        new ChannelStateFilteringHandler.GateFilterHandler[0]);
+        return new InputChannelRecoveredStateHandler(
+                new InputGate[] {inputGate},
+                new InflightDataRescalingDescriptor(
+                        new InflightDataRescalingDescriptor
+                                        
.InflightDataGateOrPartitionRescalingDescriptor[] {
+                            new InflightDataRescalingDescriptor
+                                    
.InflightDataGateOrPartitionRescalingDescriptor(
+                                    new int[] {1},
+                                    RescaleMappings.identity(1, 1),
+                                    new HashSet<>(),
+                                    InflightDataRescalingDescriptor
+                                            
.InflightDataGateOrPartitionRescalingDescriptor
+                                            .MappingType.IDENTITY)
+                        }),
+                stubFilteringHandler,
+                MemoryManager.DEFAULT_PAGE_SIZE);
     }
 
     @Test
@@ -153,4 +184,118 @@ class InputChannelRecoveredStateHandlerTest extends 
RecoveredChannelStateHandler
         assertThat(networkBufferPool.getNumberOfAvailableMemorySegments())
                 .isEqualTo(preAllocatedSegments);
     }
+
+    @Test
+    void testPreFilterBufferIsolationFromNetworkBufferPool() throws Exception {
+        try (InputChannelRecoveredStateHandler filteringHandler =
+                buildFilteringInputChannelStateHandler()) {
+            int availableBefore = 
networkBufferPool.getNumberOfAvailableMemorySegments();
+
+            RecoveredChannelStateHandler.BufferWithContext<Buffer> 
bufferWithContext =
+                    filteringHandler.getBuffer(channelInfo);
+            try {
+                Buffer buffer = bufferWithContext.context;
+                // Heap-backed: NetworkBuffer wrapping a heap (non off-heap) 
segment.
+                assertThat(buffer).isInstanceOf(NetworkBuffer.class);
+                assertThat(buffer.getMemorySegment().isOffHeap()).isFalse();
+                assertThat(buffer.getMemorySegment().size())
+                        .isEqualTo(MemoryManager.DEFAULT_PAGE_SIZE);
+                // Pool is untouched.
+                
assertThat(networkBufferPool.getNumberOfAvailableMemorySegments())
+                        .isEqualTo(availableBefore);
+            } finally {
+                bufferWithContext.context.recycleBuffer();
+            }
+        }
+    }
+
+    @Test
+    void testNonFilteringModeUsesNetworkBufferPool() throws Exception {
+        int availableBefore = 
networkBufferPool.getNumberOfAvailableMemorySegments();
+
+        RecoveredChannelStateHandler.BufferWithContext<Buffer> 
bufferWithContext =
+                icsHandler.getBuffer(channelInfo);
+        try {
+            Buffer buffer = bufferWithContext.context;
+            // Pool allocation reduces available count.
+            assertThat(networkBufferPool.getNumberOfAvailableMemorySegments())
+                    .isLessThan(availableBefore);
+            // Memory segment comes from pre-allocated pool (off-heap).
+            assertThat(buffer.getMemorySegment().isOffHeap()).isTrue();
+        } finally {
+            bufferWithContext.context.recycleBuffer();
+        }
+    }
+
+    @Test
+    void testPreFilterSegmentReusedAcrossCalls() throws Exception {
+        try (InputChannelRecoveredStateHandler filteringHandler =
+                buildFilteringInputChannelStateHandler()) {
+            // First getBuffer() lazily allocates the segment.
+            RecoveredChannelStateHandler.BufferWithContext<Buffer> first =
+                    filteringHandler.getBuffer(channelInfo);
+            MemorySegment segment1 = first.context.getMemorySegment();
+            first.context.recycleBuffer();
+
+            // Second getBuffer() must reuse the same segment instance.
+            RecoveredChannelStateHandler.BufferWithContext<Buffer> second =
+                    filteringHandler.getBuffer(channelInfo);
+            MemorySegment segment2 = second.context.getMemorySegment();
+            second.context.recycleBuffer();
+
+            assertThat(segment2).isSameAs(segment1);
+            // Third call, same check.
+            RecoveredChannelStateHandler.BufferWithContext<Buffer> third =
+                    filteringHandler.getBuffer(channelInfo);
+            MemorySegment segment3 = third.context.getMemorySegment();
+            third.context.recycleBuffer();
+
+            assertThat(segment3).isSameAs(segment1);
+
+            // Internal assertion: inUse flipped back to false after each 
recycle.
+            assertThat(filteringHandler.isPreFilterBufferInUse()).isFalse();
+        }
+    }
+
+    @Test
+    void testGetBufferThrowsWhenPriorBufferNotRecycled() throws Exception {
+        try (InputChannelRecoveredStateHandler filteringHandler =
+                buildFilteringInputChannelStateHandler()) {
+            RecoveredChannelStateHandler.BufferWithContext<Buffer> first =
+                    filteringHandler.getBuffer(channelInfo);
+            try {
+                assertThat(filteringHandler.isPreFilterBufferInUse()).isTrue();
+
+                // Without recycling, requesting another buffer must fail.
+                assertThatThrownBy(() -> 
filteringHandler.getBuffer(channelInfo))
+                        .isInstanceOf(IllegalStateException.class)
+                        .hasMessageContaining("Previous pre-filter buffer has 
not been recycled");
+            } finally {
+                first.context.recycleBuffer();
+            }
+
+            // After recycling, a new getBuffer() succeeds.
+            RecoveredChannelStateHandler.BufferWithContext<Buffer> second =
+                    filteringHandler.getBuffer(channelInfo);
+            second.context.recycleBuffer();
+        }
+    }
+
+    @Test
+    void testPreFilterSegmentFreedOnClose() throws Exception {
+        InputChannelRecoveredStateHandler filteringHandler =
+                buildFilteringInputChannelStateHandler();
+        RecoveredChannelStateHandler.BufferWithContext<Buffer> 
bufferWithContext =
+                filteringHandler.getBuffer(channelInfo);
+        bufferWithContext.context.recycleBuffer();
+
+        MemorySegment segment = 
filteringHandler.getPreFilterSegmentForTesting();
+        assertThat(segment).isNotNull();
+        assertThat(segment.isFreed()).isFalse();
+
+        filteringHandler.close();
+
+        assertThat(segment.isFreed()).isTrue();
+        assertThat(filteringHandler.getPreFilterSegmentForTesting()).isNull();
+    }
 }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContextTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContextTest.java
index 15b97d433a2..60494f7acbe 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContextTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContextTest.java
@@ -20,6 +20,7 @@ package org.apache.flink.streaming.runtime.io.recovery;
 import org.apache.flink.api.common.typeutils.base.LongSerializer;
 import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor;
 import org.apache.flink.runtime.checkpoint.RescaleMappings;
+import org.apache.flink.runtime.memory.MemoryManager;
 import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner;
 
 import org.junit.jupiter.api.Test;
@@ -53,7 +54,8 @@ class RecordFilterContextTest {
                         0,
                         128,
                         new String[] {"/tmp"},
-                        true);
+                        true,
+                        MemoryManager.DEFAULT_PAGE_SIZE);
 
         assertThat(context.getNumberOfGates()).isEqualTo(1);
         assertThat(context.getInputConfig(0)).isSameAs(config);
@@ -71,7 +73,8 @@ class RecordFilterContextTest {
                         0,
                         128,
                         null,
-                        false);
+                        false,
+                        MemoryManager.DEFAULT_PAGE_SIZE);
 
         assertThatThrownBy(() -> context.getInputConfig(0))
                 .isInstanceOf(IllegalArgumentException.class);
@@ -88,7 +91,8 @@ class RecordFilterContextTest {
                         0,
                         128,
                         null,
-                        false);
+                        false,
+                        MemoryManager.DEFAULT_PAGE_SIZE);
 
         assertThat(context.getTmpDirectories()).isNotNull().isEmpty();
     }
@@ -108,7 +112,8 @@ class RecordFilterContextTest {
                         0,
                         128,
                         null,
-                        false);
+                        false,
+                        MemoryManager.DEFAULT_PAGE_SIZE);
 
         assertThat(context.isAmbiguous(0, 0)).isFalse();
     }
@@ -128,7 +133,8 @@ class RecordFilterContextTest {
                         0,
                         128,
                         null,
-                        true);
+                        true,
+                        MemoryManager.DEFAULT_PAGE_SIZE);
 
         assertThat(context.isAmbiguous(0, 0)).isTrue();
     }
@@ -147,7 +153,8 @@ class RecordFilterContextTest {
                         0,
                         128,
                         null,
-                        true);
+                        true,
+                        MemoryManager.DEFAULT_PAGE_SIZE);
 
         // oldSubtask 0 is ambiguous
         assertThat(context.isAmbiguous(0, 0)).isTrue();
@@ -155,6 +162,45 @@ class RecordFilterContextTest {
         assertThat(context.isAmbiguous(0, 1)).isFalse();
     }
 
+    @Test
+    void testMemorySegmentSizeExposedAndValidated() {
+        RecordFilterContext context =
+                new RecordFilterContext(
+                        new RecordFilterContext.InputFilterConfig[0],
+                        InflightDataRescalingDescriptor.NO_RESCALE,
+                        0,
+                        128,
+                        null,
+                        false,
+                        MemoryManager.DEFAULT_PAGE_SIZE * 2);
+
+        
assertThat(context.getMemorySegmentSize()).isEqualTo(MemoryManager.DEFAULT_PAGE_SIZE
 * 2);
+
+        // Non-positive sizes are rejected.
+        assertThatThrownBy(
+                        () ->
+                                new RecordFilterContext(
+                                        new 
RecordFilterContext.InputFilterConfig[0],
+                                        
InflightDataRescalingDescriptor.NO_RESCALE,
+                                        0,
+                                        128,
+                                        null,
+                                        false,
+                                        0))
+                .isInstanceOf(IllegalArgumentException.class);
+        assertThatThrownBy(
+                        () ->
+                                new RecordFilterContext(
+                                        new 
RecordFilterContext.InputFilterConfig[0],
+                                        
InflightDataRescalingDescriptor.NO_RESCALE,
+                                        0,
+                                        128,
+                                        null,
+                                        false,
+                                        -1))
+                .isInstanceOf(IllegalArgumentException.class);
+    }
+
     @Test
     void testInputFilterConfigGetters() {
         ForwardPartitioner<Long> partitioner = new ForwardPartitioner<>();
@@ -182,7 +228,8 @@ class RecordFilterContextTest {
                         1,
                         256,
                         new String[] {"/tmp"},
-                        false);
+                        false,
+                        MemoryManager.DEFAULT_PAGE_SIZE);
 
         assertThat(context.getNumberOfGates()).isEqualTo(2);
         assertThat(context.getInputConfig(0)).isSameAs(config0);
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/VirtualChannelRecordFilterFactoryTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/VirtualChannelRecordFilterFactoryTest.java
index bffc42e3329..3b82587d301 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/VirtualChannelRecordFilterFactoryTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/VirtualChannelRecordFilterFactoryTest.java
@@ -19,6 +19,7 @@ package org.apache.flink.streaming.runtime.io.recovery;
 
 import org.apache.flink.api.common.typeutils.base.LongSerializer;
 import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor;
+import org.apache.flink.runtime.memory.MemoryManager;
 import org.apache.flink.streaming.runtime.partitioner.RebalancePartitioner;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 
@@ -63,7 +64,8 @@ class VirtualChannelRecordFilterFactoryTest {
                         1,
                         128,
                         new String[] {"/tmp"},
-                        true);
+                        true,
+                        MemoryManager.DEFAULT_PAGE_SIZE);
 
         VirtualChannelRecordFilterFactory<Long> factory =
                 VirtualChannelRecordFilterFactory.fromContext(context, 0);

Reply via email to