Repository: spark Updated Branches: refs/heads/master 1c5475f14 -> 540bf58f1
[SPARK-11617][NETWORK] Fix leak in TransportFrameDecoder. The code was using the wrong API to add data to the internal composite buffer, causing buffers to leak in certain situations. Use the right API and enhance the tests to catch memory leaks. Also, avoid reusing the composite buffers when downstream handlers keep references to them; this seems to cause a few different issues even though the ref counting code seems to be correct, so instead pay the cost of copying a few bytes when that situation happens. Author: Marcelo Vanzin <van...@cloudera.com> Closes #9619 from vanzin/SPARK-11617. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/540bf58f Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/540bf58f Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/540bf58f Branch: refs/heads/master Commit: 540bf58f18328c68107d6c616ffd70f3a4640054 Parents: 1c5475f Author: Marcelo Vanzin <van...@cloudera.com> Authored: Mon Nov 16 17:28:11 2015 -0800 Committer: Marcelo Vanzin <van...@cloudera.com> Committed: Mon Nov 16 17:28:11 2015 -0800 ---------------------------------------------------------------------- .../network/util/TransportFrameDecoder.java | 47 ++++-- .../util/TransportFrameDecoderSuite.java | 145 +++++++++++++++---- 2 files changed, 151 insertions(+), 41 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/540bf58f/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java ---------------------------------------------------------------------- diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java b/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java index 272ea84..5889562 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java +++ b/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java @@ -56,32 +56,43 @@ public class TransportFrameDecoder extends ChannelInboundHandlerAdapter { buffer = in.alloc().compositeBuffer(); } - buffer.writeBytes(in); + buffer.addComponent(in).writerIndex(buffer.writerIndex() + in.readableBytes()); while (buffer.isReadable()) { - feedInterceptor(); - if (interceptor != null) { - continue; - } + discardReadBytes(); + if (!feedInterceptor()) { + ByteBuf frame = decodeNext(); + if (frame == null) { + break; + } - ByteBuf frame = decodeNext(); - if (frame != null) { ctx.fireChannelRead(frame); - } else { - break; } } - // We can't discard read sub-buffers if there are other references to the buffer (e.g. - // through slices used for framing). This assumes that code that retains references - // will call retain() from the thread that called "fireChannelRead()" above, otherwise - // ref counting will go awry. - if (buffer != null && buffer.refCnt() == 1) { + discardReadBytes(); + } + + private void discardReadBytes() { + // If the buffer's been retained by downstream code, then make a copy of the remaining + // bytes into a new buffer. Otherwise, just discard stale components. + if (buffer.refCnt() > 1) { + CompositeByteBuf newBuffer = buffer.alloc().compositeBuffer(); + + if (buffer.readableBytes() > 0) { + ByteBuf spillBuf = buffer.alloc().buffer(buffer.readableBytes()); + spillBuf.writeBytes(buffer); + newBuffer.addComponent(spillBuf).writerIndex(spillBuf.readableBytes()); + } + + buffer.release(); + buffer = newBuffer; + } else { buffer.discardReadComponents(); } } - protected ByteBuf decodeNext() throws Exception { + private ByteBuf decodeNext() throws Exception { if (buffer.readableBytes() < LENGTH_SIZE) { return null; } @@ -127,10 +138,14 @@ public class TransportFrameDecoder extends ChannelInboundHandlerAdapter { this.interceptor = interceptor; } - private void feedInterceptor() throws Exception { + /** + * @return Whether the interceptor is still active after processing the data. + */ + private boolean feedInterceptor() throws Exception { if (interceptor != null && !interceptor.handle(buffer)) { interceptor = null; } + return interceptor != null; } public static interface Interceptor { http://git-wip-us.apache.org/repos/asf/spark/blob/540bf58f/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java ---------------------------------------------------------------------- diff --git a/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java b/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java index ca74f0a..19475c2 100644 --- a/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java @@ -18,41 +18,36 @@ package org.apache.spark.network.util; import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; import java.util.Random; +import java.util.concurrent.atomic.AtomicInteger; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandlerContext; +import org.junit.AfterClass; import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; import static org.junit.Assert.*; import static org.mockito.Mockito.*; public class TransportFrameDecoderSuite { + private static Random RND = new Random(); + + @AfterClass + public static void cleanup() { + RND = null; + } + @Test public void testFrameDecoding() throws Exception { - Random rnd = new Random(); TransportFrameDecoder decoder = new TransportFrameDecoder(); - ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); - - final int frameCount = 100; - ByteBuf data = Unpooled.buffer(); - try { - for (int i = 0; i < frameCount; i++) { - byte[] frame = new byte[1024 * (rnd.nextInt(31) + 1)]; - data.writeLong(frame.length + 8); - data.writeBytes(frame); - } - - while (data.isReadable()) { - int size = rnd.nextInt(16 * 1024) + 256; - decoder.channelRead(ctx, data.readSlice(Math.min(data.readableBytes(), size))); - } - - verify(ctx, times(frameCount)).fireChannelRead(any(ByteBuf.class)); - } finally { - data.release(); - } + ChannelHandlerContext ctx = mockChannelHandlerContext(); + ByteBuf data = createAndFeedFrames(100, decoder, ctx); + verifyAndCloseDecoder(decoder, ctx, data); } @Test @@ -60,7 +55,7 @@ public class TransportFrameDecoderSuite { final int interceptedReads = 3; TransportFrameDecoder decoder = new TransportFrameDecoder(); TransportFrameDecoder.Interceptor interceptor = spy(new MockInterceptor(interceptedReads)); - ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + ChannelHandlerContext ctx = mockChannelHandlerContext(); byte[] data = new byte[8]; ByteBuf len = Unpooled.copyLong(8 + data.length); @@ -70,16 +65,56 @@ public class TransportFrameDecoderSuite { decoder.setInterceptor(interceptor); for (int i = 0; i < interceptedReads; i++) { decoder.channelRead(ctx, dataBuf); - dataBuf.release(); + assertEquals(0, dataBuf.refCnt()); dataBuf = Unpooled.wrappedBuffer(data); } decoder.channelRead(ctx, len); decoder.channelRead(ctx, dataBuf); verify(interceptor, times(interceptedReads)).handle(any(ByteBuf.class)); verify(ctx).fireChannelRead(any(ByteBuffer.class)); + assertEquals(0, len.refCnt()); + assertEquals(0, dataBuf.refCnt()); } finally { - len.release(); - dataBuf.release(); + release(len); + release(dataBuf); + } + } + + @Test + public void testRetainedFrames() throws Exception { + TransportFrameDecoder decoder = new TransportFrameDecoder(); + + final AtomicInteger count = new AtomicInteger(); + final List<ByteBuf> retained = new ArrayList<>(); + + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + when(ctx.fireChannelRead(any())).thenAnswer(new Answer<Void>() { + @Override + public Void answer(InvocationOnMock in) { + // Retain a few frames but not others. + ByteBuf buf = (ByteBuf) in.getArguments()[0]; + if (count.incrementAndGet() % 2 == 0) { + retained.add(buf); + } else { + buf.release(); + } + return null; + } + }); + + ByteBuf data = createAndFeedFrames(100, decoder, ctx); + try { + // Verify all retained buffers are readable. + for (ByteBuf b : retained) { + byte[] tmp = new byte[b.readableBytes()]; + b.readBytes(tmp); + b.release(); + } + verifyAndCloseDecoder(decoder, ctx, data); + } finally { + for (ByteBuf b : retained) { + release(b); + } } } @@ -100,6 +135,47 @@ public class TransportFrameDecoderSuite { testInvalidFrame(Integer.MAX_VALUE + 9); } + /** + * Creates a number of randomly sized frames and feed them to the given decoder, verifying + * that the frames were read. + */ + private ByteBuf createAndFeedFrames( + int frameCount, + TransportFrameDecoder decoder, + ChannelHandlerContext ctx) throws Exception { + ByteBuf data = Unpooled.buffer(); + for (int i = 0; i < frameCount; i++) { + byte[] frame = new byte[1024 * (RND.nextInt(31) + 1)]; + data.writeLong(frame.length + 8); + data.writeBytes(frame); + } + + try { + while (data.isReadable()) { + int size = RND.nextInt(4 * 1024) + 256; + decoder.channelRead(ctx, data.readSlice(Math.min(data.readableBytes(), size)).retain()); + } + + verify(ctx, times(frameCount)).fireChannelRead(any(ByteBuf.class)); + } catch (Exception e) { + release(data); + throw e; + } + return data; + } + + private void verifyAndCloseDecoder( + TransportFrameDecoder decoder, + ChannelHandlerContext ctx, + ByteBuf data) throws Exception { + try { + decoder.channelInactive(ctx); + assertTrue("There shouldn't be dangling references to the data.", data.release()); + } finally { + release(data); + } + } + private void testInvalidFrame(long size) throws Exception { TransportFrameDecoder decoder = new TransportFrameDecoder(); ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); @@ -111,6 +187,25 @@ public class TransportFrameDecoderSuite { } } + private ChannelHandlerContext mockChannelHandlerContext() { + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + when(ctx.fireChannelRead(any())).thenAnswer(new Answer<Void>() { + @Override + public Void answer(InvocationOnMock in) { + ByteBuf buf = (ByteBuf) in.getArguments()[0]; + buf.release(); + return null; + } + }); + return ctx; + } + + private void release(ByteBuf buf) { + if (buf.refCnt() > 0) { + buf.release(buf.refCnt()); + } + } + private static class MockInterceptor implements TransportFrameDecoder.Interceptor { private int remainingReads; --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org