This is an automated email from the ASF dual-hosted git repository. wanglijie pushed a commit to branch release-1.17 in repository https://gitbox.apache.org/repos/asf/flink.git
commit c66ef2540c3cb53f4cf3218ff07f5b440511ad84 Author: Lijie Wang <wangdachui9...@gmail.com> AuthorDate: Wed Feb 22 11:44:06 2023 +0800 [FLINK-31114][runtime] Set parallelism of job vertices in forward group at compilation phase This closes #21963 --- .../jobgraph/forwardgroup/ForwardGroup.java | 33 ++- .../forwardgroup/ForwardGroupComputeUtil.java | 34 ++- .../AdaptiveBatchSchedulerFactory.java | 2 +- .../forwardgroup/ForwardGroupComputeUtilTest.java | 55 +--- .../runtime/scheduler/DefaultSchedulerBuilder.java | 4 +- .../api/graph/StreamingJobGraphGenerator.java | 321 ++++++++++++++++----- .../scheduling/AdaptiveBatchSchedulerITCase.java | 31 ++ 7 files changed, 327 insertions(+), 153 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroup.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroup.java index cf4f2e67194..922acf231e2 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroup.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroup.java @@ -23,6 +23,7 @@ import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.runtime.jobgraph.JobVertex; import org.apache.flink.runtime.jobgraph.JobVertexID; +import java.util.Collections; import java.util.HashSet; import java.util.Set; import java.util.stream.Collectors; @@ -38,12 +39,13 @@ public class ForwardGroup { private int parallelism = ExecutionConfig.PARALLELISM_DEFAULT; + private int maxParallelism = JobVertex.MAX_PARALLELISM_DEFAULT; private final Set<JobVertexID> jobVertexIds = new HashSet<>(); public ForwardGroup(final Set<JobVertex> jobVertices) { checkNotNull(jobVertices); - Set<Integer> decidedParallelisms = + Set<Integer> configuredParallelisms = jobVertices.stream() .filter( jobVertex -> { @@ -53,9 +55,23 @@ public class ForwardGroup { .map(JobVertex::getParallelism) .collect(Collectors.toSet()); - checkState(decidedParallelisms.size() <= 1); - if (decidedParallelisms.size() == 1) { - this.parallelism = decidedParallelisms.iterator().next(); + checkState(configuredParallelisms.size() <= 1); + if (configuredParallelisms.size() == 1) { + this.parallelism = configuredParallelisms.iterator().next(); + } + + Set<Integer> configuredMaxParallelisms = + jobVertices.stream() + .map(JobVertex::getMaxParallelism) + .filter(val -> val > 0) + .collect(Collectors.toSet()); + + if (!configuredMaxParallelisms.isEmpty()) { + this.maxParallelism = Collections.min(configuredMaxParallelisms); + checkState( + parallelism == ExecutionConfig.PARALLELISM_DEFAULT + || maxParallelism >= parallelism, + "There is a job vertex in the forward group whose maximum parallelism is smaller than the group's parallelism"); } } @@ -73,6 +89,15 @@ public class ForwardGroup { return parallelism; } + public boolean isMaxParallelismDecided() { + return maxParallelism > 0; + } + + public int getMaxParallelism() { + checkState(isMaxParallelismDecided()); + return maxParallelism; + } + public int size() { return jobVertexIds.size(); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroupComputeUtil.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroupComputeUtil.java index c8a3395ae51..dc2dc702e34 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroupComputeUtil.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroupComputeUtil.java @@ -19,8 +19,8 @@ package org.apache.flink.runtime.jobgraph.forwardgroup; -import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.runtime.executiongraph.VertexGroupComputeUtil; +import org.apache.flink.runtime.jobgraph.IntermediateDataSet; import org.apache.flink.runtime.jobgraph.JobEdge; import org.apache.flink.runtime.jobgraph.JobVertex; import org.apache.flink.runtime.jobgraph.JobVertexID; @@ -30,31 +30,34 @@ import java.util.HashSet; import java.util.IdentityHashMap; import java.util.Map; import java.util.Set; +import java.util.function.Function; import java.util.stream.Collectors; +import static org.apache.flink.util.Preconditions.checkState; + /** Common utils for computing forward groups. */ public class ForwardGroupComputeUtil { - public static Map<JobVertexID, ForwardGroup> - computeForwardGroupsAndSetVertexParallelismsIfNecessary( - final Iterable<JobVertex> topologicallySortedVertices) { + public static Map<JobVertexID, ForwardGroup> computeForwardGroupsAndCheckParallelism( + final Iterable<JobVertex> topologicallySortedVertices) { final Map<JobVertexID, ForwardGroup> forwardGroupsByJobVertexId = - computeForwardGroups(topologicallySortedVertices); - // set parallelism for vertices in parallelism-decided forward groups + computeForwardGroups( + topologicallySortedVertices, ForwardGroupComputeUtil::getForwardProducers); + // the vertex's parallelism in parallelism-decided forward group should have been set at + // compilation phase topologicallySortedVertices.forEach( jobVertex -> { ForwardGroup forwardGroup = forwardGroupsByJobVertexId.get(jobVertex.getID()); - if (jobVertex.getParallelism() == ExecutionConfig.PARALLELISM_DEFAULT - && forwardGroup != null - && forwardGroup.isParallelismDecided()) { - jobVertex.setParallelism(forwardGroup.getParallelism()); + if (forwardGroup != null && forwardGroup.isParallelismDecided()) { + checkState(jobVertex.getParallelism() == forwardGroup.getParallelism()); } }); return forwardGroupsByJobVertexId; } - static Map<JobVertexID, ForwardGroup> computeForwardGroups( - final Iterable<JobVertex> topologicallySortedVertices) { + public static Map<JobVertexID, ForwardGroup> computeForwardGroups( + final Iterable<JobVertex> topologicallySortedVertices, + final Function<JobVertex, Set<JobVertex>> forwardProducersRetriever) { final Map<JobVertex, Set<JobVertex>> vertexToGroup = new IdentityHashMap<>(); @@ -64,8 +67,7 @@ public class ForwardGroupComputeUtil { currentGroup.add(vertex); vertexToGroup.put(vertex, currentGroup); - for (JobEdge input : getForwardInputs(vertex)) { - final JobVertex producerVertex = input.getSource().getProducer(); + for (JobVertex producerVertex : forwardProducersRetriever.apply(vertex)) { final Set<JobVertex> producerGroup = vertexToGroup.get(producerVertex); if (producerGroup == null) { @@ -99,9 +101,11 @@ public class ForwardGroupComputeUtil { return ret; } - static Iterable<JobEdge> getForwardInputs(JobVertex jobVertex) { + static Set<JobVertex> getForwardProducers(final JobVertex jobVertex) { return jobVertex.getInputs().stream() .filter(JobEdge::isForward) + .map(JobEdge::getSource) + .map(IntermediateDataSet::getProducer) .collect(Collectors.toSet()); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerFactory.java index bdac05c1e4d..3746b549ed1 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerFactory.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerFactory.java @@ -181,7 +181,7 @@ public class AdaptiveBatchSchedulerFactory implements SchedulerNGFactory { getDefaultMaxParallelism(jobMasterConfiguration, executionConfig); final Map<JobVertexID, ForwardGroup> forwardGroupsByJobVertexId = - ForwardGroupComputeUtil.computeForwardGroupsAndSetVertexParallelismsIfNecessary( + ForwardGroupComputeUtil.computeForwardGroupsAndCheckParallelism( jobGraph.getVerticesSortedTopologicallyFromSources()); if (enableSpeculativeExecution) { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroupComputeUtilTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroupComputeUtilTest.java index 1f2bd4ead9e..e9a90e98918 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroupComputeUtilTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/forwardgroup/ForwardGroupComputeUtilTest.java @@ -22,16 +22,12 @@ import org.apache.flink.runtime.io.network.partition.ResultPartitionType; import org.apache.flink.runtime.jobgraph.DistributionPattern; import org.apache.flink.runtime.jobgraph.JobVertex; import org.apache.flink.runtime.testtasks.NoOpInvokable; -import org.apache.flink.testutils.TestingUtils; -import org.apache.flink.testutils.executor.TestExecutorExtension; import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.RegisterExtension; import java.util.Arrays; import java.util.HashSet; import java.util.Set; -import java.util.concurrent.ScheduledExecutorService; import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -40,9 +36,6 @@ import static org.assertj.core.api.Assertions.assertThat; * Unit tests for {@link org.apache.flink.runtime.jobgraph.forwardgroup.ForwardGroupComputeUtil}. */ class ForwardGroupComputeUtilTest { - @RegisterExtension - static final TestExecutorExtension<ScheduledExecutorService> EXECUTOR_RESOURCE = - TestingUtils.defaultExecutorExtension(); /** * Tests that the computation of the job graph with isolated vertices works correctly. @@ -178,58 +171,14 @@ class ForwardGroupComputeUtilTest { checkGroupSize(groups, 1, 3); } - /** - * Tests whether the parallelism of job vertices in forward group are correctly set. - * - * <pre> - * - * (v1) -> (v2) - * - * (v3) -> (v4) - * - * </pre> - */ - @Test - void testComputeForwardGroupsAndSetVertexParallelismsIfNecessary() throws Exception { - JobVertex v1 = new JobVertex("v1"); - JobVertex v2 = new JobVertex("v2"); - JobVertex v3 = new JobVertex("v3"); - JobVertex v4 = new JobVertex("v4"); - - v2.setParallelism(8); - - v2.connectNewDataSetAsInput( - v1, DistributionPattern.ALL_TO_ALL, ResultPartitionType.BLOCKING); - v4.connectNewDataSetAsInput( - v3, DistributionPattern.POINTWISE, ResultPartitionType.BLOCKING); - - v1.getProducedDataSets().get(0).getConsumers().get(0).setForward(true); - v3.getProducedDataSets().get(0).getConsumers().get(0).setForward(true); - - Set<ForwardGroup> groups = - computeForwardGroupsAndSetVertexParallelismsIfNecessary(v1, v2, v3, v4); - checkGroupSize(groups, 2, 2, 2); - assertThat(v1.getParallelism()).isEqualTo(8); - assertThat(v2.getParallelism()).isEqualTo(8); - assertThat(v3.getParallelism()).isEqualTo(-1); - assertThat(v4.getParallelism()).isEqualTo(-1); - } - - private static Set<ForwardGroup> computeForwardGroupsAndSetVertexParallelismsIfNecessary( - JobVertex... vertices) throws Exception { + private static Set<ForwardGroup> computeForwardGroups(JobVertex... vertices) { Arrays.asList(vertices).forEach(vertex -> vertex.setInvokableClass(NoOpInvokable.class)); return new HashSet<>( - ForwardGroupComputeUtil.computeForwardGroupsAndSetVertexParallelismsIfNecessary( + ForwardGroupComputeUtil.computeForwardGroupsAndCheckParallelism( Arrays.asList(vertices)) .values()); } - private static Set<ForwardGroup> computeForwardGroups(JobVertex... vertices) throws Exception { - Arrays.asList(vertices).forEach(vertex -> vertex.setInvokableClass(NoOpInvokable.class)); - return new HashSet<>( - ForwardGroupComputeUtil.computeForwardGroups(Arrays.asList(vertices)).values()); - } - private static void checkGroupSize( Set<ForwardGroup> groups, int numOfGroups, Integer... sizes) { assertThat(groups.size()).isEqualTo(numOfGroups); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/DefaultSchedulerBuilder.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/DefaultSchedulerBuilder.java index aef96c2469d..56c7b29ebd3 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/DefaultSchedulerBuilder.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/DefaultSchedulerBuilder.java @@ -335,7 +335,7 @@ public class DefaultSchedulerBuilder { vertexParallelismAndInputInfosDecider, defaultMaxParallelism, hybridPartitionDataConsumeConstraint, - ForwardGroupComputeUtil.computeForwardGroupsAndSetVertexParallelismsIfNecessary( + ForwardGroupComputeUtil.computeForwardGroupsAndCheckParallelism( jobGraph.getVerticesSortedTopologicallyFromSources())); } @@ -367,7 +367,7 @@ public class DefaultSchedulerBuilder { defaultMaxParallelism, blocklistOperations, HybridPartitionDataConsumeConstraint.ALL_PRODUCERS_FINISHED, - ForwardGroupComputeUtil.computeForwardGroupsAndSetVertexParallelismsIfNecessary( + ForwardGroupComputeUtil.computeForwardGroupsAndCheckParallelism( jobGraph.getVerticesSortedTopologicallyFromSources())); } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java index 3eceb27dc6a..3f818b6bd9a 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java @@ -46,6 +46,8 @@ import org.apache.flink.runtime.jobgraph.JobType; import org.apache.flink.runtime.jobgraph.JobVertex; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobgraph.OperatorID; +import org.apache.flink.runtime.jobgraph.forwardgroup.ForwardGroup; +import org.apache.flink.runtime.jobgraph.forwardgroup.ForwardGroupComputeUtil; import org.apache.flink.runtime.jobgraph.tasks.CheckpointCoordinatorConfiguration; import org.apache.flink.runtime.jobgraph.tasks.JobCheckpointingSettings; import org.apache.flink.runtime.jobgraph.tasks.TaskInvokable; @@ -80,6 +82,7 @@ import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner; import org.apache.flink.streaming.runtime.tasks.StreamIterationHead; import org.apache.flink.streaming.runtime.tasks.StreamIterationTail; import org.apache.flink.util.FlinkRuntimeException; +import org.apache.flink.util.IterableUtils; import org.apache.flink.util.Preconditions; import org.apache.flink.util.SerializedValue; import org.apache.flink.util.concurrent.ExecutorThreadFactory; @@ -100,6 +103,7 @@ import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; import java.util.IdentityHashMap; +import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.LinkedList; import java.util.List; @@ -193,7 +197,14 @@ public class StreamingJobGraphGenerator { List<CompletableFuture<SerializedValue<OperatorCoordinator.Provider>>>> coordinatorSerializationFuturesPerJobVertex = new HashMap<>(); - private final Map<Integer, Map<StreamEdge, NonChainedOutput>> opIntermediateOutputs; + /** The {@link OperatorChainInfo}s, key is the start node id of the chain. */ + private final Map<Integer, OperatorChainInfo> chainInfos; + + /** + * This is used to cache the non-chainable outputs, to set the non-chainable outputs config + * after all job vertices are created. + */ + private final Map<Integer, List<StreamEdge>> opNonChainableOutputsCache; private StreamingJobGraphGenerator( ClassLoader userClassloader, @@ -205,7 +216,7 @@ public class StreamingJobGraphGenerator { this.defaultStreamGraphHasher = new StreamGraphHasherV2(); this.legacyStreamGraphHashers = Arrays.asList(new StreamGraphUserHashHasher()); - this.jobVertices = new HashMap<>(); + this.jobVertices = new LinkedHashMap<>(); this.builtVertices = new HashSet<>(); this.chainedConfigs = new HashMap<>(); this.vertexConfigs = new HashMap<>(); @@ -215,7 +226,8 @@ public class StreamingJobGraphGenerator { this.chainedInputOutputFormats = new HashMap<>(); this.physicalEdgesInOrder = new ArrayList<>(); this.serializationExecutor = Preconditions.checkNotNull(serializationExecutor); - this.opIntermediateOutputs = new HashMap<>(); + this.chainInfos = new HashMap<>(); + this.opNonChainableOutputsCache = new LinkedHashMap<>(); jobGraph = new JobGraph(jobID, streamGraph.getJobName()); } @@ -241,6 +253,18 @@ public class StreamingJobGraphGenerator { setChaining(hashes, legacyHashes); + if (jobGraph.isDynamic()) { + setVertexParallelismsForDynamicGraphIfNecessary(); + } + + // Note that we set all the non-chainable outputs configuration here because the + // "setVertexParallelismsForDynamicGraphIfNecessary" may affect the parallelism of job + // vertices and partition-reuse + final Map<Integer, Map<StreamEdge, NonChainedOutput>> opIntermediateOutputs = + new HashMap<>(); + setAllOperatorNonChainedOutputsConfigs(opIntermediateOutputs); + setAllVertexNonChainedOutputsConfigs(opIntermediateOutputs); + setPhysicalEdges(); markSupportingConcurrentExecutionAttempts(); @@ -568,12 +592,11 @@ public class StreamingJobGraphGenerator { final StreamConfig.SourceInputConfig inputConfig = new StreamConfig.SourceInputConfig(sourceOutEdge); final StreamConfig operatorConfig = new StreamConfig(new Configuration()); - setVertexConfig( - sourceNodeId, - operatorConfig, - Collections.emptyList(), - Collections.emptyList(), - Collections.emptyMap()); + setOperatorConfig(sourceNodeId, operatorConfig, Collections.emptyMap()); + setOperatorChainedOutputsConfig(operatorConfig, Collections.emptyList()); + // we cache the non-chainable outputs here, and set the non-chained config later + opNonChainableOutputsCache.put(sourceNodeId, Collections.emptyList()); + operatorConfig.setChainIndex(0); // sources are always first operatorConfig.setOperatorID(opId); operatorConfig.setOperatorName(sourceNode.getOperatorName()); @@ -712,28 +735,22 @@ public class StreamingJobGraphGenerator { ? createJobVertex(startNodeId, chainInfo) : new StreamConfig(new Configuration()); - setVertexConfig( - currentNodeId, - config, - chainableOutputs, - nonChainableOutputs, - chainInfo.getChainedSources()); + tryConvertPartitionerForDynamicGraph(chainableOutputs, nonChainableOutputs); + + setOperatorConfig(currentNodeId, config, chainInfo.getChainedSources()); + + setOperatorChainedOutputsConfig(config, chainableOutputs); + + // we cache the non-chainable outputs here, and set the non-chained config later + opNonChainableOutputsCache.put(currentNodeId, nonChainableOutputs); if (currentNodeId.equals(startNodeId)) { + chainInfo.setTransitiveOutEdges(transitiveOutEdges); + chainInfos.put(startNodeId, chainInfo); config.setChainStart(); config.setChainIndex(chainIndex); config.setOperatorName(streamGraph.getStreamNode(currentNodeId).getOperatorName()); - - LinkedHashSet<NonChainedOutput> transitiveOutputs = new LinkedHashSet<>(); - for (StreamEdge edge : transitiveOutEdges) { - NonChainedOutput output = - opIntermediateOutputs.get(edge.getSourceId()).get(edge); - transitiveOutputs.add(output); - connect(startNodeId, edge, output); - } - - config.setVertexNonChainedOutputs(new ArrayList<>(transitiveOutputs)); config.setTransitiveChainedTaskConfigs(chainedConfigs.get(startNodeId)); } else { @@ -758,6 +775,102 @@ public class StreamingJobGraphGenerator { } } + /** + * This method is used to reset or set job vertices' parallelism for dynamic graph: + * + * <p>1. Reset parallelism for job vertices whose parallelism is not configured. + * + * <p>2. Set parallelism and maxParallelism for job vertices in forward group, to ensure the + * parallelism and maxParallelism of vertices in the same forward group to be the same; set the + * parallelism at early stage if possible, to avoid invalid partition reuse. + */ + private void setVertexParallelismsForDynamicGraphIfNecessary() { + // Note that the jobVertices are reverse topological order + final List<JobVertex> topologicalOrderVertices = + IterableUtils.toStream(jobVertices.values()).collect(Collectors.toList()); + Collections.reverse(topologicalOrderVertices); + + // reset parallelism for job vertices whose parallelism is not configured + jobVertices.forEach( + (startNodeId, jobVertex) -> { + final OperatorChainInfo chainInfo = chainInfos.get(startNodeId); + if (!jobVertex.isParallelismConfigured() + && streamGraph.isAutoParallelismEnabled()) { + jobVertex.setParallelism(ExecutionConfig.PARALLELISM_DEFAULT); + chainInfo + .getAllChainedNodes() + .forEach( + n -> + n.setParallelism( + ExecutionConfig.PARALLELISM_DEFAULT, + false)); + } + }); + + final Map<JobVertex, Set<JobVertex>> forwardProducersByJobVertex = new HashMap<>(); + jobVertices.forEach( + (startNodeId, jobVertex) -> { + Set<JobVertex> forwardConsumers = + chainInfos.get(startNodeId).getTransitiveOutEdges().stream() + .filter( + edge -> + edge.getPartitioner() + instanceof ForwardPartitioner) + .map(StreamEdge::getTargetId) + .map(jobVertices::get) + .collect(Collectors.toSet()); + + for (JobVertex forwardConsumer : forwardConsumers) { + forwardProducersByJobVertex.compute( + forwardConsumer, + (ignored, producers) -> { + if (producers == null) { + producers = new HashSet<>(); + } + producers.add(jobVertex); + return producers; + }); + } + }); + + // compute forward groups + final Map<JobVertexID, ForwardGroup> forwardGroupsByJobVertexId = + ForwardGroupComputeUtil.computeForwardGroups( + topologicalOrderVertices, + jobVertex -> + forwardProducersByJobVertex.getOrDefault( + jobVertex, Collections.emptySet())); + + jobVertices.forEach( + (startNodeId, jobVertex) -> { + ForwardGroup forwardGroup = forwardGroupsByJobVertexId.get(jobVertex.getID()); + // set parallelism for vertices in forward group + if (forwardGroup != null && forwardGroup.isParallelismDecided()) { + jobVertex.setParallelism(forwardGroup.getParallelism()); + jobVertex.setParallelismConfigured(true); + chainInfos + .get(startNodeId) + .getAllChainedNodes() + .forEach( + streamNode -> + streamNode.setParallelism( + forwardGroup.getParallelism(), true)); + } + + // set max parallelism for vertices in forward group + if (forwardGroup != null && forwardGroup.isMaxParallelismDecided()) { + jobVertex.setMaxParallelism(forwardGroup.getMaxParallelism()); + chainInfos + .get(startNodeId) + .getAllChainedNodes() + .forEach( + streamNode -> + streamNode.setMaxParallelism( + forwardGroup.getMaxParallelism())); + } + }); + } + private void checkAndReplaceReusableHybridPartitionType(NonChainedOutput reusableOutput) { if (reusableOutput.getPartitionType() == ResultPartitionType.HYBRID_SELECTIVE) { // for can be reused hybrid output, it can be optimized to always use full @@ -916,30 +1029,16 @@ public class StreamingJobGraphGenerator { jobVertex.setParallelismConfigured( chainInfo.getAllChainedNodes().stream() .anyMatch(StreamNode::isParallelismConfigured)); - if (streamGraph.isDynamic() - && !jobVertex.isParallelismConfigured() - && streamGraph.isAutoParallelismEnabled()) { - jobVertex.setParallelism(ExecutionConfig.PARALLELISM_DEFAULT); - chainInfo - .getAllChainedNodes() - .forEach(n -> n.setParallelism(ExecutionConfig.PARALLELISM_DEFAULT, false)); - } return new StreamConfig(jobVertex.getConfiguration()); } - private void setVertexConfig( - Integer vertexID, - StreamConfig config, - List<StreamEdge> chainableOutputs, - List<StreamEdge> nonChainableOutputs, - Map<Integer, ChainedSourceInfo> chainedSources) { - - tryConvertPartitionerForDynamicGraph(chainableOutputs, nonChainableOutputs); + private void setOperatorConfig( + Integer vertexId, StreamConfig config, Map<Integer, ChainedSourceInfo> chainedSources) { - StreamNode vertex = streamGraph.getStreamNode(vertexID); + StreamNode vertex = streamGraph.getStreamNode(vertexId); - config.setVertexID(vertexID); + config.setVertexID(vertexId); // build the inputs as a combination of source and network inputs final List<StreamEdge> inEdges = vertex.getInEdges(); @@ -965,7 +1064,7 @@ public class StreamingJobGraphGenerator { } inputConfigs[inputIndex] = chainedSource.getInputConfig(); chainedConfigs - .computeIfAbsent(vertexID, (key) -> new HashMap<>()) + .computeIfAbsent(vertexId, (key) -> new HashMap<>()) .put(inEdge.getSourceId(), chainedSource.getOperatorConfig()); } else { // network input. null if we move to a new input, non-null if this is a further edge @@ -996,34 +1095,8 @@ public class StreamingJobGraphGenerator { config.setTypeSerializerOut(vertex.getTypeSerializerOut()); - // iterate edges, find sideOutput edges create and save serializers for each outputTag type - for (StreamEdge edge : chainableOutputs) { - if (edge.getOutputTag() != null) { - config.setTypeSerializerSideOut( - edge.getOutputTag(), - edge.getOutputTag() - .getTypeInfo() - .createSerializer(streamGraph.getExecutionConfig())); - } - } - for (StreamEdge edge : nonChainableOutputs) { - if (edge.getOutputTag() != null) { - config.setTypeSerializerSideOut( - edge.getOutputTag(), - edge.getOutputTag() - .getTypeInfo() - .createSerializer(streamGraph.getExecutionConfig())); - } - } - config.setStreamOperatorFactory(vertex.getOperatorFactory()); - List<NonChainedOutput> deduplicatedOutputs = - mayReuseNonChainedOutputs(vertexID, nonChainableOutputs); - config.setNumberOfOutputs(deduplicatedOutputs.size()); - config.setOperatorNonChainedOutputs(deduplicatedOutputs); - config.setChainedOutputs(chainableOutputs); - config.setTimeCharacteristic(streamGraph.getTimeCharacteristic()); final CheckpointConfig checkpointCfg = streamGraph.getCheckpointConfig(); @@ -1053,21 +1126,103 @@ public class StreamingJobGraphGenerator { if (vertexClass.equals(StreamIterationHead.class) || vertexClass.equals(StreamIterationTail.class)) { - config.setIterationId(streamGraph.getBrokerID(vertexID)); - config.setIterationWaitTime(streamGraph.getLoopTimeout(vertexID)); + config.setIterationId(streamGraph.getBrokerID(vertexId)); + config.setIterationWaitTime(streamGraph.getLoopTimeout(vertexId)); + } + + vertexConfigs.put(vertexId, config); + } + + private void setOperatorChainedOutputsConfig( + StreamConfig config, List<StreamEdge> chainableOutputs) { + // iterate edges, find sideOutput edges create and save serializers for each outputTag type + for (StreamEdge edge : chainableOutputs) { + if (edge.getOutputTag() != null) { + config.setTypeSerializerSideOut( + edge.getOutputTag(), + edge.getOutputTag() + .getTypeInfo() + .createSerializer(streamGraph.getExecutionConfig())); + } + } + config.setChainedOutputs(chainableOutputs); + } + + private void setOperatorNonChainedOutputsConfig( + Integer vertexId, + StreamConfig config, + List<StreamEdge> nonChainableOutputs, + Map<StreamEdge, NonChainedOutput> outputsConsumedByEdge) { + // iterate edges, find sideOutput edges create and save serializers for each outputTag type + for (StreamEdge edge : nonChainableOutputs) { + if (edge.getOutputTag() != null) { + config.setTypeSerializerSideOut( + edge.getOutputTag(), + edge.getOutputTag() + .getTypeInfo() + .createSerializer(streamGraph.getExecutionConfig())); + } + } + + List<NonChainedOutput> deduplicatedOutputs = + mayReuseNonChainedOutputs(vertexId, nonChainableOutputs, outputsConsumedByEdge); + config.setNumberOfOutputs(deduplicatedOutputs.size()); + config.setOperatorNonChainedOutputs(deduplicatedOutputs); + } + + private void setVertexNonChainedOutputsConfig( + Integer startNodeId, + StreamConfig config, + List<StreamEdge> transitiveOutEdges, + final Map<Integer, Map<StreamEdge, NonChainedOutput>> opIntermediateOutputs) { + + LinkedHashSet<NonChainedOutput> transitiveOutputs = new LinkedHashSet<>(); + for (StreamEdge edge : transitiveOutEdges) { + NonChainedOutput output = opIntermediateOutputs.get(edge.getSourceId()).get(edge); + transitiveOutputs.add(output); + connect(startNodeId, edge, output); } - vertexConfigs.put(vertexID, config); + config.setVertexNonChainedOutputs(new ArrayList<>(transitiveOutputs)); + } + + private void setAllOperatorNonChainedOutputsConfigs( + final Map<Integer, Map<StreamEdge, NonChainedOutput>> opIntermediateOutputs) { + // set non chainable output config + opNonChainableOutputsCache.forEach( + (vertexId, nonChainableOutputs) -> { + Map<StreamEdge, NonChainedOutput> outputsConsumedByEdge = + opIntermediateOutputs.computeIfAbsent( + vertexId, ignored -> new HashMap<>()); + setOperatorNonChainedOutputsConfig( + vertexId, + vertexConfigs.get(vertexId), + nonChainableOutputs, + outputsConsumedByEdge); + }); + } + + private void setAllVertexNonChainedOutputsConfigs( + final Map<Integer, Map<StreamEdge, NonChainedOutput>> opIntermediateOutputs) { + jobVertices + .keySet() + .forEach( + startNodeId -> + setVertexNonChainedOutputsConfig( + startNodeId, + vertexConfigs.get(startNodeId), + chainInfos.get(startNodeId).getTransitiveOutEdges(), + opIntermediateOutputs)); } private List<NonChainedOutput> mayReuseNonChainedOutputs( - int vertexId, List<StreamEdge> consumerEdges) { + int vertexId, + List<StreamEdge> consumerEdges, + Map<StreamEdge, NonChainedOutput> outputsConsumedByEdge) { if (consumerEdges.isEmpty()) { return new ArrayList<>(); } List<NonChainedOutput> outputs = new ArrayList<>(consumerEdges.size()); - Map<StreamEdge, NonChainedOutput> outputsConsumedByEdge = - opIntermediateOutputs.computeIfAbsent(vertexId, ignored -> new HashMap<>()); for (StreamEdge consumerEdge : consumerEdges) { checkState(vertexId == consumerEdge.getSourceId(), "Vertex id must be the same."); ResultPartitionType partitionType = getResultPartitionType(consumerEdge); @@ -1251,7 +1406,7 @@ public class StreamingJobGraphGenerator { headVertex, DistributionPattern.POINTWISE, resultPartitionType, - opIntermediateOutputs.get(edge.getSourceId()).get(edge).getDataSetId(), + output.getDataSetId(), partitioner.isBroadcast()); } else { jobEdge = @@ -1259,7 +1414,7 @@ public class StreamingJobGraphGenerator { headVertex, DistributionPattern.ALL_TO_ALL, resultPartitionType, - opIntermediateOutputs.get(edge.getSourceId()).get(edge).getDataSetId(), + output.getDataSetId(), partitioner.isBroadcast()); } @@ -1894,6 +2049,7 @@ public class StreamingJobGraphGenerator { private final List<OperatorCoordinator.Provider> coordinatorProviders; private final StreamGraph streamGraph; private final List<StreamNode> chainedNodes; + private final List<StreamEdge> transitiveOutEdges; private OperatorChainInfo( int startNodeId, @@ -1909,6 +2065,7 @@ public class StreamingJobGraphGenerator { this.chainedSources = chainedSources; this.streamGraph = streamGraph; this.chainedNodes = new ArrayList<>(); + this.transitiveOutEdges = new ArrayList<>(); } byte[] getHash(Integer streamNodeId) { @@ -1955,6 +2112,14 @@ public class StreamingJobGraphGenerator { return new OperatorID(primaryHashBytes); } + private void setTransitiveOutEdges(final List<StreamEdge> transitiveOutEdges) { + this.transitiveOutEdges.addAll(transitiveOutEdges); + } + + private List<StreamEdge> getTransitiveOutEdges() { + return transitiveOutEdges; + } + private void recordChainedNode(int currentNodeId) { StreamNode streamNode = streamGraph.getStreamNode(currentNodeId); chainedNodes.add(streamNode); diff --git a/flink-tests/src/test/java/org/apache/flink/test/scheduling/AdaptiveBatchSchedulerITCase.java b/flink-tests/src/test/java/org/apache/flink/test/scheduling/AdaptiveBatchSchedulerITCase.java index cab7aa9ba74..771755abc5c 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/scheduling/AdaptiveBatchSchedulerITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/scheduling/AdaptiveBatchSchedulerITCase.java @@ -96,6 +96,37 @@ class AdaptiveBatchSchedulerITCase { env.execute(); } + @Test + void testDifferentConsumerParallelism() throws Exception { + final Configuration configuration = createConfiguration(); + final StreamExecutionEnvironment env = + StreamExecutionEnvironment.createLocalEnvironment(configuration); + env.setRuntimeMode(RuntimeExecutionMode.BATCH); + env.setParallelism(8); + + final DataStream<Long> source2 = + env.fromSequence(0, NUMBERS_TO_PRODUCE - 1) + .setParallelism(8) + .name("source2") + .slotSharingGroup("group2"); + + final DataStream<Long> source1 = + env.fromSequence(0, NUMBERS_TO_PRODUCE - 1) + .setParallelism(8) + .name("source1") + .slotSharingGroup("group1"); + + source1.forward() + .union(source2) + .map(new NumberCounter()) + .name("map1") + .slotSharingGroup("group3"); + + source2.map(new NumberCounter()).name("map2").slotSharingGroup("group4"); + + env.execute(); + } + private void testScheduling(Boolean isFineGrained) throws Exception { executeJob(isFineGrained);