[FLINK-7378][core] Create a fix size (non rebalancing) buffer pool type for the floating buffers
This closes #4485. Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/064a1e60 Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/064a1e60 Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/064a1e60 Branch: refs/heads/master Commit: 064a1e60c5310688c7fa8bc5d07ce5da512e076d Parents: a803dc7 Author: Zhijiang <[email protected]> Authored: Mon Aug 7 17:31:17 2017 +0800 Committer: zentol <[email protected]> Committed: Tue Oct 10 16:53:19 2017 +0200 ---------------------------------------------------------------------- .../runtime/io/network/NetworkEnvironment.java | 25 +-- .../io/network/buffer/NetworkBufferPool.java | 62 +++++++ .../network/partition/ResultPartitionType.java | 27 ++- .../partition/consumer/RemoteInputChannel.java | 6 + .../partition/consumer/SingleInputGate.java | 57 +++++- .../io/network/NetworkEnvironmentTest.java | 19 +- .../network/buffer/BufferPoolFactoryTest.java | 102 ++++++++--- .../network/buffer/NetworkBufferPoolTest.java | 172 +++++++++++++++++++ .../partition/consumer/SingleInputGateTest.java | 67 ++++++++ 9 files changed, 492 insertions(+), 45 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/064a1e60/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java index 4269af6..9193859 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java @@ -35,6 +35,7 @@ import org.apache.flink.runtime.query.netty.KvStateServer; import org.apache.flink.runtime.state.internal.InternalKvState; import org.apache.flink.runtime.taskmanager.Task; import org.apache.flink.runtime.taskmanager.TaskManager; +import org.apache.flink.util.ExceptionUtils; import org.apache.flink.util.Preconditions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -184,7 +185,7 @@ public class NetworkEnvironment { partition.getNumberOfSubpartitions() * networkBuffersPerChannel + extraNetworkBuffersPerGate : Integer.MAX_VALUE; bufferPool = networkBufferPool.createBufferPool(partition.getNumberOfSubpartitions(), - maxNumberOfMemorySegments); + maxNumberOfMemorySegments); partition.registerBufferPool(bufferPool); resultPartitionManager.registerResultPartition(partition); @@ -211,22 +212,24 @@ public class NetworkEnvironment { BufferPool bufferPool = null; try { - int maxNumberOfMemorySegments = gate.getConsumedPartitionType().isBounded() ? - gate.getNumberOfInputChannels() * networkBuffersPerChannel + - extraNetworkBuffersPerGate : Integer.MAX_VALUE; - bufferPool = networkBufferPool.createBufferPool(gate.getNumberOfInputChannels(), - maxNumberOfMemorySegments); + if (gate.getConsumedPartitionType().isCreditBased()) { + // Create a fixed-size buffer pool for floating buffers and assign exclusive buffers to input channels directly + bufferPool = networkBufferPool.createBufferPool(extraNetworkBuffersPerGate, extraNetworkBuffersPerGate); + gate.assignExclusiveSegments(networkBufferPool, networkBuffersPerChannel); + } else { + int maxNumberOfMemorySegments = gate.getConsumedPartitionType().isBounded() ? + gate.getNumberOfInputChannels() * networkBuffersPerChannel + + extraNetworkBuffersPerGate : Integer.MAX_VALUE; + bufferPool = networkBufferPool.createBufferPool(gate.getNumberOfInputChannels(), + maxNumberOfMemorySegments); + } gate.setBufferPool(bufferPool); } catch (Throwable t) { if (bufferPool != null) { bufferPool.lazyDestroy(); } - if (t instanceof IOException) { - throw (IOException) t; - } else { - throw new IOException(t.getMessage(), t); - } + ExceptionUtils.rethrowIOException(t); } } } http://git-wip-us.apache.org/repos/asf/flink/blob/064a1e60/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/NetworkBufferPool.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/NetworkBufferPool.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/NetworkBufferPool.java index b70e912..f899f05 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/NetworkBufferPool.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/buffer/NetworkBufferPool.java @@ -22,16 +22,21 @@ import org.apache.flink.configuration.TaskManagerOptions; import org.apache.flink.core.memory.MemorySegment; import org.apache.flink.core.memory.MemorySegmentFactory; import org.apache.flink.core.memory.MemoryType; +import org.apache.flink.util.ExceptionUtils; import org.apache.flink.util.MathUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; import java.nio.ByteBuffer; +import java.util.ArrayList; import java.util.HashSet; +import java.util.List; import java.util.Set; import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.TimeUnit; +import static org.apache.flink.util.Preconditions.checkArgument; import static org.apache.flink.util.Preconditions.checkNotNull; /** @@ -131,6 +136,63 @@ public class NetworkBufferPool implements BufferPoolFactory { availableMemorySegments.add(segment); } + public List<MemorySegment> requestMemorySegments(int numRequiredBuffers) throws IOException { + checkArgument(numRequiredBuffers > 0, "The number of required buffers should be larger than 0."); + + synchronized (factoryLock) { + if (isDestroyed) { + throw new IllegalStateException("Network buffer pool has already been destroyed."); + } + + if (numTotalRequiredBuffers + numRequiredBuffers > totalNumberOfMemorySegments) { + throw new IOException(String.format("Insufficient number of network buffers: " + + "required %d, but only %d available. The total number of network " + + "buffers is currently set to %d of %d bytes each. You can increase this " + + "number by setting the configuration keys '%s', '%s', and '%s'.", + numRequiredBuffers, + totalNumberOfMemorySegments - numTotalRequiredBuffers, + totalNumberOfMemorySegments, + memorySegmentSize, + TaskManagerOptions.NETWORK_BUFFERS_MEMORY_FRACTION.key(), + TaskManagerOptions.NETWORK_BUFFERS_MEMORY_MIN.key(), + TaskManagerOptions.NETWORK_BUFFERS_MEMORY_MAX.key())); + } + + this.numTotalRequiredBuffers += numRequiredBuffers; + + redistributeBuffers(); + } + + final List<MemorySegment> segments = new ArrayList<>(numRequiredBuffers); + try { + while (segments.size() < numRequiredBuffers) { + if (isDestroyed) { + throw new IllegalStateException("Buffer pool is destroyed."); + } + + final MemorySegment segment = availableMemorySegments.poll(2, TimeUnit.SECONDS); + if (segment != null) { + segments.add(segment); + } + } + } catch (Throwable e) { + recycleMemorySegments(segments); + ExceptionUtils.rethrowIOException(e); + } + + return segments; + } + + public void recycleMemorySegments(List<MemorySegment> segments) throws IOException { + synchronized (factoryLock) { + numTotalRequiredBuffers -= segments.size(); + + availableMemorySegments.addAll(segments); + + redistributeBuffers(); + } + } + public void destroy() { synchronized (factoryLock) { isDestroyed = true; http://git-wip-us.apache.org/repos/asf/flink/blob/064a1e60/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartitionType.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartitionType.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartitionType.java index 256387c..683fe37 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartitionType.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartitionType.java @@ -20,9 +20,9 @@ package org.apache.flink.runtime.io.network.partition; public enum ResultPartitionType { - BLOCKING(false, false, false), + BLOCKING(false, false, false, false), - PIPELINED(true, true, false), + PIPELINED(true, true, false, false), /** * Pipelined partitions with a bounded (local) buffer pool. @@ -35,7 +35,13 @@ public enum ResultPartitionType { * For batch jobs, it will be best to keep this unlimited ({@link #PIPELINED}) since there are * no checkpoint barriers. */ - PIPELINED_BOUNDED(true, true, true); + PIPELINED_BOUNDED(true, true, true, false), + + /** + * Pipelined partitions with a bounded (local) buffer pool for floating buffers in input gate, and a number + * of exclusive buffers per input channel. The producer transfers data based on consumer's available credits. + */ + PIPELINED_CREDIT_BASED(true, true, true, true); /** Can the partition be consumed while being produced? */ private final boolean isPipelined; @@ -46,13 +52,17 @@ public enum ResultPartitionType { /** Does this partition use a limited number of (network) buffers? */ private final boolean isBounded; + /** Does this partition only send data when consumer has available buffers? */ + private final boolean isCreditBased; + /** * Specifies the behaviour of an intermediate result partition at runtime. */ - ResultPartitionType(boolean isPipelined, boolean hasBackPressure, boolean isBounded) { + ResultPartitionType(boolean isPipelined, boolean hasBackPressure, boolean isBounded, boolean isCreditBased) { this.isPipelined = isPipelined; this.hasBackPressure = hasBackPressure; this.isBounded = isBounded; + this.isCreditBased = isCreditBased; } public boolean hasBackPressure() { @@ -75,4 +85,13 @@ public enum ResultPartitionType { public boolean isBounded() { return isBounded; } + + /** + * Whether this partition uses the credit-based mode to transfer data or not. + * + * @return <tt>true</tt> if the data is transferred based on consumer's credit + */ + public boolean isCreditBased() { + return isCreditBased; + } } http://git-wip-us.apache.org/repos/asf/flink/blob/064a1e60/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java index 719f340..58c9484 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/RemoteInputChannel.java @@ -18,6 +18,7 @@ package org.apache.flink.runtime.io.network.partition.consumer; +import org.apache.flink.core.memory.MemorySegment; import org.apache.flink.runtime.event.TaskEvent; import org.apache.flink.runtime.io.network.ConnectionID; import org.apache.flink.runtime.io.network.ConnectionManager; @@ -30,6 +31,7 @@ import org.apache.flink.runtime.metrics.groups.TaskIOMetricGroup; import java.io.IOException; import java.util.ArrayDeque; +import java.util.List; import java.util.concurrent.atomic.AtomicBoolean; import static org.apache.flink.util.Preconditions.checkNotNull; @@ -97,6 +99,10 @@ public class RemoteInputChannel extends InputChannel { this.connectionManager = checkNotNull(connectionManager); } + void assignExclusiveSegments(List<MemorySegment> segments) { + // TODO in next PR + } + // ------------------------------------------------------------------------ // Consume // ------------------------------------------------------------------------ http://git-wip-us.apache.org/repos/asf/flink/blob/064a1e60/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java index ebfb300..945d127 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java @@ -19,6 +19,7 @@ package org.apache.flink.runtime.io.network.partition.consumer; import org.apache.flink.api.common.JobID; +import org.apache.flink.core.memory.MemorySegment; import org.apache.flink.runtime.deployment.InputChannelDeploymentDescriptor; import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor; import org.apache.flink.runtime.deployment.ResultPartitionLocation; @@ -31,6 +32,7 @@ import org.apache.flink.runtime.io.network.api.serialization.EventSerializer; import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.buffer.BufferPool; import org.apache.flink.runtime.io.network.buffer.BufferProvider; +import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; import org.apache.flink.runtime.io.network.partition.ResultPartitionType; import org.apache.flink.runtime.io.network.partition.consumer.InputChannel.BufferAndAvailability; @@ -147,6 +149,9 @@ public class SingleInputGate implements InputGate { */ private BufferPool bufferPool; + /** Global network buffer pool to request and recycle exclusive buffers. */ + private NetworkBufferPool networkBufferPool; + private boolean hasReceivedAllEndOfPartitionEvents; /** Flag indicating whether partitions have been requested. */ @@ -162,6 +167,9 @@ public class SingleInputGate implements InputGate { private int numberOfUninitializedChannels; + /** Number of network buffers to use for each remote input channel. */ + private int networkBuffersPerChannel; + /** A timer to retrigger local partition requests. Only initialized if actually needed. */ private Timer retriggerLocalRequestTimer; @@ -259,21 +267,55 @@ public class SingleInputGate implements InputGate { public void setBufferPool(BufferPool bufferPool) { // Sanity checks - checkArgument(numberOfInputChannels == bufferPool.getNumberOfRequiredMemorySegments(), + if (!getConsumedPartitionType().isCreditBased()) { + checkArgument(numberOfInputChannels == bufferPool.getNumberOfRequiredMemorySegments(), "Bug in input gate setup logic: buffer pool has not enough guaranteed buffers " + - "for this input gate. Input gates require at least as many buffers as " + + "for this input gate. Input gates require at least as many buffers as " + "there are input channels."); + } checkState(this.bufferPool == null, "Bug in input gate setup logic: buffer pool has" + - "already been set for this input gate."); + "already been set for this input gate."); this.bufferPool = checkNotNull(bufferPool); } + /** + * Assign the exclusive buffers to all remote input channels directly for credit-based mode. + * + * @param networkBufferPool The global pool to request and recycle exclusive buffers + * @param networkBuffersPerChannel The number of exclusive buffers for each channel + */ + public void assignExclusiveSegments(NetworkBufferPool networkBufferPool, int networkBuffersPerChannel) throws IOException { + checkState(this.networkBufferPool == null, "Bug in input gate setup logic: global buffer pool has" + + "already been set for this input gate."); + + this.networkBufferPool = checkNotNull(networkBufferPool); + this.networkBuffersPerChannel = networkBuffersPerChannel; + + synchronized (requestLock) { + for (InputChannel inputChannel : inputChannels.values()) { + if (inputChannel instanceof RemoteInputChannel) { + ((RemoteInputChannel) inputChannel).assignExclusiveSegments( + networkBufferPool.requestMemorySegments(networkBuffersPerChannel)); + } + } + } + } + + /** + * The exclusive segments are recycled to network buffer pool directly when input channel is released. + * + * @param segments The exclusive segments need to be recycled + */ + public void returnExclusiveSegments(List<MemorySegment> segments) throws IOException { + networkBufferPool.recycleMemorySegments(segments); + } + public void setInputChannel(IntermediateResultPartitionID partitionId, InputChannel inputChannel) { synchronized (requestLock) { if (inputChannels.put(checkNotNull(partitionId), checkNotNull(inputChannel)) == null - && inputChannel.getClass() == UnknownInputChannel.class) { + && inputChannel instanceof UnknownInputChannel) { numberOfUninitializedChannels++; } @@ -291,7 +333,7 @@ public class SingleInputGate implements InputGate { InputChannel current = inputChannels.get(partitionId); - if (current.getClass() == UnknownInputChannel.class) { + if (current instanceof UnknownInputChannel) { UnknownInputChannel unknownChannel = (UnknownInputChannel) current; @@ -304,6 +346,11 @@ public class SingleInputGate implements InputGate { } else if (partitionLocation.isRemote()) { newChannel = unknownChannel.toRemoteInputChannel(partitionLocation.getConnectionId()); + + if (getConsumedPartitionType().isCreditBased()) { + ((RemoteInputChannel)newChannel).assignExclusiveSegments( + networkBufferPool.requestMemorySegments(networkBuffersPerChannel)); + } } else { throw new IllegalStateException("Tried to update unknown channel with unknown channel."); http://git-wip-us.apache.org/repos/asf/flink/blob/064a1e60/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java index b956691..826ae3f 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java @@ -37,10 +37,14 @@ import org.junit.Test; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; +import java.io.IOException; + import static org.junit.Assert.assertEquals; import static org.mockito.Matchers.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; /** @@ -82,11 +86,11 @@ public class NetworkEnvironmentTest { new ResultPartitionWriter(rp3), new ResultPartitionWriter(rp4)}; // input gates - final SingleInputGate[] inputGates = new SingleInputGate[] { - createSingleInputGateMock(ResultPartitionType.PIPELINED, 2), - createSingleInputGateMock(ResultPartitionType.BLOCKING, 2), - createSingleInputGateMock(ResultPartitionType.PIPELINED_BOUNDED, 2), - createSingleInputGateMock(ResultPartitionType.PIPELINED_BOUNDED, 8)}; + SingleInputGate ig1 = createSingleInputGateMock(ResultPartitionType.PIPELINED, 2); + SingleInputGate ig2 = createSingleInputGateMock(ResultPartitionType.BLOCKING, 2); + SingleInputGate ig3 = createSingleInputGateMock(ResultPartitionType.PIPELINED_BOUNDED, 2); + SingleInputGate ig4 = createSingleInputGateMock(ResultPartitionType.PIPELINED_CREDIT_BASED, 8); + final SingleInputGate[] inputGates = new SingleInputGate[] {ig1, ig2, ig3, ig4}; // overall task to register Task task = mock(Task.class); @@ -101,6 +105,8 @@ public class NetworkEnvironmentTest { assertEquals(2 * 2 + 8, rp3.getBufferPool().getMaxNumberOfMemorySegments()); assertEquals(8 * 2 + 8, rp4.getBufferPool().getMaxNumberOfMemorySegments()); + verify(ig4, times(1)).assignExclusiveSegments(network.getNetworkBufferPool(), 2); + network.shutdown(); } @@ -154,12 +160,15 @@ public class NetworkEnvironmentTest { BufferPool bp = invocation.getArgumentAt(0, BufferPool.class); if (partitionType == ResultPartitionType.PIPELINED_BOUNDED) { assertEquals(channels * 2 + 8, bp.getMaxNumberOfMemorySegments()); + } else if (partitionType == ResultPartitionType.PIPELINED_CREDIT_BASED) { + assertEquals(8, bp.getMaxNumberOfMemorySegments()); } else { assertEquals(Integer.MAX_VALUE, bp.getMaxNumberOfMemorySegments()); } return null; } }).when(ig).setBufferPool(any(BufferPool.class)); + return ig; } } http://git-wip-us.apache.org/repos/asf/flink/blob/064a1e60/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/BufferPoolFactoryTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/BufferPoolFactoryTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/BufferPoolFactoryTest.java index ce76a6d..7f2bcc6 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/BufferPoolFactoryTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/BufferPoolFactoryTest.java @@ -18,6 +18,7 @@ package org.apache.flink.runtime.io.network.buffer; +import org.apache.flink.core.memory.MemorySegment; import org.apache.flink.core.memory.MemoryType; import org.junit.After; import org.junit.Before; @@ -48,9 +49,13 @@ public class BufferPoolFactoryTest { @After public void verifyAllBuffersReturned() { String msg = "Did not return all buffers to network buffer pool after test."; - assertEquals(msg, numBuffers, networkBufferPool.getNumberOfAvailableMemorySegments()); - // in case buffers have actually been requested, we must release them again - networkBufferPool.destroy(); + try { + assertEquals(msg, numBuffers, networkBufferPool.getNumberOfAvailableMemorySegments()); + } finally { + // in case buffers have actually been requested, we must release them again + networkBufferPool.destroyAllBufferPools(); + networkBufferPool.destroy(); + } } @Test(expected = IOException.class) @@ -134,25 +139,82 @@ public class BufferPoolFactoryTest { @Test public void testUniformDistributionBounded3() throws IOException { NetworkBufferPool globalPool = new NetworkBufferPool(3, 128, MemoryType.HEAP); - BufferPool first = globalPool.createBufferPool(0, 10); - assertEquals(3, first.getNumBuffers()); - - BufferPool second = globalPool.createBufferPool(0, 10); - // the order of which buffer pool received 2 or 1 buffer is undefined - assertEquals(3, first.getNumBuffers() + second.getNumBuffers()); - assertNotEquals(3, first.getNumBuffers()); - assertNotEquals(3, second.getNumBuffers()); + try { + BufferPool first = globalPool.createBufferPool(0, 10); + assertEquals(3, first.getNumBuffers()); + + BufferPool second = globalPool.createBufferPool(0, 10); + // the order of which buffer pool received 2 or 1 buffer is undefined + assertEquals(3, first.getNumBuffers() + second.getNumBuffers()); + assertNotEquals(3, first.getNumBuffers()); + assertNotEquals(3, second.getNumBuffers()); + + BufferPool third = globalPool.createBufferPool(0, 10); + assertEquals(1, first.getNumBuffers()); + assertEquals(1, second.getNumBuffers()); + assertEquals(1, third.getNumBuffers()); + + // similar to #verifyAllBuffersReturned() + String msg = "Wrong number of available segments after creating buffer pools."; + assertEquals(msg, 3, globalPool.getNumberOfAvailableMemorySegments()); + } finally { + // in case buffers have actually been requested, we must release them again + globalPool.destroyAllBufferPools(); + globalPool.destroy(); + } + } - BufferPool third = globalPool.createBufferPool(0, 10); - assertEquals(1, first.getNumBuffers()); - assertEquals(1, second.getNumBuffers()); - assertEquals(1, third.getNumBuffers()); + /** + * Tests the interaction of requesting memory segments and creating local buffer pool and + * verifies the number of assigned buffers match after redistributing buffers because of newly + * requested memory segments or new buffer pools created. + */ + @Test + public void testUniformDistributionBounded4() throws IOException { + NetworkBufferPool globalPool = new NetworkBufferPool(10, 128, MemoryType.HEAP); + try { + BufferPool first = globalPool.createBufferPool(0, 10); + assertEquals(10, first.getNumBuffers()); - // similar to #verifyAllBuffersReturned() - String msg = "Did not return all buffers to network buffer pool after test."; - assertEquals(msg, 3, globalPool.getNumberOfAvailableMemorySegments()); - // in case buffers have actually been requested, we must release them again - globalPool.destroy(); + List<MemorySegment> segmentList1 = globalPool.requestMemorySegments(2); + assertEquals(2, segmentList1.size()); + assertEquals(8, first.getNumBuffers()); + + BufferPool second = globalPool.createBufferPool(0, 10); + assertEquals(4, first.getNumBuffers()); + assertEquals(4, second.getNumBuffers()); + + List<MemorySegment> segmentList2 = globalPool.requestMemorySegments(2); + assertEquals(2, segmentList2.size()); + assertEquals(3, first.getNumBuffers()); + assertEquals(3, second.getNumBuffers()); + + List<MemorySegment> segmentList3 = globalPool.requestMemorySegments(2); + assertEquals(2, segmentList3.size()); + assertEquals(2, first.getNumBuffers()); + assertEquals(2, second.getNumBuffers()); + + String msg = "Wrong number of available segments after creating buffer pools and requesting segments."; + assertEquals(msg, 4, globalPool.getNumberOfAvailableMemorySegments()); + + globalPool.recycleMemorySegments(segmentList1); + assertEquals(msg, 6, globalPool.getNumberOfAvailableMemorySegments()); + assertEquals(3, first.getNumBuffers()); + assertEquals(3, second.getNumBuffers()); + + globalPool.recycleMemorySegments(segmentList2); + assertEquals(msg, 8, globalPool.getNumberOfAvailableMemorySegments()); + assertEquals(4, first.getNumBuffers()); + assertEquals(4, second.getNumBuffers()); + + globalPool.recycleMemorySegments(segmentList3); + assertEquals(msg, 10, globalPool.getNumberOfAvailableMemorySegments()); + assertEquals(5, first.getNumBuffers()); + assertEquals(5, second.getNumBuffers()); + } finally { + globalPool.destroyAllBufferPools(); + globalPool.destroy(); + } } @Test http://git-wip-us.apache.org/repos/asf/flink/blob/064a1e60/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/NetworkBufferPoolTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/NetworkBufferPoolTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/NetworkBufferPoolTest.java index 7c6a543..e30e955 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/NetworkBufferPoolTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/NetworkBufferPoolTest.java @@ -18,20 +18,36 @@ package org.apache.flink.runtime.io.network.buffer; +import org.apache.flink.core.memory.MemorySegment; import org.apache.flink.core.memory.MemoryType; +import org.apache.flink.core.testutils.CheckedThread; +import org.apache.flink.core.testutils.OneShotLatch; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; +import java.io.IOException; import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import static org.hamcrest.core.IsCollectionContaining.hasItem; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.hamcrest.core.IsNot.not; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; public class NetworkBufferPoolTest { + @Rule + public ExpectedException expectedException = ExpectedException.none(); + @Test public void testCreatePoolAfterDestroy() { try { @@ -168,4 +184,160 @@ public class NetworkBufferPoolTest { fail(e.getMessage()); } } + + /** + * Tests {@link NetworkBufferPool#requestMemorySegments(int)} with the {@link NetworkBufferPool} + * currently containing the number of required free segments. + */ + @Test + public void testRequestMemorySegmentsLessThanTotalBuffers() throws Exception { + final int numBuffers = 10; + + NetworkBufferPool globalPool = new NetworkBufferPool(numBuffers, 128, MemoryType.HEAP); + + List<MemorySegment> memorySegments = Collections.emptyList(); + try { + memorySegments = globalPool.requestMemorySegments(numBuffers / 2); + assertEquals(memorySegments.size(), numBuffers / 2); + + globalPool.recycleMemorySegments(memorySegments); + memorySegments.clear(); + assertEquals(globalPool.getNumberOfAvailableMemorySegments(), numBuffers); + } finally { + globalPool.recycleMemorySegments(memorySegments); // just in case + globalPool.destroy(); + } + } + + /** + * Tests {@link NetworkBufferPool#requestMemorySegments(int)} with the number of required + * buffers exceeding the capacity of {@link NetworkBufferPool}. + */ + @Test + public void testRequestMemorySegmentsMoreThanTotalBuffers() throws Exception { + final int numBuffers = 10; + + NetworkBufferPool globalPool = new NetworkBufferPool(numBuffers, 128, MemoryType.HEAP); + + try { + globalPool.requestMemorySegments(numBuffers + 1); + fail("Should throw an IOException"); + } catch (IOException e) { + assertEquals(globalPool.getNumberOfAvailableMemorySegments(), numBuffers); + } finally { + globalPool.destroy(); + } + } + + /** + * Tests {@link NetworkBufferPool#requestMemorySegments(int)} with the invalid argument to + * cause exception. + */ + @Test + public void testRequestMemorySegmentsWithInvalidArgument() throws Exception { + final int numBuffers = 10; + + NetworkBufferPool globalPool = new NetworkBufferPool(numBuffers, 128, MemoryType.HEAP); + + try { + // the number of requested buffers should be larger than zero + globalPool.requestMemorySegments(0); + fail("Should throw an IllegalArgumentException"); + } catch (IllegalArgumentException e) { + assertEquals(globalPool.getNumberOfAvailableMemorySegments(), numBuffers); + } finally { + globalPool.destroy(); + } + } + + /** + * Tests {@link NetworkBufferPool#requestMemorySegments(int)} with the {@link NetworkBufferPool} + * currently not containing the number of required free segments (currently occupied by a buffer pool). + */ + @Test + public void testRequestMemorySegmentsWithBuffersTaken() throws IOException, InterruptedException { + final int numBuffers = 10; + + NetworkBufferPool networkBufferPool = new NetworkBufferPool(numBuffers, 128, MemoryType.HEAP); + + final List<Buffer> buffers = new ArrayList<>(numBuffers); + List<MemorySegment> memorySegments = Collections.emptyList(); + Thread bufferRecycler = null; + BufferPool lbp1 = null; + try { + lbp1 = networkBufferPool.createBufferPool(numBuffers / 2, numBuffers); + + // take all buffers (more than the minimum required) + for (int i = 0; i < numBuffers; ++i) { + Buffer buffer = lbp1.requestBuffer(); + buffers.add(buffer); + assertNotNull(buffer); + } + + // requestMemorySegments() below will wait for buffers + // this will make sure that enough buffers are freed eventually for it to continue + final OneShotLatch isRunning = new OneShotLatch(); + bufferRecycler = new Thread(() -> { + try { + isRunning.trigger(); + Thread.sleep(100); + } catch (InterruptedException ignored) { + } + + for (Buffer buffer : buffers) { + buffer.recycle(); + } + }); + bufferRecycler.start(); + + // take more buffers than are freely available at the moment via requestMemorySegments() + isRunning.await(); + memorySegments = networkBufferPool.requestMemorySegments(numBuffers / 2); + assertThat(memorySegments, not(hasItem(nullValue()))); + } finally { + if (bufferRecycler != null) { + bufferRecycler.join(); + } + if (lbp1 != null) { + lbp1.lazyDestroy(); + } + networkBufferPool.recycleMemorySegments(memorySegments); + networkBufferPool.destroy(); + } + } + + /** + * Tests {@link NetworkBufferPool#requestMemorySegments(int)}, verifying it may be aborted in + * case of a concurrent {@link NetworkBufferPool#destroy()} call. + */ + @Test + public void testRequestMemorySegmentsInterruptable() throws Exception { + final int numBuffers = 10; + + NetworkBufferPool globalPool = new NetworkBufferPool(numBuffers, 128, MemoryType.HEAP); + MemorySegment segment = globalPool.requestMemorySegment(); + assertNotNull(segment); + + final OneShotLatch isRunning = new OneShotLatch(); + CheckedThread asyncRequest = new CheckedThread() { + @Override + public void go() throws Exception { + isRunning.trigger(); + globalPool.requestMemorySegments(10); + } + }; + asyncRequest.start(); + + // We want the destroy call inside the blocking part of the globalPool.requestMemorySegments() + // call above. We cannot guarantee this though but make it highly probable: + isRunning.await(); + Thread.sleep(10); + globalPool.destroy(); + + segment.free(); + + expectedException.expect(IllegalStateException.class); + expectedException.expectMessage("destroyed"); + asyncRequest.sync(); + } } http://git-wip-us.apache.org/repos/asf/flink/blob/064a1e60/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java index 737f17b..4d7d884 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java @@ -19,6 +19,7 @@ package org.apache.flink.runtime.io.network.partition.consumer; import org.apache.flink.api.common.JobID; +import org.apache.flink.core.memory.MemorySegment; import org.apache.flink.core.memory.MemorySegmentFactory; import org.apache.flink.runtime.deployment.InputChannelDeploymentDescriptor; import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor; @@ -33,6 +34,7 @@ import org.apache.flink.runtime.io.network.TaskEventDispatcher; import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.buffer.BufferPool; import org.apache.flink.runtime.io.network.buffer.BufferRecycler; +import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool; import org.apache.flink.runtime.io.network.partition.BufferAvailabilityListener; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; import org.apache.flink.runtime.io.network.partition.ResultPartitionManager; @@ -57,6 +59,7 @@ import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyInt; +import static org.mockito.Matchers.anyListOf; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; @@ -372,6 +375,70 @@ public class SingleInputGateTest { } } + /** + * Tests that input gate requests and assigns network buffers for remote input channel. + */ + @Test + public void testRequestBuffersWithRemoteInputChannel() throws Exception { + final SingleInputGate inputGate = new SingleInputGate( + "t1", + new JobID(), + new IntermediateDataSetID(), + ResultPartitionType.PIPELINED_CREDIT_BASED, + 0, + 1, + mock(TaskActions.class), + new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup()); + + RemoteInputChannel remote = mock(RemoteInputChannel.class); + inputGate.setInputChannel(new IntermediateResultPartitionID(), remote); + + final int buffersPerChannel = 2; + NetworkBufferPool network = mock(NetworkBufferPool.class); + // Trigger requests of segments from global pool and assign buffers to remote input channel + inputGate.assignExclusiveSegments(network, buffersPerChannel); + + verify(network, times(1)).requestMemorySegments(buffersPerChannel); + verify(remote, times(1)).assignExclusiveSegments(anyListOf(MemorySegment.class)); + } + + /** + * Tests that input gate requests and assigns network buffers when unknown input channel + * updates to remote input channel. + */ + @Test + public void testRequestBuffersWithUnknownInputChannel() throws Exception { + final SingleInputGate inputGate = new SingleInputGate( + "t1", + new JobID(), + new IntermediateDataSetID(), + ResultPartitionType.PIPELINED_CREDIT_BASED, + 0, + 1, + mock(TaskActions.class), + new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup()); + + UnknownInputChannel unknown = mock(UnknownInputChannel.class); + final ResultPartitionID resultPartitionId = new ResultPartitionID(); + inputGate.setInputChannel(resultPartitionId.getPartitionId(), unknown); + + RemoteInputChannel remote = mock(RemoteInputChannel.class); + final ConnectionID connectionId = new ConnectionID(new InetSocketAddress("localhost", 5000), 0); + when(unknown.toRemoteInputChannel(connectionId)).thenReturn(remote); + + final int buffersPerChannel = 2; + NetworkBufferPool network = mock(NetworkBufferPool.class); + inputGate.assignExclusiveSegments(network, buffersPerChannel); + + // Trigger updates to remote input channel from unknown input channel + inputGate.updateInputChannel(new InputChannelDeploymentDescriptor( + resultPartitionId, + ResultPartitionLocation.createRemote(connectionId))); + + verify(network, times(1)).requestMemorySegments(buffersPerChannel); + verify(remote, times(1)).assignExclusiveSegments(anyListOf(MemorySegment.class)); + } + // --------------------------------------------------------------------------------------------- static void verifyBufferOrEvent(
