Repository: spark
Updated Branches:
  refs/heads/master 8790ee6d6 -> 27feafccb


[SPARK-11235][NETWORK] Add ability to stream data using network lib.

The current interface used to fetch shuffle data is not very efficient for
large buffers; it requires the receiver to buffer the entirety of the
contents being downloaded in memory before processing the data.

To use the network library to transfer large files (such as those that
can be added using SparkContext addJar / addFile), this change adds a
more efficient way of downloding data, by streaming the data and feeding
it to a callback as data arrives.

This is achieved by a custom frame decoder that replaces the current netty
one; this decoder allows entering a mode where framing is skipped and data
is instead provided directly to a callback. The existing netty classes
(ByteToMessageDecoder and LengthFieldBasedFrameDecoder) could not be reused
since their semantics do not allow for the interception approach the new
decoder uses.

Author: Marcelo Vanzin <van...@cloudera.com>

Closes #9206 from vanzin/SPARK-11235.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/27feafcc
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/27feafcc
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/27feafcc

Branch: refs/heads/master
Commit: 27feafccbd6945b000ca51b14c57912acbad9031
Parents: 8790ee6
Author: Marcelo Vanzin <van...@cloudera.com>
Authored: Wed Nov 4 09:11:54 2015 -0800
Committer: Marcelo Vanzin <van...@cloudera.com>
Committed: Wed Nov 4 09:11:54 2015 -0800

----------------------------------------------------------------------
 .../apache/spark/network/TransportContext.java  |   3 +-
 .../spark/network/client/StreamCallback.java    |  40 +++
 .../spark/network/client/StreamInterceptor.java |  76 +++++
 .../spark/network/client/TransportClient.java   |  41 +++
 .../client/TransportResponseHandler.java        |  47 ++-
 .../network/protocol/ChunkFetchSuccess.java     |  16 +-
 .../apache/spark/network/protocol/Message.java  |   6 +-
 .../spark/network/protocol/MessageDecoder.java  |   9 +
 .../spark/network/protocol/MessageEncoder.java  |  27 +-
 .../network/protocol/ResponseWithBody.java      |  40 +++
 .../spark/network/protocol/StreamFailure.java   |  80 +++++
 .../spark/network/protocol/StreamRequest.java   |  78 +++++
 .../spark/network/protocol/StreamResponse.java  |  91 ++++++
 .../spark/network/server/StreamManager.java     |  13 +
 .../network/server/TransportRequestHandler.java |  20 ++
 .../apache/spark/network/util/NettyUtils.java   |   9 +-
 .../network/util/TransportFrameDecoder.java     | 154 +++++++++
 .../org/apache/spark/network/ProtocolSuite.java |   8 +
 .../org/apache/spark/network/StreamSuite.java   | 325 +++++++++++++++++++
 .../util/TransportFrameDecoderSuite.java        | 142 ++++++++
 20 files changed, 1196 insertions(+), 29 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/27feafcc/network/common/src/main/java/org/apache/spark/network/TransportContext.java
----------------------------------------------------------------------
diff --git 
a/network/common/src/main/java/org/apache/spark/network/TransportContext.java 
b/network/common/src/main/java/org/apache/spark/network/TransportContext.java
index b8d073f..43900e6 100644
--- 
a/network/common/src/main/java/org/apache/spark/network/TransportContext.java
+++ 
b/network/common/src/main/java/org/apache/spark/network/TransportContext.java
@@ -39,6 +39,7 @@ import org.apache.spark.network.server.TransportServer;
 import org.apache.spark.network.server.TransportServerBootstrap;
 import org.apache.spark.network.util.NettyUtils;
 import org.apache.spark.network.util.TransportConf;
+import org.apache.spark.network.util.TransportFrameDecoder;
 
 /**
  * Contains the context to create a {@link TransportServer}, {@link 
TransportClientFactory}, and to
@@ -119,7 +120,7 @@ public class TransportContext {
       TransportChannelHandler channelHandler = createChannelHandler(channel, 
channelRpcHandler);
       channel.pipeline()
         .addLast("encoder", encoder)
-        .addLast("frameDecoder", NettyUtils.createFrameDecoder())
+        .addLast(TransportFrameDecoder.HANDLER_NAME, 
NettyUtils.createFrameDecoder())
         .addLast("decoder", decoder)
         .addLast("idleStateHandler", new IdleStateHandler(0, 0, 
conf.connectionTimeoutMs() / 1000))
         // NOTE: Chunks are currently guaranteed to be returned in the order 
of request, but this

http://git-wip-us.apache.org/repos/asf/spark/blob/27feafcc/network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java
----------------------------------------------------------------------
diff --git 
a/network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java
 
b/network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java
new file mode 100644
index 0000000..093fada
--- /dev/null
+++ 
b/network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+
+/**
+ * Callback for streaming data. Stream data will be offered to the {@link 
onData(ByteBuffer)}
+ * method as it arrives. Once all the stream data is received, {@link 
onComplete()} will be
+ * called.
+ * <p>
+ * The network library guarantees that a single thread will call these methods 
at a time, but
+ * different call may be made by different threads.
+ */
+public interface StreamCallback {
+  /** Called upon receipt of stream data. */
+  void onData(String streamId, ByteBuffer buf) throws IOException;
+
+  /** Called when all data from the stream has been received. */
+  void onComplete(String streamId) throws IOException;
+
+  /** Called if there's an error reading data from the stream. */
+  void onFailure(String streamId, Throwable cause) throws IOException;
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/27feafcc/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java
----------------------------------------------------------------------
diff --git 
a/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java
 
b/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java
new file mode 100644
index 0000000..02230a0
--- /dev/null
+++ 
b/network/common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+import java.nio.ByteBuffer;
+import java.nio.channels.ClosedChannelException;
+
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.util.TransportFrameDecoder;
+
+/**
+ * An interceptor that is registered with the frame decoder to feed stream 
data to a
+ * callback.
+ */
+class StreamInterceptor implements TransportFrameDecoder.Interceptor {
+
+  private final String streamId;
+  private final long byteCount;
+  private final StreamCallback callback;
+
+  private volatile long bytesRead;
+
+  StreamInterceptor(String streamId, long byteCount, StreamCallback callback) {
+    this.streamId = streamId;
+    this.byteCount = byteCount;
+    this.callback = callback;
+    this.bytesRead = 0;
+  }
+
+  @Override
+  public void exceptionCaught(Throwable cause) throws Exception {
+    callback.onFailure(streamId, cause);
+  }
+
+  @Override
+  public void channelInactive() throws Exception {
+    callback.onFailure(streamId, new ClosedChannelException());
+  }
+
+  @Override
+  public boolean handle(ByteBuf buf) throws Exception {
+    int toRead = (int) Math.min(buf.readableBytes(), byteCount - bytesRead);
+    ByteBuffer nioBuffer = buf.readSlice(toRead).nioBuffer();
+
+    int available = nioBuffer.remaining();
+    callback.onData(streamId, nioBuffer);
+    bytesRead += available;
+    if (bytesRead > byteCount) {
+      RuntimeException re = new IllegalStateException(String.format(
+        "Read too many bytes? Expected %d, but read %d.", byteCount, 
bytesRead));
+      callback.onFailure(streamId, re);
+      throw re;
+    } else if (bytesRead == byteCount) {
+      callback.onComplete(streamId);
+    }
+
+    return bytesRead != byteCount;
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/27feafcc/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
----------------------------------------------------------------------
diff --git 
a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
 
b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
index fbb8bb6..a0ba223 100644
--- 
a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
+++ 
b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
@@ -38,6 +38,7 @@ import org.slf4j.LoggerFactory;
 import org.apache.spark.network.protocol.ChunkFetchRequest;
 import org.apache.spark.network.protocol.RpcRequest;
 import org.apache.spark.network.protocol.StreamChunkId;
+import org.apache.spark.network.protocol.StreamRequest;
 import org.apache.spark.network.util.NettyUtils;
 
 /**
@@ -160,6 +161,46 @@ public class TransportClient implements Closeable {
   }
 
   /**
+   * Request to stream the data with the given stream ID from the remote end.
+   *
+   * @param streamId The stream to fetch.
+   * @param callback Object to call with the stream data.
+   */
+  public void stream(final String streamId, final StreamCallback callback) {
+    final String serverAddr = NettyUtils.getRemoteAddress(channel);
+    final long startTime = System.currentTimeMillis();
+    logger.debug("Sending stream request for {} to {}", streamId, serverAddr);
+
+    // Need to synchronize here so that the callback is added to the queue and 
the RPC is
+    // written to the socket atomically, so that callbacks are called in the 
right order
+    // when responses arrive.
+    synchronized (this) {
+      handler.addStreamCallback(callback);
+      channel.writeAndFlush(new StreamRequest(streamId)).addListener(
+        new ChannelFutureListener() {
+          @Override
+          public void operationComplete(ChannelFuture future) throws Exception 
{
+            if (future.isSuccess()) {
+              long timeTaken = System.currentTimeMillis() - startTime;
+              logger.trace("Sending request for {} to {} took {} ms", 
streamId, serverAddr,
+                timeTaken);
+            } else {
+              String errorMsg = String.format("Failed to send request for %s 
to %s: %s", streamId,
+                serverAddr, future.cause());
+              logger.error(errorMsg, future.cause());
+              channel.close();
+              try {
+                callback.onFailure(streamId, new IOException(errorMsg, 
future.cause()));
+              } catch (Exception e) {
+                logger.error("Uncaught exception in RPC response callback 
handler!", e);
+              }
+            }
+          }
+        });
+    }
+  }
+
+  /**
    * Sends an opaque message to the RpcHandler on the server-side. The 
callback will be invoked
    * with the server's response or upon any failure.
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/27feafcc/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
----------------------------------------------------------------------
diff --git 
a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
 
b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
index 94fc21a..ed3f36a 100644
--- 
a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
+++ 
b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
@@ -19,7 +19,9 @@ package org.apache.spark.network.client;
 
 import java.io.IOException;
 import java.util.Map;
+import java.util.Queue;
 import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentLinkedQueue;
 import java.util.concurrent.atomic.AtomicLong;
 
 import io.netty.channel.Channel;
@@ -32,8 +34,11 @@ import org.apache.spark.network.protocol.ResponseMessage;
 import org.apache.spark.network.protocol.RpcFailure;
 import org.apache.spark.network.protocol.RpcResponse;
 import org.apache.spark.network.protocol.StreamChunkId;
+import org.apache.spark.network.protocol.StreamFailure;
+import org.apache.spark.network.protocol.StreamResponse;
 import org.apache.spark.network.server.MessageHandler;
 import org.apache.spark.network.util.NettyUtils;
+import org.apache.spark.network.util.TransportFrameDecoder;
 
 /**
  * Handler that processes server responses, in response to requests issued 
from a
@@ -50,6 +55,8 @@ public class TransportResponseHandler extends 
MessageHandler<ResponseMessage> {
 
   private final Map<Long, RpcResponseCallback> outstandingRpcs;
 
+  private final Queue<StreamCallback> streamCallbacks;
+
   /** Records the time (in system nanoseconds) that the last fetch or RPC 
request was sent. */
   private final AtomicLong timeOfLastRequestNs;
 
@@ -57,6 +64,7 @@ public class TransportResponseHandler extends 
MessageHandler<ResponseMessage> {
     this.channel = channel;
     this.outstandingFetches = new ConcurrentHashMap<StreamChunkId, 
ChunkReceivedCallback>();
     this.outstandingRpcs = new ConcurrentHashMap<Long, RpcResponseCallback>();
+    this.streamCallbacks = new ConcurrentLinkedQueue<StreamCallback>();
     this.timeOfLastRequestNs = new AtomicLong(0);
   }
 
@@ -78,6 +86,10 @@ public class TransportResponseHandler extends 
MessageHandler<ResponseMessage> {
     outstandingRpcs.remove(requestId);
   }
 
+  public void addStreamCallback(StreamCallback callback) {
+    streamCallbacks.offer(callback);
+  }
+
   /**
    * Fire the failure callback for all outstanding requests. This is called 
when we have an
    * uncaught exception or pre-mature connection termination.
@@ -124,11 +136,11 @@ public class TransportResponseHandler extends 
MessageHandler<ResponseMessage> {
       if (listener == null) {
         logger.warn("Ignoring response for block {} from {} since it is not 
outstanding",
           resp.streamChunkId, remoteAddress);
-        resp.buffer.release();
+        resp.body.release();
       } else {
         outstandingFetches.remove(resp.streamChunkId);
-        listener.onSuccess(resp.streamChunkId.chunkIndex, resp.buffer);
-        resp.buffer.release();
+        listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body);
+        resp.body.release();
       }
     } else if (message instanceof ChunkFetchFailure) {
       ChunkFetchFailure resp = (ChunkFetchFailure) message;
@@ -161,6 +173,34 @@ public class TransportResponseHandler extends 
MessageHandler<ResponseMessage> {
         outstandingRpcs.remove(resp.requestId);
         listener.onFailure(new RuntimeException(resp.errorString));
       }
+    } else if (message instanceof StreamResponse) {
+      StreamResponse resp = (StreamResponse) message;
+      StreamCallback callback = streamCallbacks.poll();
+      if (callback != null) {
+        StreamInterceptor interceptor = new StreamInterceptor(resp.streamId, 
resp.byteCount,
+          callback);
+        try {
+          TransportFrameDecoder frameDecoder = (TransportFrameDecoder)
+            channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME);
+          frameDecoder.setInterceptor(interceptor);
+        } catch (Exception e) {
+          logger.error("Error installing stream handler.", e);
+        }
+      } else {
+        logger.error("Could not find callback for StreamResponse.");
+      }
+    } else if (message instanceof StreamFailure) {
+      StreamFailure resp = (StreamFailure) message;
+      StreamCallback callback = streamCallbacks.poll();
+      if (callback != null) {
+        try {
+          callback.onFailure(resp.streamId, new RuntimeException(resp.error));
+        } catch (IOException ioe) {
+          logger.warn("Error in stream failure handler.", ioe);
+        }
+      } else {
+        logger.warn("Stream failure with unknown callback: {}", resp.error);
+      }
     } else {
       throw new IllegalStateException("Unknown response type: " + 
message.type());
     }
@@ -175,4 +215,5 @@ public class TransportResponseHandler extends 
MessageHandler<ResponseMessage> {
   public long getTimeOfLastRequestNs() {
     return timeOfLastRequestNs.get();
   }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/27feafcc/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java
----------------------------------------------------------------------
diff --git 
a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java
 
b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java
index c962fb7..e6a7e9a 100644
--- 
a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java
+++ 
b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java
@@ -30,13 +30,12 @@ import org.apache.spark.network.buffer.NettyManagedBuffer;
  * may be written by Netty in a more efficient manner (i.e., zero-copy write).
  * Similarly, the client-side decoding will reuse the Netty ByteBuf as the 
buffer.
  */
-public final class ChunkFetchSuccess implements ResponseMessage {
+public final class ChunkFetchSuccess extends ResponseWithBody {
   public final StreamChunkId streamChunkId;
-  public final ManagedBuffer buffer;
 
   public ChunkFetchSuccess(StreamChunkId streamChunkId, ManagedBuffer buffer) {
+    super(buffer, true);
     this.streamChunkId = streamChunkId;
-    this.buffer = buffer;
   }
 
   @Override
@@ -53,6 +52,11 @@ public final class ChunkFetchSuccess implements 
ResponseMessage {
     streamChunkId.encode(buf);
   }
 
+  @Override
+  public ResponseMessage createFailureResponse(String error) {
+    return new ChunkFetchFailure(streamChunkId, error);
+  }
+
   /** Decoding uses the given ByteBuf as our data, and will retain() it. */
   public static ChunkFetchSuccess decode(ByteBuf buf) {
     StreamChunkId streamChunkId = StreamChunkId.decode(buf);
@@ -63,14 +67,14 @@ public final class ChunkFetchSuccess implements 
ResponseMessage {
 
   @Override
   public int hashCode() {
-    return Objects.hashCode(streamChunkId, buffer);
+    return Objects.hashCode(streamChunkId, body);
   }
 
   @Override
   public boolean equals(Object other) {
     if (other instanceof ChunkFetchSuccess) {
       ChunkFetchSuccess o = (ChunkFetchSuccess) other;
-      return streamChunkId.equals(o.streamChunkId) && buffer.equals(o.buffer);
+      return streamChunkId.equals(o.streamChunkId) && body.equals(o.body);
     }
     return false;
   }
@@ -79,7 +83,7 @@ public final class ChunkFetchSuccess implements 
ResponseMessage {
   public String toString() {
     return Objects.toStringHelper(this)
       .add("streamChunkId", streamChunkId)
-      .add("buffer", buffer)
+      .add("buffer", body)
       .toString();
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/27feafcc/network/common/src/main/java/org/apache/spark/network/protocol/Message.java
----------------------------------------------------------------------
diff --git 
a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java 
b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java
index d568370..d01598c 100644
--- 
a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java
+++ 
b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java
@@ -27,7 +27,8 @@ public interface Message extends Encodable {
   /** Preceding every serialized Message is its type, which allows us to 
deserialize it. */
   public static enum Type implements Encodable {
     ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2),
-    RpcRequest(3), RpcResponse(4), RpcFailure(5);
+    RpcRequest(3), RpcResponse(4), RpcFailure(5),
+    StreamRequest(6), StreamResponse(7), StreamFailure(8);
 
     private final byte id;
 
@@ -51,6 +52,9 @@ public interface Message extends Encodable {
         case 3: return RpcRequest;
         case 4: return RpcResponse;
         case 5: return RpcFailure;
+        case 6: return StreamRequest;
+        case 7: return StreamResponse;
+        case 8: return StreamFailure;
         default: throw new IllegalArgumentException("Unknown message type: " + 
id);
       }
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/27feafcc/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java
----------------------------------------------------------------------
diff --git 
a/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java
 
b/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java
index 81f8d7f..3c04048 100644
--- 
a/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java
+++ 
b/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java
@@ -63,6 +63,15 @@ public final class MessageDecoder extends 
MessageToMessageDecoder<ByteBuf> {
       case RpcFailure:
         return RpcFailure.decode(in);
 
+      case StreamRequest:
+        return StreamRequest.decode(in);
+
+      case StreamResponse:
+        return StreamResponse.decode(in);
+
+      case StreamFailure:
+        return StreamFailure.decode(in);
+
       default:
         throw new IllegalArgumentException("Unexpected message type: " + 
msgType);
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/27feafcc/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
----------------------------------------------------------------------
diff --git 
a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
 
b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
index 0f999f5..6cce97c 100644
--- 
a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
+++ 
b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java
@@ -45,27 +45,32 @@ public final class MessageEncoder extends 
MessageToMessageEncoder<Message> {
   public void encode(ChannelHandlerContext ctx, Message in, List<Object> out) {
     Object body = null;
     long bodyLength = 0;
+    boolean isBodyInFrame = false;
 
-    // Only ChunkFetchSuccesses have data besides the header.
+    // Detect ResponseWithBody messages and get the data buffer out of them.
     // The body is used in order to enable zero-copy transfer for the payload.
-    if (in instanceof ChunkFetchSuccess) {
-      ChunkFetchSuccess resp = (ChunkFetchSuccess) in;
+    if (in instanceof ResponseWithBody) {
+      ResponseWithBody resp = (ResponseWithBody) in;
       try {
-        bodyLength = resp.buffer.size();
-        body = resp.buffer.convertToNetty();
+        bodyLength = resp.body.size();
+        body = resp.body.convertToNetty();
+        isBodyInFrame = resp.isBodyInFrame;
       } catch (Exception e) {
-        // Re-encode this message as BlockFetchFailure.
-        logger.error(String.format("Error opening block %s for client %s",
-          resp.streamChunkId, ctx.channel().remoteAddress()), e);
-        encode(ctx, new ChunkFetchFailure(resp.streamChunkId, e.getMessage()), 
out);
+        // Re-encode this message as a failure response.
+        String error = e.getMessage() != null ? e.getMessage() : "null";
+        logger.error(String.format("Error processing %s for client %s",
+          resp, ctx.channel().remoteAddress()), e);
+        encode(ctx, resp.createFailureResponse(error), out);
         return;
       }
     }
 
     Message.Type msgType = in.type();
-    // All messages have the frame length, message type, and message itself.
+    // All messages have the frame length, message type, and message itself. 
The frame length
+    // may optionally include the length of the body data, depending on what 
message is being
+    // sent.
     int headerLength = 8 + msgType.encodedLength() + in.encodedLength();
-    long frameLength = headerLength + bodyLength;
+    long frameLength = headerLength + (isBodyInFrame ? bodyLength : 0);
     ByteBuf header = ctx.alloc().heapBuffer(headerLength);
     header.writeLong(frameLength);
     msgType.encode(header);

http://git-wip-us.apache.org/repos/asf/spark/blob/27feafcc/network/common/src/main/java/org/apache/spark/network/protocol/ResponseWithBody.java
----------------------------------------------------------------------
diff --git 
a/network/common/src/main/java/org/apache/spark/network/protocol/ResponseWithBody.java
 
b/network/common/src/main/java/org/apache/spark/network/protocol/ResponseWithBody.java
new file mode 100644
index 0000000..67be77e
--- /dev/null
+++ 
b/network/common/src/main/java/org/apache/spark/network/protocol/ResponseWithBody.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
+
+/**
+ * Abstract class for response messages that contain a large data portion kept 
in a separate
+ * buffer. These messages are treated especially by MessageEncoder.
+ */
+public abstract class ResponseWithBody implements ResponseMessage {
+  public final ManagedBuffer body;
+  public final boolean isBodyInFrame;
+
+  protected ResponseWithBody(ManagedBuffer body, boolean isBodyInFrame) {
+    this.body = body;
+    this.isBodyInFrame = isBodyInFrame;
+  }
+
+  public abstract ResponseMessage createFailureResponse(String error);
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/27feafcc/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java
----------------------------------------------------------------------
diff --git 
a/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java
 
b/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java
new file mode 100644
index 0000000..e3dade2
--- /dev/null
+++ 
b/network/common/src/main/java/org/apache/spark/network/protocol/StreamFailure.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
+
+/**
+ * Message indicating an error when transferring a stream.
+ */
+public final class StreamFailure implements ResponseMessage {
+  public final String streamId;
+  public final String error;
+
+  public StreamFailure(String streamId, String error) {
+    this.streamId = streamId;
+    this.error = error;
+  }
+
+  @Override
+  public Type type() { return Type.StreamFailure; }
+
+  @Override
+  public int encodedLength() {
+    return Encoders.Strings.encodedLength(streamId) + 
Encoders.Strings.encodedLength(error);
+  }
+
+  @Override
+  public void encode(ByteBuf buf) {
+    Encoders.Strings.encode(buf, streamId);
+    Encoders.Strings.encode(buf, error);
+  }
+
+  public static StreamFailure decode(ByteBuf buf) {
+    String streamId = Encoders.Strings.decode(buf);
+    String error = Encoders.Strings.decode(buf);
+    return new StreamFailure(streamId, error);
+  }
+
+  @Override
+  public int hashCode() {
+    return Objects.hashCode(streamId, error);
+  }
+
+  @Override
+  public boolean equals(Object other) {
+    if (other instanceof StreamFailure) {
+      StreamFailure o = (StreamFailure) other;
+      return streamId.equals(o.streamId) && error.equals(o.error);
+    }
+    return false;
+  }
+
+  @Override
+  public String toString() {
+    return Objects.toStringHelper(this)
+      .add("streamId", streamId)
+      .add("error", error)
+      .toString();
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/27feafcc/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java
----------------------------------------------------------------------
diff --git 
a/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java
 
b/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java
new file mode 100644
index 0000000..821e8f5
--- /dev/null
+++ 
b/network/common/src/main/java/org/apache/spark/network/protocol/StreamRequest.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
+
+/**
+ * Request to stream data from the remote end.
+ * <p>
+ * The stream ID is an arbitrary string that needs to be negotiated between 
the two endpoints before
+ * the data can be streamed.
+ */
+public final class StreamRequest implements RequestMessage {
+   public final String streamId;
+
+   public StreamRequest(String streamId) {
+     this.streamId = streamId;
+   }
+
+  @Override
+  public Type type() { return Type.StreamRequest; }
+
+  @Override
+  public int encodedLength() {
+    return Encoders.Strings.encodedLength(streamId);
+  }
+
+  @Override
+  public void encode(ByteBuf buf) {
+    Encoders.Strings.encode(buf, streamId);
+  }
+
+  public static StreamRequest decode(ByteBuf buf) {
+    String streamId = Encoders.Strings.decode(buf);
+    return new StreamRequest(streamId);
+  }
+
+  @Override
+  public int hashCode() {
+    return Objects.hashCode(streamId);
+  }
+
+  @Override
+  public boolean equals(Object other) {
+    if (other instanceof StreamRequest) {
+      StreamRequest o = (StreamRequest) other;
+      return streamId.equals(o.streamId);
+    }
+    return false;
+  }
+
+  @Override
+  public String toString() {
+    return Objects.toStringHelper(this)
+      .add("streamId", streamId)
+      .toString();
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/27feafcc/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java
----------------------------------------------------------------------
diff --git 
a/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java
 
b/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java
new file mode 100644
index 0000000..ac5ab9a
--- /dev/null
+++ 
b/network/common/src/main/java/org/apache/spark/network/protocol/StreamResponse.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
+
+/**
+ * Response to {@link StreamRequest} when the stream has been successfully 
opened.
+ * <p>
+ * Note the message itself does not contain the stream data. That is written 
separately by the
+ * sender. The receiver is expected to set a temporary channel handler that 
will consume the
+ * number of bytes this message says the stream has.
+ */
+public final class StreamResponse extends ResponseWithBody {
+   public final String streamId;
+   public final long byteCount;
+
+   public StreamResponse(String streamId, long byteCount, ManagedBuffer 
buffer) {
+     super(buffer, false);
+     this.streamId = streamId;
+     this.byteCount = byteCount;
+   }
+
+  @Override
+  public Type type() { return Type.StreamResponse; }
+
+  @Override
+  public int encodedLength() {
+    return 8 + Encoders.Strings.encodedLength(streamId);
+  }
+
+  /** Encoding does NOT include 'buffer' itself. See {@link MessageEncoder}. */
+  @Override
+  public void encode(ByteBuf buf) {
+    Encoders.Strings.encode(buf, streamId);
+    buf.writeLong(byteCount);
+  }
+
+  @Override
+  public ResponseMessage createFailureResponse(String error) {
+    return new StreamFailure(streamId, error);
+  }
+
+  public static StreamResponse decode(ByteBuf buf) {
+    String streamId = Encoders.Strings.decode(buf);
+    long byteCount = buf.readLong();
+    return new StreamResponse(streamId, byteCount, null);
+  }
+
+  @Override
+  public int hashCode() {
+    return Objects.hashCode(byteCount, streamId);
+  }
+
+  @Override
+  public boolean equals(Object other) {
+    if (other instanceof StreamResponse) {
+      StreamResponse o = (StreamResponse) other;
+      return byteCount == o.byteCount && streamId.equals(o.streamId);
+    }
+    return false;
+  }
+
+  @Override
+  public String toString() {
+    return Objects.toStringHelper(this)
+      .add("streamId", streamId)
+      .add("byteCount", byteCount)
+      .toString();
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/27feafcc/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java
----------------------------------------------------------------------
diff --git 
a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java
 
b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java
index aaa677c..3f01559 100644
--- 
a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java
+++ 
b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java
@@ -47,6 +47,19 @@ public abstract class StreamManager {
   public abstract ManagedBuffer getChunk(long streamId, int chunkIndex);
 
   /**
+   * Called in response to a stream() request. The returned data is streamed 
to the client
+   * through a single TCP connection.
+   *
+   * Note the <code>streamId</code> argument is not related to the similarly 
named argument in the
+   * {@link #getChunk(long, int)} method.
+   *
+   * @param streamId id of a stream that has been previously registered with 
the StreamManager.
+   */
+  public ManagedBuffer openStream(String streamId) {
+    throw new UnsupportedOperationException();
+  }
+
+  /**
    * Associates a stream with a single client connection, which is guaranteed 
to be the only reader
    * of the stream. The getChunk() method will be called serially on this 
connection and once the
    * connection is closed, the stream will never be used again, enabling 
cleanup.

http://git-wip-us.apache.org/repos/asf/spark/blob/27feafcc/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
----------------------------------------------------------------------
diff --git 
a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
 
b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
index 9b8b047..4f67bd5 100644
--- 
a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
+++ 
b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java
@@ -35,6 +35,9 @@ import org.apache.spark.network.protocol.ChunkFetchFailure;
 import org.apache.spark.network.protocol.ChunkFetchSuccess;
 import org.apache.spark.network.protocol.RpcFailure;
 import org.apache.spark.network.protocol.RpcResponse;
+import org.apache.spark.network.protocol.StreamFailure;
+import org.apache.spark.network.protocol.StreamRequest;
+import org.apache.spark.network.protocol.StreamResponse;
 import org.apache.spark.network.util.NettyUtils;
 
 /**
@@ -92,6 +95,8 @@ public class TransportRequestHandler extends 
MessageHandler<RequestMessage> {
       processFetchRequest((ChunkFetchRequest) request);
     } else if (request instanceof RpcRequest) {
       processRpcRequest((RpcRequest) request);
+    } else if (request instanceof StreamRequest) {
+      processStreamRequest((StreamRequest) request);
     } else {
       throw new IllegalArgumentException("Unknown request type: " + request);
     }
@@ -117,6 +122,21 @@ public class TransportRequestHandler extends 
MessageHandler<RequestMessage> {
     respond(new ChunkFetchSuccess(req.streamChunkId, buf));
   }
 
+  private void processStreamRequest(final StreamRequest req) {
+    final String client = NettyUtils.getRemoteAddress(channel);
+    ManagedBuffer buf;
+    try {
+      buf = streamManager.openStream(req.streamId);
+    } catch (Exception e) {
+      logger.error(String.format(
+        "Error opening stream %s for request from %s", req.streamId, client), 
e);
+      respond(new StreamFailure(req.streamId, 
Throwables.getStackTraceAsString(e)));
+      return;
+    }
+
+    respond(new StreamResponse(req.streamId, buf.size(), buf));
+  }
+
   private void processRpcRequest(final RpcRequest req) {
     try {
       rpcHandler.receive(reverseClient, req.message, new RpcResponseCallback() 
{

http://git-wip-us.apache.org/repos/asf/spark/blob/27feafcc/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java
----------------------------------------------------------------------
diff --git 
a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java 
b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java
index 26c6399..caa7260 100644
--- a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java
+++ b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java
@@ -89,13 +89,8 @@ public class NettyUtils {
    * Creates a LengthFieldBasedFrameDecoder where the first 8 bytes are the 
length of the frame.
    * This is used before all decoders.
    */
-  public static ByteToMessageDecoder createFrameDecoder() {
-    // maxFrameLength = 2G
-    // lengthFieldOffset = 0
-    // lengthFieldLength = 8
-    // lengthAdjustment = -8, i.e. exclude the 8 byte length itself
-    // initialBytesToStrip = 8, i.e. strip out the length field itself
-    return new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 8, -8, 8);
+  public static TransportFrameDecoder createFrameDecoder() {
+    return new TransportFrameDecoder();
   }
 
   /** Returns the remote address on the channel or "&lt;unknown remote&gt;" if 
none exists. */

http://git-wip-us.apache.org/repos/asf/spark/blob/27feafcc/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java
----------------------------------------------------------------------
diff --git 
a/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java
 
b/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java
new file mode 100644
index 0000000..272ea84
--- /dev/null
+++ 
b/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+import com.google.common.base.Preconditions;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.CompositeByteBuf;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+
+/**
+ * A customized frame decoder that allows intercepting raw data.
+ * <p>
+ * This behaves like Netty's frame decoder (with harcoded parameters that 
match this library's
+ * needs), except it allows an interceptor to be installed to read data 
directly before it's
+ * framed.
+ * <p>
+ * Unlike Netty's frame decoder, each frame is dispatched to child handlers as 
soon as it's
+ * decoded, instead of building as many frames as the current buffer allows 
and dispatching
+ * all of them. This allows a child handler to install an interceptor if 
needed.
+ * <p>
+ * If an interceptor is installed, framing stops, and data is instead fed 
directly to the
+ * interceptor. When the interceptor indicates that it doesn't need to read 
any more data,
+ * framing resumes. Interceptors should not hold references to the data 
buffers provided
+ * to their handle() method.
+ */
+public class TransportFrameDecoder extends ChannelInboundHandlerAdapter {
+
+  public static final String HANDLER_NAME = "frameDecoder";
+  private static final int LENGTH_SIZE = 8;
+  private static final int MAX_FRAME_SIZE = Integer.MAX_VALUE;
+
+  private CompositeByteBuf buffer;
+  private volatile Interceptor interceptor;
+
+  @Override
+  public void channelRead(ChannelHandlerContext ctx, Object data) throws 
Exception {
+    ByteBuf in = (ByteBuf) data;
+
+    if (buffer == null) {
+      buffer = in.alloc().compositeBuffer();
+    }
+
+    buffer.writeBytes(in);
+
+    while (buffer.isReadable()) {
+      feedInterceptor();
+      if (interceptor != null) {
+        continue;
+      }
+
+      ByteBuf frame = decodeNext();
+      if (frame != null) {
+        ctx.fireChannelRead(frame);
+      } else {
+        break;
+      }
+    }
+
+    // We can't discard read sub-buffers if there are other references to the 
buffer (e.g.
+    // through slices used for framing). This assumes that code that retains 
references
+    // will call retain() from the thread that called "fireChannelRead()" 
above, otherwise
+    // ref counting will go awry.
+    if (buffer != null && buffer.refCnt() == 1) {
+      buffer.discardReadComponents();
+    }
+  }
+
+  protected ByteBuf decodeNext() throws Exception {
+    if (buffer.readableBytes() < LENGTH_SIZE) {
+      return null;
+    }
+
+    int frameLen = (int) buffer.readLong() - LENGTH_SIZE;
+    if (buffer.readableBytes() < frameLen) {
+      buffer.readerIndex(buffer.readerIndex() - LENGTH_SIZE);
+      return null;
+    }
+
+    Preconditions.checkArgument(frameLen < MAX_FRAME_SIZE, "Too large frame: 
%s", frameLen);
+    Preconditions.checkArgument(frameLen > 0, "Frame length should be 
positive: %s", frameLen);
+
+    ByteBuf frame = buffer.readSlice(frameLen);
+    frame.retain();
+    return frame;
+  }
+
+  @Override
+  public void channelInactive(ChannelHandlerContext ctx) throws Exception {
+    if (buffer != null) {
+      if (buffer.isReadable()) {
+        feedInterceptor();
+      }
+      buffer.release();
+    }
+    if (interceptor != null) {
+      interceptor.channelInactive();
+    }
+    super.channelInactive(ctx);
+  }
+
+  @Override
+  public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) 
throws Exception {
+    if (interceptor != null) {
+      interceptor.exceptionCaught(cause);
+    }
+    super.exceptionCaught(ctx, cause);
+  }
+
+  public void setInterceptor(Interceptor interceptor) {
+    Preconditions.checkState(this.interceptor == null, "Already have an 
interceptor.");
+    this.interceptor = interceptor;
+  }
+
+  private void feedInterceptor() throws Exception {
+    if (interceptor != null && !interceptor.handle(buffer)) {
+      interceptor = null;
+    }
+  }
+
+  public static interface Interceptor {
+
+    /**
+     * Handles data received from the remote end.
+     *
+     * @param data Buffer containing data.
+     * @return "true" if the interceptor expects more data, "false" to 
uninstall the interceptor.
+     */
+    boolean handle(ByteBuf data) throws Exception;
+
+    /** Called if an exception is thrown in the channel pipeline. */
+    void exceptionCaught(Throwable cause) throws Exception;
+
+    /** Called if the channel is closed and the interceptor is still 
installed. */
+    void channelInactive() throws Exception;
+
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/27feafcc/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
----------------------------------------------------------------------
diff --git 
a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java 
b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
index d500bc3..22b451f 100644
--- a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java
@@ -39,6 +39,9 @@ import org.apache.spark.network.protocol.RpcFailure;
 import org.apache.spark.network.protocol.RpcRequest;
 import org.apache.spark.network.protocol.RpcResponse;
 import org.apache.spark.network.protocol.StreamChunkId;
+import org.apache.spark.network.protocol.StreamFailure;
+import org.apache.spark.network.protocol.StreamRequest;
+import org.apache.spark.network.protocol.StreamResponse;
 import org.apache.spark.network.util.ByteArrayWritableChannel;
 import org.apache.spark.network.util.NettyUtils;
 
@@ -80,6 +83,7 @@ public class ProtocolSuite {
     testClientToServer(new ChunkFetchRequest(new StreamChunkId(1, 2)));
     testClientToServer(new RpcRequest(12345, new byte[0]));
     testClientToServer(new RpcRequest(12345, new byte[100]));
+    testClientToServer(new StreamRequest("abcde"));
   }
 
   @Test
@@ -92,6 +96,10 @@ public class ProtocolSuite {
     testServerToClient(new RpcResponse(12345, new byte[1000]));
     testServerToClient(new RpcFailure(0, "this is an error"));
     testServerToClient(new RpcFailure(0, ""));
+    // Note: buffer size must be "0" since StreamResponse's buffer is written 
differently to the
+    // channel and cannot be tested like this.
+    testServerToClient(new StreamResponse("anId", 12345L, new 
TestManagedBuffer(0)));
+    testServerToClient(new StreamFailure("anId", "this is an error"));
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/27feafcc/network/common/src/test/java/org/apache/spark/network/StreamSuite.java
----------------------------------------------------------------------
diff --git 
a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java 
b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java
new file mode 100644
index 0000000..6dcec83
--- /dev/null
+++ b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java
@@ -0,0 +1,325 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import java.io.ByteArrayOutputStream;
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.OutputStream;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Random;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.TimeUnit;
+
+import com.google.common.io.Files;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import static org.junit.Assert.*;
+
+import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NioManagedBuffer;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.StreamCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientFactory;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.util.SystemPropertyConfigProvider;
+import org.apache.spark.network.util.TransportConf;
+
+public class StreamSuite {
+  private static final String[] STREAMS = { "largeBuffer", "smallBuffer", 
"file" };
+
+  private static TransportServer server;
+  private static TransportClientFactory clientFactory;
+  private static File testFile;
+  private static File tempDir;
+
+  private static ByteBuffer smallBuffer;
+  private static ByteBuffer largeBuffer;
+
+  private static ByteBuffer createBuffer(int bufSize) {
+    ByteBuffer buf = ByteBuffer.allocate(bufSize);
+    for (int i = 0; i < bufSize; i ++) {
+      buf.put((byte) i);
+    }
+    buf.flip();
+    return buf;
+  }
+
+  @BeforeClass
+  public static void setUp() throws Exception {
+    tempDir = Files.createTempDir();
+    smallBuffer = createBuffer(100);
+    largeBuffer = createBuffer(100000);
+
+    testFile = File.createTempFile("stream-test-file", "txt", tempDir);
+    FileOutputStream fp = new FileOutputStream(testFile);
+    try {
+      Random rnd = new Random();
+      for (int i = 0; i < 512; i++) {
+        byte[] fileContent = new byte[1024];
+        rnd.nextBytes(fileContent);
+        fp.write(fileContent);
+      }
+    } finally {
+      fp.close();
+    }
+
+    final TransportConf conf = new TransportConf(new 
SystemPropertyConfigProvider());
+    final StreamManager streamManager = new StreamManager() {
+      @Override
+      public ManagedBuffer getChunk(long streamId, int chunkIndex) {
+        throw new UnsupportedOperationException();
+      }
+
+      @Override
+      public ManagedBuffer openStream(String streamId) {
+        switch (streamId) {
+          case "largeBuffer":
+            return new NioManagedBuffer(largeBuffer);
+          case "smallBuffer":
+            return new NioManagedBuffer(smallBuffer);
+          case "file":
+            return new FileSegmentManagedBuffer(conf, testFile, 0, 
testFile.length());
+          default:
+            throw new IllegalArgumentException("Invalid stream: " + streamId);
+        }
+      }
+    };
+    RpcHandler handler = new RpcHandler() {
+      @Override
+      public void receive(TransportClient client, byte[] message, 
RpcResponseCallback callback) {
+        throw new UnsupportedOperationException();
+      }
+
+      @Override
+      public StreamManager getStreamManager() {
+        return streamManager;
+      }
+    };
+    TransportContext context = new TransportContext(conf, handler);
+    server = context.createServer();
+    clientFactory = context.createClientFactory();
+  }
+
+  @AfterClass
+  public static void tearDown() {
+    server.close();
+    clientFactory.close();
+    if (tempDir != null) {
+      for (File f : tempDir.listFiles()) {
+        f.delete();
+      }
+      tempDir.delete();
+    }
+  }
+
+  @Test
+  public void testSingleStream() throws Throwable {
+    TransportClient client = 
clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+    try {
+      StreamTask task = new StreamTask(client, "largeBuffer", 
TimeUnit.SECONDS.toMillis(5));
+      task.run();
+      task.check();
+    } finally {
+      client.close();
+    }
+  }
+
+  @Test
+  public void testMultipleStreams() throws Throwable {
+    TransportClient client = 
clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+    try {
+      for (int i = 0; i < 20; i++) {
+        StreamTask task = new StreamTask(client, STREAMS[i % STREAMS.length],
+          TimeUnit.SECONDS.toMillis(5));
+        task.run();
+        task.check();
+      }
+    } finally {
+      client.close();
+    }
+  }
+
+  @Test
+  public void testConcurrentStreams() throws Throwable {
+    ExecutorService executor = Executors.newFixedThreadPool(20);
+    TransportClient client = 
clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
+
+    try {
+      List<StreamTask> tasks = new ArrayList<>();
+      for (int i = 0; i < 20; i++) {
+        StreamTask task = new StreamTask(client, STREAMS[i % STREAMS.length],
+          TimeUnit.SECONDS.toMillis(20));
+        tasks.add(task);
+        executor.submit(task);
+      }
+
+      executor.shutdown();
+      assertTrue("Timed out waiting for tasks.", executor.awaitTermination(30, 
TimeUnit.SECONDS));
+      for (StreamTask task : tasks) {
+        task.check();
+      }
+    } finally {
+      executor.shutdownNow();
+      client.close();
+    }
+  }
+
+  private static class StreamTask implements Runnable {
+
+    private final TransportClient client;
+    private final String streamId;
+    private final long timeoutMs;
+    private Throwable error;
+
+    StreamTask(TransportClient client, String streamId, long timeoutMs) {
+      this.client = client;
+      this.streamId = streamId;
+      this.timeoutMs = timeoutMs;
+    }
+
+    @Override
+    public void run() {
+      ByteBuffer srcBuffer = null;
+      OutputStream out = null;
+      File outFile = null;
+      try {
+        ByteArrayOutputStream baos = null;
+
+        switch (streamId) {
+          case "largeBuffer":
+            baos = new ByteArrayOutputStream();
+            out = baos;
+            srcBuffer = largeBuffer;
+            break;
+          case "smallBuffer":
+            baos = new ByteArrayOutputStream();
+            out = baos;
+            srcBuffer = smallBuffer;
+            break;
+          case "file":
+            outFile = File.createTempFile("data", ".tmp", tempDir);
+            out = new FileOutputStream(outFile);
+            break;
+          default:
+            throw new IllegalArgumentException(streamId);
+        }
+
+        TestCallback callback = new TestCallback(out);
+        client.stream(streamId, callback);
+        waitForCompletion(callback);
+
+        if (srcBuffer == null) {
+          assertTrue("File stream did not match.", Files.equal(testFile, 
outFile));
+        } else {
+          ByteBuffer base;
+          synchronized (srcBuffer) {
+            base = srcBuffer.duplicate();
+          }
+          byte[] result = baos.toByteArray();
+          byte[] expected = new byte[base.remaining()];
+          base.get(expected);
+          assertEquals(expected.length, result.length);
+          assertTrue("buffers don't match", Arrays.equals(expected, result));
+        }
+      } catch (Throwable t) {
+        error = t;
+      } finally {
+        if (out != null) {
+          try {
+            out.close();
+          } catch (Exception e) {
+            // ignore.
+          }
+        }
+        if (outFile != null) {
+          outFile.delete();
+        }
+      }
+    }
+
+    public void check() throws Throwable {
+      if (error != null) {
+        throw error;
+      }
+    }
+
+    private void waitForCompletion(TestCallback callback) throws Exception {
+      long now = System.currentTimeMillis();
+      long deadline = now + timeoutMs;
+      synchronized (callback) {
+        while (!callback.completed && now < deadline) {
+          callback.wait(deadline - now);
+          now = System.currentTimeMillis();
+        }
+      }
+      assertTrue("Timed out waiting for stream.", callback.completed);
+      assertNull(callback.error);
+    }
+
+  }
+
+  private static class TestCallback implements StreamCallback {
+
+    private final OutputStream out;
+    public volatile boolean completed;
+    public volatile Throwable error;
+
+    TestCallback(OutputStream out) {
+      this.out = out;
+      this.completed = false;
+    }
+
+    @Override
+    public void onData(String streamId, ByteBuffer buf) throws IOException {
+      byte[] tmp = new byte[buf.remaining()];
+      buf.get(tmp);
+      out.write(tmp);
+    }
+
+    @Override
+    public void onComplete(String streamId) throws IOException {
+      out.close();
+      synchronized (this) {
+        completed = true;
+        notifyAll();
+      }
+    }
+
+    @Override
+    public void onFailure(String streamId, Throwable cause) {
+      error = cause;
+      synchronized (this) {
+        completed = true;
+        notifyAll();
+      }
+    }
+
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/27feafcc/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java
----------------------------------------------------------------------
diff --git 
a/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java
 
b/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java
new file mode 100644
index 0000000..ca74f0a
--- /dev/null
+++ 
b/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.util;
+
+import java.nio.ByteBuffer;
+import java.util.Random;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.channel.ChannelHandlerContext;
+import org.junit.Test;
+import static org.junit.Assert.*;
+import static org.mockito.Mockito.*;
+
+public class TransportFrameDecoderSuite {
+
+  @Test
+  public void testFrameDecoding() throws Exception {
+    Random rnd = new Random();
+    TransportFrameDecoder decoder = new TransportFrameDecoder();
+    ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
+
+    final int frameCount = 100;
+    ByteBuf data = Unpooled.buffer();
+    try {
+      for (int i = 0; i < frameCount; i++) {
+        byte[] frame = new byte[1024 * (rnd.nextInt(31) + 1)];
+        data.writeLong(frame.length + 8);
+        data.writeBytes(frame);
+      }
+
+      while (data.isReadable()) {
+        int size = rnd.nextInt(16 * 1024) + 256;
+        decoder.channelRead(ctx, data.readSlice(Math.min(data.readableBytes(), 
size)));
+      }
+
+      verify(ctx, times(frameCount)).fireChannelRead(any(ByteBuf.class));
+    } finally {
+      data.release();
+    }
+  }
+
+  @Test
+  public void testInterception() throws Exception {
+    final int interceptedReads = 3;
+    TransportFrameDecoder decoder = new TransportFrameDecoder();
+    TransportFrameDecoder.Interceptor interceptor = spy(new 
MockInterceptor(interceptedReads));
+    ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
+
+    byte[] data = new byte[8];
+    ByteBuf len = Unpooled.copyLong(8 + data.length);
+    ByteBuf dataBuf = Unpooled.wrappedBuffer(data);
+
+    try {
+      decoder.setInterceptor(interceptor);
+      for (int i = 0; i < interceptedReads; i++) {
+        decoder.channelRead(ctx, dataBuf);
+        dataBuf.release();
+        dataBuf = Unpooled.wrappedBuffer(data);
+      }
+      decoder.channelRead(ctx, len);
+      decoder.channelRead(ctx, dataBuf);
+      verify(interceptor, times(interceptedReads)).handle(any(ByteBuf.class));
+      verify(ctx).fireChannelRead(any(ByteBuffer.class));
+    } finally {
+      len.release();
+      dataBuf.release();
+    }
+  }
+
+  @Test(expected = IllegalArgumentException.class)
+  public void testNegativeFrameSize() throws Exception {
+    testInvalidFrame(-1);
+  }
+
+  @Test(expected = IllegalArgumentException.class)
+  public void testEmptyFrame() throws Exception {
+    // 8 because frame size includes the frame length.
+    testInvalidFrame(8);
+  }
+
+  @Test(expected = IllegalArgumentException.class)
+  public void testLargeFrame() throws Exception {
+    // Frame length includes the frame size field, so need to add a few more 
bytes.
+    testInvalidFrame(Integer.MAX_VALUE + 9);
+  }
+
+  private void testInvalidFrame(long size) throws Exception {
+    TransportFrameDecoder decoder = new TransportFrameDecoder();
+    ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
+    ByteBuf frame = Unpooled.copyLong(size);
+    try {
+      decoder.channelRead(ctx, frame);
+    } finally {
+      frame.release();
+    }
+  }
+
+  private static class MockInterceptor implements 
TransportFrameDecoder.Interceptor {
+
+    private int remainingReads;
+
+    MockInterceptor(int readCount) {
+      this.remainingReads = readCount;
+    }
+
+    @Override
+    public boolean handle(ByteBuf data) throws Exception {
+      data.readerIndex(data.readerIndex() + data.readableBytes());
+      assertFalse(data.isReadable());
+      remainingReads -= 1;
+      return remainingReads != 0;
+    }
+
+    @Override
+    public void exceptionCaught(Throwable cause) throws Exception {
+
+    }
+
+    @Override
+    public void channelInactive() throws Exception {
+
+    }
+
+  }
+
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to