This is an automated email from the ASF dual-hosted git repository.

gyfora pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 61d6e78e1f6dff093e005c834e1a008f146c02b9
Author: Maximilian Michels <m...@apache.org>
AuthorDate: Mon Nov 28 11:37:24 2022 +0100

    [FLINK-30213] Change ForwardPartitioner to RebalancePartitioner on 
parallelism changes
    
    In case of parallelism changes to the JobGraph, as done via the 
AdaptiveScheduler
    or through providing JobVertexId overrides in 
PipelineOptions#PARALLELISM_OVERRIDES, the inner
    serialized PartitionStrategy of a StreamTask may not be suitable anymore.
    
    This is the case for the ForwardPartitioner strategy which uses a fixed 
local channel for
    transmitting data. Whenever the consumer parallelism doesn't match the 
local parallelism, we should
    be replacing it with the RebalancePartitioner.
---
 .../flink/runtime/dispatcher/Dispatcher.java       |  3 +-
 .../api/writer/ChannelSelectorRecordWriter.java    |  6 ++
 .../streaming/api/graph/NonChainedOutput.java      |  6 +-
 .../flink/streaming/runtime/tasks/StreamTask.java  | 12 ++++
 .../runtime/tasks/StreamConfigChainer.java         |  8 ++-
 .../tasks/StreamTaskFinalCheckpointsTest.java      |  4 +-
 .../tasks/StreamTaskMailboxTestHarnessBuilder.java | 17 ++++--
 .../streaming/runtime/tasks/StreamTaskTest.java    | 69 ++++++++++++++++++++++
 8 files changed, 115 insertions(+), 10 deletions(-)

diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/dispatcher/Dispatcher.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/dispatcher/Dispatcher.java
index c733af51334..cfa36bad1df 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/dispatcher/Dispatcher.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/dispatcher/Dispatcher.java
@@ -1342,11 +1342,12 @@ public abstract class Dispatcher extends 
FencedRpcEndpoint<DispatcherId>
         for (JobVertex vertex : jobGraph.getVertices()) {
             String override = overrides.get(vertex.getID().toHexString());
             if (override != null) {
+                int currentParallelism = vertex.getParallelism();
                 int overrideParallelism = Integer.parseInt(override);
                 log.info(
                         "Changing job vertex {} parallelism from {} to {}",
                         vertex.getID(),
-                        vertex.getParallelism(),
+                        currentParallelism,
                         overrideParallelism);
                 vertex.setParallelism(overrideParallelism);
             }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ChannelSelectorRecordWriter.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ChannelSelectorRecordWriter.java
index 07181bd01f1..5b756b693b3 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ChannelSelectorRecordWriter.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ChannelSelectorRecordWriter.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.runtime.io.network.api.writer;
 
+import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.core.io.IOReadableWritable;
 
 import java.io.IOException;
@@ -71,4 +72,9 @@ public final class ChannelSelectorRecordWriter<T extends 
IOReadableWritable>
             flushAll();
         }
     }
+
+    @VisibleForTesting
+    public ChannelSelector<T> getChannelSelector() {
+        return channelSelector;
+    }
 }
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/NonChainedOutput.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/NonChainedOutput.java
index 20d357c14e4..1042b396276 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/NonChainedOutput.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/NonChainedOutput.java
@@ -59,7 +59,7 @@ public class NonChainedOutput implements Serializable {
     private final OutputTag<?> outputTag;
 
     /** The corresponding data partitioner. */
-    private final StreamPartitioner<?> partitioner;
+    private StreamPartitioner<?> partitioner;
 
     /** Target {@link ResultPartitionType}. */
     private final ResultPartitionType partitionType;
@@ -119,6 +119,10 @@ public class NonChainedOutput implements Serializable {
         return outputTag;
     }
 
+    public void setPartitioner(StreamPartitioner<?> partitioner) {
+        this.partitioner = partitioner;
+    }
+
     public StreamPartitioner<?> getPartitioner() {
         return partitioner;
     }
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
index dc6a9dff6e4..16666b71804 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
@@ -91,6 +91,8 @@ import 
org.apache.flink.streaming.runtime.io.StreamInputProcessor;
 import 
org.apache.flink.streaming.runtime.io.checkpointing.BarrierAlignmentUtil;
 import 
org.apache.flink.streaming.runtime.io.checkpointing.CheckpointBarrierHandler;
 import 
org.apache.flink.streaming.runtime.partitioner.ConfigurableStreamPartitioner;
+import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner;
+import org.apache.flink.streaming.runtime.partitioner.RescalePartitioner;
 import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.runtime.tasks.mailbox.GaugePeriodTimer;
@@ -1603,6 +1605,7 @@ public abstract class StreamTask<OUT, OP extends 
StreamOperator<OUT>>
 
         int index = 0;
         for (NonChainedOutput streamOutput : outputsInOrder) {
+            
replaceForwardPartitionerIfConsumerParallelismDoesNotMatch(environment, 
streamOutput);
             recordWriters.add(
                     createRecordWriter(
                             streamOutput,
@@ -1614,6 +1617,15 @@ public abstract class StreamTask<OUT, OP extends 
StreamOperator<OUT>>
         return recordWriters;
     }
 
+    private static void 
replaceForwardPartitionerIfConsumerParallelismDoesNotMatch(
+            Environment environment, NonChainedOutput streamOutput) {
+        if (streamOutput.getPartitioner() instanceof ForwardPartitioner
+                && streamOutput.getConsumerParallelism()
+                        != 
environment.getTaskInfo().getNumberOfParallelSubtasks()) {
+            streamOutput.setPartitioner(new RescalePartitioner<>());
+        }
+    }
+
     @SuppressWarnings("unchecked")
     private static <OUT> 
RecordWriter<SerializationDelegate<StreamRecord<OUT>>> createRecordWriter(
             NonChainedOutput streamOutput,
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamConfigChainer.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamConfigChainer.java
index 37520e0515c..100340a6ac5 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamConfigChainer.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamConfigChainer.java
@@ -35,6 +35,7 @@ import 
org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
 import org.apache.flink.streaming.api.operators.StreamOperator;
 import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
 import org.apache.flink.streaming.runtime.partitioner.BroadcastPartitioner;
+import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
 
 import java.util.Collections;
 import java.util.HashMap;
@@ -203,6 +204,11 @@ public class StreamConfigChainer<OWNER> {
     }
 
     public <OUT> OWNER finishForSingletonOperatorChain(TypeSerializer<OUT> 
outputSerializer) {
+        return finishForSingletonOperatorChain(outputSerializer, new 
BroadcastPartitioner<>());
+    }
+
+    public <OUT> OWNER finishForSingletonOperatorChain(
+            TypeSerializer<OUT> outputSerializer, StreamPartitioner<?> 
partitioner) {
 
         checkState(chainIndex == 0, "Use finishForSingletonOperatorChain");
         checkState(headConfig == tailConfig);
@@ -231,7 +237,7 @@ public class StreamConfigChainer<OWNER> {
                             false,
                             new IntermediateDataSetID(),
                             null,
-                            new BroadcastPartitioner<>(),
+                            partitioner,
                             ResultPartitionType.PIPELINED_BOUNDED));
         }
 
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskFinalCheckpointsTest.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskFinalCheckpointsTest.java
index 08ab5d99e60..637784d132f 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskFinalCheckpointsTest.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskFinalCheckpointsTest.java
@@ -53,6 +53,7 @@ import 
org.apache.flink.streaming.api.operators.AbstractStreamOperator;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
 import org.apache.flink.streaming.api.operators.StreamSource;
 import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.partitioner.BroadcastPartitioner;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.util.CompletingCheckpointResponder;
 import org.apache.flink.util.FlinkRuntimeException;
@@ -125,7 +126,8 @@ public class StreamTaskFinalCheckpointsTest {
                             .addInput(STRING_TYPE_INFO)
                             .addAdditionalOutput(partitionWriters)
                             .setupOperatorChain(new EmptyOperator())
-                            
.finishForSingletonOperatorChain(StringSerializer.INSTANCE)
+                            .finishForSingletonOperatorChain(
+                                    StringSerializer.INSTANCE, new 
BroadcastPartitioner<>())
                             .build()) {
                 testHarness.endInput();
 
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskMailboxTestHarnessBuilder.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskMailboxTestHarnessBuilder.java
index 7fb875ada35..dcf9ffb3b8c 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskMailboxTestHarnessBuilder.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskMailboxTestHarnessBuilder.java
@@ -58,6 +58,7 @@ import 
org.apache.flink.streaming.api.operators.StreamOperator;
 import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
 import org.apache.flink.streaming.runtime.partitioner.BroadcastPartitioner;
 import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner;
+import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
 import org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer;
 import org.apache.flink.util.function.FunctionWithException;
 
@@ -108,6 +109,8 @@ public class StreamTaskMailboxTestHarnessBuilder<OUT> {
     private Function<SingleInputGateBuilder, SingleInputGateBuilder> 
modifyGateBuilder =
             Function.identity();
 
+    private StreamPartitioner<?> partitioner = new BroadcastPartitioner<>();
+
     public StreamTaskMailboxTestHarnessBuilder(
             FunctionWithException<Environment, ? extends StreamTask<OUT, ?>, 
Exception> taskFactory,
             TypeInformation<OUT> outputType) {
@@ -324,11 +327,7 @@ public class StreamTaskMailboxTestHarnessBuilder<OUT> {
                         0, null, null, (StreamOperator<?>) null, null, 
SourceStreamTask.class);
         StreamEdge streamEdge =
                 new StreamEdge(
-                        sourceVertexDummy,
-                        targetVertexDummy,
-                        gateIndex + 1,
-                        new BroadcastPartitioner<>(),
-                        null);
+                        sourceVertexDummy, targetVertexDummy, gateIndex + 1, 
partitioner, null);
 
         inPhysicalEdges.add(streamEdge);
         
streamMockEnvironment.addInputGate(inputGates[gateIndex].getInputGate());
@@ -415,7 +414,7 @@ public class StreamTaskMailboxTestHarnessBuilder<OUT> {
             StreamOperatorFactory<?> factory, OperatorID operatorID) {
         checkState(!setupCalled, "This harness was already setup.");
         return setupOperatorChain(operatorID, factory)
-                .finishForSingletonOperatorChain(outputSerializer);
+                .finishForSingletonOperatorChain(outputSerializer, 
partitioner);
     }
 
     public StreamConfigChainer<StreamTaskMailboxTestHarnessBuilder<OUT>> 
setupOperatorChain(
@@ -462,6 +461,12 @@ public class StreamTaskMailboxTestHarnessBuilder<OUT> {
         return this;
     }
 
+    public StreamTaskMailboxTestHarnessBuilder<OUT> setOutputPartitioner(
+            StreamPartitioner partitioner) {
+        this.partitioner = partitioner;
+        return this;
+    }
+
     /**
      * A place holder representation of a {@link SourceInputConfig}. When 
building the test harness
      * it is replaced with {@link SourceInputConfig}.
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
index 9f5bbf227e9..30063b0a09e 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
@@ -24,6 +24,7 @@ import org.apache.flink.api.common.operators.MailboxExecutor;
 import 
org.apache.flink.api.common.operators.ProcessingTimeService.ProcessingTimeCallback;
 import org.apache.flink.api.common.state.CheckpointListener;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.IntegerTypeInfo;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.configuration.ReadableConfig;
@@ -53,10 +54,15 @@ import 
org.apache.flink.runtime.io.network.NettyShuffleEnvironment;
 import org.apache.flink.runtime.io.network.NettyShuffleEnvironmentBuilder;
 import org.apache.flink.runtime.io.network.api.StopMode;
 import 
org.apache.flink.runtime.io.network.api.writer.AvailabilityTestResultPartitionWriter;
+import 
org.apache.flink.runtime.io.network.api.writer.ChannelSelectorRecordWriter;
+import org.apache.flink.runtime.io.network.api.writer.RecordWriterDelegate;
 import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
+import org.apache.flink.runtime.io.network.api.writer.SingleRecordWriter;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
 import org.apache.flink.runtime.io.network.partition.consumer.IndexedInputGate;
 import org.apache.flink.runtime.io.network.partition.consumer.InputGate;
 import org.apache.flink.runtime.io.network.partition.consumer.TestInputChannel;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
 import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.jobgraph.tasks.TaskInvokable;
 import org.apache.flink.runtime.metrics.TimerGauge;
@@ -65,6 +71,7 @@ import 
org.apache.flink.runtime.operators.testutils.DummyEnvironment;
 import org.apache.flink.runtime.operators.testutils.ExpectedTestException;
 import org.apache.flink.runtime.operators.testutils.MockEnvironment;
 import org.apache.flink.runtime.operators.testutils.MockEnvironmentBuilder;
+import org.apache.flink.runtime.plugable.SerializationDelegate;
 import org.apache.flink.runtime.shuffle.ShuffleEnvironment;
 import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.AbstractStateBackend;
@@ -109,6 +116,7 @@ import org.apache.flink.streaming.api.TimeCharacteristic;
 import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
 import 
org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
 import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.graph.NonChainedOutput;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
 import org.apache.flink.streaming.api.operators.InternalTimeServiceManager;
@@ -122,6 +130,8 @@ import 
org.apache.flink.streaming.api.operators.StreamSource;
 import org.apache.flink.streaming.api.operators.StreamTaskStateInitializer;
 import org.apache.flink.streaming.runtime.io.DataInputStatus;
 import org.apache.flink.streaming.runtime.io.StreamInputProcessor;
+import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner;
+import org.apache.flink.streaming.runtime.partitioner.RescalePartitioner;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.runtime.tasks.mailbox.MailboxDefaultAction;
 import org.apache.flink.streaming.util.MockStreamConfig;
@@ -132,6 +142,7 @@ import org.apache.flink.util.CloseableIterable;
 import org.apache.flink.util.ExceptionUtils;
 import org.apache.flink.util.FatalExitExceptionHandler;
 import org.apache.flink.util.FlinkRuntimeException;
+import org.apache.flink.util.OutputTag;
 import org.apache.flink.util.TestLogger;
 import org.apache.flink.util.clock.SystemClock;
 import org.apache.flink.util.concurrent.FutureUtils;
@@ -1799,6 +1810,64 @@ public class StreamTaskTest extends TestLogger {
         }
     }
 
+    @Test
+    public void 
testForwardPartitionerIsConvertedToRebalanceOnParallelismChanges()
+            throws Exception {
+        StreamTaskMailboxTestHarnessBuilder<Integer> builder =
+                new StreamTaskMailboxTestHarnessBuilder<>(
+                                OneInputStreamTask::new, 
BasicTypeInfo.INT_TYPE_INFO)
+                        .addInput(BasicTypeInfo.INT_TYPE_INFO)
+                        .setOutputPartitioner(new ForwardPartitioner<>())
+                        .setupOutputForSingletonOperatorChain(
+                                new TestBoundedOneInputStreamOperator());
+
+        try (StreamTaskMailboxTestHarness<Integer> harness = builder.build()) {
+
+            RecordWriterDelegate<SerializationDelegate<StreamRecord<Object>>> 
recordWriterDelegate =
+                    harness.streamTask.createRecordWriterDelegate(
+                            harness.streamTask.configuration, 
harness.streamMockEnvironment);
+            // Prerequisite: We are using the ForwardPartitioner
+            assertTrue(
+                    ((ChannelSelectorRecordWriter)
+                                            ((SingleRecordWriter) 
recordWriterDelegate)
+                                                    .getRecordWriter(0))
+                                    .getChannelSelector()
+                            instanceof ForwardPartitioner);
+
+            // Change consumer parallelism
+            harness.streamTask.configuration.setVertexNonChainedOutputs(
+                    List.of(
+                            new NonChainedOutput(
+                                    false,
+                                    0,
+                                    // Set a different consumer parallelism to 
force trigger
+                                    // replacing the ForwardPartitioner
+                                    42,
+                                    100,
+                                    1000,
+                                    false,
+                                    new IntermediateDataSetID(),
+                                    new OutputTag<>("output", 
IntegerTypeInfo.INT_TYPE_INFO),
+                                    // Use forward partitioner
+                                    new ForwardPartitioner<>(),
+                                    ResultPartitionType.PIPELINED)));
+            harness.streamTask.configuration.serializeAllConfigs();
+
+            // Re-create outputs
+            recordWriterDelegate =
+                    harness.streamTask.createRecordWriterDelegate(
+                            harness.streamTask.configuration, 
harness.streamMockEnvironment);
+            // We should now have a RescalePartitioner to distribute the load
+            // for the non-matching downstream parallelism
+            assertTrue(
+                    ((ChannelSelectorRecordWriter)
+                                            ((SingleRecordWriter) 
recordWriterDelegate)
+                                                    .getRecordWriter(0))
+                                    .getChannelSelector()
+                            instanceof RescalePartitioner);
+        }
+    }
+
     private int getCurrentBufferSize(InputGate inputGate) {
         return getTestChannel(inputGate, 0).getCurrentBufferSize();
     }

Reply via email to