Repository: spark
Updated Branches:
  refs/heads/master 1c5475f14 -> 540bf58f1


[SPARK-11617][NETWORK] Fix leak in TransportFrameDecoder.

The code was using the wrong API to add data to the internal composite
buffer, causing buffers to leak in certain situations. Use the right
API and enhance the tests to catch memory leaks.

Also, avoid reusing the composite buffers when downstream handlers keep
references to them; this seems to cause a few different issues even though
the ref counting code seems to be correct, so instead pay the cost of copying
a few bytes when that situation happens.

Author: Marcelo Vanzin <van...@cloudera.com>

Closes #9619 from vanzin/SPARK-11617.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/540bf58f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/540bf58f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/540bf58f

Branch: refs/heads/master
Commit: 540bf58f18328c68107d6c616ffd70f3a4640054
Parents: 1c5475f
Author: Marcelo Vanzin <van...@cloudera.com>
Authored: Mon Nov 16 17:28:11 2015 -0800
Committer: Marcelo Vanzin <van...@cloudera.com>
Committed: Mon Nov 16 17:28:11 2015 -0800

----------------------------------------------------------------------
 .../network/util/TransportFrameDecoder.java     |  47 ++++--
 .../util/TransportFrameDecoderSuite.java        | 145 +++++++++++++++----
 2 files changed, 151 insertions(+), 41 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/540bf58f/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java
----------------------------------------------------------------------
diff --git 
a/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java
 
b/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java
index 272ea84..5889562 100644
--- 
a/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java
+++ 
b/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java
@@ -56,32 +56,43 @@ public class TransportFrameDecoder extends 
ChannelInboundHandlerAdapter {
       buffer = in.alloc().compositeBuffer();
     }
 
-    buffer.writeBytes(in);
+    buffer.addComponent(in).writerIndex(buffer.writerIndex() + 
in.readableBytes());
 
     while (buffer.isReadable()) {
-      feedInterceptor();
-      if (interceptor != null) {
-        continue;
-      }
+      discardReadBytes();
+      if (!feedInterceptor()) {
+        ByteBuf frame = decodeNext();
+        if (frame == null) {
+          break;
+        }
 
-      ByteBuf frame = decodeNext();
-      if (frame != null) {
         ctx.fireChannelRead(frame);
-      } else {
-        break;
       }
     }
 
-    // We can't discard read sub-buffers if there are other references to the 
buffer (e.g.
-    // through slices used for framing). This assumes that code that retains 
references
-    // will call retain() from the thread that called "fireChannelRead()" 
above, otherwise
-    // ref counting will go awry.
-    if (buffer != null && buffer.refCnt() == 1) {
+    discardReadBytes();
+  }
+
+  private void discardReadBytes() {
+    // If the buffer's been retained by downstream code, then make a copy of 
the remaining
+    // bytes into a new buffer. Otherwise, just discard stale components.
+    if (buffer.refCnt() > 1) {
+      CompositeByteBuf newBuffer = buffer.alloc().compositeBuffer();
+
+      if (buffer.readableBytes() > 0) {
+        ByteBuf spillBuf = buffer.alloc().buffer(buffer.readableBytes());
+        spillBuf.writeBytes(buffer);
+        newBuffer.addComponent(spillBuf).writerIndex(spillBuf.readableBytes());
+      }
+
+      buffer.release();
+      buffer = newBuffer;
+    } else {
       buffer.discardReadComponents();
     }
   }
 
-  protected ByteBuf decodeNext() throws Exception {
+  private ByteBuf decodeNext() throws Exception {
     if (buffer.readableBytes() < LENGTH_SIZE) {
       return null;
     }
@@ -127,10 +138,14 @@ public class TransportFrameDecoder extends 
ChannelInboundHandlerAdapter {
     this.interceptor = interceptor;
   }
 
-  private void feedInterceptor() throws Exception {
+  /**
+   * @return Whether the interceptor is still active after processing the data.
+   */
+  private boolean feedInterceptor() throws Exception {
     if (interceptor != null && !interceptor.handle(buffer)) {
       interceptor = null;
     }
+    return interceptor != null;
   }
 
   public static interface Interceptor {

http://git-wip-us.apache.org/repos/asf/spark/blob/540bf58f/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java
----------------------------------------------------------------------
diff --git 
a/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java
 
b/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java
index ca74f0a..19475c2 100644
--- 
a/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java
+++ 
b/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java
@@ -18,41 +18,36 @@
 package org.apache.spark.network.util;
 
 import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
 import java.util.Random;
+import java.util.concurrent.atomic.AtomicInteger;
 
 import io.netty.buffer.ByteBuf;
 import io.netty.buffer.Unpooled;
 import io.netty.channel.ChannelHandlerContext;
+import org.junit.AfterClass;
 import org.junit.Test;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
 import static org.junit.Assert.*;
 import static org.mockito.Mockito.*;
 
 public class TransportFrameDecoderSuite {
 
+  private static Random RND = new Random();
+
+  @AfterClass
+  public static void cleanup() {
+    RND = null;
+  }
+
   @Test
   public void testFrameDecoding() throws Exception {
-    Random rnd = new Random();
     TransportFrameDecoder decoder = new TransportFrameDecoder();
-    ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
-
-    final int frameCount = 100;
-    ByteBuf data = Unpooled.buffer();
-    try {
-      for (int i = 0; i < frameCount; i++) {
-        byte[] frame = new byte[1024 * (rnd.nextInt(31) + 1)];
-        data.writeLong(frame.length + 8);
-        data.writeBytes(frame);
-      }
-
-      while (data.isReadable()) {
-        int size = rnd.nextInt(16 * 1024) + 256;
-        decoder.channelRead(ctx, data.readSlice(Math.min(data.readableBytes(), 
size)));
-      }
-
-      verify(ctx, times(frameCount)).fireChannelRead(any(ByteBuf.class));
-    } finally {
-      data.release();
-    }
+    ChannelHandlerContext ctx = mockChannelHandlerContext();
+    ByteBuf data = createAndFeedFrames(100, decoder, ctx);
+    verifyAndCloseDecoder(decoder, ctx, data);
   }
 
   @Test
@@ -60,7 +55,7 @@ public class TransportFrameDecoderSuite {
     final int interceptedReads = 3;
     TransportFrameDecoder decoder = new TransportFrameDecoder();
     TransportFrameDecoder.Interceptor interceptor = spy(new 
MockInterceptor(interceptedReads));
-    ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
+    ChannelHandlerContext ctx = mockChannelHandlerContext();
 
     byte[] data = new byte[8];
     ByteBuf len = Unpooled.copyLong(8 + data.length);
@@ -70,16 +65,56 @@ public class TransportFrameDecoderSuite {
       decoder.setInterceptor(interceptor);
       for (int i = 0; i < interceptedReads; i++) {
         decoder.channelRead(ctx, dataBuf);
-        dataBuf.release();
+        assertEquals(0, dataBuf.refCnt());
         dataBuf = Unpooled.wrappedBuffer(data);
       }
       decoder.channelRead(ctx, len);
       decoder.channelRead(ctx, dataBuf);
       verify(interceptor, times(interceptedReads)).handle(any(ByteBuf.class));
       verify(ctx).fireChannelRead(any(ByteBuffer.class));
+      assertEquals(0, len.refCnt());
+      assertEquals(0, dataBuf.refCnt());
     } finally {
-      len.release();
-      dataBuf.release();
+      release(len);
+      release(dataBuf);
+    }
+  }
+
+  @Test
+  public void testRetainedFrames() throws Exception {
+    TransportFrameDecoder decoder = new TransportFrameDecoder();
+
+    final AtomicInteger count = new AtomicInteger();
+    final List<ByteBuf> retained = new ArrayList<>();
+
+    ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
+    when(ctx.fireChannelRead(any())).thenAnswer(new Answer<Void>() {
+      @Override
+      public Void answer(InvocationOnMock in) {
+        // Retain a few frames but not others.
+        ByteBuf buf = (ByteBuf) in.getArguments()[0];
+        if (count.incrementAndGet() % 2 == 0) {
+          retained.add(buf);
+        } else {
+          buf.release();
+        }
+        return null;
+      }
+    });
+
+    ByteBuf data = createAndFeedFrames(100, decoder, ctx);
+    try {
+      // Verify all retained buffers are readable.
+      for (ByteBuf b : retained) {
+        byte[] tmp = new byte[b.readableBytes()];
+        b.readBytes(tmp);
+        b.release();
+      }
+      verifyAndCloseDecoder(decoder, ctx, data);
+    } finally {
+      for (ByteBuf b : retained) {
+        release(b);
+      }
     }
   }
 
@@ -100,6 +135,47 @@ public class TransportFrameDecoderSuite {
     testInvalidFrame(Integer.MAX_VALUE + 9);
   }
 
+  /**
+   * Creates a number of randomly sized frames and feed them to the given 
decoder, verifying
+   * that the frames were read.
+   */
+  private ByteBuf createAndFeedFrames(
+      int frameCount,
+      TransportFrameDecoder decoder,
+      ChannelHandlerContext ctx) throws Exception {
+    ByteBuf data = Unpooled.buffer();
+    for (int i = 0; i < frameCount; i++) {
+      byte[] frame = new byte[1024 * (RND.nextInt(31) + 1)];
+      data.writeLong(frame.length + 8);
+      data.writeBytes(frame);
+    }
+
+    try {
+      while (data.isReadable()) {
+        int size = RND.nextInt(4 * 1024) + 256;
+        decoder.channelRead(ctx, data.readSlice(Math.min(data.readableBytes(), 
size)).retain());
+      }
+
+      verify(ctx, times(frameCount)).fireChannelRead(any(ByteBuf.class));
+    } catch (Exception e) {
+      release(data);
+      throw e;
+    }
+    return data;
+  }
+
+  private void verifyAndCloseDecoder(
+      TransportFrameDecoder decoder,
+      ChannelHandlerContext ctx,
+      ByteBuf data) throws Exception {
+    try {
+      decoder.channelInactive(ctx);
+      assertTrue("There shouldn't be dangling references to the data.", 
data.release());
+    } finally {
+      release(data);
+    }
+  }
+
   private void testInvalidFrame(long size) throws Exception {
     TransportFrameDecoder decoder = new TransportFrameDecoder();
     ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
@@ -111,6 +187,25 @@ public class TransportFrameDecoderSuite {
     }
   }
 
+  private ChannelHandlerContext mockChannelHandlerContext() {
+    ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
+    when(ctx.fireChannelRead(any())).thenAnswer(new Answer<Void>() {
+      @Override
+      public Void answer(InvocationOnMock in) {
+        ByteBuf buf = (ByteBuf) in.getArguments()[0];
+        buf.release();
+        return null;
+      }
+    });
+    return ctx;
+  }
+
+  private void release(ByteBuf buf) {
+    if (buf.refCnt() > 0) {
+      buf.release(buf.refCnt());
+    }
+  }
+
   private static class MockInterceptor implements 
TransportFrameDecoder.Interceptor {
 
     private int remainingReads;


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to