[FLINK-7406][network] Implement Netty receiver incoming pipeline for credit-based
Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/268867ce Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/268867ce Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/268867ce Branch: refs/heads/master Commit: 268867ce620a2c12879749db2ecb68bbe129cad5 Parents: 542419b Author: Zhijiang <wangzhijiang...@aliyun.com> Authored: Thu Aug 10 13:29:13 2017 +0800 Committer: Stefan Richter <s.rich...@data-artisans.com> Committed: Mon Jan 8 11:46:00 2018 +0100 ---------------------------------------------------------------------- .../network/netty/CreditBasedClientHandler.java | 277 ++++++++ .../runtime/io/network/netty/NettyMessage.java | 15 +- .../netty/PartitionRequestClientHandler.java | 8 +- .../io/network/netty/PartitionRequestQueue.java | 3 +- .../partition/consumer/RemoteInputChannel.java | 257 +++++-- .../netty/NettyMessageSerializationTest.java | 3 +- .../PartitionRequestClientHandlerTest.java | 151 ++--- .../partition/InputGateConcurrentTest.java | 2 +- .../partition/InputGateFairnessTest.java | 8 +- .../consumer/RemoteInputChannelTest.java | 665 +++++++++++++++++-- 10 files changed, 1175 insertions(+), 214 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/268867ce/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/CreditBasedClientHandler.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/CreditBasedClientHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/CreditBasedClientHandler.java new file mode 100644 index 0000000..1f18588 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/CreditBasedClientHandler.java @@ -0,0 +1,277 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network.netty; + +import org.apache.flink.core.memory.MemorySegment; +import org.apache.flink.core.memory.MemorySegmentFactory; +import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler; +import org.apache.flink.runtime.io.network.netty.exception.LocalTransportException; +import org.apache.flink.runtime.io.network.netty.exception.RemoteTransportException; +import org.apache.flink.runtime.io.network.netty.exception.TransportException; +import org.apache.flink.runtime.io.network.partition.PartitionNotFoundException; +import org.apache.flink.runtime.io.network.partition.consumer.InputChannelID; +import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel; + +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext; +import org.apache.flink.shaded.netty4.io.netty.channel.ChannelInboundHandlerAdapter; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.net.SocketAddress; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicReference; + +/** + * Channel handler to read the messages of buffer response or error response from the + * producer, to write and flush the unannounced credits for the producer. + */ +class CreditBasedClientHandler extends ChannelInboundHandlerAdapter { + + private static final Logger LOG = LoggerFactory.getLogger(CreditBasedClientHandler.class); + + /** Channels, which already requested partitions from the producers. */ + private final ConcurrentMap<InputChannelID, RemoteInputChannel> inputChannels = new ConcurrentHashMap<>(); + + private final AtomicReference<Throwable> channelError = new AtomicReference<>(); + + /** + * Set of cancelled partition requests. A request is cancelled iff an input channel is cleared + * while data is still coming in for this channel. + */ + private final ConcurrentMap<InputChannelID, InputChannelID> cancelled = new ConcurrentHashMap<>(); + + private volatile ChannelHandlerContext ctx; + + // ------------------------------------------------------------------------ + // Input channel/receiver registration + // ------------------------------------------------------------------------ + + void addInputChannel(RemoteInputChannel listener) throws IOException { + checkError(); + + if (!inputChannels.containsKey(listener.getInputChannelId())) { + inputChannels.put(listener.getInputChannelId(), listener); + } + } + + void removeInputChannel(RemoteInputChannel listener) { + inputChannels.remove(listener.getInputChannelId()); + } + + void cancelRequestFor(InputChannelID inputChannelId) { + if (inputChannelId == null || ctx == null) { + return; + } + + if (cancelled.putIfAbsent(inputChannelId, inputChannelId) == null) { + ctx.writeAndFlush(new NettyMessage.CancelPartitionRequest(inputChannelId)); + } + } + + // ------------------------------------------------------------------------ + // Network events + // ------------------------------------------------------------------------ + + @Override + public void channelActive(final ChannelHandlerContext ctx) throws Exception { + if (this.ctx == null) { + this.ctx = ctx; + } + + super.channelActive(ctx); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + // Unexpected close. In normal operation, the client closes the connection after all input + // channels have been removed. This indicates a problem with the remote task manager. + if (!inputChannels.isEmpty()) { + final SocketAddress remoteAddr = ctx.channel().remoteAddress(); + + notifyAllChannelsOfErrorAndClose(new RemoteTransportException( + "Connection unexpectedly closed by remote task manager '" + remoteAddr + "'. " + + "This might indicate that the remote task manager was lost.", remoteAddr)); + } + + super.channelInactive(ctx); + } + + /** + * Called on exceptions in the client handler pipeline. + * + * <p>Remote exceptions are received as regular payload. + */ + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + + if (cause instanceof TransportException) { + notifyAllChannelsOfErrorAndClose(cause); + } else { + final SocketAddress remoteAddr = ctx.channel().remoteAddress(); + + final TransportException tex; + + // Improve on the connection reset by peer error message + if (cause instanceof IOException && cause.getMessage().equals("Connection reset by peer")) { + tex = new RemoteTransportException("Lost connection to task manager '" + remoteAddr + "'. " + + "This indicates that the remote task manager was lost.", remoteAddr, cause); + } else { + tex = new LocalTransportException(cause.getMessage(), ctx.channel().localAddress(), cause); + } + + notifyAllChannelsOfErrorAndClose(tex); + } + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + try { + decodeMsg(msg); + } catch (Throwable t) { + notifyAllChannelsOfErrorAndClose(t); + } + } + + private void notifyAllChannelsOfErrorAndClose(Throwable cause) { + if (channelError.compareAndSet(null, cause)) { + try { + for (RemoteInputChannel inputChannel : inputChannels.values()) { + inputChannel.onError(cause); + } + } catch (Throwable t) { + // We can only swallow the Exception at this point. :( + LOG.warn("An Exception was thrown during error notification of a remote input channel.", t); + } finally { + inputChannels.clear(); + + if (ctx != null) { + ctx.close(); + } + } + } + } + + // ------------------------------------------------------------------------ + + /** + * Checks for an error and rethrows it if one was reported. + */ + private void checkError() throws IOException { + final Throwable t = channelError.get(); + + if (t != null) { + if (t instanceof IOException) { + throw (IOException) t; + } else { + throw new IOException("There has been an error in the channel.", t); + } + } + } + + private void decodeMsg(Object msg) throws Throwable { + final Class<?> msgClazz = msg.getClass(); + + // ---- Buffer -------------------------------------------------------- + if (msgClazz == NettyMessage.BufferResponse.class) { + NettyMessage.BufferResponse bufferOrEvent = (NettyMessage.BufferResponse) msg; + + RemoteInputChannel inputChannel = inputChannels.get(bufferOrEvent.receiverId); + if (inputChannel == null) { + bufferOrEvent.releaseBuffer(); + + cancelRequestFor(bufferOrEvent.receiverId); + + return; + } + + decodeBufferOrEvent(inputChannel, bufferOrEvent); + + } else if (msgClazz == NettyMessage.ErrorResponse.class) { + // ---- Error --------------------------------------------------------- + NettyMessage.ErrorResponse error = (NettyMessage.ErrorResponse) msg; + + SocketAddress remoteAddr = ctx.channel().remoteAddress(); + + if (error.isFatalError()) { + notifyAllChannelsOfErrorAndClose(new RemoteTransportException( + "Fatal error at remote task manager '" + remoteAddr + "'.", + remoteAddr, + error.cause)); + } else { + RemoteInputChannel inputChannel = inputChannels.get(error.receiverId); + + if (inputChannel != null) { + if (error.cause.getClass() == PartitionNotFoundException.class) { + inputChannel.onFailedPartitionRequest(); + } else { + inputChannel.onError(new RemoteTransportException( + "Error at remote task manager '" + remoteAddr + "'.", + remoteAddr, + error.cause)); + } + } + } + } else { + throw new IllegalStateException("Received unknown message from producer: " + msg.getClass()); + } + } + + private void decodeBufferOrEvent(RemoteInputChannel inputChannel, NettyMessage.BufferResponse bufferOrEvent) throws Throwable { + try { + if (bufferOrEvent.isBuffer()) { + // ---- Buffer ------------------------------------------------ + + // Early return for empty buffers. Otherwise Netty's readBytes() throws an + // IndexOutOfBoundsException. + if (bufferOrEvent.getSize() == 0) { + inputChannel.onEmptyBuffer(bufferOrEvent.sequenceNumber, bufferOrEvent.backlog); + return; + } + + Buffer buffer = inputChannel.requestBuffer(); + if (buffer != null) { + buffer.setSize(bufferOrEvent.getSize()); + bufferOrEvent.getNettyBuffer().readBytes(buffer.getNioBuffer()); + + inputChannel.onBuffer(buffer, bufferOrEvent.sequenceNumber, bufferOrEvent.backlog); + } else if (inputChannel.isReleased()) { + cancelRequestFor(bufferOrEvent.receiverId); + } else { + throw new IllegalStateException("No buffer available in credit-based input channel."); + } + } else { + // ---- Event ------------------------------------------------- + // TODO We can just keep the serialized data in the Netty buffer and release it later at the reader + byte[] byteArray = new byte[bufferOrEvent.getSize()]; + bufferOrEvent.getNettyBuffer().readBytes(byteArray); + + MemorySegment memSeg = MemorySegmentFactory.wrap(byteArray); + Buffer buffer = new Buffer(memSeg, FreeingBufferRecycler.INSTANCE, false); + + inputChannel.onBuffer(buffer, bufferOrEvent.sequenceNumber, bufferOrEvent.backlog); + } + } finally { + bufferOrEvent.releaseBuffer(); + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/268867ce/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessage.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessage.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessage.java index 89fb9e8..db1b899 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessage.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/NettyMessage.java @@ -221,6 +221,8 @@ public abstract class NettyMessage { final int sequenceNumber; + final int backlog; + // ---- Deserialization ----------------------------------------------- final boolean isBuffer; @@ -232,7 +234,8 @@ public abstract class NettyMessage { private BufferResponse( ByteBuf retainedSlice, boolean isBuffer, int sequenceNumber, - InputChannelID receiverId) { + InputChannelID receiverId, + int backlog) { // When deserializing we first have to request a buffer from the respective buffer // provider (at the handler) and copy the buffer from Netty's space to ours. Only // retainedSlice is set in this case. @@ -242,15 +245,17 @@ public abstract class NettyMessage { this.isBuffer = isBuffer; this.sequenceNumber = sequenceNumber; this.receiverId = checkNotNull(receiverId); + this.backlog = backlog; } - BufferResponse(Buffer buffer, int sequenceNumber, InputChannelID receiverId) { + BufferResponse(Buffer buffer, int sequenceNumber, InputChannelID receiverId, int backlog) { this.buffer = checkNotNull(buffer); this.retainedSlice = null; this.isBuffer = buffer.isBuffer(); this.size = buffer.getSize(); this.sequenceNumber = sequenceNumber; this.receiverId = checkNotNull(receiverId); + this.backlog = backlog; } boolean isBuffer() { @@ -280,7 +285,7 @@ public abstract class NettyMessage { ByteBuf write(ByteBufAllocator allocator) throws IOException { checkNotNull(buffer, "No buffer instance to serialize."); - int length = 16 + 4 + 1 + 4 + buffer.getSize(); + int length = 16 + 4 + 4 + 1 + 4 + buffer.getSize(); ByteBuf result = null; try { @@ -288,6 +293,7 @@ public abstract class NettyMessage { receiverId.writeTo(result); result.writeInt(sequenceNumber); + result.writeInt(backlog); result.writeBoolean(buffer.isBuffer()); result.writeInt(buffer.getSize()); result.writeBytes(buffer.getNioBuffer()); @@ -309,12 +315,13 @@ public abstract class NettyMessage { static BufferResponse readFrom(ByteBuf buffer) { InputChannelID receiverId = InputChannelID.fromByteBuf(buffer); int sequenceNumber = buffer.readInt(); + int backlog = buffer.readInt(); boolean isBuffer = buffer.readBoolean(); int size = buffer.readInt(); ByteBuf retainedSlice = buffer.readSlice(size).retain(); - return new BufferResponse(retainedSlice, isBuffer, sequenceNumber, receiverId); + return new BufferResponse(retainedSlice, isBuffer, sequenceNumber, receiverId, backlog); } } http://git-wip-us.apache.org/repos/asf/flink/blob/268867ce/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandler.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandler.java index 566b215..ab4798e 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandler.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandler.java @@ -276,7 +276,7 @@ class PartitionRequestClientHandler extends ChannelInboundHandlerAdapter { // Early return for empty buffers. Otherwise Netty's readBytes() throws an // IndexOutOfBoundsException. if (bufferOrEvent.getSize() == 0) { - inputChannel.onEmptyBuffer(bufferOrEvent.sequenceNumber); + inputChannel.onEmptyBuffer(bufferOrEvent.sequenceNumber, -1); return true; } @@ -295,7 +295,7 @@ class PartitionRequestClientHandler extends ChannelInboundHandlerAdapter { buffer.setSize(bufferOrEvent.getSize()); bufferOrEvent.getNettyBuffer().readBytes(buffer.getNioBuffer()); - inputChannel.onBuffer(buffer, bufferOrEvent.sequenceNumber); + inputChannel.onBuffer(buffer, bufferOrEvent.sequenceNumber, -1); return true; } @@ -318,7 +318,7 @@ class PartitionRequestClientHandler extends ChannelInboundHandlerAdapter { MemorySegment memSeg = MemorySegmentFactory.wrap(byteArray); Buffer buffer = new Buffer(memSeg, FreeingBufferRecycler.INSTANCE, false); - inputChannel.onBuffer(buffer, bufferOrEvent.sequenceNumber); + inputChannel.onBuffer(buffer, bufferOrEvent.sequenceNumber, -1); return true; } @@ -450,7 +450,7 @@ class PartitionRequestClientHandler extends ChannelInboundHandlerAdapter { RemoteInputChannel inputChannel = inputChannels.get(stagedBufferResponse.receiverId); if (inputChannel != null) { - inputChannel.onBuffer(buffer, stagedBufferResponse.sequenceNumber); + inputChannel.onBuffer(buffer, stagedBufferResponse.sequenceNumber, -1); success = true; } http://git-wip-us.apache.org/repos/asf/flink/blob/268867ce/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestQueue.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestQueue.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestQueue.java index ff0f130..41f87ae 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestQueue.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/netty/PartitionRequestQueue.java @@ -193,7 +193,8 @@ class PartitionRequestQueue extends ChannelInboundHandlerAdapter { BufferResponse msg = new BufferResponse( next.buffer(), reader.getSequenceNumber(), - reader.getReceiverId()); + reader.getReceiverId(), + 0); if (isEndOfPartitionEvent(next.buffer())) { reader.notifySubpartitionConsumed(); http://git-wip-us.apache.org/repos/asf/flink/blob/268867ce/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java index cd00934..02c7b34 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java @@ -18,6 +18,7 @@ package org.apache.flink.runtime.io.network.partition.consumer; +import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.core.memory.MemorySegment; import org.apache.flink.runtime.event.TaskEvent; import org.apache.flink.runtime.io.network.ConnectionID; @@ -32,11 +33,13 @@ import org.apache.flink.runtime.io.network.partition.ResultPartitionID; import org.apache.flink.runtime.metrics.groups.TaskIOMetricGroup; import org.apache.flink.util.ExceptionUtils; +import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; import java.io.IOException; import java.util.ArrayDeque; +import java.util.Collections; import java.util.List; import java.util.ArrayList; -import java.util.Arrays; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; @@ -82,17 +85,19 @@ public class RemoteInputChannel extends InputChannel implements BufferRecycler, /** The initial number of exclusive buffers assigned to this channel. */ private int initialCredit; - /** The current available buffers including both exclusive buffers and requested floating buffers. */ - private final ArrayDeque<Buffer> availableBuffers = new ArrayDeque<>(); + /** The available buffer queue wraps both exclusive and requested floating buffers. */ + private final AvailableBufferQueue bufferQueue = new AvailableBufferQueue(); /** The number of available buffers that have not been announced to the producer yet. */ private final AtomicInteger unannouncedCredit = new AtomicInteger(0); - /** The number of unsent buffers in the producer's sub partition. */ - private final AtomicInteger senderBacklog = new AtomicInteger(0); + /** The number of required buffers that equals to sender's backlog plus initial credit. */ + @GuardedBy("bufferQueue") + private int numRequiredBuffers; /** The tag indicates whether this channel is waiting for additional floating buffers from the buffer pool. */ - private final AtomicBoolean isWaitingForFloatingBuffers = new AtomicBoolean(false); + @GuardedBy("bufferQueue") + private boolean isWaitingForFloatingBuffers; public RemoteInputChannel( SingleInputGate inputGate, @@ -133,10 +138,11 @@ public class RemoteInputChannel extends InputChannel implements BufferRecycler, checkArgument(segments.size() > 0, "The number of exclusive buffers per channel should be larger than 0."); this.initialCredit = segments.size(); + this.numRequiredBuffers = segments.size(); - synchronized(availableBuffers) { + synchronized(bufferQueue) { for (MemorySegment segment : segments) { - availableBuffers.add(new Buffer(segment, this)); + bufferQueue.addExclusiveBuffer(new Buffer(segment, this), numRequiredBuffers); } } } @@ -211,7 +217,7 @@ public class RemoteInputChannel extends InputChannel implements BufferRecycler, // ------------------------------------------------------------------------ @Override - boolean isReleased() { + public boolean isReleased() { return isReleased.get(); } @@ -227,7 +233,8 @@ public class RemoteInputChannel extends InputChannel implements BufferRecycler, void releaseAllResources() throws IOException { if (isReleased.compareAndSet(false, true)) { - // Gather all exclusive buffers and recycle them to global pool in batch + // Gather all exclusive buffers and recycle them to global pool in batch, because + // we do not want to trigger redistribution of buffers after each recycle. final List<MemorySegment> exclusiveRecyclingSegments = new ArrayList<>(); synchronized (receivedBuffers) { @@ -240,16 +247,8 @@ public class RemoteInputChannel extends InputChannel implements BufferRecycler, } } } - - synchronized (availableBuffers) { - Buffer buffer; - while ((buffer = availableBuffers.poll()) != null) { - if (buffer.getRecycler() == this) { - exclusiveRecyclingSegments.add(buffer.getMemorySegment()); - } else { - buffer.recycle(); - } - } + synchronized (bufferQueue) { + bufferQueue.releaseAll(exclusiveRecyclingSegments); } if (exclusiveRecyclingSegments.size() > 0) { @@ -287,81 +286,93 @@ public class RemoteInputChannel extends InputChannel implements BufferRecycler, } /** - * Exclusive buffer is recycled to this input channel directly and it may trigger notify - * credit to producer. + * Exclusive buffer is recycled to this input channel directly and it may trigger return extra + * floating buffer and notify increased credit to the producer. * * @param segment The exclusive segment of this channel. */ @Override public void recycle(MemorySegment segment) { - synchronized (availableBuffers) { - // Important: the isReleased check should be inside the synchronized block. - // that way the segment can also be returned to global pool after added into - // the available queue during releasing all resources. + int numAddedBuffers; + + synchronized (bufferQueue) { + // Important: check the isReleased state inside synchronized block, so there is no + // race condition when recycle and releaseAllResources running in parallel. if (isReleased.get()) { try { - inputGate.returnExclusiveSegments(Arrays.asList(segment)); + inputGate.returnExclusiveSegments(Collections.singletonList(segment)); return; } catch (Throwable t) { ExceptionUtils.rethrow(t); } } - availableBuffers.add(new Buffer(segment, this)); + numAddedBuffers = bufferQueue.addExclusiveBuffer(new Buffer(segment, this), numRequiredBuffers); } - if (unannouncedCredit.getAndAdd(1) == 0) { + if (numAddedBuffers > 0 && unannouncedCredit.getAndAdd(numAddedBuffers) == 0) { notifyCreditAvailable(); } } public int getNumberOfAvailableBuffers() { - synchronized (availableBuffers) { - return availableBuffers.size(); + synchronized (bufferQueue) { + return bufferQueue.getAvailableBufferSize(); } } + @VisibleForTesting + public int getNumberOfRequiredBuffers() { + return numRequiredBuffers; + } + /** * The Buffer pool notifies this channel of an available floating buffer. If the channel is released or * currently does not need extra buffers, the buffer should be recycled to the buffer pool. Otherwise, - * the buffer will be added into the <tt>availableBuffers</tt> queue and the unannounced credit is - * increased by one. + * the buffer will be added into the <tt>bufferQueue</tt> and the unannounced credit is increased + * by one. * * @param buffer Buffer that becomes available in buffer pool. * @return True when this channel is waiting for more floating buffers, otherwise false. */ @Override public boolean notifyBufferAvailable(Buffer buffer) { - checkState(isWaitingForFloatingBuffers.get(), "This channel should be waiting for floating buffers."); + // Check the isReleased state outside synchronized block first to avoid + // deadlock with releaseAllResources running in parallel. + if (isReleased.get()) { + buffer.recycle(); + return false; + } - synchronized (availableBuffers) { - // Important: the isReleased check should be inside the synchronized block. - if (isReleased.get() || availableBuffers.size() >= senderBacklog.get()) { - isWaitingForFloatingBuffers.set(false); - buffer.recycle(); + boolean needMoreBuffers = false; + synchronized (bufferQueue) { + checkState(isWaitingForFloatingBuffers, "This channel should be waiting for floating buffers."); + // Important: double check the isReleased state inside synchronized block, so there is no + // race condition when notifyBufferAvailable and releaseAllResources running in parallel. + if (isReleased.get() || bufferQueue.getAvailableBufferSize() >= numRequiredBuffers) { + buffer.recycle(); return false; } - availableBuffers.add(buffer); - - if (unannouncedCredit.getAndAdd(1) == 0) { - notifyCreditAvailable(); - } + bufferQueue.addFloatingBuffer(buffer); - if (availableBuffers.size() >= senderBacklog.get()) { - isWaitingForFloatingBuffers.set(false); - return false; + if (bufferQueue.getAvailableBufferSize() == numRequiredBuffers) { + isWaitingForFloatingBuffers = false; } else { - return true; + needMoreBuffers = true; } } + + if (unannouncedCredit.getAndAdd(1) == 0) { + notifyCreditAvailable(); + } + + return needMoreBuffers; } @Override public void notifyBufferDestroyed() { - if (!isWaitingForFloatingBuffers.compareAndSet(true, false)) { - throw new IllegalStateException("This channel should be waiting for floating buffers currently."); - } + // Nothing to do actually. } // ------------------------------------------------------------------------ @@ -394,7 +405,58 @@ public class RemoteInputChannel extends InputChannel implements BufferRecycler, return inputGate.getBufferProvider(); } - public void onBuffer(Buffer buffer, int sequenceNumber) { + /** + * Requests buffer from input channel directly for receiving network data. + * It should always return an available buffer in credit-based mode unless + * the channel has been released. + * + * @return The available buffer. + */ + @Nullable + public Buffer requestBuffer() { + synchronized (bufferQueue) { + return bufferQueue.takeBuffer(); + } + } + + /** + * Receives the backlog from the producer's buffer response. If the number of available + * buffers is less than backlog + initialCredit, it will request floating buffers from the buffer + * pool, and then notify unannounced credits to the producer. + * + * @param backlog The number of unsent buffers in the producer's sub partition. + */ + @VisibleForTesting + void onSenderBacklog(int backlog) throws IOException { + int numRequestedBuffers = 0; + + synchronized (bufferQueue) { + // Important: check the isReleased state inside synchronized block, so there is no + // race condition when onSenderBacklog and releaseAllResources running in parallel. + if (isReleased.get()) { + return; + } + + numRequiredBuffers = backlog + initialCredit; + while (bufferQueue.getAvailableBufferSize() < numRequiredBuffers && !isWaitingForFloatingBuffers) { + Buffer buffer = inputGate.getBufferPool().requestBuffer(); + if (buffer != null) { + bufferQueue.addFloatingBuffer(buffer); + numRequestedBuffers++; + } else if (inputGate.getBufferProvider().addBufferListener(this)) { + // If the channel has not got enough buffers, register it as listener to wait for more floating buffers. + isWaitingForFloatingBuffers = true; + break; + } + } + } + + if (numRequestedBuffers > 0 && unannouncedCredit.getAndAdd(numRequestedBuffers) == 0) { + notifyCreditAvailable(); + } + } + + public void onBuffer(Buffer buffer, int sequenceNumber, int backlog) throws IOException { boolean success = false; try { @@ -416,6 +478,10 @@ public class RemoteInputChannel extends InputChannel implements BufferRecycler, } } } + + if (success && backlog >= 0) { + onSenderBacklog(backlog); + } } finally { if (!success) { buffer.recycle(); @@ -423,16 +489,23 @@ public class RemoteInputChannel extends InputChannel implements BufferRecycler, } } - public void onEmptyBuffer(int sequenceNumber) { + public void onEmptyBuffer(int sequenceNumber, int backlog) throws IOException { + boolean success = false; + synchronized (receivedBuffers) { if (!isReleased.get()) { if (expectedSequenceNumber == sequenceNumber) { expectedSequenceNumber++; + success = true; } else { onError(new BufferReorderingException(expectedSequenceNumber, sequenceNumber)); } } } + + if (success && backlog >= 0) { + onSenderBacklog(backlog); + } } public void onFailedPartitionRequest() { @@ -462,4 +535,82 @@ public class RemoteInputChannel extends InputChannel implements BufferRecycler, expectedSequenceNumber, actualSequenceNumber); } } + + /** + * Manages the exclusive and floating buffers of this channel, and handles the + * internal buffer related logic. + */ + private static class AvailableBufferQueue { + + /** The current available floating buffers from the fixed buffer pool. */ + private final ArrayDeque<Buffer> floatingBuffers; + + /** The current available exclusive buffers from the global buffer pool. */ + private final ArrayDeque<Buffer> exclusiveBuffers; + + AvailableBufferQueue() { + this.exclusiveBuffers = new ArrayDeque<>(); + this.floatingBuffers = new ArrayDeque<>(); + } + + /** + * Adds an exclusive buffer (back) into the queue and recycles one floating buffer if the + * number of available buffers in queue is more than the required amount. + * + * @param buffer The exclusive buffer to add + * @param numRequiredBuffers The number of required buffers + * + * @return How many buffers were added to the queue + */ + int addExclusiveBuffer(Buffer buffer, int numRequiredBuffers) { + exclusiveBuffers.add(buffer); + if (getAvailableBufferSize() > numRequiredBuffers) { + Buffer floatingBuffer = floatingBuffers.poll(); + floatingBuffer.recycle(); + return 0; + } else { + return 1; + } + } + + void addFloatingBuffer(Buffer buffer) { + floatingBuffers.add(buffer); + } + + /** + * Takes the floating buffer first in order to make full use of floating + * buffers reasonably. + * + * @return An available floating or exclusive buffer, may be null + * if the channel is released. + */ + @Nullable + Buffer takeBuffer() { + if (floatingBuffers.size() > 0) { + return floatingBuffers.poll(); + } else { + return exclusiveBuffers.poll(); + } + } + + /** + * The floating buffer is recycled to local buffer pool directly, and the + * exclusive buffer will be gathered to return to global buffer pool later. + * + * @param exclusiveSegments The list that we will add exclusive segments into. + */ + void releaseAll(List<MemorySegment> exclusiveSegments) { + Buffer buffer; + while ((buffer = floatingBuffers.poll()) != null) { + buffer.recycle(); + } + while ((buffer = exclusiveBuffers.poll()) != null) { + exclusiveSegments.add(buffer.getMemorySegment()); + } + } + + int getAvailableBufferSize() { + return floatingBuffers.size() + exclusiveBuffers.size(); + } + } } http://git-wip-us.apache.org/repos/asf/flink/blob/268867ce/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageSerializationTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageSerializationTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageSerializationTest.java index 0651f97..8c87ceb 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageSerializationTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/NettyMessageSerializationTest.java @@ -62,7 +62,7 @@ public class NettyMessageSerializationTest { nioBuffer.putInt(i); } - NettyMessage.BufferResponse expected = new NettyMessage.BufferResponse(buffer, random.nextInt(), new InputChannelID()); + NettyMessage.BufferResponse expected = new NettyMessage.BufferResponse(buffer, random.nextInt(), new InputChannelID(), random.nextInt()); NettyMessage.BufferResponse actual = encodeAndDecode(expected); // Verify recycle has been called on buffer instance @@ -85,6 +85,7 @@ public class NettyMessageSerializationTest { assertEquals(expected.sequenceNumber, actual.sequenceNumber); assertEquals(expected.receiverId, actual.receiverId); + assertEquals(expected.backlog, actual.backlog); } { http://git-wip-us.apache.org/repos/asf/flink/blob/268867ce/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java index e1e5bd3..d3ff6c2 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestClientHandlerTest.java @@ -30,23 +30,16 @@ import org.apache.flink.runtime.io.network.partition.ResultPartitionID; import org.apache.flink.runtime.io.network.partition.consumer.InputChannelID; import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel; import org.apache.flink.runtime.io.network.util.TestBufferFactory; -import org.apache.flink.runtime.testutils.DiscardingRecycler; import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf; import org.apache.flink.shaded.netty4.io.netty.buffer.UnpooledByteBufAllocator; import org.apache.flink.shaded.netty4.io.netty.channel.Channel; import org.apache.flink.shaded.netty4.io.netty.channel.ChannelHandlerContext; -import org.apache.flink.shaded.netty4.io.netty.channel.embedded.EmbeddedChannel; import org.junit.Test; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; import java.io.IOException; -import java.util.concurrent.atomic.AtomicReference; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; import static org.mockito.Matchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -80,19 +73,19 @@ public class PartitionRequestClientHandlerTest { when(inputChannel.getInputChannelId()).thenReturn(new InputChannelID()); when(inputChannel.getBufferProvider()).thenReturn(bufferProvider); - final BufferResponse ReceivedBuffer = createBufferResponse( - TestBufferFactory.createBuffer(), 0, inputChannel.getInputChannelId()); + final BufferResponse receivedBuffer = createBufferResponse( + TestBufferFactory.createBuffer(), 0, inputChannel.getInputChannelId(), 2); final PartitionRequestClientHandler client = new PartitionRequestClientHandler(); client.addInputChannel(inputChannel); - client.channelRead(mock(ChannelHandlerContext.class), ReceivedBuffer); + client.channelRead(mock(ChannelHandlerContext.class), receivedBuffer); } /** * Tests a fix for FLINK-1761. * - * <p> FLINK-1761 discovered an IndexOutOfBoundsException, when receiving buffers of size 0. + * <p>FLINK-1761 discovered an IndexOutOfBoundsException, when receiving buffers of size 0. */ @Test public void testReceiveEmptyBuffer() throws Exception { @@ -108,10 +101,11 @@ public class PartitionRequestClientHandlerTest { final Buffer emptyBuffer = TestBufferFactory.createBuffer(); emptyBuffer.setSize(0); + final int backlog = 2; final BufferResponse receivedBuffer = createBufferResponse( - emptyBuffer, 0, inputChannel.getInputChannelId()); + emptyBuffer, 0, inputChannel.getInputChannelId(), backlog); - final PartitionRequestClientHandler client = new PartitionRequestClientHandler(); + final CreditBasedClientHandler client = new CreditBasedClientHandler(); client.addInputChannel(inputChannel); // Read the empty buffer @@ -119,6 +113,51 @@ public class PartitionRequestClientHandlerTest { // This should not throw an exception verify(inputChannel, never()).onError(any(Throwable.class)); + verify(inputChannel, times(1)).onEmptyBuffer(0, backlog); + } + + /** + * Verifies that {@link RemoteInputChannel#onBuffer(Buffer, int, int)} is called when a + * {@link BufferResponse} is received. + */ + @Test + public void testReceiveBuffer() throws Exception { + final Buffer buffer = TestBufferFactory.createBuffer(); + final RemoteInputChannel inputChannel = mock(RemoteInputChannel.class); + when(inputChannel.getInputChannelId()).thenReturn(new InputChannelID()); + when(inputChannel.requestBuffer()).thenReturn(buffer); + + final int backlog = 2; + final BufferResponse bufferResponse = createBufferResponse( + TestBufferFactory.createBuffer(), 0, inputChannel.getInputChannelId(), backlog); + + final CreditBasedClientHandler client = new CreditBasedClientHandler(); + client.addInputChannel(inputChannel); + + client.channelRead(mock(ChannelHandlerContext.class), bufferResponse); + + verify(inputChannel, times(1)).onBuffer(buffer, 0, backlog); + } + + /** + * Verifies that {@link RemoteInputChannel#onError(Throwable)} is called when a + * {@link BufferResponse} is received but no available buffer in input channel. + */ + @Test + public void testThrowExceptionForNoAvailableBuffer() throws Exception { + final RemoteInputChannel inputChannel = mock(RemoteInputChannel.class); + when(inputChannel.getInputChannelId()).thenReturn(new InputChannelID()); + when(inputChannel.requestBuffer()).thenReturn(null); + + final BufferResponse bufferResponse = createBufferResponse( + TestBufferFactory.createBuffer(), 0, inputChannel.getInputChannelId(), 2); + + final CreditBasedClientHandler client = new CreditBasedClientHandler(); + client.addInputChannel(inputChannel); + + client.channelRead(mock(ChannelHandlerContext.class), bufferResponse); + + verify(inputChannel, times(1)).onError(any(IllegalStateException.class)); } /** @@ -136,8 +175,8 @@ public class PartitionRequestClientHandlerTest { when(inputChannel.getBufferProvider()).thenReturn(bufferProvider); final ErrorResponse partitionNotFound = new ErrorResponse( - new PartitionNotFoundException(new ResultPartitionID()), - inputChannel.getInputChannelId()); + new PartitionNotFoundException(new ResultPartitionID()), + inputChannel.getInputChannelId()); final PartitionRequestClientHandler client = new PartitionRequestClientHandler(); client.addInputChannel(inputChannel); @@ -169,95 +208,19 @@ public class PartitionRequestClientHandlerTest { client.cancelRequestFor(inputChannel.getInputChannelId()); } - /** - * Tests that an unsuccessful message decode call for a staged message - * does not leave the channel with auto read set to false. - */ - @Test - @SuppressWarnings("unchecked") - public void testAutoReadAfterUnsuccessfulStagedMessage() throws Exception { - PartitionRequestClientHandler handler = new PartitionRequestClientHandler(); - EmbeddedChannel channel = new EmbeddedChannel(handler); - - final AtomicReference<BufferListener> listener = new AtomicReference<>(); - - BufferProvider bufferProvider = mock(BufferProvider.class); - when(bufferProvider.addBufferListener(any(BufferListener.class))).thenAnswer(new Answer<Boolean>() { - @Override - @SuppressWarnings("unchecked") - public Boolean answer(InvocationOnMock invocation) throws Throwable { - listener.set((BufferListener) invocation.getArguments()[0]); - return true; - } - }); - - when(bufferProvider.requestBuffer()).thenReturn(null); - - InputChannelID channelId = new InputChannelID(0, 0); - RemoteInputChannel inputChannel = mock(RemoteInputChannel.class); - when(inputChannel.getInputChannelId()).thenReturn(channelId); - - // The 3rd staged msg has a null buffer provider - when(inputChannel.getBufferProvider()).thenReturn(bufferProvider, bufferProvider, null); - - handler.addInputChannel(inputChannel); - - BufferResponse msg = createBufferResponse(createBuffer(true), 0, channelId); - - // Write 1st buffer msg. No buffer is available, therefore the buffer - // should be staged and auto read should be set to false. - assertTrue(channel.config().isAutoRead()); - channel.writeInbound(msg); - - // No buffer available, auto read false - assertFalse(channel.config().isAutoRead()); - - // Write more buffers... all staged. - msg = createBufferResponse(createBuffer(true), 1, channelId); - channel.writeInbound(msg); - - msg = createBufferResponse(createBuffer(true), 2, channelId); - channel.writeInbound(msg); - - // Notify about buffer => handle 1st msg - Buffer availableBuffer = createBuffer(false); - listener.get().notifyBufferAvailable(availableBuffer); - - // Start processing of staged buffers (in run pending tasks). Make - // sure that the buffer provider acts like it's destroyed. - when(bufferProvider.addBufferListener(any(BufferListener.class))).thenReturn(false); - when(bufferProvider.isDestroyed()).thenReturn(true); - - // Execute all tasks that are scheduled in the event loop. Further - // eventLoop().execute() calls are directly executed, if they are - // called in the scope of this call. - channel.runPendingTasks(); - - assertTrue(channel.config().isAutoRead()); - } - // --------------------------------------------------------------------------------------------- - private static Buffer createBuffer(boolean fill) { - MemorySegment segment = MemorySegmentFactory.allocateUnpooledSegment(1024, null); - if (fill) { - for (int i = 0; i < 1024; i++) { - segment.put(i, (byte) i); - } - } - return new Buffer(segment, DiscardingRecycler.INSTANCE, true); - } - /** * Returns a deserialized buffer message as it would be received during runtime. */ private BufferResponse createBufferResponse( Buffer buffer, int sequenceNumber, - InputChannelID receivingChannelId) throws IOException { + InputChannelID receivingChannelId, + int backlog) throws IOException { // Mock buffer to serialize - BufferResponse resp = new BufferResponse(buffer, sequenceNumber, receivingChannelId); + BufferResponse resp = new BufferResponse(buffer, sequenceNumber, receivingChannelId, backlog); ByteBuf serialized = resp.write(UnpooledByteBufAllocator.DEFAULT); http://git-wip-us.apache.org/repos/asf/flink/blob/268867ce/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java index 6f98119..81788c9 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java @@ -216,7 +216,7 @@ public class InputGateConcurrentTest { @Override void addBuffer(Buffer buffer) throws Exception { - channel.onBuffer(buffer, seq++); + channel.onBuffer(buffer, seq++, -1); } } http://git-wip-us.apache.org/repos/asf/flink/blob/268867ce/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java index 324a060..4e90265 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java @@ -206,9 +206,9 @@ public class InputGateFairnessTest { channels[i] = channel; for (int p = 0; p < buffersPerChannel; p++) { - channel.onBuffer(mockBuffer, p); + channel.onBuffer(mockBuffer, p, -1); } - channel.onBuffer(EventSerializer.toBuffer(EndOfPartitionEvent.INSTANCE), buffersPerChannel); + channel.onBuffer(EventSerializer.toBuffer(EndOfPartitionEvent.INSTANCE), buffersPerChannel, -1); gate.setInputChannel(new IntermediateResultPartitionID(), channel); } @@ -263,7 +263,7 @@ public class InputGateFairnessTest { gate.setInputChannel(new IntermediateResultPartitionID(), channel); } - channels[11].onBuffer(mockBuffer, 0); + channels[11].onBuffer(mockBuffer, 0, -1); channelSequenceNums[11]++; // read all the buffers and the EOF event @@ -325,7 +325,7 @@ public class InputGateFairnessTest { Collections.shuffle(poss); for (int i : poss) { - partitions[i].onBuffer(buffer, sequenceNumbers[i]++); + partitions[i].onBuffer(buffer, sequenceNumbers[i]++, -1); } } http://git-wip-us.apache.org/repos/asf/flink/blob/268867ce/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java index d791ced..863f886 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannelTest.java @@ -18,24 +18,28 @@ package org.apache.flink.runtime.io.network.partition.consumer; -import org.apache.flink.core.memory.MemorySegment; -import org.apache.flink.core.memory.MemorySegmentFactory; +import org.apache.flink.api.common.JobID; import org.apache.flink.runtime.execution.CancelTaskException; import org.apache.flink.runtime.io.network.ConnectionID; import org.apache.flink.runtime.io.network.ConnectionManager; import org.apache.flink.runtime.io.network.buffer.Buffer; +import org.apache.flink.runtime.io.network.buffer.BufferPool; +import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool; import org.apache.flink.runtime.io.network.netty.PartitionRequestClient; import org.apache.flink.runtime.io.network.partition.ProducerFailedException; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; +import org.apache.flink.runtime.io.network.partition.ResultPartitionType; import org.apache.flink.runtime.io.network.util.TestBufferFactory; +import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; import org.apache.flink.runtime.metrics.groups.UnregisteredMetricGroups; +import org.apache.flink.runtime.taskmanager.TaskActions; import org.apache.flink.shaded.guava18.com.google.common.collect.Lists; import org.junit.Test; -import scala.Tuple2; import java.io.IOException; +import java.util.ArrayDeque; import java.util.ArrayList; import java.util.List; import java.util.concurrent.Callable; @@ -43,12 +47,14 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; +import scala.Tuple2; + import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Matchers.any; -import static org.mockito.Matchers.anyListOf; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; @@ -66,10 +72,10 @@ public class RemoteInputChannelTest { final Buffer buffer = TestBufferFactory.createBuffer(); // The test - inputChannel.onBuffer(buffer.retain(), 0); + inputChannel.onBuffer(buffer.retain(), 0, -1); // This does not yet throw the exception, but sets the error at the channel. - inputChannel.onBuffer(buffer, 29); + inputChannel.onBuffer(buffer, 29, -1); try { inputChannel.getNextBuffer(); @@ -113,7 +119,7 @@ public class RemoteInputChannelTest { for (int j = 0; j < 128; j++) { // this is the same buffer over and over again which will be // recycled by the RemoteInputChannel - inputChannel.onBuffer(buffer.retain(), j); + inputChannel.onBuffer(buffer.retain(), j, -1); } if (inputChannel.isReleased()) { @@ -301,81 +307,562 @@ public class RemoteInputChannelTest { } /** - * Tests {@link RemoteInputChannel#recycle(MemorySegment)}, verifying the exclusive segment is - * recycled to available buffers directly and it triggers notify of announced credit. + * Tests to verify the behaviours of three different processes if the number of available + * buffers is less than required buffers. + * + * 1. Recycle the floating buffer + * 2. Recycle the exclusive buffer + * 3. Decrease the sender's backlog */ @Test - public void testRecycleExclusiveBufferBeforeReleased() throws Exception { - final SingleInputGate inputGate = mock(SingleInputGate.class); - final RemoteInputChannel inputChannel = spy(createRemoteInputChannel(inputGate)); - - // Recycle exclusive segment - inputChannel.recycle(MemorySegmentFactory.allocateUnpooledSegment(1024, inputChannel)); + public void testAvailableBuffersLessThanRequiredBuffers() throws Exception { + // Setup + final NetworkBufferPool networkBufferPool = new NetworkBufferPool(16, 32); + final int numExclusiveBuffers = 2; + final int numFloatingBuffers = 14; - assertEquals("There should be one buffer available after recycle.", - 1, inputChannel.getNumberOfAvailableBuffers()); - verify(inputChannel, times(1)).notifyCreditAvailable(); + final SingleInputGate inputGate = createSingleInputGate(); + final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate); + inputGate.setInputChannel(inputChannel.partitionId.getPartitionId(), inputChannel); + try { + final BufferPool bufferPool = spy(networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers)); + inputGate.setBufferPool(bufferPool); + inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveBuffers); + + // Prepare the exclusive and floating buffers to verify recycle logic later + final Buffer exclusiveBuffer = inputChannel.requestBuffer(); + assertNotNull(exclusiveBuffer); + + final int numRecycleFloatingBuffers = 2; + final ArrayDeque<Buffer> floatingBufferQueue = new ArrayDeque<>(numRecycleFloatingBuffers); + for (int i = 0; i < numRecycleFloatingBuffers; i++) { + Buffer floatingBuffer = bufferPool.requestBuffer(); + assertNotNull(floatingBuffer); + floatingBufferQueue.add(floatingBuffer); + } - inputChannel.recycle(MemorySegmentFactory.allocateUnpooledSegment(1024, inputChannel)); + verify(bufferPool, times(numRecycleFloatingBuffers)).requestBuffer(); + + // Receive the producer's backlog more than the number of available floating buffers + inputChannel.onSenderBacklog(14); + + // The channel requests (backlog + numExclusiveBuffers) floating buffers from local pool. + // It does not get enough floating buffers and register as buffer listener + verify(bufferPool, times(15)).requestBuffer(); + verify(bufferPool, times(1)).addBufferListener(inputChannel); + assertEquals("There should be 13 buffers available in the channel", + 13, inputChannel.getNumberOfAvailableBuffers()); + assertEquals("There should be 16 buffers required in the channel", + 16, inputChannel.getNumberOfRequiredBuffers()); + assertEquals("There should be 0 buffer available in local pool", + 0, bufferPool.getNumberOfAvailableMemorySegments()); + + // Increase the backlog + inputChannel.onSenderBacklog(16); + + // The channel is already in the status of waiting for buffers and will not request any more + verify(bufferPool, times(15)).requestBuffer(); + verify(bufferPool, times(1)).addBufferListener(inputChannel); + assertEquals("There should be 13 buffers available in the channel", + 13, inputChannel.getNumberOfAvailableBuffers()); + assertEquals("There should be 18 buffers required in the channel", + 18, inputChannel.getNumberOfRequiredBuffers()); + assertEquals("There should be 0 buffer available in local pool", + 0, bufferPool.getNumberOfAvailableMemorySegments()); + + // Recycle one floating buffer + floatingBufferQueue.poll().recycle(); + + // Assign the floating buffer to the listener and the channel is still waiting for more floating buffers + verify(bufferPool, times(15)).requestBuffer(); + verify(bufferPool, times(1)).addBufferListener(inputChannel); + assertEquals("There should be 14 buffers available in the channel", + 14, inputChannel.getNumberOfAvailableBuffers()); + assertEquals("There should be 18 buffers required in the channel", + 18, inputChannel.getNumberOfRequiredBuffers()); + assertEquals("There should be 0 buffer available in local pool", + 0, bufferPool.getNumberOfAvailableMemorySegments()); + + // Recycle one more floating buffer + floatingBufferQueue.poll().recycle(); + + // Assign the floating buffer to the listener and the channel is still waiting for more floating buffers + verify(bufferPool, times(15)).requestBuffer(); + verify(bufferPool, times(1)).addBufferListener(inputChannel); + assertEquals("There should be 15 buffers available in the channel", + 15, inputChannel.getNumberOfAvailableBuffers()); + assertEquals("There should be 18 buffers required in the channel", + 18, inputChannel.getNumberOfRequiredBuffers()); + assertEquals("There should be 0 buffer available in local pool", + 0, bufferPool.getNumberOfAvailableMemorySegments()); + + // Decrease the backlog + inputChannel.onSenderBacklog(15); + + // Only the number of required buffers is changed by (backlog + numExclusiveBuffers) + verify(bufferPool, times(15)).requestBuffer(); + verify(bufferPool, times(1)).addBufferListener(inputChannel); + assertEquals("There should be 15 buffers available in the channel", + 15, inputChannel.getNumberOfAvailableBuffers()); + assertEquals("There should be 17 buffers required in the channel", + 17, inputChannel.getNumberOfRequiredBuffers()); + assertEquals("There should be 0 buffer available in local pool", + 0, bufferPool.getNumberOfAvailableMemorySegments()); + + // Recycle one exclusive buffer + exclusiveBuffer.recycle(); + + // The exclusive buffer is returned to the channel directly + verify(bufferPool, times(15)).requestBuffer(); + verify(bufferPool, times(1)).addBufferListener(inputChannel); + assertEquals("There should be 16 buffers available in the channel", + 16, inputChannel.getNumberOfAvailableBuffers()); + assertEquals("There should be 17 buffers required in the channel", + 17, inputChannel.getNumberOfRequiredBuffers()); + assertEquals("There should be 0 buffers available in local pool", + 0, bufferPool.getNumberOfAvailableMemorySegments()); + + } finally { + // Release all the buffer resources + inputChannel.releaseAllResources(); - assertEquals("There should be two buffers available after recycle.", - 2, inputChannel.getNumberOfAvailableBuffers()); - // It should be called only once when increased from zero. - verify(inputChannel, times(1)).notifyCreditAvailable(); + networkBufferPool.destroyAllBufferPools(); + networkBufferPool.destroy(); + } } /** - * Tests {@link RemoteInputChannel#recycle(MemorySegment)}, verifying the exclusive segment is - * recycled to global pool via input gate when channel is released. + * Tests to verify the behaviours of recycling floating and exclusive buffers if the number of available + * buffers equals to required buffers. */ @Test - public void testRecycleExclusiveBufferAfterReleased() throws Exception { + public void testAvailableBuffersEqualToRequiredBuffers() throws Exception { // Setup - final SingleInputGate inputGate = mock(SingleInputGate.class); - final RemoteInputChannel inputChannel = spy(createRemoteInputChannel(inputGate)); - - inputChannel.releaseAllResources(); + final NetworkBufferPool networkBufferPool = new NetworkBufferPool(16, 32); + final int numExclusiveBuffers = 2; + final int numFloatingBuffers = 14; - // Recycle exclusive segment after channel released - inputChannel.recycle(MemorySegmentFactory.allocateUnpooledSegment(1024, inputChannel)); + final SingleInputGate inputGate = createSingleInputGate(); + final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate); + inputGate.setInputChannel(inputChannel.partitionId.getPartitionId(), inputChannel); + try { + final BufferPool bufferPool = spy(networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers)); + inputGate.setBufferPool(bufferPool); + inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveBuffers); + + // Prepare the exclusive and floating buffers to verify recycle logic later + final Buffer exclusiveBuffer = inputChannel.requestBuffer(); + assertNotNull(exclusiveBuffer); + final Buffer floatingBuffer = bufferPool.requestBuffer(); + assertNotNull(floatingBuffer); + verify(bufferPool, times(1)).requestBuffer(); + + // Receive the producer's backlog + inputChannel.onSenderBacklog(12); + + // The channel requests (backlog + numExclusiveBuffers) floating buffers from local pool + // and gets enough floating buffers + verify(bufferPool, times(14)).requestBuffer(); + verify(bufferPool, times(0)).addBufferListener(inputChannel); + assertEquals("There should be 14 buffers available in the channel", + 14, inputChannel.getNumberOfAvailableBuffers()); + assertEquals("There should be 14 buffers required in the channel", + 14, inputChannel.getNumberOfRequiredBuffers()); + assertEquals("There should be 0 buffer available in local pool", + 0, bufferPool.getNumberOfAvailableMemorySegments()); + + // Recycle one floating buffer + floatingBuffer.recycle(); + + // The floating buffer is returned to local buffer directly because the channel is not waiting + // for floating buffers + verify(bufferPool, times(14)).requestBuffer(); + verify(bufferPool, times(0)).addBufferListener(inputChannel); + assertEquals("There should be 14 buffers available in the channel", + 14, inputChannel.getNumberOfAvailableBuffers()); + assertEquals("There should be 14 buffers required in the channel", + 14, inputChannel.getNumberOfRequiredBuffers()); + assertEquals("There should be 1 buffer available in local pool", + 1, bufferPool.getNumberOfAvailableMemorySegments()); + + // Recycle one exclusive buffer + exclusiveBuffer.recycle(); + + // Return one extra floating buffer to the local pool because the number of available buffers + // already equals to required buffers + verify(bufferPool, times(14)).requestBuffer(); + verify(bufferPool, times(0)).addBufferListener(inputChannel); + assertEquals("There should be 14 buffers available in the channel", + 14, inputChannel.getNumberOfAvailableBuffers()); + assertEquals("There should be 14 buffers required in the channel", + 14, inputChannel.getNumberOfRequiredBuffers()); + assertEquals("There should be 2 buffers available in local pool", + 2, bufferPool.getNumberOfAvailableMemorySegments()); + + } finally { + // Release all the buffer resources + inputChannel.releaseAllResources(); - assertEquals("Resource leak during recycling buffer after channel is released.", - 0, inputChannel.getNumberOfAvailableBuffers()); - verify(inputChannel, times(0)).notifyCreditAvailable(); - verify(inputGate, times(1)).returnExclusiveSegments(anyListOf(MemorySegment.class)); + networkBufferPool.destroyAllBufferPools(); + networkBufferPool.destroy(); + } } /** - * Tests {@link RemoteInputChannel#releaseAllResources()}, verifying the exclusive segments are - * recycled to global pool via input gate and no resource leak. + * Tests to verify the behaviours of recycling floating and exclusive buffers if the number of available + * buffers is more than required buffers by decreasing the sender's backlog. */ @Test - public void testReleaseExclusiveBuffers() throws Exception { + public void testAvailableBuffersMoreThanRequiredBuffers() throws Exception { // Setup - final SingleInputGate inputGate = mock(SingleInputGate.class); + final NetworkBufferPool networkBufferPool = new NetworkBufferPool(16, 32); + final int numExclusiveBuffers = 2; + final int numFloatingBuffers = 14; + + final SingleInputGate inputGate = createSingleInputGate(); final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate); + inputGate.setInputChannel(inputChannel.partitionId.getPartitionId(), inputChannel); + try { + final BufferPool bufferPool = spy(networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers)); + inputGate.setBufferPool(bufferPool); + inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveBuffers); + + // Prepare the exclusive and floating buffers to verify recycle logic later + final Buffer exclusiveBuffer = inputChannel.requestBuffer(); + assertNotNull(exclusiveBuffer); + + final Buffer floatingBuffer = bufferPool.requestBuffer(); + assertNotNull(floatingBuffer); + + verify(bufferPool, times(1)).requestBuffer(); + + // Receive the producer's backlog + inputChannel.onSenderBacklog(12); + + // The channel gets enough floating buffers from local pool + verify(bufferPool, times(14)).requestBuffer(); + verify(bufferPool, times(0)).addBufferListener(inputChannel); + assertEquals("There should be 14 buffers available in the channel", + 14, inputChannel.getNumberOfAvailableBuffers()); + assertEquals("There should be 14 buffers required in the channel", + 14, inputChannel.getNumberOfRequiredBuffers()); + assertEquals("There should be 0 buffer available in local pool", + 0, bufferPool.getNumberOfAvailableMemorySegments()); + + // Decrease the backlog to make the number of available buffers more than required buffers + inputChannel.onSenderBacklog(10); + + // Only the number of required buffers is changed by (backlog + numExclusiveBuffers) + verify(bufferPool, times(14)).requestBuffer(); + verify(bufferPool, times(0)).addBufferListener(inputChannel); + assertEquals("There should be 14 buffers available in the channel", + 14, inputChannel.getNumberOfAvailableBuffers()); + assertEquals("There should be 12 buffers required in the channel", + 12, inputChannel.getNumberOfRequiredBuffers()); + assertEquals("There should be 0 buffer available in local pool", + 0, bufferPool.getNumberOfAvailableMemorySegments()); + + // Recycle one exclusive buffer + exclusiveBuffer.recycle(); + + // Return one extra floating buffer to the local pool because the number of available buffers + // is more than required buffers + verify(bufferPool, times(14)).requestBuffer(); + verify(bufferPool, times(0)).addBufferListener(inputChannel); + assertEquals("There should be 14 buffers available in the channel", + 14, inputChannel.getNumberOfAvailableBuffers()); + assertEquals("There should be 12 buffers required in the channel", + 12, inputChannel.getNumberOfRequiredBuffers()); + assertEquals("There should be 1 buffer available in local pool", + 1, bufferPool.getNumberOfAvailableMemorySegments()); + + // Recycle one floating buffer + floatingBuffer.recycle(); + + // The floating buffer is returned to local pool directly because the channel is not waiting for + // floating buffers + verify(bufferPool, times(14)).requestBuffer(); + verify(bufferPool, times(0)).addBufferListener(inputChannel); + assertEquals("There should be 14 buffers available in the channel", + 14, inputChannel.getNumberOfAvailableBuffers()); + assertEquals("There should be 12 buffers required in the channel", + 12, inputChannel.getNumberOfRequiredBuffers()); + assertEquals("There should be 2 buffers available in local pool", + 2, bufferPool.getNumberOfAvailableMemorySegments()); + + } finally { + // Release all the buffer resources + inputChannel.releaseAllResources(); - // Assign exclusive segments to channel - final List<MemorySegment> exclusiveSegments = new ArrayList<>(); + networkBufferPool.destroyAllBufferPools(); + networkBufferPool.destroy(); + } + } + + /** + * Tests to verify that the buffer pool will distribute available floating buffers among + * all the channel listeners in a fair way. + */ + @Test + public void testFairDistributionFloatingBuffers() throws Exception { + // Setup + final NetworkBufferPool networkBufferPool = new NetworkBufferPool(12, 32); final int numExclusiveBuffers = 2; - for (int i = 0; i < numExclusiveBuffers; i++) { - exclusiveSegments.add(MemorySegmentFactory.allocateUnpooledSegment(1024, inputChannel)); + final int numFloatingBuffers = 3; + + final SingleInputGate inputGate = createSingleInputGate(); + final RemoteInputChannel channel1 = spy(createRemoteInputChannel(inputGate)); + final RemoteInputChannel channel2 = spy(createRemoteInputChannel(inputGate)); + final RemoteInputChannel channel3 = spy(createRemoteInputChannel(inputGate)); + inputGate.setInputChannel(channel1.partitionId.getPartitionId(), channel1); + inputGate.setInputChannel(channel2.partitionId.getPartitionId(), channel2); + inputGate.setInputChannel(channel3.partitionId.getPartitionId(), channel3); + try { + final BufferPool bufferPool = spy(networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers)); + inputGate.setBufferPool(bufferPool); + inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveBuffers); + + // Exhaust all the floating buffers + final List<Buffer> floatingBuffers = new ArrayList<>(numFloatingBuffers); + for (int i = 0; i < numFloatingBuffers; i++) { + Buffer buffer = bufferPool.requestBuffer(); + assertNotNull(buffer); + floatingBuffers.add(buffer); + } + + // Receive the producer's backlog to trigger request floating buffers from pool + // and register as listeners as a result + channel1.onSenderBacklog(8); + channel2.onSenderBacklog(8); + channel3.onSenderBacklog(8); + + verify(bufferPool, times(1)).addBufferListener(channel1); + verify(bufferPool, times(1)).addBufferListener(channel2); + verify(bufferPool, times(1)).addBufferListener(channel3); + assertEquals("There should be " + numExclusiveBuffers + " buffers available in the channel", + numExclusiveBuffers, channel1.getNumberOfAvailableBuffers()); + assertEquals("There should be " + numExclusiveBuffers + " buffers available in the channel", + numExclusiveBuffers, channel2.getNumberOfAvailableBuffers()); + assertEquals("There should be " + numExclusiveBuffers + " buffers available in the channel", + numExclusiveBuffers, channel3.getNumberOfAvailableBuffers()); + + // Recycle three floating buffers to trigger notify buffer available + for (Buffer buffer : floatingBuffers) { + buffer.recycle(); + } + + verify(channel1, times(1)).notifyBufferAvailable(any(Buffer.class)); + verify(channel2, times(1)).notifyBufferAvailable(any(Buffer.class)); + verify(channel3, times(1)).notifyBufferAvailable(any(Buffer.class)); + assertEquals("There should be 3 buffers available in the channel", 3, channel1.getNumberOfAvailableBuffers()); + assertEquals("There should be 3 buffers available in the channel", 3, channel2.getNumberOfAvailableBuffers()); + assertEquals("There should be 3 buffers available in the channel", 3, channel3.getNumberOfAvailableBuffers()); + + } finally { + // Release all the buffer resources + channel1.releaseAllResources(); + channel2.releaseAllResources(); + channel3.releaseAllResources(); + + networkBufferPool.destroyAllBufferPools(); + networkBufferPool.destroy(); } - inputChannel.assignExclusiveSegments(exclusiveSegments); + } + + /** + * Tests to verify that there is no race condition with two things running in parallel: + * requesting floating buffers on sender backlog and some other thread releasing + * the input channel. + */ + @Test + public void testConcurrentOnSenderBacklogAndRelease() throws Exception { + // Setup + final NetworkBufferPool networkBufferPool = new NetworkBufferPool(130, 32); + final int numExclusiveBuffers = 2; + final int numFloatingBuffers = 128; + + final ExecutorService executor = Executors.newFixedThreadPool(2); + + final SingleInputGate inputGate = createSingleInputGate(); + final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate); + inputGate.setInputChannel(inputChannel.partitionId.getPartitionId(), inputChannel); + try { + final BufferPool bufferPool = networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers); + inputGate.setBufferPool(bufferPool); + inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveBuffers); + + final Callable<Void> requestBufferTask = new Callable<Void>() { + @Override + public Void call() throws Exception { + while (true) { + for (int j = 1; j <= numFloatingBuffers; j++) { + inputChannel.onSenderBacklog(j); + } - assertEquals("The number of available buffers is not equal to the assigned amount.", - numExclusiveBuffers, inputChannel.getNumberOfAvailableBuffers()); + if (inputChannel.isReleased()) { + return null; + } + } + } + }; + + final Callable<Void> releaseTask = new Callable<Void>() { + @Override + public Void call() throws Exception { + inputChannel.releaseAllResources(); + + return null; + } + }; + + // Submit tasks and wait to finish + submitTasksAndWaitForResults(executor, new Callable[]{requestBufferTask, releaseTask}); + + assertEquals("There should be no buffers available in the channel.", + 0, inputChannel.getNumberOfAvailableBuffers()); + assertEquals("There should be 130 buffers available in local pool.", + 130, bufferPool.getNumberOfAvailableMemorySegments() + networkBufferPool.getNumberOfAvailableMemorySegments()); - // Release this channel - inputChannel.releaseAllResources(); + } finally { + // Release all the buffer resources once exception + if (!inputChannel.isReleased()) { + inputChannel.releaseAllResources(); + } + + networkBufferPool.destroyAllBufferPools(); + networkBufferPool.destroy(); - assertEquals("Resource leak after channel is released.", - 0, inputChannel.getNumberOfAvailableBuffers()); - verify(inputGate, times(1)).returnExclusiveSegments(anyListOf(MemorySegment.class)); + executor.shutdown(); + } + } + + /** + * Tests to verify that there is no race condition with two things running in parallel: + * requesting floating buffers on sender backlog and some other thread recycling + * floating or exclusive buffers. + */ + @Test + public void testConcurrentOnSenderBacklogAndRecycle() throws Exception { + // Setup + final NetworkBufferPool networkBufferPool = new NetworkBufferPool(248, 32); + final int numExclusiveSegments = 120; + final int numFloatingBuffers = 128; + final int backlog = 128; + + final ExecutorService executor = Executors.newFixedThreadPool(3); + + final SingleInputGate inputGate = createSingleInputGate(); + final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate); + inputGate.setInputChannel(inputChannel.partitionId.getPartitionId(), inputChannel); + try { + final BufferPool bufferPool = networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers); + inputGate.setBufferPool(bufferPool); + inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveSegments); + + final Callable<Void> requestBufferTask = new Callable<Void>() { + @Override + public Void call() throws Exception { + for (int j = 1; j <= backlog; j++) { + inputChannel.onSenderBacklog(j); + } + + return null; + } + }; + + // Submit tasks and wait to finish + submitTasksAndWaitForResults(executor, new Callable[]{ + recycleExclusiveBufferTask(inputChannel, numExclusiveSegments), + recycleFloatingBufferTask(bufferPool, numFloatingBuffers), + requestBufferTask}); + + assertEquals("There should be " + inputChannel.getNumberOfRequiredBuffers() +" buffers available in channel.", + inputChannel.getNumberOfRequiredBuffers(), inputChannel.getNumberOfAvailableBuffers()); + assertEquals("There should be no buffers available in local pool.", + 0, bufferPool.getNumberOfAvailableMemorySegments()); + + } finally { + // Release all the buffer resources + inputChannel.releaseAllResources(); + + networkBufferPool.destroyAllBufferPools(); + networkBufferPool.destroy(); + + executor.shutdown(); + } + } + + /** + * Tests to verify that there is no race condition with two things running in parallel: + * recycling the exclusive or floating buffers and some other thread releasing the + * input channel. + */ + @Test + public void testConcurrentRecycleAndRelease() throws Exception { + // Setup + final NetworkBufferPool networkBufferPool = new NetworkBufferPool(248, 32); + final int numExclusiveSegments = 120; + final int numFloatingBuffers = 128; + + final ExecutorService executor = Executors.newFixedThreadPool(3); + + final SingleInputGate inputGate = createSingleInputGate(); + final RemoteInputChannel inputChannel = createRemoteInputChannel(inputGate); + inputGate.setInputChannel(inputChannel.partitionId.getPartitionId(), inputChannel); + try { + final BufferPool bufferPool = networkBufferPool.createBufferPool(numFloatingBuffers, numFloatingBuffers); + inputGate.setBufferPool(bufferPool); + inputGate.assignExclusiveSegments(networkBufferPool, numExclusiveSegments); + + final Callable<Void> releaseTask = new Callable<Void>() { + @Override + public Void call() throws Exception { + inputChannel.releaseAllResources(); + + return null; + } + }; + + // Submit tasks and wait to finish + submitTasksAndWaitForResults(executor, new Callable[]{ + recycleExclusiveBufferTask(inputChannel, numExclusiveSegments), + recycleFloatingBufferTask(bufferPool, numFloatingBuffers), + releaseTask}); + + assertEquals("There should be no buffers available in the channel.", + 0, inputChannel.getNumberOfAvailableBuffers()); + assertEquals("There should be " + numFloatingBuffers + " buffers available in local pool.", + numFloatingBuffers, bufferPool.getNumberOfAvailableMemorySegments()); + assertEquals("There should be " + numExclusiveSegments + " buffers available in global pool.", + numExclusiveSegments, networkBufferPool.getNumberOfAvailableMemorySegments()); + + } finally { + // Release all the buffer resources once exception + if (!inputChannel.isReleased()) { + inputChannel.releaseAllResources(); + } + + networkBufferPool.destroyAllBufferPools(); + networkBufferPool.destroy(); + + executor.shutdown(); + } } // --------------------------------------------------------------------------------------------- + private SingleInputGate createSingleInputGate() { + return new SingleInputGate( + "InputGate", + new JobID(), + new IntermediateDataSetID(), + ResultPartitionType.PIPELINED_CREDIT_BASED, + 0, + 1, + mock(TaskActions.class), + UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup()); + } + private RemoteInputChannel createRemoteInputChannel(SingleInputGate inputGate) throws IOException, InterruptedException { @@ -403,4 +890,78 @@ public class RemoteInputChannelTest { initialAndMaxRequestBackoff._2(), UnregisteredMetricGroups.createUnregisteredTaskMetricGroup().getIOMetricGroup()); } + + /** + * Requests the exclusive buffers from input channel first and then recycles them by a callable task. + * + * @param inputChannel The input channel that exclusive buffers request from. + * @param numExclusiveSegments The number of exclusive buffers to request. + * @return The callable task to recycle exclusive buffers. + */ + private Callable<Void> recycleExclusiveBufferTask(RemoteInputChannel inputChannel, int numExclusiveSegments) { + final List<Buffer> exclusiveBuffers = new ArrayList<>(numExclusiveSegments); + // Exhaust all the exclusive buffers + for (int i = 0; i < numExclusiveSegments; i++) { + Buffer buffer = inputChannel.requestBuffer(); + assertNotNull(buffer); + exclusiveBuffers.add(buffer); + } + + return new Callable<Void>() { + @Override + public Void call() throws Exception { + for (Buffer buffer : exclusiveBuffers) { + buffer.recycle(); + } + + return null; + } + }; + } + + /** + * Requests the floating buffers from pool first and then recycles them by a callable task. + * + * @param bufferPool The buffer pool that floating buffers request from. + * @param numFloatingBuffers The number of floating buffers to request. + * @return The callable task to recycle floating buffers. + */ + private Callable<Void> recycleFloatingBufferTask(BufferPool bufferPool, int numFloatingBuffers) throws Exception { + final List<Buffer> floatingBuffers = new ArrayList<>(numFloatingBuffers); + // Exhaust all the floating buffers + for (int i = 0; i < numFloatingBuffers; i++) { + Buffer buffer = bufferPool.requestBuffer(); + assertNotNull(buffer); + floatingBuffers.add(buffer); + } + + return new Callable<Void>() { + @Override + public Void call() throws Exception { + for (Buffer buffer : floatingBuffers) { + buffer.recycle(); + } + + return null; + } + }; + } + + /** + * Submits all the callable tasks to the executor and waits for the results. + * + * @param executor The executor service for running tasks. + * @param tasks The callable tasks to be submitted and executed. + */ + private void submitTasksAndWaitForResults(ExecutorService executor, Callable[] tasks) throws Exception { + final List<Future> results = Lists.newArrayListWithCapacity(tasks.length); + + for(Callable task : tasks) { + results.add(executor.submit(task)); + } + + for (Future result : results) { + result.get(); + } + } }