This is an automated email from the ASF dual-hosted git repository. fanrui pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit 23ebbb5d2bb0f3ffb0e07798bd3d8b0c1808331e Author: Rui Fan <[email protected]> AuthorDate: Wed Feb 18 21:25:46 2026 +0100 [FLINK-38930][checkpoint] Filtering record before processing without spilling strategy Core filtering mechanism for recovered channel state buffers: - ChannelStateFilteringHandler with per-gate GateFilterHandler - RecordFilterContext with VirtualChannelRecordFilterFactory - Partial data check in SequentialChannelStateReaderImpl - Fix RecordFilterContext for Union downscale scenario --- .../channel/ChannelStateFilteringHandler.java | 458 +++++++++++++++++++++ .../channel/RecoveredChannelStateHandler.java | 64 ++- .../channel/SequentialChannelStateReader.java | 13 +- .../channel/SequentialChannelStateReaderImpl.java | 27 +- .../io/StreamMultipleInputProcessorFactory.java | 8 +- .../runtime/io/StreamTaskNetworkInputFactory.java | 9 +- .../runtime/io/StreamTwoInputProcessorFactory.java | 11 +- .../runtime/io/recovery/RecordFilterContext.java | 227 ++++++++++ .../VirtualChannelRecordFilterFactory.java | 123 ++++++ .../runtime/tasks/OneInputStreamTask.java | 6 +- .../flink/streaming/runtime/tasks/StreamTask.java | 82 +++- .../GateFilterHandlerBufferOwnershipTest.java | 230 +++++++++++ .../checkpoint/channel/GateFilterHandlerTest.java | 213 ++++++++++ .../InputChannelRecoveredStateHandlerTest.java | 6 +- .../SequentialChannelStateReaderImplTest.java | 3 +- .../runtime/state/ChannelPersistenceITCase.java | 3 +- .../io/recovery/RecordFilterContextTest.java | 193 +++++++++ .../VirtualChannelRecordFilterFactoryTest.java | 90 ++++ 18 files changed, 1739 insertions(+), 27 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateFilteringHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateFilteringHandler.java new file mode 100644 index 00000000000..b257c3b4054 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateFilteringHandler.java @@ -0,0 +1,458 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.core.memory.DataOutputSerializer; +import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor; +import org.apache.flink.runtime.checkpoint.RescaleMappings; +import org.apache.flink.runtime.io.network.api.SubtaskConnectionDescriptor; +import org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer; +import org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer.DeserializationResult; +import org.apache.flink.runtime.io.network.api.serialization.SpillingAdaptiveSpanningRecordDeserializer; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.partition.consumer.InputGate; +import org.apache.flink.runtime.plugable.DeserializationDelegate; +import org.apache.flink.runtime.plugable.NonReusingDeserializationDelegate; +import org.apache.flink.streaming.runtime.io.recovery.RecordFilter; +import org.apache.flink.streaming.runtime.io.recovery.RecordFilterContext; +import org.apache.flink.streaming.runtime.io.recovery.VirtualChannel; +import org.apache.flink.streaming.runtime.io.recovery.VirtualChannelRecordFilterFactory; +import org.apache.flink.streaming.runtime.streamrecord.StreamElement; +import org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer; + +import javax.annotation.Nullable; + +import java.io.Closeable; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * Filters recovered channel state buffers during the channel-state-unspilling phase, removing + * records that do not belong to the current subtask after rescaling. + * + * <p>Uses a per-gate architecture: each {@link InputGate} gets its own {@link GateFilterHandler} + * with the correct serializer, so multi-input tasks (e.g., TwoInputStreamTask) correctly + * deserialize different record types on different gates. + */ +@Internal +public class ChannelStateFilteringHandler implements Closeable { + + // Wildcard allows heterogeneous record types across gates. + private final GateFilterHandler<?>[] gateHandlers; + + ChannelStateFilteringHandler(GateFilterHandler<?>[] gateHandlers) { + this.gateHandlers = checkNotNull(gateHandlers); + } + + /** + * Creates a handler from the recovery context, building per-gate virtual channels based on + * rescaling descriptors. Returns {@code null} if no filtering is needed (e.g., source tasks or + * no rescaling). + */ + @Nullable + public static ChannelStateFilteringHandler createFromContext( + RecordFilterContext filterContext, InputGate[] inputGates) { + // Source tasks have no network inputs + if (filterContext.getNumberOfGates() == 0) { + return null; + } + + InflightDataRescalingDescriptor rescalingDescriptor = + filterContext.getRescalingDescriptor(); + + GateFilterHandler<?>[] gateHandlers = new GateFilterHandler<?>[inputGates.length]; + boolean hasAnyVirtualChannels = false; + + for (int gateIndex = 0; gateIndex < inputGates.length; gateIndex++) { + gateHandlers[gateIndex] = + createGateHandler(filterContext, inputGates, rescalingDescriptor, gateIndex); + if (gateHandlers[gateIndex] != null) { + hasAnyVirtualChannels = true; + } + } + + if (!hasAnyVirtualChannels) { + return null; + } + + return new ChannelStateFilteringHandler(gateHandlers); + } + + /** + * Filters a recovered buffer from the specified virtual channel, returning new buffers + * containing only the records that belong to the current subtask. + * + * <p>One source buffer may produce 0 to N result buffers: 0 if all records are filtered out, + * and potentially more than 1 when a spanning record completes in this buffer. The deserializer + * caches partial record data from previous buffers, so the output may contain data that was not + * in the current source buffer, causing the total output size to exceed one buffer capacity. + * This can happen with any spanning record regardless of its size. + * + * @return filtered buffers, possibly empty if all records were filtered out. + */ + public List<Buffer> filterAndRewrite( + int gateIndex, + int oldSubtaskIndex, + int oldChannelIndex, + Buffer sourceBuffer, + BufferSupplier bufferSupplier) + throws IOException, InterruptedException { + + if (gateIndex < 0 || gateIndex >= gateHandlers.length) { + throw new IllegalStateException( + "Invalid gateIndex: " + + gateIndex + + ", number of gates: " + + gateHandlers.length); + } + + GateFilterHandler<?> gateHandler = gateHandlers[gateIndex]; + if (gateHandler == null) { + throw new IllegalStateException( + "No handler for gateIndex " + + gateIndex + + ". This gate is not a network input and should not have recovered buffers."); + } + return gateHandler.filterAndRewrite( + oldSubtaskIndex, oldChannelIndex, sourceBuffer, bufferSupplier); + } + + /** Returns {@code true} if any virtual channel has a partial (spanning) record pending. */ + public boolean hasPartialData() { + for (GateFilterHandler<?> handler : gateHandlers) { + if (handler != null && handler.hasPartialData()) { + return true; + } + } + return false; + } + + @Override + public void close() { + for (GateFilterHandler<?> handler : gateHandlers) { + if (handler != null) { + handler.clear(); + } + } + } + + // ------------------------------------------------------------------------------------------- + // Private static helper methods + // ------------------------------------------------------------------------------------------- + + /** + * Creates a {@link GateFilterHandler} for a single gate. The method-level type parameter + * ensures type safety within each gate while allowing different gates to have different types. + */ + @SuppressWarnings("unchecked") + @Nullable + private static <T> GateFilterHandler<T> createGateHandler( + RecordFilterContext filterContext, + InputGate[] inputGates, + InflightDataRescalingDescriptor rescalingDescriptor, + int gateIndex) { + RecordFilterContext.InputFilterConfig inputConfig = filterContext.getInputConfig(gateIndex); + if (inputConfig == null) { + throw new IllegalStateException( + "No InputFilterConfig for gateIndex " + + gateIndex + + ". This indicates a bug in RecordFilterContext initialization."); + } + + InputGate gate = inputGates[gateIndex]; + int[] oldSubtaskIndexes = rescalingDescriptor.getOldSubtaskIndexes(gateIndex); + RescaleMappings channelMapping = rescalingDescriptor.getChannelMapping(gateIndex); + + TypeSerializer<T> typeSerializer = (TypeSerializer<T>) inputConfig.getTypeSerializer(); + StreamElementSerializer<T> elementSerializer = + new StreamElementSerializer<>(typeSerializer); + + VirtualChannelRecordFilterFactory<T> filterFactory = + VirtualChannelRecordFilterFactory.fromContext(filterContext, gateIndex); + + Map<SubtaskConnectionDescriptor, VirtualChannel<T>> gateVirtualChannels = new HashMap<>(); + + for (int oldSubtaskIndex : oldSubtaskIndexes) { + int numChannels = gate.getNumberOfInputChannels(); + int[] oldChannelIndexes = getOldChannelIndexes(channelMapping, numChannels); + + for (int oldChannelIndex : oldChannelIndexes) { + SubtaskConnectionDescriptor key = + new SubtaskConnectionDescriptor(oldSubtaskIndex, oldChannelIndex); + + if (gateVirtualChannels.containsKey(key)) { + continue; + } + + // Only ambiguous channels need actual filtering; non-ambiguous ones pass through + boolean isAmbiguous = rescalingDescriptor.isAmbiguous(gateIndex, oldSubtaskIndex); + + RecordFilter<T> recordFilter = + isAmbiguous + ? filterFactory.createFilter() + : VirtualChannelRecordFilterFactory.createPassThroughFilter(); + + RecordDeserializer<DeserializationDelegate<StreamElement>> deserializer = + createDeserializer(filterContext.getTmpDirectories()); + + VirtualChannel<T> vc = new VirtualChannel<>(deserializer, recordFilter); + gateVirtualChannels.put(key, vc); + } + } + + if (gateVirtualChannels.isEmpty()) { + return null; + } + + return new GateFilterHandler<>(gateVirtualChannels, elementSerializer); + } + + /** + * Collects all old channel indexes that are mapped from any new channel index in this gate. + * channelMapping is new-to-old, so we iterate new indexes and collect their old counterparts. + */ + private static int[] getOldChannelIndexes(RescaleMappings channelMapping, int numChannels) { + List<Integer> oldIndexes = new ArrayList<>(); + for (int newIndex = 0; newIndex < numChannels; newIndex++) { + int[] mapped = channelMapping.getMappedIndexes(newIndex); + for (int oldIndex : mapped) { + if (!oldIndexes.contains(oldIndex)) { + oldIndexes.add(oldIndex); + } + } + } + return oldIndexes.stream().mapToInt(Integer::intValue).toArray(); + } + + private static RecordDeserializer<DeserializationDelegate<StreamElement>> createDeserializer( + String[] tmpDirectories) { + if (tmpDirectories != null && tmpDirectories.length > 0) { + return new SpillingAdaptiveSpanningRecordDeserializer<>(tmpDirectories); + } else { + String[] defaultDirs = new String[] {System.getProperty("java.io.tmpdir")}; + return new SpillingAdaptiveSpanningRecordDeserializer<>(defaultDirs); + } + } + + // ------------------------------------------------------------------------------------------- + // Inner classes + // ------------------------------------------------------------------------------------------- + + /** Provides buffers for re-serializing filtered records. Implementations may block. */ + @FunctionalInterface + public interface BufferSupplier { + Buffer requestBufferBlocking() throws IOException, InterruptedException; + } + + /** + * Handles record filtering for a single input gate. Each gate has its own serializer and set of + * virtual channels, allowing different gates to handle different record types independently. + */ + static class GateFilterHandler<T> { + + private final Map<SubtaskConnectionDescriptor, VirtualChannel<T>> virtualChannels; + private final StreamElementSerializer<T> serializer; + private final DeserializationDelegate<StreamElement> deserializationDelegate; + private final DataOutputSerializer outputSerializer; + private final byte[] lengthBuffer = new byte[4]; + + GateFilterHandler( + Map<SubtaskConnectionDescriptor, VirtualChannel<T>> virtualChannels, + StreamElementSerializer<T> serializer) { + this.virtualChannels = checkNotNull(virtualChannels); + this.serializer = checkNotNull(serializer); + this.deserializationDelegate = new NonReusingDeserializationDelegate<>(serializer); + this.outputSerializer = new DataOutputSerializer(128); + } + + /** + * Deserializes records from {@code sourceBuffer}, applies the virtual channel's record + * filter, and immediately re-serializes each surviving record into output buffers. + */ + List<Buffer> filterAndRewrite( + int oldSubtaskIndex, + int oldChannelIndex, + Buffer sourceBuffer, + BufferSupplier bufferSupplier) + throws IOException, InterruptedException { + + boolean sourceBufferOwnershipTransferred = false; + List<Buffer> resultBuffers = new ArrayList<>(); + Buffer currentBuffer = null; + try { + SubtaskConnectionDescriptor key = + new SubtaskConnectionDescriptor(oldSubtaskIndex, oldChannelIndex); + VirtualChannel<T> vc = virtualChannels.get(key); + if (vc == null) { + throw new IllegalStateException( + "No VirtualChannel found for key: " + + key + + "; known channels are " + + virtualChannels.keySet()); + } + + vc.setNextBuffer(sourceBuffer); + sourceBufferOwnershipTransferred = true; + + while (true) { + DeserializationResult result = vc.getNextRecord(deserializationDelegate); + if (result.isFullRecord()) { + if (currentBuffer == null) { + currentBuffer = bufferSupplier.requestBufferBlocking(); + } + currentBuffer = + serializeElement( + deserializationDelegate.getInstance(), + currentBuffer, + resultBuffers, + bufferSupplier); + } + if (result.isBufferConsumed()) { + break; + } + } + + if (currentBuffer != null) { + if (currentBuffer.readableBytes() > 0) { + resultBuffers.add(currentBuffer); + } else { + currentBuffer.recycleBuffer(); + } + currentBuffer = null; + } + + return resultBuffers; + } catch (Throwable t) { + if (!sourceBufferOwnershipTransferred) { + sourceBuffer.recycleBuffer(); + } + // Avoid double-recycle: currentBuffer may already be the last element in + // resultBuffers if serializeElement added it before the exception. + if (currentBuffer != null + && (resultBuffers.isEmpty() + || resultBuffers.get(resultBuffers.size() - 1) != currentBuffer)) { + currentBuffer.recycleBuffer(); + } + for (Buffer buf : resultBuffers) { + buf.recycleBuffer(); + } + resultBuffers.clear(); + throw t; + } + } + + /** + * Serializes a single stream element into the current buffer using the length-prefixed + * format (4-byte big-endian length + record bytes) expected by Flink's record + * deserializers. Spills into new buffers from {@code bufferSupplier} when needed. + * + * @return the buffer to continue writing into (may differ from the input buffer). + */ + private Buffer serializeElement( + StreamElement element, + Buffer currentBuffer, + List<Buffer> resultBuffers, + BufferSupplier bufferSupplier) + throws IOException, InterruptedException { + outputSerializer.clear(); + serializer.serialize(element, outputSerializer); + int recordLength = outputSerializer.length(); + + writeLengthToBuffer(recordLength); + currentBuffer = + writeDataToBuffer( + lengthBuffer, 0, 4, currentBuffer, resultBuffers, bufferSupplier); + + byte[] serializedData = outputSerializer.getSharedBuffer(); + currentBuffer = + writeDataToBuffer( + serializedData, + 0, + recordLength, + currentBuffer, + resultBuffers, + bufferSupplier); + return currentBuffer; + } + + private void writeLengthToBuffer(int length) { + lengthBuffer[0] = (byte) (length >> 24); + lengthBuffer[1] = (byte) (length >> 16); + lengthBuffer[2] = (byte) (length >> 8); + lengthBuffer[3] = (byte) length; + } + + /** + * Writes data to the current buffer, spilling into new buffers from {@code bufferSupplier} + * when the current one is full. + * + * @return the buffer to continue writing into (may differ from the input buffer). + */ + private Buffer writeDataToBuffer( + byte[] data, + int dataOffset, + int dataLength, + Buffer currentBuffer, + List<Buffer> resultBuffers, + BufferSupplier bufferSupplier) + throws IOException, InterruptedException { + int offset = dataOffset; + int remaining = dataLength; + + while (remaining > 0) { + int writableBytes = currentBuffer.getMaxCapacity() - currentBuffer.getSize(); + + if (writableBytes == 0) { + // Buffer is full, transfer ownership to resultBuffers + resultBuffers.add(currentBuffer); + currentBuffer = bufferSupplier.requestBufferBlocking(); + writableBytes = currentBuffer.getMaxCapacity(); + } + + int bytesToWrite = Math.min(remaining, writableBytes); + currentBuffer + .getMemorySegment() + .put( + currentBuffer.getMemorySegmentOffset() + currentBuffer.getSize(), + data, + offset, + bytesToWrite); + currentBuffer.setSize(currentBuffer.getSize() + bytesToWrite); + + offset += bytesToWrite; + remaining -= bytesToWrite; + } + return currentBuffer; + } + + boolean hasPartialData() { + return virtualChannels.values().stream().anyMatch(VirtualChannel::hasPartialData); + } + + void clear() { + virtualChannels.values().forEach(VirtualChannel::clear); + } + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandler.java index 85fc31db4bc..31db728bc48 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandler.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/RecoveredChannelStateHandler.java @@ -31,6 +31,7 @@ import org.apache.flink.runtime.io.network.partition.consumer.InputGate; import org.apache.flink.runtime.io.network.partition.consumer.RecoveredInputChannel; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import java.io.IOException; import java.util.HashMap; @@ -63,7 +64,7 @@ interface RecoveredChannelStateHandler<Info, Context> extends AutoCloseable { * case of an error. */ void recover(Info info, int oldSubtaskIndex, BufferWithContext<Context> bufferWithContext) - throws IOException; + throws IOException, InterruptedException; } class InputChannelRecoveredStateHandler @@ -75,10 +76,19 @@ class InputChannelRecoveredStateHandler private final Map<InputChannelInfo, RecoveredInputChannel> rescaledChannels = new HashMap<>(); private final Map<Integer, RescaleMappings> oldToNewMappings = new HashMap<>(); + /** + * Optional filtering handler for filtering recovered buffers. When non-null, filtering is + * performed during recovery in the channel-state-unspilling thread. + */ + @Nullable private final ChannelStateFilteringHandler filteringHandler; + InputChannelRecoveredStateHandler( - InputGate[] inputGates, InflightDataRescalingDescriptor channelMapping) { + InputGate[] inputGates, + InflightDataRescalingDescriptor channelMapping, + @Nullable ChannelStateFilteringHandler filteringHandler) { this.inputGates = inputGates; this.channelMapping = channelMapping; + this.filteringHandler = filteringHandler; } @Override @@ -95,23 +105,57 @@ class InputChannelRecoveredStateHandler InputChannelInfo channelInfo, int oldSubtaskIndex, BufferWithContext<Buffer> bufferWithContext) - throws IOException { + throws IOException, InterruptedException { Buffer buffer = bufferWithContext.context; try { if (buffer.readableBytes() > 0) { RecoveredInputChannel channel = getMappedChannels(channelInfo); - channel.onRecoveredStateBuffer( - EventSerializer.toBuffer( - new SubtaskConnectionDescriptor( - oldSubtaskIndex, channelInfo.getInputChannelIdx()), - false)); - channel.onRecoveredStateBuffer(buffer.retainBuffer()); + + if (filteringHandler != null) { + recoverWithFiltering( + channel, channelInfo, oldSubtaskIndex, buffer.retainBuffer()); + } else { + channel.onRecoveredStateBuffer( + EventSerializer.toBuffer( + new SubtaskConnectionDescriptor( + oldSubtaskIndex, channelInfo.getInputChannelIdx()), + false)); + channel.onRecoveredStateBuffer(buffer.retainBuffer()); + } } } finally { buffer.recycleBuffer(); } } + private void recoverWithFiltering( + RecoveredInputChannel channel, + InputChannelInfo channelInfo, + int oldSubtaskIndex, + Buffer retainedBuffer) + throws IOException, InterruptedException { + checkState(filteringHandler != null, "filtering handler not set."); + List<Buffer> filteredBuffers = + filteringHandler.filterAndRewrite( + channelInfo.getGateIdx(), + oldSubtaskIndex, + channelInfo.getInputChannelIdx(), + retainedBuffer, + channel::requestBufferBlocking); + + int i = 0; + try { + for (; i < filteredBuffers.size(); i++) { + channel.onRecoveredStateBuffer(filteredBuffers.get(i)); + } + } catch (Throwable t) { + for (int j = i; j < filteredBuffers.size(); j++) { + filteredBuffers.get(j).recycleBuffer(); + } + throw t; + } + } + @Override public void close() throws IOException { // note that we need to finish all RecoveredInputChannels, not just those with state @@ -191,7 +235,7 @@ class ResultSubpartitionRecoveredStateHandler ResultSubpartitionInfo subpartitionInfo, int oldSubtaskIndex, BufferWithContext<BufferBuilder> bufferWithContext) - throws IOException { + throws IOException, InterruptedException { try (BufferBuilder bufferBuilder = bufferWithContext.context; BufferConsumer bufferConsumer = bufferBuilder.createBufferConsumerFromBeginning()) { bufferBuilder.finish(); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReader.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReader.java index 7adf6d62946..547b60ef93a 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReader.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReader.java @@ -20,6 +20,7 @@ package org.apache.flink.runtime.checkpoint.channel; import org.apache.flink.annotation.Internal; import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; import org.apache.flink.runtime.io.network.partition.consumer.InputGate; +import org.apache.flink.streaming.runtime.io.recovery.RecordFilterContext; import java.io.IOException; @@ -27,7 +28,14 @@ import java.io.IOException; @Internal public interface SequentialChannelStateReader extends AutoCloseable { - void readInputData(InputGate[] inputGates) throws IOException, InterruptedException; + /** + * Reads input channel state with filtering support. + * + * @param inputGates The input gates to recover state for. + * @param filterContext The filter context containing input configs and rescaling info. + */ + void readInputData(InputGate[] inputGates, RecordFilterContext filterContext) + throws IOException, InterruptedException; void readOutputData(ResultPartitionWriter[] writers, boolean notifyAndBlockOnCompletion) throws IOException, InterruptedException; @@ -39,7 +47,8 @@ public interface SequentialChannelStateReader extends AutoCloseable { new SequentialChannelStateReader() { @Override - public void readInputData(InputGate[] inputGates) {} + public void readInputData( + InputGate[] inputGates, RecordFilterContext filterContext) {} @Override public void readOutputData( diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java index 3daa4b4947a..8aa8db2679f 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImpl.java @@ -28,6 +28,7 @@ import org.apache.flink.runtime.io.network.partition.consumer.InputGate; import org.apache.flink.runtime.state.AbstractChannelStateHandle; import org.apache.flink.runtime.state.ChannelStateHelper; import org.apache.flink.runtime.state.StreamStateHandle; +import org.apache.flink.streaming.runtime.io.recovery.RecordFilterContext; import java.io.Closeable; import java.io.IOException; @@ -43,6 +44,7 @@ import java.util.stream.Stream; import static java.util.Comparator.comparingLong; import static java.util.stream.Collectors.groupingBy; import static java.util.stream.Collectors.toList; +import static org.apache.flink.util.Preconditions.checkState; /** {@link SequentialChannelStateReader} implementation. */ public class SequentialChannelStateReaderImpl implements SequentialChannelStateReader { @@ -58,10 +60,21 @@ public class SequentialChannelStateReaderImpl implements SequentialChannelStateR } @Override - public void readInputData(InputGate[] inputGates) throws IOException, InterruptedException { - try (InputChannelRecoveredStateHandler stateHandler = - new InputChannelRecoveredStateHandler( - inputGates, taskStateSnapshot.getInputRescalingDescriptor())) { + public void readInputData(InputGate[] inputGates, RecordFilterContext filterContext) + throws IOException, InterruptedException { + + // Create filtering handler if filtering is needed + ChannelStateFilteringHandler filteringHandler = + filterContext.isCheckpointingDuringRecoveryEnabled() + ? ChannelStateFilteringHandler.createFromContext(filterContext, inputGates) + : null; + + try (ChannelStateFilteringHandler ignored = filteringHandler; + InputChannelRecoveredStateHandler stateHandler = + new InputChannelRecoveredStateHandler( + inputGates, + taskStateSnapshot.getInputRescalingDescriptor(), + filteringHandler)) { read( stateHandler, groupByDelegate( @@ -72,6 +85,12 @@ public class SequentialChannelStateReaderImpl implements SequentialChannelStateR groupByDelegate( streamSubtaskStates(), OperatorSubtaskState::getUpstreamOutputBufferState)); + + if (filteringHandler != null) { + checkState( + !filteringHandler.hasPartialData(), + "Not all data has been fully consumed during filtering"); + } } } diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamMultipleInputProcessorFactory.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamMultipleInputProcessorFactory.java index 78877a3d62e..873f62c59ee 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamMultipleInputProcessorFactory.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamMultipleInputProcessorFactory.java @@ -23,6 +23,7 @@ import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.TaskInfo; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.configuration.CheckpointingOptions; import org.apache.flink.configuration.Configuration; import org.apache.flink.core.memory.ManagedMemoryUseCase; import org.apache.flink.metrics.Counter; @@ -103,6 +104,10 @@ public class StreamMultipleInputProcessorFactory { "Number of configured inputs in StreamConfig [%s] doesn't match the main operator's number of inputs [%s]", configuredInputs.length, inputsCount); + + boolean checkpointingDuringRecoveryEnabled = + CheckpointingOptions.isCheckpointingDuringRecoveryEnabled(jobConfig); + StreamTaskInput[] inputs = new StreamTaskInput[inputsCount]; for (int i = 0; i < inputsCount; i++) { StreamConfig.InputConfig configuredInput = configuredInputs[i]; @@ -121,7 +126,8 @@ public class StreamMultipleInputProcessorFactory { gatePartitioners, taskInfo, canEmitBatchOfRecords, - streamConfig.getWatermarkDeclarations(userClassloader)); + streamConfig.getWatermarkDeclarations(userClassloader), + checkpointingDuringRecoveryEnabled); } else if (configuredInput instanceof StreamConfig.SourceInputConfig) { StreamConfig.SourceInputConfig sourceInput = (StreamConfig.SourceInputConfig) configuredInput; diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamTaskNetworkInputFactory.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamTaskNetworkInputFactory.java index 46c9cd96936..70f00eed829 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamTaskNetworkInputFactory.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamTaskNetworkInputFactory.java @@ -47,9 +47,14 @@ public class StreamTaskNetworkInputFactory { Function<Integer, StreamPartitioner<?>> gatePartitioners, TaskInfo taskInfo, CanEmitBatchOfRecordsChecker canEmitBatchOfRecords, - Set<AbstractInternalWatermarkDeclaration<?>> watermarkDeclarationSet) { + Set<AbstractInternalWatermarkDeclaration<?>> watermarkDeclarationSet, + boolean checkpointingDuringRecoveryEnabled) { return rescalingDescriptorinflightDataRescalingDescriptor.equals( - InflightDataRescalingDescriptor.NO_RESCALE) + InflightDataRescalingDescriptor.NO_RESCALE) + // When filter during recovery is enabled, records are already filtered in + // the channel-state-unspilling thread. Use StreamTaskNetworkInput to avoid + // redundant demultiplexing/filtering in the Task thread. + || checkpointingDuringRecoveryEnabled ? new StreamTaskNetworkInput<>( checkpointedInputGate, inputSerializer, diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamTwoInputProcessorFactory.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamTwoInputProcessorFactory.java index 2a0c675710b..04cd6ccb410 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamTwoInputProcessorFactory.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/StreamTwoInputProcessorFactory.java @@ -22,6 +22,7 @@ import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.TaskInfo; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.configuration.CheckpointingOptions; import org.apache.flink.configuration.Configuration; import org.apache.flink.core.memory.ManagedMemoryUseCase; import org.apache.flink.metrics.Counter; @@ -84,6 +85,10 @@ public class StreamTwoInputProcessorFactory { checkNotNull(operatorChain); taskIOMetricGroup.reuseRecordsInputCounter(numRecordsIn); + + boolean checkpointingDuringRecoveryEnabled = + CheckpointingOptions.isCheckpointingDuringRecoveryEnabled(jobConfig); + TypeSerializer<IN1> typeSerializer1 = streamConfig.getTypeSerializerIn(0, userClassloader); StreamTaskInput<IN1> input1 = StreamTaskNetworkInputFactory.create( @@ -96,7 +101,8 @@ public class StreamTwoInputProcessorFactory { gatePartitioners, taskInfo, canEmitBatchOfRecords, - streamConfig.getWatermarkDeclarations(userClassloader)); + streamConfig.getWatermarkDeclarations(userClassloader), + checkpointingDuringRecoveryEnabled); TypeSerializer<IN2> typeSerializer2 = streamConfig.getTypeSerializerIn(1, userClassloader); StreamTaskInput<IN2> input2 = StreamTaskNetworkInputFactory.create( @@ -109,7 +115,8 @@ public class StreamTwoInputProcessorFactory { gatePartitioners, taskInfo, canEmitBatchOfRecords, - streamConfig.getWatermarkDeclarations(userClassloader)); + streamConfig.getWatermarkDeclarations(userClassloader), + checkpointingDuringRecoveryEnabled); InputSelectable inputSelectable = streamOperator instanceof InputSelectable ? (InputSelectable) streamOperator : null; diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContext.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContext.java new file mode 100644 index 00000000000..f2568fe5854 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContext.java @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.runtime.io.recovery; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor; +import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner; + +import static org.apache.flink.util.Preconditions.checkArgument; +import static org.apache.flink.util.Preconditions.checkNotNull; + +/** + * Context containing all information needed for filtering recovered channel state buffers. + * + * <p>This context encapsulates the input configurations, rescaling descriptor, and subtask + * information required by the channel-state-unspilling thread to perform record filtering during + * recovery. + * + * <p>Supports multiple inputs (e.g., TwoInputStreamTask, MultipleInputStreamTask) by storing a list + * of {@link InputFilterConfig} instances indexed by input index. + * + * <p>Use the constructor with empty inputConfigs or enabled=false when filtering is not needed. + */ +@Internal +public class RecordFilterContext { + + /** Configuration for filtering records on a specific input. */ + public static class InputFilterConfig { + private final TypeSerializer<?> typeSerializer; + private final StreamPartitioner<?> partitioner; + private final int numberOfChannels; + + /** + * Creates a new InputFilterConfig. + * + * @param typeSerializer Serializer for the record type. + * @param partitioner Partitioner used to determine record ownership. + * @param numberOfChannels The parallelism of the current operator. + */ + public InputFilterConfig( + TypeSerializer<?> typeSerializer, + StreamPartitioner<?> partitioner, + int numberOfChannels) { + this.typeSerializer = checkNotNull(typeSerializer); + this.partitioner = checkNotNull(partitioner); + this.numberOfChannels = numberOfChannels; + } + + public TypeSerializer<?> getTypeSerializer() { + return typeSerializer; + } + + public StreamPartitioner<?> getPartitioner() { + return partitioner; + } + + public int getNumberOfChannels() { + return numberOfChannels; + } + } + + /** + * Input configurations indexed by gate index. Array elements may be null for non-network inputs + * (e.g., SourceInputConfig). The array length equals the total number of input gates. + */ + private final InputFilterConfig[] inputConfigs; + + /** Descriptor containing rescaling information. Never null. */ + private final InflightDataRescalingDescriptor rescalingDescriptor; + + /** Current subtask index. */ + private final int subtaskIndex; + + /** Maximum parallelism for configuring partitioners. */ + private final int maxParallelism; + + /** Temporary directories for spilling spanning records. Can be empty but never null. */ + private final String[] tmpDirectories; + + /** Whether unaligned checkpointing during recovery is enabled. */ + private final boolean checkpointingDuringRecoveryEnabled; + + /** + * Creates a new RecordFilterContext. + * + * @param inputConfigs Input configurations indexed by gate index. Array elements may be null + * for non-network inputs. Not null itself. + * @param rescalingDescriptor Descriptor containing rescaling information. Not null. + * @param subtaskIndex Current subtask index. + * @param maxParallelism Maximum parallelism. + * @param tmpDirectories Temporary directories for spilling spanning records. Can be null + * (converted to empty array). + * @param checkpointingDuringRecoveryEnabled Whether unaligned checkpointing during recovery is + * enabled. + */ + public RecordFilterContext( + InputFilterConfig[] inputConfigs, + InflightDataRescalingDescriptor rescalingDescriptor, + int subtaskIndex, + int maxParallelism, + String[] tmpDirectories, + boolean checkpointingDuringRecoveryEnabled) { + this.inputConfigs = checkNotNull(inputConfigs).clone(); + this.rescalingDescriptor = checkNotNull(rescalingDescriptor); + this.subtaskIndex = subtaskIndex; + this.maxParallelism = maxParallelism; + this.tmpDirectories = tmpDirectories != null ? tmpDirectories : new String[0]; + this.checkpointingDuringRecoveryEnabled = checkpointingDuringRecoveryEnabled; + } + + /** + * Gets the input configuration for a specific gate. + * + * @param gateIndex The gate index (0-based). + * @return The input configuration for the specified gate, or null if the gate is not a network + * input (e.g., SourceInputConfig). + * @throws IllegalArgumentException if gateIndex is out of bounds. + */ + public InputFilterConfig getInputConfig(int gateIndex) { + checkArgument( + gateIndex >= 0 && gateIndex < inputConfigs.length, + "Invalid gate index: %s, number of gates: %s", + gateIndex, + inputConfigs.length); + return inputConfigs[gateIndex]; + } + + /** + * Gets the number of input gates. + * + * @return The number of input gates. + */ + public int getNumberOfGates() { + return inputConfigs.length; + } + + /** + * Checks whether unaligned checkpointing during recovery is enabled. + * + * @return {@code true} if enabled, {@code false} otherwise. + */ + public boolean isCheckpointingDuringRecoveryEnabled() { + return checkpointingDuringRecoveryEnabled; + } + + /** + * Gets the rescaling descriptor. + * + * @return The descriptor containing rescaling information. + */ + public InflightDataRescalingDescriptor getRescalingDescriptor() { + return rescalingDescriptor; + } + + /** + * Gets the current subtask index. + * + * @return The subtask index. + */ + public int getSubtaskIndex() { + return subtaskIndex; + } + + /** + * Gets the maximum parallelism. + * + * @return The maximum parallelism value. + */ + public int getMaxParallelism() { + return maxParallelism; + } + + /** + * Gets the temporary directories for spilling spanning records. + * + * @return The temporary directories, never null (may be empty array). + */ + public String[] getTmpDirectories() { + return tmpDirectories; + } + + /** + * Checks if a specific gate and subtask combination is ambiguous (requires filtering). + * + * @param gateIndex The gate index. + * @param oldSubtaskIndex The old subtask index. + * @return true if enabled and the channel is ambiguous and records need filtering. + */ + public boolean isAmbiguous(int gateIndex, int oldSubtaskIndex) { + return checkpointingDuringRecoveryEnabled + && rescalingDescriptor.isAmbiguous(gateIndex, oldSubtaskIndex); + } + + /** + * Creates a disabled RecordFilterContext for testing or when filtering is not needed. + * + * <p>The returned context has empty inputConfigs and enabled=false, so {@link + * #isCheckpointingDuringRecoveryEnabled()} will always return false. + * + * @return A disabled RecordFilterContext. + */ + public static RecordFilterContext disabled() { + return new RecordFilterContext( + new InputFilterConfig[0], + InflightDataRescalingDescriptor.NO_RESCALE, + 0, + 0, + new String[0], + false); + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/VirtualChannelRecordFilterFactory.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/VirtualChannelRecordFilterFactory.java new file mode 100644 index 00000000000..a2093f11437 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/io/recovery/VirtualChannelRecordFilterFactory.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.runtime.io.recovery; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.runtime.io.network.api.writer.ChannelSelector; +import org.apache.flink.runtime.plugable.SerializationDelegate; +import org.apache.flink.streaming.runtime.partitioner.ConfigurableStreamPartitioner; +import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; + +/** + * Factory for creating record filters used in Virtual Channels during channel state recovery. + * + * <p>This factory provides methods to create {@link RecordFilter} instances that determine whether + * a record belongs to the current subtask based on the partitioner logic. + * + * @param <T> The type of record values. + */ +@Internal +public class VirtualChannelRecordFilterFactory<T> { + + private final TypeSerializer<T> typeSerializer; + private final StreamPartitioner<T> partitioner; + private final int subtaskIndex; + private final int numberOfChannels; + private final int maxParallelism; + + /** + * Creates a new VirtualChannelRecordFilterFactory. + * + * @param typeSerializer Serializer for the record type. + * @param partitioner Partitioner used to determine record ownership. + * @param subtaskIndex Current subtask index. + * @param numberOfChannels Number of parallel subtasks. + * @param maxParallelism Maximum parallelism for configuring partitioners. + */ + public VirtualChannelRecordFilterFactory( + TypeSerializer<T> typeSerializer, + StreamPartitioner<T> partitioner, + int subtaskIndex, + int numberOfChannels, + int maxParallelism) { + this.typeSerializer = typeSerializer; + this.partitioner = partitioner; + this.subtaskIndex = subtaskIndex; + this.numberOfChannels = numberOfChannels; + this.maxParallelism = maxParallelism; + } + + /** + * Creates a new VirtualChannelRecordFilterFactory from a RecordFilterContext and input index. + * + * @param context The record filter context. + * @param inputIndex The input index to get configuration from. + * @param <T> The type of record values. + * @return A new factory instance. + */ + @SuppressWarnings("unchecked") + public static <T> VirtualChannelRecordFilterFactory<T> fromContext( + RecordFilterContext context, int inputIndex) { + RecordFilterContext.InputFilterConfig inputConfig = context.getInputConfig(inputIndex); + return new VirtualChannelRecordFilterFactory<>( + (TypeSerializer<T>) inputConfig.getTypeSerializer(), + (StreamPartitioner<T>) inputConfig.getPartitioner(), + context.getSubtaskIndex(), + inputConfig.getNumberOfChannels(), + context.getMaxParallelism()); + } + + /** + * Creates a record filter for ambiguous channels that requires actual filtering. + * + * @return A RecordFilter that tests if a record belongs to this subtask. + */ + public RecordFilter<T> createFilter() { + StreamPartitioner<T> configuredPartitioner = configurePartitioner(); + @SuppressWarnings("unchecked") + ChannelSelector<SerializationDelegate<StreamRecord<T>>> channelSelector = + configuredPartitioner; + return new PartitionerRecordFilter<>(channelSelector, typeSerializer, subtaskIndex); + } + + /** + * Creates a pass-through filter that accepts all records. + * + * @param <T> The type of record values. + * @return A RecordFilter that always returns true. + */ + public static <T> RecordFilter<T> createPassThroughFilter() { + return RecordFilter.acceptAll(); + } + + /** + * Configures the partitioner with the correct number of channels and max parallelism. + * + * @return A configured copy of the partitioner. + */ + private StreamPartitioner<T> configurePartitioner() { + StreamPartitioner<T> copy = partitioner.copy(); + copy.setup(numberOfChannels); + if (copy instanceof ConfigurableStreamPartitioner) { + ((ConfigurableStreamPartitioner) copy).configure(maxParallelism); + } + return copy; + } +} diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTask.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTask.java index 353582674e0..009f48b082f 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTask.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTask.java @@ -203,6 +203,9 @@ public class OneInputStreamTask<IN, OUT> extends StreamTask<OUT, OneInputStreamO Set<AbstractInternalWatermarkDeclaration<?>> watermarkDeclarationSet = configuration.getWatermarkDeclarations(getUserCodeClassLoader()); + boolean checkpointingDuringRecoveryEnabled = + CheckpointingOptions.isCheckpointingDuringRecoveryEnabled(getJobConfiguration()); + return StreamTaskNetworkInputFactory.create( inputGate, inSerializer, @@ -217,7 +220,8 @@ public class OneInputStreamTask<IN, OUT> extends StreamTask<OUT, OneInputStreamO .getPartitioner(), getEnvironment().getTaskInfo(), getCanEmitBatchOfRecords(), - watermarkDeclarationSet); + watermarkDeclarationSet, + checkpointingDuringRecoveryEnabled); } /** diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java index a07d5ee3915..27823f7137f 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java @@ -21,6 +21,7 @@ import org.apache.flink.annotation.Internal; import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.api.common.operators.MailboxExecutor; import org.apache.flink.api.common.operators.ProcessingTimeService.ProcessingTimeCallback; +import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.configuration.CheckpointingOptions; import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.NettyShuffleEnvironmentOptions; @@ -84,6 +85,7 @@ import org.apache.flink.runtime.taskmanager.Task; import org.apache.flink.runtime.util.ConfigurationParserUtils; import org.apache.flink.streaming.api.graph.NonChainedOutput; import org.apache.flink.streaming.api.graph.StreamConfig; +import org.apache.flink.streaming.api.graph.StreamEdge; import org.apache.flink.streaming.api.operators.InternalTimeServiceManager; import org.apache.flink.streaming.api.operators.InternalTimeServiceManagerImpl; import org.apache.flink.streaming.api.operators.StreamOperator; @@ -94,6 +96,7 @@ import org.apache.flink.streaming.runtime.io.RecordWriterOutput; import org.apache.flink.streaming.runtime.io.StreamInputProcessor; import org.apache.flink.streaming.runtime.io.checkpointing.BarrierAlignmentUtil; import org.apache.flink.streaming.runtime.io.checkpointing.CheckpointBarrierHandler; +import org.apache.flink.streaming.runtime.io.recovery.RecordFilterContext; import org.apache.flink.streaming.runtime.partitioner.ConfigurableStreamPartitioner; import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner; import org.apache.flink.streaming.runtime.partitioner.RebalancePartitioner; @@ -881,7 +884,7 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>> channelIOExecutor.execute( () -> { try { - reader.readInputData(inputGates); + reader.readInputData(inputGates, createRecordFilterContext()); } catch (Exception e) { asyncExceptionHandler.handleAsyncException( "Unable to read channel state", e); @@ -1956,6 +1959,83 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>> return environment; } + /** + * Creates a RecordFilterContext for filtering recovered channel state buffers. + * + * <p>This method builds the complete context using information available in StreamTask, + * including input configurations for all network inputs. + * + * @return A RecordFilterContext with input configurations. The context may have empty + * inputConfigs (e.g., for source tasks) or enabled=false when filtering is not needed. + */ + protected RecordFilterContext createRecordFilterContext() { + boolean checkpointingDuringRecoveryEnabled = + CheckpointingOptions.isCheckpointingDuringRecoveryEnabled(getJobConfiguration()); + if (!checkpointingDuringRecoveryEnabled) { + return RecordFilterContext.disabled(); + } + + ClassLoader cl = getUserCodeClassLoader(); + StreamConfig.InputConfig[] inputs = configuration.getInputs(cl); + List<StreamEdge> inEdges = configuration.getInPhysicalEdges(cl); + + // Create array sized to match the number of physical input gates. + // For source tasks, this will be 0. For tasks with network inputs, each physical gate + // must have a corresponding config entry. + int numGates = getEnvironment().getAllInputGates().length; + RecordFilterContext.InputFilterConfig[] inputConfigs = + new RecordFilterContext.InputFilterConfig[numGates]; + + Preconditions.checkState( + numGates == inEdges.size(), + "Number of input gates (%s) does not match number of physical edges (%s)", + numGates, + inEdges.size()); + + // Iterate through all physical edges (inEdges) instead of logical inputs. + // This is critical for Union scenarios where multiple physical gates map to one logical + // input. The order of inEdges matches the order of physical input gates. + int numberOfChannels = getEnvironment().getTaskInfo().getNumberOfParallelSubtasks(); + for (int gateIndex = 0; gateIndex < inEdges.size(); gateIndex++) { + StreamEdge edge = inEdges.get(gateIndex); + // Calculate logical input index from typeNumber + // typeNumber = 0 means single input, typeNumber >= 1 means multi-input (1-indexed) + int inputIndex = edge.getTypeNumber() == 0 ? 0 : edge.getTypeNumber() - 1; + + Preconditions.checkState( + inputIndex < inputs.length + && inputs[inputIndex] instanceof StreamConfig.NetworkInputConfig, + "Physical edge at gateIndex %s has invalid inputIndex %s or non-network input", + gateIndex, + inputIndex); + + StreamConfig.NetworkInputConfig networkInput = + (StreamConfig.NetworkInputConfig) inputs[inputIndex]; + TypeSerializer<?> typeSerializer = networkInput.getTypeSerializer(); + StreamPartitioner<?> partitioner = edge.getPartitioner(); + + inputConfigs[gateIndex] = + new RecordFilterContext.InputFilterConfig( + typeSerializer, partitioner, numberOfChannels); + } + + for (int i = 0; i < inputConfigs.length; i++) { + Preconditions.checkState( + inputConfigs[i] != null, + "InputFilterConfig at index %s is null. " + + "All physical gates must have corresponding configurations.", + i); + } + + return new RecordFilterContext( + inputConfigs, + getEnvironment().getTaskStateManager().getInputRescalingDescriptor(), + getEnvironment().getTaskInfo().getIndexOfThisSubtask(), + getEnvironment().getTaskInfo().getMaxNumberOfParallelSubtasks(), + getEnvironment().getIOManager().getSpillingDirectoriesPaths(), + true); + } + /** Check whether records can be emitted in batch. */ @FunctionalInterface public interface CanEmitBatchOfRecordsChecker { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/GateFilterHandlerBufferOwnershipTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/GateFilterHandlerBufferOwnershipTest.java new file mode 100644 index 00000000000..85b4fd1d48e --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/GateFilterHandlerBufferOwnershipTest.java @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.apache.flink.api.common.typeutils.base.LongSerializer; +import org.apache.flink.core.memory.DataOutputSerializer; +import org.apache.flink.core.memory.MemorySegment; +import org.apache.flink.core.memory.MemorySegmentFactory; +import org.apache.flink.runtime.io.network.api.SubtaskConnectionDescriptor; +import org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer; +import org.apache.flink.runtime.io.network.api.serialization.SpillingAdaptiveSpanningRecordDeserializer; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler; +import org.apache.flink.runtime.io.network.buffer.NetworkBuffer; +import org.apache.flink.runtime.plugable.DeserializationDelegate; +import org.apache.flink.streaming.runtime.io.recovery.RecordFilter; +import org.apache.flink.streaming.runtime.io.recovery.VirtualChannel; +import org.apache.flink.streaming.runtime.streamrecord.StreamElement; +import org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; + +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Tests buffer ownership semantics of {@link ChannelStateFilteringHandler.GateFilterHandler}. Each + * test verifies that buffers are properly recycled on both success and failure paths. + */ +class GateFilterHandlerBufferOwnershipTest { + + private static final int BUFFER_SIZE = 1024; + private static final SubtaskConnectionDescriptor KEY = new SubtaskConnectionDescriptor(0, 0); + + @Test + void testSourceBufferRecycledOnSuccess() throws Exception { + ChannelStateFilteringHandler.GateFilterHandler<Long> handler = + createHandler(RecordFilter.acceptAll()); + + Buffer sourceBuffer = createBufferWithRecords(1L, 2L); + List<Buffer> result = handler.filterAndRewrite(0, 0, sourceBuffer, this::createEmptyBuffer); + + // sourceBuffer should be recycled by the deserializer after consumption + assertThat(sourceBuffer.isRecycled()).isTrue(); + + // Clean up result buffers + result.forEach(Buffer::recycleBuffer); + } + + @Test + void testSourceBufferRecycledWhenAllRecordsFilteredOut() throws Exception { + RecordFilter<Long> rejectAll = record -> false; + ChannelStateFilteringHandler.GateFilterHandler<Long> handler = createHandler(rejectAll); + + Buffer sourceBuffer = createBufferWithRecords(1L, 2L); + List<Buffer> result = handler.filterAndRewrite(0, 0, sourceBuffer, this::createEmptyBuffer); + + assertThat(result).isEmpty(); + // sourceBuffer should still be recycled even though no output was produced + assertThat(sourceBuffer.isRecycled()).isTrue(); + } + + @Test + void testSourceBufferRecycledOnInvalidVirtualChannel() { + // Create handler with KEY=(0,0) but call with (1,1) to trigger IllegalStateException + ChannelStateFilteringHandler.GateFilterHandler<Long> handler = + createHandler(RecordFilter.acceptAll()); + + Buffer sourceBuffer = createBufferWithRecords(1L); + + assertThatThrownBy( + () -> handler.filterAndRewrite(1, 1, sourceBuffer, this::createEmptyBuffer)) + .isInstanceOf(IllegalStateException.class); + + // sourceBuffer must be recycled even when lookup fails before setNextBuffer + assertThat(sourceBuffer.isRecycled()).isTrue(); + } + + @Test + void testResultBuffersAndCurrentBufferRecycledOnSerializationError() throws Exception { + // Use a small buffer so that records span multiple buffers. The supplier fails on the + // second request, after the first output buffer has been filled and added to resultBuffers. + AtomicInteger bufferRequestCount = new AtomicInteger(0); + ChannelStateFilteringHandler.BufferSupplier failingSupplier = + () -> { + if (bufferRequestCount.incrementAndGet() > 1) { + throw new IOException("Simulated buffer allocation failure"); + } + return createEmptyBuffer(13); + }; + + ChannelStateFilteringHandler.GateFilterHandler<Long> handler = + createHandler(RecordFilter.acceptAll()); + + Buffer sourceBuffer = createBufferWithRecords(1L, 2L, 3L, 4L, 5L); + + // The exception should propagate; no buffer leak (no IllegalReferenceCountException + // from double-recycle). + assertThatThrownBy(() -> handler.filterAndRewrite(0, 0, sourceBuffer, failingSupplier)) + .isInstanceOf(IOException.class) + .hasMessage("Simulated buffer allocation failure"); + + // sourceBuffer ownership was transferred to the deserializer via setNextBuffer(). + // The deserializer may still hold it if it hasn't fully consumed the buffer before the + // error. Calling clear() triggers the cleanup chain: + // GateFilterHandler#clear() -> VirtualChannel#clear() -> deserializer.clear() + handler.clear(); + assertThat(sourceBuffer.isRecycled()).isTrue(); + } + + /** + * Tests the production cleanup path: when filterAndRewrite throws mid-processing, the + * deserializer may still hold sourceBuffer. In production, ChannelStateFilteringHandler is used + * in a try-with-resources block (see {@code SequentialChannelStateReaderImpl#readInputData}), + * so its close() is guaranteed to be called, which triggers clear() on all GateFilterHandlers + * and their deserializers. This test simulates that exact pattern. + */ + @Test + void testCloseRecyclesDeserializerHeldBufferAfterError() throws Exception { + AtomicInteger bufferRequestCount = new AtomicInteger(0); + ChannelStateFilteringHandler.BufferSupplier failingSupplier = + () -> { + if (bufferRequestCount.incrementAndGet() > 1) { + throw new IOException("Simulated buffer allocation failure"); + } + return createEmptyBuffer(13); + }; + + ChannelStateFilteringHandler.GateFilterHandler<Long> gateHandler = + createHandler(RecordFilter.acceptAll()); + // Wrap in ChannelStateFilteringHandler, the production-level owner + ChannelStateFilteringHandler filteringHandler = + new ChannelStateFilteringHandler( + new ChannelStateFilteringHandler.GateFilterHandler<?>[] {gateHandler}); + + Buffer sourceBuffer = createBufferWithRecords(1L, 2L, 3L, 4L, 5L); + + // Simulate the production try-with-resources pattern + assertThatThrownBy( + () -> { + try (ChannelStateFilteringHandler ignored = filteringHandler) { + filteringHandler.filterAndRewrite( + 0, 0, 0, sourceBuffer, failingSupplier); + } + }) + .isInstanceOf(IOException.class) + .hasMessage("Simulated buffer allocation failure"); + + // After close(), the entire cleanup chain has fired: + // ChannelStateFilteringHandler.close() -> GateFilterHandler.clear() + // -> VirtualChannel.clear() -> deserializer.clear() -> sourceBuffer.recycleBuffer() + assertThat(sourceBuffer.isRecycled()).isTrue(); + } + + // ------------------------------------------------------------------------------------------- + // Helper methods + // ------------------------------------------------------------------------------------------- + + private ChannelStateFilteringHandler.GateFilterHandler<Long> createHandler( + RecordFilter<Long> filter) { + RecordDeserializer<DeserializationDelegate<StreamElement>> deserializer = + new SpillingAdaptiveSpanningRecordDeserializer<>( + new String[] {System.getProperty("java.io.tmpdir")}); + VirtualChannel<Long> vc = new VirtualChannel<>(deserializer, filter); + + Map<SubtaskConnectionDescriptor, VirtualChannel<Long>> channels = new HashMap<>(); + channels.put(KEY, vc); + + StreamElementSerializer<Long> serializer = + new StreamElementSerializer<>(LongSerializer.INSTANCE); + return new ChannelStateFilteringHandler.GateFilterHandler<>(channels, serializer); + } + + private Buffer createBufferWithRecords(Long... values) { + try { + StreamElementSerializer<Long> serializer = + new StreamElementSerializer<>(LongSerializer.INSTANCE); + DataOutputSerializer output = new DataOutputSerializer(BUFFER_SIZE); + + for (Long value : values) { + DataOutputSerializer recordOutput = new DataOutputSerializer(64); + serializer.serialize(new StreamRecord<>(value), recordOutput); + int recordLength = recordOutput.length(); + output.writeInt(recordLength); + output.write(recordOutput.getSharedBuffer(), 0, recordLength); + } + + byte[] data = output.getCopyOfBuffer(); + MemorySegment segment = MemorySegmentFactory.allocateUnpooledSegment(BUFFER_SIZE); + segment.put(0, data, 0, data.length); + + NetworkBuffer buffer = new NetworkBuffer(segment, FreeingBufferRecycler.INSTANCE); + buffer.setSize(data.length); + return buffer; + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private Buffer createEmptyBuffer() { + return createEmptyBuffer(BUFFER_SIZE); + } + + private Buffer createEmptyBuffer(int size) { + MemorySegment segment = MemorySegmentFactory.allocateUnpooledSegment(size); + return new NetworkBuffer(segment, FreeingBufferRecycler.INSTANCE); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/GateFilterHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/GateFilterHandlerTest.java new file mode 100644 index 00000000000..f02ce35fd86 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/GateFilterHandlerTest.java @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.checkpoint.channel; + +import org.apache.flink.api.common.typeutils.base.LongSerializer; +import org.apache.flink.core.memory.DataOutputSerializer; +import org.apache.flink.core.memory.MemorySegment; +import org.apache.flink.core.memory.MemorySegmentFactory; +import org.apache.flink.runtime.io.network.api.SubtaskConnectionDescriptor; +import org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer; +import org.apache.flink.runtime.io.network.api.serialization.SpillingAdaptiveSpanningRecordDeserializer; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler; +import org.apache.flink.runtime.io.network.buffer.NetworkBuffer; +import org.apache.flink.runtime.plugable.DeserializationDelegate; +import org.apache.flink.runtime.plugable.NonReusingDeserializationDelegate; +import org.apache.flink.streaming.runtime.io.recovery.RecordFilter; +import org.apache.flink.streaming.runtime.io.recovery.VirtualChannel; +import org.apache.flink.streaming.runtime.streamrecord.StreamElement; +import org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; + +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Tests for {@link ChannelStateFilteringHandler.GateFilterHandler}. */ +class GateFilterHandlerTest { + + private static final int BUFFER_SIZE = 1024; + private static final SubtaskConnectionDescriptor KEY = new SubtaskConnectionDescriptor(0, 0); + + @Test + void testAllRecordsPassFilter() throws Exception { + ChannelStateFilteringHandler.GateFilterHandler<Long> handler = + createHandler(RecordFilter.acceptAll()); + + Buffer sourceBuffer = createBufferWithRecords(1L, 2L, 3L); + List<Buffer> result = handler.filterAndRewrite(0, 0, sourceBuffer, this::createEmptyBuffer); + + // deserializeBuffers consumes (recycles) each buffer via the deserializer + List<Long> values = deserializeBuffers(result); + assertThat(values).containsExactly(1L, 2L, 3L); + } + + @Test + void testAllRecordsFilteredOut() throws Exception { + RecordFilter<Long> rejectAll = record -> false; + ChannelStateFilteringHandler.GateFilterHandler<Long> handler = createHandler(rejectAll); + + Buffer sourceBuffer = createBufferWithRecords(1L, 2L, 3L); + List<Buffer> result = handler.filterAndRewrite(0, 0, sourceBuffer, this::createEmptyBuffer); + + assertThat(result).isEmpty(); + } + + @Test + void testPartialFiltering() throws Exception { + RecordFilter<Long> keepEven = record -> record.getValue() % 2 == 0; + ChannelStateFilteringHandler.GateFilterHandler<Long> handler = createHandler(keepEven); + + Buffer sourceBuffer = createBufferWithRecords(1L, 2L, 3L, 4L, 5L); + List<Buffer> result = handler.filterAndRewrite(0, 0, sourceBuffer, this::createEmptyBuffer); + + List<Long> values = deserializeBuffers(result); + assertThat(values).containsExactly(2L, 4L); + } + + @Test + void testSmallOutputBufferProducesMultipleBuffers() throws Exception { + // Use a very small output buffer size so records must span multiple buffers + int smallBufferSize = 8; + ChannelStateFilteringHandler.GateFilterHandler<Long> handler = + createHandler(RecordFilter.acceptAll()); + + Buffer sourceBuffer = createBufferWithRecords(1L, 2L, 3L); + List<Buffer> result = + handler.filterAndRewrite( + 0, 0, sourceBuffer, () -> createEmptyBuffer(smallBufferSize)); + + // Each Long record needs 4 bytes length + ~9 bytes data > 8-byte buffer + assertThat(result.size()).isGreaterThan(1); + + List<Long> values = deserializeBuffers(result); + assertThat(values).containsExactly(1L, 2L, 3L); + } + + @Test + void testEmptyBuffer() throws Exception { + ChannelStateFilteringHandler.GateFilterHandler<Long> handler = + createHandler(RecordFilter.acceptAll()); + + Buffer emptyBuffer = createEmptyBuffer(); + emptyBuffer.setSize(0); + + List<Buffer> result = handler.filterAndRewrite(0, 0, emptyBuffer, this::createEmptyBuffer); + + assertThat(result).isEmpty(); + } + + // ------------------------------------------------------------------------------------------- + // Helper methods + // ------------------------------------------------------------------------------------------- + + private ChannelStateFilteringHandler.GateFilterHandler<Long> createHandler( + RecordFilter<Long> filter) { + RecordDeserializer<DeserializationDelegate<StreamElement>> deserializer = + new SpillingAdaptiveSpanningRecordDeserializer<>( + new String[] {System.getProperty("java.io.tmpdir")}); + VirtualChannel<Long> vc = new VirtualChannel<>(deserializer, filter); + + Map<SubtaskConnectionDescriptor, VirtualChannel<Long>> channels = new HashMap<>(); + channels.put(KEY, vc); + + StreamElementSerializer<Long> serializer = + new StreamElementSerializer<>(LongSerializer.INSTANCE); + return new ChannelStateFilteringHandler.GateFilterHandler<>(channels, serializer); + } + + private Buffer createBufferWithRecords(Long... values) throws IOException { + StreamElementSerializer<Long> serializer = + new StreamElementSerializer<>(LongSerializer.INSTANCE); + return serializeRecordsToBuffer(serializer, values); + } + + /** Serializes records into a buffer using Flink's length-prefixed format. */ + private Buffer serializeRecordsToBuffer( + StreamElementSerializer<Long> serializer, Long... values) throws IOException { + DataOutputSerializer output = new DataOutputSerializer(BUFFER_SIZE); + + for (Long value : values) { + // Serialize using the same length-prefixed format as Flink + DataOutputSerializer recordOutput = new DataOutputSerializer(64); + serializer.serialize(new StreamRecord<>(value), recordOutput); + int recordLength = recordOutput.length(); + + // Write 4-byte big-endian length prefix + output.writeInt(recordLength); + // Write record bytes + output.write(recordOutput.getSharedBuffer(), 0, recordLength); + } + + byte[] data = output.getCopyOfBuffer(); + MemorySegment segment = MemorySegmentFactory.allocateUnpooledSegment(BUFFER_SIZE); + segment.put(0, data, 0, data.length); + + NetworkBuffer buffer = new NetworkBuffer(segment, FreeingBufferRecycler.INSTANCE); + buffer.setSize(data.length); + return buffer; + } + + private Buffer createEmptyBuffer() { + return createEmptyBuffer(BUFFER_SIZE); + } + + private Buffer createEmptyBuffer(int size) { + MemorySegment segment = MemorySegmentFactory.allocateUnpooledSegment(size); + return new NetworkBuffer(segment, FreeingBufferRecycler.INSTANCE); + } + + private List<Long> deserializeBuffers(List<Buffer> buffers) throws IOException { + StreamElementSerializer<Long> serializer = + new StreamElementSerializer<>(LongSerializer.INSTANCE); + SpillingAdaptiveSpanningRecordDeserializer<DeserializationDelegate<StreamElement>> + deserializer = + new SpillingAdaptiveSpanningRecordDeserializer<>( + new String[] {System.getProperty("java.io.tmpdir")}); + DeserializationDelegate<StreamElement> delegate = + new NonReusingDeserializationDelegate<>(serializer); + + List<Long> values = new ArrayList<>(); + for (Buffer buffer : buffers) { + deserializer.setNextBuffer(buffer); + while (true) { + RecordDeserializer.DeserializationResult result = + deserializer.getNextRecord(delegate); + if (result.isFullRecord()) { + StreamElement element = delegate.getInstance(); + if (element.isRecord()) { + @SuppressWarnings("unchecked") + StreamRecord<Long> record = (StreamRecord<Long>) element; + values.add(record.getValue()); + } + } + if (result.isBufferConsumed()) { + break; + } + } + } + return values; + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/InputChannelRecoveredStateHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/InputChannelRecoveredStateHandlerTest.java index e2b1d69c56a..39ce6c7d4bf 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/InputChannelRecoveredStateHandlerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/InputChannelRecoveredStateHandlerTest.java @@ -77,7 +77,8 @@ class InputChannelRecoveredStateHandlerTest extends RecoveredChannelStateHandler InflightDataRescalingDescriptor .InflightDataGateOrPartitionRescalingDescriptor .MappingType.IDENTITY) - })); + }), + null); } private InputChannelRecoveredStateHandler buildMultiChannelHandler() { @@ -103,7 +104,8 @@ class InputChannelRecoveredStateHandlerTest extends RecoveredChannelStateHandler InflightDataRescalingDescriptor .InflightDataGateOrPartitionRescalingDescriptor .MappingType.RESCALING) - })); + }), + null); } @Test diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImplTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImplTest.java index 5f05af0d92e..d80442b8a06 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImplTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/SequentialChannelStateReaderImplTest.java @@ -40,6 +40,7 @@ import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.state.InputChannelStateHandle; import org.apache.flink.runtime.state.ResultSubpartitionStateHandle; import org.apache.flink.runtime.state.memory.ByteStreamStateHandle; +import org.apache.flink.streaming.runtime.io.recovery.RecordFilterContext; import org.apache.flink.testutils.junit.extensions.parameterized.Parameter; import org.apache.flink.testutils.junit.extensions.parameterized.ParameterizedTestExtension; import org.apache.flink.testutils.junit.extensions.parameterized.Parameters; @@ -143,7 +144,7 @@ public class SequentialChannelStateReaderImplTest { withInputGates( gates -> { - reader.readInputData(gates); + reader.readInputData(gates, RecordFilterContext.disabled()); assertBuffersEquals(inputChannelsData, collectBuffers(gates)); assertConsumed(gates); }); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ChannelPersistenceITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ChannelPersistenceITCase.java index 53c05b19613..f64a4d9fb9c 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ChannelPersistenceITCase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ChannelPersistenceITCase.java @@ -49,6 +49,7 @@ import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGateBui import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.state.storage.JobManagerCheckpointStorage; +import org.apache.flink.streaming.runtime.io.recovery.RecordFilterContext; import org.apache.flink.util.function.SupplierWithException; import org.junit.jupiter.api.Test; @@ -119,7 +120,7 @@ class ChannelPersistenceITCase { try { int numChannels = 1; InputGate gate = buildGate(networkBufferPool, numChannels); - reader.readInputData(new InputGate[] {gate}); + reader.readInputData(new InputGate[] {gate}, RecordFilterContext.disabled()); assertThat(collectBytes(gate::pollNext, BufferOrEvent::getBuffer)) .isEqualTo(inputChannelInfoData); diff --git a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContextTest.java b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContextTest.java new file mode 100644 index 00000000000..15b97d433a2 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/RecordFilterContextTest.java @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.runtime.io.recovery; + +import org.apache.flink.api.common.typeutils.base.LongSerializer; +import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor; +import org.apache.flink.runtime.checkpoint.RescaleMappings; +import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner; + +import org.junit.jupiter.api.Test; + +import static org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptorUtil.mappings; +import static org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptorUtil.rescalingDescriptor; +import static org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptorUtil.set; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Tests for {@link RecordFilterContext}. */ +class RecordFilterContextTest { + + @Test + void testDisabledContextHasNoGates() { + RecordFilterContext disabled = RecordFilterContext.disabled(); + assertThat(disabled.getNumberOfGates()).isEqualTo(0); + assertThat(disabled.isCheckpointingDuringRecoveryEnabled()).isFalse(); + } + + @Test + void testGetInputConfigReturnsCorrectConfig() { + RecordFilterContext.InputFilterConfig config = + new RecordFilterContext.InputFilterConfig( + LongSerializer.INSTANCE, new ForwardPartitioner<>(), 4); + + RecordFilterContext context = + new RecordFilterContext( + new RecordFilterContext.InputFilterConfig[] {config}, + InflightDataRescalingDescriptor.NO_RESCALE, + 0, + 128, + new String[] {"/tmp"}, + true); + + assertThat(context.getNumberOfGates()).isEqualTo(1); + assertThat(context.getInputConfig(0)).isSameAs(config); + assertThat(context.getSubtaskIndex()).isEqualTo(0); + assertThat(context.getMaxParallelism()).isEqualTo(128); + assertThat(context.isCheckpointingDuringRecoveryEnabled()).isTrue(); + } + + @Test + void testGetInputConfigThrowsForInvalidIndex() { + RecordFilterContext context = + new RecordFilterContext( + new RecordFilterContext.InputFilterConfig[0], + InflightDataRescalingDescriptor.NO_RESCALE, + 0, + 128, + null, + false); + + assertThatThrownBy(() -> context.getInputConfig(0)) + .isInstanceOf(IllegalArgumentException.class); + assertThatThrownBy(() -> context.getInputConfig(-1)) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + void testNullTmpDirectoriesConvertedToEmptyArray() { + RecordFilterContext context = + new RecordFilterContext( + new RecordFilterContext.InputFilterConfig[0], + InflightDataRescalingDescriptor.NO_RESCALE, + 0, + 128, + null, + false); + + assertThat(context.getTmpDirectories()).isNotNull().isEmpty(); + } + + @Test + void testIsAmbiguousWhenDisabled() { + // Create a rescaling descriptor with an ambiguous subtask (oldSubtask 0 is ambiguous) + RescaleMappings mapping = mappings(new int[] {0}); + InflightDataRescalingDescriptor descriptor = + rescalingDescriptor(new int[] {0}, new RescaleMappings[] {mapping}, set(0)); + + // When checkpointingDuringRecoveryEnabled is false, isAmbiguous should always return false + RecordFilterContext context = + new RecordFilterContext( + new RecordFilterContext.InputFilterConfig[0], + descriptor, + 0, + 128, + null, + false); + + assertThat(context.isAmbiguous(0, 0)).isFalse(); + } + + @Test + void testIsAmbiguousWhenEnabled() { + // Create a rescaling descriptor with an ambiguous subtask (oldSubtask 0 is ambiguous) + RescaleMappings mapping = mappings(new int[] {0}); + InflightDataRescalingDescriptor descriptor = + rescalingDescriptor(new int[] {0}, new RescaleMappings[] {mapping}, set(0)); + + // When checkpointingDuringRecoveryEnabled is true, isAmbiguous follows rescalingDescriptor + RecordFilterContext context = + new RecordFilterContext( + new RecordFilterContext.InputFilterConfig[0], + descriptor, + 0, + 128, + null, + true); + + assertThat(context.isAmbiguous(0, 0)).isTrue(); + } + + @Test + void testIsAmbiguousForNonAmbiguousSubtask() { + // Create a rescaling descriptor where oldSubtask 0 is ambiguous but oldSubtask 1 is not + RescaleMappings mapping = mappings(new int[] {0}); + InflightDataRescalingDescriptor descriptor = + rescalingDescriptor(new int[] {0, 1}, new RescaleMappings[] {mapping}, set(0)); + + RecordFilterContext context = + new RecordFilterContext( + new RecordFilterContext.InputFilterConfig[0], + descriptor, + 0, + 128, + null, + true); + + // oldSubtask 0 is ambiguous + assertThat(context.isAmbiguous(0, 0)).isTrue(); + // oldSubtask 1 is NOT in the ambiguous set + assertThat(context.isAmbiguous(0, 1)).isFalse(); + } + + @Test + void testInputFilterConfigGetters() { + ForwardPartitioner<Long> partitioner = new ForwardPartitioner<>(); + RecordFilterContext.InputFilterConfig config = + new RecordFilterContext.InputFilterConfig(LongSerializer.INSTANCE, partitioner, 4); + + assertThat(config.getTypeSerializer()).isSameAs(LongSerializer.INSTANCE); + assertThat(config.getPartitioner()).isSameAs(partitioner); + assertThat(config.getNumberOfChannels()).isEqualTo(4); + } + + @Test + void testMultipleGateConfigs() { + RecordFilterContext.InputFilterConfig config0 = + new RecordFilterContext.InputFilterConfig( + LongSerializer.INSTANCE, new ForwardPartitioner<>(), 2); + RecordFilterContext.InputFilterConfig config1 = + new RecordFilterContext.InputFilterConfig( + LongSerializer.INSTANCE, new ForwardPartitioner<>(), 4); + + RecordFilterContext context = + new RecordFilterContext( + new RecordFilterContext.InputFilterConfig[] {config0, config1}, + InflightDataRescalingDescriptor.NO_RESCALE, + 1, + 256, + new String[] {"/tmp"}, + false); + + assertThat(context.getNumberOfGates()).isEqualTo(2); + assertThat(context.getInputConfig(0)).isSameAs(config0); + assertThat(context.getInputConfig(1)).isSameAs(config1); + assertThat(context.getInputConfig(0).getNumberOfChannels()).isEqualTo(2); + assertThat(context.getInputConfig(1).getNumberOfChannels()).isEqualTo(4); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/VirtualChannelRecordFilterFactoryTest.java b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/VirtualChannelRecordFilterFactoryTest.java new file mode 100644 index 00000000000..bffc42e3329 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/io/recovery/VirtualChannelRecordFilterFactoryTest.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.streaming.runtime.io.recovery; + +import org.apache.flink.api.common.typeutils.base.LongSerializer; +import org.apache.flink.runtime.checkpoint.InflightDataRescalingDescriptor; +import org.apache.flink.streaming.runtime.partitioner.RebalancePartitioner; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +/** Tests for {@link VirtualChannelRecordFilterFactory}. */ +class VirtualChannelRecordFilterFactoryTest { + + @Test + void testCreatePassThroughFilter() { + RecordFilter<Long> filter = VirtualChannelRecordFilterFactory.createPassThroughFilter(); + assertThat(filter.filter(new StreamRecord<>(0L))).isTrue(); + assertThat(filter.filter(new StreamRecord<>(1L))).isTrue(); + assertThat(filter.filter(new StreamRecord<>(42L))).isTrue(); + } + + @Test + void testCreateFilterProducesPartitionerBasedFilter() { + RebalancePartitioner<Long> partitioner = new RebalancePartitioner<>(); + + VirtualChannelRecordFilterFactory<Long> factory = + new VirtualChannelRecordFilterFactory<>( + LongSerializer.INSTANCE, partitioner, 0, 2, 128); + + RecordFilter<Long> filter = factory.createFilter(); + // The filter should be a PartitionerRecordFilter that filters based on partitioner + assertThat(filter).isInstanceOf(PartitionerRecordFilter.class); + } + + @Test + void testFromContextCreatesFactory() { + RebalancePartitioner<Long> partitioner = new RebalancePartitioner<>(); + RecordFilterContext.InputFilterConfig config = + new RecordFilterContext.InputFilterConfig(LongSerializer.INSTANCE, partitioner, 4); + + RecordFilterContext context = + new RecordFilterContext( + new RecordFilterContext.InputFilterConfig[] {config}, + InflightDataRescalingDescriptor.NO_RESCALE, + 1, + 128, + new String[] {"/tmp"}, + true); + + VirtualChannelRecordFilterFactory<Long> factory = + VirtualChannelRecordFilterFactory.fromContext(context, 0); + RecordFilter<Long> filter = factory.createFilter(); + + // The filter should be a functional PartitionerRecordFilter + assertThat(filter).isInstanceOf(PartitionerRecordFilter.class); + } + + @Test + void testEachFilterCallCreatesIndependentFilter() { + RebalancePartitioner<Long> partitioner = new RebalancePartitioner<>(); + + VirtualChannelRecordFilterFactory<Long> factory = + new VirtualChannelRecordFilterFactory<>( + LongSerializer.INSTANCE, partitioner, 0, 2, 128); + + RecordFilter<Long> filter1 = factory.createFilter(); + RecordFilter<Long> filter2 = factory.createFilter(); + + // Each call should produce a distinct filter instance (using a copy of the partitioner) + assertThat(filter1).isNotSameAs(filter2); + } +}
