[FLINK-7378][core] Create a fix size (non rebalancing) buffer pool type for the 
floating buffers

This closes #4485.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/064a1e60
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/064a1e60
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/064a1e60

Branch: refs/heads/master
Commit: 064a1e60c5310688c7fa8bc5d07ce5da512e076d
Parents: a803dc7
Author: Zhijiang <[email protected]>
Authored: Mon Aug 7 17:31:17 2017 +0800
Committer: zentol <[email protected]>
Committed: Tue Oct 10 16:53:19 2017 +0200

----------------------------------------------------------------------
 .../runtime/io/network/NetworkEnvironment.java  |  25 +--
 .../io/network/buffer/NetworkBufferPool.java    |  62 +++++++
 .../network/partition/ResultPartitionType.java  |  27 ++-
 .../partition/consumer/RemoteInputChannel.java  |   6 +
 .../partition/consumer/SingleInputGate.java     |  57 +++++-
 .../io/network/NetworkEnvironmentTest.java      |  19 +-
 .../network/buffer/BufferPoolFactoryTest.java   | 102 ++++++++---
 .../network/buffer/NetworkBufferPoolTest.java   | 172 +++++++++++++++++++
 .../partition/consumer/SingleInputGateTest.java |  67 ++++++++
 9 files changed, 492 insertions(+), 45 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/064a1e60/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java
index 4269af6..9193859 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java
@@ -35,6 +35,7 @@ import org.apache.flink.runtime.query.netty.KvStateServer;
 import org.apache.flink.runtime.state.internal.InternalKvState;
 import org.apache.flink.runtime.taskmanager.Task;
 import org.apache.flink.runtime.taskmanager.TaskManager;
+import org.apache.flink.util.ExceptionUtils;
 import org.apache.flink.util.Preconditions;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -184,7 +185,7 @@ public class NetworkEnvironment {
                                                
partition.getNumberOfSubpartitions() * networkBuffersPerChannel +
                                                        
extraNetworkBuffersPerGate : Integer.MAX_VALUE;
                                        bufferPool = 
networkBufferPool.createBufferPool(partition.getNumberOfSubpartitions(),
-                                                       
maxNumberOfMemorySegments);
+                                               maxNumberOfMemorySegments);
                                        
partition.registerBufferPool(bufferPool);
 
                                        
resultPartitionManager.registerResultPartition(partition);
@@ -211,22 +212,24 @@ public class NetworkEnvironment {
                                BufferPool bufferPool = null;
 
                                try {
-                                       int maxNumberOfMemorySegments = 
gate.getConsumedPartitionType().isBounded() ?
-                                               gate.getNumberOfInputChannels() 
* networkBuffersPerChannel +
-                                                       
extraNetworkBuffersPerGate : Integer.MAX_VALUE;
-                                       bufferPool = 
networkBufferPool.createBufferPool(gate.getNumberOfInputChannels(),
-                                               maxNumberOfMemorySegments);
+                                       if 
(gate.getConsumedPartitionType().isCreditBased()) {
+                                               // Create a fixed-size buffer 
pool for floating buffers and assign exclusive buffers to input channels 
directly
+                                               bufferPool = 
networkBufferPool.createBufferPool(extraNetworkBuffersPerGate, 
extraNetworkBuffersPerGate);
+                                               
gate.assignExclusiveSegments(networkBufferPool, networkBuffersPerChannel);
+                                       } else {
+                                               int maxNumberOfMemorySegments = 
gate.getConsumedPartitionType().isBounded() ?
+                                                       
gate.getNumberOfInputChannels() * networkBuffersPerChannel +
+                                                               
extraNetworkBuffersPerGate : Integer.MAX_VALUE;
+                                               bufferPool = 
networkBufferPool.createBufferPool(gate.getNumberOfInputChannels(),
+                                                       
maxNumberOfMemorySegments);
+                                       }
                                        gate.setBufferPool(bufferPool);
                                } catch (Throwable t) {
                                        if (bufferPool != null) {
                                                bufferPool.lazyDestroy();
                                        }
 
-                                       if (t instanceof IOException) {
-                                               throw (IOException) t;
-                                       } else {
-                                               throw new 
IOException(t.getMessage(), t);
-                                       }
+                                       ExceptionUtils.rethrowIOException(t);
                                }
                        }
                }

http://git-wip-us.apache.org/repos/asf/flink/blob/064a1e60/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/NetworkBufferPool.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/NetworkBufferPool.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/NetworkBufferPool.java
index b70e912..f899f05 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/NetworkBufferPool.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/NetworkBufferPool.java
@@ -22,16 +22,21 @@ import org.apache.flink.configuration.TaskManagerOptions;
 import org.apache.flink.core.memory.MemorySegment;
 import org.apache.flink.core.memory.MemorySegmentFactory;
 import org.apache.flink.core.memory.MemoryType;
+import org.apache.flink.util.ExceptionUtils;
 import org.apache.flink.util.MathUtils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.io.IOException;
 import java.nio.ByteBuffer;
+import java.util.ArrayList;
 import java.util.HashSet;
+import java.util.List;
 import java.util.Set;
 import java.util.concurrent.ArrayBlockingQueue;
+import java.util.concurrent.TimeUnit;
 
+import static org.apache.flink.util.Preconditions.checkArgument;
 import static org.apache.flink.util.Preconditions.checkNotNull;
 
 /**
@@ -131,6 +136,63 @@ public class NetworkBufferPool implements 
BufferPoolFactory {
                availableMemorySegments.add(segment);
        }
 
+       public List<MemorySegment> requestMemorySegments(int 
numRequiredBuffers) throws IOException {
+               checkArgument(numRequiredBuffers > 0, "The number of required 
buffers should be larger than 0.");
+
+               synchronized (factoryLock) {
+                       if (isDestroyed) {
+                               throw new IllegalStateException("Network buffer 
pool has already been destroyed.");
+                       }
+
+                       if (numTotalRequiredBuffers + numRequiredBuffers > 
totalNumberOfMemorySegments) {
+                               throw new 
IOException(String.format("Insufficient number of network buffers: " +
+                                                               "required %d, 
but only %d available. The total number of network " +
+                                                               "buffers is 
currently set to %d of %d bytes each. You can increase this " +
+                                                               "number by 
setting the configuration keys '%s', '%s', and '%s'.",
+                                               numRequiredBuffers,
+                                               totalNumberOfMemorySegments - 
numTotalRequiredBuffers,
+                                               totalNumberOfMemorySegments,
+                                               memorySegmentSize,
+                                               
TaskManagerOptions.NETWORK_BUFFERS_MEMORY_FRACTION.key(),
+                                               
TaskManagerOptions.NETWORK_BUFFERS_MEMORY_MIN.key(),
+                                               
TaskManagerOptions.NETWORK_BUFFERS_MEMORY_MAX.key()));
+                       }
+
+                       this.numTotalRequiredBuffers += numRequiredBuffers;
+
+                       redistributeBuffers();
+               }
+
+               final List<MemorySegment> segments = new 
ArrayList<>(numRequiredBuffers);
+               try {
+                       while (segments.size() < numRequiredBuffers) {
+                               if (isDestroyed) {
+                                       throw new IllegalStateException("Buffer 
pool is destroyed.");
+                               }
+
+                               final MemorySegment segment = 
availableMemorySegments.poll(2, TimeUnit.SECONDS);
+                               if (segment != null) {
+                                       segments.add(segment);
+                               }
+                       }
+               } catch (Throwable e) {
+                       recycleMemorySegments(segments);
+                       ExceptionUtils.rethrowIOException(e);
+               }
+
+               return segments;
+       }
+
+       public void recycleMemorySegments(List<MemorySegment> segments) throws 
IOException {
+               synchronized (factoryLock) {
+                       numTotalRequiredBuffers -= segments.size();
+
+                       availableMemorySegments.addAll(segments);
+
+                       redistributeBuffers();
+               }
+       }
+
        public void destroy() {
                synchronized (factoryLock) {
                        isDestroyed = true;

http://git-wip-us.apache.org/repos/asf/flink/blob/064a1e60/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartitionType.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartitionType.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartitionType.java
index 256387c..683fe37 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartitionType.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartitionType.java
@@ -20,9 +20,9 @@ package org.apache.flink.runtime.io.network.partition;
 
 public enum ResultPartitionType {
 
-       BLOCKING(false, false, false),
+       BLOCKING(false, false, false, false),
 
-       PIPELINED(true, true, false),
+       PIPELINED(true, true, false, false),
 
        /**
         * Pipelined partitions with a bounded (local) buffer pool.
@@ -35,7 +35,13 @@ public enum ResultPartitionType {
         * For batch jobs, it will be best to keep this unlimited ({@link 
#PIPELINED}) since there are
         * no checkpoint barriers.
         */
-       PIPELINED_BOUNDED(true, true, true);
+       PIPELINED_BOUNDED(true, true, true, false),
+
+       /**
+        * Pipelined partitions with a bounded (local) buffer pool for floating 
buffers in input gate, and a number
+        * of exclusive buffers per input channel. The producer transfers data 
based on consumer's available credits.
+        */
+       PIPELINED_CREDIT_BASED(true, true, true, true);
 
        /** Can the partition be consumed while being produced? */
        private final boolean isPipelined;
@@ -46,13 +52,17 @@ public enum ResultPartitionType {
        /** Does this partition use a limited number of (network) buffers? */
        private final boolean isBounded;
 
+       /** Does this partition only send data when consumer has available 
buffers? */
+       private final boolean isCreditBased;
+
        /**
         * Specifies the behaviour of an intermediate result partition at 
runtime.
         */
-       ResultPartitionType(boolean isPipelined, boolean hasBackPressure, 
boolean isBounded) {
+       ResultPartitionType(boolean isPipelined, boolean hasBackPressure, 
boolean isBounded, boolean isCreditBased) {
                this.isPipelined = isPipelined;
                this.hasBackPressure = hasBackPressure;
                this.isBounded = isBounded;
+               this.isCreditBased = isCreditBased;
        }
 
        public boolean hasBackPressure() {
@@ -75,4 +85,13 @@ public enum ResultPartitionType {
        public boolean isBounded() {
                return isBounded;
        }
+
+       /**
+        * Whether this partition uses the credit-based mode to transfer data 
or not.
+        *
+        * @return <tt>true</tt> if the data is transferred based on consumer's 
credit
+        */
+       public boolean isCreditBased() {
+               return isCreditBased;
+       }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/064a1e60/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java
index 719f340..58c9484 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.runtime.io.network.partition.consumer;
 
+import org.apache.flink.core.memory.MemorySegment;
 import org.apache.flink.runtime.event.TaskEvent;
 import org.apache.flink.runtime.io.network.ConnectionID;
 import org.apache.flink.runtime.io.network.ConnectionManager;
@@ -30,6 +31,7 @@ import 
org.apache.flink.runtime.metrics.groups.TaskIOMetricGroup;
 
 import java.io.IOException;
 import java.util.ArrayDeque;
+import java.util.List;
 import java.util.concurrent.atomic.AtomicBoolean;
 
 import static org.apache.flink.util.Preconditions.checkNotNull;
@@ -97,6 +99,10 @@ public class RemoteInputChannel extends InputChannel {
                this.connectionManager = checkNotNull(connectionManager);
        }
 
+       void assignExclusiveSegments(List<MemorySegment> segments) {
+               // TODO in next PR
+       }
+
        // 
------------------------------------------------------------------------
        // Consume
        // 
------------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/flink/blob/064a1e60/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
index ebfb300..945d127 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
@@ -19,6 +19,7 @@
 package org.apache.flink.runtime.io.network.partition.consumer;
 
 import org.apache.flink.api.common.JobID;
+import org.apache.flink.core.memory.MemorySegment;
 import org.apache.flink.runtime.deployment.InputChannelDeploymentDescriptor;
 import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
 import org.apache.flink.runtime.deployment.ResultPartitionLocation;
@@ -31,6 +32,7 @@ import 
org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.io.network.buffer.BufferPool;
 import org.apache.flink.runtime.io.network.buffer.BufferProvider;
+import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
 import 
org.apache.flink.runtime.io.network.partition.consumer.InputChannel.BufferAndAvailability;
@@ -147,6 +149,9 @@ public class SingleInputGate implements InputGate {
         */
        private BufferPool bufferPool;
 
+       /** Global network buffer pool to request and recycle exclusive 
buffers. */
+       private NetworkBufferPool networkBufferPool;
+
        private boolean hasReceivedAllEndOfPartitionEvents;
 
        /** Flag indicating whether partitions have been requested. */
@@ -162,6 +167,9 @@ public class SingleInputGate implements InputGate {
 
        private int numberOfUninitializedChannels;
 
+       /** Number of network buffers to use for each remote input channel. */
+       private int networkBuffersPerChannel;
+
        /** A timer to retrigger local partition requests. Only initialized if 
actually needed. */
        private Timer retriggerLocalRequestTimer;
 
@@ -259,21 +267,55 @@ public class SingleInputGate implements InputGate {
 
        public void setBufferPool(BufferPool bufferPool) {
                // Sanity checks
-               checkArgument(numberOfInputChannels == 
bufferPool.getNumberOfRequiredMemorySegments(),
+               if (!getConsumedPartitionType().isCreditBased()) {
+                       checkArgument(numberOfInputChannels == 
bufferPool.getNumberOfRequiredMemorySegments(),
                                "Bug in input gate setup logic: buffer pool has 
not enough guaranteed buffers " +
-                                               "for this input gate. Input 
gates require at least as many buffers as " +
+                                       "for this input gate. Input gates 
require at least as many buffers as " +
                                                "there are input channels.");
+               }
 
                checkState(this.bufferPool == null, "Bug in input gate setup 
logic: buffer pool has" +
-                               "already been set for this input gate.");
+                       "already been set for this input gate.");
 
                this.bufferPool = checkNotNull(bufferPool);
        }
 
+       /**
+        * Assign the exclusive buffers to all remote input channels directly 
for credit-based mode.
+        *
+        * @param networkBufferPool The global pool to request and recycle 
exclusive buffers
+        * @param networkBuffersPerChannel The number of exclusive buffers for 
each channel
+        */
+       public void assignExclusiveSegments(NetworkBufferPool 
networkBufferPool, int networkBuffersPerChannel) throws IOException {
+               checkState(this.networkBufferPool == null, "Bug in input gate 
setup logic: global buffer pool has" +
+                       "already been set for this input gate.");
+
+               this.networkBufferPool = checkNotNull(networkBufferPool);
+               this.networkBuffersPerChannel = networkBuffersPerChannel;
+               
+               synchronized (requestLock) {
+                       for (InputChannel inputChannel : 
inputChannels.values()) {
+                               if (inputChannel instanceof RemoteInputChannel) 
{
+                                       ((RemoteInputChannel) 
inputChannel).assignExclusiveSegments(
+                                               
networkBufferPool.requestMemorySegments(networkBuffersPerChannel));
+                               }
+                       }
+               }
+       }
+
+       /**
+        * The exclusive segments are recycled to network buffer pool directly 
when input channel is released.
+        *
+        * @param segments The exclusive segments need to be recycled
+        */
+       public void returnExclusiveSegments(List<MemorySegment> segments) 
throws IOException {
+               networkBufferPool.recycleMemorySegments(segments);
+       }
+
        public void setInputChannel(IntermediateResultPartitionID partitionId, 
InputChannel inputChannel) {
                synchronized (requestLock) {
                        if (inputChannels.put(checkNotNull(partitionId), 
checkNotNull(inputChannel)) == null
-                                       && inputChannel.getClass() == 
UnknownInputChannel.class) {
+                                       && inputChannel instanceof 
UnknownInputChannel) {
 
                                numberOfUninitializedChannels++;
                        }
@@ -291,7 +333,7 @@ public class SingleInputGate implements InputGate {
 
                        InputChannel current = inputChannels.get(partitionId);
 
-                       if (current.getClass() == UnknownInputChannel.class) {
+                       if (current instanceof UnknownInputChannel) {
 
                                UnknownInputChannel unknownChannel = 
(UnknownInputChannel) current;
 
@@ -304,6 +346,11 @@ public class SingleInputGate implements InputGate {
                                }
                                else if (partitionLocation.isRemote()) {
                                        newChannel = 
unknownChannel.toRemoteInputChannel(partitionLocation.getConnectionId());
+
+                                       if 
(getConsumedPartitionType().isCreditBased()) {
+                                               
((RemoteInputChannel)newChannel).assignExclusiveSegments(
+                                                       
networkBufferPool.requestMemorySegments(networkBuffersPerChannel));
+                                       }
                                }
                                else {
                                        throw new IllegalStateException("Tried 
to update unknown channel with unknown channel.");

http://git-wip-us.apache.org/repos/asf/flink/blob/064a1e60/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java
index b956691..826ae3f 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java
@@ -37,10 +37,14 @@ import org.junit.Test;
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
 
+import java.io.IOException;
+
 import static org.junit.Assert.assertEquals;
 import static org.mockito.Matchers.any;
 import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
 /**
@@ -82,11 +86,11 @@ public class NetworkEnvironmentTest {
                        new ResultPartitionWriter(rp3), new 
ResultPartitionWriter(rp4)};
 
                // input gates
-               final SingleInputGate[] inputGates = new SingleInputGate[] {
-                       
createSingleInputGateMock(ResultPartitionType.PIPELINED, 2),
-                       createSingleInputGateMock(ResultPartitionType.BLOCKING, 
2),
-                       
createSingleInputGateMock(ResultPartitionType.PIPELINED_BOUNDED, 2),
-                       
createSingleInputGateMock(ResultPartitionType.PIPELINED_BOUNDED, 8)};
+               SingleInputGate ig1 = 
createSingleInputGateMock(ResultPartitionType.PIPELINED, 2);
+               SingleInputGate ig2 = 
createSingleInputGateMock(ResultPartitionType.BLOCKING, 2);
+               SingleInputGate ig3 = 
createSingleInputGateMock(ResultPartitionType.PIPELINED_BOUNDED, 2);
+               SingleInputGate ig4 = 
createSingleInputGateMock(ResultPartitionType.PIPELINED_CREDIT_BASED, 8);
+               final SingleInputGate[] inputGates = new SingleInputGate[] 
{ig1, ig2, ig3, ig4};
 
                // overall task to register
                Task task = mock(Task.class);
@@ -101,6 +105,8 @@ public class NetworkEnvironmentTest {
                assertEquals(2 * 2 + 8, 
rp3.getBufferPool().getMaxNumberOfMemorySegments());
                assertEquals(8 * 2 + 8, 
rp4.getBufferPool().getMaxNumberOfMemorySegments());
 
+               verify(ig4, 
times(1)).assignExclusiveSegments(network.getNetworkBufferPool(), 2);
+
                network.shutdown();
        }
 
@@ -154,12 +160,15 @@ public class NetworkEnvironmentTest {
                                BufferPool bp = invocation.getArgumentAt(0, 
BufferPool.class);
                                if (partitionType == 
ResultPartitionType.PIPELINED_BOUNDED) {
                                        assertEquals(channels * 2 + 8, 
bp.getMaxNumberOfMemorySegments());
+                               } else if (partitionType == 
ResultPartitionType.PIPELINED_CREDIT_BASED) {
+                                       assertEquals(8, 
bp.getMaxNumberOfMemorySegments());
                                } else {
                                        assertEquals(Integer.MAX_VALUE, 
bp.getMaxNumberOfMemorySegments());
                                }
                                return null;
                        }
                }).when(ig).setBufferPool(any(BufferPool.class));
+
                return ig;
        }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/064a1e60/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/BufferPoolFactoryTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/BufferPoolFactoryTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/BufferPoolFactoryTest.java
index ce76a6d..7f2bcc6 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/BufferPoolFactoryTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/BufferPoolFactoryTest.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.runtime.io.network.buffer;
 
+import org.apache.flink.core.memory.MemorySegment;
 import org.apache.flink.core.memory.MemoryType;
 import org.junit.After;
 import org.junit.Before;
@@ -48,9 +49,13 @@ public class BufferPoolFactoryTest {
        @After
        public void verifyAllBuffersReturned() {
                String msg = "Did not return all buffers to network buffer pool 
after test.";
-               assertEquals(msg, numBuffers, 
networkBufferPool.getNumberOfAvailableMemorySegments());
-               // in case buffers have actually been requested, we must 
release them again
-               networkBufferPool.destroy();
+               try {
+                       assertEquals(msg, numBuffers, 
networkBufferPool.getNumberOfAvailableMemorySegments());
+               } finally {
+                       // in case buffers have actually been requested, we 
must release them again
+                       networkBufferPool.destroyAllBufferPools();
+                       networkBufferPool.destroy();
+               }
        }
 
        @Test(expected = IOException.class)
@@ -134,25 +139,82 @@ public class BufferPoolFactoryTest {
        @Test
        public void testUniformDistributionBounded3() throws IOException {
                NetworkBufferPool globalPool = new NetworkBufferPool(3, 128, 
MemoryType.HEAP);
-               BufferPool first = globalPool.createBufferPool(0, 10);
-               assertEquals(3, first.getNumBuffers());
-
-               BufferPool second = globalPool.createBufferPool(0, 10);
-               // the order of which buffer pool received 2 or 1 buffer is 
undefined
-               assertEquals(3, first.getNumBuffers() + second.getNumBuffers());
-               assertNotEquals(3, first.getNumBuffers());
-               assertNotEquals(3, second.getNumBuffers());
+               try {
+                       BufferPool first = globalPool.createBufferPool(0, 10);
+                       assertEquals(3, first.getNumBuffers());
+
+                       BufferPool second = globalPool.createBufferPool(0, 10);
+                       // the order of which buffer pool received 2 or 1 
buffer is undefined
+                       assertEquals(3, first.getNumBuffers() + 
second.getNumBuffers());
+                       assertNotEquals(3, first.getNumBuffers());
+                       assertNotEquals(3, second.getNumBuffers());
+
+                       BufferPool third = globalPool.createBufferPool(0, 10);
+                       assertEquals(1, first.getNumBuffers());
+                       assertEquals(1, second.getNumBuffers());
+                       assertEquals(1, third.getNumBuffers());
+
+                       // similar to #verifyAllBuffersReturned()
+                       String msg = "Wrong number of available segments after 
creating buffer pools.";
+                       assertEquals(msg, 3, 
globalPool.getNumberOfAvailableMemorySegments());
+               } finally {
+                       // in case buffers have actually been requested, we 
must release them again
+                       globalPool.destroyAllBufferPools();
+                       globalPool.destroy();
+               }
+       }
 
-               BufferPool third = globalPool.createBufferPool(0, 10);
-               assertEquals(1, first.getNumBuffers());
-               assertEquals(1, second.getNumBuffers());
-               assertEquals(1, third.getNumBuffers());
+       /**
+        * Tests the interaction of requesting memory segments and creating 
local buffer pool and
+        * verifies the number of assigned buffers match after redistributing 
buffers because of newly
+        * requested memory segments or new buffer pools created.
+        */
+       @Test
+       public void testUniformDistributionBounded4() throws IOException {
+               NetworkBufferPool globalPool = new NetworkBufferPool(10, 128, 
MemoryType.HEAP);
+               try {
+                       BufferPool first = globalPool.createBufferPool(0, 10);
+                       assertEquals(10, first.getNumBuffers());
 
-               // similar to #verifyAllBuffersReturned()
-               String msg = "Did not return all buffers to network buffer pool 
after test.";
-               assertEquals(msg, 3, 
globalPool.getNumberOfAvailableMemorySegments());
-               // in case buffers have actually been requested, we must 
release them again
-               globalPool.destroy();
+                       List<MemorySegment> segmentList1 = 
globalPool.requestMemorySegments(2);
+                       assertEquals(2, segmentList1.size());
+                       assertEquals(8, first.getNumBuffers());
+
+                       BufferPool second = globalPool.createBufferPool(0, 10);
+                       assertEquals(4, first.getNumBuffers());
+                       assertEquals(4, second.getNumBuffers());
+
+                       List<MemorySegment> segmentList2 = 
globalPool.requestMemorySegments(2);
+                       assertEquals(2, segmentList2.size());
+                       assertEquals(3, first.getNumBuffers());
+                       assertEquals(3, second.getNumBuffers());
+
+                       List<MemorySegment> segmentList3 = 
globalPool.requestMemorySegments(2);
+                       assertEquals(2, segmentList3.size());
+                       assertEquals(2, first.getNumBuffers());
+                       assertEquals(2, second.getNumBuffers());
+
+                       String msg = "Wrong number of available segments after 
creating buffer pools and requesting segments.";
+                       assertEquals(msg, 4, 
globalPool.getNumberOfAvailableMemorySegments());
+
+                       globalPool.recycleMemorySegments(segmentList1);
+                       assertEquals(msg, 6, 
globalPool.getNumberOfAvailableMemorySegments());
+                       assertEquals(3, first.getNumBuffers());
+                       assertEquals(3, second.getNumBuffers());
+
+                       globalPool.recycleMemorySegments(segmentList2);
+                       assertEquals(msg, 8, 
globalPool.getNumberOfAvailableMemorySegments());
+                       assertEquals(4, first.getNumBuffers());
+                       assertEquals(4, second.getNumBuffers());
+
+                       globalPool.recycleMemorySegments(segmentList3);
+                       assertEquals(msg, 10, 
globalPool.getNumberOfAvailableMemorySegments());
+                       assertEquals(5, first.getNumBuffers());
+                       assertEquals(5, second.getNumBuffers());
+               } finally {
+                       globalPool.destroyAllBufferPools();
+                       globalPool.destroy();
+               }
        }
 
        @Test

http://git-wip-us.apache.org/repos/asf/flink/blob/064a1e60/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/NetworkBufferPoolTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/NetworkBufferPoolTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/NetworkBufferPoolTest.java
index 7c6a543..e30e955 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/NetworkBufferPoolTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/NetworkBufferPoolTest.java
@@ -18,20 +18,36 @@
 
 package org.apache.flink.runtime.io.network.buffer;
 
+import org.apache.flink.core.memory.MemorySegment;
 import org.apache.flink.core.memory.MemoryType;
+import org.apache.flink.core.testutils.CheckedThread;
+import org.apache.flink.core.testutils.OneShotLatch;
+import org.junit.Rule;
 import org.junit.Test;
+import org.junit.rules.ExpectedException;
 
+import java.io.IOException;
 import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
+import static org.hamcrest.core.IsCollectionContaining.hasItem;
+import static org.hamcrest.CoreMatchers.nullValue;
+import static org.hamcrest.core.IsNot.not;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertThat;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 
 public class NetworkBufferPoolTest {
 
+       @Rule
+       public ExpectedException expectedException = ExpectedException.none();
+
        @Test
        public void testCreatePoolAfterDestroy() {
                try {
@@ -168,4 +184,160 @@ public class NetworkBufferPoolTest {
                        fail(e.getMessage());
                }
        }
+
+       /**
+        * Tests {@link NetworkBufferPool#requestMemorySegments(int)} with the 
{@link NetworkBufferPool}
+        * currently containing the number of required free segments.
+        */
+       @Test
+       public void testRequestMemorySegmentsLessThanTotalBuffers() throws 
Exception {
+               final int numBuffers = 10;
+
+               NetworkBufferPool globalPool = new 
NetworkBufferPool(numBuffers, 128, MemoryType.HEAP);
+
+               List<MemorySegment> memorySegments = Collections.emptyList();
+               try {
+                       memorySegments = 
globalPool.requestMemorySegments(numBuffers / 2);
+                       assertEquals(memorySegments.size(), numBuffers / 2);
+
+                       globalPool.recycleMemorySegments(memorySegments);
+                       memorySegments.clear();
+                       
assertEquals(globalPool.getNumberOfAvailableMemorySegments(), numBuffers);
+               } finally {
+                       globalPool.recycleMemorySegments(memorySegments); // 
just in case
+                       globalPool.destroy();
+               }
+       }
+
+       /**
+        * Tests {@link NetworkBufferPool#requestMemorySegments(int)} with the 
number of required
+        * buffers exceeding the capacity of {@link NetworkBufferPool}.
+        */
+       @Test
+       public void testRequestMemorySegmentsMoreThanTotalBuffers() throws 
Exception {
+               final int numBuffers = 10;
+
+               NetworkBufferPool globalPool = new 
NetworkBufferPool(numBuffers, 128, MemoryType.HEAP);
+
+               try {
+                       globalPool.requestMemorySegments(numBuffers + 1);
+                       fail("Should throw an IOException");
+               } catch (IOException e) {
+                       
assertEquals(globalPool.getNumberOfAvailableMemorySegments(), numBuffers);
+               } finally {
+                       globalPool.destroy();
+               }
+       }
+
+       /**
+        * Tests {@link NetworkBufferPool#requestMemorySegments(int)} with the 
invalid argument to
+        * cause exception.
+        */
+       @Test
+       public void testRequestMemorySegmentsWithInvalidArgument() throws 
Exception {
+               final int numBuffers = 10;
+
+               NetworkBufferPool globalPool = new 
NetworkBufferPool(numBuffers, 128, MemoryType.HEAP);
+
+               try {
+                       // the number of requested buffers should be larger 
than zero
+                       globalPool.requestMemorySegments(0);
+                       fail("Should throw an IllegalArgumentException");
+               } catch (IllegalArgumentException e) {
+                       
assertEquals(globalPool.getNumberOfAvailableMemorySegments(), numBuffers);
+               } finally {
+                       globalPool.destroy();
+               }
+       }
+
+       /**
+        * Tests {@link NetworkBufferPool#requestMemorySegments(int)} with the 
{@link NetworkBufferPool}
+        * currently not containing the number of required free segments 
(currently occupied by a buffer pool).
+        */
+       @Test
+       public void testRequestMemorySegmentsWithBuffersTaken() throws 
IOException, InterruptedException {
+               final int numBuffers = 10;
+
+               NetworkBufferPool networkBufferPool = new 
NetworkBufferPool(numBuffers, 128, MemoryType.HEAP);
+
+               final List<Buffer> buffers = new ArrayList<>(numBuffers);
+               List<MemorySegment> memorySegments = Collections.emptyList();
+               Thread bufferRecycler = null;
+               BufferPool lbp1 = null;
+               try {
+                       lbp1 = networkBufferPool.createBufferPool(numBuffers / 
2, numBuffers);
+
+                       // take all buffers (more than the minimum required)
+                       for (int i = 0; i < numBuffers; ++i) {
+                               Buffer buffer = lbp1.requestBuffer();
+                               buffers.add(buffer);
+                               assertNotNull(buffer);
+                       }
+
+                       // requestMemorySegments() below will wait for buffers
+                       // this will make sure that enough buffers are freed 
eventually for it to continue
+                       final OneShotLatch isRunning = new OneShotLatch();
+                       bufferRecycler = new Thread(() -> {
+                               try {
+                                       isRunning.trigger();
+                                       Thread.sleep(100);
+                               } catch (InterruptedException ignored) {
+                               }
+
+                               for (Buffer buffer : buffers) {
+                                       buffer.recycle();
+                               }
+                       });
+                       bufferRecycler.start();
+
+                       // take more buffers than are freely available at the 
moment via requestMemorySegments()
+                       isRunning.await();
+                       memorySegments = 
networkBufferPool.requestMemorySegments(numBuffers / 2);
+                       assertThat(memorySegments, not(hasItem(nullValue())));
+               } finally {
+                       if (bufferRecycler != null) {
+                               bufferRecycler.join();
+                       }
+                       if (lbp1 != null) {
+                               lbp1.lazyDestroy();
+                       }
+                       networkBufferPool.recycleMemorySegments(memorySegments);
+                       networkBufferPool.destroy();
+               }
+       }
+
+       /**
+        * Tests {@link NetworkBufferPool#requestMemorySegments(int)}, 
verifying it may be aborted in
+        * case of a concurrent {@link NetworkBufferPool#destroy()} call.
+        */
+       @Test
+       public void testRequestMemorySegmentsInterruptable() throws Exception {
+               final int numBuffers = 10;
+
+               NetworkBufferPool globalPool = new 
NetworkBufferPool(numBuffers, 128, MemoryType.HEAP);
+               MemorySegment segment = globalPool.requestMemorySegment();
+               assertNotNull(segment);
+
+               final OneShotLatch isRunning = new OneShotLatch();
+               CheckedThread asyncRequest = new CheckedThread() {
+                       @Override
+                       public void go() throws Exception {
+                               isRunning.trigger();
+                               globalPool.requestMemorySegments(10);
+                       }
+               };
+               asyncRequest.start();
+
+               // We want the destroy call inside the blocking part of the 
globalPool.requestMemorySegments()
+               // call above. We cannot guarantee this though but make it 
highly probable:
+               isRunning.await();
+               Thread.sleep(10);
+               globalPool.destroy();
+
+               segment.free();
+
+               expectedException.expect(IllegalStateException.class);
+               expectedException.expectMessage("destroyed");
+               asyncRequest.sync();
+       }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/064a1e60/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
index 737f17b..4d7d884 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
@@ -19,6 +19,7 @@
 package org.apache.flink.runtime.io.network.partition.consumer;
 
 import org.apache.flink.api.common.JobID;
+import org.apache.flink.core.memory.MemorySegment;
 import org.apache.flink.core.memory.MemorySegmentFactory;
 import org.apache.flink.runtime.deployment.InputChannelDeploymentDescriptor;
 import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
@@ -33,6 +34,7 @@ import 
org.apache.flink.runtime.io.network.TaskEventDispatcher;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.io.network.buffer.BufferPool;
 import org.apache.flink.runtime.io.network.buffer.BufferRecycler;
+import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
 import 
org.apache.flink.runtime.io.network.partition.BufferAvailabilityListener;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionManager;
@@ -57,6 +59,7 @@ import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.anyInt;
+import static org.mockito.Matchers.anyListOf;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.times;
@@ -372,6 +375,70 @@ public class SingleInputGateTest {
                }
        }
 
+       /**
+        * Tests that input gate requests and assigns network buffers for 
remote input channel.
+        */
+       @Test
+       public void testRequestBuffersWithRemoteInputChannel() throws Exception 
{
+               final SingleInputGate inputGate = new SingleInputGate(
+                       "t1",
+                       new JobID(),
+                       new IntermediateDataSetID(),
+                       ResultPartitionType.PIPELINED_CREDIT_BASED,
+                       0,
+                       1,
+                       mock(TaskActions.class),
+                       new 
UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
+
+               RemoteInputChannel remote = mock(RemoteInputChannel.class);
+               inputGate.setInputChannel(new IntermediateResultPartitionID(), 
remote);
+
+               final int buffersPerChannel = 2;
+               NetworkBufferPool network = mock(NetworkBufferPool.class);
+               // Trigger requests of segments from global pool and assign 
buffers to remote input channel
+               inputGate.assignExclusiveSegments(network, buffersPerChannel);
+
+               verify(network, 
times(1)).requestMemorySegments(buffersPerChannel);
+               verify(remote, 
times(1)).assignExclusiveSegments(anyListOf(MemorySegment.class));
+       }
+
+       /**
+        * Tests that input gate requests and assigns network buffers when 
unknown input channel
+        * updates to remote input channel.
+        */
+       @Test
+       public void testRequestBuffersWithUnknownInputChannel() throws 
Exception {
+               final SingleInputGate inputGate = new SingleInputGate(
+                       "t1",
+                       new JobID(),
+                       new IntermediateDataSetID(),
+                       ResultPartitionType.PIPELINED_CREDIT_BASED,
+                       0,
+                       1,
+                       mock(TaskActions.class),
+                       new 
UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
+
+               UnknownInputChannel unknown = mock(UnknownInputChannel.class);
+               final ResultPartitionID resultPartitionId = new 
ResultPartitionID();
+               inputGate.setInputChannel(resultPartitionId.getPartitionId(), 
unknown);
+
+               RemoteInputChannel remote = mock(RemoteInputChannel.class);
+               final ConnectionID connectionId = new ConnectionID(new 
InetSocketAddress("localhost", 5000), 0);
+               
when(unknown.toRemoteInputChannel(connectionId)).thenReturn(remote);
+
+               final int buffersPerChannel = 2;
+               NetworkBufferPool network = mock(NetworkBufferPool.class);
+               inputGate.assignExclusiveSegments(network, buffersPerChannel);
+
+               // Trigger updates to remote input channel from unknown input 
channel
+               inputGate.updateInputChannel(new 
InputChannelDeploymentDescriptor(
+                       resultPartitionId,
+                       ResultPartitionLocation.createRemote(connectionId)));
+
+               verify(network, 
times(1)).requestMemorySegments(buffersPerChannel);
+               verify(remote, 
times(1)).assignExclusiveSegments(anyListOf(MemorySegment.class));
+       }
+
        // 
---------------------------------------------------------------------------------------------
 
        static void verifyBufferOrEvent(

Reply via email to