This is an automated email from the ASF dual-hosted git repository. zhuzh pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit c0bfb0b04bb38411c390813fea1ff93f6638f409 Author: Zhu Zhu <reed...@gmail.com> AuthorDate: Mon Jan 30 12:02:13 2023 +0800 [FLINK-15325][coordination] Set the ConsumedPartitionGroup/ConsumerVertexGroup to its corresponding ConsumerVertexGroup/ConsumedPartitionGroup --- .../executiongraph/EdgeManagerBuildUtil.java | 3 + .../scheduler/strategy/ConsumedPartitionGroup.java | 20 +- .../scheduler/strategy/ConsumerVertexGroup.java | 21 +- .../executiongraph/EdgeManagerBuildUtilTest.java | 64 ++++-- .../strategy/TestingSchedulingExecutionVertex.java | 61 +----- .../strategy/TestingSchedulingResultPartition.java | 14 +- .../strategy/TestingSchedulingTopology.java | 236 ++++++++++++++------- 7 files changed, 244 insertions(+), 175 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtil.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtil.java index a3613eba27c..889d7a39ae1 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtil.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtil.java @@ -174,6 +174,9 @@ public class EdgeManagerBuildUtil { for (IntermediateResultPartition partition : partitions) { partition.addConsumers(consumerVertexGroup); } + + consumedPartitionGroup.setConsumerVertexGroup(consumerVertexGroup); + consumerVertexGroup.setConsumedPartitionGroup(consumedPartitionGroup); } private static ConsumedPartitionGroup createAndRegisterConsumedPartitionGroupToEdgeManager( diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/ConsumedPartitionGroup.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/ConsumedPartitionGroup.java index c3d32fa99fb..e4440d7575c 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/ConsumedPartitionGroup.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/ConsumedPartitionGroup.java @@ -23,14 +23,21 @@ import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; import org.apache.flink.util.Preconditions; +import javax.annotation.Nullable; + import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.concurrent.atomic.AtomicInteger; import static org.apache.flink.util.Preconditions.checkArgument; +import static org.apache.flink.util.Preconditions.checkNotNull; +import static org.apache.flink.util.Preconditions.checkState; -/** Group of consumed {@link IntermediateResultPartitionID}s. */ +/** + * Group of consumed {@link IntermediateResultPartitionID}s. One such a group corresponds to one + * {@link ConsumerVertexGroup}. + */ public class ConsumedPartitionGroup implements Iterable<IntermediateResultPartitionID> { private final List<IntermediateResultPartitionID> resultPartitions; @@ -44,6 +51,8 @@ public class ConsumedPartitionGroup implements Iterable<IntermediateResultPartit /** Number of consumer tasks in the corresponding {@link ConsumerVertexGroup}. */ private final int numConsumers; + @Nullable private ConsumerVertexGroup consumerVertexGroup; + private ConsumedPartitionGroup( int numConsumers, List<IntermediateResultPartitionID> resultPartitions, @@ -130,4 +139,13 @@ public class ConsumedPartitionGroup implements Iterable<IntermediateResultPartit public ResultPartitionType getResultPartitionType() { return resultPartitionType; } + + public ConsumerVertexGroup getConsumerVertexGroup() { + return checkNotNull(consumerVertexGroup, "ConsumerVertexGroup is not properly set."); + } + + public void setConsumerVertexGroup(ConsumerVertexGroup consumerVertexGroup) { + checkState(this.consumerVertexGroup == null); + this.consumerVertexGroup = checkNotNull(consumerVertexGroup); + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/ConsumerVertexGroup.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/ConsumerVertexGroup.java index fb8b3f1951c..9939206b7d9 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/ConsumerVertexGroup.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/ConsumerVertexGroup.java @@ -20,16 +20,26 @@ package org.apache.flink.runtime.scheduler.strategy; import org.apache.flink.runtime.io.network.partition.ResultPartitionType; +import javax.annotation.Nullable; + import java.util.Collections; import java.util.Iterator; import java.util.List; -/** Group of consumer {@link ExecutionVertexID}s. */ +import static org.apache.flink.util.Preconditions.checkNotNull; +import static org.apache.flink.util.Preconditions.checkState; + +/** + * Group of consumer {@link ExecutionVertexID}s. One such a group corresponds to one {@link + * ConsumedPartitionGroup}. + */ public class ConsumerVertexGroup implements Iterable<ExecutionVertexID> { private final List<ExecutionVertexID> vertices; private final ResultPartitionType resultPartitionType; + @Nullable private ConsumedPartitionGroup consumedPartitionGroup; + private ConsumerVertexGroup( List<ExecutionVertexID> vertices, ResultPartitionType resultPartitionType) { this.vertices = vertices; @@ -66,4 +76,13 @@ public class ConsumerVertexGroup implements Iterable<ExecutionVertexID> { public ExecutionVertexID getFirst() { return iterator().next(); } + + public ConsumedPartitionGroup getConsumedPartitionGroup() { + return checkNotNull(consumedPartitionGroup, "ConsumedPartitionGroup is not properly set."); + } + + public void setConsumedPartitionGroup(ConsumedPartitionGroup consumedPartitionGroup) { + checkState(this.consumedPartitionGroup == null); + this.consumedPartitionGroup = checkNotNull(consumedPartitionGroup); + } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtilTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtilTest.java index 41422bb48b1..63ecd9d442b 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtilTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtilTest.java @@ -22,6 +22,7 @@ import org.apache.flink.runtime.io.network.partition.ResultPartitionType; import org.apache.flink.runtime.jobgraph.DistributionPattern; import org.apache.flink.runtime.jobgraph.JobVertex; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; +import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup; import org.apache.flink.runtime.scheduler.strategy.ConsumerVertexGroup; import org.apache.flink.testutils.TestingUtils; import org.apache.flink.testutils.executor.TestExecutorExtension; @@ -116,21 +117,24 @@ class EdgeManagerBuildUtilTest { ExecutionVertex vertex2 = consumer.getTaskVertices()[1]; // check consumers of the partitions - assertThat(partition1.getConsumerVertexGroups().get(0)) - .containsExactlyInAnyOrder(vertex1.getID(), vertex2.getID()); - assertThat(partition1.getConsumerVertexGroups().get(0)) - .isEqualTo(partition1.getConsumerVertexGroups().get(0)); - assertThat(partition3.getConsumerVertexGroups().get(0)) - .isEqualTo(partition1.getConsumerVertexGroups().get(0)); + ConsumerVertexGroup consumerVertexGroup = partition1.getConsumerVertexGroups().get(0); + assertThat(consumerVertexGroup).containsExactlyInAnyOrder(vertex1.getID(), vertex2.getID()); + assertThat(partition2.getConsumerVertexGroups().get(0)).isEqualTo(consumerVertexGroup); + assertThat(partition3.getConsumerVertexGroups().get(0)).isEqualTo(consumerVertexGroup); // check inputs of the execution vertices - assertThat(vertex1.getConsumedPartitionGroup(0)) + ConsumedPartitionGroup consumedPartitionGroup = vertex1.getConsumedPartitionGroup(0); + assertThat(consumedPartitionGroup) .containsExactlyInAnyOrder( partition1.getPartitionId(), partition2.getPartitionId(), partition3.getPartitionId()); - assertThat(vertex2.getConsumedPartitionGroup(0)) - .isEqualTo(vertex1.getConsumedPartitionGroup(0)); + assertThat(vertex2.getConsumedPartitionGroup(0)).isEqualTo(consumedPartitionGroup); + + // check the consumerVertexGroup and consumedPartitionGroup are set to each other + assertThat(consumerVertexGroup.getConsumedPartitionGroup()) + .isEqualTo(consumedPartitionGroup); + assertThat(consumedPartitionGroup.getConsumerVertexGroup()).isEqualTo(consumerVertexGroup); } @Test @@ -186,25 +190,39 @@ class EdgeManagerBuildUtilTest { ExecutionVertex vertex4 = consumer.getTaskVertices()[3]; // check consumers of the partitions - assertThat(partition1.getConsumerVertexGroups().get(0)) + ConsumerVertexGroup consumerVertexGroup1 = partition1.getConsumerVertexGroups().get(0); + ConsumerVertexGroup consumerVertexGroup2 = partition2.getConsumerVertexGroups().get(0); + ConsumerVertexGroup consumerVertexGroup3 = partition4.getConsumerVertexGroups().get(0); + assertThat(consumerVertexGroup1) .containsExactlyInAnyOrder(vertex1.getID(), vertex2.getID()); - assertThat(partition2.getConsumerVertexGroups().get(0)) - .containsExactlyInAnyOrder(vertex3.getID()); - assertThat(partition3.getConsumerVertexGroups().get(0)) - .isEqualTo(partition2.getConsumerVertexGroups().get(0)); - assertThat(partition4.getConsumerVertexGroups().get(0)) - .containsExactlyInAnyOrder(vertex4.getID()); + assertThat(consumerVertexGroup2).containsExactlyInAnyOrder(vertex3.getID()); + assertThat(partition3.getConsumerVertexGroups().get(0)).isEqualTo(consumerVertexGroup2); + assertThat(consumerVertexGroup3).containsExactlyInAnyOrder(vertex4.getID()); // check inputs of the execution vertices - assertThat(vertex1.getConsumedPartitionGroup(0)) - .containsExactlyInAnyOrder(partition1.getPartitionId()); - assertThat(vertex2.getConsumedPartitionGroup(0)) - .isEqualTo(vertex1.getConsumedPartitionGroup(0)); - assertThat(vertex3.getConsumedPartitionGroup(0)) + ConsumedPartitionGroup consumedPartitionGroup1 = vertex1.getConsumedPartitionGroup(0); + ConsumedPartitionGroup consumedPartitionGroup2 = vertex3.getConsumedPartitionGroup(0); + ConsumedPartitionGroup consumedPartitionGroup3 = vertex4.getConsumedPartitionGroup(0); + assertThat(consumedPartitionGroup1).containsExactlyInAnyOrder(partition1.getPartitionId()); + assertThat(vertex2.getConsumedPartitionGroup(0)).isEqualTo(consumedPartitionGroup1); + assertThat(consumedPartitionGroup2) .containsExactlyInAnyOrder( partition2.getPartitionId(), partition3.getPartitionId()); - assertThat(vertex4.getConsumedPartitionGroup(0)) - .containsExactlyInAnyOrder(partition4.getPartitionId()); + assertThat(consumedPartitionGroup3).containsExactlyInAnyOrder(partition4.getPartitionId()); + + // check the consumerVertexGroups and consumedPartitionGroups are properly set + assertThat(consumerVertexGroup1.getConsumedPartitionGroup()) + .isEqualTo(consumedPartitionGroup1); + assertThat(consumedPartitionGroup1.getConsumerVertexGroup()) + .isEqualTo(consumerVertexGroup1); + assertThat(consumerVertexGroup2.getConsumedPartitionGroup()) + .isEqualTo(consumedPartitionGroup2); + assertThat(consumedPartitionGroup2.getConsumerVertexGroup()) + .isEqualTo(consumerVertexGroup2); + assertThat(consumerVertexGroup3.getConsumedPartitionGroup()) + .isEqualTo(consumedPartitionGroup3); + assertThat(consumedPartitionGroup3.getConsumerVertexGroup()) + .isEqualTo(consumerVertexGroup3); } private void testGetMaxNumEdgesToTarget( diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingExecutionVertex.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingExecutionVertex.java index 3bc9c404fae..b549da3e976 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingExecutionVertex.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingExecutionVertex.java @@ -19,7 +19,6 @@ package org.apache.flink.runtime.scheduler.strategy; import org.apache.flink.runtime.execution.ExecutionState; -import org.apache.flink.runtime.io.network.partition.ResultPartitionType; import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.util.IterableUtils; @@ -30,8 +29,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import static org.apache.flink.util.Preconditions.checkNotNull; - /** A simple scheduling execution vertex for testing purposes. */ public class TestingSchedulingExecutionVertex implements SchedulingExecutionVertex { @@ -47,17 +44,12 @@ public class TestingSchedulingExecutionVertex implements SchedulingExecutionVert private ExecutionState executionState; public TestingSchedulingExecutionVertex( - JobVertexID jobVertexId, - int subtaskIndex, - List<ConsumedPartitionGroup> consumedPartitionGroups, - Map<IntermediateResultPartitionID, TestingSchedulingResultPartition> - resultPartitionsById, - ExecutionState executionState) { + JobVertexID jobVertexId, int subtaskIndex, ExecutionState executionState) { this.executionVertexId = new ExecutionVertexID(jobVertexId, subtaskIndex); - this.consumedPartitionGroups = checkNotNull(consumedPartitionGroups); + this.consumedPartitionGroups = new ArrayList<>(); this.producedPartitions = new ArrayList<>(); - this.resultPartitionsById = checkNotNull(resultPartitionsById); + this.resultPartitionsById = new HashMap<>(); this.executionState = executionState; } @@ -90,22 +82,6 @@ public class TestingSchedulingExecutionVertex implements SchedulingExecutionVert return consumedPartitionGroups; } - void addConsumedPartition(TestingSchedulingResultPartition consumedPartition) { - final ConsumedPartitionGroup consumedPartitionGroup = - ConsumedPartitionGroup.fromSinglePartition( - consumedPartition.getNumConsumers(), - consumedPartition.getId(), - consumedPartition.getResultType()); - - consumedPartition.registerConsumedPartitionGroup(consumedPartitionGroup); - if (consumedPartition.getState() == ResultPartitionState.ALL_DATA_PRODUCED) { - consumedPartitionGroup.partitionFinished(); - } - - this.consumedPartitionGroups.add(consumedPartitionGroup); - this.resultPartitionsById.putIfAbsent(consumedPartition.getId(), consumedPartition); - } - void addConsumedPartitionGroup( ConsumedPartitionGroup consumedPartitionGroup, Map<IntermediateResultPartitionID, TestingSchedulingResultPartition> @@ -131,9 +107,6 @@ public class TestingSchedulingExecutionVertex implements SchedulingExecutionVert public static class Builder { private JobVertexID jobVertexId = new JobVertexID(); private int subtaskIndex = 0; - private final List<ConsumedPartitionGroup> consumedPartitionGroups = new ArrayList<>(); - private final Map<IntermediateResultPartitionID, TestingSchedulingResultPartition> - resultPartitionsById = new HashMap<>(); private ExecutionState executionState = ExecutionState.CREATED; Builder withExecutionVertexID(JobVertexID jobVertexId, int subtaskIndex) { @@ -142,39 +115,13 @@ public class TestingSchedulingExecutionVertex implements SchedulingExecutionVert return this; } - public Builder withConsumedPartitionGroups( - List<ConsumedPartitionGroup> consumedPartitionGroups, - Map<IntermediateResultPartitionID, TestingSchedulingResultPartition> - resultPartitionsById) { - this.resultPartitionsById.putAll(resultPartitionsById); - final ResultPartitionType resultType = - resultPartitionsById.values().iterator().next().getResultType(); - - for (ConsumedPartitionGroup partitionGroup : consumedPartitionGroups) { - List<IntermediateResultPartitionID> partitionIds = - new ArrayList<>(partitionGroup.size()); - for (IntermediateResultPartitionID partitionId : partitionGroup) { - partitionIds.add(partitionId); - } - this.consumedPartitionGroups.add( - ConsumedPartitionGroup.fromMultiplePartitions( - partitionGroup.getNumConsumers(), partitionIds, resultType)); - } - return this; - } - public Builder withExecutionState(ExecutionState executionState) { this.executionState = executionState; return this; } public TestingSchedulingExecutionVertex build() { - return new TestingSchedulingExecutionVertex( - jobVertexId, - subtaskIndex, - consumedPartitionGroups, - resultPartitionsById, - executionState); + return new TestingSchedulingExecutionVertex(jobVertexId, subtaskIndex, executionState); } } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingResultPartition.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingResultPartition.java index 70759fd2b6c..77514ba57b2 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingResultPartition.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingResultPartition.java @@ -25,10 +25,8 @@ import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; import javax.annotation.Nullable; import java.util.ArrayList; -import java.util.Collection; import java.util.Collections; import java.util.List; -import java.util.stream.Collectors; import static org.apache.flink.util.Preconditions.checkNotNull; import static org.apache.flink.util.Preconditions.checkState; @@ -102,18 +100,8 @@ public class TestingSchedulingResultPartition implements SchedulingResultPartiti return Collections.unmodifiableList(consumedPartitionGroups); } - void addConsumerGroup( - Collection<TestingSchedulingExecutionVertex> consumerVertices, - ResultPartitionType resultPartitionType) { + void addConsumerGroup(ConsumerVertexGroup consumerVertexGroup) { checkState(this.consumerVertexGroup == null); - - final ConsumerVertexGroup consumerVertexGroup = - ConsumerVertexGroup.fromMultipleVertices( - consumerVertices.stream() - .map(TestingSchedulingExecutionVertex::getId) - .collect(Collectors.toList()), - resultPartitionType); - this.consumerVertexGroup = consumerVertexGroup; } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingTopology.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingTopology.java index ec86fa36d9b..b9653db3cf0 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingTopology.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingTopology.java @@ -38,6 +38,7 @@ import java.util.Set; import java.util.function.Function; import java.util.stream.Collectors; +import static org.apache.flink.util.Preconditions.checkNotNull; import static org.apache.flink.util.Preconditions.checkState; /** A simple scheduling topology for testing purposes. */ @@ -189,17 +190,12 @@ public class TestingSchedulingTopology implements SchedulingTopology { TestingSchedulingExecutionVertex consumer, ResultPartitionType resultPartitionType) { - final TestingSchedulingResultPartition resultPartition = - new TestingSchedulingResultPartition.Builder() - .withResultPartitionType(resultPartitionType) - .build(); - - resultPartition.addConsumerGroup( - Collections.singleton(consumer), resultPartition.getResultType()); - resultPartition.setProducer(producer); - - producer.addProducedPartition(resultPartition); - consumer.addConsumedPartition(resultPartition); + connectConsumersToProducers( + Collections.singletonList(consumer), + Collections.singletonList(producer), + new IntermediateDataSetID(), + resultPartitionType, + ResultPartitionState.ALL_DATA_PRODUCED); updateVertexResultPartitions(producer); updateVertexResultPartitions(consumer); @@ -223,6 +219,142 @@ public class TestingSchedulingTopology implements SchedulingTopology { return new ProducerConsumerAllToAllConnectionBuilder(producers, consumers); } + private static List<TestingSchedulingResultPartition> connectConsumersToProducers( + final List<TestingSchedulingExecutionVertex> consumers, + final List<TestingSchedulingExecutionVertex> producers, + final IntermediateDataSetID intermediateDataSetId, + final ResultPartitionType resultPartitionType, + final ResultPartitionState resultPartitionState) { + + final List<TestingSchedulingResultPartition> resultPartitions = new ArrayList<>(); + + final ConnectionResult connectionResult = + connectConsumersToProducersById( + consumers.stream() + .map(SchedulingExecutionVertex::getId) + .collect(Collectors.toList()), + producers.stream() + .map(SchedulingExecutionVertex::getId) + .collect(Collectors.toList()), + intermediateDataSetId, + resultPartitionType); + + final ConsumedPartitionGroup consumedPartitionGroup = + connectionResult.getConsumedPartitionGroup(); + final ConsumerVertexGroup consumerVertexGroup = connectionResult.getConsumerVertexGroup(); + + final TestingSchedulingResultPartition.Builder resultPartitionBuilder = + new TestingSchedulingResultPartition.Builder() + .withIntermediateDataSetID(intermediateDataSetId) + .withResultPartitionType(resultPartitionType) + .withResultPartitionState(resultPartitionState); + + for (int i = 0; i < producers.size(); i++) { + final TestingSchedulingExecutionVertex producer = producers.get(i); + final IntermediateResultPartitionID partitionId = + connectionResult.getResultPartitions().get(i); + final TestingSchedulingResultPartition resultPartition = + resultPartitionBuilder + .withPartitionNum(partitionId.getPartitionNumber()) + .build(); + + producer.addProducedPartition(resultPartition); + + resultPartition.setProducer(producer); + resultPartitions.add(resultPartition); + resultPartition.registerConsumedPartitionGroup(consumedPartitionGroup); + resultPartition.addConsumerGroup(consumerVertexGroup); + + if (resultPartition.getState() == ResultPartitionState.ALL_DATA_PRODUCED) { + consumedPartitionGroup.partitionFinished(); + } + } + + final Map<IntermediateResultPartitionID, TestingSchedulingResultPartition> + consumedPartitionById = + resultPartitions.stream() + .collect( + Collectors.toMap( + TestingSchedulingResultPartition::getId, + Function.identity())); + for (TestingSchedulingExecutionVertex consumer : consumers) { + consumer.addConsumedPartitionGroup(consumedPartitionGroup, consumedPartitionById); + } + + return resultPartitions; + } + + public static ConnectionResult connectConsumersToProducersById( + final List<ExecutionVertexID> consumers, + final List<ExecutionVertexID> producers, + final IntermediateDataSetID intermediateDataSetId, + final ResultPartitionType resultPartitionType) { + + final List<IntermediateResultPartitionID> resultPartitions = new ArrayList<>(); + for (ExecutionVertexID producer : producers) { + final IntermediateResultPartitionID resultPartition = + new IntermediateResultPartitionID( + intermediateDataSetId, producer.getSubtaskIndex()); + resultPartitions.add(resultPartition); + } + + final ConsumedPartitionGroup consumedPartitionGroup = + createConsumedPartitionGroup( + consumers.size(), resultPartitions, resultPartitionType); + final ConsumerVertexGroup consumerVertexGroup = + createConsumerVertexGroup(consumers, resultPartitionType); + + consumedPartitionGroup.setConsumerVertexGroup(consumerVertexGroup); + consumerVertexGroup.setConsumedPartitionGroup(consumedPartitionGroup); + + return new ConnectionResult(resultPartitions, consumedPartitionGroup, consumerVertexGroup); + } + + private static ConsumedPartitionGroup createConsumedPartitionGroup( + final int numConsumers, + final List<IntermediateResultPartitionID> consumedPartitions, + final ResultPartitionType resultPartitionType) { + return ConsumedPartitionGroup.fromMultiplePartitions( + numConsumers, consumedPartitions, resultPartitionType); + } + + private static ConsumerVertexGroup createConsumerVertexGroup( + final List<ExecutionVertexID> consumers, + final ResultPartitionType resultPartitionType) { + return ConsumerVertexGroup.fromMultipleVertices(consumers, resultPartitionType); + } + + /** + * The result of connecting a set of consumers to their producers, including the created result + * partitions and the consumption groups. + */ + public static class ConnectionResult { + private final List<IntermediateResultPartitionID> resultPartitions; + private final ConsumedPartitionGroup consumedPartitionGroup; + private final ConsumerVertexGroup consumerVertexGroup; + + public ConnectionResult( + final List<IntermediateResultPartitionID> resultPartitions, + final ConsumedPartitionGroup consumedPartitionGroup, + final ConsumerVertexGroup consumerVertexGroup) { + this.resultPartitions = checkNotNull(resultPartitions); + this.consumedPartitionGroup = checkNotNull(consumedPartitionGroup); + this.consumerVertexGroup = checkNotNull(consumerVertexGroup); + } + + public List<IntermediateResultPartitionID> getResultPartitions() { + return resultPartitions; + } + + public ConsumedPartitionGroup getConsumedPartitionGroup() { + return consumedPartitionGroup; + } + + public ConsumerVertexGroup getConsumerVertexGroup() { + return consumerVertexGroup; + } + } + /** Builder for {@link TestingSchedulingResultPartition}. */ public abstract class ProducerConsumerConnectionBuilder { @@ -265,11 +397,6 @@ public class TestingSchedulingTopology implements SchedulingTopology { return resultPartitions; } - TestingSchedulingResultPartition.Builder initTestingSchedulingResultPartitionBuilder() { - return new TestingSchedulingResultPartition.Builder() - .withResultPartitionType(resultPartitionType); - } - protected abstract List<TestingSchedulingResultPartition> connect(); } @@ -292,25 +419,15 @@ public class TestingSchedulingTopology implements SchedulingTopology { protected List<TestingSchedulingResultPartition> connect() { final List<TestingSchedulingResultPartition> resultPartitions = new ArrayList<>(); final IntermediateDataSetID intermediateDataSetId = new IntermediateDataSetID(); - for (int idx = 0; idx < producers.size(); idx++) { - final TestingSchedulingExecutionVertex producer = producers.get(idx); - final TestingSchedulingExecutionVertex consumer = consumers.get(idx); - - final TestingSchedulingResultPartition resultPartition = - initTestingSchedulingResultPartitionBuilder() - .withIntermediateDataSetID(intermediateDataSetId) - .withResultPartitionState(resultPartitionState) - .withPartitionNum(idx) - .build(); - resultPartition.setProducer(producer); - producer.addProducedPartition(resultPartition); - consumer.addConsumedPartition(resultPartition); - resultPartition.addConsumerGroup( - Collections.singleton(consumer), resultPartitionType); - resultPartitions.add(resultPartition); + resultPartitions.addAll( + connectConsumersToProducers( + Collections.singletonList(consumers.get(idx)), + Collections.singletonList(producers.get(idx)), + intermediateDataSetId, + resultPartitionType, + resultPartitionState)); } - return resultPartitions; } } @@ -330,53 +447,12 @@ public class TestingSchedulingTopology implements SchedulingTopology { @Override protected List<TestingSchedulingResultPartition> connect() { - final List<TestingSchedulingResultPartition> resultPartitions = new ArrayList<>(); - final IntermediateDataSetID intermediateDataSetId = new IntermediateDataSetID(); - - TestingSchedulingResultPartition.Builder resultPartitionBuilder = - initTestingSchedulingResultPartitionBuilder() - .withIntermediateDataSetID(intermediateDataSetId) - .withResultPartitionState(resultPartitionState); - - int partitionNum = 0; - - for (TestingSchedulingExecutionVertex producer : producers) { - - final TestingSchedulingResultPartition resultPartition = - resultPartitionBuilder.withPartitionNum(partitionNum++).build(); - resultPartition.setProducer(producer); - producer.addProducedPartition(resultPartition); - - resultPartition.addConsumerGroup(consumers, resultPartitionType); - resultPartitions.add(resultPartition); - } - - ConsumedPartitionGroup consumedPartitionGroup = - ConsumedPartitionGroup.fromMultiplePartitions( - consumers.size(), - resultPartitions.stream() - .map(TestingSchedulingResultPartition::getId) - .collect(Collectors.toList()), - resultPartitions.get(0).getResultType()); - Map<IntermediateResultPartitionID, TestingSchedulingResultPartition> - consumedPartitionById = - resultPartitions.stream() - .collect( - Collectors.toMap( - TestingSchedulingResultPartition::getId, - Function.identity())); - for (TestingSchedulingExecutionVertex consumer : consumers) { - consumer.addConsumedPartitionGroup(consumedPartitionGroup, consumedPartitionById); - } - - for (TestingSchedulingResultPartition resultPartition : resultPartitions) { - resultPartition.registerConsumedPartitionGroup(consumedPartitionGroup); - if (resultPartition.getState() == ResultPartitionState.ALL_DATA_PRODUCED) { - consumedPartitionGroup.partitionFinished(); - } - } - - return resultPartitions; + return connectConsumersToProducers( + consumers, + producers, + new IntermediateDataSetID(), + resultPartitionType, + resultPartitionState); } }