[FLINK-9256] [network] Fix NPE in SingleInputGate#updateInputChannel() for non-credit based flow control
This closes #5914 Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/56e2b0b5 Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/56e2b0b5 Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/56e2b0b5 Branch: refs/heads/release-1.5 Commit: 56e2b0b5d600935eae590a985643a5879f224d04 Parents: f1fa517 Author: Nico Kruber <[email protected]> Authored: Wed Apr 25 18:28:48 2018 +0200 Committer: Stephan Ewen <[email protected]> Committed: Mon Apr 30 23:25:38 2018 +0200 ---------------------------------------------------------------------- .../runtime/io/network/NetworkEnvironment.java | 4 + .../partition/consumer/SingleInputGate.java | 20 +- .../io/network/NetworkEnvironmentTest.java | 5 +- .../PartitionRequestClientHandlerTest.java | 3 +- .../partition/InputGateConcurrentTest.java | 9 +- .../partition/InputGateFairnessTest.java | 17 +- .../consumer/LocalInputChannelTest.java | 6 +- .../consumer/RemoteInputChannelTest.java | 3 +- .../partition/consumer/SingleInputGateTest.java | 318 +++++++++++++++---- .../partition/consumer/TestSingleInputGate.java | 3 +- .../partition/consumer/UnionInputGateTest.java | 6 +- 11 files changed, 301 insertions(+), 93 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/56e2b0b5/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java index 0a9dc0f..f254756 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java @@ -157,6 +157,10 @@ public class NetworkEnvironment { return partitionRequestMaxBackoff; } + public boolean isCreditBased() { + return enableCreditBased; + } + public KvStateRegistry getKvStateRegistry() { return kvStateRegistry; } http://git-wip-us.apache.org/repos/asf/flink/blob/56e2b0b5/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java index b9091b2..06e80ff 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java @@ -157,9 +157,11 @@ public class SingleInputGate implements InputGate { */ private BufferPool bufferPool; - /** Global network buffer pool to request and recycle exclusive buffers. */ + /** Global network buffer pool to request and recycle exclusive buffers (only for credit-based). */ private NetworkBufferPool networkBufferPool; + private final boolean isCreditBased; + private boolean hasReceivedAllEndOfPartitionEvents; /** Flag indicating whether partitions have been requested. */ @@ -189,7 +191,8 @@ public class SingleInputGate implements InputGate { int consumedSubpartitionIndex, int numberOfInputChannels, TaskActions taskActions, - TaskIOMetricGroup metrics) { + TaskIOMetricGroup metrics, + boolean isCreditBased) { this.owningTaskName = checkNotNull(owningTaskName); this.jobId = checkNotNull(jobId); @@ -208,6 +211,7 @@ public class SingleInputGate implements InputGate { this.enqueuedInputChannelsWithData = new BitSet(numberOfInputChannels); this.taskActions = checkNotNull(taskActions); + this.isCreditBased = isCreditBased; } // ------------------------------------------------------------------------ @@ -288,6 +292,7 @@ public class SingleInputGate implements InputGate { * @param networkBuffersPerChannel The number of exclusive buffers for each channel */ public void assignExclusiveSegments(NetworkBufferPool networkBufferPool, int networkBuffersPerChannel) throws IOException { + checkState(this.isCreditBased, "Bug in input gate setup logic: exclusive buffers only exist with credit-based flow control."); checkState(this.networkBufferPool == null, "Bug in input gate setup logic: global buffer pool has" + "already been set for this input gate."); @@ -347,8 +352,13 @@ public class SingleInputGate implements InputGate { } else if (partitionLocation.isRemote()) { newChannel = unknownChannel.toRemoteInputChannel(partitionLocation.getConnectionId()); - ((RemoteInputChannel)newChannel).assignExclusiveSegments( - networkBufferPool.requestMemorySegments(networkBuffersPerChannel)); + + if (this.isCreditBased) { + checkState(this.networkBufferPool != null, "Bug in input gate setup logic: " + + "global buffer pool has not been set for this input gate."); + ((RemoteInputChannel) newChannel).assignExclusiveSegments( + networkBufferPool.requestMemorySegments(networkBuffersPerChannel)); + } } else { throw new IllegalStateException("Tried to update unknown channel with unknown channel."); @@ -661,7 +671,7 @@ public class SingleInputGate implements InputGate { final SingleInputGate inputGate = new SingleInputGate( owningTaskName, jobId, consumedResultId, consumedPartitionType, consumedSubpartitionIndex, - icdd.length, taskActions, metrics); + icdd.length, taskActions, metrics, networkEnvironment.isCreditBased()); // Create the input channels. There is one input channel for each consumed partition. final InputChannel[] inputChannels = new InputChannel[icdd.length]; http://git-wip-us.apache.org/repos/asf/flink/blob/56e2b0b5/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java index 317a214..f790b5f 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java @@ -329,7 +329,7 @@ public class NetworkEnvironmentTest { * * @return input gate with some fake settings */ - private static SingleInputGate createSingleInputGate( + private SingleInputGate createSingleInputGate( final ResultPartitionType partitionType, final int channels) { return spy(new SingleInputGate( "Test Task Name", @@ -339,7 +339,8 @@ public class NetworkEnvironmentTest { 0, channels, mock(TaskActions.class), - UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup())); + UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(), + enableCreditBasedFlowControl)); } private static void createRemoteInputChannel( http://git-wip-us.apache.org/repos/asf/flink/blob/56e2b0b5/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java index 13f7510..842aed8 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java @@ -221,7 +221,8 @@ public class PartitionRequestClientHandlerTest { 0, 1, mock(TaskActions.class), - UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup()); + UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(), + true); } /** http://git-wip-us.apache.org/repos/asf/flink/blob/56e2b0b5/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java index 289a398..73f3cfb 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java @@ -66,7 +66,8 @@ public class InputGateConcurrentTest { new IntermediateDataSetID(), ResultPartitionType.PIPELINED, 0, numChannels, mock(TaskActions.class), - UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup()); + UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(), + true); for (int i = 0; i < numChannels; i++) { LocalInputChannel channel = new LocalInputChannel(gate, i, new ResultPartitionID(), @@ -102,7 +103,8 @@ public class InputGateConcurrentTest { 0, numChannels, mock(TaskActions.class), - UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup()); + UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(), + true); for (int i = 0; i < numChannels; i++) { RemoteInputChannel channel = new RemoteInputChannel( @@ -151,7 +153,8 @@ public class InputGateConcurrentTest { 0, numChannels, mock(TaskActions.class), - UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup()); + UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(), + true); for (int i = 0, local = 0; i < numChannels; i++) { if (localOrRemote.get(i)) { http://git-wip-us.apache.org/repos/asf/flink/blob/56e2b0b5/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java index 45df56f..82a27cc 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java @@ -93,7 +93,8 @@ public class InputGateFairnessTest { new IntermediateDataSetID(), 0, numChannels, mock(TaskActions.class), - UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup()); + UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(), + true); for (int i = 0; i < numChannels; i++) { LocalInputChannel channel = new LocalInputChannel(gate, i, new ResultPartitionID(), @@ -146,7 +147,8 @@ public class InputGateFairnessTest { new IntermediateDataSetID(), 0, numChannels, mock(TaskActions.class), - UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup()); + UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(), + true); for (int i = 0; i < numChannels; i++) { LocalInputChannel channel = new LocalInputChannel(gate, i, new ResultPartitionID(), @@ -196,7 +198,8 @@ public class InputGateFairnessTest { new IntermediateDataSetID(), 0, numChannels, mock(TaskActions.class), - UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup()); + UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(), + true); final ConnectionManager connManager = createDummyConnectionManager(); @@ -251,7 +254,8 @@ public class InputGateFairnessTest { new IntermediateDataSetID(), 0, numChannels, mock(TaskActions.class), - UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup()); + UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(), + true); final ConnectionManager connManager = createDummyConnectionManager(); @@ -349,11 +353,12 @@ public class InputGateFairnessTest { int consumedSubpartitionIndex, int numberOfInputChannels, TaskActions taskActions, - TaskIOMetricGroup metrics) { + TaskIOMetricGroup metrics, + boolean isCreditBased) { super(owningTaskName, jobId, consumedResultId, ResultPartitionType.PIPELINED, consumedSubpartitionIndex, - numberOfInputChannels, taskActions, metrics); + numberOfInputChannels, taskActions, metrics, isCreditBased); try { Field f = SingleInputGate.class.getDeclaredField("inputChannelsWithData"); http://git-wip-us.apache.org/repos/asf/flink/blob/56e2b0b5/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java index c78b7b9..1ecb67f 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java @@ -293,7 +293,8 @@ public class LocalInputChannelTest { 0, 1, mock(TaskActions.class), - UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup() + UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(), + true ); ResultPartitionManager partitionManager = mock(ResultPartitionManager.class); @@ -490,7 +491,8 @@ public class LocalInputChannelTest { subpartitionIndex, numberOfInputChannels, mock(TaskActions.class), - UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup()); + UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(), + true); // Set buffer pool inputGate.setBufferPool(bufferPool); http://git-wip-us.apache.org/repos/asf/flink/blob/56e2b0b5/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java index 97a5688..802cb93 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java @@ -889,7 +889,8 @@ public class RemoteInputChannelTest { 0, 1, mock(TaskActions.class), - UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup()); + UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(), + true); } private RemoteInputChannel createRemoteInputChannel(SingleInputGate inputGate) http://git-wip-us.apache.org/repos/asf/flink/blob/56e2b0b5/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java index 8c54c1f..c244668 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java @@ -19,13 +19,13 @@ package org.apache.flink.runtime.io.network.partition.consumer; import org.apache.flink.api.common.JobID; -import org.apache.flink.core.memory.MemorySegment; import org.apache.flink.core.memory.MemorySegmentFactory; import org.apache.flink.runtime.deployment.InputChannelDeploymentDescriptor; import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor; import org.apache.flink.runtime.deployment.ResultPartitionLocation; import org.apache.flink.runtime.event.TaskEvent; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; +import org.apache.flink.runtime.io.disk.iomanager.IOManager; import org.apache.flink.runtime.io.network.ConnectionID; import org.apache.flink.runtime.io.network.ConnectionManager; import org.apache.flink.runtime.io.network.LocalConnectionManager; @@ -45,31 +45,51 @@ import org.apache.flink.runtime.io.network.util.TestTaskEvent; import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID; import org.apache.flink.runtime.metrics.groups.UnregisteredMetricGroups; +import org.apache.flink.runtime.query.KvStateRegistry; import org.apache.flink.runtime.taskmanager.TaskActions; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; import java.io.IOException; import java.net.InetSocketAddress; +import java.util.Arrays; +import java.util.List; import java.util.Map; import java.util.Optional; import java.util.concurrent.atomic.AtomicReference; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyInt; -import static org.mockito.Matchers.anyListOf; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +/** + * Tests for {@link SingleInputGate}. + */ +@RunWith(Parameterized.class) public class SingleInputGateTest { + @Parameterized.Parameter + public boolean enableCreditBasedFlowControl; + + @Parameterized.Parameters(name = "Credit-based = {0}") + public static List<Boolean> parameters() { + return Arrays.asList(Boolean.TRUE, Boolean.FALSE); + } + /** * Tests basic correctness of buffer-or-event interleaving and correct <code>null</code> return * value after receiving all end-of-partition events. @@ -324,12 +344,7 @@ public class SingleInputGateTest { int initialBackoff = 137; int maxBackoff = 1001; - NetworkEnvironment netEnv = mock(NetworkEnvironment.class); - when(netEnv.getResultPartitionManager()).thenReturn(new ResultPartitionManager()); - when(netEnv.getTaskEventDispatcher()).thenReturn(new TaskEventDispatcher()); - when(netEnv.getPartitionRequestInitialBackoff()).thenReturn(initialBackoff); - when(netEnv.getPartitionRequestMaxBackoff()).thenReturn(maxBackoff); - when(netEnv.getConnectionManager()).thenReturn(new LocalConnectionManager()); + final NetworkEnvironment netEnv = createNetworkEnvironment(2, 8, initialBackoff, maxBackoff); SingleInputGate gate = SingleInputGate.create( "TestTask", @@ -340,37 +355,43 @@ public class SingleInputGateTest { mock(TaskActions.class), UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup()); - assertEquals(gateDesc.getConsumedPartitionType(), gate.getConsumedPartitionType()); + try { + assertEquals(gateDesc.getConsumedPartitionType(), gate.getConsumedPartitionType()); - Map<IntermediateResultPartitionID, InputChannel> channelMap = gate.getInputChannels(); + Map<IntermediateResultPartitionID, InputChannel> channelMap = gate.getInputChannels(); - assertEquals(3, channelMap.size()); - InputChannel localChannel = channelMap.get(partitionIds[0].getPartitionId()); - assertEquals(LocalInputChannel.class, localChannel.getClass()); + assertEquals(3, channelMap.size()); + InputChannel localChannel = channelMap.get(partitionIds[0].getPartitionId()); + assertEquals(LocalInputChannel.class, localChannel.getClass()); - InputChannel remoteChannel = channelMap.get(partitionIds[1].getPartitionId()); - assertEquals(RemoteInputChannel.class, remoteChannel.getClass()); + InputChannel remoteChannel = channelMap.get(partitionIds[1].getPartitionId()); + assertEquals(RemoteInputChannel.class, remoteChannel.getClass()); - InputChannel unknownChannel = channelMap.get(partitionIds[2].getPartitionId()); - assertEquals(UnknownInputChannel.class, unknownChannel.getClass()); + InputChannel unknownChannel = channelMap.get(partitionIds[2].getPartitionId()); + assertEquals(UnknownInputChannel.class, unknownChannel.getClass()); - InputChannel[] channels = new InputChannel[]{localChannel, remoteChannel, unknownChannel}; - for (InputChannel ch : channels) { - assertEquals(0, ch.getCurrentBackoff()); + InputChannel[] channels = + new InputChannel[] {localChannel, remoteChannel, unknownChannel}; + for (InputChannel ch : channels) { + assertEquals(0, ch.getCurrentBackoff()); - assertTrue(ch.increaseBackoff()); - assertEquals(initialBackoff, ch.getCurrentBackoff()); + assertTrue(ch.increaseBackoff()); + assertEquals(initialBackoff, ch.getCurrentBackoff()); - assertTrue(ch.increaseBackoff()); - assertEquals(initialBackoff * 2, ch.getCurrentBackoff()); + assertTrue(ch.increaseBackoff()); + assertEquals(initialBackoff * 2, ch.getCurrentBackoff()); - assertTrue(ch.increaseBackoff()); - assertEquals(initialBackoff * 2 * 2, ch.getCurrentBackoff()); + assertTrue(ch.increaseBackoff()); + assertEquals(initialBackoff * 2 * 2, ch.getCurrentBackoff()); - assertTrue(ch.increaseBackoff()); - assertEquals(maxBackoff, ch.getCurrentBackoff()); + assertTrue(ch.increaseBackoff()); + assertEquals(maxBackoff, ch.getCurrentBackoff()); - assertFalse(ch.increaseBackoff()); + assertFalse(ch.increaseBackoff()); + } + } finally { + gate.releaseAllResources(); + netEnv.shutdown(); } } @@ -379,26 +400,39 @@ public class SingleInputGateTest { */ @Test public void testRequestBuffersWithRemoteInputChannel() throws Exception { - final SingleInputGate inputGate = new SingleInputGate( - "t1", - new JobID(), - new IntermediateDataSetID(), - ResultPartitionType.PIPELINED_BOUNDED, - 0, - 1, - mock(TaskActions.class), - UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup()); - - RemoteInputChannel remote = mock(RemoteInputChannel.class); - inputGate.setInputChannel(new IntermediateResultPartitionID(), remote); - - final int buffersPerChannel = 2; - NetworkBufferPool network = mock(NetworkBufferPool.class); - // Trigger requests of segments from global pool and assign buffers to remote input channel - inputGate.assignExclusiveSegments(network, buffersPerChannel); - - verify(network, times(1)).requestMemorySegments(buffersPerChannel); - verify(remote, times(1)).assignExclusiveSegments(anyListOf(MemorySegment.class)); + final SingleInputGate inputGate = createInputGate(1, ResultPartitionType.PIPELINED_BOUNDED); + int buffersPerChannel = 2; + int extraNetworkBuffersPerGate = 8; + final NetworkEnvironment network = createNetworkEnvironment(buffersPerChannel, + extraNetworkBuffersPerGate, 0, 0); + + try { + final ResultPartitionID resultPartitionId = new ResultPartitionID(); + final ConnectionID connectionId = new ConnectionID(new InetSocketAddress("localhost", 5000), 0); + addRemoteInputChannel(network, inputGate, connectionId, resultPartitionId, 0); + + network.setupInputGate(inputGate); + + NetworkBufferPool bufferPool = network.getNetworkBufferPool(); + if (enableCreditBasedFlowControl) { + verify(bufferPool, + times(1)).requestMemorySegments(buffersPerChannel); + RemoteInputChannel remote = (RemoteInputChannel) inputGate.getInputChannels() + .get(resultPartitionId.getPartitionId()); + // only the exclusive buffers should be assigned/available now + assertEquals(buffersPerChannel, remote.getNumberOfAvailableBuffers()); + + assertEquals(bufferPool.getTotalNumberOfMemorySegments() - buffersPerChannel, + bufferPool.getNumberOfAvailableMemorySegments()); + // note: exclusive buffers are not handed out into LocalBufferPool and are thus not counted + assertEquals(extraNetworkBuffersPerGate, bufferPool.countBuffers()); + } else { + assertEquals(buffersPerChannel + extraNetworkBuffersPerGate, bufferPool.countBuffers()); + } + } finally { + inputGate.releaseAllResources(); + network.shutdown(); + } } /** @@ -407,51 +441,195 @@ public class SingleInputGateTest { */ @Test public void testRequestBuffersWithUnknownInputChannel() throws Exception { - final SingleInputGate inputGate = createInputGate(1); + final SingleInputGate inputGate = createInputGate(1, ResultPartitionType.PIPELINED_BOUNDED); + int buffersPerChannel = 2; + int extraNetworkBuffersPerGate = 8; + final NetworkEnvironment network = createNetworkEnvironment(buffersPerChannel, extraNetworkBuffersPerGate, 0, 0); - UnknownInputChannel unknown = mock(UnknownInputChannel.class); - final ResultPartitionID resultPartitionId = new ResultPartitionID(); - inputGate.setInputChannel(resultPartitionId.getPartitionId(), unknown); + try { + final ResultPartitionID resultPartitionId = new ResultPartitionID(); + addUnknownInputChannel(network, inputGate, resultPartitionId, 0); - RemoteInputChannel remote = mock(RemoteInputChannel.class); - final ConnectionID connectionId = new ConnectionID(new InetSocketAddress("localhost", 5000), 0); - when(unknown.toRemoteInputChannel(connectionId)).thenReturn(remote); + network.setupInputGate(inputGate); + NetworkBufferPool bufferPool = network.getNetworkBufferPool(); - final int buffersPerChannel = 2; - NetworkBufferPool network = mock(NetworkBufferPool.class); - inputGate.assignExclusiveSegments(network, buffersPerChannel); + if (enableCreditBasedFlowControl) { + verify(bufferPool, times(0)).requestMemorySegments(buffersPerChannel); - // Trigger updates to remote input channel from unknown input channel - inputGate.updateInputChannel(new InputChannelDeploymentDescriptor( - resultPartitionId, - ResultPartitionLocation.createRemote(connectionId))); + assertEquals(bufferPool.getTotalNumberOfMemorySegments(), + bufferPool.getNumberOfAvailableMemorySegments()); + // note: exclusive buffers are not handed out into LocalBufferPool and are thus not counted + assertEquals(extraNetworkBuffersPerGate, bufferPool.countBuffers()); + } else { + assertEquals(buffersPerChannel + extraNetworkBuffersPerGate, bufferPool.countBuffers()); + } + + // Trigger updates to remote input channel from unknown input channel + final ConnectionID connectionId = new ConnectionID(new InetSocketAddress("localhost", 5000), 0); + inputGate.updateInputChannel(new InputChannelDeploymentDescriptor( + resultPartitionId, + ResultPartitionLocation.createRemote(connectionId))); + + if (enableCreditBasedFlowControl) { + verify(bufferPool, + times(1)).requestMemorySegments(buffersPerChannel); + RemoteInputChannel remote = (RemoteInputChannel) inputGate.getInputChannels() + .get(resultPartitionId.getPartitionId()); + // only the exclusive buffers should be assigned/available now + assertEquals(buffersPerChannel, remote.getNumberOfAvailableBuffers()); + + assertEquals(bufferPool.getTotalNumberOfMemorySegments() - buffersPerChannel, + bufferPool.getNumberOfAvailableMemorySegments()); + // note: exclusive buffers are not handed out into LocalBufferPool and are thus not counted + assertEquals(extraNetworkBuffersPerGate, bufferPool.countBuffers()); + } else { + assertEquals(buffersPerChannel + extraNetworkBuffersPerGate, bufferPool.countBuffers()); + } + } finally { + inputGate.releaseAllResources(); + network.shutdown(); + } + } - verify(network, times(1)).requestMemorySegments(buffersPerChannel); - verify(remote, times(1)).assignExclusiveSegments(anyListOf(MemorySegment.class)); + /** + * Tests that input gate can successfully convert unknown input channels into local and remote + * channels. + */ + @Test + public void testUpdateUnknownInputChannel() throws Exception { + final SingleInputGate inputGate = createInputGate(2); + int buffersPerChannel = 2; + final NetworkEnvironment network = createNetworkEnvironment(buffersPerChannel, 8, 0, 0); + + try { + final ResultPartitionID localResultPartitionId = new ResultPartitionID(); + addUnknownInputChannel(network, inputGate, localResultPartitionId, 0); + + final ResultPartitionID remoteResultPartitionId = new ResultPartitionID(); + addUnknownInputChannel(network, inputGate, remoteResultPartitionId, 1); + + network.setupInputGate(inputGate); + + assertThat(inputGate.getInputChannels().get(remoteResultPartitionId.getPartitionId()), + is(instanceOf((UnknownInputChannel.class)))); + assertThat(inputGate.getInputChannels().get(localResultPartitionId.getPartitionId()), + is(instanceOf((UnknownInputChannel.class)))); + + // Trigger updates to remote input channel from unknown input channel + final ConnectionID remoteConnectionId = new ConnectionID(new InetSocketAddress("localhost", 5000), 0); + inputGate.updateInputChannel(new InputChannelDeploymentDescriptor( + remoteResultPartitionId, + ResultPartitionLocation.createRemote(remoteConnectionId))); + + assertThat(inputGate.getInputChannels().get(remoteResultPartitionId.getPartitionId()), + is(instanceOf((RemoteInputChannel.class)))); + assertThat(inputGate.getInputChannels().get(localResultPartitionId.getPartitionId()), + is(instanceOf((UnknownInputChannel.class)))); + + // Trigger updates to local input channel from unknown input channel + inputGate.updateInputChannel(new InputChannelDeploymentDescriptor( + localResultPartitionId, + ResultPartitionLocation.createLocal())); + + assertThat(inputGate.getInputChannels().get(remoteResultPartitionId.getPartitionId()), + is(instanceOf((RemoteInputChannel.class)))); + assertThat(inputGate.getInputChannels().get(localResultPartitionId.getPartitionId()), + is(instanceOf((LocalInputChannel.class)))); + } finally { + inputGate.releaseAllResources(); + network.shutdown(); + } } // --------------------------------------------------------------------------------------------- - private static SingleInputGate createInputGate() { + private NetworkEnvironment createNetworkEnvironment( + int buffersPerChannel, + int extraNetworkBuffersPerGate, + int initialBackoff, + int maxBackoff) { + return new NetworkEnvironment( + spy(new NetworkBufferPool(100, 32)), + new LocalConnectionManager(), + new ResultPartitionManager(), + new TaskEventDispatcher(), + new KvStateRegistry(), + null, + null, + IOManager.IOMode.SYNC, + initialBackoff, + maxBackoff, + buffersPerChannel, + extraNetworkBuffersPerGate, + enableCreditBasedFlowControl); + } + + private SingleInputGate createInputGate() { return createInputGate(2); } - private static SingleInputGate createInputGate(int numberOfInputChannels) { + private SingleInputGate createInputGate(int numberOfInputChannels) { + return createInputGate(numberOfInputChannels, ResultPartitionType.PIPELINED); + } + + private SingleInputGate createInputGate( + int numberOfInputChannels, ResultPartitionType partitionType) { SingleInputGate inputGate = new SingleInputGate( "Test Task Name", new JobID(), new IntermediateDataSetID(), - ResultPartitionType.PIPELINED, + partitionType, 0, numberOfInputChannels, mock(TaskActions.class), - UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup()); + UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(), + enableCreditBasedFlowControl); - assertEquals(ResultPartitionType.PIPELINED, inputGate.getConsumedPartitionType()); + assertEquals(partitionType, inputGate.getConsumedPartitionType()); return inputGate; } + private void addUnknownInputChannel( + NetworkEnvironment network, + SingleInputGate inputGate, + ResultPartitionID partitionId, + int channelIndex) { + UnknownInputChannel unknown = + createUnknownInputChannel(network, inputGate, partitionId, channelIndex); + inputGate.setInputChannel(partitionId.getPartitionId(), unknown); + } + + private UnknownInputChannel createUnknownInputChannel( + NetworkEnvironment network, + SingleInputGate inputGate, + ResultPartitionID partitionId, + int channelIndex) { + return new UnknownInputChannel( + inputGate, + channelIndex, + partitionId, + network.getResultPartitionManager(), + network.getTaskEventDispatcher(), + network.getConnectionManager(), + network.getPartitionRequestInitialBackoff(), + network.getPartitionRequestMaxBackoff(), + UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup() + ); + } + + private void addRemoteInputChannel( + NetworkEnvironment network, + SingleInputGate inputGate, + ConnectionID connectionId, + ResultPartitionID partitionId, + int channelIndex) { + RemoteInputChannel remote = + createUnknownInputChannel(network, inputGate, partitionId, channelIndex) + .toRemoteInputChannel(connectionId); + inputGate.setInputChannel(partitionId.getPartitionId(), remote); + } + static void verifyBufferOrEvent( InputGate inputGate, boolean expectedIsBuffer, http://git-wip-us.apache.org/repos/asf/flink/blob/56e2b0b5/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestSingleInputGate.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestSingleInputGate.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestSingleInputGate.java index 0ae6e74..33dc1ca 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestSingleInputGate.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestSingleInputGate.java @@ -60,7 +60,8 @@ public class TestSingleInputGate { 0, numberOfInputChannels, mock(TaskActions.class), - UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup()); + UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(), + true); this.inputGate = spy(realGate); http://git-wip-us.apache.org/repos/asf/flink/blob/56e2b0b5/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java index 912cd5b..081d97d 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java @@ -50,13 +50,15 @@ public class UnionInputGateTest { new IntermediateDataSetID(), ResultPartitionType.PIPELINED, 0, 3, mock(TaskActions.class), - UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup()); + UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(), + true); final SingleInputGate ig2 = new SingleInputGate( testTaskName, new JobID(), new IntermediateDataSetID(), ResultPartitionType.PIPELINED, 0, 5, mock(TaskActions.class), - UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup()); + UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(), + true); final UnionInputGate union = new UnionInputGate(new SingleInputGate[]{ig1, ig2});
