[ https://issues.apache.org/jira/browse/FLINK-10820?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16701711#comment-16701711 ]
ASF GitHub Bot commented on FLINK-10820: ---------------------------------------- asfgit closed pull request #7051: [FLINK-10820][network] Simplify the RebalancePartitioner implementation URL: https://github.com/apache/flink/pull/7051 This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ChannelSelector.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ChannelSelector.java index 65012fe6c74..403b75c2ed2 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ChannelSelector.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ChannelSelector.java @@ -28,14 +28,21 @@ */ public interface ChannelSelector<T extends IOReadableWritable> { + /** + * Initializes the channel selector with the number of output channels. + * + * @param numberOfChannels the total number of output channels which are attached + * to respective output gate. + */ + void setup(int numberOfChannels); + /** * Returns the logical channel indexes, to which the given record should be * written. * - * @param record the record to the determine the output channels for - * @param numChannels the total number of output channels which are attached to respective output gate - * @return a (possibly empty) array of integer numbers which indicate the indices of the output channels through - * which the record shall be forwarded + * @param record the record to determine the output channels for. + * @return an array of integer numbers which indicate the indices of the output + * channels through which the record shall be forwarded. */ - int[] selectChannels(T record, int numChannels); + int[] selectChannels(T record); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/RecordWriter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/RecordWriter.java index 84d81837ddd..6a691c31bf9 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/RecordWriter.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/RecordWriter.java @@ -56,7 +56,7 @@ private final ChannelSelector<T> channelSelector; - private final int numChannels; + private final int numberOfChannels; private final int[] broadcastChannels; @@ -85,20 +85,20 @@ public RecordWriter(ResultPartitionWriter writer, ChannelSelector<T> channelSele this.flushAlways = flushAlways; this.targetPartition = writer; this.channelSelector = channelSelector; - - this.numChannels = writer.getNumberOfSubpartitions(); + this.numberOfChannels = writer.getNumberOfSubpartitions(); + this.channelSelector.setup(numberOfChannels); this.serializer = new SpanningRecordSerializer<T>(); - this.bufferBuilders = new Optional[numChannels]; - this.broadcastChannels = new int[numChannels]; - for (int i = 0; i < numChannels; i++) { + this.bufferBuilders = new Optional[numberOfChannels]; + this.broadcastChannels = new int[numberOfChannels]; + for (int i = 0; i < numberOfChannels; i++) { broadcastChannels[i] = i; bufferBuilders[i] = Optional.empty(); } } public void emit(T record) throws IOException, InterruptedException { - emit(record, channelSelector.selectChannels(record, numChannels)); + emit(record, channelSelector.selectChannels(record)); } /** @@ -115,7 +115,7 @@ public void broadcastEmit(T record) throws IOException, InterruptedException { public void randomEmit(T record) throws IOException, InterruptedException { serializer.serializeRecord(record); - if (copyFromSerializerToTargetChannel(rng.nextInt(numChannels))) { + if (copyFromSerializerToTargetChannel(rng.nextInt(numberOfChannels))) { serializer.prune(); } } @@ -174,7 +174,7 @@ private boolean copyFromSerializerToTargetChannel(int targetChannel) throws IOEx public void broadcastEvent(AbstractEvent event) throws IOException { try (BufferConsumer eventBufferConsumer = EventSerializer.toBufferConsumer(event)) { - for (int targetChannel = 0; targetChannel < numChannels; targetChannel++) { + for (int targetChannel = 0; targetChannel < numberOfChannels; targetChannel++) { tryFinishCurrentBufferBuilder(targetChannel); // Retain the buffer so that it can be recycled by each channel of targetPartition @@ -192,7 +192,7 @@ public void flushAll() { } public void clearBuffers() { - for (int targetChannel = 0; targetChannel < numChannels; targetChannel++) { + for (int targetChannel = 0; targetChannel < numberOfChannels; targetChannel++) { closeBufferBuilder(targetChannel); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/RoundRobinChannelSelector.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/RoundRobinChannelSelector.java index 96a4e1a081f..5da9534f340 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/RoundRobinChannelSelector.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/RoundRobinChannelSelector.java @@ -32,9 +32,16 @@ /** Stores the index of the channel to send the next record to. */ private final int[] nextChannelToSendTo = new int[] { -1 }; + private int numberOfChannels; + + @Override + public void setup(int numberOfChannels) { + this.numberOfChannels = numberOfChannels; + } + @Override - public int[] selectChannels(final T record, final int numberOfOutputChannels) { - nextChannelToSendTo[0] = (nextChannelToSendTo[0] + 1) % numberOfOutputChannels; + public int[] selectChannels(final T record) { + nextChannelToSendTo[0] = (nextChannelToSendTo[0] + 1) % numberOfChannels; return nextChannelToSendTo; } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/OutputEmitter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/OutputEmitter.java index e6f3d262477..91547e53418 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/OutputEmitter.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/shipping/OutputEmitter.java @@ -43,6 +43,9 @@ /** counter to go over channels round robin */ private int nextChannelToSendTo = 0; + + /** the total number of output channels */ + private int numberOfChannels; /** the comparator for hashing / sorting */ private final TypeComparator<T> comparator; @@ -131,7 +134,12 @@ public OutputEmitter(ShipStrategyType strategy, int indexInSubtaskGroup, // ------------------------------------------------------------------------ @Override - public final int[] selectChannels(SerializationDelegate<T> record, int numberOfChannels) { + public void setup(int numberOfChannels) { + this.numberOfChannels = numberOfChannels; + } + + @Override + public final int[] selectChannels(SerializationDelegate<T> record) { switch (strategy) { case FORWARD: return forward(); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/DefaultChannelSelectorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/DefaultChannelSelectorTest.java index 61bcde9e6cb..4abc67ecb5e 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/DefaultChannelSelectorTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/DefaultChannelSelectorTest.java @@ -38,19 +38,18 @@ public void channelSelect() { final StringValue dummyRecord = new StringValue("abc"); final RoundRobinChannelSelector<StringValue> selector = new RoundRobinChannelSelector<>(); - final int numberOfChannels = 2; + selector.setup(2); - assertSelectedChannel(selector, dummyRecord, numberOfChannels, 0); - assertSelectedChannel(selector, dummyRecord, numberOfChannels, 1); + assertSelectedChannel(selector, dummyRecord, 0); + assertSelectedChannel(selector, dummyRecord, 1); } private void assertSelectedChannel( ChannelSelector<StringValue> selector, StringValue record, - int numberOfChannels, int expectedChannel) { - int[] actualResult = selector.selectChannels(record, numberOfChannels); + int[] actualResult = selector.selectChannels(record); assertEquals(1, actualResult.length); assertEquals(expectedChannel, actualResult[0]); } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java index 52796024ac9..dcd496150cf 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java @@ -198,12 +198,12 @@ public void testSerializerClearedAfterClearBuffers() throws Exception { */ @Test public void testBroadcastEventNoRecords() throws Exception { - int numChannels = 4; + int numberOfChannels = 4; int bufferSize = 32; @SuppressWarnings("unchecked") - Queue<BufferConsumer>[] queues = new Queue[numChannels]; - for (int i = 0; i < numChannels; i++) { + Queue<BufferConsumer>[] queues = new Queue[numberOfChannels]; + for (int i = 0; i < numberOfChannels; i++) { queues[i] = new ArrayDeque<>(); } @@ -218,7 +218,7 @@ public void testBroadcastEventNoRecords() throws Exception { assertEquals(0, bufferProvider.getNumberOfCreatedBuffers()); - for (int i = 0; i < numChannels; i++) { + for (int i = 0; i < numberOfChannels; i++) { assertEquals(1, queues[i].size()); BufferOrEvent boe = parseBuffer(queues[i].remove(), i); assertTrue(boe.isEvent()); @@ -234,13 +234,13 @@ public void testBroadcastEventNoRecords() throws Exception { @Test public void testBroadcastEventMixedRecords() throws Exception { Random rand = new XORShiftRandom(); - int numChannels = 4; + int numberOfChannels = 4; int bufferSize = 32; int lenBytes = 4; // serialized length @SuppressWarnings("unchecked") - Queue<BufferConsumer>[] queues = new Queue[numChannels]; - for (int i = 0; i < numChannels; i++) { + Queue<BufferConsumer>[] queues = new Queue[numberOfChannels]; + for (int i = 0; i < numberOfChannels; i++) { queues[i] = new ArrayDeque<>(); } @@ -290,7 +290,7 @@ public void testBroadcastEventMixedRecords() throws Exception { assertEquals(1, queues[3].size()); // 0 buffers + 1 event // every queue's last element should be the event - for (int i = 0; i < numChannels; i++) { + for (int i = 0; i < numberOfChannels; i++) { boe = parseBuffer(queues[i].remove(), i); assertTrue(boe.isEvent()); assertEquals(barrier, boe.getEvent()); @@ -413,22 +413,22 @@ public void testBroadcastEmitRecord() throws Exception { * @param isBroadcastEmit whether using {@link RecordWriter#broadcastEmit(IOReadableWritable)} or not */ private void emitRecordWithBroadcastPartitionerOrBroadcastEmitRecord(boolean isBroadcastEmit) throws Exception { - final int numChannels = 4; + final int numberOfChannels = 4; final int bufferSize = 32; final int numValues = 8; final int serializationLength = 4; @SuppressWarnings("unchecked") - final Queue<BufferConsumer>[] queues = new Queue[numChannels]; - for (int i = 0; i < numChannels; i++) { + final Queue<BufferConsumer>[] queues = new Queue[numberOfChannels]; + for (int i = 0; i < numberOfChannels; i++) { queues[i] = new ArrayDeque<>(); } final TestPooledBufferProvider bufferProvider = new TestPooledBufferProvider(Integer.MAX_VALUE, bufferSize); final ResultPartitionWriter partitionWriter = new CollectingPartitionWriter(queues, bufferProvider); + final ChannelSelector selector = new Broadcast<>(); final RecordWriter<SerializationTestType> writer = isBroadcastEmit ? - new RecordWriter<>(partitionWriter) : - new RecordWriter<>(partitionWriter, new Broadcast<>()); + new RecordWriter<>(partitionWriter) : new RecordWriter<>(partitionWriter, selector); final RecordDeserializer<SerializationTestType> deserializer = new SpillingAdaptiveSpanningRecordDeserializer<>( new String[]{ tempFolder.getRoot().getAbsolutePath() }); @@ -445,7 +445,7 @@ private void emitRecordWithBroadcastPartitionerOrBroadcastEmitRecord(boolean isB } final int requiredBuffers = numValues / (bufferSize / (4 + serializationLength)); - for (int i = 0; i < numChannels; i++) { + for (int i = 0; i < numberOfChannels; i++) { assertEquals(requiredBuffers, queues[i].size()); final ArrayDeque<SerializationTestType> expectedRecords = serializedRecords.clone(); @@ -600,8 +600,15 @@ public void read(DataInputView in) throws IOException { private int[] returnChannel; + private int numberOfOutputChannels; + + @Override + public void setup(int numberOfChannels) { + this.numberOfOutputChannels = numberOfChannels; + } + @Override - public int[] selectChannels(final T record, final int numberOfOutputChannels) { + public int[] selectChannels(final T record) { if (returnChannel != null && returnChannel.length == numberOfOutputChannels) { return returnChannel; } else { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java index 5c643af1739..d56139f7e2c 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java @@ -54,13 +54,13 @@ @Test public void testConsumptionWithLocalChannels() throws Exception { - final int numChannels = 11; + final int numberOfChannels = 11; final int buffersPerChannel = 1000; final ResultPartition resultPartition = mock(ResultPartition.class); - final PipelinedSubpartition[] partitions = new PipelinedSubpartition[numChannels]; - final Source[] sources = new Source[numChannels]; + final PipelinedSubpartition[] partitions = new PipelinedSubpartition[numberOfChannels]; + final Source[] sources = new Source[numberOfChannels]; final ResultPartitionManager resultPartitionManager = createResultPartitionManager(partitions); @@ -68,12 +68,12 @@ public void testConsumptionWithLocalChannels() throws Exception { "Test Task Name", new JobID(), new IntermediateDataSetID(), ResultPartitionType.PIPELINED, - 0, numChannels, + 0, numberOfChannels, mock(TaskActions.class), UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(), true); - for (int i = 0; i < numChannels; i++) { + for (int i = 0; i < numberOfChannels; i++) { LocalInputChannel channel = new LocalInputChannel(gate, i, new ResultPartitionID(), resultPartitionManager, mock(TaskEventDispatcher.class), UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup()); gate.setInputChannel(new IntermediateResultPartitionID(), channel); @@ -82,8 +82,8 @@ public void testConsumptionWithLocalChannels() throws Exception { sources[i] = new PipelinedSubpartitionSource(partitions[i]); } - ProducerThread producer = new ProducerThread(sources, numChannels * buffersPerChannel, 4, 10); - ConsumerThread consumer = new ConsumerThread(gate, numChannels * buffersPerChannel); + ProducerThread producer = new ProducerThread(sources, numberOfChannels * buffersPerChannel, 4, 10); + ConsumerThread consumer = new ConsumerThread(gate, numberOfChannels * buffersPerChannel); producer.start(); consumer.start(); @@ -94,23 +94,23 @@ public void testConsumptionWithLocalChannels() throws Exception { @Test public void testConsumptionWithRemoteChannels() throws Exception { - final int numChannels = 11; + final int numberOfChannels = 11; final int buffersPerChannel = 1000; final ConnectionManager connManager = createDummyConnectionManager(); - final Source[] sources = new Source[numChannels]; + final Source[] sources = new Source[numberOfChannels]; final SingleInputGate gate = new SingleInputGate( "Test Task Name", new JobID(), new IntermediateDataSetID(), ResultPartitionType.PIPELINED, 0, - numChannels, + numberOfChannels, mock(TaskActions.class), UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(), true); - for (int i = 0; i < numChannels; i++) { + for (int i = 0; i < numberOfChannels; i++) { RemoteInputChannel channel = new RemoteInputChannel( gate, i, new ResultPartitionID(), mock(ConnectionID.class), connManager, 0, 0, UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup()); @@ -119,8 +119,8 @@ public void testConsumptionWithRemoteChannels() throws Exception { sources[i] = new RemoteChannelSource(channel); } - ProducerThread producer = new ProducerThread(sources, numChannels * buffersPerChannel, 4, 10); - ConsumerThread consumer = new ConsumerThread(gate, numChannels * buffersPerChannel); + ProducerThread producer = new ProducerThread(sources, numberOfChannels * buffersPerChannel, 4, 10); + ConsumerThread consumer = new ConsumerThread(gate, numberOfChannels * buffersPerChannel); producer.start(); consumer.start(); @@ -131,13 +131,13 @@ public void testConsumptionWithRemoteChannels() throws Exception { @Test public void testConsumptionWithMixedChannels() throws Exception { - final int numChannels = 61; + final int numberOfChannels = 61; final int numLocalChannels = 20; final int buffersPerChannel = 1000; // fill the local/remote decision - List<Boolean> localOrRemote = new ArrayList<>(numChannels); - for (int i = 0; i < numChannels; i++) { + List<Boolean> localOrRemote = new ArrayList<>(numberOfChannels); + for (int i = 0; i < numberOfChannels; i++) { localOrRemote.add(i < numLocalChannels); } Collections.shuffle(localOrRemote); @@ -148,19 +148,19 @@ public void testConsumptionWithMixedChannels() throws Exception { final PipelinedSubpartition[] localPartitions = new PipelinedSubpartition[numLocalChannels]; final ResultPartitionManager resultPartitionManager = createResultPartitionManager(localPartitions); - final Source[] sources = new Source[numChannels]; + final Source[] sources = new Source[numberOfChannels]; final SingleInputGate gate = new SingleInputGate( "Test Task Name", new JobID(), new IntermediateDataSetID(), ResultPartitionType.PIPELINED, 0, - numChannels, + numberOfChannels, mock(TaskActions.class), UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(), true); - for (int i = 0, local = 0; i < numChannels; i++) { + for (int i = 0, local = 0; i < numberOfChannels; i++) { if (localOrRemote.get(i)) { // local channel PipelinedSubpartition psp = new PipelinedSubpartition(0, resultPartition); @@ -182,8 +182,8 @@ public void testConsumptionWithMixedChannels() throws Exception { } } - ProducerThread producer = new ProducerThread(sources, numChannels * buffersPerChannel, 4, 10); - ConsumerThread consumer = new ConsumerThread(gate, numChannels * buffersPerChannel); + ProducerThread producer = new ProducerThread(sources, numberOfChannels * buffersPerChannel, 4, 10); + ConsumerThread consumer = new ConsumerThread(gate, numberOfChannels * buffersPerChannel); producer.start(); consumer.start(); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java index 66918757462..7a08d873440 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java @@ -65,7 +65,7 @@ @Test public void testFairConsumptionLocalChannelsPreFilled() throws Exception { - final int numChannels = 37; + final int numberOfChannels = 37; final int buffersPerChannel = 27; final ResultPartition resultPartition = mock(ResultPartition.class); @@ -73,9 +73,9 @@ public void testFairConsumptionLocalChannelsPreFilled() throws Exception { // ----- create some source channels and fill them with buffers ----- - final PipelinedSubpartition[] sources = new PipelinedSubpartition[numChannels]; + final PipelinedSubpartition[] sources = new PipelinedSubpartition[numberOfChannels]; - for (int i = 0; i < numChannels; i++) { + for (int i = 0; i < numberOfChannels; i++) { PipelinedSubpartition partition = new PipelinedSubpartition(0, resultPartition); for (int p = 0; p < buffersPerChannel; p++) { @@ -94,19 +94,19 @@ public void testFairConsumptionLocalChannelsPreFilled() throws Exception { "Test Task Name", new JobID(), new IntermediateDataSetID(), - 0, numChannels, + 0, numberOfChannels, mock(TaskActions.class), UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(), true); - for (int i = 0; i < numChannels; i++) { + for (int i = 0; i < numberOfChannels; i++) { LocalInputChannel channel = new LocalInputChannel(gate, i, new ResultPartitionID(), resultPartitionManager, mock(TaskEventDispatcher.class), UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup()); gate.setInputChannel(new IntermediateResultPartitionID(), channel); } // read all the buffers and the EOF event - for (int i = numChannels * (buffersPerChannel + 1); i > 0; --i) { + for (int i = numberOfChannels * (buffersPerChannel + 1); i > 0; --i) { assertNotNull(gate.getNextBufferOrEvent()); int min = Integer.MAX_VALUE; @@ -126,7 +126,7 @@ public void testFairConsumptionLocalChannelsPreFilled() throws Exception { @Test public void testFairConsumptionLocalChannels() throws Exception { - final int numChannels = 37; + final int numberOfChannels = 37; final int buffersPerChannel = 27; final ResultPartition resultPartition = mock(ResultPartition.class); @@ -134,9 +134,9 @@ public void testFairConsumptionLocalChannels() throws Exception { // ----- create some source channels and fill them with one buffer each ----- - final PipelinedSubpartition[] sources = new PipelinedSubpartition[numChannels]; + final PipelinedSubpartition[] sources = new PipelinedSubpartition[numberOfChannels]; - for (int i = 0; i < numChannels; i++) { + for (int i = 0; i < numberOfChannels; i++) { sources[i] = new PipelinedSubpartition(0, resultPartition); } @@ -148,12 +148,12 @@ public void testFairConsumptionLocalChannels() throws Exception { "Test Task Name", new JobID(), new IntermediateDataSetID(), - 0, numChannels, + 0, numberOfChannels, mock(TaskActions.class), UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(), true); - for (int i = 0; i < numChannels; i++) { + for (int i = 0; i < numberOfChannels; i++) { LocalInputChannel channel = new LocalInputChannel(gate, i, new ResultPartitionID(), resultPartitionManager, mock(TaskEventDispatcher.class), UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup()); gate.setInputChannel(new IntermediateResultPartitionID(), channel); @@ -163,7 +163,7 @@ public void testFairConsumptionLocalChannels() throws Exception { sources[12].add(bufferConsumer.copy()); // read all the buffers and the EOF event - for (int i = 0; i < numChannels * buffersPerChannel; i++) { + for (int i = 0; i < numberOfChannels * buffersPerChannel; i++) { assertNotNull(gate.getNextBufferOrEvent()); int min = Integer.MAX_VALUE; @@ -177,7 +177,7 @@ public void testFairConsumptionLocalChannels() throws Exception { assertTrue(max == min || max == min + 1); - if (i % (2 * numChannels) == 0) { + if (i % (2 * numberOfChannels) == 0) { // add three buffers to each channel, in random order fillRandom(sources, 3, bufferConsumer); } @@ -188,7 +188,7 @@ public void testFairConsumptionLocalChannels() throws Exception { @Test public void testFairConsumptionRemoteChannelsPreFilled() throws Exception { - final int numChannels = 37; + final int numberOfChannels = 37; final int buffersPerChannel = 27; final Buffer mockBuffer = TestBufferFactory.createBuffer(42); @@ -199,16 +199,16 @@ public void testFairConsumptionRemoteChannelsPreFilled() throws Exception { "Test Task Name", new JobID(), new IntermediateDataSetID(), - 0, numChannels, + 0, numberOfChannels, mock(TaskActions.class), UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(), true); final ConnectionManager connManager = createDummyConnectionManager(); - final RemoteInputChannel[] channels = new RemoteInputChannel[numChannels]; + final RemoteInputChannel[] channels = new RemoteInputChannel[numberOfChannels]; - for (int i = 0; i < numChannels; i++) { + for (int i = 0; i < numberOfChannels; i++) { RemoteInputChannel channel = new RemoteInputChannel( gate, i, new ResultPartitionID(), mock(ConnectionID.class), connManager, 0, 0, UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup()); @@ -224,7 +224,7 @@ public void testFairConsumptionRemoteChannelsPreFilled() throws Exception { } // read all the buffers and the EOF event - for (int i = numChannels * (buffersPerChannel + 1); i > 0; --i) { + for (int i = numberOfChannels * (buffersPerChannel + 1); i > 0; --i) { assertNotNull(gate.getNextBufferOrEvent()); int min = Integer.MAX_VALUE; @@ -244,7 +244,7 @@ public void testFairConsumptionRemoteChannelsPreFilled() throws Exception { @Test public void testFairConsumptionRemoteChannels() throws Exception { - final int numChannels = 37; + final int numberOfChannels = 37; final int buffersPerChannel = 27; final Buffer mockBuffer = TestBufferFactory.createBuffer(42); @@ -255,17 +255,17 @@ public void testFairConsumptionRemoteChannels() throws Exception { "Test Task Name", new JobID(), new IntermediateDataSetID(), - 0, numChannels, + 0, numberOfChannels, mock(TaskActions.class), UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(), true); final ConnectionManager connManager = createDummyConnectionManager(); - final RemoteInputChannel[] channels = new RemoteInputChannel[numChannels]; - final int[] channelSequenceNums = new int[numChannels]; + final RemoteInputChannel[] channels = new RemoteInputChannel[numberOfChannels]; + final int[] channelSequenceNums = new int[numberOfChannels]; - for (int i = 0; i < numChannels; i++) { + for (int i = 0; i < numberOfChannels; i++) { RemoteInputChannel channel = new RemoteInputChannel( gate, i, new ResultPartitionID(), mock(ConnectionID.class), connManager, 0, 0, UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup()); @@ -278,7 +278,7 @@ public void testFairConsumptionRemoteChannels() throws Exception { channelSequenceNums[11]++; // read all the buffers and the EOF event - for (int i = 0; i < numChannels * buffersPerChannel; i++) { + for (int i = 0; i < numberOfChannels * buffersPerChannel; i++) { assertNotNull(gate.getNextBufferOrEvent()); int min = Integer.MAX_VALUE; @@ -292,7 +292,7 @@ public void testFairConsumptionRemoteChannels() throws Exception { assertTrue(max == min || max == (min + 1)); - if (i % (2 * numChannels) == 0) { + if (i % (2 * numberOfChannels) == 0) { // add three buffers to each channel, in random order fillRandom(channels, channelSequenceNums, 3, mockBuffer); } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/util/OutputEmitterTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/util/OutputEmitterTest.java index 5c7ed3ad148..0231fbfdd9a 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/util/OutputEmitterTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/util/OutputEmitterTest.java @@ -61,8 +61,8 @@ public void testPartitionHash() { // Test hash corner cases final TestIntComparator testIntComp = new TestIntComparator(); - final ChannelSelector<SerializationDelegate<Integer>> selector = new OutputEmitter<>( - ShipStrategyType.PARTITION_HASH, testIntComp); + final ChannelSelector<SerializationDelegate<Integer>> selector = createChannelSelector( + ShipStrategyType.PARTITION_HASH, testIntComp, 100); final SerializationDelegate<Integer> serializationDelegate = new SerializationDelegate<>(new IntSerializer()); assertPartitionHashSelectedChannels(selector, serializationDelegate, Integer.MIN_VALUE, 100); @@ -74,57 +74,59 @@ public void testPartitionHash() { @Test public void testForward() { - final int numChannels = 100; + final int numberOfChannels = 100; // Test for IntValue - int numRecords = 50000 + numChannels / 2; - verifyForwardSelectedChannels(numRecords, numChannels, RecordType.INTEGER); + int numRecords = 50000 + numberOfChannels / 2; + verifyForwardSelectedChannels(numRecords, numberOfChannels, RecordType.INTEGER); // Test for StringValue - numRecords = 10000 + numChannels / 2; - verifyForwardSelectedChannels(numRecords, numChannels, RecordType.STRING); + numRecords = 10000 + numberOfChannels / 2; + verifyForwardSelectedChannels(numRecords, numberOfChannels, RecordType.STRING); } @Test public void testForcedRebalance() { - final int numChannels = 100; - int toTaskIndex = numChannels * 6 / 7; - int fromTaskIndex = toTaskIndex + numChannels; - int extraRecords = numChannels / 3; + final int numberOfChannels = 100; + int toTaskIndex = numberOfChannels * 6 / 7; + int fromTaskIndex = toTaskIndex + numberOfChannels; + int extraRecords = numberOfChannels / 3; int numRecords = 50000 + extraRecords; final SerializationDelegate<Record> delegate = new SerializationDelegate<>( new RecordSerializerFactory().getSerializer()); final ChannelSelector<SerializationDelegate<Record>> selector = new OutputEmitter<>( ShipStrategyType.PARTITION_FORCED_REBALANCE, fromTaskIndex); + selector.setup(numberOfChannels); // Test for IntValue - int[] hits = getSelectedChannelsHitCount(selector, delegate, RecordType.INTEGER, numRecords, numChannels); + int[] hits = getSelectedChannelsHitCount(selector, delegate, RecordType.INTEGER, numRecords, numberOfChannels); int totalHitCount = 0; for (int i = 0; i < hits.length; i++) { - if (toTaskIndex <= i || i < toTaskIndex+extraRecords - numChannels) { - assertTrue(hits[i] == (numRecords / numChannels) + 1); + if (toTaskIndex <= i || i < toTaskIndex+extraRecords - numberOfChannels) { + assertTrue(hits[i] == (numRecords / numberOfChannels) + 1); } else { - assertTrue(hits[i] == numRecords/numChannels); + assertTrue(hits[i] == numRecords/numberOfChannels); } totalHitCount += hits[i]; } assertTrue(totalHitCount == numRecords); - toTaskIndex = numChannels / 5; - fromTaskIndex = toTaskIndex + 2 * numChannels; - extraRecords = numChannels * 2 / 9; + toTaskIndex = numberOfChannels / 5; + fromTaskIndex = toTaskIndex + 2 * numberOfChannels; + extraRecords = numberOfChannels * 2 / 9; numRecords = 10000 + extraRecords; // Test for StringValue final ChannelSelector<SerializationDelegate<Record>> selector2 = new OutputEmitter<>( ShipStrategyType.PARTITION_FORCED_REBALANCE, fromTaskIndex); - hits = getSelectedChannelsHitCount(selector2, delegate, RecordType.STRING, numRecords, numChannels); + selector2.setup(numberOfChannels); + hits = getSelectedChannelsHitCount(selector2, delegate, RecordType.STRING, numRecords, numberOfChannels); totalHitCount = 0; for (int i = 0; i < hits.length; i++) { if (toTaskIndex <= i && i < toTaskIndex + extraRecords) { - assertTrue(hits[i] == (numRecords / numChannels) + 1); + assertTrue(hits[i] == (numRecords / numberOfChannels) + 1); } else { - assertTrue(hits[i] == numRecords / numChannels); + assertTrue(hits[i] == numRecords / numberOfChannels); } totalHitCount += hits[i]; } @@ -141,15 +143,16 @@ public void testBroadcast() { @Test public void testMultiKeys() { + final int numberOfChannels = 100; + final int numRecords = 5000; final TypeComparator<Record> multiComp = new RecordComparatorFactory( new int[] {0,1, 3}, new Class[] {IntValue.class, StringValue.class, DoubleValue.class}).createComparator(); - final ChannelSelector<SerializationDelegate<Record>> selector = new OutputEmitter<>( - ShipStrategyType.PARTITION_HASH, multiComp); + + final ChannelSelector<SerializationDelegate<Record>> selector = createChannelSelector( + ShipStrategyType.PARTITION_HASH, multiComp, numberOfChannels); final SerializationDelegate<Record> delegate = new SerializationDelegate<>(new RecordSerializerFactory().getSerializer()); - - int numChannels = 100; - int numRecords = 5000; - int[] hits = new int[numChannels]; + + int[] hits = new int[numberOfChannels]; for (int i = 0; i < numRecords; i++) { Record record = new Record(4); record.setField(0, new IntValue(i)); @@ -157,7 +160,7 @@ public void testMultiKeys() { record.setField(3, new DoubleValue(i * 3.141d)); delegate.setInstance(record); - int[] channels = selector.selectChannels(delegate, hits.length); + int[] channels = selector.selectChannels(delegate); for (int channel : channels) { hits[channel]++; } @@ -190,8 +193,8 @@ public void testWrongKeyClass() throws Exception { // Test for IntValue final TypeComparator<Record> doubleComp = new RecordComparatorFactory( new int[] {0}, new Class[] {DoubleValue.class}).createComparator(); - final ChannelSelector<SerializationDelegate<Record>> selector = new OutputEmitter<>( - ShipStrategyType.PARTITION_HASH, doubleComp); + final ChannelSelector<SerializationDelegate<Record>> selector = createChannelSelector( + ShipStrategyType.PARTITION_HASH, doubleComp, 100); final SerializationDelegate<Record> delegate = new SerializationDelegate<>(new RecordSerializerFactory().getSerializer()); PipedInputStream pipedInput = new PipedInputStream(1024 * 1024); @@ -206,15 +209,15 @@ public void testWrongKeyClass() throws Exception { try { delegate.setInstance(record); - selector.selectChannels(delegate, 100); + selector.selectChannels(delegate); } catch (DeserializationException re) { return; } Assert.fail("Expected a NullKeyFieldException."); } - private void verifyPartitionHashSelectedChannels(int numRecords, int numChannels, Enum recordType) { - int[] hits = getSelectedChannelsHitCount(ShipStrategyType.PARTITION_HASH, numRecords, numChannels, recordType); + private void verifyPartitionHashSelectedChannels(int numRecords, int numberOfChannels, Enum recordType) { + int[] hits = getSelectedChannelsHitCount(ShipStrategyType.PARTITION_HASH, numRecords, numberOfChannels, recordType); int totalHitCount = 0; for (int hit : hits) { @@ -224,20 +227,8 @@ private void verifyPartitionHashSelectedChannels(int numRecords, int numChannels assertTrue(totalHitCount == numRecords); } - private void assertPartitionHashSelectedChannels( - ChannelSelector selector, - SerializationDelegate<Integer> serializationDelegate, - int record, - int numChannels) { - serializationDelegate.setInstance(record); - int[] selectedChannels = selector.selectChannels(serializationDelegate, numChannels); - - assertTrue(selectedChannels.length == 1); - assertTrue(selectedChannels[0] >= 0 && selectedChannels[0] <= numChannels - 1); - } - - private void verifyForwardSelectedChannels(int numRecords, int numChannels, Enum recordType) { - int[] hits = getSelectedChannelsHitCount(ShipStrategyType.FORWARD, numRecords, numChannels, recordType); + private void verifyForwardSelectedChannels(int numRecords, int numberOfChannels, Enum recordType) { + int[] hits = getSelectedChannelsHitCount(ShipStrategyType.FORWARD, numRecords, numberOfChannels, recordType); assertTrue(hits[0] == numRecords); for (int i = 1; i < hits.length; i++) { @@ -245,8 +236,8 @@ private void verifyForwardSelectedChannels(int numRecords, int numChannels, Enum } } - private void verifyBroadcastSelectedChannels(int numRecords, int numChannels, Enum recordType) { - int[] hits = getSelectedChannelsHitCount(ShipStrategyType.BROADCAST, numRecords, numChannels, recordType); + private void verifyBroadcastSelectedChannels(int numRecords, int numberOfChannels, Enum recordType) { + int[] hits = getSelectedChannelsHitCount(ShipStrategyType.BROADCAST, numRecords, numberOfChannels, recordType); for (int hit : hits) { assertTrue(hit + "", hit == numRecords); @@ -256,8 +247,8 @@ private void verifyBroadcastSelectedChannels(int numRecords, int numChannels, En private boolean verifyWrongPartitionHashKey(int position, int fieldNum) { final TypeComparator<Record> comparator = new RecordComparatorFactory( new int[] {position}, new Class[] {IntValue.class}).createComparator(); - final ChannelSelector<SerializationDelegate<Record>> selector = new OutputEmitter<>( - ShipStrategyType.PARTITION_HASH, comparator); + final ChannelSelector<SerializationDelegate<Record>> selector = createChannelSelector( + ShipStrategyType.PARTITION_HASH, comparator, 100); final SerializationDelegate<Record> delegate = new SerializationDelegate<>(new RecordSerializerFactory().getSerializer()); Record record = new Record(2); @@ -265,7 +256,7 @@ private boolean verifyWrongPartitionHashKey(int position, int fieldNum) { delegate.setInstance(record); try { - selector.selectChannels(delegate, 100); + selector.selectChannels(delegate); } catch (NullKeyFieldException re) { Assert.assertEquals(position, re.getFieldNumber()); return true; @@ -276,14 +267,23 @@ private boolean verifyWrongPartitionHashKey(int position, int fieldNum) { private int[] getSelectedChannelsHitCount( ShipStrategyType shipStrategyType, int numRecords, - int numChannels, + int numberOfChannels, Enum recordType) { final TypeComparator<Record> comparator = new RecordComparatorFactory( new int[] {0}, new Class[] {recordType == RecordType.INTEGER ? IntValue.class : StringValue.class}).createComparator(); - final ChannelSelector<SerializationDelegate<Record>> selector = new OutputEmitter<>(shipStrategyType, comparator); + final ChannelSelector<SerializationDelegate<Record>> selector = createChannelSelector(shipStrategyType, comparator, numberOfChannels); final SerializationDelegate<Record> delegate = new SerializationDelegate<>(new RecordSerializerFactory().getSerializer()); - return getSelectedChannelsHitCount(selector, delegate, recordType, numRecords, numChannels); + return getSelectedChannelsHitCount(selector, delegate, recordType, numRecords, numberOfChannels); + } + + private ChannelSelector createChannelSelector( + ShipStrategyType shipStrategyType, + TypeComparator comparator, + int numberOfChannels) { + final ChannelSelector selector = new OutputEmitter<>(shipStrategyType, comparator); + selector.setup(numberOfChannels); + return selector; } private int[] getSelectedChannelsHitCount( @@ -291,8 +291,8 @@ private boolean verifyWrongPartitionHashKey(int position, int fieldNum) { SerializationDelegate<Record> delegate, Enum recordType, int numRecords, - int numChannels) { - int[] hits = new int[numChannels]; + int numberOfChannels) { + int[] hits = new int[numberOfChannels]; Value value; for (int i = 0; i < numRecords; i++) { if (recordType == RecordType.INTEGER) { @@ -303,7 +303,7 @@ private boolean verifyWrongPartitionHashKey(int position, int fieldNum) { Record record = new Record(value); delegate.setInstance(record); - int[] channels = selector.selectChannels(delegate, hits.length); + int[] channels = selector.selectChannels(delegate); for (int channel : channels) { hits[channel]++; } @@ -311,6 +311,18 @@ private boolean verifyWrongPartitionHashKey(int position, int fieldNum) { return hits; } + private void assertPartitionHashSelectedChannels( + ChannelSelector selector, + SerializationDelegate<Integer> serializationDelegate, + int record, + int numberOfChannels) { + serializationDelegate.setInstance(record); + int[] selectedChannels = selector.selectChannels(serializationDelegate); + + assertTrue(selectedChannels.length == 1); + assertTrue(selectedChannels[0] >= 0 && selectedChannels[0] <= numberOfChannels - 1); + } + private static class TestIntComparator extends TypeComparator<Integer> { private TypeComparator[] comparators = new TypeComparator[]{new IntComparator(true)}; diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/partitioner/BroadcastPartitioner.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/partitioner/BroadcastPartitioner.java index c796813adc5..0614ca1a83e 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/partitioner/BroadcastPartitioner.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/partitioner/BroadcastPartitioner.java @@ -33,13 +33,12 @@ private int[] returnArray; @Override - public int[] selectChannels(SerializationDelegate<StreamRecord<T>> record, - int numberOfOutputChannels) { - if (returnArray != null && returnArray.length == numberOfOutputChannels) { + public int[] selectChannels(SerializationDelegate<StreamRecord<T>> record) { + if (returnArray != null && returnArray.length == numberOfChannels) { return returnArray; } else { - this.returnArray = new int[numberOfOutputChannels]; - for (int i = 0; i < numberOfOutputChannels; i++) { + this.returnArray = new int[numberOfChannels]; + for (int i = 0; i < numberOfChannels; i++) { returnArray[i] = i; } return returnArray; diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/partitioner/CustomPartitionerWrapper.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/partitioner/CustomPartitionerWrapper.java index f19c87d7dfc..73041d1a9a7 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/partitioner/CustomPartitionerWrapper.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/partitioner/CustomPartitionerWrapper.java @@ -45,17 +45,15 @@ public CustomPartitionerWrapper(Partitioner<K> partitioner, KeySelector<T, K> ke } @Override - public int[] selectChannels(SerializationDelegate<StreamRecord<T>> record, int numberOfOutputChannels) { - - K key = null; + public int[] selectChannels(SerializationDelegate<StreamRecord<T>> record) { + K key; try { key = keySelector.getKey(record.getInstance().getValue()); } catch (Exception e) { throw new RuntimeException("Could not extract key from " + record.getInstance(), e); } - returnArray[0] = partitioner.partition(key, - numberOfOutputChannels); + returnArray[0] = partitioner.partition(key, numberOfChannels); return returnArray; } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/partitioner/ForwardPartitioner.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/partitioner/ForwardPartitioner.java index c952282a810..91530f13f8a 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/partitioner/ForwardPartitioner.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/partitioner/ForwardPartitioner.java @@ -33,7 +33,7 @@ private final int[] returnArray = new int[] {0}; @Override - public int[] selectChannels(SerializationDelegate<StreamRecord<T>> record, int numberOfOutputChannels) { + public int[] selectChannels(SerializationDelegate<StreamRecord<T>> record) { return returnArray; } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/partitioner/GlobalPartitioner.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/partitioner/GlobalPartitioner.java index 69c8d0073b9..9f95ecef9bd 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/partitioner/GlobalPartitioner.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/partitioner/GlobalPartitioner.java @@ -33,8 +33,7 @@ private final int[] returnArray = new int[] { 0 }; @Override - public int[] selectChannels(SerializationDelegate<StreamRecord<T>> record, - int numberOfOutputChannels) { + public int[] selectChannels(SerializationDelegate<StreamRecord<T>> record) { return returnArray; } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/partitioner/KeyGroupStreamPartitioner.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/partitioner/KeyGroupStreamPartitioner.java index ddbdaea5537..9c58e2a9306 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/partitioner/KeyGroupStreamPartitioner.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/partitioner/KeyGroupStreamPartitioner.java @@ -50,17 +50,14 @@ public int getMaxParallelism() { } @Override - public int[] selectChannels( - SerializationDelegate<StreamRecord<T>> record, - int numberOfOutputChannels) { - + public int[] selectChannels(SerializationDelegate<StreamRecord<T>> record) { K key; try { key = keySelector.getKey(record.getInstance().getValue()); } catch (Exception e) { throw new RuntimeException("Could not extract key from " + record.getInstance().getValue(), e); } - returnArray[0] = KeyGroupRangeAssignment.assignKeyToParallelOperator(key, maxParallelism, numberOfOutputChannels); + returnArray[0] = KeyGroupRangeAssignment.assignKeyToParallelOperator(key, maxParallelism, numberOfChannels); return returnArray; } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/partitioner/RebalancePartitioner.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/partitioner/RebalancePartitioner.java index 6c5c063cf10..d74a25dac21 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/partitioner/RebalancePartitioner.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/partitioner/RebalancePartitioner.java @@ -33,27 +33,19 @@ public class RebalancePartitioner<T> extends StreamPartitioner<T> { private static final long serialVersionUID = 1L; - private final int[] returnArray = {Integer.MAX_VALUE - 1}; + private final int[] returnArray = new int[1]; @Override - public int[] selectChannels( - SerializationDelegate<StreamRecord<T>> record, - int numChannels) { - int newChannel = ++returnArray[0]; - if (newChannel >= numChannels) { - returnArray[0] = resetValue(numChannels, newChannel); - } - return returnArray; + public void setup(int numberOfChannels) { + super.setup(numberOfChannels); + + returnArray[0] = ThreadLocalRandom.current().nextInt(numberOfChannels); } - private static int resetValue( - int numChannels, - int newChannel) { - if (newChannel == Integer.MAX_VALUE) { - // Initializes the first partition, this branch is only entered when initializing. - return ThreadLocalRandom.current().nextInt(numChannels); - } - return 0; + @Override + public int[] selectChannels(SerializationDelegate<StreamRecord<T>> record) { + returnArray[0] = (returnArray[0] + 1) % numberOfChannels; + return returnArray; } public StreamPartitioner<T> copy() { diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/partitioner/RescalePartitioner.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/partitioner/RescalePartitioner.java index b9af629b89a..bd65d0b1081 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/partitioner/RescalePartitioner.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/partitioner/RescalePartitioner.java @@ -51,12 +51,12 @@ private final int[] returnArray = new int[] {-1}; @Override - public int[] selectChannels(SerializationDelegate<StreamRecord<T>> record, int numberOfOutputChannels) { - int newChannel = ++this.returnArray[0]; - if (newChannel >= numberOfOutputChannels) { - this.returnArray[0] = 0; + public int[] selectChannels(SerializationDelegate<StreamRecord<T>> record) { + int newChannel = ++returnArray[0]; + if (newChannel >= numberOfChannels) { + returnArray[0] = 0; } - return this.returnArray; + return returnArray; } public StreamPartitioner<T> copy() { diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/partitioner/ShufflePartitioner.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/partitioner/ShufflePartitioner.java index ddcbec72130..0cc1d870075 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/partitioner/ShufflePartitioner.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/partitioner/ShufflePartitioner.java @@ -39,9 +39,8 @@ private final int[] returnArray = new int[1]; @Override - public int[] selectChannels(SerializationDelegate<StreamRecord<T>> record, - int numberOfOutputChannels) { - returnArray[0] = random.nextInt(numberOfOutputChannels); + public int[] selectChannels(SerializationDelegate<StreamRecord<T>> record) { + returnArray[0] = random.nextInt(numberOfChannels); return returnArray; } diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/partitioner/StreamPartitioner.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/partitioner/StreamPartitioner.java index 411aa8bad5e..d023c1af2a0 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/partitioner/StreamPartitioner.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/partitioner/StreamPartitioner.java @@ -32,5 +32,12 @@ ChannelSelector<SerializationDelegate<StreamRecord<T>>>, Serializable { private static final long serialVersionUID = 1L; + protected int numberOfChannels; + + @Override + public void setup(int numberOfChannels) { + this.numberOfChannels = numberOfChannels; + } + public abstract StreamPartitioner<T> copy(); } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferMassiveRandomTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferMassiveRandomTest.java index e968101ca2d..afbb03694f4 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferMassiveRandomTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BarrierBufferMassiveRandomTest.java @@ -132,7 +132,7 @@ public boolean isNextBarrier() { private static class RandomGeneratingInputGate implements InputGate { - private final int numChannels; + private final int numberOfChannels; private final BufferPool[] bufferPools; private final int[] currentBarriers; private final BarrierGenerator[] barrierGens; @@ -146,8 +146,8 @@ public RandomGeneratingInputGate(BufferPool[] bufferPools, BarrierGenerator[] ba } public RandomGeneratingInputGate(BufferPool[] bufferPools, BarrierGenerator[] barrierGens, String owningTaskName) { - this.numChannels = bufferPools.length; - this.currentBarriers = new int[numChannels]; + this.numberOfChannels = bufferPools.length; + this.currentBarriers = new int[numberOfChannels]; this.bufferPools = bufferPools; this.barrierGens = barrierGens; this.owningTaskName = owningTaskName; @@ -155,7 +155,7 @@ public RandomGeneratingInputGate(BufferPool[] bufferPools, BarrierGenerator[] ba @Override public int getNumberOfInputChannels() { - return numChannels; + return numberOfChannels; } @Override @@ -173,7 +173,7 @@ public void requestPartitions() {} @Override public Optional<BufferOrEvent> getNextBufferOrEvent() throws IOException, InterruptedException { - currentChannel = (currentChannel + 1) % numChannels; + currentChannel = (currentChannel + 1) % numberOfChannels; if (barrierGens[currentChannel].isNextBarrier()) { return Optional.of( diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BufferBlockerTestBase.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BufferBlockerTestBase.java index 4448eddbdc9..4533a659656 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BufferBlockerTestBase.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/BufferBlockerTestBase.java @@ -72,7 +72,7 @@ public void testSpillAndRollOverSimple() throws IOException { bufferRnd.setSeed(bufferSeed); final int numEventsAndBuffers = rnd.nextInt(maxNumEventsAndBuffers) + 1; - final int numChannels = rnd.nextInt(maxNumChannels) + 1; + final int numberOfChannels = rnd.nextInt(maxNumChannels) + 1; final ArrayList<BufferOrEvent> events = new ArrayList<BufferOrEvent>(128); @@ -81,10 +81,10 @@ public void testSpillAndRollOverSimple() throws IOException { boolean isEvent = rnd.nextDouble() < 0.05d; BufferOrEvent evt; if (isEvent) { - evt = generateRandomEvent(rnd, numChannels); + evt = generateRandomEvent(rnd, numberOfChannels); events.add(evt); } else { - evt = generateRandomBuffer(bufferRnd.nextInt(PAGE_SIZE) + 1, bufferRnd.nextInt(numChannels)); + evt = generateRandomBuffer(bufferRnd.nextInt(PAGE_SIZE) + 1, bufferRnd.nextInt(numberOfChannels)); } bufferBlocker.add(evt); } @@ -106,7 +106,7 @@ public void testSpillAndRollOverSimple() throws IOException { assertEquals(expected.getEvent(), next.getEvent()); assertEquals(expected.getChannelIndex(), next.getChannelIndex()); } else { - validateBuffer(next, bufferRnd.nextInt(PAGE_SIZE) + 1, bufferRnd.nextInt(numChannels)); + validateBuffer(next, bufferRnd.nextInt(PAGE_SIZE) + 1, bufferRnd.nextInt(numberOfChannels)); } } @@ -150,7 +150,7 @@ public void testSpillWhileReading() throws IOException { final Random bufferRnd = new Random(bufferSeed); final int numEventsAndBuffers = rnd.nextInt(maxNumEventsAndBuffers) + 1; - final int numChannels = rnd.nextInt(maxNumChannels) + 1; + final int numberOfChannels = rnd.nextInt(maxNumChannels) + 1; final ArrayList<BufferOrEvent> events = new ArrayList<BufferOrEvent>(128); @@ -162,10 +162,10 @@ public void testSpillWhileReading() throws IOException { boolean isEvent = rnd.nextDouble() < 0.05; BufferOrEvent evt; if (isEvent) { - evt = generateRandomEvent(rnd, numChannels); + evt = generateRandomEvent(rnd, numberOfChannels); events.add(evt); } else { - evt = generateRandomBuffer(bufferRnd.nextInt(PAGE_SIZE) + 1, bufferRnd.nextInt(numChannels)); + evt = generateRandomBuffer(bufferRnd.nextInt(PAGE_SIZE) + 1, bufferRnd.nextInt(numberOfChannels)); } bufferBlocker.add(evt); generated++; @@ -179,7 +179,7 @@ public void testSpillWhileReading() throws IOException { assertEquals(expected.getChannelIndex(), next.getChannelIndex()); } else { Random validationRnd = currentSequence.bufferRnd; - validateBuffer(next, validationRnd.nextInt(PAGE_SIZE) + 1, validationRnd.nextInt(currentSequence.numChannels)); + validateBuffer(next, validationRnd.nextInt(PAGE_SIZE) + 1, validationRnd.nextInt(currentSequence.numberOfChannels)); } currentNumRecordAndEvents++; @@ -207,7 +207,7 @@ public void testSpillWhileReading() throws IOException { bufferRnd.setSeed(bufferSeed); BufferOrEventSequence seq = bufferBlocker.rollOverReusingResources(); - SequenceToConsume stc = new SequenceToConsume(bufferRnd, events, seq, numEventsAndBuffers, numChannels); + SequenceToConsume stc = new SequenceToConsume(bufferRnd, events, seq, numEventsAndBuffers, numberOfChannels); if (currentSequence == null) { currentSequence = stc; @@ -229,7 +229,7 @@ public void testSpillWhileReading() throws IOException { assertEquals(expected.getChannelIndex(), next.getChannelIndex()); } else { Random validationRnd = currentSequence.bufferRnd; - validateBuffer(next, validationRnd.nextInt(PAGE_SIZE) + 1, validationRnd.nextInt(currentSequence.numChannels)); + validateBuffer(next, validationRnd.nextInt(PAGE_SIZE) + 1, validationRnd.nextInt(currentSequence.numberOfChannels)); } currentNumRecordAndEvents++; @@ -259,13 +259,13 @@ public void testSpillWhileReading() throws IOException { // Utils // ------------------------------------------------------------------------ - private static BufferOrEvent generateRandomEvent(Random rnd, int numChannels) { + private static BufferOrEvent generateRandomEvent(Random rnd, int numberOfChannels) { long magicNumber = rnd.nextLong(); byte[] data = new byte[rnd.nextInt(1000)]; rnd.nextBytes(data); TestEvent evt = new TestEvent(magicNumber, data); - int channelIndex = rnd.nextInt(numChannels); + int channelIndex = rnd.nextInt(numberOfChannels); return new BufferOrEvent(evt, channelIndex); } @@ -307,19 +307,19 @@ private static void validateBuffer(BufferOrEvent boe, int expectedSize, int expe final ArrayList<BufferOrEvent> events; final Random bufferRnd; final int numBuffersAndEvents; - final int numChannels; + final int numberOfChannels; private SequenceToConsume( Random bufferRnd, ArrayList<BufferOrEvent> events, BufferOrEventSequence sequence, int numBuffersAndEvents, - int numChannels) { + int numberOfChannels) { this.bufferRnd = bufferRnd; this.events = events; this.sequence = sequence; this.numBuffersAndEvents = numBuffersAndEvents; - this.numChannels = numChannels; + this.numberOfChannels = numberOfChannels; } } } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java index 6400a175e20..a150b8f3571 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java @@ -36,7 +36,7 @@ private final int pageSize; - private final int numChannels; + private final int numberOfChannels; private final Queue<BufferOrEvent> bufferOrEvents; @@ -46,15 +46,15 @@ private final String owningTaskName; - public MockInputGate(int pageSize, int numChannels, List<BufferOrEvent> bufferOrEvents) { - this(pageSize, numChannels, bufferOrEvents, "MockTask"); + public MockInputGate(int pageSize, int numberOfChannels, List<BufferOrEvent> bufferOrEvents) { + this(pageSize, numberOfChannels, bufferOrEvents, "MockTask"); } - public MockInputGate(int pageSize, int numChannels, List<BufferOrEvent> bufferOrEvents, String owningTaskName) { + public MockInputGate(int pageSize, int numberOfChannels, List<BufferOrEvent> bufferOrEvents, String owningTaskName) { this.pageSize = pageSize; - this.numChannels = numChannels; + this.numberOfChannels = numberOfChannels; this.bufferOrEvents = new ArrayDeque<BufferOrEvent>(bufferOrEvents); - this.closed = new boolean[numChannels]; + this.closed = new boolean[numberOfChannels]; this.owningTaskName = owningTaskName; } @@ -65,7 +65,7 @@ public int getPageSize() { @Override public int getNumberOfInputChannels() { - return numChannels; + return numberOfChannels; } @Override diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/SpilledBufferOrEventSequenceTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/SpilledBufferOrEventSequenceTest.java index c1ff79faf32..09c18e042c3 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/SpilledBufferOrEventSequenceTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/SpilledBufferOrEventSequenceTest.java @@ -135,12 +135,12 @@ public void testBufferSequence() { final long seed = rnd.nextLong(); final int numBuffers = 325; - final int numChannels = 671; + final int numberOfChannels = 671; rnd.setSeed(seed); for (int i = 0; i < numBuffers; i++) { - writeBuffer(fileChannel, rnd.nextInt(pageSize) + 1, rnd.nextInt(numChannels)); + writeBuffer(fileChannel, rnd.nextInt(pageSize) + 1, rnd.nextInt(numberOfChannels)); } fileChannel.position(0L); @@ -150,7 +150,7 @@ public void testBufferSequence() { seq.open(); for (int i = 0; i < numBuffers; i++) { - validateBuffer(seq.getNext(), rnd.nextInt(pageSize) + 1, rnd.nextInt(numChannels)); + validateBuffer(seq.getNext(), rnd.nextInt(pageSize) + 1, rnd.nextInt(numberOfChannels)); } // should have no more data @@ -205,12 +205,12 @@ public void testEventSequence() { try { final Random rnd = new Random(); final int numEvents = 3000; - final int numChannels = 1656; + final int numberOfChannels = 1656; final ArrayList<BufferOrEvent> events = new ArrayList<BufferOrEvent>(numEvents); for (int i = 0; i < numEvents; i++) { - events.add(generateAndWriteEvent(fileChannel, rnd, numChannels)); + events.add(generateAndWriteEvent(fileChannel, rnd, numberOfChannels)); } fileChannel.position(0L); @@ -245,7 +245,7 @@ public void testMixedSequence() { bufferRnd.setSeed(bufferSeed); final int numEventsAndBuffers = 3000; - final int numChannels = 1656; + final int numberOfChannels = 1656; final ArrayList<BufferOrEvent> events = new ArrayList<BufferOrEvent>(128); @@ -254,10 +254,10 @@ public void testMixedSequence() { for (int i = 0; i < numEventsAndBuffers; i++) { boolean isEvent = rnd.nextDouble() < 0.05d; if (isEvent) { - events.add(generateAndWriteEvent(fileChannel, rnd, numChannels)); + events.add(generateAndWriteEvent(fileChannel, rnd, numberOfChannels)); } else { - writeBuffer(fileChannel, bufferRnd.nextInt(pageSize) + 1, bufferRnd.nextInt(numChannels)); + writeBuffer(fileChannel, bufferRnd.nextInt(pageSize) + 1, bufferRnd.nextInt(numberOfChannels)); } } @@ -279,7 +279,7 @@ public void testMixedSequence() { assertEquals(expected.getChannelIndex(), next.getChannelIndex()); } else { - validateBuffer(next, bufferRnd.nextInt(pageSize) + 1, bufferRnd.nextInt(numChannels)); + validateBuffer(next, bufferRnd.nextInt(pageSize) + 1, bufferRnd.nextInt(numberOfChannels)); } } @@ -314,7 +314,7 @@ public void testMultipleSequences() { final int numEventsAndBuffers1 = 272; final int numEventsAndBuffers2 = 151; - final int numChannels = 1656; + final int numberOfChannels = 1656; final ArrayList<BufferOrEvent> events1 = new ArrayList<BufferOrEvent>(128); final ArrayList<BufferOrEvent> events2 = new ArrayList<BufferOrEvent>(128); @@ -324,10 +324,10 @@ public void testMultipleSequences() { for (int i = 0; i < numEventsAndBuffers1; i++) { boolean isEvent = rnd.nextDouble() < 0.05d; if (isEvent) { - events1.add(generateAndWriteEvent(fileChannel, rnd, numChannels)); + events1.add(generateAndWriteEvent(fileChannel, rnd, numberOfChannels)); } else { - writeBuffer(fileChannel, bufferRnd.nextInt(pageSize) + 1, bufferRnd.nextInt(numChannels)); + writeBuffer(fileChannel, bufferRnd.nextInt(pageSize) + 1, bufferRnd.nextInt(numberOfChannels)); } } @@ -336,10 +336,10 @@ public void testMultipleSequences() { for (int i = 0; i < numEventsAndBuffers2; i++) { boolean isEvent = rnd.nextDouble() < 0.05d; if (isEvent) { - events2.add(generateAndWriteEvent(secondChannel, rnd, numChannels)); + events2.add(generateAndWriteEvent(secondChannel, rnd, numberOfChannels)); } else { - writeBuffer(secondChannel, bufferRnd.nextInt(pageSize) + 1, bufferRnd.nextInt(numChannels)); + writeBuffer(secondChannel, bufferRnd.nextInt(pageSize) + 1, bufferRnd.nextInt(numberOfChannels)); } } @@ -365,7 +365,7 @@ public void testMultipleSequences() { assertEquals(expected.getChannelIndex(), next.getChannelIndex()); } else { - validateBuffer(next, bufferRnd.nextInt(pageSize) + 1, bufferRnd.nextInt(numChannels)); + validateBuffer(next, bufferRnd.nextInt(pageSize) + 1, bufferRnd.nextInt(numberOfChannels)); } } assertNull(seq1.getNext()); @@ -383,7 +383,7 @@ public void testMultipleSequences() { assertEquals(expected.getChannelIndex(), next.getChannelIndex()); } else { - validateBuffer(next, bufferRnd.nextInt(pageSize) + 1, bufferRnd.nextInt(numChannels)); + validateBuffer(next, bufferRnd.nextInt(pageSize) + 1, bufferRnd.nextInt(numberOfChannels)); } } assertNull(seq2.getNext()); @@ -435,13 +435,13 @@ public void testCleanup() { // Utils // ------------------------------------------------------------------------ - private static BufferOrEvent generateAndWriteEvent(FileChannel fileChannel, Random rnd, int numChannels) throws IOException { + private static BufferOrEvent generateAndWriteEvent(FileChannel fileChannel, Random rnd, int numberOfChannels) throws IOException { long magicNumber = rnd.nextLong(); byte[] data = new byte[rnd.nextInt(1000)]; rnd.nextBytes(data); TestEvent evt = new TestEvent(magicNumber, data); - int channelIndex = rnd.nextInt(numChannels); + int channelIndex = rnd.nextInt(numberOfChannels); ByteBuffer serializedEvent = EventSerializer.toSerializedEvent(evt); ByteBuffer header = ByteBuffer.allocate(9); diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/BroadcastPartitionerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/BroadcastPartitionerTest.java index 63a45fde064..aea191a9a1c 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/BroadcastPartitionerTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/BroadcastPartitionerTest.java @@ -40,9 +40,9 @@ @Before public void setPartitioner() { - broadcastPartitioner1 = new BroadcastPartitioner<>(); - broadcastPartitioner2 = new BroadcastPartitioner<>(); - broadcastPartitioner3 = new BroadcastPartitioner<>(); + broadcastPartitioner1 = createBroadcastPartitioner(1); + broadcastPartitioner2 = createBroadcastPartitioner(2); + broadcastPartitioner3 = createBroadcastPartitioner(6); } @Test @@ -53,8 +53,14 @@ public void testSelectChannels() { serializationDelegate.setInstance(streamRecord); - assertArrayEquals(first, broadcastPartitioner1.selectChannels(serializationDelegate, 1)); - assertArrayEquals(second, broadcastPartitioner2.selectChannels(serializationDelegate, 2)); - assertArrayEquals(sixth, broadcastPartitioner3.selectChannels(serializationDelegate, 6)); + assertArrayEquals(first, broadcastPartitioner1.selectChannels(serializationDelegate)); + assertArrayEquals(second, broadcastPartitioner2.selectChannels(serializationDelegate)); + assertArrayEquals(sixth, broadcastPartitioner3.selectChannels(serializationDelegate)); + } + + private BroadcastPartitioner<Tuple> createBroadcastPartitioner(int numberOfChannels) { + BroadcastPartitioner<Tuple> broadcastPartitioner = new BroadcastPartitioner<>(); + broadcastPartitioner.setup(numberOfChannels); + return broadcastPartitioner; } } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/ForwardPartitionerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/ForwardPartitionerTest.java index 9b84b1238b1..593119a57b3 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/ForwardPartitionerTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/ForwardPartitionerTest.java @@ -33,8 +33,8 @@ @Test public void testSelectChannelsInterval() { - assertSelectedChannel(0, 1); - assertSelectedChannel(0, 2); - assertSelectedChannel(0, 1024); + assertSelectedChannelWithSetup(0, 1); + assertSelectedChannelWithSetup(0, 2); + assertSelectedChannelWithSetup(0, 1024); } } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/GlobalPartitionerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/GlobalPartitionerTest.java index b382f3e6ea8..194a0997b79 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/GlobalPartitionerTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/GlobalPartitionerTest.java @@ -33,8 +33,8 @@ @Test public void testSelectChannels() { - assertSelectedChannel(0, 1); - assertSelectedChannel(0, 2); - assertSelectedChannel(0, 1024); + assertSelectedChannelWithSetup(0, 1); + assertSelectedChannelWithSetup(0, 2); + assertSelectedChannelWithSetup(0, 1024); } } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/KeyGroupStreamPartitionerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/KeyGroupStreamPartitionerTest.java index 07b57217221..65554f4373f 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/KeyGroupStreamPartitionerTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/KeyGroupStreamPartitionerTest.java @@ -58,9 +58,9 @@ public String getKey(Tuple2<String, Integer> value) throws Exception { public void testSelectChannelsLength() { serializationDelegate1.setInstance(streamRecord1); - assertEquals(1, keyGroupPartitioner.selectChannels(serializationDelegate1, 1).length); - assertEquals(1, keyGroupPartitioner.selectChannels(serializationDelegate1, 2).length); - assertEquals(1, keyGroupPartitioner.selectChannels(serializationDelegate1, 1024).length); + assertEquals(1, selectChannels(serializationDelegate1, 1).length); + assertEquals(1, selectChannels(serializationDelegate1, 2).length); + assertEquals(1, selectChannels(serializationDelegate1, 1024).length); } @Test @@ -68,11 +68,15 @@ public void testSelectChannelsGrouping() { serializationDelegate1.setInstance(streamRecord1); serializationDelegate2.setInstance(streamRecord2); - assertArrayEquals(keyGroupPartitioner.selectChannels(serializationDelegate1, 1), - keyGroupPartitioner.selectChannels(serializationDelegate2, 1)); - assertArrayEquals(keyGroupPartitioner.selectChannels(serializationDelegate1, 2), - keyGroupPartitioner.selectChannels(serializationDelegate2, 2)); - assertArrayEquals(keyGroupPartitioner.selectChannels(serializationDelegate1, 1024), - keyGroupPartitioner.selectChannels(serializationDelegate2, 1024)); + assertArrayEquals(selectChannels(serializationDelegate1, 1), selectChannels(serializationDelegate2, 1)); + assertArrayEquals(selectChannels(serializationDelegate1, 2), selectChannels(serializationDelegate2, 2)); + assertArrayEquals(selectChannels(serializationDelegate1, 1024), selectChannels(serializationDelegate2, 1024)); + } + + private int[] selectChannels( + SerializationDelegate<StreamRecord<Tuple2<String, Integer>>> serializationDelegate, + int numberOfChannels) { + keyGroupPartitioner.setup(numberOfChannels); + return keyGroupPartitioner.selectChannels(serializationDelegate); } } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/RebalancePartitionerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/RebalancePartitionerTest.java index f5ed0aa815a..75b551c51f7 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/RebalancePartitionerTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/RebalancePartitionerTest.java @@ -35,12 +35,15 @@ @Test public void testSelectChannelsInterval() { - int initialChannel = streamPartitioner.selectChannels(serializationDelegate, 3)[0]; + final int numberOfChannels = 3; + streamPartitioner.setup(numberOfChannels); + + int initialChannel = selectChannelAndAssertLength(); assertTrue(0 <= initialChannel); - assertTrue(3 > initialChannel); + assertTrue(numberOfChannels > initialChannel); - assertSelectedChannel((initialChannel + 1) % 3, 3); - assertSelectedChannel((initialChannel + 2) % 3, 3); - assertSelectedChannel((initialChannel + 3) % 3, 3); + for (int i = 1; i <= 3; i++) { + assertSelectedChannel((initialChannel + i) % numberOfChannels); + } } } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/RescalePartitionerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/RescalePartitionerTest.java index 212ffbd335c..6d60d789d5a 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/RescalePartitionerTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/RescalePartitionerTest.java @@ -67,10 +67,12 @@ @Test public void testSelectChannelsInterval() { - assertSelectedChannel(0, 3); - assertSelectedChannel(1, 3); - assertSelectedChannel(2, 3); - assertSelectedChannel(0, 3); + streamPartitioner.setup(3); + + assertSelectedChannel(0); + assertSelectedChannel(1); + assertSelectedChannel(2); + assertSelectedChannel(0); } @Test diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/ShufflePartitionerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/ShufflePartitionerTest.java index 5198ecf7ac9..8c2c4e133ea 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/ShufflePartitionerTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/ShufflePartitionerTest.java @@ -35,16 +35,14 @@ @Test public void testSelectChannelsInterval() { - assertSelectedChannel(0, 1); + assertSelectedChannelWithSetup(0, 1); - assertTrue(0 <= selectChannel(2)); - assertTrue(2 > selectChannel(2)); + streamPartitioner.setup(2); + assertTrue(0 <= selectChannelAndAssertLength()); + assertTrue(2 > selectChannelAndAssertLength()); - assertTrue(0 <= selectChannel(1024)); - assertTrue(1024 > selectChannel(1024)); - } - - private int selectChannel(int numberOfChannels) { - return streamPartitioner.selectChannels(serializationDelegate, numberOfChannels)[0]; + streamPartitioner.setup(1024); + assertTrue(0 <= selectChannelAndAssertLength()); + assertTrue(1024 > selectChannelAndAssertLength()); } } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/StreamPartitionerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/StreamPartitionerTest.java index f98b0a524f5..1e72c496efc 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/StreamPartitionerTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/StreamPartitionerTest.java @@ -23,7 +23,6 @@ import org.apache.flink.util.TestLogger; import org.junit.Before; -import org.junit.Test; import static org.junit.Assert.assertEquals; @@ -44,16 +43,20 @@ public void setup() { serializationDelegate.setInstance(streamRecord); } - @Test - public void testSelectChannelsLength() { - assertEquals(1, streamPartitioner.selectChannels(serializationDelegate, 1).length); - assertEquals(1, streamPartitioner.selectChannels(serializationDelegate, 2).length); - assertEquals(1, streamPartitioner.selectChannels(serializationDelegate, 1024).length); + protected int selectChannelAndAssertLength() { + int[] selectedChannels = streamPartitioner.selectChannels(serializationDelegate); + assertEquals(1, selectedChannels.length); + + return selectedChannels[0]; + } + + protected void assertSelectedChannel(int expectedChannel) { + int actualResult = selectChannelAndAssertLength(); + assertEquals(expectedChannel, actualResult); } - protected void assertSelectedChannel(int expectedChannel, int numberOfChannels) { - int[] actualResult = streamPartitioner.selectChannels(serializationDelegate, numberOfChannels); - assertEquals(1, actualResult.length); - assertEquals(expectedChannel, actualResult[0]); + protected void assertSelectedChannelWithSetup(int expectedChannel, int numberOfChannels) { + streamPartitioner.setup(numberOfChannels); + assertSelectedChannel(expectedChannel); } } ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org > Simplify the RebalancePartitioner implementation > ------------------------------------------------ > > Key: FLINK-10820 > URL: https://issues.apache.org/jira/browse/FLINK-10820 > Project: Flink > Issue Type: Sub-task > Components: Network > Affects Versions: 1.8.0 > Reporter: zhijiang > Assignee: zhijiang > Priority: Minor > Labels: pull-request-available > Fix For: 1.8.0 > > > _The current {{RebalancePartitioner}} implementation seems a little hacky for > selecting a random number as the first channel index, and the following > selections based on this random index in round-robin fashion._ > _Especially for the corner case of {{numChannels = Integer.MAX_VALUE}}, it > would trigger next random index once reaching the last channel index. > Actually the random index should be selected only once at the first time._ -- This message was sent by Atlassian JIRA (v7.6.3#76005)