Repository: flink
Updated Branches:
  refs/heads/master 1e1e4dc70 -> cbdb784dc


[FLINK-4894] [network] Don't request buffer after writing to partition

After emitting a record via the RecordWriter, we eagerly requested
a new buffer for the next emit on that channel (although it's not clear
that we will immediately need it). With this change, we request that
buffer lazily when an emit call requires it.

This closes #2690.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/cbdb784d
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/cbdb784d
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/cbdb784d

Branch: refs/heads/master
Commit: cbdb784dc24abba50674d054bb21c94dd7a559a5
Parents: 1e1e4dc
Author: Ufuk Celebi <u...@apache.org>
Authored: Mon Oct 24 18:01:28 2016 +0200
Committer: Ufuk Celebi <u...@apache.org>
Committed: Thu Oct 27 17:33:38 2016 +0200

----------------------------------------------------------------------
 .../serialization/SpanningRecordSerializer.java |   2 +-
 .../io/network/api/writer/RecordWriter.java     |  58 +++---
 .../io/network/api/writer/RecordWriterTest.java | 200 ++++++++++++++++++-
 3 files changed, 226 insertions(+), 34 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/cbdb784d/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializer.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializer.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializer.java
index 7c4d937..65b3d20 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializer.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/SpanningRecordSerializer.java
@@ -188,7 +188,7 @@ public class SpanningRecordSerializer<T extends 
IOReadableWritable> implements R
        @Override
        public boolean hasData() {
                // either data in current target buffer or intermediate buffers
-               return this.position > 0 || (this.lengthBuffer.hasRemaining() 
|| this.dataBuffer.hasRemaining());
+               return (this.position > 0 && this.position < this.limit) || 
(this.lengthBuffer.hasRemaining() || this.dataBuffer.hasRemaining());
        }
 
        @Override

http://git-wip-us.apache.org/repos/asf/flink/blob/cbdb784d/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/RecordWriter.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/RecordWriter.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/RecordWriter.java
index 96eea23..fb35843 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/RecordWriter.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/RecordWriter.java
@@ -47,7 +47,7 @@ import static 
org.apache.flink.runtime.io.network.api.serialization.RecordSerial
  */
 public class RecordWriter<T extends IOReadableWritable> {
 
-       protected final ResultPartitionWriter writer;
+       protected final ResultPartitionWriter targetPartition;
 
        private final ChannelSelector<T> channelSelector;
 
@@ -64,7 +64,7 @@ public class RecordWriter<T extends IOReadableWritable> {
 
        @SuppressWarnings("unchecked")
        public RecordWriter(ResultPartitionWriter writer, ChannelSelector<T> 
channelSelector) {
-               this.writer = writer;
+               this.targetPartition = writer;
                this.channelSelector = channelSelector;
 
                this.numChannels = writer.getNumberOfOutputChannels();
@@ -108,15 +108,25 @@ public class RecordWriter<T extends IOReadableWritable> {
 
                synchronized (serializer) {
                        SerializationResult result = 
serializer.addRecord(record);
+
                        while (result.isFullBuffer()) {
                                Buffer buffer = serializer.getCurrentBuffer();
 
                                if (buffer != null) {
-                                       writeBuffer(buffer, targetChannel, 
serializer);
+                                       writeAndClearBuffer(buffer, 
targetChannel, serializer);
+
+                                       // If this was a full record, we are 
done. Not breaking
+                                       // out of the loop at this point will 
lead to another
+                                       // buffer request before breaking out 
(that would not be
+                                       // a problem per se, but it can lead to 
stalls in the
+                                       // pipeline).
+                                       if (result.isFullRecord()) {
+                                               break;
+                                       }
+                               } else {
+                                       buffer = 
targetPartition.getBufferProvider().requestBufferBlocking();
+                                       result = 
serializer.setNextBuffer(buffer);
                                }
-
-                               buffer = 
writer.getBufferProvider().requestBufferBlocking();
-                               result = serializer.setNextBuffer(buffer);
                        }
                }
        }
@@ -126,23 +136,14 @@ public class RecordWriter<T extends IOReadableWritable> {
                        RecordSerializer<T> serializer = 
serializers[targetChannel];
 
                        synchronized (serializer) {
-
-                               if (serializer.hasData()) {
-                                       Buffer buffer = 
serializer.getCurrentBuffer();
-                                       if (buffer == null) {
-                                               throw new 
IllegalStateException("Serializer has data but no buffer.");
-                                       }
-
-                                       writeBuffer(buffer, targetChannel, 
serializer);
-
-                                       writer.writeEvent(event, targetChannel);
-
-                                       buffer = 
writer.getBufferProvider().requestBufferBlocking();
-                                       serializer.setNextBuffer(buffer);
-                               }
-                               else {
-                                       writer.writeEvent(event, targetChannel);
+                               Buffer buffer = serializer.getCurrentBuffer();
+                               if (buffer != null) {
+                                       writeAndClearBuffer(buffer, 
targetChannel, serializer);
+                               } else if (serializer.hasData()) {
+                                       throw new IllegalStateException("No 
buffer, but serializer has buffered data.");
                                }
+
+                               targetPartition.writeEvent(event, 
targetChannel);
                        }
                }
        }
@@ -154,15 +155,12 @@ public class RecordWriter<T extends IOReadableWritable> {
                        synchronized (serializer) {
                                Buffer buffer = serializer.getCurrentBuffer();
                                if (buffer != null) {
-                                       writeBuffer(buffer, targetChannel, 
serializer);
-
-                                       buffer = 
writer.getBufferProvider().requestBufferBlocking();
-                                       serializer.setNextBuffer(buffer);
+                                       writeAndClearBuffer(buffer, 
targetChannel, serializer);
                                }
                        }
                }
 
-               writer.writeEndOfSuperstep();
+               targetPartition.writeEndOfSuperstep();
        }
 
        public void flush() throws IOException {
@@ -174,7 +172,7 @@ public class RecordWriter<T extends IOReadableWritable> {
                                        Buffer buffer = 
serializer.getCurrentBuffer();
 
                                        if (buffer != null) {
-                                               writeBuffer(buffer, 
targetChannel, serializer);
+                                               writeAndClearBuffer(buffer, 
targetChannel, serializer);
                                        }
                                } finally {
                                        serializer.clear();
@@ -224,13 +222,13 @@ public class RecordWriter<T extends IOReadableWritable> {
         *
         * <p> The buffer is cleared from the serializer state after a call to 
this method.
         */
-       private void writeBuffer(
+       private void writeAndClearBuffer(
                        Buffer buffer,
                        int targetChannel,
                        RecordSerializer<T> serializer) throws IOException {
 
                try {
-                       writer.writeBuffer(buffer, targetChannel);
+                       targetPartition.writeBuffer(buffer, targetChannel);
                }
                finally {
                        serializer.clearCurrentBuffer();

http://git-wip-us.apache.org/repos/asf/flink/blob/cbdb784d/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java
----------------------------------------------------------------------
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java
index 70faf22..43a93c6 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java
@@ -18,25 +18,37 @@
 
 package org.apache.flink.runtime.io.network.api.writer;
 
+import org.apache.flink.core.io.IOReadableWritable;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.core.memory.MemorySegmentFactory;
 import org.apache.flink.core.memory.MemoryType;
+import org.apache.flink.runtime.event.AbstractEvent;
+import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
+import 
org.apache.flink.runtime.io.network.api.serialization.RecordSerializer.SerializationResult;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.io.network.buffer.BufferPool;
 import org.apache.flink.runtime.io.network.buffer.BufferProvider;
 import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
+import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent;
 import org.apache.flink.runtime.io.network.util.TestBufferFactory;
 import org.apache.flink.runtime.io.network.util.TestTaskEvent;
+import org.apache.flink.runtime.testutils.DiscardingRecycler;
 import org.apache.flink.types.IntValue;
-
+import org.apache.flink.util.XORShiftRandom;
 import org.junit.Assert;
 import org.junit.Test;
 import org.junit.runner.RunWith;
-
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
 import org.powermock.core.classloader.annotations.PrepareForTest;
 import org.powermock.modules.junit4.PowerMockRunner;
 
 import java.io.IOException;
+import java.util.ArrayDeque;
+import java.util.Queue;
+import java.util.Random;
 import java.util.concurrent.Callable;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.ExecutorService;
@@ -44,7 +56,6 @@ import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
 
 import static org.apache.flink.util.Preconditions.checkNotNull;
-
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertTrue;
 import static org.mockito.Matchers.any;
@@ -293,10 +304,160 @@ public class RecordWriterTest {
                recordWriter.flush();
        }
 
+       /**
+        * Tests broadcasting events when no records have been emitted yet.
+        */
+       @Test
+       public void testBroadcastEventNoRecords() throws Exception {
+               int numChannels = 4;
+               int bufferSize = 32;
+
+               @SuppressWarnings("unchecked")
+               Queue<BufferOrEvent>[] queues = new Queue[numChannels];
+               for (int i = 0; i < numChannels; i++) {
+                       queues[i] = new ArrayDeque<>();
+               }
+
+               BufferProvider bufferProvider = 
createBufferProvider(bufferSize);
+
+               ResultPartitionWriter partitionWriter = 
createCollectingPartitionWriter(queues, bufferProvider);
+               RecordWriter<ByteArrayIO> writer = new 
RecordWriter<>(partitionWriter, new RoundRobin<ByteArrayIO>());
+               CheckpointBarrier barrier = new 
CheckpointBarrier(Integer.MAX_VALUE + 919192L, Integer.MAX_VALUE + 18828228L);
+
+               // No records emitted yet, broadcast should not request a buffer
+               writer.broadcastEvent(barrier);
+
+               verify(bufferProvider, times(0)).requestBufferBlocking();
+
+               for (Queue<BufferOrEvent> queue : queues) {
+                       assertEquals(1, queue.size());
+                       BufferOrEvent boe = queue.remove();
+                       assertTrue(boe.isEvent());
+                       assertEquals(barrier, boe.getEvent());
+               }
+       }
+
+       /**
+        * Tests broadcasting events when records have been emitted. The emitted
+        * records cover all three {@link SerializationResult} types.
+        */
+       @Test
+       public void testBroadcastEventMixedRecords() throws Exception {
+               Random rand = new XORShiftRandom();
+               int numChannels = 4;
+               int bufferSize = 32;
+               int lenBytes = 4; // serialized length
+
+               @SuppressWarnings("unchecked")
+               Queue<BufferOrEvent>[] queues = new Queue[numChannels];
+               for (int i = 0; i < numChannels; i++) {
+                       queues[i] = new ArrayDeque<>();
+               }
+
+               BufferProvider bufferProvider = 
createBufferProvider(bufferSize);
+
+               ResultPartitionWriter partitionWriter = 
createCollectingPartitionWriter(queues, bufferProvider);
+               RecordWriter<ByteArrayIO> writer = new 
RecordWriter<>(partitionWriter, new RoundRobin<ByteArrayIO>());
+               CheckpointBarrier barrier = new 
CheckpointBarrier(Integer.MAX_VALUE + 1292L, Integer.MAX_VALUE + 199L);
+
+               // Emit records on some channels first (requesting buffers), 
then
+               // broadcast the event. The record buffers should be emitted 
first, then
+               // the event. After the event, no new buffer should be 
requested.
+
+               // (i) Smaller than the buffer size (single buffer request => 1)
+               byte[] bytes = new byte[bufferSize / 2];
+               rand.nextBytes(bytes);
+
+               writer.emit(new ByteArrayIO(bytes));
+
+               // (ii) Larger than the buffer size (two buffer requests => 1 + 
2)
+               bytes = new byte[bufferSize + 1];
+               rand.nextBytes(bytes);
+
+               writer.emit(new ByteArrayIO(bytes));
+
+               // (iii) Exactly the buffer size (single buffer request => 1 + 
2 + 1)
+               bytes = new byte[bufferSize - lenBytes];
+               rand.nextBytes(bytes);
+
+               writer.emit(new ByteArrayIO(bytes));
+
+               // (iv) Nothing on the 4th channel (no buffer request => 1 + 2 
+ 1 + 0 = 4)
+
+               // (v) Broadcast the event
+               writer.broadcastEvent(barrier);
+
+               verify(bufferProvider, times(4)).requestBufferBlocking();
+
+               assertEquals(2, queues[0].size()); // 1 buffer + 1 event
+               assertEquals(3, queues[1].size()); // 2 buffers + 1 event
+               assertEquals(2, queues[2].size()); // 1 buffer + 1 event
+               assertEquals(1, queues[3].size()); // 0 buffers + 1 event
+       }
+
        // 
---------------------------------------------------------------------------------------------
        // Helpers
        // 
---------------------------------------------------------------------------------------------
 
+       /**
+        * Creates a mock partition writer that collects the added 
buffers/events.
+        *
+        * <p>This much mocking should not be necessary with better designed
+        * interfaces. Refactoring this will take too much time now though, 
hence
+        * the mocking. Ideally, we will refactor all of this mess in order to 
make
+        * our lives easier and test it better.
+        */
+       private ResultPartitionWriter createCollectingPartitionWriter(
+                       final Queue<BufferOrEvent>[] queues,
+                       BufferProvider bufferProvider) throws IOException {
+
+               int numChannels = queues.length;
+
+               ResultPartitionWriter partitionWriter = 
mock(ResultPartitionWriter.class);
+               
when(partitionWriter.getBufferProvider()).thenReturn(checkNotNull(bufferProvider));
+               
when(partitionWriter.getNumberOfOutputChannels()).thenReturn(numChannels);
+
+               doAnswer(new Answer<Void>() {
+                       @Override
+                       public Void answer(InvocationOnMock invocationOnMock) 
throws Throwable {
+                               Buffer buffer = (Buffer) 
invocationOnMock.getArguments()[0];
+                               Integer targetChannel = (Integer) 
invocationOnMock.getArguments()[1];
+                               queues[targetChannel].add(new 
BufferOrEvent(buffer, targetChannel));
+                               return null;
+                       }
+               }).when(partitionWriter).writeBuffer(any(Buffer.class), 
anyInt());
+
+               doAnswer(new Answer<Void>() {
+                       @Override
+                       public Void answer(InvocationOnMock invocationOnMock) 
throws Throwable {
+                               AbstractEvent event = (AbstractEvent) 
invocationOnMock.getArguments()[0];
+                               Integer targetChannel = (Integer) 
invocationOnMock.getArguments()[1];
+                               queues[targetChannel].add(new 
BufferOrEvent(event, targetChannel));
+                               return null;
+                       }
+               }).when(partitionWriter).writeEvent(any(AbstractEvent.class), 
anyInt());
+
+               return partitionWriter;
+       }
+
+       private BufferProvider createBufferProvider(final int bufferSize)
+                       throws IOException, InterruptedException {
+
+               BufferProvider bufferProvider = mock(BufferProvider.class);
+               when(bufferProvider.requestBufferBlocking()).thenAnswer(
+                               new Answer<Buffer>() {
+                                       @Override
+                                       public Buffer answer(InvocationOnMock 
invocationOnMock) throws Throwable {
+                                               MemorySegment segment = 
MemorySegmentFactory.allocateUnpooledSegment(bufferSize);
+                                               Buffer buffer = new 
Buffer(segment, DiscardingRecycler.INSTANCE);
+                                               return buffer;
+                                       }
+                               }
+               );
+
+               return bufferProvider;
+       }
+
        private BufferProvider createBufferProvider(Buffer... buffers)
                        throws IOException, InterruptedException {
 
@@ -328,4 +489,37 @@ public class RecordWriterTest {
 
                return partitionWriter;
        }
+
+       private static class ByteArrayIO implements IOReadableWritable {
+
+               private final byte[] bytes;
+
+               public ByteArrayIO(byte[] bytes) {
+                       this.bytes = bytes;
+               }
+
+               @Override
+               public void write(DataOutputView out) throws IOException {
+                       out.write(bytes);
+               }
+
+               @Override
+               public void read(DataInputView in) throws IOException {
+                       in.read(bytes);
+               }
+       }
+
+       /**
+        * RoundRobin channel selector starting at 0 ({@link 
RoundRobinChannelSelector} starts at 1).
+        */
+       private static class RoundRobin<T extends IOReadableWritable> 
implements ChannelSelector<T> {
+
+               private int[] nextChannel = new int[] { -1 };
+
+               @Override
+               public int[] selectChannels(final T record, final int 
numberOfOutputChannels) {
+                       nextChannel[0] = (nextChannel[0] + 1) % 
numberOfOutputChannels;
+                       return nextChannel;
+               }
+       }
 }

Reply via email to