This is an automated email from the ASF dual-hosted git repository.

fanrui pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 23ebbb5d2bb0f3ffb0e07798bd3d8b0c1808331e
Author: Rui Fan <[email protected]>
AuthorDate: Wed Feb 18 21:25:46 2026 +0100

    [FLINK-38930][checkpoint] Filtering record before processing without 
spilling strategy
    
    Core filtering mechanism for recovered channel state buffers:
    - ChannelStateFilteringHandler with per-gate GateFilterHandler
    - RecordFilterContext with VirtualChannelRecordFilterFactory
    - Partial data check in SequentialChannelStateReaderImpl
    - Fix RecordFilterContext for Union downscale scenario
---
 .../channel/ChannelStateFilteringHandler.java      | 458 +++++++++++++++++++++
 .../channel/RecoveredChannelStateHandler.java      |  64 ++-
 .../channel/SequentialChannelStateReader.java      |  13 +-
 .../channel/SequentialChannelStateReaderImpl.java  |  27 +-
 .../io/StreamMultipleInputProcessorFactory.java    |   8 +-
 .../runtime/io/StreamTaskNetworkInputFactory.java  |   9 +-
 .../runtime/io/StreamTwoInputProcessorFactory.java |  11 +-
 .../runtime/io/recovery/RecordFilterContext.java   | 227 ++++++++++
 .../VirtualChannelRecordFilterFactory.java         | 123 ++++++
 .../runtime/tasks/OneInputStreamTask.java          |   6 +-
 .../flink/streaming/runtime/tasks/StreamTask.java  |  82 +++-
 .../GateFilterHandlerBufferOwnershipTest.java      | 230 +++++++++++
 .../checkpoint/channel/GateFilterHandlerTest.java  | 213 ++++++++++
 .../InputChannelRecoveredStateHandlerTest.java     |   6 +-
 .../SequentialChannelStateReaderImplTest.java      |   3 +-
 .../runtime/state/ChannelPersistenceITCase.java    |   3 +-
 .../io/recovery/RecordFilterContextTest.java       | 193 +++++++++
 .../VirtualChannelRecordFilterFactoryTest.java     |  90 ++++
 18 files changed, 1739 insertions(+), 27 deletions(-)

diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateFilteringHandler.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateFilteringHandler.java
new file mode 100644
index 00000000000..b257c3b4054
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateFilteringHandler.java
@@ -0,0 +1,458 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint.channel;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.memory.DataOutputSerializer;
+import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor;
+import org.apache.flink.runtime.checkpoint.RescaleMappings;
+import org.apache.flink.runtime.io.network.api.SubtaskConnectionDescriptor;
+import 
org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer;
+import 
org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer.DeserializationResult;
+import 
org.apache.flink.runtime.io.network.api.serialization.SpillingAdaptiveSpanningRecordDeserializer;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.partition.consumer.InputGate;
+import org.apache.flink.runtime.plugable.DeserializationDelegate;
+import org.apache.flink.runtime.plugable.NonReusingDeserializationDelegate;
+import org.apache.flink.streaming.runtime.io.recovery.RecordFilter;
+import org.apache.flink.streaming.runtime.io.recovery.RecordFilterContext;
+import org.apache.flink.streaming.runtime.io.recovery.VirtualChannel;
+import 
org.apache.flink.streaming.runtime.io.recovery.VirtualChannelRecordFilterFactory;
+import org.apache.flink.streaming.runtime.streamrecord.StreamElement;
+import org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer;
+
+import javax.annotation.Nullable;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.apache.flink.util.Preconditions.checkNotNull;
+
+/**
+ * Filters recovered channel state buffers during the channel-state-unspilling 
phase, removing
+ * records that do not belong to the current subtask after rescaling.
+ *
+ * <p>Uses a per-gate architecture: each {@link InputGate} gets its own {@link 
GateFilterHandler}
+ * with the correct serializer, so multi-input tasks (e.g., 
TwoInputStreamTask) correctly
+ * deserialize different record types on different gates.
+ */
+@Internal
+public class ChannelStateFilteringHandler implements Closeable {
+
+    // Wildcard allows heterogeneous record types across gates.
+    private final GateFilterHandler<?>[] gateHandlers;
+
+    ChannelStateFilteringHandler(GateFilterHandler<?>[] gateHandlers) {
+        this.gateHandlers = checkNotNull(gateHandlers);
+    }
+
+    /**
+     * Creates a handler from the recovery context, building per-gate virtual 
channels based on
+     * rescaling descriptors. Returns {@code null} if no filtering is needed 
(e.g., source tasks or
+     * no rescaling).
+     */
+    @Nullable
+    public static ChannelStateFilteringHandler createFromContext(
+            RecordFilterContext filterContext, InputGate[] inputGates) {
+        // Source tasks have no network inputs
+        if (filterContext.getNumberOfGates() == 0) {
+            return null;
+        }
+
+        InflightDataRescalingDescriptor rescalingDescriptor =
+                filterContext.getRescalingDescriptor();
+
+        GateFilterHandler<?>[] gateHandlers = new 
GateFilterHandler<?>[inputGates.length];
+        boolean hasAnyVirtualChannels = false;
+
+        for (int gateIndex = 0; gateIndex < inputGates.length; gateIndex++) {
+            gateHandlers[gateIndex] =
+                    createGateHandler(filterContext, inputGates, 
rescalingDescriptor, gateIndex);
+            if (gateHandlers[gateIndex] != null) {
+                hasAnyVirtualChannels = true;
+            }
+        }
+
+        if (!hasAnyVirtualChannels) {
+            return null;
+        }
+
+        return new ChannelStateFilteringHandler(gateHandlers);
+    }
+
+    /**
+     * Filters a recovered buffer from the specified virtual channel, 
returning new buffers
+     * containing only the records that belong to the current subtask.
+     *
+     * <p>One source buffer may produce 0 to N result buffers: 0 if all 
records are filtered out,
+     * and potentially more than 1 when a spanning record completes in this 
buffer. The deserializer
+     * caches partial record data from previous buffers, so the output may 
contain data that was not
+     * in the current source buffer, causing the total output size to exceed 
one buffer capacity.
+     * This can happen with any spanning record regardless of its size.
+     *
+     * @return filtered buffers, possibly empty if all records were filtered 
out.
+     */
+    public List<Buffer> filterAndRewrite(
+            int gateIndex,
+            int oldSubtaskIndex,
+            int oldChannelIndex,
+            Buffer sourceBuffer,
+            BufferSupplier bufferSupplier)
+            throws IOException, InterruptedException {
+
+        if (gateIndex < 0 || gateIndex >= gateHandlers.length) {
+            throw new IllegalStateException(
+                    "Invalid gateIndex: "
+                            + gateIndex
+                            + ", number of gates: "
+                            + gateHandlers.length);
+        }
+
+        GateFilterHandler<?> gateHandler = gateHandlers[gateIndex];
+        if (gateHandler == null) {
+            throw new IllegalStateException(
+                    "No handler for gateIndex "
+                            + gateIndex
+                            + ". This gate is not a network input and should 
not have recovered buffers.");
+        }
+        return gateHandler.filterAndRewrite(
+                oldSubtaskIndex, oldChannelIndex, sourceBuffer, 
bufferSupplier);
+    }
+
+    /** Returns {@code true} if any virtual channel has a partial (spanning) 
record pending. */
+    public boolean hasPartialData() {
+        for (GateFilterHandler<?> handler : gateHandlers) {
+            if (handler != null && handler.hasPartialData()) {
+                return true;
+            }
+        }
+        return false;
+    }
+
+    @Override
+    public void close() {
+        for (GateFilterHandler<?> handler : gateHandlers) {
+            if (handler != null) {
+                handler.clear();
+            }
+        }
+    }
+
+    // 
-------------------------------------------------------------------------------------------
+    // Private static helper methods
+    // 
-------------------------------------------------------------------------------------------
+
+    /**
+     * Creates a {@link GateFilterHandler} for a single gate. The method-level 
type parameter
+     * ensures type safety within each gate while allowing different gates to 
have different types.
+     */
+    @SuppressWarnings("unchecked")
+    @Nullable
+    private static <T> GateFilterHandler<T> createGateHandler(
+            RecordFilterContext filterContext,
+            InputGate[] inputGates,
+            InflightDataRescalingDescriptor rescalingDescriptor,
+            int gateIndex) {
+        RecordFilterContext.InputFilterConfig inputConfig = 
filterContext.getInputConfig(gateIndex);
+        if (inputConfig == null) {
+            throw new IllegalStateException(
+                    "No InputFilterConfig for gateIndex "
+                            + gateIndex
+                            + ". This indicates a bug in RecordFilterContext 
initialization.");
+        }
+
+        InputGate gate = inputGates[gateIndex];
+        int[] oldSubtaskIndexes = 
rescalingDescriptor.getOldSubtaskIndexes(gateIndex);
+        RescaleMappings channelMapping = 
rescalingDescriptor.getChannelMapping(gateIndex);
+
+        TypeSerializer<T> typeSerializer = (TypeSerializer<T>) 
inputConfig.getTypeSerializer();
+        StreamElementSerializer<T> elementSerializer =
+                new StreamElementSerializer<>(typeSerializer);
+
+        VirtualChannelRecordFilterFactory<T> filterFactory =
+                VirtualChannelRecordFilterFactory.fromContext(filterContext, 
gateIndex);
+
+        Map<SubtaskConnectionDescriptor, VirtualChannel<T>> 
gateVirtualChannels = new HashMap<>();
+
+        for (int oldSubtaskIndex : oldSubtaskIndexes) {
+            int numChannels = gate.getNumberOfInputChannels();
+            int[] oldChannelIndexes = getOldChannelIndexes(channelMapping, 
numChannels);
+
+            for (int oldChannelIndex : oldChannelIndexes) {
+                SubtaskConnectionDescriptor key =
+                        new SubtaskConnectionDescriptor(oldSubtaskIndex, 
oldChannelIndex);
+
+                if (gateVirtualChannels.containsKey(key)) {
+                    continue;
+                }
+
+                // Only ambiguous channels need actual filtering; 
non-ambiguous ones pass through
+                boolean isAmbiguous = 
rescalingDescriptor.isAmbiguous(gateIndex, oldSubtaskIndex);
+
+                RecordFilter<T> recordFilter =
+                        isAmbiguous
+                                ? filterFactory.createFilter()
+                                : 
VirtualChannelRecordFilterFactory.createPassThroughFilter();
+
+                RecordDeserializer<DeserializationDelegate<StreamElement>> 
deserializer =
+                        createDeserializer(filterContext.getTmpDirectories());
+
+                VirtualChannel<T> vc = new VirtualChannel<>(deserializer, 
recordFilter);
+                gateVirtualChannels.put(key, vc);
+            }
+        }
+
+        if (gateVirtualChannels.isEmpty()) {
+            return null;
+        }
+
+        return new GateFilterHandler<>(gateVirtualChannels, elementSerializer);
+    }
+
+    /**
+     * Collects all old channel indexes that are mapped from any new channel 
index in this gate.
+     * channelMapping is new-to-old, so we iterate new indexes and collect 
their old counterparts.
+     */
+    private static int[] getOldChannelIndexes(RescaleMappings channelMapping, 
int numChannels) {
+        List<Integer> oldIndexes = new ArrayList<>();
+        for (int newIndex = 0; newIndex < numChannels; newIndex++) {
+            int[] mapped = channelMapping.getMappedIndexes(newIndex);
+            for (int oldIndex : mapped) {
+                if (!oldIndexes.contains(oldIndex)) {
+                    oldIndexes.add(oldIndex);
+                }
+            }
+        }
+        return oldIndexes.stream().mapToInt(Integer::intValue).toArray();
+    }
+
+    private static RecordDeserializer<DeserializationDelegate<StreamElement>> 
createDeserializer(
+            String[] tmpDirectories) {
+        if (tmpDirectories != null && tmpDirectories.length > 0) {
+            return new 
SpillingAdaptiveSpanningRecordDeserializer<>(tmpDirectories);
+        } else {
+            String[] defaultDirs = new String[] 
{System.getProperty("java.io.tmpdir")};
+            return new 
SpillingAdaptiveSpanningRecordDeserializer<>(defaultDirs);
+        }
+    }
+
+    // 
-------------------------------------------------------------------------------------------
+    // Inner classes
+    // 
-------------------------------------------------------------------------------------------
+
+    /** Provides buffers for re-serializing filtered records. Implementations 
may block. */
+    @FunctionalInterface
+    public interface BufferSupplier {
+        Buffer requestBufferBlocking() throws IOException, 
InterruptedException;
+    }
+
+    /**
+     * Handles record filtering for a single input gate. Each gate has its own 
serializer and set of
+     * virtual channels, allowing different gates to handle different record 
types independently.
+     */
+    static class GateFilterHandler<T> {
+
+        private final Map<SubtaskConnectionDescriptor, VirtualChannel<T>> 
virtualChannels;
+        private final StreamElementSerializer<T> serializer;
+        private final DeserializationDelegate<StreamElement> 
deserializationDelegate;
+        private final DataOutputSerializer outputSerializer;
+        private final byte[] lengthBuffer = new byte[4];
+
+        GateFilterHandler(
+                Map<SubtaskConnectionDescriptor, VirtualChannel<T>> 
virtualChannels,
+                StreamElementSerializer<T> serializer) {
+            this.virtualChannels = checkNotNull(virtualChannels);
+            this.serializer = checkNotNull(serializer);
+            this.deserializationDelegate = new 
NonReusingDeserializationDelegate<>(serializer);
+            this.outputSerializer = new DataOutputSerializer(128);
+        }
+
+        /**
+         * Deserializes records from {@code sourceBuffer}, applies the virtual 
channel's record
+         * filter, and immediately re-serializes each surviving record into 
output buffers.
+         */
+        List<Buffer> filterAndRewrite(
+                int oldSubtaskIndex,
+                int oldChannelIndex,
+                Buffer sourceBuffer,
+                BufferSupplier bufferSupplier)
+                throws IOException, InterruptedException {
+
+            boolean sourceBufferOwnershipTransferred = false;
+            List<Buffer> resultBuffers = new ArrayList<>();
+            Buffer currentBuffer = null;
+            try {
+                SubtaskConnectionDescriptor key =
+                        new SubtaskConnectionDescriptor(oldSubtaskIndex, 
oldChannelIndex);
+                VirtualChannel<T> vc = virtualChannels.get(key);
+                if (vc == null) {
+                    throw new IllegalStateException(
+                            "No VirtualChannel found for key: "
+                                    + key
+                                    + "; known channels are "
+                                    + virtualChannels.keySet());
+                }
+
+                vc.setNextBuffer(sourceBuffer);
+                sourceBufferOwnershipTransferred = true;
+
+                while (true) {
+                    DeserializationResult result = 
vc.getNextRecord(deserializationDelegate);
+                    if (result.isFullRecord()) {
+                        if (currentBuffer == null) {
+                            currentBuffer = 
bufferSupplier.requestBufferBlocking();
+                        }
+                        currentBuffer =
+                                serializeElement(
+                                        deserializationDelegate.getInstance(),
+                                        currentBuffer,
+                                        resultBuffers,
+                                        bufferSupplier);
+                    }
+                    if (result.isBufferConsumed()) {
+                        break;
+                    }
+                }
+
+                if (currentBuffer != null) {
+                    if (currentBuffer.readableBytes() > 0) {
+                        resultBuffers.add(currentBuffer);
+                    } else {
+                        currentBuffer.recycleBuffer();
+                    }
+                    currentBuffer = null;
+                }
+
+                return resultBuffers;
+            } catch (Throwable t) {
+                if (!sourceBufferOwnershipTransferred) {
+                    sourceBuffer.recycleBuffer();
+                }
+                // Avoid double-recycle: currentBuffer may already be the last 
element in
+                // resultBuffers if serializeElement added it before the 
exception.
+                if (currentBuffer != null
+                        && (resultBuffers.isEmpty()
+                                || resultBuffers.get(resultBuffers.size() - 1) 
!= currentBuffer)) {
+                    currentBuffer.recycleBuffer();
+                }
+                for (Buffer buf : resultBuffers) {
+                    buf.recycleBuffer();
+                }
+                resultBuffers.clear();
+                throw t;
+            }
+        }
+
+        /**
+         * Serializes a single stream element into the current buffer using 
the length-prefixed
+         * format (4-byte big-endian length + record bytes) expected by 
Flink's record
+         * deserializers. Spills into new buffers from {@code bufferSupplier} 
when needed.
+         *
+         * @return the buffer to continue writing into (may differ from the 
input buffer).
+         */
+        private Buffer serializeElement(
+                StreamElement element,
+                Buffer currentBuffer,
+                List<Buffer> resultBuffers,
+                BufferSupplier bufferSupplier)
+                throws IOException, InterruptedException {
+            outputSerializer.clear();
+            serializer.serialize(element, outputSerializer);
+            int recordLength = outputSerializer.length();
+
+            writeLengthToBuffer(recordLength);
+            currentBuffer =
+                    writeDataToBuffer(
+                            lengthBuffer, 0, 4, currentBuffer, resultBuffers, 
bufferSupplier);
+
+            byte[] serializedData = outputSerializer.getSharedBuffer();
+            currentBuffer =
+                    writeDataToBuffer(
+                            serializedData,
+                            0,
+                            recordLength,
+                            currentBuffer,
+                            resultBuffers,
+                            bufferSupplier);
+            return currentBuffer;
+        }
+
+        private void writeLengthToBuffer(int length) {
+            lengthBuffer[0] = (byte) (length >> 24);
+            lengthBuffer[1] = (byte) (length >> 16);
+            lengthBuffer[2] = (byte) (length >> 8);
+            lengthBuffer[3] = (byte) length;
+        }
+
+        /**
+         * Writes data to the current buffer, spilling into new buffers from 
{@code bufferSupplier}
+         * when the current one is full.
+         *
+         * @return the buffer to continue writing into (may differ from the 
input buffer).
+         */
+        private Buffer writeDataToBuffer(
+                byte[] data,
+                int dataOffset,
+                int dataLength,
+                Buffer currentBuffer,
+                List<Buffer> resultBuffers,
+                BufferSupplier bufferSupplier)
+                throws IOException, InterruptedException {
+            int offset = dataOffset;
+            int remaining = dataLength;
+
+            while (remaining > 0) {
+                int writableBytes = currentBuffer.getMaxCapacity() - 
currentBuffer.getSize();
+
+                if (writableBytes == 0) {
+                    // Buffer is full, transfer ownership to resultBuffers
+                    resultBuffers.add(currentBuffer);
+                    currentBuffer = bufferSupplier.requestBufferBlocking();
+                    writableBytes = currentBuffer.getMaxCapacity();
+                }
+
+                int bytesToWrite = Math.min(remaining, writableBytes);
+                currentBuffer
+                        .getMemorySegment()
+                        .put(
+                                currentBuffer.getMemorySegmentOffset() + 
currentBuffer.getSize(),
+                                data,
+                                offset,
+                                bytesToWrite);
+                currentBuffer.setSize(currentBuffer.getSize() + bytesToWrite);
+
+                offset += bytesToWrite;
+                remaining -= bytesToWrite;
+            }
+            return currentBuffer;
+        }
+
+        boolean hasPartialData() {
+            return 
virtualChannels.values().stream().anyMatch(VirtualChannel::hasPartialData);
+        }
+
+        void clear() {
+            virtualChannels.values().forEach(VirtualChannel::clear);
+        }
+    }
+}
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandler.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandler.java
index 85fc31db4bc..31db728bc48 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandler.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandler.java
@@ -31,6 +31,7 @@ import 
org.apache.flink.runtime.io.network.partition.consumer.InputGate;
 import 
org.apache.flink.runtime.io.network.partition.consumer.RecoveredInputChannel;
 
 import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
 
 import java.io.IOException;
 import java.util.HashMap;
@@ -63,7 +64,7 @@ interface RecoveredChannelStateHandler<Info, Context> extends 
AutoCloseable {
      * case of an error.
      */
     void recover(Info info, int oldSubtaskIndex, BufferWithContext<Context> 
bufferWithContext)
-            throws IOException;
+            throws IOException, InterruptedException;
 }
 
 class InputChannelRecoveredStateHandler
@@ -75,10 +76,19 @@ class InputChannelRecoveredStateHandler
     private final Map<InputChannelInfo, RecoveredInputChannel> 
rescaledChannels = new HashMap<>();
     private final Map<Integer, RescaleMappings> oldToNewMappings = new 
HashMap<>();
 
+    /**
+     * Optional filtering handler for filtering recovered buffers. When 
non-null, filtering is
+     * performed during recovery in the channel-state-unspilling thread.
+     */
+    @Nullable private final ChannelStateFilteringHandler filteringHandler;
+
     InputChannelRecoveredStateHandler(
-            InputGate[] inputGates, InflightDataRescalingDescriptor 
channelMapping) {
+            InputGate[] inputGates,
+            InflightDataRescalingDescriptor channelMapping,
+            @Nullable ChannelStateFilteringHandler filteringHandler) {
         this.inputGates = inputGates;
         this.channelMapping = channelMapping;
+        this.filteringHandler = filteringHandler;
     }
 
     @Override
@@ -95,23 +105,57 @@ class InputChannelRecoveredStateHandler
             InputChannelInfo channelInfo,
             int oldSubtaskIndex,
             BufferWithContext<Buffer> bufferWithContext)
-            throws IOException {
+            throws IOException, InterruptedException {
         Buffer buffer = bufferWithContext.context;
         try {
             if (buffer.readableBytes() > 0) {
                 RecoveredInputChannel channel = getMappedChannels(channelInfo);
-                channel.onRecoveredStateBuffer(
-                        EventSerializer.toBuffer(
-                                new SubtaskConnectionDescriptor(
-                                        oldSubtaskIndex, 
channelInfo.getInputChannelIdx()),
-                                false));
-                channel.onRecoveredStateBuffer(buffer.retainBuffer());
+
+                if (filteringHandler != null) {
+                    recoverWithFiltering(
+                            channel, channelInfo, oldSubtaskIndex, 
buffer.retainBuffer());
+                } else {
+                    channel.onRecoveredStateBuffer(
+                            EventSerializer.toBuffer(
+                                    new SubtaskConnectionDescriptor(
+                                            oldSubtaskIndex, 
channelInfo.getInputChannelIdx()),
+                                    false));
+                    channel.onRecoveredStateBuffer(buffer.retainBuffer());
+                }
             }
         } finally {
             buffer.recycleBuffer();
         }
     }
 
+    private void recoverWithFiltering(
+            RecoveredInputChannel channel,
+            InputChannelInfo channelInfo,
+            int oldSubtaskIndex,
+            Buffer retainedBuffer)
+            throws IOException, InterruptedException {
+        checkState(filteringHandler != null, "filtering handler not set.");
+        List<Buffer> filteredBuffers =
+                filteringHandler.filterAndRewrite(
+                        channelInfo.getGateIdx(),
+                        oldSubtaskIndex,
+                        channelInfo.getInputChannelIdx(),
+                        retainedBuffer,
+                        channel::requestBufferBlocking);
+
+        int i = 0;
+        try {
+            for (; i < filteredBuffers.size(); i++) {
+                channel.onRecoveredStateBuffer(filteredBuffers.get(i));
+            }
+        } catch (Throwable t) {
+            for (int j = i; j < filteredBuffers.size(); j++) {
+                filteredBuffers.get(j).recycleBuffer();
+            }
+            throw t;
+        }
+    }
+
     @Override
     public void close() throws IOException {
         // note that we need to finish all RecoveredInputChannels, not just 
those with state
@@ -191,7 +235,7 @@ class ResultSubpartitionRecoveredStateHandler
             ResultSubpartitionInfo subpartitionInfo,
             int oldSubtaskIndex,
             BufferWithContext<BufferBuilder> bufferWithContext)
-            throws IOException {
+            throws IOException, InterruptedException {
         try (BufferBuilder bufferBuilder = bufferWithContext.context;
                 BufferConsumer bufferConsumer = 
bufferBuilder.createBufferConsumerFromBeginning()) {
             bufferBuilder.finish();
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReader.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReader.java
index 7adf6d62946..547b60ef93a 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReader.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReader.java
@@ -20,6 +20,7 @@ package org.apache.flink.runtime.checkpoint.channel;
 import org.apache.flink.annotation.Internal;
 import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
 import org.apache.flink.runtime.io.network.partition.consumer.InputGate;
+import org.apache.flink.streaming.runtime.io.recovery.RecordFilterContext;
 
 import java.io.IOException;
 
@@ -27,7 +28,14 @@ import java.io.IOException;
 @Internal
 public interface SequentialChannelStateReader extends AutoCloseable {
 
-    void readInputData(InputGate[] inputGates) throws IOException, 
InterruptedException;
+    /**
+     * Reads input channel state with filtering support.
+     *
+     * @param inputGates The input gates to recover state for.
+     * @param filterContext The filter context containing input configs and 
rescaling info.
+     */
+    void readInputData(InputGate[] inputGates, RecordFilterContext 
filterContext)
+            throws IOException, InterruptedException;
 
     void readOutputData(ResultPartitionWriter[] writers, boolean 
notifyAndBlockOnCompletion)
             throws IOException, InterruptedException;
@@ -39,7 +47,8 @@ public interface SequentialChannelStateReader extends 
AutoCloseable {
             new SequentialChannelStateReader() {
 
                 @Override
-                public void readInputData(InputGate[] inputGates) {}
+                public void readInputData(
+                        InputGate[] inputGates, RecordFilterContext 
filterContext) {}
 
                 @Override
                 public void readOutputData(
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java
index 3daa4b4947a..8aa8db2679f 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java
@@ -28,6 +28,7 @@ import 
org.apache.flink.runtime.io.network.partition.consumer.InputGate;
 import org.apache.flink.runtime.state.AbstractChannelStateHandle;
 import org.apache.flink.runtime.state.ChannelStateHelper;
 import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.streaming.runtime.io.recovery.RecordFilterContext;
 
 import java.io.Closeable;
 import java.io.IOException;
@@ -43,6 +44,7 @@ import java.util.stream.Stream;
 import static java.util.Comparator.comparingLong;
 import static java.util.stream.Collectors.groupingBy;
 import static java.util.stream.Collectors.toList;
+import static org.apache.flink.util.Preconditions.checkState;
 
 /** {@link SequentialChannelStateReader} implementation. */
 public class SequentialChannelStateReaderImpl implements 
SequentialChannelStateReader {
@@ -58,10 +60,21 @@ public class SequentialChannelStateReaderImpl implements 
SequentialChannelStateR
     }
 
     @Override
-    public void readInputData(InputGate[] inputGates) throws IOException, 
InterruptedException {
-        try (InputChannelRecoveredStateHandler stateHandler =
-                new InputChannelRecoveredStateHandler(
-                        inputGates, 
taskStateSnapshot.getInputRescalingDescriptor())) {
+    public void readInputData(InputGate[] inputGates, RecordFilterContext 
filterContext)
+            throws IOException, InterruptedException {
+
+        // Create filtering handler if filtering is needed
+        ChannelStateFilteringHandler filteringHandler =
+                filterContext.isCheckpointingDuringRecoveryEnabled()
+                        ? 
ChannelStateFilteringHandler.createFromContext(filterContext, inputGates)
+                        : null;
+
+        try (ChannelStateFilteringHandler ignored = filteringHandler;
+                InputChannelRecoveredStateHandler stateHandler =
+                        new InputChannelRecoveredStateHandler(
+                                inputGates,
+                                
taskStateSnapshot.getInputRescalingDescriptor(),
+                                filteringHandler)) {
             read(
                     stateHandler,
                     groupByDelegate(
@@ -72,6 +85,12 @@ public class SequentialChannelStateReaderImpl implements 
SequentialChannelStateR
                     groupByDelegate(
                             streamSubtaskStates(),
                             
OperatorSubtaskState::getUpstreamOutputBufferState));
+
+            if (filteringHandler != null) {
+                checkState(
+                        !filteringHandler.hasPartialData(),
+                        "Not all data has been fully consumed during 
filtering");
+            }
         }
     }
 
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamMultipleInputProcessorFactory.java
 
b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamMultipleInputProcessorFactory.java
index 78877a3d62e..873f62c59ee 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamMultipleInputProcessorFactory.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamMultipleInputProcessorFactory.java
@@ -23,6 +23,7 @@ import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.TaskInfo;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.configuration.CheckpointingOptions;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.core.memory.ManagedMemoryUseCase;
 import org.apache.flink.metrics.Counter;
@@ -103,6 +104,10 @@ public class StreamMultipleInputProcessorFactory {
                 "Number of configured inputs in StreamConfig [%s] doesn't 
match the main operator's number of inputs [%s]",
                 configuredInputs.length,
                 inputsCount);
+
+        boolean checkpointingDuringRecoveryEnabled =
+                
CheckpointingOptions.isCheckpointingDuringRecoveryEnabled(jobConfig);
+
         StreamTaskInput[] inputs = new StreamTaskInput[inputsCount];
         for (int i = 0; i < inputsCount; i++) {
             StreamConfig.InputConfig configuredInput = configuredInputs[i];
@@ -121,7 +126,8 @@ public class StreamMultipleInputProcessorFactory {
                                 gatePartitioners,
                                 taskInfo,
                                 canEmitBatchOfRecords,
-                                
streamConfig.getWatermarkDeclarations(userClassloader));
+                                
streamConfig.getWatermarkDeclarations(userClassloader),
+                                checkpointingDuringRecoveryEnabled);
             } else if (configuredInput instanceof 
StreamConfig.SourceInputConfig) {
                 StreamConfig.SourceInputConfig sourceInput =
                         (StreamConfig.SourceInputConfig) configuredInput;
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamTaskNetworkInputFactory.java
 
b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamTaskNetworkInputFactory.java
index 46c9cd96936..70f00eed829 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamTaskNetworkInputFactory.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamTaskNetworkInputFactory.java
@@ -47,9 +47,14 @@ public class StreamTaskNetworkInputFactory {
             Function<Integer, StreamPartitioner<?>> gatePartitioners,
             TaskInfo taskInfo,
             CanEmitBatchOfRecordsChecker canEmitBatchOfRecords,
-            Set<AbstractInternalWatermarkDeclaration<?>> 
watermarkDeclarationSet) {
+            Set<AbstractInternalWatermarkDeclaration<?>> 
watermarkDeclarationSet,
+            boolean checkpointingDuringRecoveryEnabled) {
         return rescalingDescriptorinflightDataRescalingDescriptor.equals(
-                        InflightDataRescalingDescriptor.NO_RESCALE)
+                                InflightDataRescalingDescriptor.NO_RESCALE)
+                        // When filter during recovery is enabled, records are 
already filtered in
+                        // the channel-state-unspilling thread. Use 
StreamTaskNetworkInput to avoid
+                        // redundant demultiplexing/filtering in the Task 
thread.
+                        || checkpointingDuringRecoveryEnabled
                 ? new StreamTaskNetworkInput<>(
                         checkpointedInputGate,
                         inputSerializer,
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamTwoInputProcessorFactory.java
 
b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamTwoInputProcessorFactory.java
index 2a0c675710b..04cd6ccb410 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamTwoInputProcessorFactory.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamTwoInputProcessorFactory.java
@@ -22,6 +22,7 @@ import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.TaskInfo;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.configuration.CheckpointingOptions;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.core.memory.ManagedMemoryUseCase;
 import org.apache.flink.metrics.Counter;
@@ -84,6 +85,10 @@ public class StreamTwoInputProcessorFactory {
         checkNotNull(operatorChain);
 
         taskIOMetricGroup.reuseRecordsInputCounter(numRecordsIn);
+
+        boolean checkpointingDuringRecoveryEnabled =
+                
CheckpointingOptions.isCheckpointingDuringRecoveryEnabled(jobConfig);
+
         TypeSerializer<IN1> typeSerializer1 = 
streamConfig.getTypeSerializerIn(0, userClassloader);
         StreamTaskInput<IN1> input1 =
                 StreamTaskNetworkInputFactory.create(
@@ -96,7 +101,8 @@ public class StreamTwoInputProcessorFactory {
                         gatePartitioners,
                         taskInfo,
                         canEmitBatchOfRecords,
-                        
streamConfig.getWatermarkDeclarations(userClassloader));
+                        streamConfig.getWatermarkDeclarations(userClassloader),
+                        checkpointingDuringRecoveryEnabled);
         TypeSerializer<IN2> typeSerializer2 = 
streamConfig.getTypeSerializerIn(1, userClassloader);
         StreamTaskInput<IN2> input2 =
                 StreamTaskNetworkInputFactory.create(
@@ -109,7 +115,8 @@ public class StreamTwoInputProcessorFactory {
                         gatePartitioners,
                         taskInfo,
                         canEmitBatchOfRecords,
-                        
streamConfig.getWatermarkDeclarations(userClassloader));
+                        streamConfig.getWatermarkDeclarations(userClassloader),
+                        checkpointingDuringRecoveryEnabled);
 
         InputSelectable inputSelectable =
                 streamOperator instanceof InputSelectable ? (InputSelectable) 
streamOperator : null;
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContext.java
 
b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContext.java
new file mode 100644
index 00000000000..f2568fe5854
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContext.java
@@ -0,0 +1,227 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.runtime.io.recovery;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor;
+import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
+
+import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.util.Preconditions.checkNotNull;
+
+/**
+ * Context containing all information needed for filtering recovered channel 
state buffers.
+ *
+ * <p>This context encapsulates the input configurations, rescaling 
descriptor, and subtask
+ * information required by the channel-state-unspilling thread to perform 
record filtering during
+ * recovery.
+ *
+ * <p>Supports multiple inputs (e.g., TwoInputStreamTask, 
MultipleInputStreamTask) by storing a list
+ * of {@link InputFilterConfig} instances indexed by input index.
+ *
+ * <p>Use the constructor with empty inputConfigs or enabled=false when 
filtering is not needed.
+ */
+@Internal
+public class RecordFilterContext {
+
+    /** Configuration for filtering records on a specific input. */
+    public static class InputFilterConfig {
+        private final TypeSerializer<?> typeSerializer;
+        private final StreamPartitioner<?> partitioner;
+        private final int numberOfChannels;
+
+        /**
+         * Creates a new InputFilterConfig.
+         *
+         * @param typeSerializer Serializer for the record type.
+         * @param partitioner Partitioner used to determine record ownership.
+         * @param numberOfChannels The parallelism of the current operator.
+         */
+        public InputFilterConfig(
+                TypeSerializer<?> typeSerializer,
+                StreamPartitioner<?> partitioner,
+                int numberOfChannels) {
+            this.typeSerializer = checkNotNull(typeSerializer);
+            this.partitioner = checkNotNull(partitioner);
+            this.numberOfChannels = numberOfChannels;
+        }
+
+        public TypeSerializer<?> getTypeSerializer() {
+            return typeSerializer;
+        }
+
+        public StreamPartitioner<?> getPartitioner() {
+            return partitioner;
+        }
+
+        public int getNumberOfChannels() {
+            return numberOfChannels;
+        }
+    }
+
+    /**
+     * Input configurations indexed by gate index. Array elements may be null 
for non-network inputs
+     * (e.g., SourceInputConfig). The array length equals the total number of 
input gates.
+     */
+    private final InputFilterConfig[] inputConfigs;
+
+    /** Descriptor containing rescaling information. Never null. */
+    private final InflightDataRescalingDescriptor rescalingDescriptor;
+
+    /** Current subtask index. */
+    private final int subtaskIndex;
+
+    /** Maximum parallelism for configuring partitioners. */
+    private final int maxParallelism;
+
+    /** Temporary directories for spilling spanning records. Can be empty but 
never null. */
+    private final String[] tmpDirectories;
+
+    /** Whether unaligned checkpointing during recovery is enabled. */
+    private final boolean checkpointingDuringRecoveryEnabled;
+
+    /**
+     * Creates a new RecordFilterContext.
+     *
+     * @param inputConfigs Input configurations indexed by gate index. Array 
elements may be null
+     *     for non-network inputs. Not null itself.
+     * @param rescalingDescriptor Descriptor containing rescaling information. 
Not null.
+     * @param subtaskIndex Current subtask index.
+     * @param maxParallelism Maximum parallelism.
+     * @param tmpDirectories Temporary directories for spilling spanning 
records. Can be null
+     *     (converted to empty array).
+     * @param checkpointingDuringRecoveryEnabled Whether unaligned 
checkpointing during recovery is
+     *     enabled.
+     */
+    public RecordFilterContext(
+            InputFilterConfig[] inputConfigs,
+            InflightDataRescalingDescriptor rescalingDescriptor,
+            int subtaskIndex,
+            int maxParallelism,
+            String[] tmpDirectories,
+            boolean checkpointingDuringRecoveryEnabled) {
+        this.inputConfigs = checkNotNull(inputConfigs).clone();
+        this.rescalingDescriptor = checkNotNull(rescalingDescriptor);
+        this.subtaskIndex = subtaskIndex;
+        this.maxParallelism = maxParallelism;
+        this.tmpDirectories = tmpDirectories != null ? tmpDirectories : new 
String[0];
+        this.checkpointingDuringRecoveryEnabled = 
checkpointingDuringRecoveryEnabled;
+    }
+
+    /**
+     * Gets the input configuration for a specific gate.
+     *
+     * @param gateIndex The gate index (0-based).
+     * @return The input configuration for the specified gate, or null if the 
gate is not a network
+     *     input (e.g., SourceInputConfig).
+     * @throws IllegalArgumentException if gateIndex is out of bounds.
+     */
+    public InputFilterConfig getInputConfig(int gateIndex) {
+        checkArgument(
+                gateIndex >= 0 && gateIndex < inputConfigs.length,
+                "Invalid gate index: %s, number of gates: %s",
+                gateIndex,
+                inputConfigs.length);
+        return inputConfigs[gateIndex];
+    }
+
+    /**
+     * Gets the number of input gates.
+     *
+     * @return The number of input gates.
+     */
+    public int getNumberOfGates() {
+        return inputConfigs.length;
+    }
+
+    /**
+     * Checks whether unaligned checkpointing during recovery is enabled.
+     *
+     * @return {@code true} if enabled, {@code false} otherwise.
+     */
+    public boolean isCheckpointingDuringRecoveryEnabled() {
+        return checkpointingDuringRecoveryEnabled;
+    }
+
+    /**
+     * Gets the rescaling descriptor.
+     *
+     * @return The descriptor containing rescaling information.
+     */
+    public InflightDataRescalingDescriptor getRescalingDescriptor() {
+        return rescalingDescriptor;
+    }
+
+    /**
+     * Gets the current subtask index.
+     *
+     * @return The subtask index.
+     */
+    public int getSubtaskIndex() {
+        return subtaskIndex;
+    }
+
+    /**
+     * Gets the maximum parallelism.
+     *
+     * @return The maximum parallelism value.
+     */
+    public int getMaxParallelism() {
+        return maxParallelism;
+    }
+
+    /**
+     * Gets the temporary directories for spilling spanning records.
+     *
+     * @return The temporary directories, never null (may be empty array).
+     */
+    public String[] getTmpDirectories() {
+        return tmpDirectories;
+    }
+
+    /**
+     * Checks if a specific gate and subtask combination is ambiguous 
(requires filtering).
+     *
+     * @param gateIndex The gate index.
+     * @param oldSubtaskIndex The old subtask index.
+     * @return true if enabled and the channel is ambiguous and records need 
filtering.
+     */
+    public boolean isAmbiguous(int gateIndex, int oldSubtaskIndex) {
+        return checkpointingDuringRecoveryEnabled
+                && rescalingDescriptor.isAmbiguous(gateIndex, oldSubtaskIndex);
+    }
+
+    /**
+     * Creates a disabled RecordFilterContext for testing or when filtering is 
not needed.
+     *
+     * <p>The returned context has empty inputConfigs and enabled=false, so 
{@link
+     * #isCheckpointingDuringRecoveryEnabled()} will always return false.
+     *
+     * @return A disabled RecordFilterContext.
+     */
+    public static RecordFilterContext disabled() {
+        return new RecordFilterContext(
+                new InputFilterConfig[0],
+                InflightDataRescalingDescriptor.NO_RESCALE,
+                0,
+                0,
+                new String[0],
+                false);
+    }
+}
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/VirtualChannelRecordFilterFactory.java
 
b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/VirtualChannelRecordFilterFactory.java
new file mode 100644
index 00000000000..a2093f11437
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/VirtualChannelRecordFilterFactory.java
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.runtime.io.recovery;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.runtime.io.network.api.writer.ChannelSelector;
+import org.apache.flink.runtime.plugable.SerializationDelegate;
+import 
org.apache.flink.streaming.runtime.partitioner.ConfigurableStreamPartitioner;
+import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+/**
+ * Factory for creating record filters used in Virtual Channels during channel 
state recovery.
+ *
+ * <p>This factory provides methods to create {@link RecordFilter} instances 
that determine whether
+ * a record belongs to the current subtask based on the partitioner logic.
+ *
+ * @param <T> The type of record values.
+ */
+@Internal
+public class VirtualChannelRecordFilterFactory<T> {
+
+    private final TypeSerializer<T> typeSerializer;
+    private final StreamPartitioner<T> partitioner;
+    private final int subtaskIndex;
+    private final int numberOfChannels;
+    private final int maxParallelism;
+
+    /**
+     * Creates a new VirtualChannelRecordFilterFactory.
+     *
+     * @param typeSerializer Serializer for the record type.
+     * @param partitioner Partitioner used to determine record ownership.
+     * @param subtaskIndex Current subtask index.
+     * @param numberOfChannels Number of parallel subtasks.
+     * @param maxParallelism Maximum parallelism for configuring partitioners.
+     */
+    public VirtualChannelRecordFilterFactory(
+            TypeSerializer<T> typeSerializer,
+            StreamPartitioner<T> partitioner,
+            int subtaskIndex,
+            int numberOfChannels,
+            int maxParallelism) {
+        this.typeSerializer = typeSerializer;
+        this.partitioner = partitioner;
+        this.subtaskIndex = subtaskIndex;
+        this.numberOfChannels = numberOfChannels;
+        this.maxParallelism = maxParallelism;
+    }
+
+    /**
+     * Creates a new VirtualChannelRecordFilterFactory from a 
RecordFilterContext and input index.
+     *
+     * @param context The record filter context.
+     * @param inputIndex The input index to get configuration from.
+     * @param <T> The type of record values.
+     * @return A new factory instance.
+     */
+    @SuppressWarnings("unchecked")
+    public static <T> VirtualChannelRecordFilterFactory<T> fromContext(
+            RecordFilterContext context, int inputIndex) {
+        RecordFilterContext.InputFilterConfig inputConfig = 
context.getInputConfig(inputIndex);
+        return new VirtualChannelRecordFilterFactory<>(
+                (TypeSerializer<T>) inputConfig.getTypeSerializer(),
+                (StreamPartitioner<T>) inputConfig.getPartitioner(),
+                context.getSubtaskIndex(),
+                inputConfig.getNumberOfChannels(),
+                context.getMaxParallelism());
+    }
+
+    /**
+     * Creates a record filter for ambiguous channels that requires actual 
filtering.
+     *
+     * @return A RecordFilter that tests if a record belongs to this subtask.
+     */
+    public RecordFilter<T> createFilter() {
+        StreamPartitioner<T> configuredPartitioner = configurePartitioner();
+        @SuppressWarnings("unchecked")
+        ChannelSelector<SerializationDelegate<StreamRecord<T>>> 
channelSelector =
+                configuredPartitioner;
+        return new PartitionerRecordFilter<>(channelSelector, typeSerializer, 
subtaskIndex);
+    }
+
+    /**
+     * Creates a pass-through filter that accepts all records.
+     *
+     * @param <T> The type of record values.
+     * @return A RecordFilter that always returns true.
+     */
+    public static <T> RecordFilter<T> createPassThroughFilter() {
+        return RecordFilter.acceptAll();
+    }
+
+    /**
+     * Configures the partitioner with the correct number of channels and max 
parallelism.
+     *
+     * @return A configured copy of the partitioner.
+     */
+    private StreamPartitioner<T> configurePartitioner() {
+        StreamPartitioner<T> copy = partitioner.copy();
+        copy.setup(numberOfChannels);
+        if (copy instanceof ConfigurableStreamPartitioner) {
+            ((ConfigurableStreamPartitioner) copy).configure(maxParallelism);
+        }
+        return copy;
+    }
+}
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTask.java
 
b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTask.java
index 353582674e0..009f48b082f 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTask.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTask.java
@@ -203,6 +203,9 @@ public class OneInputStreamTask<IN, OUT> extends 
StreamTask<OUT, OneInputStreamO
         Set<AbstractInternalWatermarkDeclaration<?>> watermarkDeclarationSet =
                 
configuration.getWatermarkDeclarations(getUserCodeClassLoader());
 
+        boolean checkpointingDuringRecoveryEnabled =
+                
CheckpointingOptions.isCheckpointingDuringRecoveryEnabled(getJobConfiguration());
+
         return StreamTaskNetworkInputFactory.create(
                 inputGate,
                 inSerializer,
@@ -217,7 +220,8 @@ public class OneInputStreamTask<IN, OUT> extends 
StreamTask<OUT, OneInputStreamO
                                 .getPartitioner(),
                 getEnvironment().getTaskInfo(),
                 getCanEmitBatchOfRecords(),
-                watermarkDeclarationSet);
+                watermarkDeclarationSet,
+                checkpointingDuringRecoveryEnabled);
     }
 
     /**
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
 
b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
index a07d5ee3915..27823f7137f 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
@@ -21,6 +21,7 @@ import org.apache.flink.annotation.Internal;
 import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.operators.MailboxExecutor;
 import 
org.apache.flink.api.common.operators.ProcessingTimeService.ProcessingTimeCallback;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.configuration.CheckpointingOptions;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.configuration.NettyShuffleEnvironmentOptions;
@@ -84,6 +85,7 @@ import org.apache.flink.runtime.taskmanager.Task;
 import org.apache.flink.runtime.util.ConfigurationParserUtils;
 import org.apache.flink.streaming.api.graph.NonChainedOutput;
 import org.apache.flink.streaming.api.graph.StreamConfig;
+import org.apache.flink.streaming.api.graph.StreamEdge;
 import org.apache.flink.streaming.api.operators.InternalTimeServiceManager;
 import org.apache.flink.streaming.api.operators.InternalTimeServiceManagerImpl;
 import org.apache.flink.streaming.api.operators.StreamOperator;
@@ -94,6 +96,7 @@ import 
org.apache.flink.streaming.runtime.io.RecordWriterOutput;
 import org.apache.flink.streaming.runtime.io.StreamInputProcessor;
 import 
org.apache.flink.streaming.runtime.io.checkpointing.BarrierAlignmentUtil;
 import 
org.apache.flink.streaming.runtime.io.checkpointing.CheckpointBarrierHandler;
+import org.apache.flink.streaming.runtime.io.recovery.RecordFilterContext;
 import 
org.apache.flink.streaming.runtime.partitioner.ConfigurableStreamPartitioner;
 import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner;
 import org.apache.flink.streaming.runtime.partitioner.RebalancePartitioner;
@@ -881,7 +884,7 @@ public abstract class StreamTask<OUT, OP extends 
StreamOperator<OUT>>
         channelIOExecutor.execute(
                 () -> {
                     try {
-                        reader.readInputData(inputGates);
+                        reader.readInputData(inputGates, 
createRecordFilterContext());
                     } catch (Exception e) {
                         asyncExceptionHandler.handleAsyncException(
                                 "Unable to read channel state", e);
@@ -1956,6 +1959,83 @@ public abstract class StreamTask<OUT, OP extends 
StreamOperator<OUT>>
         return environment;
     }
 
+    /**
+     * Creates a RecordFilterContext for filtering recovered channel state 
buffers.
+     *
+     * <p>This method builds the complete context using information available 
in StreamTask,
+     * including input configurations for all network inputs.
+     *
+     * @return A RecordFilterContext with input configurations. The context 
may have empty
+     *     inputConfigs (e.g., for source tasks) or enabled=false when 
filtering is not needed.
+     */
+    protected RecordFilterContext createRecordFilterContext() {
+        boolean checkpointingDuringRecoveryEnabled =
+                
CheckpointingOptions.isCheckpointingDuringRecoveryEnabled(getJobConfiguration());
+        if (!checkpointingDuringRecoveryEnabled) {
+            return RecordFilterContext.disabled();
+        }
+
+        ClassLoader cl = getUserCodeClassLoader();
+        StreamConfig.InputConfig[] inputs = configuration.getInputs(cl);
+        List<StreamEdge> inEdges = configuration.getInPhysicalEdges(cl);
+
+        // Create array sized to match the number of physical input gates.
+        // For source tasks, this will be 0. For tasks with network inputs, 
each physical gate
+        // must have a corresponding config entry.
+        int numGates = getEnvironment().getAllInputGates().length;
+        RecordFilterContext.InputFilterConfig[] inputConfigs =
+                new RecordFilterContext.InputFilterConfig[numGates];
+
+        Preconditions.checkState(
+                numGates == inEdges.size(),
+                "Number of input gates (%s) does not match number of physical 
edges (%s)",
+                numGates,
+                inEdges.size());
+
+        // Iterate through all physical edges (inEdges) instead of logical 
inputs.
+        // This is critical for Union scenarios where multiple physical gates 
map to one logical
+        // input. The order of inEdges matches the order of physical input 
gates.
+        int numberOfChannels = 
getEnvironment().getTaskInfo().getNumberOfParallelSubtasks();
+        for (int gateIndex = 0; gateIndex < inEdges.size(); gateIndex++) {
+            StreamEdge edge = inEdges.get(gateIndex);
+            // Calculate logical input index from typeNumber
+            // typeNumber = 0 means single input, typeNumber >= 1 means 
multi-input (1-indexed)
+            int inputIndex = edge.getTypeNumber() == 0 ? 0 : 
edge.getTypeNumber() - 1;
+
+            Preconditions.checkState(
+                    inputIndex < inputs.length
+                            && inputs[inputIndex] instanceof 
StreamConfig.NetworkInputConfig,
+                    "Physical edge at gateIndex %s has invalid inputIndex %s 
or non-network input",
+                    gateIndex,
+                    inputIndex);
+
+            StreamConfig.NetworkInputConfig networkInput =
+                    (StreamConfig.NetworkInputConfig) inputs[inputIndex];
+            TypeSerializer<?> typeSerializer = 
networkInput.getTypeSerializer();
+            StreamPartitioner<?> partitioner = edge.getPartitioner();
+
+            inputConfigs[gateIndex] =
+                    new RecordFilterContext.InputFilterConfig(
+                            typeSerializer, partitioner, numberOfChannels);
+        }
+
+        for (int i = 0; i < inputConfigs.length; i++) {
+            Preconditions.checkState(
+                    inputConfigs[i] != null,
+                    "InputFilterConfig at index %s is null. "
+                            + "All physical gates must have corresponding 
configurations.",
+                    i);
+        }
+
+        return new RecordFilterContext(
+                inputConfigs,
+                
getEnvironment().getTaskStateManager().getInputRescalingDescriptor(),
+                getEnvironment().getTaskInfo().getIndexOfThisSubtask(),
+                
getEnvironment().getTaskInfo().getMaxNumberOfParallelSubtasks(),
+                getEnvironment().getIOManager().getSpillingDirectoriesPaths(),
+                true);
+    }
+
     /** Check whether records can be emitted in batch. */
     @FunctionalInterface
     public interface CanEmitBatchOfRecordsChecker {
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/GateFilterHandlerBufferOwnershipTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/GateFilterHandlerBufferOwnershipTest.java
new file mode 100644
index 00000000000..85b4fd1d48e
--- /dev/null
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/GateFilterHandlerBufferOwnershipTest.java
@@ -0,0 +1,230 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint.channel;
+
+import org.apache.flink.api.common.typeutils.base.LongSerializer;
+import org.apache.flink.core.memory.DataOutputSerializer;
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.core.memory.MemorySegmentFactory;
+import org.apache.flink.runtime.io.network.api.SubtaskConnectionDescriptor;
+import 
org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer;
+import 
org.apache.flink.runtime.io.network.api.serialization.SpillingAdaptiveSpanningRecordDeserializer;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
+import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
+import org.apache.flink.runtime.plugable.DeserializationDelegate;
+import org.apache.flink.streaming.runtime.io.recovery.RecordFilter;
+import org.apache.flink.streaming.runtime.io.recovery.VirtualChannel;
+import org.apache.flink.streaming.runtime.streamrecord.StreamElement;
+import org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+import org.junit.jupiter.api.Test;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/**
+ * Tests buffer ownership semantics of {@link 
ChannelStateFilteringHandler.GateFilterHandler}. Each
+ * test verifies that buffers are properly recycled on both success and 
failure paths.
+ */
+class GateFilterHandlerBufferOwnershipTest {
+
+    private static final int BUFFER_SIZE = 1024;
+    private static final SubtaskConnectionDescriptor KEY = new 
SubtaskConnectionDescriptor(0, 0);
+
+    @Test
+    void testSourceBufferRecycledOnSuccess() throws Exception {
+        ChannelStateFilteringHandler.GateFilterHandler<Long> handler =
+                createHandler(RecordFilter.acceptAll());
+
+        Buffer sourceBuffer = createBufferWithRecords(1L, 2L);
+        List<Buffer> result = handler.filterAndRewrite(0, 0, sourceBuffer, 
this::createEmptyBuffer);
+
+        // sourceBuffer should be recycled by the deserializer after 
consumption
+        assertThat(sourceBuffer.isRecycled()).isTrue();
+
+        // Clean up result buffers
+        result.forEach(Buffer::recycleBuffer);
+    }
+
+    @Test
+    void testSourceBufferRecycledWhenAllRecordsFilteredOut() throws Exception {
+        RecordFilter<Long> rejectAll = record -> false;
+        ChannelStateFilteringHandler.GateFilterHandler<Long> handler = 
createHandler(rejectAll);
+
+        Buffer sourceBuffer = createBufferWithRecords(1L, 2L);
+        List<Buffer> result = handler.filterAndRewrite(0, 0, sourceBuffer, 
this::createEmptyBuffer);
+
+        assertThat(result).isEmpty();
+        // sourceBuffer should still be recycled even though no output was 
produced
+        assertThat(sourceBuffer.isRecycled()).isTrue();
+    }
+
+    @Test
+    void testSourceBufferRecycledOnInvalidVirtualChannel() {
+        // Create handler with KEY=(0,0) but call with (1,1) to trigger 
IllegalStateException
+        ChannelStateFilteringHandler.GateFilterHandler<Long> handler =
+                createHandler(RecordFilter.acceptAll());
+
+        Buffer sourceBuffer = createBufferWithRecords(1L);
+
+        assertThatThrownBy(
+                        () -> handler.filterAndRewrite(1, 1, sourceBuffer, 
this::createEmptyBuffer))
+                .isInstanceOf(IllegalStateException.class);
+
+        // sourceBuffer must be recycled even when lookup fails before 
setNextBuffer
+        assertThat(sourceBuffer.isRecycled()).isTrue();
+    }
+
+    @Test
+    void testResultBuffersAndCurrentBufferRecycledOnSerializationError() 
throws Exception {
+        // Use a small buffer so that records span multiple buffers. The 
supplier fails on the
+        // second request, after the first output buffer has been filled and 
added to resultBuffers.
+        AtomicInteger bufferRequestCount = new AtomicInteger(0);
+        ChannelStateFilteringHandler.BufferSupplier failingSupplier =
+                () -> {
+                    if (bufferRequestCount.incrementAndGet() > 1) {
+                        throw new IOException("Simulated buffer allocation 
failure");
+                    }
+                    return createEmptyBuffer(13);
+                };
+
+        ChannelStateFilteringHandler.GateFilterHandler<Long> handler =
+                createHandler(RecordFilter.acceptAll());
+
+        Buffer sourceBuffer = createBufferWithRecords(1L, 2L, 3L, 4L, 5L);
+
+        // The exception should propagate; no buffer leak (no 
IllegalReferenceCountException
+        // from double-recycle).
+        assertThatThrownBy(() -> handler.filterAndRewrite(0, 0, sourceBuffer, 
failingSupplier))
+                .isInstanceOf(IOException.class)
+                .hasMessage("Simulated buffer allocation failure");
+
+        // sourceBuffer ownership was transferred to the deserializer via 
setNextBuffer().
+        // The deserializer may still hold it if it hasn't fully consumed the 
buffer before the
+        // error. Calling clear() triggers the cleanup chain:
+        // GateFilterHandler#clear() -> VirtualChannel#clear() -> 
deserializer.clear()
+        handler.clear();
+        assertThat(sourceBuffer.isRecycled()).isTrue();
+    }
+
+    /**
+     * Tests the production cleanup path: when filterAndRewrite throws 
mid-processing, the
+     * deserializer may still hold sourceBuffer. In production, 
ChannelStateFilteringHandler is used
+     * in a try-with-resources block (see {@code 
SequentialChannelStateReaderImpl#readInputData}),
+     * so its close() is guaranteed to be called, which triggers clear() on 
all GateFilterHandlers
+     * and their deserializers. This test simulates that exact pattern.
+     */
+    @Test
+    void testCloseRecyclesDeserializerHeldBufferAfterError() throws Exception {
+        AtomicInteger bufferRequestCount = new AtomicInteger(0);
+        ChannelStateFilteringHandler.BufferSupplier failingSupplier =
+                () -> {
+                    if (bufferRequestCount.incrementAndGet() > 1) {
+                        throw new IOException("Simulated buffer allocation 
failure");
+                    }
+                    return createEmptyBuffer(13);
+                };
+
+        ChannelStateFilteringHandler.GateFilterHandler<Long> gateHandler =
+                createHandler(RecordFilter.acceptAll());
+        // Wrap in ChannelStateFilteringHandler, the production-level owner
+        ChannelStateFilteringHandler filteringHandler =
+                new ChannelStateFilteringHandler(
+                        new 
ChannelStateFilteringHandler.GateFilterHandler<?>[] {gateHandler});
+
+        Buffer sourceBuffer = createBufferWithRecords(1L, 2L, 3L, 4L, 5L);
+
+        // Simulate the production try-with-resources pattern
+        assertThatThrownBy(
+                        () -> {
+                            try (ChannelStateFilteringHandler ignored = 
filteringHandler) {
+                                filteringHandler.filterAndRewrite(
+                                        0, 0, 0, sourceBuffer, 
failingSupplier);
+                            }
+                        })
+                .isInstanceOf(IOException.class)
+                .hasMessage("Simulated buffer allocation failure");
+
+        // After close(), the entire cleanup chain has fired:
+        // ChannelStateFilteringHandler.close() -> GateFilterHandler.clear()
+        //   -> VirtualChannel.clear() -> deserializer.clear() -> 
sourceBuffer.recycleBuffer()
+        assertThat(sourceBuffer.isRecycled()).isTrue();
+    }
+
+    // 
-------------------------------------------------------------------------------------------
+    // Helper methods
+    // 
-------------------------------------------------------------------------------------------
+
+    private ChannelStateFilteringHandler.GateFilterHandler<Long> createHandler(
+            RecordFilter<Long> filter) {
+        RecordDeserializer<DeserializationDelegate<StreamElement>> 
deserializer =
+                new SpillingAdaptiveSpanningRecordDeserializer<>(
+                        new String[] {System.getProperty("java.io.tmpdir")});
+        VirtualChannel<Long> vc = new VirtualChannel<>(deserializer, filter);
+
+        Map<SubtaskConnectionDescriptor, VirtualChannel<Long>> channels = new 
HashMap<>();
+        channels.put(KEY, vc);
+
+        StreamElementSerializer<Long> serializer =
+                new StreamElementSerializer<>(LongSerializer.INSTANCE);
+        return new ChannelStateFilteringHandler.GateFilterHandler<>(channels, 
serializer);
+    }
+
+    private Buffer createBufferWithRecords(Long... values) {
+        try {
+            StreamElementSerializer<Long> serializer =
+                    new StreamElementSerializer<>(LongSerializer.INSTANCE);
+            DataOutputSerializer output = new 
DataOutputSerializer(BUFFER_SIZE);
+
+            for (Long value : values) {
+                DataOutputSerializer recordOutput = new 
DataOutputSerializer(64);
+                serializer.serialize(new StreamRecord<>(value), recordOutput);
+                int recordLength = recordOutput.length();
+                output.writeInt(recordLength);
+                output.write(recordOutput.getSharedBuffer(), 0, recordLength);
+            }
+
+            byte[] data = output.getCopyOfBuffer();
+            MemorySegment segment = 
MemorySegmentFactory.allocateUnpooledSegment(BUFFER_SIZE);
+            segment.put(0, data, 0, data.length);
+
+            NetworkBuffer buffer = new NetworkBuffer(segment, 
FreeingBufferRecycler.INSTANCE);
+            buffer.setSize(data.length);
+            return buffer;
+        } catch (IOException e) {
+            throw new RuntimeException(e);
+        }
+    }
+
+    private Buffer createEmptyBuffer() {
+        return createEmptyBuffer(BUFFER_SIZE);
+    }
+
+    private Buffer createEmptyBuffer(int size) {
+        MemorySegment segment = 
MemorySegmentFactory.allocateUnpooledSegment(size);
+        return new NetworkBuffer(segment, FreeingBufferRecycler.INSTANCE);
+    }
+}
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/GateFilterHandlerTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/GateFilterHandlerTest.java
new file mode 100644
index 00000000000..f02ce35fd86
--- /dev/null
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/GateFilterHandlerTest.java
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint.channel;
+
+import org.apache.flink.api.common.typeutils.base.LongSerializer;
+import org.apache.flink.core.memory.DataOutputSerializer;
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.core.memory.MemorySegmentFactory;
+import org.apache.flink.runtime.io.network.api.SubtaskConnectionDescriptor;
+import 
org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer;
+import 
org.apache.flink.runtime.io.network.api.serialization.SpillingAdaptiveSpanningRecordDeserializer;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
+import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
+import org.apache.flink.runtime.plugable.DeserializationDelegate;
+import org.apache.flink.runtime.plugable.NonReusingDeserializationDelegate;
+import org.apache.flink.streaming.runtime.io.recovery.RecordFilter;
+import org.apache.flink.streaming.runtime.io.recovery.VirtualChannel;
+import org.apache.flink.streaming.runtime.streamrecord.StreamElement;
+import org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+import org.junit.jupiter.api.Test;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/** Tests for {@link ChannelStateFilteringHandler.GateFilterHandler}. */
+class GateFilterHandlerTest {
+
+    private static final int BUFFER_SIZE = 1024;
+    private static final SubtaskConnectionDescriptor KEY = new 
SubtaskConnectionDescriptor(0, 0);
+
+    @Test
+    void testAllRecordsPassFilter() throws Exception {
+        ChannelStateFilteringHandler.GateFilterHandler<Long> handler =
+                createHandler(RecordFilter.acceptAll());
+
+        Buffer sourceBuffer = createBufferWithRecords(1L, 2L, 3L);
+        List<Buffer> result = handler.filterAndRewrite(0, 0, sourceBuffer, 
this::createEmptyBuffer);
+
+        // deserializeBuffers consumes (recycles) each buffer via the 
deserializer
+        List<Long> values = deserializeBuffers(result);
+        assertThat(values).containsExactly(1L, 2L, 3L);
+    }
+
+    @Test
+    void testAllRecordsFilteredOut() throws Exception {
+        RecordFilter<Long> rejectAll = record -> false;
+        ChannelStateFilteringHandler.GateFilterHandler<Long> handler = 
createHandler(rejectAll);
+
+        Buffer sourceBuffer = createBufferWithRecords(1L, 2L, 3L);
+        List<Buffer> result = handler.filterAndRewrite(0, 0, sourceBuffer, 
this::createEmptyBuffer);
+
+        assertThat(result).isEmpty();
+    }
+
+    @Test
+    void testPartialFiltering() throws Exception {
+        RecordFilter<Long> keepEven = record -> record.getValue() % 2 == 0;
+        ChannelStateFilteringHandler.GateFilterHandler<Long> handler = 
createHandler(keepEven);
+
+        Buffer sourceBuffer = createBufferWithRecords(1L, 2L, 3L, 4L, 5L);
+        List<Buffer> result = handler.filterAndRewrite(0, 0, sourceBuffer, 
this::createEmptyBuffer);
+
+        List<Long> values = deserializeBuffers(result);
+        assertThat(values).containsExactly(2L, 4L);
+    }
+
+    @Test
+    void testSmallOutputBufferProducesMultipleBuffers() throws Exception {
+        // Use a very small output buffer size so records must span multiple 
buffers
+        int smallBufferSize = 8;
+        ChannelStateFilteringHandler.GateFilterHandler<Long> handler =
+                createHandler(RecordFilter.acceptAll());
+
+        Buffer sourceBuffer = createBufferWithRecords(1L, 2L, 3L);
+        List<Buffer> result =
+                handler.filterAndRewrite(
+                        0, 0, sourceBuffer, () -> 
createEmptyBuffer(smallBufferSize));
+
+        // Each Long record needs 4 bytes length + ~9 bytes data > 8-byte 
buffer
+        assertThat(result.size()).isGreaterThan(1);
+
+        List<Long> values = deserializeBuffers(result);
+        assertThat(values).containsExactly(1L, 2L, 3L);
+    }
+
+    @Test
+    void testEmptyBuffer() throws Exception {
+        ChannelStateFilteringHandler.GateFilterHandler<Long> handler =
+                createHandler(RecordFilter.acceptAll());
+
+        Buffer emptyBuffer = createEmptyBuffer();
+        emptyBuffer.setSize(0);
+
+        List<Buffer> result = handler.filterAndRewrite(0, 0, emptyBuffer, 
this::createEmptyBuffer);
+
+        assertThat(result).isEmpty();
+    }
+
+    // 
-------------------------------------------------------------------------------------------
+    // Helper methods
+    // 
-------------------------------------------------------------------------------------------
+
+    private ChannelStateFilteringHandler.GateFilterHandler<Long> createHandler(
+            RecordFilter<Long> filter) {
+        RecordDeserializer<DeserializationDelegate<StreamElement>> 
deserializer =
+                new SpillingAdaptiveSpanningRecordDeserializer<>(
+                        new String[] {System.getProperty("java.io.tmpdir")});
+        VirtualChannel<Long> vc = new VirtualChannel<>(deserializer, filter);
+
+        Map<SubtaskConnectionDescriptor, VirtualChannel<Long>> channels = new 
HashMap<>();
+        channels.put(KEY, vc);
+
+        StreamElementSerializer<Long> serializer =
+                new StreamElementSerializer<>(LongSerializer.INSTANCE);
+        return new ChannelStateFilteringHandler.GateFilterHandler<>(channels, 
serializer);
+    }
+
+    private Buffer createBufferWithRecords(Long... values) throws IOException {
+        StreamElementSerializer<Long> serializer =
+                new StreamElementSerializer<>(LongSerializer.INSTANCE);
+        return serializeRecordsToBuffer(serializer, values);
+    }
+
+    /** Serializes records into a buffer using Flink's length-prefixed format. 
*/
+    private Buffer serializeRecordsToBuffer(
+            StreamElementSerializer<Long> serializer, Long... values) throws 
IOException {
+        DataOutputSerializer output = new DataOutputSerializer(BUFFER_SIZE);
+
+        for (Long value : values) {
+            // Serialize using the same length-prefixed format as Flink
+            DataOutputSerializer recordOutput = new DataOutputSerializer(64);
+            serializer.serialize(new StreamRecord<>(value), recordOutput);
+            int recordLength = recordOutput.length();
+
+            // Write 4-byte big-endian length prefix
+            output.writeInt(recordLength);
+            // Write record bytes
+            output.write(recordOutput.getSharedBuffer(), 0, recordLength);
+        }
+
+        byte[] data = output.getCopyOfBuffer();
+        MemorySegment segment = 
MemorySegmentFactory.allocateUnpooledSegment(BUFFER_SIZE);
+        segment.put(0, data, 0, data.length);
+
+        NetworkBuffer buffer = new NetworkBuffer(segment, 
FreeingBufferRecycler.INSTANCE);
+        buffer.setSize(data.length);
+        return buffer;
+    }
+
+    private Buffer createEmptyBuffer() {
+        return createEmptyBuffer(BUFFER_SIZE);
+    }
+
+    private Buffer createEmptyBuffer(int size) {
+        MemorySegment segment = 
MemorySegmentFactory.allocateUnpooledSegment(size);
+        return new NetworkBuffer(segment, FreeingBufferRecycler.INSTANCE);
+    }
+
+    private List<Long> deserializeBuffers(List<Buffer> buffers) throws 
IOException {
+        StreamElementSerializer<Long> serializer =
+                new StreamElementSerializer<>(LongSerializer.INSTANCE);
+        
SpillingAdaptiveSpanningRecordDeserializer<DeserializationDelegate<StreamElement>>
+                deserializer =
+                        new SpillingAdaptiveSpanningRecordDeserializer<>(
+                                new String[] 
{System.getProperty("java.io.tmpdir")});
+        DeserializationDelegate<StreamElement> delegate =
+                new NonReusingDeserializationDelegate<>(serializer);
+
+        List<Long> values = new ArrayList<>();
+        for (Buffer buffer : buffers) {
+            deserializer.setNextBuffer(buffer);
+            while (true) {
+                RecordDeserializer.DeserializationResult result =
+                        deserializer.getNextRecord(delegate);
+                if (result.isFullRecord()) {
+                    StreamElement element = delegate.getInstance();
+                    if (element.isRecord()) {
+                        @SuppressWarnings("unchecked")
+                        StreamRecord<Long> record = (StreamRecord<Long>) 
element;
+                        values.add(record.getValue());
+                    }
+                }
+                if (result.isBufferConsumed()) {
+                    break;
+                }
+            }
+        }
+        return values;
+    }
+}
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/InputChannelRecoveredStateHandlerTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/InputChannelRecoveredStateHandlerTest.java
index e2b1d69c56a..39ce6c7d4bf 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/InputChannelRecoveredStateHandlerTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/InputChannelRecoveredStateHandlerTest.java
@@ -77,7 +77,8 @@ class InputChannelRecoveredStateHandlerTest extends 
RecoveredChannelStateHandler
                                     InflightDataRescalingDescriptor
                                             
.InflightDataGateOrPartitionRescalingDescriptor
                                             .MappingType.IDENTITY)
-                        }));
+                        }),
+                null);
     }
 
     private InputChannelRecoveredStateHandler buildMultiChannelHandler() {
@@ -103,7 +104,8 @@ class InputChannelRecoveredStateHandlerTest extends 
RecoveredChannelStateHandler
                                     InflightDataRescalingDescriptor
                                             
.InflightDataGateOrPartitionRescalingDescriptor
                                             .MappingType.RESCALING)
-                        }));
+                        }),
+                null);
     }
 
     @Test
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImplTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImplTest.java
index 5f05af0d92e..d80442b8a06 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImplTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImplTest.java
@@ -40,6 +40,7 @@ import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.state.InputChannelStateHandle;
 import org.apache.flink.runtime.state.ResultSubpartitionStateHandle;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
+import org.apache.flink.streaming.runtime.io.recovery.RecordFilterContext;
 import org.apache.flink.testutils.junit.extensions.parameterized.Parameter;
 import 
org.apache.flink.testutils.junit.extensions.parameterized.ParameterizedTestExtension;
 import org.apache.flink.testutils.junit.extensions.parameterized.Parameters;
@@ -143,7 +144,7 @@ public class SequentialChannelStateReaderImplTest {
 
         withInputGates(
                 gates -> {
-                    reader.readInputData(gates);
+                    reader.readInputData(gates, 
RecordFilterContext.disabled());
                     assertBuffersEquals(inputChannelsData, 
collectBuffers(gates));
                     assertConsumed(gates);
                 });
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ChannelPersistenceITCase.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ChannelPersistenceITCase.java
index 53c05b19613..f64a4d9fb9c 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ChannelPersistenceITCase.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ChannelPersistenceITCase.java
@@ -49,6 +49,7 @@ import 
org.apache.flink.runtime.io.network.partition.consumer.SingleInputGateBui
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.state.storage.JobManagerCheckpointStorage;
+import org.apache.flink.streaming.runtime.io.recovery.RecordFilterContext;
 import org.apache.flink.util.function.SupplierWithException;
 
 import org.junit.jupiter.api.Test;
@@ -119,7 +120,7 @@ class ChannelPersistenceITCase {
         try {
             int numChannels = 1;
             InputGate gate = buildGate(networkBufferPool, numChannels);
-            reader.readInputData(new InputGate[] {gate});
+            reader.readInputData(new InputGate[] {gate}, 
RecordFilterContext.disabled());
             assertThat(collectBytes(gate::pollNext, BufferOrEvent::getBuffer))
                     .isEqualTo(inputChannelInfoData);
 
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContextTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContextTest.java
new file mode 100644
index 00000000000..15b97d433a2
--- /dev/null
+++ 
b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContextTest.java
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.runtime.io.recovery;
+
+import org.apache.flink.api.common.typeutils.base.LongSerializer;
+import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor;
+import org.apache.flink.runtime.checkpoint.RescaleMappings;
+import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner;
+
+import org.junit.jupiter.api.Test;
+
+import static 
org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptorUtil.mappings;
+import static 
org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptorUtil.rescalingDescriptor;
+import static 
org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptorUtil.set;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/** Tests for {@link RecordFilterContext}. */
+class RecordFilterContextTest {
+
+    @Test
+    void testDisabledContextHasNoGates() {
+        RecordFilterContext disabled = RecordFilterContext.disabled();
+        assertThat(disabled.getNumberOfGates()).isEqualTo(0);
+        assertThat(disabled.isCheckpointingDuringRecoveryEnabled()).isFalse();
+    }
+
+    @Test
+    void testGetInputConfigReturnsCorrectConfig() {
+        RecordFilterContext.InputFilterConfig config =
+                new RecordFilterContext.InputFilterConfig(
+                        LongSerializer.INSTANCE, new ForwardPartitioner<>(), 
4);
+
+        RecordFilterContext context =
+                new RecordFilterContext(
+                        new RecordFilterContext.InputFilterConfig[] {config},
+                        InflightDataRescalingDescriptor.NO_RESCALE,
+                        0,
+                        128,
+                        new String[] {"/tmp"},
+                        true);
+
+        assertThat(context.getNumberOfGates()).isEqualTo(1);
+        assertThat(context.getInputConfig(0)).isSameAs(config);
+        assertThat(context.getSubtaskIndex()).isEqualTo(0);
+        assertThat(context.getMaxParallelism()).isEqualTo(128);
+        assertThat(context.isCheckpointingDuringRecoveryEnabled()).isTrue();
+    }
+
+    @Test
+    void testGetInputConfigThrowsForInvalidIndex() {
+        RecordFilterContext context =
+                new RecordFilterContext(
+                        new RecordFilterContext.InputFilterConfig[0],
+                        InflightDataRescalingDescriptor.NO_RESCALE,
+                        0,
+                        128,
+                        null,
+                        false);
+
+        assertThatThrownBy(() -> context.getInputConfig(0))
+                .isInstanceOf(IllegalArgumentException.class);
+        assertThatThrownBy(() -> context.getInputConfig(-1))
+                .isInstanceOf(IllegalArgumentException.class);
+    }
+
+    @Test
+    void testNullTmpDirectoriesConvertedToEmptyArray() {
+        RecordFilterContext context =
+                new RecordFilterContext(
+                        new RecordFilterContext.InputFilterConfig[0],
+                        InflightDataRescalingDescriptor.NO_RESCALE,
+                        0,
+                        128,
+                        null,
+                        false);
+
+        assertThat(context.getTmpDirectories()).isNotNull().isEmpty();
+    }
+
+    @Test
+    void testIsAmbiguousWhenDisabled() {
+        // Create a rescaling descriptor with an ambiguous subtask (oldSubtask 
0 is ambiguous)
+        RescaleMappings mapping = mappings(new int[] {0});
+        InflightDataRescalingDescriptor descriptor =
+                rescalingDescriptor(new int[] {0}, new RescaleMappings[] 
{mapping}, set(0));
+
+        // When checkpointingDuringRecoveryEnabled is false, isAmbiguous 
should always return false
+        RecordFilterContext context =
+                new RecordFilterContext(
+                        new RecordFilterContext.InputFilterConfig[0],
+                        descriptor,
+                        0,
+                        128,
+                        null,
+                        false);
+
+        assertThat(context.isAmbiguous(0, 0)).isFalse();
+    }
+
+    @Test
+    void testIsAmbiguousWhenEnabled() {
+        // Create a rescaling descriptor with an ambiguous subtask (oldSubtask 
0 is ambiguous)
+        RescaleMappings mapping = mappings(new int[] {0});
+        InflightDataRescalingDescriptor descriptor =
+                rescalingDescriptor(new int[] {0}, new RescaleMappings[] 
{mapping}, set(0));
+
+        // When checkpointingDuringRecoveryEnabled is true, isAmbiguous 
follows rescalingDescriptor
+        RecordFilterContext context =
+                new RecordFilterContext(
+                        new RecordFilterContext.InputFilterConfig[0],
+                        descriptor,
+                        0,
+                        128,
+                        null,
+                        true);
+
+        assertThat(context.isAmbiguous(0, 0)).isTrue();
+    }
+
+    @Test
+    void testIsAmbiguousForNonAmbiguousSubtask() {
+        // Create a rescaling descriptor where oldSubtask 0 is ambiguous but 
oldSubtask 1 is not
+        RescaleMappings mapping = mappings(new int[] {0});
+        InflightDataRescalingDescriptor descriptor =
+                rescalingDescriptor(new int[] {0, 1}, new RescaleMappings[] 
{mapping}, set(0));
+
+        RecordFilterContext context =
+                new RecordFilterContext(
+                        new RecordFilterContext.InputFilterConfig[0],
+                        descriptor,
+                        0,
+                        128,
+                        null,
+                        true);
+
+        // oldSubtask 0 is ambiguous
+        assertThat(context.isAmbiguous(0, 0)).isTrue();
+        // oldSubtask 1 is NOT in the ambiguous set
+        assertThat(context.isAmbiguous(0, 1)).isFalse();
+    }
+
+    @Test
+    void testInputFilterConfigGetters() {
+        ForwardPartitioner<Long> partitioner = new ForwardPartitioner<>();
+        RecordFilterContext.InputFilterConfig config =
+                new 
RecordFilterContext.InputFilterConfig(LongSerializer.INSTANCE, partitioner, 4);
+
+        
assertThat(config.getTypeSerializer()).isSameAs(LongSerializer.INSTANCE);
+        assertThat(config.getPartitioner()).isSameAs(partitioner);
+        assertThat(config.getNumberOfChannels()).isEqualTo(4);
+    }
+
+    @Test
+    void testMultipleGateConfigs() {
+        RecordFilterContext.InputFilterConfig config0 =
+                new RecordFilterContext.InputFilterConfig(
+                        LongSerializer.INSTANCE, new ForwardPartitioner<>(), 
2);
+        RecordFilterContext.InputFilterConfig config1 =
+                new RecordFilterContext.InputFilterConfig(
+                        LongSerializer.INSTANCE, new ForwardPartitioner<>(), 
4);
+
+        RecordFilterContext context =
+                new RecordFilterContext(
+                        new RecordFilterContext.InputFilterConfig[] {config0, 
config1},
+                        InflightDataRescalingDescriptor.NO_RESCALE,
+                        1,
+                        256,
+                        new String[] {"/tmp"},
+                        false);
+
+        assertThat(context.getNumberOfGates()).isEqualTo(2);
+        assertThat(context.getInputConfig(0)).isSameAs(config0);
+        assertThat(context.getInputConfig(1)).isSameAs(config1);
+        
assertThat(context.getInputConfig(0).getNumberOfChannels()).isEqualTo(2);
+        
assertThat(context.getInputConfig(1).getNumberOfChannels()).isEqualTo(4);
+    }
+}
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/VirtualChannelRecordFilterFactoryTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/VirtualChannelRecordFilterFactoryTest.java
new file mode 100644
index 00000000000..bffc42e3329
--- /dev/null
+++ 
b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/VirtualChannelRecordFilterFactoryTest.java
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.runtime.io.recovery;
+
+import org.apache.flink.api.common.typeutils.base.LongSerializer;
+import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor;
+import org.apache.flink.streaming.runtime.partitioner.RebalancePartitioner;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+import org.junit.jupiter.api.Test;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/** Tests for {@link VirtualChannelRecordFilterFactory}. */
+class VirtualChannelRecordFilterFactoryTest {
+
+    @Test
+    void testCreatePassThroughFilter() {
+        RecordFilter<Long> filter = 
VirtualChannelRecordFilterFactory.createPassThroughFilter();
+        assertThat(filter.filter(new StreamRecord<>(0L))).isTrue();
+        assertThat(filter.filter(new StreamRecord<>(1L))).isTrue();
+        assertThat(filter.filter(new StreamRecord<>(42L))).isTrue();
+    }
+
+    @Test
+    void testCreateFilterProducesPartitionerBasedFilter() {
+        RebalancePartitioner<Long> partitioner = new RebalancePartitioner<>();
+
+        VirtualChannelRecordFilterFactory<Long> factory =
+                new VirtualChannelRecordFilterFactory<>(
+                        LongSerializer.INSTANCE, partitioner, 0, 2, 128);
+
+        RecordFilter<Long> filter = factory.createFilter();
+        // The filter should be a PartitionerRecordFilter that filters based 
on partitioner
+        assertThat(filter).isInstanceOf(PartitionerRecordFilter.class);
+    }
+
+    @Test
+    void testFromContextCreatesFactory() {
+        RebalancePartitioner<Long> partitioner = new RebalancePartitioner<>();
+        RecordFilterContext.InputFilterConfig config =
+                new 
RecordFilterContext.InputFilterConfig(LongSerializer.INSTANCE, partitioner, 4);
+
+        RecordFilterContext context =
+                new RecordFilterContext(
+                        new RecordFilterContext.InputFilterConfig[] {config},
+                        InflightDataRescalingDescriptor.NO_RESCALE,
+                        1,
+                        128,
+                        new String[] {"/tmp"},
+                        true);
+
+        VirtualChannelRecordFilterFactory<Long> factory =
+                VirtualChannelRecordFilterFactory.fromContext(context, 0);
+        RecordFilter<Long> filter = factory.createFilter();
+
+        // The filter should be a functional PartitionerRecordFilter
+        assertThat(filter).isInstanceOf(PartitionerRecordFilter.class);
+    }
+
+    @Test
+    void testEachFilterCallCreatesIndependentFilter() {
+        RebalancePartitioner<Long> partitioner = new RebalancePartitioner<>();
+
+        VirtualChannelRecordFilterFactory<Long> factory =
+                new VirtualChannelRecordFilterFactory<>(
+                        LongSerializer.INSTANCE, partitioner, 0, 2, 128);
+
+        RecordFilter<Long> filter1 = factory.createFilter();
+        RecordFilter<Long> filter2 = factory.createFilter();
+
+        // Each call should produce a distinct filter instance (using a copy 
of the partitioner)
+        assertThat(filter1).isNotSameAs(filter2);
+    }
+}

Reply via email to