[FLINK-9256] [network] Fix NPE in SingleInputGate#updateInputChannel() for 
non-credit based flow control

This closes #5914


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/56e2b0b5
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/56e2b0b5
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/56e2b0b5

Branch: refs/heads/release-1.5
Commit: 56e2b0b5d600935eae590a985643a5879f224d04
Parents: f1fa517
Author: Nico Kruber <[email protected]>
Authored: Wed Apr 25 18:28:48 2018 +0200
Committer: Stephan Ewen <[email protected]>
Committed: Mon Apr 30 23:25:38 2018 +0200

----------------------------------------------------------------------
 .../runtime/io/network/NetworkEnvironment.java  |   4 +
 .../partition/consumer/SingleInputGate.java     |  20 +-
 .../io/network/NetworkEnvironmentTest.java      |   5 +-
 .../PartitionRequestClientHandlerTest.java      |   3 +-
 .../partition/InputGateConcurrentTest.java      |   9 +-
 .../partition/InputGateFairnessTest.java        |  17 +-
 .../consumer/LocalInputChannelTest.java         |   6 +-
 .../consumer/RemoteInputChannelTest.java        |   3 +-
 .../partition/consumer/SingleInputGateTest.java | 318 +++++++++++++++----
 .../partition/consumer/TestSingleInputGate.java |   3 +-
 .../partition/consumer/UnionInputGateTest.java  |   6 +-
 11 files changed, 301 insertions(+), 93 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/56e2b0b5/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java
index 0a9dc0f..f254756 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java
@@ -157,6 +157,10 @@ public class NetworkEnvironment {
                return partitionRequestMaxBackoff;
        }
 
+       public boolean isCreditBased() {
+               return enableCreditBased;
+       }
+
        public KvStateRegistry getKvStateRegistry() {
                return kvStateRegistry;
        }

http://git-wip-us.apache.org/repos/asf/flink/blob/56e2b0b5/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
index b9091b2..06e80ff 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
@@ -157,9 +157,11 @@ public class SingleInputGate implements InputGate {
         */
        private BufferPool bufferPool;
 
-       /** Global network buffer pool to request and recycle exclusive 
buffers. */
+       /** Global network buffer pool to request and recycle exclusive buffers 
(only for credit-based). */
        private NetworkBufferPool networkBufferPool;
 
+       private final boolean isCreditBased;
+
        private boolean hasReceivedAllEndOfPartitionEvents;
 
        /** Flag indicating whether partitions have been requested. */
@@ -189,7 +191,8 @@ public class SingleInputGate implements InputGate {
                int consumedSubpartitionIndex,
                int numberOfInputChannels,
                TaskActions taskActions,
-               TaskIOMetricGroup metrics) {
+               TaskIOMetricGroup metrics,
+               boolean isCreditBased) {
 
                this.owningTaskName = checkNotNull(owningTaskName);
                this.jobId = checkNotNull(jobId);
@@ -208,6 +211,7 @@ public class SingleInputGate implements InputGate {
                this.enqueuedInputChannelsWithData = new 
BitSet(numberOfInputChannels);
 
                this.taskActions = checkNotNull(taskActions);
+               this.isCreditBased = isCreditBased;
        }
 
        // 
------------------------------------------------------------------------
@@ -288,6 +292,7 @@ public class SingleInputGate implements InputGate {
         * @param networkBuffersPerChannel The number of exclusive buffers for 
each channel
         */
        public void assignExclusiveSegments(NetworkBufferPool 
networkBufferPool, int networkBuffersPerChannel) throws IOException {
+               checkState(this.isCreditBased, "Bug in input gate setup logic: 
exclusive buffers only exist with credit-based flow control.");
                checkState(this.networkBufferPool == null, "Bug in input gate 
setup logic: global buffer pool has" +
                        "already been set for this input gate.");
 
@@ -347,8 +352,13 @@ public class SingleInputGate implements InputGate {
                                }
                                else if (partitionLocation.isRemote()) {
                                        newChannel = 
unknownChannel.toRemoteInputChannel(partitionLocation.getConnectionId());
-                                       
((RemoteInputChannel)newChannel).assignExclusiveSegments(
-                                               
networkBufferPool.requestMemorySegments(networkBuffersPerChannel));
+
+                                       if (this.isCreditBased) {
+                                               
checkState(this.networkBufferPool != null, "Bug in input gate setup logic: " +
+                                                       "global buffer pool has 
not been set for this input gate.");
+                                               ((RemoteInputChannel) 
newChannel).assignExclusiveSegments(
+                                                       
networkBufferPool.requestMemorySegments(networkBuffersPerChannel));
+                                       }
                                }
                                else {
                                        throw new IllegalStateException("Tried 
to update unknown channel with unknown channel.");
@@ -661,7 +671,7 @@ public class SingleInputGate implements InputGate {
 
                final SingleInputGate inputGate = new SingleInputGate(
                        owningTaskName, jobId, consumedResultId, 
consumedPartitionType, consumedSubpartitionIndex,
-                       icdd.length, taskActions, metrics);
+                       icdd.length, taskActions, metrics, 
networkEnvironment.isCreditBased());
 
                // Create the input channels. There is one input channel for 
each consumed partition.
                final InputChannel[] inputChannels = new 
InputChannel[icdd.length];

http://git-wip-us.apache.org/repos/asf/flink/blob/56e2b0b5/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java
index 317a214..f790b5f 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java
@@ -329,7 +329,7 @@ public class NetworkEnvironmentTest {
         *
         * @return input gate with some fake settings
         */
-       private static SingleInputGate createSingleInputGate(
+       private SingleInputGate createSingleInputGate(
                        final ResultPartitionType partitionType, final int 
channels) {
                return spy(new SingleInputGate(
                        "Test Task Name",
@@ -339,7 +339,8 @@ public class NetworkEnvironmentTest {
                        0,
                        channels,
                        mock(TaskActions.class),
-                       
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup()));
+                       
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+                       enableCreditBasedFlowControl));
        }
 
        private static void createRemoteInputChannel(

http://git-wip-us.apache.org/repos/asf/flink/blob/56e2b0b5/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java
index 13f7510..842aed8 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java
@@ -221,7 +221,8 @@ public class PartitionRequestClientHandlerTest {
                        0,
                        1,
                        mock(TaskActions.class),
-                       
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
+                       
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+                       true);
        }
 
        /**

http://git-wip-us.apache.org/repos/asf/flink/blob/56e2b0b5/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java
index 289a398..73f3cfb 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java
@@ -66,7 +66,8 @@ public class InputGateConcurrentTest {
                                new IntermediateDataSetID(), 
ResultPartitionType.PIPELINED,
                                0, numChannels,
                                mock(TaskActions.class),
-                               
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
+                               
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+                               true);
 
                for (int i = 0; i < numChannels; i++) {
                        LocalInputChannel channel = new LocalInputChannel(gate, 
i, new ResultPartitionID(),
@@ -102,7 +103,8 @@ public class InputGateConcurrentTest {
                                0,
                                numChannels,
                                mock(TaskActions.class),
-                               
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
+                               
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+                               true);
 
                for (int i = 0; i < numChannels; i++) {
                        RemoteInputChannel channel = new RemoteInputChannel(
@@ -151,7 +153,8 @@ public class InputGateConcurrentTest {
                                0,
                                numChannels,
                                mock(TaskActions.class),
-                               
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
+                               
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+                               true);
 
                for (int i = 0, local = 0; i < numChannels; i++) {
                        if (localOrRemote.get(i)) {

http://git-wip-us.apache.org/repos/asf/flink/blob/56e2b0b5/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java
index 45df56f..82a27cc 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java
@@ -93,7 +93,8 @@ public class InputGateFairnessTest {
                                new IntermediateDataSetID(),
                                0, numChannels,
                                mock(TaskActions.class),
-                               
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
+                               
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+                               true);
 
                for (int i = 0; i < numChannels; i++) {
                        LocalInputChannel channel = new LocalInputChannel(gate, 
i, new ResultPartitionID(),
@@ -146,7 +147,8 @@ public class InputGateFairnessTest {
                                new IntermediateDataSetID(),
                                0, numChannels,
                                mock(TaskActions.class),
-                               
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
+                               
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+                               true);
 
                        for (int i = 0; i < numChannels; i++) {
                                LocalInputChannel channel = new 
LocalInputChannel(gate, i, new ResultPartitionID(),
@@ -196,7 +198,8 @@ public class InputGateFairnessTest {
                                new IntermediateDataSetID(),
                                0, numChannels,
                                mock(TaskActions.class),
-                               
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
+                               
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+                               true);
 
                final ConnectionManager connManager = 
createDummyConnectionManager();
 
@@ -251,7 +254,8 @@ public class InputGateFairnessTest {
                                new IntermediateDataSetID(),
                                0, numChannels,
                                mock(TaskActions.class),
-                               
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
+                               
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+                               true);
 
                final ConnectionManager connManager = 
createDummyConnectionManager();
 
@@ -349,11 +353,12 @@ public class InputGateFairnessTest {
                                int consumedSubpartitionIndex,
                                int numberOfInputChannels,
                                TaskActions taskActions,
-                               TaskIOMetricGroup metrics) {
+                               TaskIOMetricGroup metrics,
+                               boolean isCreditBased) {
 
                        super(owningTaskName, jobId, consumedResultId, 
ResultPartitionType.PIPELINED,
                                consumedSubpartitionIndex,
-                                       numberOfInputChannels, taskActions, 
metrics);
+                                       numberOfInputChannels, taskActions, 
metrics, isCreditBased);
 
                        try {
                                Field f = 
SingleInputGate.class.getDeclaredField("inputChannelsWithData");

http://git-wip-us.apache.org/repos/asf/flink/blob/56e2b0b5/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java
index c78b7b9..1ecb67f 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannelTest.java
@@ -293,7 +293,8 @@ public class LocalInputChannelTest {
                        0,
                        1,
                        mock(TaskActions.class),
-                       
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup()
+                       
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+                       true
                );
 
                ResultPartitionManager partitionManager = 
mock(ResultPartitionManager.class);
@@ -490,7 +491,8 @@ public class LocalInputChannelTest {
                                        subpartitionIndex,
                                        numberOfInputChannels,
                                        mock(TaskActions.class),
-                                       
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
+                                       
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+                                       true);
 
                        // Set buffer pool
                        inputGate.setBufferPool(bufferPool);

http://git-wip-us.apache.org/repos/asf/flink/blob/56e2b0b5/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java
index 97a5688..802cb93 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java
@@ -889,7 +889,8 @@ public class RemoteInputChannelTest {
                        0,
                        1,
                        mock(TaskActions.class),
-                       
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
+                       
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+                       true);
        }
 
        private RemoteInputChannel createRemoteInputChannel(SingleInputGate 
inputGate)

http://git-wip-us.apache.org/repos/asf/flink/blob/56e2b0b5/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
index 8c54c1f..c244668 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
@@ -19,13 +19,13 @@
 package org.apache.flink.runtime.io.network.partition.consumer;
 
 import org.apache.flink.api.common.JobID;
-import org.apache.flink.core.memory.MemorySegment;
 import org.apache.flink.core.memory.MemorySegmentFactory;
 import org.apache.flink.runtime.deployment.InputChannelDeploymentDescriptor;
 import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
 import org.apache.flink.runtime.deployment.ResultPartitionLocation;
 import org.apache.flink.runtime.event.TaskEvent;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
+import org.apache.flink.runtime.io.disk.iomanager.IOManager;
 import org.apache.flink.runtime.io.network.ConnectionID;
 import org.apache.flink.runtime.io.network.ConnectionManager;
 import org.apache.flink.runtime.io.network.LocalConnectionManager;
@@ -45,31 +45,51 @@ import 
org.apache.flink.runtime.io.network.util.TestTaskEvent;
 import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
 import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
 import org.apache.flink.runtime.metrics.groups.UnregisteredMetricGroups;
+import org.apache.flink.runtime.query.KvStateRegistry;
 import org.apache.flink.runtime.taskmanager.TaskActions;
 
 import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
 
 import java.io.IOException;
 import java.net.InetSocketAddress;
+import java.util.Arrays;
+import java.util.List;
 import java.util.Map;
 import java.util.Optional;
 import java.util.concurrent.atomic.AtomicReference;
 
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.instanceOf;
+import static org.hamcrest.Matchers.is;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertTrue;
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.anyInt;
-import static org.mockito.Matchers.anyListOf;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
+/**
+ * Tests for {@link SingleInputGate}.
+ */
+@RunWith(Parameterized.class)
 public class SingleInputGateTest {
 
+       @Parameterized.Parameter
+       public boolean enableCreditBasedFlowControl;
+
+       @Parameterized.Parameters(name = "Credit-based = {0}")
+       public static List<Boolean> parameters() {
+               return Arrays.asList(Boolean.TRUE, Boolean.FALSE);
+       }
+
        /**
         * Tests basic correctness of buffer-or-event interleaving and correct 
<code>null</code> return
         * value after receiving all end-of-partition events.
@@ -324,12 +344,7 @@ public class SingleInputGateTest {
                int initialBackoff = 137;
                int maxBackoff = 1001;
 
-               NetworkEnvironment netEnv = mock(NetworkEnvironment.class);
-               when(netEnv.getResultPartitionManager()).thenReturn(new 
ResultPartitionManager());
-               when(netEnv.getTaskEventDispatcher()).thenReturn(new 
TaskEventDispatcher());
-               
when(netEnv.getPartitionRequestInitialBackoff()).thenReturn(initialBackoff);
-               
when(netEnv.getPartitionRequestMaxBackoff()).thenReturn(maxBackoff);
-               when(netEnv.getConnectionManager()).thenReturn(new 
LocalConnectionManager());
+               final NetworkEnvironment netEnv = createNetworkEnvironment(2, 
8, initialBackoff, maxBackoff);
 
                SingleInputGate gate = SingleInputGate.create(
                        "TestTask",
@@ -340,37 +355,43 @@ public class SingleInputGateTest {
                        mock(TaskActions.class),
                        
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
 
-               assertEquals(gateDesc.getConsumedPartitionType(), 
gate.getConsumedPartitionType());
+               try {
+                       assertEquals(gateDesc.getConsumedPartitionType(), 
gate.getConsumedPartitionType());
 
-               Map<IntermediateResultPartitionID, InputChannel> channelMap = 
gate.getInputChannels();
+                       Map<IntermediateResultPartitionID, InputChannel> 
channelMap = gate.getInputChannels();
 
-               assertEquals(3, channelMap.size());
-               InputChannel localChannel = 
channelMap.get(partitionIds[0].getPartitionId());
-               assertEquals(LocalInputChannel.class, localChannel.getClass());
+                       assertEquals(3, channelMap.size());
+                       InputChannel localChannel = 
channelMap.get(partitionIds[0].getPartitionId());
+                       assertEquals(LocalInputChannel.class, 
localChannel.getClass());
 
-               InputChannel remoteChannel = 
channelMap.get(partitionIds[1].getPartitionId());
-               assertEquals(RemoteInputChannel.class, 
remoteChannel.getClass());
+                       InputChannel remoteChannel = 
channelMap.get(partitionIds[1].getPartitionId());
+                       assertEquals(RemoteInputChannel.class, 
remoteChannel.getClass());
 
-               InputChannel unknownChannel = 
channelMap.get(partitionIds[2].getPartitionId());
-               assertEquals(UnknownInputChannel.class, 
unknownChannel.getClass());
+                       InputChannel unknownChannel = 
channelMap.get(partitionIds[2].getPartitionId());
+                       assertEquals(UnknownInputChannel.class, 
unknownChannel.getClass());
 
-               InputChannel[] channels = new InputChannel[]{localChannel, 
remoteChannel, unknownChannel};
-               for (InputChannel ch : channels) {
-                       assertEquals(0, ch.getCurrentBackoff());
+                       InputChannel[] channels =
+                               new InputChannel[] {localChannel, 
remoteChannel, unknownChannel};
+                       for (InputChannel ch : channels) {
+                               assertEquals(0, ch.getCurrentBackoff());
 
-                       assertTrue(ch.increaseBackoff());
-                       assertEquals(initialBackoff, ch.getCurrentBackoff());
+                               assertTrue(ch.increaseBackoff());
+                               assertEquals(initialBackoff, 
ch.getCurrentBackoff());
 
-                       assertTrue(ch.increaseBackoff());
-                       assertEquals(initialBackoff * 2, 
ch.getCurrentBackoff());
+                               assertTrue(ch.increaseBackoff());
+                               assertEquals(initialBackoff * 2, 
ch.getCurrentBackoff());
 
-                       assertTrue(ch.increaseBackoff());
-                       assertEquals(initialBackoff * 2 * 2, 
ch.getCurrentBackoff());
+                               assertTrue(ch.increaseBackoff());
+                               assertEquals(initialBackoff * 2 * 2, 
ch.getCurrentBackoff());
 
-                       assertTrue(ch.increaseBackoff());
-                       assertEquals(maxBackoff, ch.getCurrentBackoff());
+                               assertTrue(ch.increaseBackoff());
+                               assertEquals(maxBackoff, 
ch.getCurrentBackoff());
 
-                       assertFalse(ch.increaseBackoff());
+                               assertFalse(ch.increaseBackoff());
+                       }
+               } finally {
+                       gate.releaseAllResources();
+                       netEnv.shutdown();
                }
        }
 
@@ -379,26 +400,39 @@ public class SingleInputGateTest {
         */
        @Test
        public void testRequestBuffersWithRemoteInputChannel() throws Exception 
{
-               final SingleInputGate inputGate = new SingleInputGate(
-                       "t1",
-                       new JobID(),
-                       new IntermediateDataSetID(),
-                       ResultPartitionType.PIPELINED_BOUNDED,
-                       0,
-                       1,
-                       mock(TaskActions.class),
-                       
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
-
-               RemoteInputChannel remote = mock(RemoteInputChannel.class);
-               inputGate.setInputChannel(new IntermediateResultPartitionID(), 
remote);
-
-               final int buffersPerChannel = 2;
-               NetworkBufferPool network = mock(NetworkBufferPool.class);
-               // Trigger requests of segments from global pool and assign 
buffers to remote input channel
-               inputGate.assignExclusiveSegments(network, buffersPerChannel);
-
-               verify(network, 
times(1)).requestMemorySegments(buffersPerChannel);
-               verify(remote, 
times(1)).assignExclusiveSegments(anyListOf(MemorySegment.class));
+               final SingleInputGate inputGate = createInputGate(1, 
ResultPartitionType.PIPELINED_BOUNDED);
+               int buffersPerChannel = 2;
+               int extraNetworkBuffersPerGate = 8;
+               final NetworkEnvironment network = 
createNetworkEnvironment(buffersPerChannel,
+                       extraNetworkBuffersPerGate, 0, 0);
+
+               try {
+                       final ResultPartitionID resultPartitionId = new 
ResultPartitionID();
+                       final ConnectionID connectionId = new ConnectionID(new 
InetSocketAddress("localhost", 5000), 0);
+                       addRemoteInputChannel(network, inputGate, connectionId, 
resultPartitionId, 0);
+
+                       network.setupInputGate(inputGate);
+
+                       NetworkBufferPool bufferPool = 
network.getNetworkBufferPool();
+                       if (enableCreditBasedFlowControl) {
+                               verify(bufferPool,
+                                       
times(1)).requestMemorySegments(buffersPerChannel);
+                               RemoteInputChannel remote = 
(RemoteInputChannel) inputGate.getInputChannels()
+                                       
.get(resultPartitionId.getPartitionId());
+                               // only the exclusive buffers should be 
assigned/available now
+                               assertEquals(buffersPerChannel, 
remote.getNumberOfAvailableBuffers());
+
+                               
assertEquals(bufferPool.getTotalNumberOfMemorySegments() - buffersPerChannel,
+                                       
bufferPool.getNumberOfAvailableMemorySegments());
+                               // note: exclusive buffers are not handed out 
into LocalBufferPool and are thus not counted
+                               assertEquals(extraNetworkBuffersPerGate, 
bufferPool.countBuffers());
+                       } else {
+                               assertEquals(buffersPerChannel + 
extraNetworkBuffersPerGate, bufferPool.countBuffers());
+                       }
+               } finally {
+                       inputGate.releaseAllResources();
+                       network.shutdown();
+               }
        }
 
        /**
@@ -407,51 +441,195 @@ public class SingleInputGateTest {
         */
        @Test
        public void testRequestBuffersWithUnknownInputChannel() throws 
Exception {
-               final SingleInputGate inputGate = createInputGate(1);
+               final SingleInputGate inputGate = createInputGate(1, 
ResultPartitionType.PIPELINED_BOUNDED);
+               int buffersPerChannel = 2;
+               int extraNetworkBuffersPerGate = 8;
+               final NetworkEnvironment network = 
createNetworkEnvironment(buffersPerChannel, extraNetworkBuffersPerGate, 0, 0);
 
-               UnknownInputChannel unknown = mock(UnknownInputChannel.class);
-               final ResultPartitionID resultPartitionId = new 
ResultPartitionID();
-               inputGate.setInputChannel(resultPartitionId.getPartitionId(), 
unknown);
+               try {
+                       final ResultPartitionID resultPartitionId = new 
ResultPartitionID();
+                       addUnknownInputChannel(network, inputGate, 
resultPartitionId, 0);
 
-               RemoteInputChannel remote = mock(RemoteInputChannel.class);
-               final ConnectionID connectionId = new ConnectionID(new 
InetSocketAddress("localhost", 5000), 0);
-               
when(unknown.toRemoteInputChannel(connectionId)).thenReturn(remote);
+                       network.setupInputGate(inputGate);
+                       NetworkBufferPool bufferPool = 
network.getNetworkBufferPool();
 
-               final int buffersPerChannel = 2;
-               NetworkBufferPool network = mock(NetworkBufferPool.class);
-               inputGate.assignExclusiveSegments(network, buffersPerChannel);
+                       if (enableCreditBasedFlowControl) {
+                               verify(bufferPool, 
times(0)).requestMemorySegments(buffersPerChannel);
 
-               // Trigger updates to remote input channel from unknown input 
channel
-               inputGate.updateInputChannel(new 
InputChannelDeploymentDescriptor(
-                       resultPartitionId,
-                       ResultPartitionLocation.createRemote(connectionId)));
+                               
assertEquals(bufferPool.getTotalNumberOfMemorySegments(),
+                                       
bufferPool.getNumberOfAvailableMemorySegments());
+                               // note: exclusive buffers are not handed out 
into LocalBufferPool and are thus not counted
+                               assertEquals(extraNetworkBuffersPerGate, 
bufferPool.countBuffers());
+                       } else {
+                               assertEquals(buffersPerChannel + 
extraNetworkBuffersPerGate, bufferPool.countBuffers());
+                       }
+
+                       // Trigger updates to remote input channel from unknown 
input channel
+                       final ConnectionID connectionId = new ConnectionID(new 
InetSocketAddress("localhost", 5000), 0);
+                       inputGate.updateInputChannel(new 
InputChannelDeploymentDescriptor(
+                               resultPartitionId,
+                               
ResultPartitionLocation.createRemote(connectionId)));
+
+                       if (enableCreditBasedFlowControl) {
+                               verify(bufferPool,
+                                       
times(1)).requestMemorySegments(buffersPerChannel);
+                               RemoteInputChannel remote = 
(RemoteInputChannel) inputGate.getInputChannels()
+                                       
.get(resultPartitionId.getPartitionId());
+                               // only the exclusive buffers should be 
assigned/available now
+                               assertEquals(buffersPerChannel, 
remote.getNumberOfAvailableBuffers());
+
+                               
assertEquals(bufferPool.getTotalNumberOfMemorySegments() - buffersPerChannel,
+                                       
bufferPool.getNumberOfAvailableMemorySegments());
+                               // note: exclusive buffers are not handed out 
into LocalBufferPool and are thus not counted
+                               assertEquals(extraNetworkBuffersPerGate, 
bufferPool.countBuffers());
+                       } else {
+                               assertEquals(buffersPerChannel + 
extraNetworkBuffersPerGate, bufferPool.countBuffers());
+                       }
+               } finally {
+                       inputGate.releaseAllResources();
+                       network.shutdown();
+               }
+       }
 
-               verify(network, 
times(1)).requestMemorySegments(buffersPerChannel);
-               verify(remote, 
times(1)).assignExclusiveSegments(anyListOf(MemorySegment.class));
+       /**
+        * Tests that input gate can successfully convert unknown input 
channels into local and remote
+        * channels.
+        */
+       @Test
+       public void testUpdateUnknownInputChannel() throws Exception {
+               final SingleInputGate inputGate = createInputGate(2);
+               int buffersPerChannel = 2;
+               final NetworkEnvironment network = 
createNetworkEnvironment(buffersPerChannel, 8, 0, 0);
+
+               try {
+                       final ResultPartitionID localResultPartitionId = new 
ResultPartitionID();
+                       addUnknownInputChannel(network, inputGate, 
localResultPartitionId, 0);
+
+                       final ResultPartitionID remoteResultPartitionId = new 
ResultPartitionID();
+                       addUnknownInputChannel(network, inputGate, 
remoteResultPartitionId, 1);
+
+                       network.setupInputGate(inputGate);
+
+                       
assertThat(inputGate.getInputChannels().get(remoteResultPartitionId.getPartitionId()),
+                               is(instanceOf((UnknownInputChannel.class))));
+                       
assertThat(inputGate.getInputChannels().get(localResultPartitionId.getPartitionId()),
+                               is(instanceOf((UnknownInputChannel.class))));
+
+                       // Trigger updates to remote input channel from unknown 
input channel
+                       final ConnectionID remoteConnectionId = new 
ConnectionID(new InetSocketAddress("localhost", 5000), 0);
+                       inputGate.updateInputChannel(new 
InputChannelDeploymentDescriptor(
+                               remoteResultPartitionId,
+                               
ResultPartitionLocation.createRemote(remoteConnectionId)));
+
+                       
assertThat(inputGate.getInputChannels().get(remoteResultPartitionId.getPartitionId()),
+                               is(instanceOf((RemoteInputChannel.class))));
+                       
assertThat(inputGate.getInputChannels().get(localResultPartitionId.getPartitionId()),
+                               is(instanceOf((UnknownInputChannel.class))));
+
+                       // Trigger updates to local input channel from unknown 
input channel
+                       inputGate.updateInputChannel(new 
InputChannelDeploymentDescriptor(
+                               localResultPartitionId,
+                               ResultPartitionLocation.createLocal()));
+
+                       
assertThat(inputGate.getInputChannels().get(remoteResultPartitionId.getPartitionId()),
+                               is(instanceOf((RemoteInputChannel.class))));
+                       
assertThat(inputGate.getInputChannels().get(localResultPartitionId.getPartitionId()),
+                               is(instanceOf((LocalInputChannel.class))));
+               } finally {
+                       inputGate.releaseAllResources();
+                       network.shutdown();
+               }
        }
 
        // 
---------------------------------------------------------------------------------------------
 
-       private static SingleInputGate createInputGate() {
+       private NetworkEnvironment createNetworkEnvironment(
+                       int buffersPerChannel,
+                       int extraNetworkBuffersPerGate,
+                       int initialBackoff,
+                       int maxBackoff) {
+               return new NetworkEnvironment(
+                       spy(new NetworkBufferPool(100, 32)),
+                       new LocalConnectionManager(),
+                       new ResultPartitionManager(),
+                       new TaskEventDispatcher(),
+                       new KvStateRegistry(),
+                       null,
+                       null,
+                       IOManager.IOMode.SYNC,
+                       initialBackoff,
+                       maxBackoff,
+                       buffersPerChannel,
+                       extraNetworkBuffersPerGate,
+                       enableCreditBasedFlowControl);
+       }
+
+       private SingleInputGate createInputGate() {
                return createInputGate(2);
        }
 
-       private static SingleInputGate createInputGate(int 
numberOfInputChannels) {
+       private SingleInputGate createInputGate(int numberOfInputChannels) {
+               return createInputGate(numberOfInputChannels, 
ResultPartitionType.PIPELINED);
+       }
+
+       private SingleInputGate createInputGate(
+                       int numberOfInputChannels, ResultPartitionType 
partitionType) {
                SingleInputGate inputGate = new SingleInputGate(
                        "Test Task Name",
                        new JobID(),
                        new IntermediateDataSetID(),
-                       ResultPartitionType.PIPELINED,
+                       partitionType,
                        0,
                        numberOfInputChannels,
                        mock(TaskActions.class),
-                       
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
+                       
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+                       enableCreditBasedFlowControl);
 
-               assertEquals(ResultPartitionType.PIPELINED, 
inputGate.getConsumedPartitionType());
+               assertEquals(partitionType, 
inputGate.getConsumedPartitionType());
 
                return inputGate;
        }
 
+       private void addUnknownInputChannel(
+                       NetworkEnvironment network,
+                       SingleInputGate inputGate,
+                       ResultPartitionID partitionId,
+                       int channelIndex) {
+               UnknownInputChannel unknown =
+                       createUnknownInputChannel(network, inputGate, 
partitionId, channelIndex);
+               inputGate.setInputChannel(partitionId.getPartitionId(), 
unknown);
+       }
+
+       private UnknownInputChannel createUnknownInputChannel(
+                       NetworkEnvironment network,
+                       SingleInputGate inputGate,
+                       ResultPartitionID partitionId,
+                       int channelIndex) {
+               return new UnknownInputChannel(
+                       inputGate,
+                       channelIndex,
+                       partitionId,
+                       network.getResultPartitionManager(),
+                       network.getTaskEventDispatcher(),
+                       network.getConnectionManager(),
+                       network.getPartitionRequestInitialBackoff(),
+                       network.getPartitionRequestMaxBackoff(),
+                       
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup()
+               );
+       }
+
+       private void addRemoteInputChannel(
+                       NetworkEnvironment network,
+                       SingleInputGate inputGate,
+                       ConnectionID connectionId,
+                       ResultPartitionID partitionId,
+                       int channelIndex) {
+               RemoteInputChannel remote =
+                       createUnknownInputChannel(network, inputGate, 
partitionId, channelIndex)
+                               .toRemoteInputChannel(connectionId);
+               inputGate.setInputChannel(partitionId.getPartitionId(), remote);
+       }
+
        static void verifyBufferOrEvent(
                        InputGate inputGate,
                        boolean expectedIsBuffer,

http://git-wip-us.apache.org/repos/asf/flink/blob/56e2b0b5/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestSingleInputGate.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestSingleInputGate.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestSingleInputGate.java
index 0ae6e74..33dc1ca 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestSingleInputGate.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestSingleInputGate.java
@@ -60,7 +60,8 @@ public class TestSingleInputGate {
                        0,
                        numberOfInputChannels,
                        mock(TaskActions.class),
-                       
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
+                       
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+                       true);
 
                this.inputGate = spy(realGate);
 

http://git-wip-us.apache.org/repos/asf/flink/blob/56e2b0b5/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java
index 912cd5b..081d97d 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java
@@ -50,13 +50,15 @@ public class UnionInputGateTest {
                        new IntermediateDataSetID(), 
ResultPartitionType.PIPELINED,
                        0, 3,
                        mock(TaskActions.class),
-                       
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
+                       
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+                       true);
                final SingleInputGate ig2 = new SingleInputGate(
                        testTaskName, new JobID(),
                        new IntermediateDataSetID(), 
ResultPartitionType.PIPELINED,
                        0, 5,
                        mock(TaskActions.class),
-                       
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
+                       
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup(),
+                       true);
 
                final UnionInputGate union = new UnionInputGate(new 
SingleInputGate[]{ig1, ig2});
 

Reply via email to