[FLINK-7406][network] Implement Netty receiver incoming pipeline for 
credit-based


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/268867ce
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/268867ce
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/268867ce

Branch: refs/heads/master
Commit: 268867ce620a2c12879749db2ecb68bbe129cad5
Parents: 542419b
Author: Zhijiang <wangzhijiang...@aliyun.com>
Authored: Thu Aug 10 13:29:13 2017 +0800
Committer: Stefan Richter <s.rich...@data-artisans.com>
Committed: Mon Jan 8 11:46:00 2018 +0100

----------------------------------------------------------------------
 .../network/netty/CreditBasedClientHandler.java | 277 ++++++++
 .../runtime/io/network/netty/NettyMessage.java  |  15 +-
 .../netty/PartitionRequestClientHandler.java    |   8 +-
 .../io/network/netty/PartitionRequestQueue.java |   3 +-
 .../partition/consumer/RemoteInputChannel.java  | 257 +++++--
 .../netty/NettyMessageSerializationTest.java    |   3 +-
 .../PartitionRequestClientHandlerTest.java      | 151 ++---
 .../partition/InputGateConcurrentTest.java      |   2 +-
 .../partition/InputGateFairnessTest.java        |   8 +-
 .../consumer/RemoteInputChannelTest.java        | 665 +++++++++++++++++--
 10 files changed, 1175 insertions(+), 214 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/268867ce/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/CreditBasedClientHandler.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/CreditBasedClientHandler.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/CreditBasedClientHandler.java
new file mode 100644
index 0000000..1f18588
--- /dev/null
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/CreditBasedClientHandler.java
@@ -0,0 +1,277 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.netty;
+
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.core.memory.MemorySegmentFactory;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
+import 
org.apache.flink.runtime.io.network.netty.exception.LocalTransportException;
+import 
org.apache.flink.runtime.io.network.netty.exception.RemoteTransportException;
+import org.apache.flink.runtime.io.network.netty.exception.TransportException;
+import 
org.apache.flink.runtime.io.network.partition.PartitionNotFoundException;
+import org.apache.flink.runtime.io.network.partition.consumer.InputChannelID;
+import 
org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel;
+
+import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext;
+import 
org.apache.flink.shaded.netty4.io.netty.channel.ChannelInboundHandlerAdapter;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.net.SocketAddress;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.atomic.AtomicReference;
+
+/**
+ * Channel handler to read the messages of buffer response or error response 
from the
+ * producer, to write and flush the unannounced credits for the producer.
+ */
+class CreditBasedClientHandler extends ChannelInboundHandlerAdapter {
+
+       private static final Logger LOG = 
LoggerFactory.getLogger(CreditBasedClientHandler.class);
+
+       /** Channels, which already requested partitions from the producers. */
+       private final ConcurrentMap<InputChannelID, RemoteInputChannel> 
inputChannels = new ConcurrentHashMap<>();
+
+       private final AtomicReference<Throwable> channelError = new 
AtomicReference<>();
+
+       /**
+        * Set of cancelled partition requests. A request is cancelled iff an 
input channel is cleared
+        * while data is still coming in for this channel.
+        */
+       private final ConcurrentMap<InputChannelID, InputChannelID> cancelled = 
new ConcurrentHashMap<>();
+
+       private volatile ChannelHandlerContext ctx;
+
+       // 
------------------------------------------------------------------------
+       // Input channel/receiver registration
+       // 
------------------------------------------------------------------------
+
+       void addInputChannel(RemoteInputChannel listener) throws IOException {
+               checkError();
+
+               if (!inputChannels.containsKey(listener.getInputChannelId())) {
+                       inputChannels.put(listener.getInputChannelId(), 
listener);
+               }
+       }
+
+       void removeInputChannel(RemoteInputChannel listener) {
+               inputChannels.remove(listener.getInputChannelId());
+       }
+
+       void cancelRequestFor(InputChannelID inputChannelId) {
+               if (inputChannelId == null || ctx == null) {
+                       return;
+               }
+
+               if (cancelled.putIfAbsent(inputChannelId, inputChannelId) == 
null) {
+                       ctx.writeAndFlush(new 
NettyMessage.CancelPartitionRequest(inputChannelId));
+               }
+       }
+
+       // 
------------------------------------------------------------------------
+       // Network events
+       // 
------------------------------------------------------------------------
+
+       @Override
+       public void channelActive(final ChannelHandlerContext ctx) throws 
Exception {
+               if (this.ctx == null) {
+                       this.ctx = ctx;
+               }
+
+               super.channelActive(ctx);
+       }
+
+       @Override
+       public void channelInactive(ChannelHandlerContext ctx) throws Exception 
{
+               // Unexpected close. In normal operation, the client closes the 
connection after all input
+               // channels have been removed. This indicates a problem with 
the remote task manager.
+               if (!inputChannels.isEmpty()) {
+                       final SocketAddress remoteAddr = 
ctx.channel().remoteAddress();
+
+                       notifyAllChannelsOfErrorAndClose(new 
RemoteTransportException(
+                               "Connection unexpectedly closed by remote task 
manager '" + remoteAddr + "'. "
+                                       + "This might indicate that the remote 
task manager was lost.", remoteAddr));
+               }
+
+               super.channelInactive(ctx);
+       }
+
+       /**
+        * Called on exceptions in the client handler pipeline.
+        *
+        * <p>Remote exceptions are received as regular payload.
+        */
+       @Override
+       public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) 
throws Exception {
+
+               if (cause instanceof TransportException) {
+                       notifyAllChannelsOfErrorAndClose(cause);
+               } else {
+                       final SocketAddress remoteAddr = 
ctx.channel().remoteAddress();
+
+                       final TransportException tex;
+
+                       // Improve on the connection reset by peer error message
+                       if (cause instanceof IOException && 
cause.getMessage().equals("Connection reset by peer")) {
+                               tex = new RemoteTransportException("Lost 
connection to task manager '" + remoteAddr + "'. " +
+                                       "This indicates that the remote task 
manager was lost.", remoteAddr, cause);
+                       } else {
+                               tex = new 
LocalTransportException(cause.getMessage(), ctx.channel().localAddress(), 
cause);
+                       }
+
+                       notifyAllChannelsOfErrorAndClose(tex);
+               }
+       }
+
+       @Override
+       public void channelRead(ChannelHandlerContext ctx, Object msg) throws 
Exception {
+               try {
+                       decodeMsg(msg);
+               } catch (Throwable t) {
+                       notifyAllChannelsOfErrorAndClose(t);
+               }
+       }
+
+       private void notifyAllChannelsOfErrorAndClose(Throwable cause) {
+               if (channelError.compareAndSet(null, cause)) {
+                       try {
+                               for (RemoteInputChannel inputChannel : 
inputChannels.values()) {
+                                       inputChannel.onError(cause);
+                               }
+                       } catch (Throwable t) {
+                               // We can only swallow the Exception at this 
point. :(
+                               LOG.warn("An Exception was thrown during error 
notification of a remote input channel.", t);
+                       } finally {
+                               inputChannels.clear();
+
+                               if (ctx != null) {
+                                       ctx.close();
+                               }
+                       }
+               }
+       }
+
+       // 
------------------------------------------------------------------------
+
+       /**
+        * Checks for an error and rethrows it if one was reported.
+        */
+       private void checkError() throws IOException {
+               final Throwable t = channelError.get();
+
+               if (t != null) {
+                       if (t instanceof IOException) {
+                               throw (IOException) t;
+                       } else {
+                               throw new IOException("There has been an error 
in the channel.", t);
+                       }
+               }
+       }
+
+       private void decodeMsg(Object msg) throws Throwable {
+               final Class<?> msgClazz = msg.getClass();
+
+               // ---- Buffer 
--------------------------------------------------------
+               if (msgClazz == NettyMessage.BufferResponse.class) {
+                       NettyMessage.BufferResponse bufferOrEvent = 
(NettyMessage.BufferResponse) msg;
+
+                       RemoteInputChannel inputChannel = 
inputChannels.get(bufferOrEvent.receiverId);
+                       if (inputChannel == null) {
+                               bufferOrEvent.releaseBuffer();
+
+                               cancelRequestFor(bufferOrEvent.receiverId);
+
+                               return;
+                       }
+
+                       decodeBufferOrEvent(inputChannel, bufferOrEvent);
+
+               } else if (msgClazz == NettyMessage.ErrorResponse.class) {
+                       // ---- Error 
---------------------------------------------------------
+                       NettyMessage.ErrorResponse error = 
(NettyMessage.ErrorResponse) msg;
+
+                       SocketAddress remoteAddr = 
ctx.channel().remoteAddress();
+
+                       if (error.isFatalError()) {
+                               notifyAllChannelsOfErrorAndClose(new 
RemoteTransportException(
+                                       "Fatal error at remote task manager '" 
+ remoteAddr + "'.",
+                                       remoteAddr,
+                                       error.cause));
+                       } else {
+                               RemoteInputChannel inputChannel = 
inputChannels.get(error.receiverId);
+
+                               if (inputChannel != null) {
+                                       if (error.cause.getClass() == 
PartitionNotFoundException.class) {
+                                               
inputChannel.onFailedPartitionRequest();
+                                       } else {
+                                               inputChannel.onError(new 
RemoteTransportException(
+                                                       "Error at remote task 
manager '" + remoteAddr + "'.",
+                                                       remoteAddr,
+                                                       error.cause));
+                                       }
+                               }
+                       }
+               } else {
+                       throw new IllegalStateException("Received unknown 
message from producer: " + msg.getClass());
+               }
+       }
+
+       private void decodeBufferOrEvent(RemoteInputChannel inputChannel, 
NettyMessage.BufferResponse bufferOrEvent) throws Throwable {
+               try {
+                       if (bufferOrEvent.isBuffer()) {
+                               // ---- Buffer 
------------------------------------------------
+
+                               // Early return for empty buffers. Otherwise 
Netty's readBytes() throws an
+                               // IndexOutOfBoundsException.
+                               if (bufferOrEvent.getSize() == 0) {
+                                       
inputChannel.onEmptyBuffer(bufferOrEvent.sequenceNumber, bufferOrEvent.backlog);
+                                       return;
+                               }
+
+                               Buffer buffer = inputChannel.requestBuffer();
+                               if (buffer != null) {
+                                       buffer.setSize(bufferOrEvent.getSize());
+                                       
bufferOrEvent.getNettyBuffer().readBytes(buffer.getNioBuffer());
+
+                                       inputChannel.onBuffer(buffer, 
bufferOrEvent.sequenceNumber, bufferOrEvent.backlog);
+                               } else if (inputChannel.isReleased()) {
+                                       
cancelRequestFor(bufferOrEvent.receiverId);
+                               } else {
+                                       throw new IllegalStateException("No 
buffer available in credit-based input channel.");
+                               }
+                       } else {
+                               // ---- Event 
-------------------------------------------------
+                               // TODO We can just keep the serialized data in 
the Netty buffer and release it later at the reader
+                               byte[] byteArray = new 
byte[bufferOrEvent.getSize()];
+                               
bufferOrEvent.getNettyBuffer().readBytes(byteArray);
+
+                               MemorySegment memSeg = 
MemorySegmentFactory.wrap(byteArray);
+                               Buffer buffer = new Buffer(memSeg, 
FreeingBufferRecycler.INSTANCE, false);
+
+                               inputChannel.onBuffer(buffer, 
bufferOrEvent.sequenceNumber, bufferOrEvent.backlog);
+                       }
+               } finally {
+                       bufferOrEvent.releaseBuffer();
+               }
+       }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/268867ce/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessage.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessage.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessage.java
index 89fb9e8..db1b899 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessage.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessage.java
@@ -221,6 +221,8 @@ public abstract class NettyMessage {
 
                final int sequenceNumber;
 
+               final int backlog;
+
                // ---- Deserialization 
-----------------------------------------------
 
                final boolean isBuffer;
@@ -232,7 +234,8 @@ public abstract class NettyMessage {
 
                private BufferResponse(
                                ByteBuf retainedSlice, boolean isBuffer, int 
sequenceNumber,
-                               InputChannelID receiverId) {
+                               InputChannelID receiverId,
+                               int backlog) {
                        // When deserializing we first have to request a buffer 
from the respective buffer
                        // provider (at the handler) and copy the buffer from 
Netty's space to ours. Only
                        // retainedSlice is set in this case.
@@ -242,15 +245,17 @@ public abstract class NettyMessage {
                        this.isBuffer = isBuffer;
                        this.sequenceNumber = sequenceNumber;
                        this.receiverId = checkNotNull(receiverId);
+                       this.backlog = backlog;
                }
 
-               BufferResponse(Buffer buffer, int sequenceNumber, 
InputChannelID receiverId) {
+               BufferResponse(Buffer buffer, int sequenceNumber, 
InputChannelID receiverId, int backlog) {
                        this.buffer = checkNotNull(buffer);
                        this.retainedSlice = null;
                        this.isBuffer = buffer.isBuffer();
                        this.size = buffer.getSize();
                        this.sequenceNumber = sequenceNumber;
                        this.receiverId = checkNotNull(receiverId);
+                       this.backlog = backlog;
                }
 
                boolean isBuffer() {
@@ -280,7 +285,7 @@ public abstract class NettyMessage {
                ByteBuf write(ByteBufAllocator allocator) throws IOException {
                        checkNotNull(buffer, "No buffer instance to 
serialize.");
 
-                       int length = 16 + 4 + 1 + 4 + buffer.getSize();
+                       int length = 16 + 4 + 4 + 1 + 4 + buffer.getSize();
 
                        ByteBuf result = null;
                        try {
@@ -288,6 +293,7 @@ public abstract class NettyMessage {
 
                                receiverId.writeTo(result);
                                result.writeInt(sequenceNumber);
+                               result.writeInt(backlog);
                                result.writeBoolean(buffer.isBuffer());
                                result.writeInt(buffer.getSize());
                                result.writeBytes(buffer.getNioBuffer());
@@ -309,12 +315,13 @@ public abstract class NettyMessage {
                static BufferResponse readFrom(ByteBuf buffer) {
                        InputChannelID receiverId = 
InputChannelID.fromByteBuf(buffer);
                        int sequenceNumber = buffer.readInt();
+                       int backlog = buffer.readInt();
                        boolean isBuffer = buffer.readBoolean();
                        int size = buffer.readInt();
 
                        ByteBuf retainedSlice = buffer.readSlice(size).retain();
 
-                       return new BufferResponse(retainedSlice, isBuffer, 
sequenceNumber, receiverId);
+                       return new BufferResponse(retainedSlice, isBuffer, 
sequenceNumber, receiverId, backlog);
                }
        }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/268867ce/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandler.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandler.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandler.java
index 566b215..ab4798e 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandler.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandler.java
@@ -276,7 +276,7 @@ class PartitionRequestClientHandler extends 
ChannelInboundHandlerAdapter {
                                // Early return for empty buffers. Otherwise 
Netty's readBytes() throws an
                                // IndexOutOfBoundsException.
                                if (bufferOrEvent.getSize() == 0) {
-                                       
inputChannel.onEmptyBuffer(bufferOrEvent.sequenceNumber);
+                                       
inputChannel.onEmptyBuffer(bufferOrEvent.sequenceNumber, -1);
                                        return true;
                                }
 
@@ -295,7 +295,7 @@ class PartitionRequestClientHandler extends 
ChannelInboundHandlerAdapter {
                                                
buffer.setSize(bufferOrEvent.getSize());
                                                
bufferOrEvent.getNettyBuffer().readBytes(buffer.getNioBuffer());
 
-                                               inputChannel.onBuffer(buffer, 
bufferOrEvent.sequenceNumber);
+                                               inputChannel.onBuffer(buffer, 
bufferOrEvent.sequenceNumber, -1);
 
                                                return true;
                                        }
@@ -318,7 +318,7 @@ class PartitionRequestClientHandler extends 
ChannelInboundHandlerAdapter {
                                MemorySegment memSeg = 
MemorySegmentFactory.wrap(byteArray);
                                Buffer buffer = new Buffer(memSeg, 
FreeingBufferRecycler.INSTANCE, false);
 
-                               inputChannel.onBuffer(buffer, 
bufferOrEvent.sequenceNumber);
+                               inputChannel.onBuffer(buffer, 
bufferOrEvent.sequenceNumber, -1);
 
                                return true;
                        }
@@ -450,7 +450,7 @@ class PartitionRequestClientHandler extends 
ChannelInboundHandlerAdapter {
                                RemoteInputChannel inputChannel = 
inputChannels.get(stagedBufferResponse.receiverId);
 
                                if (inputChannel != null) {
-                                       inputChannel.onBuffer(buffer, 
stagedBufferResponse.sequenceNumber);
+                                       inputChannel.onBuffer(buffer, 
stagedBufferResponse.sequenceNumber, -1);
 
                                        success = true;
                                }

http://git-wip-us.apache.org/repos/asf/flink/blob/268867ce/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestQueue.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestQueue.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestQueue.java
index ff0f130..41f87ae 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestQueue.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestQueue.java
@@ -193,7 +193,8 @@ class PartitionRequestQueue extends 
ChannelInboundHandlerAdapter {
                                                BufferResponse msg = new 
BufferResponse(
                                                        next.buffer(),
                                                        
reader.getSequenceNumber(),
-                                                       reader.getReceiverId());
+                                                       reader.getReceiverId(),
+                                                       0);
 
                                                if 
(isEndOfPartitionEvent(next.buffer())) {
                                                        
reader.notifySubpartitionConsumed();

http://git-wip-us.apache.org/repos/asf/flink/blob/268867ce/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java
index cd00934..02c7b34 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.runtime.io.network.partition.consumer;
 
+import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.core.memory.MemorySegment;
 import org.apache.flink.runtime.event.TaskEvent;
 import org.apache.flink.runtime.io.network.ConnectionID;
@@ -32,11 +33,13 @@ import 
org.apache.flink.runtime.io.network.partition.ResultPartitionID;
 import org.apache.flink.runtime.metrics.groups.TaskIOMetricGroup;
 import org.apache.flink.util.ExceptionUtils;
 
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.GuardedBy;
 import java.io.IOException;
 import java.util.ArrayDeque;
+import java.util.Collections;
 import java.util.List;
 import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicInteger;
 
@@ -82,17 +85,19 @@ public class RemoteInputChannel extends InputChannel 
implements BufferRecycler,
        /** The initial number of exclusive buffers assigned to this channel. */
        private int initialCredit;
 
-       /** The current available buffers including both exclusive buffers and 
requested floating buffers. */
-       private final ArrayDeque<Buffer> availableBuffers = new ArrayDeque<>();
+       /** The available buffer queue wraps both exclusive and requested 
floating buffers. */
+       private final AvailableBufferQueue bufferQueue = new 
AvailableBufferQueue();
 
        /** The number of available buffers that have not been announced to the 
producer yet. */
        private final AtomicInteger unannouncedCredit = new AtomicInteger(0);
 
-       /** The number of unsent buffers in the producer's sub partition. */
-       private final AtomicInteger senderBacklog = new AtomicInteger(0);
+       /** The number of required buffers that equals to sender's backlog plus 
initial credit. */
+       @GuardedBy("bufferQueue")
+       private int numRequiredBuffers;
 
        /** The tag indicates whether this channel is waiting for additional 
floating buffers from the buffer pool. */
-       private final AtomicBoolean isWaitingForFloatingBuffers = new 
AtomicBoolean(false);
+       @GuardedBy("bufferQueue")
+       private boolean isWaitingForFloatingBuffers;
 
        public RemoteInputChannel(
                SingleInputGate inputGate,
@@ -133,10 +138,11 @@ public class RemoteInputChannel extends InputChannel 
implements BufferRecycler,
                checkArgument(segments.size() > 0, "The number of exclusive 
buffers per channel should be larger than 0.");
 
                this.initialCredit = segments.size();
+               this.numRequiredBuffers = segments.size();
 
-               synchronized(availableBuffers) {
+               synchronized(bufferQueue) {
                        for (MemorySegment segment : segments) {
-                               availableBuffers.add(new Buffer(segment, this));
+                               bufferQueue.addExclusiveBuffer(new 
Buffer(segment, this), numRequiredBuffers);
                        }
                }
        }
@@ -211,7 +217,7 @@ public class RemoteInputChannel extends InputChannel 
implements BufferRecycler,
        // 
------------------------------------------------------------------------
 
        @Override
-       boolean isReleased() {
+       public boolean isReleased() {
                return isReleased.get();
        }
 
@@ -227,7 +233,8 @@ public class RemoteInputChannel extends InputChannel 
implements BufferRecycler,
        void releaseAllResources() throws IOException {
                if (isReleased.compareAndSet(false, true)) {
 
-                       // Gather all exclusive buffers and recycle them to 
global pool in batch
+                       // Gather all exclusive buffers and recycle them to 
global pool in batch, because
+                       // we do not want to trigger redistribution of buffers 
after each recycle.
                        final List<MemorySegment> exclusiveRecyclingSegments = 
new ArrayList<>();
 
                        synchronized (receivedBuffers) {
@@ -240,16 +247,8 @@ public class RemoteInputChannel extends InputChannel 
implements BufferRecycler,
                                        }
                                }
                        }
-
-                       synchronized (availableBuffers) {
-                               Buffer buffer;
-                               while ((buffer = availableBuffers.poll()) != 
null) {
-                                       if (buffer.getRecycler() == this) {
-                                               
exclusiveRecyclingSegments.add(buffer.getMemorySegment());
-                                       } else {
-                                               buffer.recycle();
-                                       }
-                               }
+                       synchronized (bufferQueue) {
+                               
bufferQueue.releaseAll(exclusiveRecyclingSegments);
                        }
 
                        if (exclusiveRecyclingSegments.size() > 0) {
@@ -287,81 +286,93 @@ public class RemoteInputChannel extends InputChannel 
implements BufferRecycler,
        }
 
        /**
-        * Exclusive buffer is recycled to this input channel directly and it 
may trigger notify
-        * credit to producer.
+        * Exclusive buffer is recycled to this input channel directly and it 
may trigger return extra
+        * floating buffer and notify increased credit to the producer.
         *
         * @param segment The exclusive segment of this channel.
         */
        @Override
        public void recycle(MemorySegment segment) {
-               synchronized (availableBuffers) {
-                       // Important: the isReleased check should be inside the 
synchronized block.
-                       // that way the segment can also be returned to global 
pool after added into
-                       // the available queue during releasing all resources.
+               int numAddedBuffers;
+
+               synchronized (bufferQueue) {
+                       // Important: check the isReleased state inside 
synchronized block, so there is no
+                       // race condition when recycle and releaseAllResources 
running in parallel.
                        if (isReleased.get()) {
                                try {
-                                       
inputGate.returnExclusiveSegments(Arrays.asList(segment));
+                                       
inputGate.returnExclusiveSegments(Collections.singletonList(segment));
                                        return;
                                } catch (Throwable t) {
                                        ExceptionUtils.rethrow(t);
                                }
                        }
-                       availableBuffers.add(new Buffer(segment, this));
+                       numAddedBuffers = bufferQueue.addExclusiveBuffer(new 
Buffer(segment, this), numRequiredBuffers);
                }
 
-               if (unannouncedCredit.getAndAdd(1) == 0) {
+               if (numAddedBuffers > 0 && 
unannouncedCredit.getAndAdd(numAddedBuffers) == 0) {
                        notifyCreditAvailable();
                }
        }
 
        public int getNumberOfAvailableBuffers() {
-               synchronized (availableBuffers) {
-                       return availableBuffers.size();
+               synchronized (bufferQueue) {
+                       return bufferQueue.getAvailableBufferSize();
                }
        }
 
+       @VisibleForTesting
+       public int getNumberOfRequiredBuffers() {
+               return numRequiredBuffers;
+       }
+
        /**
         * The Buffer pool notifies this channel of an available floating 
buffer. If the channel is released or
         * currently does not need extra buffers, the buffer should be recycled 
to the buffer pool. Otherwise,
-        * the buffer will be added into the <tt>availableBuffers</tt> queue 
and the unannounced credit is
-        * increased by one.
+        * the buffer will be added into the <tt>bufferQueue</tt> and the 
unannounced credit is increased
+        * by one.
         *
         * @param buffer Buffer that becomes available in buffer pool.
         * @return True when this channel is waiting for more floating buffers, 
otherwise false.
         */
        @Override
        public boolean notifyBufferAvailable(Buffer buffer) {
-               checkState(isWaitingForFloatingBuffers.get(), "This channel 
should be waiting for floating buffers.");
+               // Check the isReleased state outside synchronized block first 
to avoid
+               // deadlock with releaseAllResources running in parallel.
+               if (isReleased.get()) {
+                       buffer.recycle();
+                       return false;
+               }
 
-               synchronized (availableBuffers) {
-                       // Important: the isReleased check should be inside the 
synchronized block.
-                       if (isReleased.get() || availableBuffers.size() >= 
senderBacklog.get()) {
-                               isWaitingForFloatingBuffers.set(false);
-                               buffer.recycle();
+               boolean needMoreBuffers = false;
+               synchronized (bufferQueue) {
+                       checkState(isWaitingForFloatingBuffers, "This channel 
should be waiting for floating buffers.");
 
+                       // Important: double check the isReleased state inside 
synchronized block, so there is no
+                       // race condition when notifyBufferAvailable and 
releaseAllResources running in parallel.
+                       if (isReleased.get() || 
bufferQueue.getAvailableBufferSize() >= numRequiredBuffers) {
+                               buffer.recycle();
                                return false;
                        }
 
-                       availableBuffers.add(buffer);
-
-                       if (unannouncedCredit.getAndAdd(1) == 0) {
-                               notifyCreditAvailable();
-                       }
+                       bufferQueue.addFloatingBuffer(buffer);
 
-                       if (availableBuffers.size() >= senderBacklog.get()) {
-                               isWaitingForFloatingBuffers.set(false);
-                               return false;
+                       if (bufferQueue.getAvailableBufferSize() == 
numRequiredBuffers) {
+                               isWaitingForFloatingBuffers = false;
                        } else {
-                               return true;
+                               needMoreBuffers =  true;
                        }
                }
+
+               if (unannouncedCredit.getAndAdd(1) == 0) {
+                       notifyCreditAvailable();
+               }
+
+               return needMoreBuffers;
        }
 
        @Override
        public void notifyBufferDestroyed() {
-               if (!isWaitingForFloatingBuffers.compareAndSet(true, false)) {
-                       throw new IllegalStateException("This channel should be 
waiting for floating buffers currently.");
-               }
+               // Nothing to do actually.
        }
 
        // 
------------------------------------------------------------------------
@@ -394,7 +405,58 @@ public class RemoteInputChannel extends InputChannel 
implements BufferRecycler,
                return inputGate.getBufferProvider();
        }
 
-       public void onBuffer(Buffer buffer, int sequenceNumber) {
+       /**
+        * Requests buffer from input channel directly for receiving network 
data.
+        * It should always return an available buffer in credit-based mode 
unless
+        * the channel has been released.
+        *
+        * @return The available buffer.
+        */
+       @Nullable
+       public Buffer requestBuffer() {
+               synchronized (bufferQueue) {
+                       return bufferQueue.takeBuffer();
+               }
+       }
+
+       /**
+        * Receives the backlog from the producer's buffer response. If the 
number of available
+        * buffers is less than backlog + initialCredit, it will request 
floating buffers from the buffer
+        * pool, and then notify unannounced credits to the producer.
+        *
+        * @param backlog The number of unsent buffers in the producer's sub 
partition.
+        */
+       @VisibleForTesting
+       void onSenderBacklog(int backlog) throws IOException {
+               int numRequestedBuffers = 0;
+
+               synchronized (bufferQueue) {
+                       // Important: check the isReleased state inside 
synchronized block, so there is no
+                       // race condition when onSenderBacklog and 
releaseAllResources running in parallel.
+                       if (isReleased.get()) {
+                               return;
+                       }
+
+                       numRequiredBuffers = backlog + initialCredit;
+                       while (bufferQueue.getAvailableBufferSize() < 
numRequiredBuffers && !isWaitingForFloatingBuffers) {
+                               Buffer buffer = 
inputGate.getBufferPool().requestBuffer();
+                               if (buffer != null) {
+                                       bufferQueue.addFloatingBuffer(buffer);
+                                       numRequestedBuffers++;
+                               } else if 
(inputGate.getBufferProvider().addBufferListener(this)) {
+                                       // If the channel has not got enough 
buffers, register it as listener to wait for more floating buffers.
+                                       isWaitingForFloatingBuffers = true;
+                                       break;
+                               }
+                       }
+               }
+
+               if (numRequestedBuffers > 0 && 
unannouncedCredit.getAndAdd(numRequestedBuffers) == 0) {
+                       notifyCreditAvailable();
+               }
+       }
+
+       public void onBuffer(Buffer buffer, int sequenceNumber, int backlog) 
throws IOException {
                boolean success = false;
 
                try {
@@ -416,6 +478,10 @@ public class RemoteInputChannel extends InputChannel 
implements BufferRecycler,
                                        }
                                }
                        }
+
+                       if (success && backlog >= 0) {
+                               onSenderBacklog(backlog);
+                       }
                } finally {
                        if (!success) {
                                buffer.recycle();
@@ -423,16 +489,23 @@ public class RemoteInputChannel extends InputChannel 
implements BufferRecycler,
                }
        }
 
-       public void onEmptyBuffer(int sequenceNumber) {
+       public void onEmptyBuffer(int sequenceNumber, int backlog) throws 
IOException {
+               boolean success = false;
+
                synchronized (receivedBuffers) {
                        if (!isReleased.get()) {
                                if (expectedSequenceNumber == sequenceNumber) {
                                        expectedSequenceNumber++;
+                                       success = true;
                                } else {
                                        onError(new 
BufferReorderingException(expectedSequenceNumber, sequenceNumber));
                                }
                        }
                }
+
+               if (success && backlog >= 0) {
+                       onSenderBacklog(backlog);
+               }
        }
 
        public void onFailedPartitionRequest() {
@@ -462,4 +535,82 @@ public class RemoteInputChannel extends InputChannel 
implements BufferRecycler,
                                expectedSequenceNumber, actualSequenceNumber);
                }
        }
+
+       /**
+        * Manages the exclusive and floating buffers of this channel, and 
handles the
+        * internal buffer related logic.
+        */
+       private static class AvailableBufferQueue {
+
+               /** The current available floating buffers from the fixed 
buffer pool. */
+               private final ArrayDeque<Buffer> floatingBuffers;
+
+               /** The current available exclusive buffers from the global 
buffer pool. */
+               private final ArrayDeque<Buffer> exclusiveBuffers;
+
+               AvailableBufferQueue() {
+                       this.exclusiveBuffers = new ArrayDeque<>();
+                       this.floatingBuffers = new ArrayDeque<>();
+               }
+
+               /**
+                * Adds an exclusive buffer (back) into the queue and recycles 
one floating buffer if the
+                * number of available buffers in queue is more than the 
required amount.
+                *
+                * @param buffer The exclusive buffer to add
+                * @param numRequiredBuffers The number of required buffers
+                *
+                * @return How many buffers were added to the queue
+                */
+               int addExclusiveBuffer(Buffer buffer, int numRequiredBuffers) {
+                       exclusiveBuffers.add(buffer);
+                       if (getAvailableBufferSize() > numRequiredBuffers) {
+                               Buffer floatingBuffer = floatingBuffers.poll();
+                               floatingBuffer.recycle();
+                               return 0;
+                       } else {
+                               return 1;
+                       }
+               }
+
+               void addFloatingBuffer(Buffer buffer) {
+                       floatingBuffers.add(buffer);
+               }
+
+               /**
+                * Takes the floating buffer first in order to make full use of 
floating
+                * buffers reasonably.
+                *
+                * @return An available floating or exclusive buffer, may be 
null
+                * if the channel is released.
+                */
+               @Nullable
+               Buffer takeBuffer() {
+                       if (floatingBuffers.size() > 0) {
+                               return floatingBuffers.poll();
+                       } else {
+                               return exclusiveBuffers.poll();
+                       }
+               }
+
+               /**
+                * The floating buffer is recycled to local buffer pool 
directly, and the
+                * exclusive buffer will be gathered to return to global buffer 
pool later.
+                *
+                * @param exclusiveSegments The list that we will add exclusive 
segments into.
+                */
+               void releaseAll(List<MemorySegment> exclusiveSegments) {
+                       Buffer buffer;
+                       while ((buffer = floatingBuffers.poll()) != null) {
+                               buffer.recycle();
+                       }
+                       while ((buffer = exclusiveBuffers.poll()) != null) {
+                               
exclusiveSegments.add(buffer.getMemorySegment());
+                       }
+               }
+
+               int getAvailableBufferSize() {
+                       return floatingBuffers.size() + exclusiveBuffers.size();
+               }
+       }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/268867ce/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageSerializationTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageSerializationTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageSerializationTest.java
index 0651f97..8c87ceb 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageSerializationTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageSerializationTest.java
@@ -62,7 +62,7 @@ public class NettyMessageSerializationTest {
                                nioBuffer.putInt(i);
                        }
 
-                       NettyMessage.BufferResponse expected = new 
NettyMessage.BufferResponse(buffer, random.nextInt(), new InputChannelID());
+                       NettyMessage.BufferResponse expected = new 
NettyMessage.BufferResponse(buffer, random.nextInt(), new InputChannelID(), 
random.nextInt());
                        NettyMessage.BufferResponse actual = 
encodeAndDecode(expected);
 
                        // Verify recycle has been called on buffer instance
@@ -85,6 +85,7 @@ public class NettyMessageSerializationTest {
 
                        assertEquals(expected.sequenceNumber, 
actual.sequenceNumber);
                        assertEquals(expected.receiverId, actual.receiverId);
+                       assertEquals(expected.backlog, actual.backlog);
                }
 
                {

http://git-wip-us.apache.org/repos/asf/flink/blob/268867ce/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java
index e1e5bd3..d3ff6c2 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java
@@ -30,23 +30,16 @@ import 
org.apache.flink.runtime.io.network.partition.ResultPartitionID;
 import org.apache.flink.runtime.io.network.partition.consumer.InputChannelID;
 import 
org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel;
 import org.apache.flink.runtime.io.network.util.TestBufferFactory;
-import org.apache.flink.runtime.testutils.DiscardingRecycler;
 
 import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
 import org.apache.flink.shaded.netty4.io.netty.buffer.UnpooledByteBufAllocator;
 import org.apache.flink.shaded.netty4.io.netty.channel.Channel;
 import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext;
-import 
org.apache.flink.shaded.netty4.io.netty.channel.embedded.EmbeddedChannel;
 
 import org.junit.Test;
-import org.mockito.invocation.InvocationOnMock;
-import org.mockito.stubbing.Answer;
 
 import java.io.IOException;
-import java.util.concurrent.atomic.AtomicReference;
 
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertTrue;
 import static org.mockito.Matchers.any;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
@@ -80,19 +73,19 @@ public class PartitionRequestClientHandlerTest {
                when(inputChannel.getInputChannelId()).thenReturn(new 
InputChannelID());
                
when(inputChannel.getBufferProvider()).thenReturn(bufferProvider);
 
-               final BufferResponse ReceivedBuffer = createBufferResponse(
-                               TestBufferFactory.createBuffer(), 0, 
inputChannel.getInputChannelId());
+               final BufferResponse receivedBuffer = createBufferResponse(
+                       TestBufferFactory.createBuffer(), 0, 
inputChannel.getInputChannelId(), 2);
 
                final PartitionRequestClientHandler client = new 
PartitionRequestClientHandler();
                client.addInputChannel(inputChannel);
 
-               client.channelRead(mock(ChannelHandlerContext.class), 
ReceivedBuffer);
+               client.channelRead(mock(ChannelHandlerContext.class), 
receivedBuffer);
        }
 
        /**
         * Tests a fix for FLINK-1761.
         *
-        * <p> FLINK-1761 discovered an IndexOutOfBoundsException, when 
receiving buffers of size 0.
+        * <p>FLINK-1761 discovered an IndexOutOfBoundsException, when 
receiving buffers of size 0.
         */
        @Test
        public void testReceiveEmptyBuffer() throws Exception {
@@ -108,10 +101,11 @@ public class PartitionRequestClientHandlerTest {
                final Buffer emptyBuffer = TestBufferFactory.createBuffer();
                emptyBuffer.setSize(0);
 
+               final int backlog = 2;
                final BufferResponse receivedBuffer = createBufferResponse(
-                               emptyBuffer, 0, 
inputChannel.getInputChannelId());
+                       emptyBuffer, 0, inputChannel.getInputChannelId(), 
backlog);
 
-               final PartitionRequestClientHandler client = new 
PartitionRequestClientHandler();
+               final CreditBasedClientHandler client = new 
CreditBasedClientHandler();
                client.addInputChannel(inputChannel);
 
                // Read the empty buffer
@@ -119,6 +113,51 @@ public class PartitionRequestClientHandlerTest {
 
                // This should not throw an exception
                verify(inputChannel, never()).onError(any(Throwable.class));
+               verify(inputChannel, times(1)).onEmptyBuffer(0, backlog);
+       }
+
+       /**
+        * Verifies that {@link RemoteInputChannel#onBuffer(Buffer, int, int)} 
is called when a
+        * {@link BufferResponse} is received.
+        */
+       @Test
+       public void testReceiveBuffer() throws Exception {
+               final Buffer buffer = TestBufferFactory.createBuffer();
+               final RemoteInputChannel inputChannel = 
mock(RemoteInputChannel.class);
+               when(inputChannel.getInputChannelId()).thenReturn(new 
InputChannelID());
+               when(inputChannel.requestBuffer()).thenReturn(buffer);
+
+               final int backlog = 2;
+               final BufferResponse bufferResponse = createBufferResponse(
+                       TestBufferFactory.createBuffer(), 0, 
inputChannel.getInputChannelId(), backlog);
+
+               final CreditBasedClientHandler client = new 
CreditBasedClientHandler();
+               client.addInputChannel(inputChannel);
+
+               client.channelRead(mock(ChannelHandlerContext.class), 
bufferResponse);
+
+               verify(inputChannel, times(1)).onBuffer(buffer, 0, backlog);
+       }
+
+       /**
+        * Verifies that {@link RemoteInputChannel#onError(Throwable)} is 
called when a
+        * {@link BufferResponse} is received but no available buffer in input 
channel.
+        */
+       @Test
+       public void testThrowExceptionForNoAvailableBuffer() throws Exception {
+               final RemoteInputChannel inputChannel = 
mock(RemoteInputChannel.class);
+               when(inputChannel.getInputChannelId()).thenReturn(new 
InputChannelID());
+               when(inputChannel.requestBuffer()).thenReturn(null);
+
+               final BufferResponse bufferResponse = createBufferResponse(
+                       TestBufferFactory.createBuffer(), 0, 
inputChannel.getInputChannelId(), 2);
+
+               final CreditBasedClientHandler client = new 
CreditBasedClientHandler();
+               client.addInputChannel(inputChannel);
+
+               client.channelRead(mock(ChannelHandlerContext.class), 
bufferResponse);
+
+               verify(inputChannel, 
times(1)).onError(any(IllegalStateException.class));
        }
 
        /**
@@ -136,8 +175,8 @@ public class PartitionRequestClientHandlerTest {
                
when(inputChannel.getBufferProvider()).thenReturn(bufferProvider);
 
                final ErrorResponse partitionNotFound = new ErrorResponse(
-                               new PartitionNotFoundException(new 
ResultPartitionID()),
-                               inputChannel.getInputChannelId());
+                       new PartitionNotFoundException(new ResultPartitionID()),
+                       inputChannel.getInputChannelId());
 
                final PartitionRequestClientHandler client = new 
PartitionRequestClientHandler();
                client.addInputChannel(inputChannel);
@@ -169,95 +208,19 @@ public class PartitionRequestClientHandlerTest {
                client.cancelRequestFor(inputChannel.getInputChannelId());
        }
 
-       /**
-        * Tests that an unsuccessful message decode call for a staged message
-        * does not leave the channel with auto read set to false.
-        */
-       @Test
-       @SuppressWarnings("unchecked")
-       public void testAutoReadAfterUnsuccessfulStagedMessage() throws 
Exception {
-               PartitionRequestClientHandler handler = new 
PartitionRequestClientHandler();
-               EmbeddedChannel channel = new EmbeddedChannel(handler);
-
-               final AtomicReference<BufferListener> listener = new 
AtomicReference<>();
-
-               BufferProvider bufferProvider = mock(BufferProvider.class);
-               
when(bufferProvider.addBufferListener(any(BufferListener.class))).thenAnswer(new
 Answer<Boolean>() {
-                       @Override
-                       @SuppressWarnings("unchecked")
-                       public Boolean answer(InvocationOnMock invocation) 
throws Throwable {
-                               listener.set((BufferListener) 
invocation.getArguments()[0]);
-                               return true;
-                       }
-               });
-
-               when(bufferProvider.requestBuffer()).thenReturn(null);
-
-               InputChannelID channelId = new InputChannelID(0, 0);
-               RemoteInputChannel inputChannel = 
mock(RemoteInputChannel.class);
-               when(inputChannel.getInputChannelId()).thenReturn(channelId);
-
-               // The 3rd staged msg has a null buffer provider
-               
when(inputChannel.getBufferProvider()).thenReturn(bufferProvider, 
bufferProvider, null);
-
-               handler.addInputChannel(inputChannel);
-
-               BufferResponse msg = createBufferResponse(createBuffer(true), 
0, channelId);
-
-               // Write 1st buffer msg. No buffer is available, therefore the 
buffer
-               // should be staged and auto read should be set to false.
-               assertTrue(channel.config().isAutoRead());
-               channel.writeInbound(msg);
-
-               // No buffer available, auto read false
-               assertFalse(channel.config().isAutoRead());
-
-               // Write more buffers... all staged.
-               msg = createBufferResponse(createBuffer(true), 1, channelId);
-               channel.writeInbound(msg);
-
-               msg = createBufferResponse(createBuffer(true), 2, channelId);
-               channel.writeInbound(msg);
-
-               // Notify about buffer => handle 1st msg
-               Buffer availableBuffer = createBuffer(false);
-               listener.get().notifyBufferAvailable(availableBuffer);
-
-               // Start processing of staged buffers (in run pending tasks). 
Make
-               // sure that the buffer provider acts like it's destroyed.
-               
when(bufferProvider.addBufferListener(any(BufferListener.class))).thenReturn(false);
-               when(bufferProvider.isDestroyed()).thenReturn(true);
-
-               // Execute all tasks that are scheduled in the event loop. 
Further
-               // eventLoop().execute() calls are directly executed, if they 
are
-               // called in the scope of this call.
-               channel.runPendingTasks();
-
-               assertTrue(channel.config().isAutoRead());
-       }
-
        // 
---------------------------------------------------------------------------------------------
 
-       private static Buffer createBuffer(boolean fill) {
-               MemorySegment segment = 
MemorySegmentFactory.allocateUnpooledSegment(1024, null);
-               if (fill) {
-                       for (int i = 0; i < 1024; i++) {
-                               segment.put(i, (byte) i);
-                       }
-               }
-               return new Buffer(segment, DiscardingRecycler.INSTANCE, true);
-       }
-
        /**
         * Returns a deserialized buffer message as it would be received during 
runtime.
         */
        private BufferResponse createBufferResponse(
                        Buffer buffer,
                        int sequenceNumber,
-                       InputChannelID receivingChannelId) throws IOException {
+                       InputChannelID receivingChannelId,
+                       int backlog) throws IOException {
 
                // Mock buffer to serialize
-               BufferResponse resp = new BufferResponse(buffer, 
sequenceNumber, receivingChannelId);
+               BufferResponse resp = new BufferResponse(buffer, 
sequenceNumber, receivingChannelId, backlog);
 
                ByteBuf serialized = 
resp.write(UnpooledByteBufAllocator.DEFAULT);
 

http://git-wip-us.apache.org/repos/asf/flink/blob/268867ce/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java
index 6f98119..81788c9 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java
@@ -216,7 +216,7 @@ public class InputGateConcurrentTest {
 
                @Override
                void addBuffer(Buffer buffer) throws Exception {
-                       channel.onBuffer(buffer, seq++);
+                       channel.onBuffer(buffer, seq++, -1);
                }
        }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/268867ce/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java
index 324a060..4e90265 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java
@@ -206,9 +206,9 @@ public class InputGateFairnessTest {
                        channels[i] = channel;
                        
                        for (int p = 0; p < buffersPerChannel; p++) {
-                               channel.onBuffer(mockBuffer, p);
+                               channel.onBuffer(mockBuffer, p, -1);
                        }
-                       
channel.onBuffer(EventSerializer.toBuffer(EndOfPartitionEvent.INSTANCE), 
buffersPerChannel);
+                       
channel.onBuffer(EventSerializer.toBuffer(EndOfPartitionEvent.INSTANCE), 
buffersPerChannel, -1);
 
                        gate.setInputChannel(new 
IntermediateResultPartitionID(), channel);
                }
@@ -263,7 +263,7 @@ public class InputGateFairnessTest {
                        gate.setInputChannel(new 
IntermediateResultPartitionID(), channel);
                }
 
-               channels[11].onBuffer(mockBuffer, 0);
+               channels[11].onBuffer(mockBuffer, 0, -1);
                channelSequenceNums[11]++;
 
                // read all the buffers and the EOF event
@@ -325,7 +325,7 @@ public class InputGateFairnessTest {
                Collections.shuffle(poss);
 
                for (int i : poss) {
-                       partitions[i].onBuffer(buffer, sequenceNumbers[i]++);
+                       partitions[i].onBuffer(buffer, sequenceNumbers[i]++, 
-1);
                }
        }
        

http://git-wip-us.apache.org/repos/asf/flink/blob/268867ce/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java
index d791ced..863f886 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java
@@ -18,24 +18,28 @@
 
 package org.apache.flink.runtime.io.network.partition.consumer;
 
-import org.apache.flink.core.memory.MemorySegment;
-import org.apache.flink.core.memory.MemorySegmentFactory;
+import org.apache.flink.api.common.JobID;
 import org.apache.flink.runtime.execution.CancelTaskException;
 import org.apache.flink.runtime.io.network.ConnectionID;
 import org.apache.flink.runtime.io.network.ConnectionManager;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.BufferPool;
+import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
 import org.apache.flink.runtime.io.network.netty.PartitionRequestClient;
 import org.apache.flink.runtime.io.network.partition.ProducerFailedException;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
 import org.apache.flink.runtime.io.network.util.TestBufferFactory;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
 import org.apache.flink.runtime.metrics.groups.UnregisteredMetricGroups;
+import org.apache.flink.runtime.taskmanager.TaskActions;
 
 import org.apache.flink.shaded.guava18.com.google.common.collect.Lists;
 
 import org.junit.Test;
-import scala.Tuple2;
 
 import java.io.IOException;
+import java.util.ArrayDeque;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.concurrent.Callable;
@@ -43,12 +47,14 @@ import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
 
+import scala.Tuple2;
+
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 import static org.mockito.Matchers.any;
-import static org.mockito.Matchers.anyListOf;
 import static org.mockito.Matchers.eq;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.spy;
@@ -66,10 +72,10 @@ public class RemoteInputChannelTest {
                final Buffer buffer = TestBufferFactory.createBuffer();
 
                // The test
-               inputChannel.onBuffer(buffer.retain(), 0);
+               inputChannel.onBuffer(buffer.retain(), 0, -1);
 
                // This does not yet throw the exception, but sets the error at 
the channel.
-               inputChannel.onBuffer(buffer, 29);
+               inputChannel.onBuffer(buffer, 29, -1);
 
                try {
                        inputChannel.getNextBuffer();
@@ -113,7 +119,7 @@ public class RemoteInputChannelTest {
                                                        for (int j = 0; j < 
128; j++) {
                                                                // this is the 
same buffer over and over again which will be
                                                                // recycled by 
the RemoteInputChannel
-                                                               
inputChannel.onBuffer(buffer.retain(), j);
+                                                               
inputChannel.onBuffer(buffer.retain(), j, -1);
                                                        }
 
                                                        if 
(inputChannel.isReleased()) {
@@ -301,81 +307,562 @@ public class RemoteInputChannelTest {
        }
 
        /**
-        * Tests {@link RemoteInputChannel#recycle(MemorySegment)}, verifying 
the exclusive segment is
-        * recycled to available buffers directly and it triggers notify of 
announced credit.
+        * Tests to verify the behaviours of three different processes if the 
number of available
+        * buffers is less than required buffers.
+        *
+        * 1. Recycle the floating buffer
+        * 2. Recycle the exclusive buffer
+        * 3. Decrease the sender's backlog
         */
        @Test
-       public void testRecycleExclusiveBufferBeforeReleased() throws Exception 
{
-               final SingleInputGate inputGate = mock(SingleInputGate.class);
-               final RemoteInputChannel inputChannel = 
spy(createRemoteInputChannel(inputGate));
-
-               // Recycle exclusive segment
-               
inputChannel.recycle(MemorySegmentFactory.allocateUnpooledSegment(1024, 
inputChannel));
+       public void testAvailableBuffersLessThanRequiredBuffers() throws 
Exception {
+               // Setup
+               final NetworkBufferPool networkBufferPool = new 
NetworkBufferPool(16, 32);
+               final int numExclusiveBuffers = 2;
+               final int numFloatingBuffers = 14;
 
-               assertEquals("There should be one buffer available after 
recycle.",
-                       1, inputChannel.getNumberOfAvailableBuffers());
-               verify(inputChannel, times(1)).notifyCreditAvailable();
+               final SingleInputGate inputGate = createSingleInputGate();
+               final RemoteInputChannel inputChannel = 
createRemoteInputChannel(inputGate);
+               
inputGate.setInputChannel(inputChannel.partitionId.getPartitionId(), 
inputChannel);
+               try {
+                       final BufferPool bufferPool = 
spy(networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers));
+                       inputGate.setBufferPool(bufferPool);
+                       inputGate.assignExclusiveSegments(networkBufferPool, 
numExclusiveBuffers);
+
+                       // Prepare the exclusive and floating buffers to verify 
recycle logic later
+                       final Buffer exclusiveBuffer = 
inputChannel.requestBuffer();
+                       assertNotNull(exclusiveBuffer);
+
+                       final int numRecycleFloatingBuffers = 2;
+                       final ArrayDeque<Buffer> floatingBufferQueue = new 
ArrayDeque<>(numRecycleFloatingBuffers);
+                       for (int i = 0; i < numRecycleFloatingBuffers; i++) {
+                               Buffer floatingBuffer = 
bufferPool.requestBuffer();
+                               assertNotNull(floatingBuffer);
+                               floatingBufferQueue.add(floatingBuffer);
+                       }
 
-               
inputChannel.recycle(MemorySegmentFactory.allocateUnpooledSegment(1024, 
inputChannel));
+                       verify(bufferPool, 
times(numRecycleFloatingBuffers)).requestBuffer();
+
+                       // Receive the producer's backlog more than the number 
of available floating buffers
+                       inputChannel.onSenderBacklog(14);
+
+                       // The channel requests (backlog + numExclusiveBuffers) 
floating buffers from local pool.
+                       // It does not get enough floating buffers and register 
as buffer listener
+                       verify(bufferPool, times(15)).requestBuffer();
+                       verify(bufferPool, 
times(1)).addBufferListener(inputChannel);
+                       assertEquals("There should be 13 buffers available in 
the channel",
+                               13, inputChannel.getNumberOfAvailableBuffers());
+                       assertEquals("There should be 16 buffers required in 
the channel",
+                               16, inputChannel.getNumberOfRequiredBuffers());
+                       assertEquals("There should be 0 buffer available in 
local pool",
+                               0, 
bufferPool.getNumberOfAvailableMemorySegments());
+
+                       // Increase the backlog
+                       inputChannel.onSenderBacklog(16);
+
+                       // The channel is already in the status of waiting for 
buffers and will not request any more
+                       verify(bufferPool, times(15)).requestBuffer();
+                       verify(bufferPool, 
times(1)).addBufferListener(inputChannel);
+                       assertEquals("There should be 13 buffers available in 
the channel",
+                               13, inputChannel.getNumberOfAvailableBuffers());
+                       assertEquals("There should be 18 buffers required in 
the channel",
+                               18, inputChannel.getNumberOfRequiredBuffers());
+                       assertEquals("There should be 0 buffer available in 
local pool",
+                               0, 
bufferPool.getNumberOfAvailableMemorySegments());
+
+                       // Recycle one floating buffer
+                       floatingBufferQueue.poll().recycle();
+
+                       // Assign the floating buffer to the listener and the 
channel is still waiting for more floating buffers
+                       verify(bufferPool, times(15)).requestBuffer();
+                       verify(bufferPool, 
times(1)).addBufferListener(inputChannel);
+                       assertEquals("There should be 14 buffers available in 
the channel",
+                               14, inputChannel.getNumberOfAvailableBuffers());
+                       assertEquals("There should be 18 buffers required in 
the channel",
+                               18, inputChannel.getNumberOfRequiredBuffers());
+                       assertEquals("There should be 0 buffer available in 
local pool",
+                               0, 
bufferPool.getNumberOfAvailableMemorySegments());
+
+                       // Recycle one more floating buffer
+                       floatingBufferQueue.poll().recycle();
+
+                       // Assign the floating buffer to the listener and the 
channel is still waiting for more floating buffers
+                       verify(bufferPool, times(15)).requestBuffer();
+                       verify(bufferPool, 
times(1)).addBufferListener(inputChannel);
+                       assertEquals("There should be 15 buffers available in 
the channel",
+                               15, inputChannel.getNumberOfAvailableBuffers());
+                       assertEquals("There should be 18 buffers required in 
the channel",
+                               18, inputChannel.getNumberOfRequiredBuffers());
+                       assertEquals("There should be 0 buffer available in 
local pool",
+                               0, 
bufferPool.getNumberOfAvailableMemorySegments());
+
+                       // Decrease the backlog
+                       inputChannel.onSenderBacklog(15);
+
+                       // Only the number of required buffers is changed by 
(backlog + numExclusiveBuffers)
+                       verify(bufferPool, times(15)).requestBuffer();
+                       verify(bufferPool, 
times(1)).addBufferListener(inputChannel);
+                       assertEquals("There should be 15 buffers available in 
the channel",
+                               15, inputChannel.getNumberOfAvailableBuffers());
+                       assertEquals("There should be 17 buffers required in 
the channel",
+                               17, inputChannel.getNumberOfRequiredBuffers());
+                       assertEquals("There should be 0 buffer available in 
local pool",
+                               0, 
bufferPool.getNumberOfAvailableMemorySegments());
+
+                       // Recycle one exclusive buffer
+                       exclusiveBuffer.recycle();
+
+                       // The exclusive buffer is returned to the channel 
directly
+                       verify(bufferPool, times(15)).requestBuffer();
+                       verify(bufferPool, 
times(1)).addBufferListener(inputChannel);
+                       assertEquals("There should be 16 buffers available in 
the channel",
+                               16, inputChannel.getNumberOfAvailableBuffers());
+                       assertEquals("There should be 17 buffers required in 
the channel",
+                               17, inputChannel.getNumberOfRequiredBuffers());
+                       assertEquals("There should be 0 buffers available in 
local pool",
+                               0, 
bufferPool.getNumberOfAvailableMemorySegments());
+
+               } finally {
+                       // Release all the buffer resources
+                       inputChannel.releaseAllResources();
 
-               assertEquals("There should be two buffers available after 
recycle.",
-                       2, inputChannel.getNumberOfAvailableBuffers());
-               // It should be called only once when increased from zero.
-               verify(inputChannel, times(1)).notifyCreditAvailable();
+                       networkBufferPool.destroyAllBufferPools();
+                       networkBufferPool.destroy();
+               }
        }
 
        /**
-        * Tests {@link RemoteInputChannel#recycle(MemorySegment)}, verifying 
the exclusive segment is
-        * recycled to global pool via input gate when channel is released.
+        * Tests to verify the behaviours of recycling floating and exclusive 
buffers if the number of available
+        * buffers equals to required buffers.
         */
        @Test
-       public void testRecycleExclusiveBufferAfterReleased() throws Exception {
+       public void testAvailableBuffersEqualToRequiredBuffers() throws 
Exception {
                // Setup
-               final SingleInputGate inputGate = mock(SingleInputGate.class);
-               final RemoteInputChannel inputChannel = 
spy(createRemoteInputChannel(inputGate));
-
-               inputChannel.releaseAllResources();
+               final NetworkBufferPool networkBufferPool = new 
NetworkBufferPool(16, 32);
+               final int numExclusiveBuffers = 2;
+               final int numFloatingBuffers = 14;
 
-               // Recycle exclusive segment after channel released
-               
inputChannel.recycle(MemorySegmentFactory.allocateUnpooledSegment(1024, 
inputChannel));
+               final SingleInputGate inputGate = createSingleInputGate();
+               final RemoteInputChannel inputChannel = 
createRemoteInputChannel(inputGate);
+               
inputGate.setInputChannel(inputChannel.partitionId.getPartitionId(), 
inputChannel);
+               try {
+                       final BufferPool bufferPool = 
spy(networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers));
+                       inputGate.setBufferPool(bufferPool);
+                       inputGate.assignExclusiveSegments(networkBufferPool, 
numExclusiveBuffers);
+
+                       // Prepare the exclusive and floating buffers to verify 
recycle logic later
+                       final Buffer exclusiveBuffer = 
inputChannel.requestBuffer();
+                       assertNotNull(exclusiveBuffer);
+                       final Buffer floatingBuffer = 
bufferPool.requestBuffer();
+                       assertNotNull(floatingBuffer);
+                       verify(bufferPool, times(1)).requestBuffer();
+
+                       // Receive the producer's backlog
+                       inputChannel.onSenderBacklog(12);
+
+                       // The channel requests (backlog + numExclusiveBuffers) 
floating buffers from local pool
+                       // and gets enough floating buffers
+                       verify(bufferPool, times(14)).requestBuffer();
+                       verify(bufferPool, 
times(0)).addBufferListener(inputChannel);
+                       assertEquals("There should be 14 buffers available in 
the channel",
+                               14, inputChannel.getNumberOfAvailableBuffers());
+                       assertEquals("There should be 14 buffers required in 
the channel",
+                               14, inputChannel.getNumberOfRequiredBuffers());
+                       assertEquals("There should be 0 buffer available in 
local pool",
+                               0, 
bufferPool.getNumberOfAvailableMemorySegments());
+
+                       // Recycle one floating buffer
+                       floatingBuffer.recycle();
+
+                       // The floating buffer is returned to local buffer 
directly because the channel is not waiting
+                       // for floating buffers
+                       verify(bufferPool, times(14)).requestBuffer();
+                       verify(bufferPool, 
times(0)).addBufferListener(inputChannel);
+                       assertEquals("There should be 14 buffers available in 
the channel",
+                               14, inputChannel.getNumberOfAvailableBuffers());
+                       assertEquals("There should be 14 buffers required in 
the channel",
+                               14, inputChannel.getNumberOfRequiredBuffers());
+                       assertEquals("There should be 1 buffer available in 
local pool",
+                               1, 
bufferPool.getNumberOfAvailableMemorySegments());
+
+                       // Recycle one exclusive buffer
+                       exclusiveBuffer.recycle();
+
+                       // Return one extra floating buffer to the local pool 
because the number of available buffers
+                       // already equals to required buffers
+                       verify(bufferPool, times(14)).requestBuffer();
+                       verify(bufferPool, 
times(0)).addBufferListener(inputChannel);
+                       assertEquals("There should be 14 buffers available in 
the channel",
+                               14, inputChannel.getNumberOfAvailableBuffers());
+                       assertEquals("There should be 14 buffers required in 
the channel",
+                               14, inputChannel.getNumberOfRequiredBuffers());
+                       assertEquals("There should be 2 buffers available in 
local pool",
+                               2, 
bufferPool.getNumberOfAvailableMemorySegments());
+
+               } finally {
+                       // Release all the buffer resources
+                       inputChannel.releaseAllResources();
 
-               assertEquals("Resource leak during recycling buffer after 
channel is released.",
-                       0, inputChannel.getNumberOfAvailableBuffers());
-               verify(inputChannel, times(0)).notifyCreditAvailable();
-               verify(inputGate, 
times(1)).returnExclusiveSegments(anyListOf(MemorySegment.class));
+                       networkBufferPool.destroyAllBufferPools();
+                       networkBufferPool.destroy();
+               }
        }
 
        /**
-        * Tests {@link RemoteInputChannel#releaseAllResources()}, verifying 
the exclusive segments are
-        * recycled to global pool via input gate and no resource leak.
+        * Tests to verify the behaviours of recycling floating and exclusive 
buffers if the number of available
+        * buffers is more than required buffers by decreasing the sender's 
backlog.
         */
        @Test
-       public void testReleaseExclusiveBuffers() throws Exception {
+       public void testAvailableBuffersMoreThanRequiredBuffers() throws 
Exception {
                // Setup
-               final SingleInputGate inputGate = mock(SingleInputGate.class);
+               final NetworkBufferPool networkBufferPool = new 
NetworkBufferPool(16, 32);
+               final int numExclusiveBuffers = 2;
+               final int numFloatingBuffers = 14;
+
+               final SingleInputGate inputGate = createSingleInputGate();
                final RemoteInputChannel inputChannel = 
createRemoteInputChannel(inputGate);
+               
inputGate.setInputChannel(inputChannel.partitionId.getPartitionId(), 
inputChannel);
+               try {
+                       final BufferPool bufferPool = 
spy(networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers));
+                       inputGate.setBufferPool(bufferPool);
+                       inputGate.assignExclusiveSegments(networkBufferPool, 
numExclusiveBuffers);
+
+                       // Prepare the exclusive and floating buffers to verify 
recycle logic later
+                       final Buffer exclusiveBuffer = 
inputChannel.requestBuffer();
+                       assertNotNull(exclusiveBuffer);
+
+                       final Buffer floatingBuffer = 
bufferPool.requestBuffer();
+                       assertNotNull(floatingBuffer);
+
+                       verify(bufferPool, times(1)).requestBuffer();
+
+                       // Receive the producer's backlog
+                       inputChannel.onSenderBacklog(12);
+
+                       // The channel gets enough floating buffers from local 
pool
+                       verify(bufferPool, times(14)).requestBuffer();
+                       verify(bufferPool, 
times(0)).addBufferListener(inputChannel);
+                       assertEquals("There should be 14 buffers available in 
the channel",
+                               14, inputChannel.getNumberOfAvailableBuffers());
+                       assertEquals("There should be 14 buffers required in 
the channel",
+                               14, inputChannel.getNumberOfRequiredBuffers());
+                       assertEquals("There should be 0 buffer available in 
local pool",
+                               0, 
bufferPool.getNumberOfAvailableMemorySegments());
+
+                       // Decrease the backlog to make the number of available 
buffers more than required buffers
+                       inputChannel.onSenderBacklog(10);
+
+                       // Only the number of required buffers is changed by 
(backlog + numExclusiveBuffers)
+                       verify(bufferPool, times(14)).requestBuffer();
+                       verify(bufferPool, 
times(0)).addBufferListener(inputChannel);
+                       assertEquals("There should be 14 buffers available in 
the channel",
+                               14, inputChannel.getNumberOfAvailableBuffers());
+                       assertEquals("There should be 12 buffers required in 
the channel",
+                               12, inputChannel.getNumberOfRequiredBuffers());
+                       assertEquals("There should be 0 buffer available in 
local pool",
+                               0, 
bufferPool.getNumberOfAvailableMemorySegments());
+
+                       // Recycle one exclusive buffer
+                       exclusiveBuffer.recycle();
+
+                       // Return one extra floating buffer to the local pool 
because the number of available buffers
+                       // is more than required buffers
+                       verify(bufferPool, times(14)).requestBuffer();
+                       verify(bufferPool, 
times(0)).addBufferListener(inputChannel);
+                       assertEquals("There should be 14 buffers available in 
the channel",
+                               14, inputChannel.getNumberOfAvailableBuffers());
+                       assertEquals("There should be 12 buffers required in 
the channel",
+                               12, inputChannel.getNumberOfRequiredBuffers());
+                       assertEquals("There should be 1 buffer available in 
local pool",
+                               1, 
bufferPool.getNumberOfAvailableMemorySegments());
+
+                       // Recycle one floating buffer
+                       floatingBuffer.recycle();
+
+                       // The floating buffer is returned to local pool 
directly because the channel is not waiting for
+                       // floating buffers
+                       verify(bufferPool, times(14)).requestBuffer();
+                       verify(bufferPool, 
times(0)).addBufferListener(inputChannel);
+                       assertEquals("There should be 14 buffers available in 
the channel",
+                               14, inputChannel.getNumberOfAvailableBuffers());
+                       assertEquals("There should be 12 buffers required in 
the channel",
+                               12, inputChannel.getNumberOfRequiredBuffers());
+                       assertEquals("There should be 2 buffers available in 
local pool",
+                               2, 
bufferPool.getNumberOfAvailableMemorySegments());
+
+               } finally {
+                       // Release all the buffer resources
+                       inputChannel.releaseAllResources();
 
-               // Assign exclusive segments to channel
-               final List<MemorySegment> exclusiveSegments = new ArrayList<>();
+                       networkBufferPool.destroyAllBufferPools();
+                       networkBufferPool.destroy();
+               }
+       }
+
+       /**
+        * Tests to verify that the buffer pool will distribute available 
floating buffers among
+        * all the channel listeners in a fair way.
+        */
+       @Test
+       public void testFairDistributionFloatingBuffers() throws Exception {
+               // Setup
+               final NetworkBufferPool networkBufferPool = new 
NetworkBufferPool(12, 32);
                final int numExclusiveBuffers = 2;
-               for (int i = 0; i < numExclusiveBuffers; i++) {
-                       
exclusiveSegments.add(MemorySegmentFactory.allocateUnpooledSegment(1024, 
inputChannel));
+               final int numFloatingBuffers = 3;
+
+               final SingleInputGate inputGate = createSingleInputGate();
+               final RemoteInputChannel channel1 = 
spy(createRemoteInputChannel(inputGate));
+               final RemoteInputChannel channel2 = 
spy(createRemoteInputChannel(inputGate));
+               final RemoteInputChannel channel3 = 
spy(createRemoteInputChannel(inputGate));
+               
inputGate.setInputChannel(channel1.partitionId.getPartitionId(), channel1);
+               
inputGate.setInputChannel(channel2.partitionId.getPartitionId(), channel2);
+               
inputGate.setInputChannel(channel3.partitionId.getPartitionId(), channel3);
+               try {
+                       final BufferPool bufferPool = 
spy(networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers));
+                       inputGate.setBufferPool(bufferPool);
+                       inputGate.assignExclusiveSegments(networkBufferPool, 
numExclusiveBuffers);
+
+                       // Exhaust all the floating buffers
+                       final List<Buffer> floatingBuffers = new 
ArrayList<>(numFloatingBuffers);
+                       for (int i = 0; i < numFloatingBuffers; i++) {
+                               Buffer buffer = bufferPool.requestBuffer();
+                               assertNotNull(buffer);
+                               floatingBuffers.add(buffer);
+                       }
+
+                       // Receive the producer's backlog to trigger request 
floating buffers from pool
+                       // and register as listeners as a result
+                       channel1.onSenderBacklog(8);
+                       channel2.onSenderBacklog(8);
+                       channel3.onSenderBacklog(8);
+
+                       verify(bufferPool, 
times(1)).addBufferListener(channel1);
+                       verify(bufferPool, 
times(1)).addBufferListener(channel2);
+                       verify(bufferPool, 
times(1)).addBufferListener(channel3);
+                       assertEquals("There should be " + numExclusiveBuffers + 
" buffers available in the channel",
+                               numExclusiveBuffers, 
channel1.getNumberOfAvailableBuffers());
+                       assertEquals("There should be " + numExclusiveBuffers + 
" buffers available in the channel",
+                               numExclusiveBuffers, 
channel2.getNumberOfAvailableBuffers());
+                       assertEquals("There should be " + numExclusiveBuffers + 
" buffers available in the channel",
+                               numExclusiveBuffers, 
channel3.getNumberOfAvailableBuffers());
+
+                       // Recycle three floating buffers to trigger notify 
buffer available
+                       for (Buffer buffer : floatingBuffers) {
+                               buffer.recycle();
+                       }
+
+                       verify(channel1, 
times(1)).notifyBufferAvailable(any(Buffer.class));
+                       verify(channel2, 
times(1)).notifyBufferAvailable(any(Buffer.class));
+                       verify(channel3, 
times(1)).notifyBufferAvailable(any(Buffer.class));
+                       assertEquals("There should be 3 buffers available in 
the channel", 3, channel1.getNumberOfAvailableBuffers());
+                       assertEquals("There should be 3 buffers available in 
the channel", 3, channel2.getNumberOfAvailableBuffers());
+                       assertEquals("There should be 3 buffers available in 
the channel", 3, channel3.getNumberOfAvailableBuffers());
+
+               } finally {
+                       // Release all the buffer resources
+                       channel1.releaseAllResources();
+                       channel2.releaseAllResources();
+                       channel3.releaseAllResources();
+
+                       networkBufferPool.destroyAllBufferPools();
+                       networkBufferPool.destroy();
                }
-               inputChannel.assignExclusiveSegments(exclusiveSegments);
+       }
+
+       /**
+        * Tests to verify that there is no race condition with two things 
running in parallel:
+        * requesting floating buffers on sender backlog and some other thread 
releasing
+        * the input channel.
+        */
+       @Test
+       public void testConcurrentOnSenderBacklogAndRelease() throws Exception {
+               // Setup
+               final NetworkBufferPool networkBufferPool = new 
NetworkBufferPool(130, 32);
+               final int numExclusiveBuffers = 2;
+               final int numFloatingBuffers = 128;
+
+               final ExecutorService executor = 
Executors.newFixedThreadPool(2);
+
+               final SingleInputGate inputGate = createSingleInputGate();
+               final RemoteInputChannel inputChannel  = 
createRemoteInputChannel(inputGate);
+               
inputGate.setInputChannel(inputChannel.partitionId.getPartitionId(), 
inputChannel);
+               try {
+                       final BufferPool bufferPool = 
networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers);
+                       inputGate.setBufferPool(bufferPool);
+                       inputGate.assignExclusiveSegments(networkBufferPool, 
numExclusiveBuffers);
+
+                       final Callable<Void> requestBufferTask = new 
Callable<Void>() {
+                               @Override
+                               public Void call() throws Exception {
+                                       while (true) {
+                                               for (int j = 1; j <= 
numFloatingBuffers; j++) {
+                                                       
inputChannel.onSenderBacklog(j);
+                                               }
 
-               assertEquals("The number of available buffers is not equal to 
the assigned amount.",
-                       numExclusiveBuffers, 
inputChannel.getNumberOfAvailableBuffers());
+                                               if (inputChannel.isReleased()) {
+                                                       return null;
+                                               }
+                                       }
+                               }
+                       };
+
+                       final Callable<Void> releaseTask = new Callable<Void>() 
{
+                               @Override
+                               public Void call() throws Exception {
+                                       inputChannel.releaseAllResources();
+
+                                       return null;
+                               }
+                       };
+
+                       // Submit tasks and wait to finish
+                       submitTasksAndWaitForResults(executor, new 
Callable[]{requestBufferTask, releaseTask});
+
+                       assertEquals("There should be no buffers available in 
the channel.",
+                               0, inputChannel.getNumberOfAvailableBuffers());
+                       assertEquals("There should be 130 buffers available in 
local pool.",
+                               130, 
bufferPool.getNumberOfAvailableMemorySegments() + 
networkBufferPool.getNumberOfAvailableMemorySegments());
 
-               // Release this channel
-               inputChannel.releaseAllResources();
+               } finally {
+                       // Release all the buffer resources once exception
+                       if (!inputChannel.isReleased()) {
+                               inputChannel.releaseAllResources();
+                       }
+
+                       networkBufferPool.destroyAllBufferPools();
+                       networkBufferPool.destroy();
 
-               assertEquals("Resource leak after channel is released.",
-                       0, inputChannel.getNumberOfAvailableBuffers());
-               verify(inputGate, 
times(1)).returnExclusiveSegments(anyListOf(MemorySegment.class));
+                       executor.shutdown();
+               }
+       }
+
+       /**
+        * Tests to verify that there is no race condition with two things 
running in parallel:
+        * requesting floating buffers on sender backlog and some other thread 
recycling
+        * floating or exclusive buffers.
+        */
+       @Test
+       public void testConcurrentOnSenderBacklogAndRecycle() throws Exception {
+               // Setup
+               final NetworkBufferPool networkBufferPool = new 
NetworkBufferPool(248, 32);
+               final int numExclusiveSegments = 120;
+               final int numFloatingBuffers = 128;
+               final int backlog = 128;
+
+               final ExecutorService executor = 
Executors.newFixedThreadPool(3);
+
+               final SingleInputGate inputGate = createSingleInputGate();
+               final RemoteInputChannel inputChannel  = 
createRemoteInputChannel(inputGate);
+               
inputGate.setInputChannel(inputChannel.partitionId.getPartitionId(), 
inputChannel);
+               try {
+                       final BufferPool bufferPool = 
networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers);
+                       inputGate.setBufferPool(bufferPool);
+                       inputGate.assignExclusiveSegments(networkBufferPool, 
numExclusiveSegments);
+
+                       final Callable<Void> requestBufferTask = new 
Callable<Void>() {
+                               @Override
+                               public Void call() throws Exception {
+                                       for (int j = 1; j <= backlog; j++) {
+                                               inputChannel.onSenderBacklog(j);
+                                       }
+
+                                       return null;
+                               }
+                       };
+
+                       // Submit tasks and wait to finish
+                       submitTasksAndWaitForResults(executor, new Callable[]{
+                                       
recycleExclusiveBufferTask(inputChannel, numExclusiveSegments),
+                                       recycleFloatingBufferTask(bufferPool, 
numFloatingBuffers),
+                                       requestBufferTask});
+
+                       assertEquals("There should be " + 
inputChannel.getNumberOfRequiredBuffers() +" buffers available in channel.",
+                               inputChannel.getNumberOfRequiredBuffers(), 
inputChannel.getNumberOfAvailableBuffers());
+                       assertEquals("There should be no buffers available in 
local pool.",
+                               0, 
bufferPool.getNumberOfAvailableMemorySegments());
+
+               } finally {
+                       // Release all the buffer resources
+                       inputChannel.releaseAllResources();
+
+                       networkBufferPool.destroyAllBufferPools();
+                       networkBufferPool.destroy();
+
+                       executor.shutdown();
+               }
+       }
+
+       /**
+        * Tests to verify that there is no race condition with two things 
running in parallel:
+        * recycling the exclusive or floating buffers and some other thread 
releasing the
+        * input channel.
+        */
+       @Test
+       public void testConcurrentRecycleAndRelease() throws Exception {
+               // Setup
+               final NetworkBufferPool networkBufferPool = new 
NetworkBufferPool(248, 32);
+               final int numExclusiveSegments = 120;
+               final int numFloatingBuffers = 128;
+
+               final ExecutorService executor = 
Executors.newFixedThreadPool(3);
+
+               final SingleInputGate inputGate = createSingleInputGate();
+               final RemoteInputChannel inputChannel  = 
createRemoteInputChannel(inputGate);
+               
inputGate.setInputChannel(inputChannel.partitionId.getPartitionId(), 
inputChannel);
+               try {
+                       final BufferPool bufferPool = 
networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers);
+                       inputGate.setBufferPool(bufferPool);
+                       inputGate.assignExclusiveSegments(networkBufferPool, 
numExclusiveSegments);
+
+                       final Callable<Void> releaseTask = new Callable<Void>() 
{
+                               @Override
+                               public Void call() throws Exception {
+                                       inputChannel.releaseAllResources();
+
+                                       return null;
+                               }
+                       };
+
+                       // Submit tasks and wait to finish
+                       submitTasksAndWaitForResults(executor, new Callable[]{
+                                       
recycleExclusiveBufferTask(inputChannel, numExclusiveSegments),
+                                       recycleFloatingBufferTask(bufferPool, 
numFloatingBuffers),
+                                       releaseTask});
+
+                       assertEquals("There should be no buffers available in 
the channel.",
+                               0, inputChannel.getNumberOfAvailableBuffers());
+                       assertEquals("There should be " + numFloatingBuffers + 
" buffers available in local pool.",
+                               numFloatingBuffers, 
bufferPool.getNumberOfAvailableMemorySegments());
+                       assertEquals("There should be " + numExclusiveSegments 
+ " buffers available in global pool.",
+                               numExclusiveSegments, 
networkBufferPool.getNumberOfAvailableMemorySegments());
+
+               } finally {
+                       // Release all the buffer resources once exception
+                       if (!inputChannel.isReleased()) {
+                               inputChannel.releaseAllResources();
+                       }
+
+                       networkBufferPool.destroyAllBufferPools();
+                       networkBufferPool.destroy();
+
+                       executor.shutdown();
+               }
        }
 
        // 
---------------------------------------------------------------------------------------------
 
+       private SingleInputGate createSingleInputGate() {
+               return new SingleInputGate(
+                       "InputGate",
+                       new JobID(),
+                       new IntermediateDataSetID(),
+                       ResultPartitionType.PIPELINED_CREDIT_BASED,
+                       0,
+                       1,
+                       mock(TaskActions.class),
+                       
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
+       }
+
        private RemoteInputChannel createRemoteInputChannel(SingleInputGate 
inputGate)
                        throws IOException, InterruptedException {
 
@@ -403,4 +890,78 @@ public class RemoteInputChannelTest {
                        initialAndMaxRequestBackoff._2(),
                        
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup());
        }
+
+       /**
+        * Requests the exclusive buffers from input channel first and then 
recycles them by a callable task.
+        *
+        * @param inputChannel The input channel that exclusive buffers request 
from.
+        * @param numExclusiveSegments The number of exclusive buffers to 
request.
+        * @return The callable task to recycle exclusive buffers.
+        */
+       private Callable<Void> recycleExclusiveBufferTask(RemoteInputChannel 
inputChannel, int numExclusiveSegments) {
+               final List<Buffer> exclusiveBuffers = new 
ArrayList<>(numExclusiveSegments);
+               // Exhaust all the exclusive buffers
+               for (int i = 0; i < numExclusiveSegments; i++) {
+                       Buffer buffer = inputChannel.requestBuffer();
+                       assertNotNull(buffer);
+                       exclusiveBuffers.add(buffer);
+               }
+
+               return new Callable<Void>() {
+                       @Override
+                       public Void call() throws Exception {
+                               for (Buffer buffer : exclusiveBuffers) {
+                                       buffer.recycle();
+                               }
+
+                               return null;
+                       }
+               };
+       }
+
+       /**
+        * Requests the floating buffers from pool first and then recycles them 
by a callable task.
+        *
+        * @param bufferPool The buffer pool that floating buffers request from.
+        * @param numFloatingBuffers The number of floating buffers to request.
+        * @return The callable task to recycle floating buffers.
+        */
+       private Callable<Void> recycleFloatingBufferTask(BufferPool bufferPool, 
int numFloatingBuffers) throws Exception {
+               final List<Buffer> floatingBuffers = new 
ArrayList<>(numFloatingBuffers);
+               // Exhaust all the floating buffers
+               for (int i = 0; i < numFloatingBuffers; i++) {
+                       Buffer buffer = bufferPool.requestBuffer();
+                       assertNotNull(buffer);
+                       floatingBuffers.add(buffer);
+               }
+
+               return new Callable<Void>() {
+                       @Override
+                       public Void call() throws Exception {
+                               for (Buffer buffer : floatingBuffers) {
+                                       buffer.recycle();
+                               }
+
+                               return null;
+                       }
+               };
+       }
+
+       /**
+        * Submits all the callable tasks to the executor and waits for the 
results.
+        *
+        * @param executor The executor service for running tasks.
+        * @param tasks The callable tasks to be submitted and executed.
+        */
+       private void submitTasksAndWaitForResults(ExecutorService executor, 
Callable[] tasks) throws Exception {
+               final List<Future> results = 
Lists.newArrayListWithCapacity(tasks.length);
+
+               for(Callable task : tasks) {
+                       results.add(executor.submit(task));
+               }
+
+               for (Future result : results) {
+                       result.get();
+               }
+       }
 }

Reply via email to