This is an automated email from the ASF dual-hosted git repository. zhuzh pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push: new 7962a5d [FLINK-25668][runtime] Support to compute network memory for dynamic graph. 7962a5d is described below commit 7962a5d9c47f49f435f822a1af2c4141c42a849b Author: Lijie Wang <wangdachui9...@gmail.com> AuthorDate: Tue Dec 14 22:04:20 2021 +0800 [FLINK-25668][runtime] Support to compute network memory for dynamic graph. This closes #18376. --- .../runtime/deployment/SubpartitionIndexRange.java | 4 + .../TaskDeploymentDescriptorFactory.java | 24 ++- .../executiongraph/DefaultExecutionGraph.java | 20 ++ .../runtime/executiongraph/IntermediateResult.java | 2 +- .../flink/runtime/scheduler/DefaultScheduler.java | 11 - .../SsgNetworkMemoryCalculationUtils.java | 63 +++++- .../executiongraph/ExecutionJobVertexTest.java | 4 +- .../IntermediateResultPartitionTest.java | 2 +- .../SsgNetworkMemoryCalculationUtilsTest.java | 228 +++++++++++++++++---- 9 files changed, 291 insertions(+), 67 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/SubpartitionIndexRange.java b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/SubpartitionIndexRange.java index 1fb1d52..19484a6 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/SubpartitionIndexRange.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/SubpartitionIndexRange.java @@ -43,6 +43,10 @@ public class SubpartitionIndexRange implements Serializable { return endIndex; } + public int size() { + return endIndex - startIndex + 1; + } + @Override public String toString() { return String.format("[%d, %d]", startIndex, endIndex); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java index 528a954..bd0f5b3 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java @@ -128,17 +128,9 @@ public class TaskDeploymentDescriptorFactory { IntermediateResultPartition resultPartition = resultPartitionRetriever.apply(consumedPartitionGroup.getFirst()); - int numConsumers = resultPartition.getConsumerVertexGroup().size(); IntermediateResult consumedIntermediateResult = resultPartition.getIntermediateResult(); - int consumerIndex = subtaskIndex % numConsumers; - int numSubpartitions = resultPartition.getNumberOfSubpartitions(); SubpartitionIndexRange consumedSubpartitionRange = - computeConsumedSubpartitionRange( - consumerIndex, - numConsumers, - numSubpartitions, - consumedIntermediateResult.getProducer().getGraph().isDynamic(), - consumedIntermediateResult.isBroadcast()); + computeConsumedSubpartitionRange(resultPartition, subtaskIndex); IntermediateDataSetID resultId = consumedIntermediateResult.getId(); ResultPartitionType partitionType = consumedIntermediateResult.getResultType(); @@ -155,6 +147,20 @@ public class TaskDeploymentDescriptorFactory { return inputGates; } + public static SubpartitionIndexRange computeConsumedSubpartitionRange( + IntermediateResultPartition resultPartition, int consumerSubtaskIndex) { + int numConsumers = resultPartition.getConsumerVertexGroup().size(); + int consumerIndex = consumerSubtaskIndex % numConsumers; + IntermediateResult consumedIntermediateResult = resultPartition.getIntermediateResult(); + int numSubpartitions = resultPartition.getNumberOfSubpartitions(); + return computeConsumedSubpartitionRange( + consumerIndex, + numConsumers, + numSubpartitions, + consumedIntermediateResult.getProducer().getGraph().isDynamic(), + consumedIntermediateResult.isBroadcast()); + } + @VisibleForTesting static SubpartitionIndexRange computeConsumedSubpartitionRange( int consumerIndex, diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraph.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraph.java index 8796bc7..77f0e19 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraph.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraph.java @@ -59,8 +59,10 @@ import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; import org.apache.flink.runtime.jobgraph.JobVertex; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobgraph.tasks.CheckpointCoordinatorConfiguration; +import org.apache.flink.runtime.jobmanager.scheduler.SlotSharingGroup; import org.apache.flink.runtime.query.KvStateLocationRegistry; import org.apache.flink.runtime.scheduler.InternalFailuresListener; +import org.apache.flink.runtime.scheduler.SsgNetworkMemoryCalculationUtils; import org.apache.flink.runtime.scheduler.VertexParallelismInformation; import org.apache.flink.runtime.scheduler.VertexParallelismStore; import org.apache.flink.runtime.scheduler.adapter.DefaultExecutionTopology; @@ -854,6 +856,24 @@ public class DefaultExecutionGraph implements ExecutionGraph, InternalExecutionG } registerExecutionVerticesAndResultPartitionsFor(ejv); + + // enrich network memory. + SlotSharingGroup slotSharingGroup = ejv.getSlotSharingGroup(); + if (areJobVerticesAllInitialized(slotSharingGroup)) { + SsgNetworkMemoryCalculationUtils.enrichNetworkMemory( + slotSharingGroup, this::getJobVertex, shuffleMaster); + } + } + + private boolean areJobVerticesAllInitialized(final SlotSharingGroup group) { + for (JobVertexID jobVertexId : group.getJobVertexIds()) { + final ExecutionJobVertex jobVertex = getJobVertex(jobVertexId); + checkNotNull(jobVertex, "Unknown job vertex %s", jobVertexId); + if (!jobVertex.isInitialized()) { + return false; + } + } + return true; } @Override diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResult.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResult.java index 508e974..4b666b5 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResult.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResult.java @@ -168,7 +168,7 @@ public class IntermediateResult { return checkNotNull(getProducer().getGraph().getJobVertex(consumerJobVertexId)); } - DistributionPattern getConsumingDistributionPattern() { + public DistributionPattern getConsumingDistributionPattern() { final JobEdge consumer = checkNotNull(intermediateDataSet.getConsumer()); return consumer.getDistributionPattern(); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/DefaultScheduler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/DefaultScheduler.java index bd202ba..1a2fa36 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/DefaultScheduler.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/DefaultScheduler.java @@ -173,8 +173,6 @@ public class DefaultScheduler extends SchedulerBase implements SchedulerOperatio jobGraph.getName(), jobGraph.getJobID()); - enrichResourceProfile(); - this.executionFailureHandler = new ExecutionFailureHandler( getSchedulingTopology(), failoverStrategy, restartBackoffTimeStrategy); @@ -723,13 +721,4 @@ public class DefaultScheduler extends SchedulerBase implements SchedulerOperatio return reservedAllocationRefCounters.keySet(); } } - - private void enrichResourceProfile() { - Set<SlotSharingGroup> ssgs = new HashSet<>(); - getJobGraph().getVertices().forEach(jv -> ssgs.add(jv.getSlotSharingGroup())); - ssgs.forEach( - ssg -> - SsgNetworkMemoryCalculationUtils.enrichNetworkMemory( - ssg, this::getExecutionJobVertex, shuffleMaster)); - } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java index 13c6172..0fac885 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java @@ -18,11 +18,16 @@ package org.apache.flink.runtime.scheduler; +import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.configuration.MemorySize; import org.apache.flink.runtime.clusterframework.types.ResourceProfile; +import org.apache.flink.runtime.deployment.SubpartitionIndexRange; +import org.apache.flink.runtime.deployment.TaskDeploymentDescriptorFactory; import org.apache.flink.runtime.executiongraph.EdgeManagerBuildUtil; import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; +import org.apache.flink.runtime.executiongraph.ExecutionVertex; import org.apache.flink.runtime.executiongraph.IntermediateResult; +import org.apache.flink.runtime.executiongraph.IntermediateResultPartition; import org.apache.flink.runtime.io.network.partition.ResultPartitionType; import org.apache.flink.runtime.jobgraph.IntermediateDataSet; import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; @@ -30,9 +35,11 @@ import org.apache.flink.runtime.jobgraph.JobEdge; import org.apache.flink.runtime.jobgraph.JobVertex; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobmanager.scheduler.SlotSharingGroup; +import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup; import org.apache.flink.runtime.shuffle.ShuffleMaster; import org.apache.flink.runtime.shuffle.TaskInputsOutputsDescriptor; +import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -51,7 +58,7 @@ public class SsgNetworkMemoryCalculationUtils { * Calculates network memory requirement of {@link ExecutionJobVertex} and update {@link * ResourceProfile} of corresponding slot sharing group. */ - static void enrichNetworkMemory( + public static void enrichNetworkMemory( SlotSharingGroup ssg, Function<JobVertexID, ExecutionJobVertex> ejvs, ShuffleMaster<?> shuffleMaster) { @@ -88,8 +95,17 @@ public class SsgNetworkMemoryCalculationUtils { private static TaskInputsOutputsDescriptor buildTaskInputsOutputsDescriptor( ExecutionJobVertex ejv, Function<JobVertexID, ExecutionJobVertex> ejvs) { - Map<IntermediateDataSetID, Integer> maxInputChannelNums = getMaxInputChannelNums(ejv); - Map<IntermediateDataSetID, Integer> maxSubpartitionNums = getMaxSubpartitionNums(ejv, ejvs); + Map<IntermediateDataSetID, Integer> maxInputChannelNums; + Map<IntermediateDataSetID, Integer> maxSubpartitionNums; + + if (ejv.getGraph().isDynamic()) { + maxInputChannelNums = getMaxInputChannelNumsForDynamicGraph(ejv); + maxSubpartitionNums = getMaxSubpartitionNumsForDynamicGraph(ejv); + } else { + maxInputChannelNums = getMaxInputChannelNums(ejv); + maxSubpartitionNums = getMaxSubpartitionNums(ejv, ejvs); + } + JobVertex jv = ejv.getJobVertex(); Map<IntermediateDataSetID, ResultPartitionType> partitionTypes = getPartitionTypes(jv); @@ -148,6 +164,47 @@ public class SsgNetworkMemoryCalculationUtils { return ret; } + @VisibleForTesting + static Map<IntermediateDataSetID, Integer> getMaxInputChannelNumsForDynamicGraph( + ExecutionJobVertex ejv) { + + Map<IntermediateDataSetID, Integer> ret = new HashMap<>(); + + for (ExecutionVertex vertex : ejv.getTaskVertices()) { + for (ConsumedPartitionGroup partitionGroup : vertex.getAllConsumedPartitionGroups()) { + + IntermediateResultPartition resultPartition = + ejv.getGraph().getResultPartitionOrThrow((partitionGroup.getFirst())); + SubpartitionIndexRange subpartitionIndexRange = + TaskDeploymentDescriptorFactory.computeConsumedSubpartitionRange( + resultPartition, vertex.getParallelSubtaskIndex()); + + ret.merge( + partitionGroup.getIntermediateDataSetID(), + subpartitionIndexRange.size() * partitionGroup.size(), + Integer::max); + } + } + + return ret; + } + + private static Map<IntermediateDataSetID, Integer> getMaxSubpartitionNumsForDynamicGraph( + ExecutionJobVertex ejv) { + + Map<IntermediateDataSetID, Integer> ret = new HashMap<>(); + + for (IntermediateResult intermediateResult : ejv.getProducedDataSets()) { + final int maxNum = + Arrays.stream(intermediateResult.getPartitions()) + .map(IntermediateResultPartition::getNumberOfSubpartitions) + .reduce(0, Integer::max); + ret.put(intermediateResult.getId(), maxNum); + } + + return ret; + } + /** Private default constructor to avoid being instantiated. */ private SsgNetworkMemoryCalculationUtils() {} } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertexTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertexTest.java index 186cbf5..240427b 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertexTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertexTest.java @@ -193,7 +193,7 @@ public class ExecutionJobVertexTest { return createDynamicExecutionJobVertex(-1, -1, 1); } - private static ExecutionJobVertex createDynamicExecutionJobVertex( + public static ExecutionJobVertex createDynamicExecutionJobVertex( int parallelism, int maxParallelism, int defaultMaxParallelism) throws Exception { JobVertex jobVertex = new JobVertex("testVertex"); jobVertex.setInvokableClass(AbstractInvokable.class); @@ -227,7 +227,7 @@ public class ExecutionJobVertexTest { * @param defaultMaxParallelism the global default max parallelism * @return the computed parallelism store */ - static VertexParallelismStore computeVertexParallelismStoreForDynamicGraph( + public static VertexParallelismStore computeVertexParallelismStoreForDynamicGraph( Iterable<JobVertex> vertices, int defaultMaxParallelism) { // for dynamic graph, there is no need to normalize vertex parallelism. if the max // parallelism is not configured and the parallelism is a positive value, max diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartitionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartitionTest.java index 42f60b2..611e18a 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartitionTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartitionTest.java @@ -221,7 +221,7 @@ public class IntermediateResultPartitionTest extends TestLogger { equalTo(expectedNumSubpartitions)); } - private static ExecutionGraph createExecutionGraph( + public static ExecutionGraph createExecutionGraph( int producerParallelism, int consumerParallelism, int consumerMaxParallelism, diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtilsTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtilsTest.java index 54ad8eb..9481c82 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtilsTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtilsTest.java @@ -21,10 +21,16 @@ package org.apache.flink.runtime.scheduler; import org.apache.flink.api.common.JobID; import org.apache.flink.configuration.MemorySize; import org.apache.flink.runtime.clusterframework.types.ResourceProfile; -import org.apache.flink.runtime.executiongraph.ExecutionGraph; +import org.apache.flink.runtime.executiongraph.DefaultExecutionGraph; +import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; +import org.apache.flink.runtime.executiongraph.ExecutionJobVertexTest; +import org.apache.flink.runtime.executiongraph.IntermediateResult; +import org.apache.flink.runtime.executiongraph.IntermediateResultPartition; +import org.apache.flink.runtime.executiongraph.IntermediateResultPartitionTest; import org.apache.flink.runtime.executiongraph.TestingDefaultExecutionGraphBuilder; import org.apache.flink.runtime.io.network.partition.ResultPartitionType; import org.apache.flink.runtime.jobgraph.DistributionPattern; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; import org.apache.flink.runtime.jobgraph.JobGraph; import org.apache.flink.runtime.jobgraph.JobGraphTestUtils; import org.apache.flink.runtime.jobgraph.JobVertex; @@ -39,9 +45,13 @@ import org.apache.flink.runtime.testtasks.NoOpInvokable; import org.junit.Test; import java.util.Arrays; +import java.util.Iterator; import java.util.List; +import java.util.Map; import java.util.concurrent.CompletableFuture; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; import static org.junit.Assert.assertEquals; /** Tests for {@link SsgNetworkMemoryCalculationUtils}. */ @@ -51,81 +61,219 @@ public class SsgNetworkMemoryCalculationUtilsTest { private static final ResourceProfile DEFAULT_RESOURCE = ResourceProfile.fromResources(1.0, 100); - private JobGraph jobGraph; - - private ExecutionGraph executionGraph; - - private List<SlotSharingGroup> slotSharingGroups; - @Test public void testGenerateEnrichedResourceProfile() throws Exception { - setup(DEFAULT_RESOURCE); - slotSharingGroups.forEach( - ssg -> - SsgNetworkMemoryCalculationUtils.enrichNetworkMemory( - ssg, executionGraph.getAllVertices()::get, SHUFFLE_MASTER)); + SlotSharingGroup slotSharingGroup0 = new SlotSharingGroup(); + slotSharingGroup0.setResourceProfile(DEFAULT_RESOURCE); + + SlotSharingGroup slotSharingGroup1 = new SlotSharingGroup(); + slotSharingGroup1.setResourceProfile(DEFAULT_RESOURCE); + + createExecutionGraphAndEnrichNetworkMemory( + Arrays.asList(slotSharingGroup0, slotSharingGroup0, slotSharingGroup1)); assertEquals( new MemorySize( TestShuffleMaster.computeRequiredShuffleMemoryBytes(0, 2) + TestShuffleMaster.computeRequiredShuffleMemoryBytes(1, 6)), - slotSharingGroups.get(0).getResourceProfile().getNetworkMemory()); - + slotSharingGroup0.getResourceProfile().getNetworkMemory()); assertEquals( new MemorySize(TestShuffleMaster.computeRequiredShuffleMemoryBytes(5, 0)), - slotSharingGroups.get(1).getResourceProfile().getNetworkMemory()); + slotSharingGroup1.getResourceProfile().getNetworkMemory()); } @Test public void testGenerateUnknownResourceProfile() throws Exception { - setup(ResourceProfile.UNKNOWN); + SlotSharingGroup slotSharingGroup0 = new SlotSharingGroup(); + slotSharingGroup0.setResourceProfile(ResourceProfile.UNKNOWN); + + SlotSharingGroup slotSharingGroup1 = new SlotSharingGroup(); + slotSharingGroup1.setResourceProfile(ResourceProfile.UNKNOWN); + + createExecutionGraphAndEnrichNetworkMemory( + Arrays.asList(slotSharingGroup0, slotSharingGroup0, slotSharingGroup1)); + + assertEquals(ResourceProfile.UNKNOWN, slotSharingGroup0.getResourceProfile()); + assertEquals(ResourceProfile.UNKNOWN, slotSharingGroup1.getResourceProfile()); + } + + @Test + public void testGenerateEnrichedResourceProfileForDynamicGraph() throws Exception { + List<SlotSharingGroup> slotSharingGroups = + Arrays.asList( + new SlotSharingGroup(), new SlotSharingGroup(), new SlotSharingGroup()); + + for (SlotSharingGroup group : slotSharingGroups) { + group.setResourceProfile(DEFAULT_RESOURCE); + } + + DefaultExecutionGraph executionGraph = createDynamicExecutionGraph(slotSharingGroups, 20); + Iterator<ExecutionJobVertex> jobVertices = + executionGraph.getVerticesTopologically().iterator(); + ExecutionJobVertex source = jobVertices.next(); + ExecutionJobVertex map = jobVertices.next(); + ExecutionJobVertex sink = jobVertices.next(); - slotSharingGroups.forEach( - ssg -> - SsgNetworkMemoryCalculationUtils.enrichNetworkMemory( - ssg, executionGraph.getAllVertices()::get, SHUFFLE_MASTER)); + executionGraph.initializeJobVertex(source, 0L); + triggerComputeNumOfSubpartitions(source.getProducedDataSets()[0]); + + map.setParallelism(5); + executionGraph.initializeJobVertex(map, 0L); + triggerComputeNumOfSubpartitions(map.getProducedDataSets()[0]); + + sink.setParallelism(7); + executionGraph.initializeJobVertex(sink, 0L); + + assertNetworkMemory( + slotSharingGroups, + Arrays.asList( + new MemorySize(TestShuffleMaster.computeRequiredShuffleMemoryBytes(0, 5)), + new MemorySize(TestShuffleMaster.computeRequiredShuffleMemoryBytes(5, 20)), + new MemorySize( + TestShuffleMaster.computeRequiredShuffleMemoryBytes(15, 0)))); + } - for (SlotSharingGroup slotSharingGroup : slotSharingGroups) { - assertEquals(ResourceProfile.UNKNOWN, slotSharingGroup.getResourceProfile()); + private void triggerComputeNumOfSubpartitions(IntermediateResult result) { + // call IntermediateResultPartition#getNumberOfSubpartitions to trigger computation of + // numOfSubpartitions + for (IntermediateResultPartition partition : result.getPartitions()) { + partition.getNumberOfSubpartitions(); } } - private void setup(final ResourceProfile resourceProfile) throws Exception { - slotSharingGroups = Arrays.asList(new SlotSharingGroup(), new SlotSharingGroup()); + private void assertNetworkMemory( + List<SlotSharingGroup> slotSharingGroups, List<MemorySize> networkMemory) { - for (SlotSharingGroup slotSharingGroup : slotSharingGroups) { - slotSharingGroup.setResourceProfile(resourceProfile); + assertEquals(slotSharingGroups.size(), networkMemory.size()); + for (int i = 0; i < slotSharingGroups.size(); ++i) { + assertThat( + slotSharingGroups.get(i).getResourceProfile().getNetworkMemory(), + is(networkMemory.get(i))); } + } + + @Test + public void testGetMaxInputChannelNumForResultForAllToAll() throws Exception { + testGetMaxInputChannelNumForResult(DistributionPattern.ALL_TO_ALL, 5, 20, 7, 15); + } - jobGraph = createJobGraph(slotSharingGroups); - executionGraph = - TestingDefaultExecutionGraphBuilder.newBuilder().setJobGraph(jobGraph).build(); + @Test + public void testGetMaxInputChannelNumForResultForPointWise() throws Exception { + testGetMaxInputChannelNumForResult(DistributionPattern.POINTWISE, 5, 20, 3, 8); + testGetMaxInputChannelNumForResult(DistributionPattern.POINTWISE, 5, 20, 5, 4); + testGetMaxInputChannelNumForResult(DistributionPattern.POINTWISE, 5, 20, 7, 4); } - private static JobGraph createJobGraph(final List<SlotSharingGroup> slotSharingGroups) { + private void testGetMaxInputChannelNumForResult( + DistributionPattern distributionPattern, + int producerParallelism, + int consumerMaxParallelism, + int decidedConsumerParallelism, + int expectedNumChannels) + throws Exception { + + final DefaultExecutionGraph eg = + (DefaultExecutionGraph) + IntermediateResultPartitionTest.createExecutionGraph( + producerParallelism, + -1, + consumerMaxParallelism, + distributionPattern, + true); + + final Iterator<ExecutionJobVertex> vertexIterator = + eg.getVerticesTopologically().iterator(); + final ExecutionJobVertex producer = vertexIterator.next(); + final ExecutionJobVertex consumer = vertexIterator.next(); + + eg.initializeJobVertex(producer, 0L); + final IntermediateResult result = producer.getProducedDataSets()[0]; + triggerComputeNumOfSubpartitions(result); + + consumer.setParallelism(decidedConsumerParallelism); + eg.initializeJobVertex(consumer, 0L); + + Map<IntermediateDataSetID, Integer> maxInputChannelNums = + SsgNetworkMemoryCalculationUtils.getMaxInputChannelNumsForDynamicGraph(consumer); + + assertThat(maxInputChannelNums.size(), is(1)); + assertThat(maxInputChannelNums.get(result.getId()), is(expectedNumChannels)); + } + + private DefaultExecutionGraph createDynamicExecutionGraph( + final List<SlotSharingGroup> slotSharingGroups, int defaultMaxParallelism) + throws Exception { + + JobGraph jobGraph = createBatchGraph(slotSharingGroups, Arrays.asList(4, -1, -1)); + + final VertexParallelismStore vertexParallelismStore = + ExecutionJobVertexTest.computeVertexParallelismStoreForDynamicGraph( + jobGraph.getVertices(), defaultMaxParallelism); + + return TestingDefaultExecutionGraphBuilder.newBuilder() + .setJobGraph(jobGraph) + .setVertexParallelismStore(vertexParallelismStore) + .setShuffleMaster(SHUFFLE_MASTER) + .buildDynamicGraph(); + } + + private void createExecutionGraphAndEnrichNetworkMemory( + final List<SlotSharingGroup> slotSharingGroups) throws Exception { + TestingDefaultExecutionGraphBuilder.newBuilder() + .setJobGraph(createStreamingGraph(slotSharingGroups, Arrays.asList(4, 5, 6))) + .setShuffleMaster(SHUFFLE_MASTER) + .build(); + } + + private static JobGraph createStreamingGraph( + final List<SlotSharingGroup> slotSharingGroups, List<Integer> parallelisms) { + return createJobGraph(slotSharingGroups, parallelisms, ResultPartitionType.PIPELINED); + } + + private static JobGraph createBatchGraph( + final List<SlotSharingGroup> slotSharingGroups, List<Integer> parallelisms) { + return createJobGraph(slotSharingGroups, parallelisms, ResultPartitionType.BLOCKING); + } + + private static JobGraph createJobGraph( + final List<SlotSharingGroup> slotSharingGroups, + List<Integer> parallelisms, + ResultPartitionType resultPartitionType) { + + assertThat(slotSharingGroups.size(), is(3)); + assertThat(parallelisms.size(), is(3)); JobVertex source = new JobVertex("source"); source.setInvokableClass(NoOpInvokable.class); - source.setParallelism(4); + trySetParallelism(source, parallelisms.get(0)); source.setSlotSharingGroup(slotSharingGroups.get(0)); JobVertex map = new JobVertex("map"); map.setInvokableClass(NoOpInvokable.class); - map.setParallelism(5); - map.setSlotSharingGroup(slotSharingGroups.get(0)); + trySetParallelism(map, parallelisms.get(1)); + map.setSlotSharingGroup(slotSharingGroups.get(1)); JobVertex sink = new JobVertex("sink"); sink.setInvokableClass(NoOpInvokable.class); - sink.setParallelism(6); - sink.setSlotSharingGroup(slotSharingGroups.get(1)); + trySetParallelism(sink, parallelisms.get(2)); + sink.setSlotSharingGroup(slotSharingGroups.get(2)); + + map.connectNewDataSetAsInput(source, DistributionPattern.POINTWISE, resultPartitionType); + sink.connectNewDataSetAsInput(map, DistributionPattern.ALL_TO_ALL, resultPartitionType); + + if (resultPartitionType.isPipelined()) { + return JobGraphTestUtils.streamingJobGraph(source, map, sink); - map.connectNewDataSetAsInput( - source, DistributionPattern.POINTWISE, ResultPartitionType.PIPELINED); - sink.connectNewDataSetAsInput( - map, DistributionPattern.ALL_TO_ALL, ResultPartitionType.PIPELINED); + } else { + return JobGraphTestUtils.batchJobGraph(source, map, sink); + } + } - return JobGraphTestUtils.streamingJobGraph(source, map, sink); + private static void trySetParallelism(JobVertex jobVertex, int parallelism) { + if (parallelism > 0) { + jobVertex.setParallelism(parallelism); + } } private static class TestShuffleMaster implements ShuffleMaster<ShuffleDescriptor> {