HADOOP-13547. Optimize IPC client protobuf decoding. Contributed by Daryn Sharp.


Project: http://git-wip-us.apache.org/repos/asf/hadoop/repo
Commit: http://git-wip-us.apache.org/repos/asf/hadoop/commit/23abb09c
Tree: http://git-wip-us.apache.org/repos/asf/hadoop/tree/23abb09c
Diff: http://git-wip-us.apache.org/repos/asf/hadoop/diff/23abb09c

Branch: refs/heads/YARN-3368
Commit: 23abb09c1f979d8c18ece81e32630a35ed569399
Parents: 05f5c0f
Author: Kihwal Lee <kih...@apache.org>
Authored: Fri Sep 2 11:03:18 2016 -0500
Committer: Kihwal Lee <kih...@apache.org>
Committed: Fri Sep 2 11:03:18 2016 -0500

----------------------------------------------------------------------
 .../main/java/org/apache/hadoop/ipc/Client.java |  77 +++---
 .../apache/hadoop/ipc/ProtobufRpcEngine.java    | 241 +++----------------
 .../org/apache/hadoop/ipc/ResponseBuffer.java   |  12 +-
 .../java/org/apache/hadoop/ipc/RpcWritable.java |  20 +-
 .../apache/hadoop/security/SaslRpcClient.java   |  45 ++--
 5 files changed, 102 insertions(+), 293 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/hadoop/blob/23abb09c/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Client.java
----------------------------------------------------------------------
diff --git 
a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Client.java
 
b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Client.java
index 183bad4..567b932 100644
--- 
a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Client.java
+++ 
b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Client.java
@@ -21,7 +21,6 @@ package org.apache.hadoop.ipc;
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Preconditions;
 import com.google.common.util.concurrent.ThreadFactoryBuilder;
-import com.google.protobuf.CodedOutputStream;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.hadoop.classification.InterfaceAudience;
@@ -31,13 +30,11 @@ import 
org.apache.hadoop.classification.InterfaceStability.Unstable;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.CommonConfigurationKeys;
 import org.apache.hadoop.fs.CommonConfigurationKeysPublic;
-import org.apache.hadoop.io.DataOutputBuffer;
 import org.apache.hadoop.io.IOUtils;
 import org.apache.hadoop.io.Writable;
 import org.apache.hadoop.io.retry.RetryPolicies;
 import org.apache.hadoop.io.retry.RetryPolicy;
 import org.apache.hadoop.io.retry.RetryPolicy.RetryAction;
-import org.apache.hadoop.ipc.ProtobufRpcEngine.RpcRequestMessageWrapper;
 import org.apache.hadoop.ipc.RPC.RpcKind;
 import org.apache.hadoop.ipc.Server.AuthProtocol;
 import 
org.apache.hadoop.ipc.protobuf.IpcConnectionContextProtos.IpcConnectionContextProto;
@@ -54,7 +51,6 @@ import org.apache.hadoop.security.SaslRpcServer.AuthMethod;
 import org.apache.hadoop.security.SecurityUtil;
 import org.apache.hadoop.security.UserGroupInformation;
 import org.apache.hadoop.util.ProtoUtil;
-import org.apache.hadoop.util.ReflectionUtils;
 import org.apache.hadoop.util.StringUtils;
 import org.apache.hadoop.util.Time;
 import org.apache.hadoop.util.concurrent.AsyncGet;
@@ -65,6 +61,7 @@ import javax.net.SocketFactory;
 import javax.security.sasl.Sasl;
 import java.io.*;
 import java.net.*;
+import java.nio.ByteBuffer;
 import java.security.PrivilegedExceptionAction;
 import java.util.*;
 import java.util.Map.Entry;
@@ -429,7 +426,7 @@ public class Client implements AutoCloseable {
     private final boolean doPing; //do we need to send ping message
     private final int pingInterval; // how often sends ping to the server
     private final int soTimeout; // used by ipc ping and rpc timeout
-    private ByteArrayOutputStream pingRequest; // ping message
+    private ResponseBuffer pingRequest; // ping message
     
     // currently active calls
     private Hashtable<Integer, Call> calls = new Hashtable<Integer, Call>();
@@ -459,7 +456,7 @@ public class Client implements AutoCloseable {
       this.doPing = remoteId.getDoPing();
       if (doPing) {
         // construct a RPC header with the callId as the ping callId
-        pingRequest = new ByteArrayOutputStream();
+        pingRequest = new ResponseBuffer();
         RpcRequestHeaderProto pingHeader = ProtoUtil
             .makeRpcRequestHeader(RpcKind.RPC_PROTOCOL_BUFFER,
                 OperationProto.RPC_FINAL_PACKET, PING_CALL_ID,
@@ -979,12 +976,10 @@ public class Client implements AutoCloseable {
           .makeRpcRequestHeader(RpcKind.RPC_PROTOCOL_BUFFER,
               OperationProto.RPC_FINAL_PACKET, CONNECTION_CONTEXT_CALL_ID,
               RpcConstants.INVALID_RETRY_COUNT, clientId);
-      RpcRequestMessageWrapper request =
-          new RpcRequestMessageWrapper(connectionContextHeader, message);
-      
-      // Write out the packet length
-      out.writeInt(request.getLength());
-      request.write(out);
+      final ResponseBuffer buf = new ResponseBuffer();
+      connectionContextHeader.writeDelimitedTo(buf);
+      message.writeDelimitedTo(buf);
+      buf.writeTo(out);
     }
     
     /* wait till someone signals us to start reading RPC response or
@@ -1030,7 +1025,6 @@ public class Client implements AutoCloseable {
       if ( curTime - lastActivity.get() >= pingInterval) {
         lastActivity.set(curTime);
         synchronized (out) {
-          out.writeInt(pingRequest.size());
           pingRequest.writeTo(out);
           out.flush();
         }
@@ -1085,12 +1079,13 @@ public class Client implements AutoCloseable {
       // 2) RpcRequest
       //
       // Items '1' and '2' are prepared here. 
-      final DataOutputBuffer d = new DataOutputBuffer();
       RpcRequestHeaderProto header = ProtoUtil.makeRpcRequestHeader(
           call.rpcKind, OperationProto.RPC_FINAL_PACKET, call.id, call.retry,
           clientId);
-      header.writeDelimitedTo(d);
-      call.rpcRequest.write(d);
+
+      final ResponseBuffer buf = new ResponseBuffer();
+      header.writeDelimitedTo(buf);
+      RpcWritable.wrap(call.rpcRequest).writeTo(buf);
 
       synchronized (sendRpcRequestLock) {
         Future<?> senderFuture = sendParamsExecutor.submit(new Runnable() {
@@ -1101,14 +1096,10 @@ public class Client implements AutoCloseable {
                 if (shouldCloseConnection.get()) {
                   return;
                 }
-                
-                if (LOG.isDebugEnabled())
+                if (LOG.isDebugEnabled()) {
                   LOG.debug(getName() + " sending #" + call.id);
-         
-                byte[] data = d.getData();
-                int totalLength = d.getLength();
-                out.writeInt(totalLength); // Total Length
-                out.write(data, 0, totalLength);// RpcRequestHeader + 
RpcRequest
+                }
+                buf.writeTo(out); // RpcRequestHeader + RpcRequest
                 out.flush();
               }
             } catch (IOException e) {
@@ -1119,7 +1110,7 @@ public class Client implements AutoCloseable {
             } finally {
               //the buffer is just an in-memory buffer, but it is still polite 
to
               // close early
-              IOUtils.closeStream(d);
+              IOUtils.closeStream(buf);
             }
           }
         });
@@ -1151,12 +1142,13 @@ public class Client implements AutoCloseable {
       
       try {
         int totalLen = in.readInt();
-        RpcResponseHeaderProto header = 
-            RpcResponseHeaderProto.parseDelimitedFrom(in);
-        checkResponse(header);
+        ByteBuffer bb = ByteBuffer.allocate(totalLen);
+        in.readFully(bb.array());
 
-        int headerLen = header.getSerializedSize();
-        headerLen += CodedOutputStream.computeRawVarint32Size(headerLen);
+        RpcWritable.Buffer packet = RpcWritable.Buffer.wrap(bb);
+        RpcResponseHeaderProto header =
+            packet.getValue(RpcResponseHeaderProto.getDefaultInstance());
+        checkResponse(header);
 
         int callId = header.getCallId();
         if (LOG.isDebugEnabled())
@@ -1164,28 +1156,15 @@ public class Client implements AutoCloseable {
 
         RpcStatusProto status = header.getStatus();
         if (status == RpcStatusProto.SUCCESS) {
-          Writable value = ReflectionUtils.newInstance(valueClass, conf);
-          value.readFields(in);                 // read value
+          Writable value = packet.newInstance(valueClass, conf);
           final Call call = calls.remove(callId);
           call.setRpcResponse(value);
-          
-          // verify that length was correct
-          // only for ProtobufEngine where len can be verified easily
-          if (call.getRpcResponse() instanceof ProtobufRpcEngine.RpcWrapper) {
-            ProtobufRpcEngine.RpcWrapper resWrapper = 
-                (ProtobufRpcEngine.RpcWrapper) call.getRpcResponse();
-            if (totalLen != headerLen + resWrapper.getLength()) { 
-              throw new RpcClientException(
-                  "RPC response length mismatch on rpc success");
-            }
-          }
-        } else { // Rpc Request failed
-          // Verify that length was correct
-          if (totalLen != headerLen) {
-            throw new RpcClientException(
-                "RPC response length mismatch on rpc error");
-          }
-          
+        }
+        // verify that packet length was correct
+        if (packet.remaining() > 0) {
+          throw new RpcClientException("RPC response length mismatch");
+        }
+        if (status != RpcStatusProto.SUCCESS) { // Rpc Request failed
           final String exceptionClassName = header.hasExceptionClassName() ?
                 header.getExceptionClassName() : 
                   "ServerDidNotSetExceptionClassName";

http://git-wip-us.apache.org/repos/asf/hadoop/blob/23abb09c/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/ProtobufRpcEngine.java
----------------------------------------------------------------------
diff --git 
a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/ProtobufRpcEngine.java
 
b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/ProtobufRpcEngine.java
index eb30aa2..83e4b9e 100644
--- 
a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/ProtobufRpcEngine.java
+++ 
b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/ProtobufRpcEngine.java
@@ -27,29 +27,22 @@ import org.apache.hadoop.classification.InterfaceAudience;
 import org.apache.hadoop.classification.InterfaceStability;
 import org.apache.hadoop.classification.InterfaceStability.Unstable;
 import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.io.DataOutputOutputStream;
 import org.apache.hadoop.io.Writable;
 import org.apache.hadoop.io.retry.RetryPolicy;
 import org.apache.hadoop.ipc.Client.ConnectionId;
 import org.apache.hadoop.ipc.RPC.RpcInvoker;
 import org.apache.hadoop.ipc.RpcWritable;
 import 
org.apache.hadoop.ipc.protobuf.ProtobufRpcEngineProtos.RequestHeaderProto;
-import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcRequestHeaderProto;
-import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcResponseHeaderProto;
 import org.apache.hadoop.security.UserGroupInformation;
 import org.apache.hadoop.security.token.SecretManager;
 import org.apache.hadoop.security.token.TokenIdentifier;
-import org.apache.hadoop.util.ProtoUtil;
 import org.apache.hadoop.util.Time;
 import org.apache.hadoop.util.concurrent.AsyncGet;
 import org.apache.htrace.core.TraceScope;
 import org.apache.htrace.core.Tracer;
 
 import javax.net.SocketFactory;
-import java.io.DataInput;
-import java.io.DataOutput;
 import java.io.IOException;
-import java.io.OutputStream;
 import java.lang.reflect.Method;
 import java.lang.reflect.Proxy;
 import java.net.InetSocketAddress;
@@ -146,7 +139,7 @@ public class ProtobufRpcEngine implements RpcEngine {
     private Invoker(Class<?> protocol, Client.ConnectionId connId,
         Configuration conf, SocketFactory factory) {
       this.remoteId = connId;
-      this.client = CLIENTS.getClient(conf, factory, RpcResponseWrapper.class);
+      this.client = CLIENTS.getClient(conf, factory, RpcWritable.Buffer.class);
       this.protocolName = RPC.getProtocolName(protocol);
       this.clientProtocolVersion = RPC
           .getProtocolVersion(protocol);
@@ -193,7 +186,7 @@ public class ProtobufRpcEngine implements RpcEngine {
      * the server.
      */
     @Override
-    public Object invoke(Object proxy, final Method method, Object[] args)
+    public Message invoke(Object proxy, final Method method, Object[] args)
         throws ServiceException {
       long startTime = 0;
       if (LOG.isDebugEnabled()) {
@@ -228,11 +221,11 @@ public class ProtobufRpcEngine implements RpcEngine {
       }
 
 
-      Message theRequest = (Message) args[1];
-      final RpcResponseWrapper val;
+      final Message theRequest = (Message) args[1];
+      final RpcWritable.Buffer val;
       try {
-        val = (RpcResponseWrapper) client.call(RPC.RpcKind.RPC_PROTOCOL_BUFFER,
-            new RpcRequestWrapper(rpcRequestHeader, theRequest), remoteId,
+        val = (RpcWritable.Buffer) client.call(RPC.RpcKind.RPC_PROTOCOL_BUFFER,
+            new RpcProtobufRequest(rpcRequestHeader, theRequest), remoteId,
             fallbackToSimpleAuth);
 
       } catch (Throwable e) {
@@ -256,7 +249,7 @@ public class ProtobufRpcEngine implements RpcEngine {
       }
       
       if (Client.isAsynchronousMode()) {
-        final AsyncGet<RpcResponseWrapper, IOException> arr
+        final AsyncGet<RpcWritable.Buffer, IOException> arr
             = Client.getAsyncRpcResponse();
         final AsyncGet<Message, Exception> asyncGet
             = new AsyncGet<Message, Exception>() {
@@ -278,7 +271,7 @@ public class ProtobufRpcEngine implements RpcEngine {
     }
 
     private Message getReturnMessage(final Method method,
-        final RpcResponseWrapper rrw) throws ServiceException {
+        final RpcWritable.Buffer buf) throws ServiceException {
       Message prototype = null;
       try {
         prototype = getReturnProtoType(method);
@@ -287,8 +280,7 @@ public class ProtobufRpcEngine implements RpcEngine {
       }
       Message returnMessage;
       try {
-        returnMessage = prototype.newBuilderForType()
-            .mergeFrom(rrw.theResponseRead).build();
+        returnMessage = buf.getValue(prototype.getDefaultInstanceForType());
 
         if (LOG.isTraceEnabled()) {
           LOG.trace(Thread.currentThread().getId() + ": Response <- " +
@@ -329,201 +321,12 @@ public class ProtobufRpcEngine implements RpcEngine {
     }
   }
 
-  interface RpcWrapper extends Writable {
-    int getLength();
-  }
-  /**
-   * Wrapper for Protocol Buffer Requests
-   * 
-   * Note while this wrapper is writable, the request on the wire is in
-   * Protobuf. Several methods on {@link org.apache.hadoop.ipc.Server and RPC} 
-   * use type Writable as a wrapper to work across multiple RpcEngine kinds.
-   */
-  private static abstract class RpcMessageWithHeader<T extends 
GeneratedMessage>
-    implements RpcWrapper {
-    T requestHeader;
-    Message theRequest; // for clientSide, the request is here
-    byte[] theRequestRead; // for server side, the request is here
-
-    public RpcMessageWithHeader() {
-    }
-
-    public RpcMessageWithHeader(T requestHeader, Message theRequest) {
-      this.requestHeader = requestHeader;
-      this.theRequest = theRequest;
-    }
-
-    @Override
-    public void write(DataOutput out) throws IOException {
-      OutputStream os = DataOutputOutputStream.constructOutputStream(out);
-      
-      ((Message)requestHeader).writeDelimitedTo(os);
-      theRequest.writeDelimitedTo(os);
-    }
-
-    @Override
-    public void readFields(DataInput in) throws IOException {
-      requestHeader = parseHeaderFrom(readVarintBytes(in));
-      theRequestRead = readMessageRequest(in);
-    }
-
-    abstract T parseHeaderFrom(byte[] bytes) throws IOException;
-
-    byte[] readMessageRequest(DataInput in) throws IOException {
-      return readVarintBytes(in);
-    }
-
-    private static byte[] readVarintBytes(DataInput in) throws IOException {
-      final int length = ProtoUtil.readRawVarint32(in);
-      final byte[] bytes = new byte[length];
-      in.readFully(bytes);
-      return bytes;
-    }
-
-    public T getMessageHeader() {
-      return requestHeader;
-    }
-
-    public byte[] getMessageBytes() {
-      return theRequestRead;
-    }
-    
-    @Override
-    public int getLength() {
-      int headerLen = requestHeader.getSerializedSize();
-      int reqLen;
-      if (theRequest != null) {
-        reqLen = theRequest.getSerializedSize();
-      } else if (theRequestRead != null ) {
-        reqLen = theRequestRead.length;
-      } else {
-        throw new IllegalArgumentException(
-            "getLength on uninitialized RpcWrapper");      
-      }
-      return CodedOutputStream.computeRawVarint32Size(headerLen) +  headerLen
-          + CodedOutputStream.computeRawVarint32Size(reqLen) + reqLen;
-    }
-  }
-  
-  private static class RpcRequestWrapper
-  extends RpcMessageWithHeader<RequestHeaderProto> {
-    @SuppressWarnings("unused")
-    public RpcRequestWrapper() {}
-    
-    public RpcRequestWrapper(
-        RequestHeaderProto requestHeader, Message theRequest) {
-      super(requestHeader, theRequest);
-    }
-    
-    @Override
-    RequestHeaderProto parseHeaderFrom(byte[] bytes) throws IOException {
-      return RequestHeaderProto.parseFrom(bytes);
-    }
-    
-    @Override
-    public String toString() {
-      return requestHeader.getDeclaringClassProtocolName() + "." +
-          requestHeader.getMethodName();
-    }
-  }
-
-  @InterfaceAudience.LimitedPrivate({"RPC"})
-  public static class RpcRequestMessageWrapper
-  extends RpcMessageWithHeader<RpcRequestHeaderProto> {
-    public RpcRequestMessageWrapper() {}
-    
-    public RpcRequestMessageWrapper(
-        RpcRequestHeaderProto requestHeader, Message theRequest) {
-      super(requestHeader, theRequest);
-    }
-    
-    @Override
-    RpcRequestHeaderProto parseHeaderFrom(byte[] bytes) throws IOException {
-      return RpcRequestHeaderProto.parseFrom(bytes);
-    }
-  }
-
-  @InterfaceAudience.LimitedPrivate({"RPC"})
-  public static class RpcResponseMessageWrapper
-  extends RpcMessageWithHeader<RpcResponseHeaderProto> {
-    public RpcResponseMessageWrapper() {}
-    
-    public RpcResponseMessageWrapper(
-        RpcResponseHeaderProto responseHeader, Message theRequest) {
-      super(responseHeader, theRequest);
-    }
-    
-    @Override
-    byte[] readMessageRequest(DataInput in) throws IOException {
-      // error message contain no message body
-      switch (requestHeader.getStatus()) {
-        case ERROR:
-        case FATAL:
-          return null;
-        default:
-          return super.readMessageRequest(in);
-      }
-    }
-    
-    @Override
-    RpcResponseHeaderProto parseHeaderFrom(byte[] bytes) throws IOException {
-      return RpcResponseHeaderProto.parseFrom(bytes);
-    }
-  }
-
-  /**
-   *  Wrapper for Protocol Buffer Responses
-   * 
-   * Note while this wrapper is writable, the request on the wire is in
-   * Protobuf. Several methods on {@link org.apache.hadoop.ipc.Server and RPC} 
-   * use type Writable as a wrapper to work across multiple RpcEngine kinds.
-   */
-  @InterfaceAudience.LimitedPrivate({"RPC"}) // temporarily exposed 
-  public static class RpcResponseWrapper implements RpcWrapper {
-    Message theResponse; // for senderSide, the response is here
-    byte[] theResponseRead; // for receiver side, the response is here
-
-    public RpcResponseWrapper() {
-    }
-
-    public RpcResponseWrapper(Message message) {
-      this.theResponse = message;
-    }
-
-    @Override
-    public void write(DataOutput out) throws IOException {
-      OutputStream os = DataOutputOutputStream.constructOutputStream(out);
-      theResponse.writeDelimitedTo(os);   
-    }
-
-    @Override
-    public void readFields(DataInput in) throws IOException {
-      int length = ProtoUtil.readRawVarint32(in);
-      theResponseRead = new byte[length];
-      in.readFully(theResponseRead);
-    }
-    
-    @Override
-    public int getLength() {
-      int resLen;
-      if (theResponse != null) {
-        resLen = theResponse.getSerializedSize();
-      } else if (theResponseRead != null ) {
-        resLen = theResponseRead.length;
-      } else {
-        throw new IllegalArgumentException(
-            "getLength on uninitialized RpcWrapper");      
-      }
-      return CodedOutputStream.computeRawVarint32Size(resLen) + resLen;
-    }
-  }
-
   @VisibleForTesting
   @InterfaceAudience.Private
   @InterfaceStability.Unstable
   static Client getClient(Configuration conf) {
     return CLIENTS.getClient(conf, SocketFactory.getDefault(),
-        RpcResponseWrapper.class);
+        RpcWritable.Buffer.class);
   }
   
  
@@ -691,16 +494,30 @@ public class ProtobufRpcEngine implements RpcEngine {
   // which uses the rpc header.  in the normal case we want to defer decoding
   // the rpc header until needed by the rpc engine.
   static class RpcProtobufRequest extends RpcWritable.Buffer {
-    private RequestHeaderProto lazyHeader;
+    private volatile RequestHeaderProto requestHeader;
+    private Message payload;
 
     public RpcProtobufRequest() {
     }
 
-    synchronized RequestHeaderProto getRequestHeader() throws IOException {
-      if (lazyHeader == null) {
-        lazyHeader = getValue(RequestHeaderProto.getDefaultInstance());
+    RpcProtobufRequest(RequestHeaderProto header, Message payload) {
+      this.requestHeader = header;
+      this.payload = payload;
+    }
+
+    RequestHeaderProto getRequestHeader() throws IOException {
+      if (getByteBuffer() != null && requestHeader == null) {
+        requestHeader = getValue(RequestHeaderProto.getDefaultInstance());
+      }
+      return requestHeader;
+    }
+
+    @Override
+    public void writeTo(ResponseBuffer out) throws IOException {
+      requestHeader.writeDelimitedTo(out);
+      if (payload != null) {
+        payload.writeDelimitedTo(out);
       }
-      return lazyHeader;
     }
 
     // this is used by htrace to name the span.

http://git-wip-us.apache.org/repos/asf/hadoop/blob/23abb09c/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/ResponseBuffer.java
----------------------------------------------------------------------
diff --git 
a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/ResponseBuffer.java
 
b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/ResponseBuffer.java
index ac96a24..a789d83 100644
--- 
a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/ResponseBuffer.java
+++ 
b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/ResponseBuffer.java
@@ -27,8 +27,14 @@ import java.util.Arrays;
 import org.apache.hadoop.classification.InterfaceAudience;
 
 @InterfaceAudience.Private
-class ResponseBuffer extends DataOutputStream {
-  ResponseBuffer(int capacity) {
+/** generates byte-length framed buffers. */
+public class ResponseBuffer extends DataOutputStream {
+
+  public ResponseBuffer() {
+    this(1024);
+  }
+
+  public ResponseBuffer(int capacity) {
     super(new FramedBuffer(capacity));
   }
 
@@ -39,7 +45,7 @@ class ResponseBuffer extends DataOutputStream {
     return buf;
   }
 
-  void writeTo(OutputStream out) throws IOException {
+  public void writeTo(OutputStream out) throws IOException {
     getFramedBuffer().writeTo(out);
   }
 

http://git-wip-us.apache.org/repos/asf/hadoop/blob/23abb09c/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/RpcWritable.java
----------------------------------------------------------------------
diff --git 
a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/RpcWritable.java
 
b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/RpcWritable.java
index 5125939..54fb98e 100644
--- 
a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/RpcWritable.java
+++ 
b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/RpcWritable.java
@@ -24,7 +24,6 @@ import java.io.DataInputStream;
 import java.io.DataOutput;
 import java.io.IOException;
 import java.nio.ByteBuffer;
-
 import org.apache.hadoop.classification.InterfaceAudience;
 import org.apache.hadoop.conf.Configurable;
 import org.apache.hadoop.conf.Configuration;
@@ -34,6 +33,7 @@ import com.google.protobuf.CodedInputStream;
 import com.google.protobuf.CodedOutputStream;
 import com.google.protobuf.Message;
 
+// note anything marked public is solely for access by SaslRpcClient
 @InterfaceAudience.Private
 public abstract class RpcWritable implements Writable {
 
@@ -99,6 +99,10 @@ public abstract class RpcWritable implements Writable {
       this.message = message;
     }
 
+    Message getMessage() {
+      return message;
+    }
+
     @Override
     void writeTo(ResponseBuffer out) throws IOException {
       int length = message.getSerializedSize();
@@ -128,11 +132,13 @@ public abstract class RpcWritable implements Writable {
     }
   }
 
-  // adapter to allow decoding of writables and protobufs from a byte buffer.
-  static class Buffer extends RpcWritable {
+  /**
+   * adapter to allow decoding of writables and protobufs from a byte buffer.
+   */
+  public static class Buffer extends RpcWritable {
     private ByteBuffer bb;
 
-    static Buffer wrap(ByteBuffer bb) {
+    public static Buffer wrap(ByteBuffer bb) {
       return new Buffer(bb);
     }
 
@@ -142,6 +148,10 @@ public abstract class RpcWritable implements Writable {
       this.bb = bb;
     }
 
+    ByteBuffer getByteBuffer() {
+      return bb;
+    }
+
     @Override
     void writeTo(ResponseBuffer out) throws IOException {
       out.ensureCapacity(bb.remaining());
@@ -177,7 +187,7 @@ public abstract class RpcWritable implements Writable {
       return RpcWritable.wrap(value).readFrom(bb);
     }
 
-    int remaining() {
+    public int remaining() {
       return bb.remaining();
     }
   }

http://git-wip-us.apache.org/repos/asf/hadoop/blob/23abb09c/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/security/SaslRpcClient.java
----------------------------------------------------------------------
diff --git 
a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/security/SaslRpcClient.java
 
b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/security/SaslRpcClient.java
index c360937..60ae3b0 100644
--- 
a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/security/SaslRpcClient.java
+++ 
b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/security/SaslRpcClient.java
@@ -53,11 +53,11 @@ import org.apache.hadoop.classification.InterfaceAudience;
 import org.apache.hadoop.classification.InterfaceStability;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.GlobPattern;
-import org.apache.hadoop.ipc.ProtobufRpcEngine.RpcRequestMessageWrapper;
-import org.apache.hadoop.ipc.ProtobufRpcEngine.RpcResponseMessageWrapper;
 import org.apache.hadoop.ipc.RPC.RpcKind;
 import org.apache.hadoop.ipc.RemoteException;
+import org.apache.hadoop.ipc.ResponseBuffer;
 import org.apache.hadoop.ipc.RpcConstants;
+import org.apache.hadoop.ipc.RpcWritable;
 import org.apache.hadoop.ipc.Server.AuthProtocol;
 import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcRequestHeaderProto;
 import 
org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcRequestHeaderProto.OperationProto;
@@ -368,11 +368,13 @@ public class SaslRpcClient {
     // loop until sasl is complete or a rpc error occurs
     boolean done = false;
     do {
-      int totalLen = inStream.readInt();
-      RpcResponseMessageWrapper responseWrapper =
-          new RpcResponseMessageWrapper();
-      responseWrapper.readFields(inStream);
-      RpcResponseHeaderProto header = responseWrapper.getMessageHeader();
+      int rpcLen = inStream.readInt();
+      ByteBuffer bb = ByteBuffer.allocate(rpcLen);
+      inStream.readFully(bb.array());
+
+      RpcWritable.Buffer saslPacket = RpcWritable.Buffer.wrap(bb);
+      RpcResponseHeaderProto header =
+          saslPacket.getValue(RpcResponseHeaderProto.getDefaultInstance());
       switch (header.getStatus()) {
         case ERROR: // might get a RPC error during 
         case FATAL:
@@ -380,15 +382,14 @@ public class SaslRpcClient {
                                     header.getErrorMsg());
         default: break;
       }
-      if (totalLen != responseWrapper.getLength()) {
-        throw new SaslException("Received malformed response length");
-      }
-      
       if (header.getCallId() != AuthProtocol.SASL.callId) {
         throw new SaslException("Non-SASL response during negotiation");
       }
       RpcSaslProto saslMessage =
-          RpcSaslProto.parseFrom(responseWrapper.getMessageBytes());
+          saslPacket.getValue(RpcSaslProto.getDefaultInstance());
+      if (saslPacket.remaining() > 0) {
+        throw new SaslException("Received malformed response length");
+      }
       // handle sasl negotiation process
       RpcSaslProto.Builder response = null;
       switch (saslMessage.getState()) {
@@ -452,16 +453,16 @@ public class SaslRpcClient {
     return authMethod;
   }
 
-  private void sendSaslMessage(DataOutputStream out, RpcSaslProto message)
+  private void sendSaslMessage(OutputStream out, RpcSaslProto message)
       throws IOException {
     if (LOG.isDebugEnabled()) {
       LOG.debug("Sending sasl message "+message);
     }
-    RpcRequestMessageWrapper request =
-        new RpcRequestMessageWrapper(saslHeader, message);
-    out.writeInt(request.getLength());
-    request.write(out);
-    out.flush();    
+    ResponseBuffer buf = new ResponseBuffer();
+    saslHeader.writeDelimitedTo(buf);
+    message.writeDelimitedTo(buf);
+    buf.writeTo(out);
+    out.flush();
   }
 
   /**
@@ -634,12 +635,8 @@ public class SaslRpcClient {
           .setState(SaslState.WRAP)
           .setToken(ByteString.copyFrom(buf, 0, buf.length))
           .build();
-      RpcRequestMessageWrapper request =
-          new RpcRequestMessageWrapper(saslHeader, saslMessage);
-      DataOutputStream dob = new DataOutputStream(out);
-      dob.writeInt(request.getLength());
-      request.write(dob);
-     }
+      sendSaslMessage(out, saslMessage);
+    }
   }
 
   /** Release resources used by wrapped saslClient */


---------------------------------------------------------------------
To unsubscribe, e-mail: common-commits-unsubscr...@hadoop.apache.org
For additional commands, e-mail: common-commits-h...@hadoop.apache.org

Reply via email to