This is an automated email from the ASF dual-hosted git repository.

wanglijie pushed a commit to branch release-1.17
in repository https://gitbox.apache.org/repos/asf/flink.git

commit c66ef2540c3cb53f4cf3218ff07f5b440511ad84
Author: Lijie Wang <wangdachui9...@gmail.com>
AuthorDate: Wed Feb 22 11:44:06 2023 +0800

    [FLINK-31114][runtime] Set parallelism of job vertices in forward group at 
compilation phase
    
    This closes #21963
---
 .../jobgraph/forwardgroup/ForwardGroup.java        |  33 ++-
 .../forwardgroup/ForwardGroupComputeUtil.java      |  34 ++-
 .../AdaptiveBatchSchedulerFactory.java             |   2 +-
 .../forwardgroup/ForwardGroupComputeUtilTest.java  |  55 +---
 .../runtime/scheduler/DefaultSchedulerBuilder.java |   4 +-
 .../api/graph/StreamingJobGraphGenerator.java      | 321 ++++++++++++++++-----
 .../scheduling/AdaptiveBatchSchedulerITCase.java   |  31 ++
 7 files changed, 327 insertions(+), 153 deletions(-)

diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroup.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroup.java
index cf4f2e67194..922acf231e2 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroup.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroup.java
@@ -23,6 +23,7 @@ import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.runtime.jobgraph.JobVertex;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 
+import java.util.Collections;
 import java.util.HashSet;
 import java.util.Set;
 import java.util.stream.Collectors;
@@ -38,12 +39,13 @@ public class ForwardGroup {
 
     private int parallelism = ExecutionConfig.PARALLELISM_DEFAULT;
 
+    private int maxParallelism = JobVertex.MAX_PARALLELISM_DEFAULT;
     private final Set<JobVertexID> jobVertexIds = new HashSet<>();
 
     public ForwardGroup(final Set<JobVertex> jobVertices) {
         checkNotNull(jobVertices);
 
-        Set<Integer> decidedParallelisms =
+        Set<Integer> configuredParallelisms =
                 jobVertices.stream()
                         .filter(
                                 jobVertex -> {
@@ -53,9 +55,23 @@ public class ForwardGroup {
                         .map(JobVertex::getParallelism)
                         .collect(Collectors.toSet());
 
-        checkState(decidedParallelisms.size() <= 1);
-        if (decidedParallelisms.size() == 1) {
-            this.parallelism = decidedParallelisms.iterator().next();
+        checkState(configuredParallelisms.size() <= 1);
+        if (configuredParallelisms.size() == 1) {
+            this.parallelism = configuredParallelisms.iterator().next();
+        }
+
+        Set<Integer> configuredMaxParallelisms =
+                jobVertices.stream()
+                        .map(JobVertex::getMaxParallelism)
+                        .filter(val -> val > 0)
+                        .collect(Collectors.toSet());
+
+        if (!configuredMaxParallelisms.isEmpty()) {
+            this.maxParallelism = Collections.min(configuredMaxParallelisms);
+            checkState(
+                    parallelism == ExecutionConfig.PARALLELISM_DEFAULT
+                            || maxParallelism >= parallelism,
+                    "There is a job vertex in the forward group whose maximum 
parallelism is smaller than the group's parallelism");
         }
     }
 
@@ -73,6 +89,15 @@ public class ForwardGroup {
         return parallelism;
     }
 
+    public boolean isMaxParallelismDecided() {
+        return maxParallelism > 0;
+    }
+
+    public int getMaxParallelism() {
+        checkState(isMaxParallelismDecided());
+        return maxParallelism;
+    }
+
     public int size() {
         return jobVertexIds.size();
     }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroupComputeUtil.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroupComputeUtil.java
index c8a3395ae51..dc2dc702e34 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroupComputeUtil.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroupComputeUtil.java
@@ -19,8 +19,8 @@
 
 package org.apache.flink.runtime.jobgraph.forwardgroup;
 
-import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.runtime.executiongraph.VertexGroupComputeUtil;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSet;
 import org.apache.flink.runtime.jobgraph.JobEdge;
 import org.apache.flink.runtime.jobgraph.JobVertex;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
@@ -30,31 +30,34 @@ import java.util.HashSet;
 import java.util.IdentityHashMap;
 import java.util.Map;
 import java.util.Set;
+import java.util.function.Function;
 import java.util.stream.Collectors;
 
+import static org.apache.flink.util.Preconditions.checkState;
+
 /** Common utils for computing forward groups. */
 public class ForwardGroupComputeUtil {
 
-    public static Map<JobVertexID, ForwardGroup>
-            computeForwardGroupsAndSetVertexParallelismsIfNecessary(
-                    final Iterable<JobVertex> topologicallySortedVertices) {
+    public static Map<JobVertexID, ForwardGroup> 
computeForwardGroupsAndCheckParallelism(
+            final Iterable<JobVertex> topologicallySortedVertices) {
         final Map<JobVertexID, ForwardGroup> forwardGroupsByJobVertexId =
-                computeForwardGroups(topologicallySortedVertices);
-        // set parallelism for vertices in parallelism-decided forward groups
+                computeForwardGroups(
+                        topologicallySortedVertices, 
ForwardGroupComputeUtil::getForwardProducers);
+        // the vertex's parallelism in parallelism-decided forward group 
should have been set at
+        // compilation phase
         topologicallySortedVertices.forEach(
                 jobVertex -> {
                     ForwardGroup forwardGroup = 
forwardGroupsByJobVertexId.get(jobVertex.getID());
-                    if (jobVertex.getParallelism() == 
ExecutionConfig.PARALLELISM_DEFAULT
-                            && forwardGroup != null
-                            && forwardGroup.isParallelismDecided()) {
-                        
jobVertex.setParallelism(forwardGroup.getParallelism());
+                    if (forwardGroup != null && 
forwardGroup.isParallelismDecided()) {
+                        checkState(jobVertex.getParallelism() == 
forwardGroup.getParallelism());
                     }
                 });
         return forwardGroupsByJobVertexId;
     }
 
-    static Map<JobVertexID, ForwardGroup> computeForwardGroups(
-            final Iterable<JobVertex> topologicallySortedVertices) {
+    public static Map<JobVertexID, ForwardGroup> computeForwardGroups(
+            final Iterable<JobVertex> topologicallySortedVertices,
+            final Function<JobVertex, Set<JobVertex>> 
forwardProducersRetriever) {
 
         final Map<JobVertex, Set<JobVertex>> vertexToGroup = new 
IdentityHashMap<>();
 
@@ -64,8 +67,7 @@ public class ForwardGroupComputeUtil {
             currentGroup.add(vertex);
             vertexToGroup.put(vertex, currentGroup);
 
-            for (JobEdge input : getForwardInputs(vertex)) {
-                final JobVertex producerVertex = 
input.getSource().getProducer();
+            for (JobVertex producerVertex : 
forwardProducersRetriever.apply(vertex)) {
                 final Set<JobVertex> producerGroup = 
vertexToGroup.get(producerVertex);
 
                 if (producerGroup == null) {
@@ -99,9 +101,11 @@ public class ForwardGroupComputeUtil {
         return ret;
     }
 
-    static Iterable<JobEdge> getForwardInputs(JobVertex jobVertex) {
+    static Set<JobVertex> getForwardProducers(final JobVertex jobVertex) {
         return jobVertex.getInputs().stream()
                 .filter(JobEdge::isForward)
+                .map(JobEdge::getSource)
+                .map(IntermediateDataSet::getProducer)
                 .collect(Collectors.toSet());
     }
 }
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerFactory.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerFactory.java
index bdac05c1e4d..3746b549ed1 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerFactory.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerFactory.java
@@ -181,7 +181,7 @@ public class AdaptiveBatchSchedulerFactory implements 
SchedulerNGFactory {
                 getDefaultMaxParallelism(jobMasterConfiguration, 
executionConfig);
 
         final Map<JobVertexID, ForwardGroup> forwardGroupsByJobVertexId =
-                
ForwardGroupComputeUtil.computeForwardGroupsAndSetVertexParallelismsIfNecessary(
+                
ForwardGroupComputeUtil.computeForwardGroupsAndCheckParallelism(
                         jobGraph.getVerticesSortedTopologicallyFromSources());
 
         if (enableSpeculativeExecution) {
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroupComputeUtilTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroupComputeUtilTest.java
index 1f2bd4ead9e..e9a90e98918 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroupComputeUtilTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroupComputeUtilTest.java
@@ -22,16 +22,12 @@ import 
org.apache.flink.runtime.io.network.partition.ResultPartitionType;
 import org.apache.flink.runtime.jobgraph.DistributionPattern;
 import org.apache.flink.runtime.jobgraph.JobVertex;
 import org.apache.flink.runtime.testtasks.NoOpInvokable;
-import org.apache.flink.testutils.TestingUtils;
-import org.apache.flink.testutils.executor.TestExecutorExtension;
 
 import org.junit.jupiter.api.Test;
-import org.junit.jupiter.api.extension.RegisterExtension;
 
 import java.util.Arrays;
 import java.util.HashSet;
 import java.util.Set;
-import java.util.concurrent.ScheduledExecutorService;
 import java.util.stream.Collectors;
 
 import static org.assertj.core.api.Assertions.assertThat;
@@ -40,9 +36,6 @@ import static org.assertj.core.api.Assertions.assertThat;
  * Unit tests for {@link 
org.apache.flink.runtime.jobgraph.forwardgroup.ForwardGroupComputeUtil}.
  */
 class ForwardGroupComputeUtilTest {
-    @RegisterExtension
-    static final TestExecutorExtension<ScheduledExecutorService> 
EXECUTOR_RESOURCE =
-            TestingUtils.defaultExecutorExtension();
 
     /**
      * Tests that the computation of the job graph with isolated vertices 
works correctly.
@@ -178,58 +171,14 @@ class ForwardGroupComputeUtilTest {
         checkGroupSize(groups, 1, 3);
     }
 
-    /**
-     * Tests whether the parallelism of job vertices in forward group are 
correctly set.
-     *
-     * <pre>
-     *
-     *     (v1) -> (v2)
-     *
-     *     (v3) -> (v4)
-     *
-     * </pre>
-     */
-    @Test
-    void testComputeForwardGroupsAndSetVertexParallelismsIfNecessary() throws 
Exception {
-        JobVertex v1 = new JobVertex("v1");
-        JobVertex v2 = new JobVertex("v2");
-        JobVertex v3 = new JobVertex("v3");
-        JobVertex v4 = new JobVertex("v4");
-
-        v2.setParallelism(8);
-
-        v2.connectNewDataSetAsInput(
-                v1, DistributionPattern.ALL_TO_ALL, 
ResultPartitionType.BLOCKING);
-        v4.connectNewDataSetAsInput(
-                v3, DistributionPattern.POINTWISE, 
ResultPartitionType.BLOCKING);
-
-        v1.getProducedDataSets().get(0).getConsumers().get(0).setForward(true);
-        v3.getProducedDataSets().get(0).getConsumers().get(0).setForward(true);
-
-        Set<ForwardGroup> groups =
-                computeForwardGroupsAndSetVertexParallelismsIfNecessary(v1, 
v2, v3, v4);
-        checkGroupSize(groups, 2, 2, 2);
-        assertThat(v1.getParallelism()).isEqualTo(8);
-        assertThat(v2.getParallelism()).isEqualTo(8);
-        assertThat(v3.getParallelism()).isEqualTo(-1);
-        assertThat(v4.getParallelism()).isEqualTo(-1);
-    }
-
-    private static Set<ForwardGroup> 
computeForwardGroupsAndSetVertexParallelismsIfNecessary(
-            JobVertex... vertices) throws Exception {
+    private static Set<ForwardGroup> computeForwardGroups(JobVertex... 
vertices) {
         Arrays.asList(vertices).forEach(vertex -> 
vertex.setInvokableClass(NoOpInvokable.class));
         return new HashSet<>(
-                
ForwardGroupComputeUtil.computeForwardGroupsAndSetVertexParallelismsIfNecessary(
+                
ForwardGroupComputeUtil.computeForwardGroupsAndCheckParallelism(
                                 Arrays.asList(vertices))
                         .values());
     }
 
-    private static Set<ForwardGroup> computeForwardGroups(JobVertex... 
vertices) throws Exception {
-        Arrays.asList(vertices).forEach(vertex -> 
vertex.setInvokableClass(NoOpInvokable.class));
-        return new HashSet<>(
-                
ForwardGroupComputeUtil.computeForwardGroups(Arrays.asList(vertices)).values());
-    }
-
     private static void checkGroupSize(
             Set<ForwardGroup> groups, int numOfGroups, Integer... sizes) {
         assertThat(groups.size()).isEqualTo(numOfGroups);
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/DefaultSchedulerBuilder.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/DefaultSchedulerBuilder.java
index aef96c2469d..56c7b29ebd3 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/DefaultSchedulerBuilder.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/DefaultSchedulerBuilder.java
@@ -335,7 +335,7 @@ public class DefaultSchedulerBuilder {
                 vertexParallelismAndInputInfosDecider,
                 defaultMaxParallelism,
                 hybridPartitionDataConsumeConstraint,
-                
ForwardGroupComputeUtil.computeForwardGroupsAndSetVertexParallelismsIfNecessary(
+                
ForwardGroupComputeUtil.computeForwardGroupsAndCheckParallelism(
                         jobGraph.getVerticesSortedTopologicallyFromSources()));
     }
 
@@ -367,7 +367,7 @@ public class DefaultSchedulerBuilder {
                 defaultMaxParallelism,
                 blocklistOperations,
                 HybridPartitionDataConsumeConstraint.ALL_PRODUCERS_FINISHED,
-                
ForwardGroupComputeUtil.computeForwardGroupsAndSetVertexParallelismsIfNecessary(
+                
ForwardGroupComputeUtil.computeForwardGroupsAndCheckParallelism(
                         jobGraph.getVerticesSortedTopologicallyFromSources()));
     }
 
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java
index 3eceb27dc6a..3f818b6bd9a 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java
@@ -46,6 +46,8 @@ import org.apache.flink.runtime.jobgraph.JobType;
 import org.apache.flink.runtime.jobgraph.JobVertex;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.runtime.jobgraph.forwardgroup.ForwardGroup;
+import org.apache.flink.runtime.jobgraph.forwardgroup.ForwardGroupComputeUtil;
 import 
org.apache.flink.runtime.jobgraph.tasks.CheckpointCoordinatorConfiguration;
 import org.apache.flink.runtime.jobgraph.tasks.JobCheckpointingSettings;
 import org.apache.flink.runtime.jobgraph.tasks.TaskInvokable;
@@ -80,6 +82,7 @@ import 
org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
 import org.apache.flink.streaming.runtime.tasks.StreamIterationHead;
 import org.apache.flink.streaming.runtime.tasks.StreamIterationTail;
 import org.apache.flink.util.FlinkRuntimeException;
+import org.apache.flink.util.IterableUtils;
 import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.SerializedValue;
 import org.apache.flink.util.concurrent.ExecutorThreadFactory;
@@ -100,6 +103,7 @@ import java.util.Comparator;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.IdentityHashMap;
+import java.util.LinkedHashMap;
 import java.util.LinkedHashSet;
 import java.util.LinkedList;
 import java.util.List;
@@ -193,7 +197,14 @@ public class StreamingJobGraphGenerator {
                     
List<CompletableFuture<SerializedValue<OperatorCoordinator.Provider>>>>
             coordinatorSerializationFuturesPerJobVertex = new HashMap<>();
 
-    private final Map<Integer, Map<StreamEdge, NonChainedOutput>> 
opIntermediateOutputs;
+    /** The {@link OperatorChainInfo}s, key is the start node id of the chain. 
*/
+    private final Map<Integer, OperatorChainInfo> chainInfos;
+
+    /**
+     * This is used to cache the non-chainable outputs, to set the 
non-chainable outputs config
+     * after all job vertices are created.
+     */
+    private final Map<Integer, List<StreamEdge>> opNonChainableOutputsCache;
 
     private StreamingJobGraphGenerator(
             ClassLoader userClassloader,
@@ -205,7 +216,7 @@ public class StreamingJobGraphGenerator {
         this.defaultStreamGraphHasher = new StreamGraphHasherV2();
         this.legacyStreamGraphHashers = Arrays.asList(new 
StreamGraphUserHashHasher());
 
-        this.jobVertices = new HashMap<>();
+        this.jobVertices = new LinkedHashMap<>();
         this.builtVertices = new HashSet<>();
         this.chainedConfigs = new HashMap<>();
         this.vertexConfigs = new HashMap<>();
@@ -215,7 +226,8 @@ public class StreamingJobGraphGenerator {
         this.chainedInputOutputFormats = new HashMap<>();
         this.physicalEdgesInOrder = new ArrayList<>();
         this.serializationExecutor = 
Preconditions.checkNotNull(serializationExecutor);
-        this.opIntermediateOutputs = new HashMap<>();
+        this.chainInfos = new HashMap<>();
+        this.opNonChainableOutputsCache = new LinkedHashMap<>();
 
         jobGraph = new JobGraph(jobID, streamGraph.getJobName());
     }
@@ -241,6 +253,18 @@ public class StreamingJobGraphGenerator {
 
         setChaining(hashes, legacyHashes);
 
+        if (jobGraph.isDynamic()) {
+            setVertexParallelismsForDynamicGraphIfNecessary();
+        }
+
+        // Note that we set all the non-chainable outputs configuration here 
because the
+        // "setVertexParallelismsForDynamicGraphIfNecessary" may affect the 
parallelism of job
+        // vertices and partition-reuse
+        final Map<Integer, Map<StreamEdge, NonChainedOutput>> 
opIntermediateOutputs =
+                new HashMap<>();
+        setAllOperatorNonChainedOutputsConfigs(opIntermediateOutputs);
+        setAllVertexNonChainedOutputsConfigs(opIntermediateOutputs);
+
         setPhysicalEdges();
 
         markSupportingConcurrentExecutionAttempts();
@@ -568,12 +592,11 @@ public class StreamingJobGraphGenerator {
                     final StreamConfig.SourceInputConfig inputConfig =
                             new StreamConfig.SourceInputConfig(sourceOutEdge);
                     final StreamConfig operatorConfig = new StreamConfig(new 
Configuration());
-                    setVertexConfig(
-                            sourceNodeId,
-                            operatorConfig,
-                            Collections.emptyList(),
-                            Collections.emptyList(),
-                            Collections.emptyMap());
+                    setOperatorConfig(sourceNodeId, operatorConfig, 
Collections.emptyMap());
+                    setOperatorChainedOutputsConfig(operatorConfig, 
Collections.emptyList());
+                    // we cache the non-chainable outputs here, and set the 
non-chained config later
+                    opNonChainableOutputsCache.put(sourceNodeId, 
Collections.emptyList());
+
                     operatorConfig.setChainIndex(0); // sources are always 
first
                     operatorConfig.setOperatorID(opId);
                     
operatorConfig.setOperatorName(sourceNode.getOperatorName());
@@ -712,28 +735,22 @@ public class StreamingJobGraphGenerator {
                             ? createJobVertex(startNodeId, chainInfo)
                             : new StreamConfig(new Configuration());
 
-            setVertexConfig(
-                    currentNodeId,
-                    config,
-                    chainableOutputs,
-                    nonChainableOutputs,
-                    chainInfo.getChainedSources());
+            tryConvertPartitionerForDynamicGraph(chainableOutputs, 
nonChainableOutputs);
+
+            setOperatorConfig(currentNodeId, config, 
chainInfo.getChainedSources());
+
+            setOperatorChainedOutputsConfig(config, chainableOutputs);
+
+            // we cache the non-chainable outputs here, and set the 
non-chained config later
+            opNonChainableOutputsCache.put(currentNodeId, nonChainableOutputs);
 
             if (currentNodeId.equals(startNodeId)) {
+                chainInfo.setTransitiveOutEdges(transitiveOutEdges);
+                chainInfos.put(startNodeId, chainInfo);
 
                 config.setChainStart();
                 config.setChainIndex(chainIndex);
                 
config.setOperatorName(streamGraph.getStreamNode(currentNodeId).getOperatorName());
-
-                LinkedHashSet<NonChainedOutput> transitiveOutputs = new 
LinkedHashSet<>();
-                for (StreamEdge edge : transitiveOutEdges) {
-                    NonChainedOutput output =
-                            
opIntermediateOutputs.get(edge.getSourceId()).get(edge);
-                    transitiveOutputs.add(output);
-                    connect(startNodeId, edge, output);
-                }
-
-                config.setVertexNonChainedOutputs(new 
ArrayList<>(transitiveOutputs));
                 
config.setTransitiveChainedTaskConfigs(chainedConfigs.get(startNodeId));
 
             } else {
@@ -758,6 +775,102 @@ public class StreamingJobGraphGenerator {
         }
     }
 
+    /**
+     * This method is used to reset or set job vertices' parallelism for 
dynamic graph:
+     *
+     * <p>1. Reset parallelism for job vertices whose parallelism is not 
configured.
+     *
+     * <p>2. Set parallelism and maxParallelism for job vertices in forward 
group, to ensure the
+     * parallelism and maxParallelism of vertices in the same forward group to 
be the same; set the
+     * parallelism at early stage if possible, to avoid invalid partition 
reuse.
+     */
+    private void setVertexParallelismsForDynamicGraphIfNecessary() {
+        // Note that the jobVertices are reverse topological order
+        final List<JobVertex> topologicalOrderVertices =
+                
IterableUtils.toStream(jobVertices.values()).collect(Collectors.toList());
+        Collections.reverse(topologicalOrderVertices);
+
+        // reset parallelism for job vertices whose parallelism is not 
configured
+        jobVertices.forEach(
+                (startNodeId, jobVertex) -> {
+                    final OperatorChainInfo chainInfo = 
chainInfos.get(startNodeId);
+                    if (!jobVertex.isParallelismConfigured()
+                            && streamGraph.isAutoParallelismEnabled()) {
+                        
jobVertex.setParallelism(ExecutionConfig.PARALLELISM_DEFAULT);
+                        chainInfo
+                                .getAllChainedNodes()
+                                .forEach(
+                                        n ->
+                                                n.setParallelism(
+                                                        
ExecutionConfig.PARALLELISM_DEFAULT,
+                                                        false));
+                    }
+                });
+
+        final Map<JobVertex, Set<JobVertex>> forwardProducersByJobVertex = new 
HashMap<>();
+        jobVertices.forEach(
+                (startNodeId, jobVertex) -> {
+                    Set<JobVertex> forwardConsumers =
+                            
chainInfos.get(startNodeId).getTransitiveOutEdges().stream()
+                                    .filter(
+                                            edge ->
+                                                    edge.getPartitioner()
+                                                            instanceof 
ForwardPartitioner)
+                                    .map(StreamEdge::getTargetId)
+                                    .map(jobVertices::get)
+                                    .collect(Collectors.toSet());
+
+                    for (JobVertex forwardConsumer : forwardConsumers) {
+                        forwardProducersByJobVertex.compute(
+                                forwardConsumer,
+                                (ignored, producers) -> {
+                                    if (producers == null) {
+                                        producers = new HashSet<>();
+                                    }
+                                    producers.add(jobVertex);
+                                    return producers;
+                                });
+                    }
+                });
+
+        // compute forward groups
+        final Map<JobVertexID, ForwardGroup> forwardGroupsByJobVertexId =
+                ForwardGroupComputeUtil.computeForwardGroups(
+                        topologicalOrderVertices,
+                        jobVertex ->
+                                forwardProducersByJobVertex.getOrDefault(
+                                        jobVertex, Collections.emptySet()));
+
+        jobVertices.forEach(
+                (startNodeId, jobVertex) -> {
+                    ForwardGroup forwardGroup = 
forwardGroupsByJobVertexId.get(jobVertex.getID());
+                    // set parallelism for vertices in forward group
+                    if (forwardGroup != null && 
forwardGroup.isParallelismDecided()) {
+                        
jobVertex.setParallelism(forwardGroup.getParallelism());
+                        jobVertex.setParallelismConfigured(true);
+                        chainInfos
+                                .get(startNodeId)
+                                .getAllChainedNodes()
+                                .forEach(
+                                        streamNode ->
+                                                streamNode.setParallelism(
+                                                        
forwardGroup.getParallelism(), true));
+                    }
+
+                    // set max parallelism for vertices in forward group
+                    if (forwardGroup != null && 
forwardGroup.isMaxParallelismDecided()) {
+                        
jobVertex.setMaxParallelism(forwardGroup.getMaxParallelism());
+                        chainInfos
+                                .get(startNodeId)
+                                .getAllChainedNodes()
+                                .forEach(
+                                        streamNode ->
+                                                streamNode.setMaxParallelism(
+                                                        
forwardGroup.getMaxParallelism()));
+                    }
+                });
+    }
+
     private void checkAndReplaceReusableHybridPartitionType(NonChainedOutput 
reusableOutput) {
         if (reusableOutput.getPartitionType() == 
ResultPartitionType.HYBRID_SELECTIVE) {
             // for can be reused hybrid output, it can be optimized to always 
use full
@@ -916,30 +1029,16 @@ public class StreamingJobGraphGenerator {
         jobVertex.setParallelismConfigured(
                 chainInfo.getAllChainedNodes().stream()
                         .anyMatch(StreamNode::isParallelismConfigured));
-        if (streamGraph.isDynamic()
-                && !jobVertex.isParallelismConfigured()
-                && streamGraph.isAutoParallelismEnabled()) {
-            jobVertex.setParallelism(ExecutionConfig.PARALLELISM_DEFAULT);
-            chainInfo
-                    .getAllChainedNodes()
-                    .forEach(n -> 
n.setParallelism(ExecutionConfig.PARALLELISM_DEFAULT, false));
-        }
 
         return new StreamConfig(jobVertex.getConfiguration());
     }
 
-    private void setVertexConfig(
-            Integer vertexID,
-            StreamConfig config,
-            List<StreamEdge> chainableOutputs,
-            List<StreamEdge> nonChainableOutputs,
-            Map<Integer, ChainedSourceInfo> chainedSources) {
-
-        tryConvertPartitionerForDynamicGraph(chainableOutputs, 
nonChainableOutputs);
+    private void setOperatorConfig(
+            Integer vertexId, StreamConfig config, Map<Integer, 
ChainedSourceInfo> chainedSources) {
 
-        StreamNode vertex = streamGraph.getStreamNode(vertexID);
+        StreamNode vertex = streamGraph.getStreamNode(vertexId);
 
-        config.setVertexID(vertexID);
+        config.setVertexID(vertexId);
 
         // build the inputs as a combination of source and network inputs
         final List<StreamEdge> inEdges = vertex.getInEdges();
@@ -965,7 +1064,7 @@ public class StreamingJobGraphGenerator {
                 }
                 inputConfigs[inputIndex] = chainedSource.getInputConfig();
                 chainedConfigs
-                        .computeIfAbsent(vertexID, (key) -> new HashMap<>())
+                        .computeIfAbsent(vertexId, (key) -> new HashMap<>())
                         .put(inEdge.getSourceId(), 
chainedSource.getOperatorConfig());
             } else {
                 // network input. null if we move to a new input, non-null if 
this is a further edge
@@ -996,34 +1095,8 @@ public class StreamingJobGraphGenerator {
 
         config.setTypeSerializerOut(vertex.getTypeSerializerOut());
 
-        // iterate edges, find sideOutput edges create and save serializers 
for each outputTag type
-        for (StreamEdge edge : chainableOutputs) {
-            if (edge.getOutputTag() != null) {
-                config.setTypeSerializerSideOut(
-                        edge.getOutputTag(),
-                        edge.getOutputTag()
-                                .getTypeInfo()
-                                
.createSerializer(streamGraph.getExecutionConfig()));
-            }
-        }
-        for (StreamEdge edge : nonChainableOutputs) {
-            if (edge.getOutputTag() != null) {
-                config.setTypeSerializerSideOut(
-                        edge.getOutputTag(),
-                        edge.getOutputTag()
-                                .getTypeInfo()
-                                
.createSerializer(streamGraph.getExecutionConfig()));
-            }
-        }
-
         config.setStreamOperatorFactory(vertex.getOperatorFactory());
 
-        List<NonChainedOutput> deduplicatedOutputs =
-                mayReuseNonChainedOutputs(vertexID, nonChainableOutputs);
-        config.setNumberOfOutputs(deduplicatedOutputs.size());
-        config.setOperatorNonChainedOutputs(deduplicatedOutputs);
-        config.setChainedOutputs(chainableOutputs);
-
         config.setTimeCharacteristic(streamGraph.getTimeCharacteristic());
 
         final CheckpointConfig checkpointCfg = 
streamGraph.getCheckpointConfig();
@@ -1053,21 +1126,103 @@ public class StreamingJobGraphGenerator {
 
         if (vertexClass.equals(StreamIterationHead.class)
                 || vertexClass.equals(StreamIterationTail.class)) {
-            config.setIterationId(streamGraph.getBrokerID(vertexID));
-            config.setIterationWaitTime(streamGraph.getLoopTimeout(vertexID));
+            config.setIterationId(streamGraph.getBrokerID(vertexId));
+            config.setIterationWaitTime(streamGraph.getLoopTimeout(vertexId));
+        }
+
+        vertexConfigs.put(vertexId, config);
+    }
+
+    private void setOperatorChainedOutputsConfig(
+            StreamConfig config, List<StreamEdge> chainableOutputs) {
+        // iterate edges, find sideOutput edges create and save serializers 
for each outputTag type
+        for (StreamEdge edge : chainableOutputs) {
+            if (edge.getOutputTag() != null) {
+                config.setTypeSerializerSideOut(
+                        edge.getOutputTag(),
+                        edge.getOutputTag()
+                                .getTypeInfo()
+                                
.createSerializer(streamGraph.getExecutionConfig()));
+            }
+        }
+        config.setChainedOutputs(chainableOutputs);
+    }
+
+    private void setOperatorNonChainedOutputsConfig(
+            Integer vertexId,
+            StreamConfig config,
+            List<StreamEdge> nonChainableOutputs,
+            Map<StreamEdge, NonChainedOutput> outputsConsumedByEdge) {
+        // iterate edges, find sideOutput edges create and save serializers 
for each outputTag type
+        for (StreamEdge edge : nonChainableOutputs) {
+            if (edge.getOutputTag() != null) {
+                config.setTypeSerializerSideOut(
+                        edge.getOutputTag(),
+                        edge.getOutputTag()
+                                .getTypeInfo()
+                                
.createSerializer(streamGraph.getExecutionConfig()));
+            }
+        }
+
+        List<NonChainedOutput> deduplicatedOutputs =
+                mayReuseNonChainedOutputs(vertexId, nonChainableOutputs, 
outputsConsumedByEdge);
+        config.setNumberOfOutputs(deduplicatedOutputs.size());
+        config.setOperatorNonChainedOutputs(deduplicatedOutputs);
+    }
+
+    private void setVertexNonChainedOutputsConfig(
+            Integer startNodeId,
+            StreamConfig config,
+            List<StreamEdge> transitiveOutEdges,
+            final Map<Integer, Map<StreamEdge, NonChainedOutput>> 
opIntermediateOutputs) {
+
+        LinkedHashSet<NonChainedOutput> transitiveOutputs = new 
LinkedHashSet<>();
+        for (StreamEdge edge : transitiveOutEdges) {
+            NonChainedOutput output = 
opIntermediateOutputs.get(edge.getSourceId()).get(edge);
+            transitiveOutputs.add(output);
+            connect(startNodeId, edge, output);
         }
 
-        vertexConfigs.put(vertexID, config);
+        config.setVertexNonChainedOutputs(new ArrayList<>(transitiveOutputs));
+    }
+
+    private void setAllOperatorNonChainedOutputsConfigs(
+            final Map<Integer, Map<StreamEdge, NonChainedOutput>> 
opIntermediateOutputs) {
+        // set non chainable output config
+        opNonChainableOutputsCache.forEach(
+                (vertexId, nonChainableOutputs) -> {
+                    Map<StreamEdge, NonChainedOutput> outputsConsumedByEdge =
+                            opIntermediateOutputs.computeIfAbsent(
+                                    vertexId, ignored -> new HashMap<>());
+                    setOperatorNonChainedOutputsConfig(
+                            vertexId,
+                            vertexConfigs.get(vertexId),
+                            nonChainableOutputs,
+                            outputsConsumedByEdge);
+                });
+    }
+
+    private void setAllVertexNonChainedOutputsConfigs(
+            final Map<Integer, Map<StreamEdge, NonChainedOutput>> 
opIntermediateOutputs) {
+        jobVertices
+                .keySet()
+                .forEach(
+                        startNodeId ->
+                                setVertexNonChainedOutputsConfig(
+                                        startNodeId,
+                                        vertexConfigs.get(startNodeId),
+                                        
chainInfos.get(startNodeId).getTransitiveOutEdges(),
+                                        opIntermediateOutputs));
     }
 
     private List<NonChainedOutput> mayReuseNonChainedOutputs(
-            int vertexId, List<StreamEdge> consumerEdges) {
+            int vertexId,
+            List<StreamEdge> consumerEdges,
+            Map<StreamEdge, NonChainedOutput> outputsConsumedByEdge) {
         if (consumerEdges.isEmpty()) {
             return new ArrayList<>();
         }
         List<NonChainedOutput> outputs = new ArrayList<>(consumerEdges.size());
-        Map<StreamEdge, NonChainedOutput> outputsConsumedByEdge =
-                opIntermediateOutputs.computeIfAbsent(vertexId, ignored -> new 
HashMap<>());
         for (StreamEdge consumerEdge : consumerEdges) {
             checkState(vertexId == consumerEdge.getSourceId(), "Vertex id must 
be the same.");
             ResultPartitionType partitionType = 
getResultPartitionType(consumerEdge);
@@ -1251,7 +1406,7 @@ public class StreamingJobGraphGenerator {
                             headVertex,
                             DistributionPattern.POINTWISE,
                             resultPartitionType,
-                            
opIntermediateOutputs.get(edge.getSourceId()).get(edge).getDataSetId(),
+                            output.getDataSetId(),
                             partitioner.isBroadcast());
         } else {
             jobEdge =
@@ -1259,7 +1414,7 @@ public class StreamingJobGraphGenerator {
                             headVertex,
                             DistributionPattern.ALL_TO_ALL,
                             resultPartitionType,
-                            
opIntermediateOutputs.get(edge.getSourceId()).get(edge).getDataSetId(),
+                            output.getDataSetId(),
                             partitioner.isBroadcast());
         }
 
@@ -1894,6 +2049,7 @@ public class StreamingJobGraphGenerator {
         private final List<OperatorCoordinator.Provider> coordinatorProviders;
         private final StreamGraph streamGraph;
         private final List<StreamNode> chainedNodes;
+        private final List<StreamEdge> transitiveOutEdges;
 
         private OperatorChainInfo(
                 int startNodeId,
@@ -1909,6 +2065,7 @@ public class StreamingJobGraphGenerator {
             this.chainedSources = chainedSources;
             this.streamGraph = streamGraph;
             this.chainedNodes = new ArrayList<>();
+            this.transitiveOutEdges = new ArrayList<>();
         }
 
         byte[] getHash(Integer streamNodeId) {
@@ -1955,6 +2112,14 @@ public class StreamingJobGraphGenerator {
             return new OperatorID(primaryHashBytes);
         }
 
+        private void setTransitiveOutEdges(final List<StreamEdge> 
transitiveOutEdges) {
+            this.transitiveOutEdges.addAll(transitiveOutEdges);
+        }
+
+        private List<StreamEdge> getTransitiveOutEdges() {
+            return transitiveOutEdges;
+        }
+
         private void recordChainedNode(int currentNodeId) {
             StreamNode streamNode = streamGraph.getStreamNode(currentNodeId);
             chainedNodes.add(streamNode);
diff --git 
a/flink-tests/src/test/java/org/apache/flink/test/scheduling/AdaptiveBatchSchedulerITCase.java
 
b/flink-tests/src/test/java/org/apache/flink/test/scheduling/AdaptiveBatchSchedulerITCase.java
index cab7aa9ba74..771755abc5c 100644
--- 
a/flink-tests/src/test/java/org/apache/flink/test/scheduling/AdaptiveBatchSchedulerITCase.java
+++ 
b/flink-tests/src/test/java/org/apache/flink/test/scheduling/AdaptiveBatchSchedulerITCase.java
@@ -96,6 +96,37 @@ class AdaptiveBatchSchedulerITCase {
         env.execute();
     }
 
+    @Test
+    void testDifferentConsumerParallelism() throws Exception {
+        final Configuration configuration = createConfiguration();
+        final StreamExecutionEnvironment env =
+                
StreamExecutionEnvironment.createLocalEnvironment(configuration);
+        env.setRuntimeMode(RuntimeExecutionMode.BATCH);
+        env.setParallelism(8);
+
+        final DataStream<Long> source2 =
+                env.fromSequence(0, NUMBERS_TO_PRODUCE - 1)
+                        .setParallelism(8)
+                        .name("source2")
+                        .slotSharingGroup("group2");
+
+        final DataStream<Long> source1 =
+                env.fromSequence(0, NUMBERS_TO_PRODUCE - 1)
+                        .setParallelism(8)
+                        .name("source1")
+                        .slotSharingGroup("group1");
+
+        source1.forward()
+                .union(source2)
+                .map(new NumberCounter())
+                .name("map1")
+                .slotSharingGroup("group3");
+
+        source2.map(new 
NumberCounter()).name("map2").slotSharingGroup("group4");
+
+        env.execute();
+    }
+
     private void testScheduling(Boolean isFineGrained) throws Exception {
         executeJob(isFineGrained);
 


Reply via email to