This is an automated email from the ASF dual-hosted git repository. gyfora pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit 61d6e78e1f6dff093e005c834e1a008f146c02b9 Author: Maximilian Michels <m...@apache.org> AuthorDate: Mon Nov 28 11:37:24 2022 +0100 [FLINK-30213] Change ForwardPartitioner to RebalancePartitioner on parallelism changes In case of parallelism changes to the JobGraph, as done via the AdaptiveScheduler or through providing JobVertexId overrides in PipelineOptions#PARALLELISM_OVERRIDES, the inner serialized PartitionStrategy of a StreamTask may not be suitable anymore. This is the case for the ForwardPartitioner strategy which uses a fixed local channel for transmitting data. Whenever the consumer parallelism doesn't match the local parallelism, we should be replacing it with the RebalancePartitioner. --- .../flink/runtime/dispatcher/Dispatcher.java | 3 +- .../api/writer/ChannelSelectorRecordWriter.java | 6 ++ .../streaming/api/graph/NonChainedOutput.java | 6 +- .../flink/streaming/runtime/tasks/StreamTask.java | 12 ++++ .../runtime/tasks/StreamConfigChainer.java | 8 ++- .../tasks/StreamTaskFinalCheckpointsTest.java | 4 +- .../tasks/StreamTaskMailboxTestHarnessBuilder.java | 17 ++++-- .../streaming/runtime/tasks/StreamTaskTest.java | 69 ++++++++++++++++++++++ 8 files changed, 115 insertions(+), 10 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/dispatcher/Dispatcher.java b/flink-runtime/src/main/java/org/apache/flink/runtime/dispatcher/Dispatcher.java index c733af51334..cfa36bad1df 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/dispatcher/Dispatcher.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/dispatcher/Dispatcher.java @@ -1342,11 +1342,12 @@ public abstract class Dispatcher extends FencedRpcEndpoint<DispatcherId> for (JobVertex vertex : jobGraph.getVertices()) { String override = overrides.get(vertex.getID().toHexString()); if (override != null) { + int currentParallelism = vertex.getParallelism(); int overrideParallelism = Integer.parseInt(override); log.info( "Changing job vertex {} parallelism from {} to {}", vertex.getID(), - vertex.getParallelism(), + currentParallelism, overrideParallelism); vertex.setParallelism(overrideParallelism); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ChannelSelectorRecordWriter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ChannelSelectorRecordWriter.java index 07181bd01f1..5b756b693b3 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ChannelSelectorRecordWriter.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ChannelSelectorRecordWriter.java @@ -18,6 +18,7 @@ package org.apache.flink.runtime.io.network.api.writer; +import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.core.io.IOReadableWritable; import java.io.IOException; @@ -71,4 +72,9 @@ public final class ChannelSelectorRecordWriter<T extends IOReadableWritable> flushAll(); } } + + @VisibleForTesting + public ChannelSelector<T> getChannelSelector() { + return channelSelector; + } } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/NonChainedOutput.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/NonChainedOutput.java index 20d357c14e4..1042b396276 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/NonChainedOutput.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/NonChainedOutput.java @@ -59,7 +59,7 @@ public class NonChainedOutput implements Serializable { private final OutputTag<?> outputTag; /** The corresponding data partitioner. */ - private final StreamPartitioner<?> partitioner; + private StreamPartitioner<?> partitioner; /** Target {@link ResultPartitionType}. */ private final ResultPartitionType partitionType; @@ -119,6 +119,10 @@ public class NonChainedOutput implements Serializable { return outputTag; } + public void setPartitioner(StreamPartitioner<?> partitioner) { + this.partitioner = partitioner; + } + public StreamPartitioner<?> getPartitioner() { return partitioner; } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java index dc6a9dff6e4..16666b71804 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java @@ -91,6 +91,8 @@ import org.apache.flink.streaming.runtime.io.StreamInputProcessor; import org.apache.flink.streaming.runtime.io.checkpointing.BarrierAlignmentUtil; import org.apache.flink.streaming.runtime.io.checkpointing.CheckpointBarrierHandler; import org.apache.flink.streaming.runtime.partitioner.ConfigurableStreamPartitioner; +import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner; +import org.apache.flink.streaming.runtime.partitioner.RescalePartitioner; import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.runtime.tasks.mailbox.GaugePeriodTimer; @@ -1603,6 +1605,7 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>> int index = 0; for (NonChainedOutput streamOutput : outputsInOrder) { + replaceForwardPartitionerIfConsumerParallelismDoesNotMatch(environment, streamOutput); recordWriters.add( createRecordWriter( streamOutput, @@ -1614,6 +1617,15 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>> return recordWriters; } + private static void replaceForwardPartitionerIfConsumerParallelismDoesNotMatch( + Environment environment, NonChainedOutput streamOutput) { + if (streamOutput.getPartitioner() instanceof ForwardPartitioner + && streamOutput.getConsumerParallelism() + != environment.getTaskInfo().getNumberOfParallelSubtasks()) { + streamOutput.setPartitioner(new RescalePartitioner<>()); + } + } + @SuppressWarnings("unchecked") private static <OUT> RecordWriter<SerializationDelegate<StreamRecord<OUT>>> createRecordWriter( NonChainedOutput streamOutput, diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamConfigChainer.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamConfigChainer.java index 37520e0515c..100340a6ac5 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamConfigChainer.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamConfigChainer.java @@ -35,6 +35,7 @@ import org.apache.flink.streaming.api.operators.SimpleOperatorFactory; import org.apache.flink.streaming.api.operators.StreamOperator; import org.apache.flink.streaming.api.operators.StreamOperatorFactory; import org.apache.flink.streaming.runtime.partitioner.BroadcastPartitioner; +import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner; import java.util.Collections; import java.util.HashMap; @@ -203,6 +204,11 @@ public class StreamConfigChainer<OWNER> { } public <OUT> OWNER finishForSingletonOperatorChain(TypeSerializer<OUT> outputSerializer) { + return finishForSingletonOperatorChain(outputSerializer, new BroadcastPartitioner<>()); + } + + public <OUT> OWNER finishForSingletonOperatorChain( + TypeSerializer<OUT> outputSerializer, StreamPartitioner<?> partitioner) { checkState(chainIndex == 0, "Use finishForSingletonOperatorChain"); checkState(headConfig == tailConfig); @@ -231,7 +237,7 @@ public class StreamConfigChainer<OWNER> { false, new IntermediateDataSetID(), null, - new BroadcastPartitioner<>(), + partitioner, ResultPartitionType.PIPELINED_BOUNDED)); } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskFinalCheckpointsTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskFinalCheckpointsTest.java index 08ab5d99e60..637784d132f 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskFinalCheckpointsTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskFinalCheckpointsTest.java @@ -53,6 +53,7 @@ import org.apache.flink.streaming.api.operators.AbstractStreamOperator; import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.api.operators.StreamSource; import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.partitioner.BroadcastPartitioner; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.util.CompletingCheckpointResponder; import org.apache.flink.util.FlinkRuntimeException; @@ -125,7 +126,8 @@ public class StreamTaskFinalCheckpointsTest { .addInput(STRING_TYPE_INFO) .addAdditionalOutput(partitionWriters) .setupOperatorChain(new EmptyOperator()) - .finishForSingletonOperatorChain(StringSerializer.INSTANCE) + .finishForSingletonOperatorChain( + StringSerializer.INSTANCE, new BroadcastPartitioner<>()) .build()) { testHarness.endInput(); diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskMailboxTestHarnessBuilder.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskMailboxTestHarnessBuilder.java index 7fb875ada35..dcf9ffb3b8c 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskMailboxTestHarnessBuilder.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskMailboxTestHarnessBuilder.java @@ -58,6 +58,7 @@ import org.apache.flink.streaming.api.operators.StreamOperator; import org.apache.flink.streaming.api.operators.StreamOperatorFactory; import org.apache.flink.streaming.runtime.partitioner.BroadcastPartitioner; import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner; +import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner; import org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer; import org.apache.flink.util.function.FunctionWithException; @@ -108,6 +109,8 @@ public class StreamTaskMailboxTestHarnessBuilder<OUT> { private Function<SingleInputGateBuilder, SingleInputGateBuilder> modifyGateBuilder = Function.identity(); + private StreamPartitioner<?> partitioner = new BroadcastPartitioner<>(); + public StreamTaskMailboxTestHarnessBuilder( FunctionWithException<Environment, ? extends StreamTask<OUT, ?>, Exception> taskFactory, TypeInformation<OUT> outputType) { @@ -324,11 +327,7 @@ public class StreamTaskMailboxTestHarnessBuilder<OUT> { 0, null, null, (StreamOperator<?>) null, null, SourceStreamTask.class); StreamEdge streamEdge = new StreamEdge( - sourceVertexDummy, - targetVertexDummy, - gateIndex + 1, - new BroadcastPartitioner<>(), - null); + sourceVertexDummy, targetVertexDummy, gateIndex + 1, partitioner, null); inPhysicalEdges.add(streamEdge); streamMockEnvironment.addInputGate(inputGates[gateIndex].getInputGate()); @@ -415,7 +414,7 @@ public class StreamTaskMailboxTestHarnessBuilder<OUT> { StreamOperatorFactory<?> factory, OperatorID operatorID) { checkState(!setupCalled, "This harness was already setup."); return setupOperatorChain(operatorID, factory) - .finishForSingletonOperatorChain(outputSerializer); + .finishForSingletonOperatorChain(outputSerializer, partitioner); } public StreamConfigChainer<StreamTaskMailboxTestHarnessBuilder<OUT>> setupOperatorChain( @@ -462,6 +461,12 @@ public class StreamTaskMailboxTestHarnessBuilder<OUT> { return this; } + public StreamTaskMailboxTestHarnessBuilder<OUT> setOutputPartitioner( + StreamPartitioner partitioner) { + this.partitioner = partitioner; + return this; + } + /** * A place holder representation of a {@link SourceInputConfig}. When building the test harness * it is replaced with {@link SourceInputConfig}. diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java index 9f5bbf227e9..30063b0a09e 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java @@ -24,6 +24,7 @@ import org.apache.flink.api.common.operators.MailboxExecutor; import org.apache.flink.api.common.operators.ProcessingTimeService.ProcessingTimeCallback; import org.apache.flink.api.common.state.CheckpointListener; import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.IntegerTypeInfo; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.ReadableConfig; @@ -53,10 +54,15 @@ import org.apache.flink.runtime.io.network.NettyShuffleEnvironment; import org.apache.flink.runtime.io.network.NettyShuffleEnvironmentBuilder; import org.apache.flink.runtime.io.network.api.StopMode; import org.apache.flink.runtime.io.network.api.writer.AvailabilityTestResultPartitionWriter; +import org.apache.flink.runtime.io.network.api.writer.ChannelSelectorRecordWriter; +import org.apache.flink.runtime.io.network.api.writer.RecordWriterDelegate; import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; +import org.apache.flink.runtime.io.network.api.writer.SingleRecordWriter; +import org.apache.flink.runtime.io.network.partition.ResultPartitionType; import org.apache.flink.runtime.io.network.partition.consumer.IndexedInputGate; import org.apache.flink.runtime.io.network.partition.consumer.InputGate; import org.apache.flink.runtime.io.network.partition.consumer.TestInputChannel; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.jobgraph.tasks.TaskInvokable; import org.apache.flink.runtime.metrics.TimerGauge; @@ -65,6 +71,7 @@ import org.apache.flink.runtime.operators.testutils.DummyEnvironment; import org.apache.flink.runtime.operators.testutils.ExpectedTestException; import org.apache.flink.runtime.operators.testutils.MockEnvironment; import org.apache.flink.runtime.operators.testutils.MockEnvironmentBuilder; +import org.apache.flink.runtime.plugable.SerializationDelegate; import org.apache.flink.runtime.shuffle.ShuffleEnvironment; import org.apache.flink.runtime.state.AbstractKeyedStateBackend; import org.apache.flink.runtime.state.AbstractStateBackend; @@ -109,6 +116,7 @@ import org.apache.flink.streaming.api.TimeCharacteristic; import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction; import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction; import org.apache.flink.streaming.api.functions.source.SourceFunction; +import org.apache.flink.streaming.api.graph.NonChainedOutput; import org.apache.flink.streaming.api.graph.StreamConfig; import org.apache.flink.streaming.api.operators.AbstractStreamOperator; import org.apache.flink.streaming.api.operators.InternalTimeServiceManager; @@ -122,6 +130,8 @@ import org.apache.flink.streaming.api.operators.StreamSource; import org.apache.flink.streaming.api.operators.StreamTaskStateInitializer; import org.apache.flink.streaming.runtime.io.DataInputStatus; import org.apache.flink.streaming.runtime.io.StreamInputProcessor; +import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner; +import org.apache.flink.streaming.runtime.partitioner.RescalePartitioner; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.runtime.tasks.mailbox.MailboxDefaultAction; import org.apache.flink.streaming.util.MockStreamConfig; @@ -132,6 +142,7 @@ import org.apache.flink.util.CloseableIterable; import org.apache.flink.util.ExceptionUtils; import org.apache.flink.util.FatalExitExceptionHandler; import org.apache.flink.util.FlinkRuntimeException; +import org.apache.flink.util.OutputTag; import org.apache.flink.util.TestLogger; import org.apache.flink.util.clock.SystemClock; import org.apache.flink.util.concurrent.FutureUtils; @@ -1799,6 +1810,64 @@ public class StreamTaskTest extends TestLogger { } } + @Test + public void testForwardPartitionerIsConvertedToRebalanceOnParallelismChanges() + throws Exception { + StreamTaskMailboxTestHarnessBuilder<Integer> builder = + new StreamTaskMailboxTestHarnessBuilder<>( + OneInputStreamTask::new, BasicTypeInfo.INT_TYPE_INFO) + .addInput(BasicTypeInfo.INT_TYPE_INFO) + .setOutputPartitioner(new ForwardPartitioner<>()) + .setupOutputForSingletonOperatorChain( + new TestBoundedOneInputStreamOperator()); + + try (StreamTaskMailboxTestHarness<Integer> harness = builder.build()) { + + RecordWriterDelegate<SerializationDelegate<StreamRecord<Object>>> recordWriterDelegate = + harness.streamTask.createRecordWriterDelegate( + harness.streamTask.configuration, harness.streamMockEnvironment); + // Prerequisite: We are using the ForwardPartitioner + assertTrue( + ((ChannelSelectorRecordWriter) + ((SingleRecordWriter) recordWriterDelegate) + .getRecordWriter(0)) + .getChannelSelector() + instanceof ForwardPartitioner); + + // Change consumer parallelism + harness.streamTask.configuration.setVertexNonChainedOutputs( + List.of( + new NonChainedOutput( + false, + 0, + // Set a different consumer parallelism to force trigger + // replacing the ForwardPartitioner + 42, + 100, + 1000, + false, + new IntermediateDataSetID(), + new OutputTag<>("output", IntegerTypeInfo.INT_TYPE_INFO), + // Use forward partitioner + new ForwardPartitioner<>(), + ResultPartitionType.PIPELINED))); + harness.streamTask.configuration.serializeAllConfigs(); + + // Re-create outputs + recordWriterDelegate = + harness.streamTask.createRecordWriterDelegate( + harness.streamTask.configuration, harness.streamMockEnvironment); + // We should now have a RescalePartitioner to distribute the load + // for the non-matching downstream parallelism + assertTrue( + ((ChannelSelectorRecordWriter) + ((SingleRecordWriter) recordWriterDelegate) + .getRecordWriter(0)) + .getChannelSelector() + instanceof RescalePartitioner); + } + } + private int getCurrentBufferSize(InputGate inputGate) { return getTestChannel(inputGate, 0).getCurrentBufferSize(); }