Repository: spark
Updated Branches:
  refs/heads/branch-1.6 a4e134827 -> ef6f8c262


http://git-wip-us.apache.org/repos/asf/spark/blob/ef6f8c26/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java
----------------------------------------------------------------------
diff --git 
a/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java
 
b/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java
index 42955ef..f9b5bf9 100644
--- 
a/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java
+++ 
b/network/common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java
@@ -31,6 +31,7 @@ import org.apache.spark.network.server.TransportServer;
 import org.apache.spark.network.util.MapConfigProvider;
 import org.apache.spark.network.util.TransportConf;
 import org.junit.*;
+import static org.junit.Assert.*;
 
 import java.io.IOException;
 import java.nio.ByteBuffer;
@@ -84,13 +85,16 @@ public class RequestTimeoutIntegrationSuite {
   @Test
   public void timeoutInactiveRequests() throws Exception {
     final Semaphore semaphore = new Semaphore(1);
-    final byte[] response = new byte[16];
+    final int responseSize = 16;
     RpcHandler handler = new RpcHandler() {
       @Override
-      public void receive(TransportClient client, byte[] message, 
RpcResponseCallback callback) {
+      public void receive(
+          TransportClient client,
+          ByteBuffer message,
+          RpcResponseCallback callback) {
         try {
           semaphore.tryAcquire(FOREVER, TimeUnit.MILLISECONDS);
-          callback.onSuccess(response);
+          callback.onSuccess(ByteBuffer.allocate(responseSize));
         } catch (InterruptedException e) {
           // do nothing
         }
@@ -110,15 +114,15 @@ public class RequestTimeoutIntegrationSuite {
     // First completes quickly (semaphore starts at 1).
     TestCallback callback0 = new TestCallback();
     synchronized (callback0) {
-      client.sendRpc(new byte[0], callback0);
+      client.sendRpc(ByteBuffer.allocate(0), callback0);
       callback0.wait(FOREVER);
-      assert (callback0.success.length == response.length);
+      assertEquals(responseSize, callback0.successLength);
     }
 
     // Second times out after 2 seconds, with slack. Must be IOException.
     TestCallback callback1 = new TestCallback();
     synchronized (callback1) {
-      client.sendRpc(new byte[0], callback1);
+      client.sendRpc(ByteBuffer.allocate(0), callback1);
       callback1.wait(4 * 1000);
       assert (callback1.failure != null);
       assert (callback1.failure instanceof IOException);
@@ -131,13 +135,16 @@ public class RequestTimeoutIntegrationSuite {
   @Test
   public void timeoutCleanlyClosesClient() throws Exception {
     final Semaphore semaphore = new Semaphore(0);
-    final byte[] response = new byte[16];
+    final int responseSize = 16;
     RpcHandler handler = new RpcHandler() {
       @Override
-      public void receive(TransportClient client, byte[] message, 
RpcResponseCallback callback) {
+      public void receive(
+          TransportClient client,
+          ByteBuffer message,
+          RpcResponseCallback callback) {
         try {
           semaphore.tryAcquire(FOREVER, TimeUnit.MILLISECONDS);
-          callback.onSuccess(response);
+          callback.onSuccess(ByteBuffer.allocate(responseSize));
         } catch (InterruptedException e) {
           // do nothing
         }
@@ -158,7 +165,7 @@ public class RequestTimeoutIntegrationSuite {
       clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
     TestCallback callback0 = new TestCallback();
     synchronized (callback0) {
-      client0.sendRpc(new byte[0], callback0);
+      client0.sendRpc(ByteBuffer.allocate(0), callback0);
       callback0.wait(FOREVER);
       assert (callback0.failure instanceof IOException);
       assert (!client0.isActive());
@@ -170,10 +177,10 @@ public class RequestTimeoutIntegrationSuite {
       clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
     TestCallback callback1 = new TestCallback();
     synchronized (callback1) {
-      client1.sendRpc(new byte[0], callback1);
+      client1.sendRpc(ByteBuffer.allocate(0), callback1);
       callback1.wait(FOREVER);
-      assert (callback1.success.length == response.length);
-      assert (callback1.failure == null);
+      assertEquals(responseSize, callback1.successLength);
+      assertNull(callback1.failure);
     }
   }
 
@@ -191,7 +198,10 @@ public class RequestTimeoutIntegrationSuite {
     };
     RpcHandler handler = new RpcHandler() {
       @Override
-      public void receive(TransportClient client, byte[] message, 
RpcResponseCallback callback) {
+      public void receive(
+          TransportClient client,
+          ByteBuffer message,
+          RpcResponseCallback callback) {
         throw new UnsupportedOperationException();
       }
 
@@ -218,9 +228,10 @@ public class RequestTimeoutIntegrationSuite {
 
     synchronized (callback0) {
       // not complete yet, but should complete soon
-      assert (callback0.success == null && callback0.failure == null);
+      assertEquals(-1, callback0.successLength);
+      assertNull(callback0.failure);
       callback0.wait(2 * 1000);
-      assert (callback0.failure instanceof IOException);
+      assertTrue(callback0.failure instanceof IOException);
     }
 
     synchronized (callback1) {
@@ -235,13 +246,13 @@ public class RequestTimeoutIntegrationSuite {
    */
   class TestCallback implements RpcResponseCallback, ChunkReceivedCallback {
 
-    byte[] success;
+    int successLength = -1;
     Throwable failure;
 
     @Override
-    public void onSuccess(byte[] response) {
+    public void onSuccess(ByteBuffer response) {
       synchronized(this) {
-        success = response;
+        successLength = response.remaining();
         this.notifyAll();
       }
     }
@@ -258,7 +269,7 @@ public class RequestTimeoutIntegrationSuite {
     public void onSuccess(int chunkIndex, ManagedBuffer buffer) {
       synchronized(this) {
         try {
-          success = buffer.nioByteBuffer().array();
+          successLength = buffer.nioByteBuffer().remaining();
           this.notifyAll();
         } catch (IOException e) {
           // weird

http://git-wip-us.apache.org/repos/asf/spark/blob/ef6f8c26/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
----------------------------------------------------------------------
diff --git 
a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
 
b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
index 88fa225..9e9be98 100644
--- 
a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
+++ 
b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
@@ -17,6 +17,7 @@
 
 package org.apache.spark.network;
 
+import java.nio.ByteBuffer;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashSet;
@@ -26,7 +27,6 @@ import java.util.Set;
 import java.util.concurrent.Semaphore;
 import java.util.concurrent.TimeUnit;
 
-import com.google.common.base.Charsets;
 import com.google.common.collect.Sets;
 import org.junit.AfterClass;
 import org.junit.BeforeClass;
@@ -41,6 +41,7 @@ import org.apache.spark.network.server.OneForOneStreamManager;
 import org.apache.spark.network.server.RpcHandler;
 import org.apache.spark.network.server.StreamManager;
 import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.util.JavaUtils;
 import org.apache.spark.network.util.SystemPropertyConfigProvider;
 import org.apache.spark.network.util.TransportConf;
 
@@ -55,11 +56,14 @@ public class RpcIntegrationSuite {
     TransportConf conf = new TransportConf("shuffle", new 
SystemPropertyConfigProvider());
     rpcHandler = new RpcHandler() {
       @Override
-      public void receive(TransportClient client, byte[] message, 
RpcResponseCallback callback) {
-        String msg = new String(message, Charsets.UTF_8);
+      public void receive(
+          TransportClient client,
+          ByteBuffer message,
+          RpcResponseCallback callback) {
+        String msg = JavaUtils.bytesToString(message);
         String[] parts = msg.split("/");
         if (parts[0].equals("hello")) {
-          callback.onSuccess(("Hello, " + parts[1] + 
"!").getBytes(Charsets.UTF_8));
+          callback.onSuccess(JavaUtils.stringToBytes("Hello, " + parts[1] + 
"!"));
         } else if (parts[0].equals("return error")) {
           callback.onFailure(new RuntimeException("Returned: " + parts[1]));
         } else if (parts[0].equals("throw error")) {
@@ -68,9 +72,8 @@ public class RpcIntegrationSuite {
       }
 
       @Override
-      public void receive(TransportClient client, byte[] message) {
-        String msg = new String(message, Charsets.UTF_8);
-        oneWayMsgs.add(msg);
+      public void receive(TransportClient client, ByteBuffer message) {
+        oneWayMsgs.add(JavaUtils.bytesToString(message));
       }
 
       @Override
@@ -103,8 +106,9 @@ public class RpcIntegrationSuite {
 
     RpcResponseCallback callback = new RpcResponseCallback() {
       @Override
-      public void onSuccess(byte[] message) {
-        res.successMessages.add(new String(message, Charsets.UTF_8));
+      public void onSuccess(ByteBuffer message) {
+        String response = JavaUtils.bytesToString(message);
+        res.successMessages.add(response);
         sem.release();
       }
 
@@ -116,7 +120,7 @@ public class RpcIntegrationSuite {
     };
 
     for (String command : commands) {
-      client.sendRpc(command.getBytes(Charsets.UTF_8), callback);
+      client.sendRpc(JavaUtils.stringToBytes(command), callback);
     }
 
     if (!sem.tryAcquire(commands.length, 5, TimeUnit.SECONDS)) {
@@ -173,7 +177,7 @@ public class RpcIntegrationSuite {
     final String message = "no reply";
     TransportClient client = 
clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
     try {
-      client.send(message.getBytes(Charsets.UTF_8));
+      client.send(JavaUtils.stringToBytes(message));
       assertEquals(0, client.getHandler().numOutstandingRequests());
 
       // Make sure the message arrives.

http://git-wip-us.apache.org/repos/asf/spark/blob/ef6f8c26/network/common/src/test/java/org/apache/spark/network/StreamSuite.java
----------------------------------------------------------------------
diff --git 
a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java 
b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java
index 538f3ef..9c49556 100644
--- a/network/common/src/test/java/org/apache/spark/network/StreamSuite.java
+++ b/network/common/src/test/java/org/apache/spark/network/StreamSuite.java
@@ -116,7 +116,10 @@ public class StreamSuite {
     };
     RpcHandler handler = new RpcHandler() {
       @Override
-      public void receive(TransportClient client, byte[] message, 
RpcResponseCallback callback) {
+      public void receive(
+          TransportClient client,
+          ByteBuffer message,
+          RpcResponseCallback callback) {
         throw new UnsupportedOperationException();
       }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ef6f8c26/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java
----------------------------------------------------------------------
diff --git 
a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java
 
b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java
index 30144f4..128f7cb 100644
--- 
a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java
+++ 
b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java
@@ -17,6 +17,8 @@
 
 package org.apache.spark.network;
 
+import java.nio.ByteBuffer;
+
 import io.netty.channel.Channel;
 import io.netty.channel.local.LocalChannel;
 import org.junit.Test;
@@ -27,6 +29,7 @@ import static org.mockito.Matchers.eq;
 import static org.mockito.Mockito.*;
 
 import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NioManagedBuffer;
 import org.apache.spark.network.client.ChunkReceivedCallback;
 import org.apache.spark.network.client.RpcResponseCallback;
 import org.apache.spark.network.client.StreamCallback;
@@ -42,7 +45,7 @@ import org.apache.spark.network.util.TransportFrameDecoder;
 
 public class TransportResponseHandlerSuite {
   @Test
-  public void handleSuccessfulFetch() {
+  public void handleSuccessfulFetch() throws Exception {
     StreamChunkId streamChunkId = new StreamChunkId(1, 0);
 
     TransportResponseHandler handler = new TransportResponseHandler(new 
LocalChannel());
@@ -56,7 +59,7 @@ public class TransportResponseHandlerSuite {
   }
 
   @Test
-  public void handleFailedFetch() {
+  public void handleFailedFetch() throws Exception {
     StreamChunkId streamChunkId = new StreamChunkId(1, 0);
     TransportResponseHandler handler = new TransportResponseHandler(new 
LocalChannel());
     ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
@@ -69,7 +72,7 @@ public class TransportResponseHandlerSuite {
   }
 
   @Test
-  public void clearAllOutstandingRequests() {
+  public void clearAllOutstandingRequests() throws Exception {
     TransportResponseHandler handler = new TransportResponseHandler(new 
LocalChannel());
     ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class);
     handler.addFetchRequest(new StreamChunkId(1, 0), callback);
@@ -88,23 +91,24 @@ public class TransportResponseHandlerSuite {
   }
 
   @Test
-  public void handleSuccessfulRPC() {
+  public void handleSuccessfulRPC() throws Exception {
     TransportResponseHandler handler = new TransportResponseHandler(new 
LocalChannel());
     RpcResponseCallback callback = mock(RpcResponseCallback.class);
     handler.addRpcRequest(12345, callback);
     assertEquals(1, handler.numOutstandingRequests());
 
-    handler.handle(new RpcResponse(54321, new byte[7])); // should be ignored
+    // This response should be ignored.
+    handler.handle(new RpcResponse(54321, new 
NioManagedBuffer(ByteBuffer.allocate(7))));
     assertEquals(1, handler.numOutstandingRequests());
 
-    byte[] arr = new byte[10];
-    handler.handle(new RpcResponse(12345, arr));
-    verify(callback, times(1)).onSuccess(eq(arr));
+    ByteBuffer resp = ByteBuffer.allocate(10);
+    handler.handle(new RpcResponse(12345, new NioManagedBuffer(resp)));
+    verify(callback, times(1)).onSuccess(eq(ByteBuffer.allocate(10)));
     assertEquals(0, handler.numOutstandingRequests());
   }
 
   @Test
-  public void handleFailedRPC() {
+  public void handleFailedRPC() throws Exception {
     TransportResponseHandler handler = new TransportResponseHandler(new 
LocalChannel());
     RpcResponseCallback callback = mock(RpcResponseCallback.class);
     handler.addRpcRequest(12345, callback);
@@ -119,7 +123,7 @@ public class TransportResponseHandlerSuite {
   }
 
   @Test
-  public void testActiveStreams() {
+  public void testActiveStreams() throws Exception {
     Channel c = new LocalChannel();
     c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new 
TransportFrameDecoder());
     TransportResponseHandler handler = new TransportResponseHandler(c);

http://git-wip-us.apache.org/repos/asf/spark/blob/ef6f8c26/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
----------------------------------------------------------------------
diff --git 
a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
 
b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
index a6f180b..751516b 100644
--- 
a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
+++ 
b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
@@ -22,7 +22,7 @@ import static org.mockito.Mockito.*;
 
 import java.io.File;
 import java.lang.reflect.Method;
-import java.nio.charset.StandardCharsets;
+import java.nio.ByteBuffer;
 import java.util.Arrays;
 import java.util.List;
 import java.util.Random;
@@ -57,6 +57,7 @@ import org.apache.spark.network.server.StreamManager;
 import org.apache.spark.network.server.TransportServer;
 import org.apache.spark.network.server.TransportServerBootstrap;
 import org.apache.spark.network.util.ByteArrayWritableChannel;
+import org.apache.spark.network.util.JavaUtils;
 import org.apache.spark.network.util.SystemPropertyConfigProvider;
 import org.apache.spark.network.util.TransportConf;
 
@@ -123,39 +124,53 @@ public class SparkSaslSuite {
   }
 
   @Test
-  public void testSaslAuthentication() throws Exception {
+  public void testSaslAuthentication() throws Throwable {
     testBasicSasl(false);
   }
 
   @Test
-  public void testSaslEncryption() throws Exception {
+  public void testSaslEncryption() throws Throwable {
     testBasicSasl(true);
   }
 
-  private void testBasicSasl(boolean encrypt) throws Exception {
+  private void testBasicSasl(boolean encrypt) throws Throwable {
     RpcHandler rpcHandler = mock(RpcHandler.class);
     doAnswer(new Answer<Void>() {
         @Override
         public Void answer(InvocationOnMock invocation) {
-          byte[] message = (byte[]) invocation.getArguments()[1];
+          ByteBuffer message = (ByteBuffer) invocation.getArguments()[1];
           RpcResponseCallback cb = (RpcResponseCallback) 
invocation.getArguments()[2];
-          assertEquals("Ping", new String(message, StandardCharsets.UTF_8));
-          cb.onSuccess("Pong".getBytes(StandardCharsets.UTF_8));
+          assertEquals("Ping", JavaUtils.bytesToString(message));
+          cb.onSuccess(JavaUtils.stringToBytes("Pong"));
           return null;
         }
       })
       .when(rpcHandler)
-      .receive(any(TransportClient.class), any(byte[].class), 
any(RpcResponseCallback.class));
+      .receive(any(TransportClient.class), any(ByteBuffer.class), 
any(RpcResponseCallback.class));
 
     SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false);
     try {
-      byte[] response = 
ctx.client.sendRpcSync("Ping".getBytes(StandardCharsets.UTF_8),
-                                               TimeUnit.SECONDS.toMillis(10));
-      assertEquals("Pong", new String(response, StandardCharsets.UTF_8));
+      ByteBuffer response = 
ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"),
+        TimeUnit.SECONDS.toMillis(10));
+      assertEquals("Pong", JavaUtils.bytesToString(response));
     } finally {
       ctx.close();
       // There should be 2 terminated events; one for the client, one for the 
server.
-      verify(rpcHandler, 
times(2)).connectionTerminated(any(TransportClient.class));
+      Throwable error = null;
+      long deadline = System.nanoTime() + TimeUnit.NANOSECONDS.convert(10, 
TimeUnit.SECONDS);
+      while (deadline > System.nanoTime()) {
+        try {
+          verify(rpcHandler, 
times(2)).connectionTerminated(any(TransportClient.class));
+          error = null;
+          break;
+        } catch (Throwable t) {
+          error = t;
+          TimeUnit.MILLISECONDS.sleep(10);
+        }
+      }
+      if (error != null) {
+        throw error;
+      }
     }
   }
 
@@ -325,8 +340,8 @@ public class SparkSaslSuite {
     SaslTestCtx ctx = null;
     try {
       ctx = new SaslTestCtx(mock(RpcHandler.class), true, true);
-      ctx.client.sendRpcSync("Ping".getBytes(StandardCharsets.UTF_8),
-                             TimeUnit.SECONDS.toMillis(10));
+      ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"),
+        TimeUnit.SECONDS.toMillis(10));
       fail("Should have failed to send RPC to server.");
     } catch (Exception e) {
       assertFalse(e.getCause() instanceof TimeoutException);

http://git-wip-us.apache.org/repos/asf/spark/blob/ef6f8c26/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java
----------------------------------------------------------------------
diff --git 
a/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java
 
b/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java
index 19475c2..d4de4a9 100644
--- 
a/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java
+++ 
b/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java
@@ -118,6 +118,27 @@ public class TransportFrameDecoderSuite {
     }
   }
 
+  @Test
+  public void testSplitLengthField() throws Exception {
+    byte[] frame = new byte[1024 * (RND.nextInt(31) + 1)];
+    ByteBuf buf = Unpooled.buffer(frame.length + 8);
+    buf.writeLong(frame.length + 8);
+    buf.writeBytes(frame);
+
+    TransportFrameDecoder decoder = new TransportFrameDecoder();
+    ChannelHandlerContext ctx = mockChannelHandlerContext();
+    try {
+      decoder.channelRead(ctx, buf.readSlice(RND.nextInt(7)).retain());
+      verify(ctx, never()).fireChannelRead(any(ByteBuf.class));
+      decoder.channelRead(ctx, buf);
+      verify(ctx).fireChannelRead(any(ByteBuf.class));
+      assertEquals(0, buf.refCnt());
+    } finally {
+      decoder.channelInactive(ctx);
+      release(buf);
+    }
+  }
+
   @Test(expected = IllegalArgumentException.class)
   public void testNegativeFrameSize() throws Exception {
     testInvalidFrame(-1);
@@ -183,7 +204,7 @@ public class TransportFrameDecoderSuite {
     try {
       decoder.channelRead(ctx, frame);
     } finally {
-      frame.release();
+      release(frame);
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ef6f8c26/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
----------------------------------------------------------------------
diff --git 
a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
 
b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
index 3ddf5c3..f22187a 100644
--- 
a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
+++ 
b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java
@@ -19,6 +19,7 @@ package org.apache.spark.network.shuffle;
 
 import java.io.File;
 import java.io.IOException;
+import java.nio.ByteBuffer;
 import java.util.List;
 
 import com.google.common.annotations.VisibleForTesting;
@@ -66,8 +67,8 @@ public class ExternalShuffleBlockHandler extends RpcHandler {
   }
 
   @Override
-  public void receive(TransportClient client, byte[] message, 
RpcResponseCallback callback) {
-    BlockTransferMessage msgObj = 
BlockTransferMessage.Decoder.fromByteArray(message);
+  public void receive(TransportClient client, ByteBuffer message, 
RpcResponseCallback callback) {
+    BlockTransferMessage msgObj = 
BlockTransferMessage.Decoder.fromByteBuffer(message);
     handleMessage(msgObj, client, callback);
   }
 
@@ -85,13 +86,13 @@ public class ExternalShuffleBlockHandler extends RpcHandler 
{
       }
       long streamId = streamManager.registerStream(client.getClientId(), 
blocks.iterator());
       logger.trace("Registered streamId {} with {} buffers", streamId, 
msg.blockIds.length);
-      callback.onSuccess(new StreamHandle(streamId, 
msg.blockIds.length).toByteArray());
+      callback.onSuccess(new StreamHandle(streamId, 
msg.blockIds.length).toByteBuffer());
 
     } else if (msgObj instanceof RegisterExecutor) {
       RegisterExecutor msg = (RegisterExecutor) msgObj;
       checkAuth(client, msg.appId);
       blockManager.registerExecutor(msg.appId, msg.execId, msg.executorInfo);
-      callback.onSuccess(new byte[0]);
+      callback.onSuccess(ByteBuffer.wrap(new byte[0]));
 
     } else {
       throw new UnsupportedOperationException("Unexpected message: " + msgObj);

http://git-wip-us.apache.org/repos/asf/spark/blob/ef6f8c26/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
----------------------------------------------------------------------
diff --git 
a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
 
b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
index ef3a9dc..58ca87d 100644
--- 
a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
+++ 
b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
@@ -18,6 +18,7 @@
 package org.apache.spark.network.shuffle;
 
 import java.io.IOException;
+import java.nio.ByteBuffer;
 import java.util.List;
 
 import com.google.common.base.Preconditions;
@@ -139,7 +140,7 @@ public class ExternalShuffleClient extends ShuffleClient {
     checkInit();
     TransportClient client = clientFactory.createUnmanagedClient(host, port);
     try {
-      byte[] registerMessage = new RegisterExecutor(appId, execId, 
executorInfo).toByteArray();
+      ByteBuffer registerMessage = new RegisterExecutor(appId, execId, 
executorInfo).toByteBuffer();
       client.sendRpcSync(registerMessage, 5000 /* timeoutMs */);
     } finally {
       client.close();

http://git-wip-us.apache.org/repos/asf/spark/blob/ef6f8c26/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
----------------------------------------------------------------------
diff --git 
a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
 
b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
index e653f5c..1b2ddbf 100644
--- 
a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
+++ 
b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
@@ -17,6 +17,7 @@
 
 package org.apache.spark.network.shuffle;
 
+import java.nio.ByteBuffer;
 import java.util.Arrays;
 
 import org.slf4j.Logger;
@@ -89,11 +90,11 @@ public class OneForOneBlockFetcher {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
 
-    client.sendRpc(openMessage.toByteArray(), new RpcResponseCallback() {
+    client.sendRpc(openMessage.toByteBuffer(), new RpcResponseCallback() {
       @Override
-      public void onSuccess(byte[] response) {
+      public void onSuccess(ByteBuffer response) {
         try {
-          streamHandle = (StreamHandle) 
BlockTransferMessage.Decoder.fromByteArray(response);
+          streamHandle = (StreamHandle) 
BlockTransferMessage.Decoder.fromByteBuffer(response);
           logger.trace("Successfully opened blocks {}, preparing to fetch 
chunks.", streamHandle);
 
           // Immediately request all chunks -- we expect that the total size 
of the request is

http://git-wip-us.apache.org/repos/asf/spark/blob/ef6f8c26/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java
----------------------------------------------------------------------
diff --git 
a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java
 
b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java
index 7543b6b..6758203 100644
--- 
a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java
+++ 
b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java
@@ -18,6 +18,7 @@
 package org.apache.spark.network.shuffle.mesos;
 
 import java.io.IOException;
+import java.nio.ByteBuffer;
 
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -54,11 +55,11 @@ public class MesosExternalShuffleClient extends 
ExternalShuffleClient {
 
   public void registerDriverWithShuffleService(String host, int port) throws 
IOException {
     checkInit();
-    byte[] registerDriver = new RegisterDriver(appId).toByteArray();
+    ByteBuffer registerDriver = new RegisterDriver(appId).toByteBuffer();
     TransportClient client = clientFactory.createClient(host, port);
     client.sendRpc(registerDriver, new RpcResponseCallback() {
       @Override
-      public void onSuccess(byte[] response) {
+      public void onSuccess(ByteBuffer response) {
         logger.info("Successfully registered app " + appId + " with external 
shuffle service.");
       }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ef6f8c26/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java
----------------------------------------------------------------------
diff --git 
a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java
 
b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java
index fcb5236..7fbe338 100644
--- 
a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java
+++ 
b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java
@@ -17,6 +17,8 @@
 
 package org.apache.spark.network.shuffle.protocol;
 
+import java.nio.ByteBuffer;
+
 import io.netty.buffer.ByteBuf;
 import io.netty.buffer.Unpooled;
 
@@ -53,7 +55,7 @@ public abstract class BlockTransferMessage implements 
Encodable {
   // NB: Java does not support static methods in interfaces, so we must put 
this in a static class.
   public static class Decoder {
     /** Deserializes the 'type' byte followed by the message itself. */
-    public static BlockTransferMessage fromByteArray(byte[] msg) {
+    public static BlockTransferMessage fromByteBuffer(ByteBuffer msg) {
       ByteBuf buf = Unpooled.wrappedBuffer(msg);
       byte type = buf.readByte();
       switch (type) {
@@ -68,12 +70,12 @@ public abstract class BlockTransferMessage implements 
Encodable {
   }
 
   /** Serializes the 'type' byte followed by the message itself. */
-  public byte[] toByteArray() {
+  public ByteBuffer toByteBuffer() {
     // Allow room for encoded message, plus the type byte
     ByteBuf buf = Unpooled.buffer(encodedLength() + 1);
     buf.writeByte(type().id);
     encode(buf);
     assert buf.writableBytes() == 0 : "Writable bytes remain: " + 
buf.writableBytes();
-    return buf.array();
+    return buf.nioBuffer();
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/ef6f8c26/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
----------------------------------------------------------------------
diff --git 
a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
 
b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
index 1c2fa4d..19c870a 100644
--- 
a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
+++ 
b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
@@ -18,6 +18,7 @@
 package org.apache.spark.network.sasl;
 
 import java.io.IOException;
+import java.nio.ByteBuffer;
 import java.util.Arrays;
 import java.util.concurrent.atomic.AtomicReference;
 
@@ -52,6 +53,7 @@ import 
org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
 import org.apache.spark.network.shuffle.protocol.OpenBlocks;
 import org.apache.spark.network.shuffle.protocol.RegisterExecutor;
 import org.apache.spark.network.shuffle.protocol.StreamHandle;
+import org.apache.spark.network.util.JavaUtils;
 import org.apache.spark.network.util.SystemPropertyConfigProvider;
 import org.apache.spark.network.util.TransportConf;
 
@@ -107,8 +109,8 @@ public class SaslIntegrationSuite {
 
     TransportClient client = 
clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
     String msg = "Hello, World!";
-    byte[] resp = client.sendRpcSync(msg.getBytes(), TIMEOUT_MS);
-    assertEquals(msg, new String(resp)); // our rpc handler should just return 
the given msg
+    ByteBuffer resp = client.sendRpcSync(JavaUtils.stringToBytes(msg), 
TIMEOUT_MS);
+    assertEquals(msg, JavaUtils.bytesToString(resp));
   }
 
   @Test
@@ -136,7 +138,7 @@ public class SaslIntegrationSuite {
 
     TransportClient client = 
clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
     try {
-      client.sendRpcSync(new byte[13], TIMEOUT_MS);
+      client.sendRpcSync(ByteBuffer.allocate(13), TIMEOUT_MS);
       fail("Should have failed");
     } catch (Exception e) {
       assertTrue(e.getMessage(), e.getMessage().contains("Expected 
SaslMessage"));
@@ -144,7 +146,7 @@ public class SaslIntegrationSuite {
 
     try {
       // Guessing the right tag byte doesn't magically get you in...
-      client.sendRpcSync(new byte[] { (byte) 0xEA }, TIMEOUT_MS);
+      client.sendRpcSync(ByteBuffer.wrap(new byte[] { (byte) 0xEA }), 
TIMEOUT_MS);
       fail("Should have failed");
     } catch (Exception e) {
       assertTrue(e.getMessage(), 
e.getMessage().contains("java.lang.IndexOutOfBoundsException"));
@@ -222,13 +224,13 @@ public class SaslIntegrationSuite {
         new String[] { System.getProperty("java.io.tmpdir") }, 1,
         "org.apache.spark.shuffle.sort.SortShuffleManager");
       RegisterExecutor regmsg = new RegisterExecutor("app-1", "0", 
executorInfo);
-      client1.sendRpcSync(regmsg.toByteArray(), TIMEOUT_MS);
+      client1.sendRpcSync(regmsg.toByteBuffer(), TIMEOUT_MS);
 
       // Make a successful request to fetch blocks, which creates a new 
stream. But do not actually
       // fetch any blocks, to keep the stream open.
       OpenBlocks openMessage = new OpenBlocks("app-1", "0", blockIds);
-      byte[] response = client1.sendRpcSync(openMessage.toByteArray(), 
TIMEOUT_MS);
-      StreamHandle stream = (StreamHandle) 
BlockTransferMessage.Decoder.fromByteArray(response);
+      ByteBuffer response = client1.sendRpcSync(openMessage.toByteBuffer(), 
TIMEOUT_MS);
+      StreamHandle stream = (StreamHandle) 
BlockTransferMessage.Decoder.fromByteBuffer(response);
       long streamId = stream.streamId;
 
       // Create a second client, authenticated with a different app ID, and 
try to read from
@@ -275,7 +277,7 @@ public class SaslIntegrationSuite {
   /** RPC handler which simply responds with the message it received. */
   public static class TestRpcHandler extends RpcHandler {
     @Override
-    public void receive(TransportClient client, byte[] message, 
RpcResponseCallback callback) {
+    public void receive(TransportClient client, ByteBuffer message, 
RpcResponseCallback callback) {
       callback.onSuccess(message);
     }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ef6f8c26/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java
----------------------------------------------------------------------
diff --git 
a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java
 
b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java
index d65de9c..86c8609 100644
--- 
a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java
+++ 
b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java
@@ -36,7 +36,7 @@ public class BlockTransferMessagesSuite {
   }
 
   private void checkSerializeDeserialize(BlockTransferMessage msg) {
-    BlockTransferMessage msg2 = 
BlockTransferMessage.Decoder.fromByteArray(msg.toByteArray());
+    BlockTransferMessage msg2 = 
BlockTransferMessage.Decoder.fromByteBuffer(msg.toByteBuffer());
     assertEquals(msg, msg2);
     assertEquals(msg.hashCode(), msg2.hashCode());
     assertEquals(msg.toString(), msg2.toString());

http://git-wip-us.apache.org/repos/asf/spark/blob/ef6f8c26/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
----------------------------------------------------------------------
diff --git 
a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
 
b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
index e61390c..9379412 100644
--- 
a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
+++ 
b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java
@@ -60,12 +60,12 @@ public class ExternalShuffleBlockHandlerSuite {
     RpcResponseCallback callback = mock(RpcResponseCallback.class);
 
     ExecutorShuffleInfo config = new ExecutorShuffleInfo(new String[] {"/a", 
"/b"}, 16, "sort");
-    byte[] registerMessage = new RegisterExecutor("app0", "exec1", 
config).toByteArray();
+    ByteBuffer registerMessage = new RegisterExecutor("app0", "exec1", 
config).toByteBuffer();
     handler.receive(client, registerMessage, callback);
     verify(blockResolver, times(1)).registerExecutor("app0", "exec1", config);
 
-    verify(callback, times(1)).onSuccess((byte[]) any());
-    verify(callback, never()).onFailure((Throwable) any());
+    verify(callback, times(1)).onSuccess(any(ByteBuffer.class));
+    verify(callback, never()).onFailure(any(Throwable.class));
   }
 
   @SuppressWarnings("unchecked")
@@ -77,17 +77,18 @@ public class ExternalShuffleBlockHandlerSuite {
     ManagedBuffer block1Marker = new NioManagedBuffer(ByteBuffer.wrap(new 
byte[7]));
     when(blockResolver.getBlockData("app0", "exec1", 
"b0")).thenReturn(block0Marker);
     when(blockResolver.getBlockData("app0", "exec1", 
"b1")).thenReturn(block1Marker);
-    byte[] openBlocks = new OpenBlocks("app0", "exec1", new String[] { "b0", 
"b1" }).toByteArray();
+    ByteBuffer openBlocks = new OpenBlocks("app0", "exec1", new String[] { 
"b0", "b1" })
+      .toByteBuffer();
     handler.receive(client, openBlocks, callback);
     verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b0");
     verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b1");
 
-    ArgumentCaptor<byte[]> response = ArgumentCaptor.forClass(byte[].class);
+    ArgumentCaptor<ByteBuffer> response = 
ArgumentCaptor.forClass(ByteBuffer.class);
     verify(callback, times(1)).onSuccess(response.capture());
     verify(callback, never()).onFailure((Throwable) any());
 
     StreamHandle handle =
-      (StreamHandle) 
BlockTransferMessage.Decoder.fromByteArray(response.getValue());
+      (StreamHandle) 
BlockTransferMessage.Decoder.fromByteBuffer(response.getValue());
     assertEquals(2, handle.numChunks);
 
     @SuppressWarnings("unchecked")
@@ -104,7 +105,7 @@ public class ExternalShuffleBlockHandlerSuite {
   public void testBadMessages() {
     RpcResponseCallback callback = mock(RpcResponseCallback.class);
 
-    byte[] unserializableMsg = new byte[] { 0x12, 0x34, 0x56 };
+    ByteBuffer unserializableMsg = ByteBuffer.wrap(new byte[] { 0x12, 0x34, 
0x56 });
     try {
       handler.receive(client, unserializableMsg, callback);
       fail("Should have thrown");
@@ -112,7 +113,7 @@ public class ExternalShuffleBlockHandlerSuite {
       // pass
     }
 
-    byte[] unexpectedMsg = new UploadBlock("a", "e", "b", new byte[1], new 
byte[2]).toByteArray();
+    ByteBuffer unexpectedMsg = new UploadBlock("a", "e", "b", new byte[1], new 
byte[2]).toByteBuffer();
     try {
       handler.receive(client, unexpectedMsg, callback);
       fail("Should have thrown");
@@ -120,7 +121,7 @@ public class ExternalShuffleBlockHandlerSuite {
       // pass
     }
 
-    verify(callback, never()).onSuccess((byte[]) any());
-    verify(callback, never()).onFailure((Throwable) any());
+    verify(callback, never()).onSuccess(any(ByteBuffer.class));
+    verify(callback, never()).onFailure(any(Throwable.class));
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/ef6f8c26/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java
----------------------------------------------------------------------
diff --git 
a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java
 
b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java
index b35a6d6..2590b9c 100644
--- 
a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java
+++ 
b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java
@@ -134,14 +134,14 @@ public class OneForOneBlockFetcherSuite {
     doAnswer(new Answer<Void>() {
       @Override
       public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
-        BlockTransferMessage message = 
BlockTransferMessage.Decoder.fromByteArray(
-          (byte[]) invocationOnMock.getArguments()[0]);
+        BlockTransferMessage message = 
BlockTransferMessage.Decoder.fromByteBuffer(
+          (ByteBuffer) invocationOnMock.getArguments()[0]);
         RpcResponseCallback callback = (RpcResponseCallback) 
invocationOnMock.getArguments()[1];
-        callback.onSuccess(new StreamHandle(123, blocks.size()).toByteArray());
+        callback.onSuccess(new StreamHandle(123, 
blocks.size()).toByteBuffer());
         assertEquals(new OpenBlocks("app-id", "exec-id", blockIds), message);
         return null;
       }
-    }).when(client).sendRpc((byte[]) any(), (RpcResponseCallback) any());
+    }).when(client).sendRpc(any(ByteBuffer.class), 
any(RpcResponseCallback.class));
 
     // Respond to each chunk request with a single buffer from our blocks 
array.
     final AtomicInteger expectedChunkIndex = new AtomicInteger(0);


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to