HADOOP-12909. Change ipc.Client to support asynchronous calls.  Contributed by  
Xiaobing Zhou


Project: http://git-wip-us.apache.org/repos/asf/hadoop/repo
Commit: http://git-wip-us.apache.org/repos/asf/hadoop/commit/a62637a4
Tree: http://git-wip-us.apache.org/repos/asf/hadoop/tree/a62637a4
Diff: http://git-wip-us.apache.org/repos/asf/hadoop/diff/a62637a4

Branch: refs/heads/YARN-3368
Commit: a62637a413ad88c4273d3251892b8fc1c05afa34
Parents: 3c18a53
Author: Tsz-Wo Nicholas Sze <szets...@hortonworks.com>
Authored: Thu Apr 7 14:01:33 2016 +0800
Committer: Tsz-Wo Nicholas Sze <szets...@hortonworks.com>
Committed: Thu Apr 7 14:02:51 2016 +0800

----------------------------------------------------------------------
 .../main/java/org/apache/hadoop/ipc/Client.java |  73 +++-
 .../org/apache/hadoop/ipc/TestAsyncIPC.java     | 346 +++++++++++++++++++
 .../java/org/apache/hadoop/ipc/TestIPC.java     |  29 +-
 3 files changed, 436 insertions(+), 12 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/hadoop/blob/a62637a4/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Client.java
----------------------------------------------------------------------
diff --git 
a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Client.java
 
b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Client.java
index fb11cb7..489c354 100644
--- 
a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Client.java
+++ 
b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/Client.java
@@ -62,6 +62,7 @@ import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.hadoop.classification.InterfaceAudience;
 import org.apache.hadoop.classification.InterfaceStability;
+import org.apache.hadoop.classification.InterfaceStability.Unstable;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.CommonConfigurationKeys;
 import org.apache.hadoop.fs.CommonConfigurationKeysPublic;
@@ -96,6 +97,7 @@ import org.apache.htrace.core.Tracer;
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Preconditions;
+import com.google.common.util.concurrent.AbstractFuture;
 import com.google.common.util.concurrent.ThreadFactoryBuilder;
 import com.google.protobuf.CodedOutputStream;
 
@@ -107,7 +109,7 @@ import com.google.protobuf.CodedOutputStream;
  */
 @InterfaceAudience.LimitedPrivate(value = { "Common", "HDFS", "MapReduce", 
"Yarn" })
 @InterfaceStability.Evolving
-public class Client {
+public class Client implements AutoCloseable {
   
   public static final Log LOG = LogFactory.getLog(Client.class);
 
@@ -116,6 +118,20 @@ public class Client {
 
   private static final ThreadLocal<Integer> callId = new 
ThreadLocal<Integer>();
   private static final ThreadLocal<Integer> retryCount = new 
ThreadLocal<Integer>();
+  private static final ThreadLocal<Future<?>> returnValue = new 
ThreadLocal<>();
+  private static final ThreadLocal<Boolean> asynchronousMode =
+      new ThreadLocal<Boolean>() {
+        @Override
+        protected Boolean initialValue() {
+          return false;
+        }
+      };
+
+  @SuppressWarnings("unchecked")
+  @Unstable
+  public static <T> Future<T> getReturnValue() {
+    return (Future<T>) returnValue.get();
+  }
 
   /** Set call id and retry count for the next call. */
   public static void setCallIdAndRetryCount(int cid, int rc) {
@@ -1354,8 +1370,8 @@ public class Client {
       ConnectionId remoteId, int serviceClass,
       AtomicBoolean fallbackToSimpleAuth) throws IOException {
     final Call call = createCall(rpcKind, rpcRequest);
-    Connection connection = getConnection(remoteId, call, serviceClass,
-      fallbackToSimpleAuth);
+    final Connection connection = getConnection(remoteId, call, serviceClass,
+        fallbackToSimpleAuth);
     try {
       connection.sendRpcRequest(call);                 // send the rpc request
     } catch (RejectedExecutionException e) {
@@ -1366,6 +1382,51 @@ public class Client {
       throw new IOException(e);
     }
 
+    if (isAsynchronousMode()) {
+      Future<Writable> returnFuture = new AbstractFuture<Writable>() {
+        @Override
+        public Writable get() throws InterruptedException, ExecutionException {
+          try {
+            set(getRpcResponse(call, connection));
+          } catch (IOException ie) {
+            setException(ie);
+          }
+          return super.get();
+        }
+      };
+
+      returnValue.set(returnFuture);
+      return null;
+    } else {
+      return getRpcResponse(call, connection);
+    }
+  }
+
+  /**
+   * Check if RPC is in asynchronous mode or not.
+   *
+   * @returns true, if RPC is in asynchronous mode, otherwise false for
+   *          synchronous mode.
+   */
+  @Unstable
+  static boolean isAsynchronousMode() {
+    return asynchronousMode.get();
+  }
+
+  /**
+   * Set RPC to asynchronous or synchronous mode.
+   *
+   * @param async
+   *          true, RPC will be in asynchronous mode, otherwise false for
+   *          synchronous mode
+   */
+  @Unstable
+  public static void setAsynchronousMode(boolean async) {
+    asynchronousMode.set(async);
+  }
+
+  private Writable getRpcResponse(final Call call, final Connection connection)
+      throws IOException {
     synchronized (call) {
       while (!call.done) {
         try {
@@ -1640,4 +1701,10 @@ public class Client {
   public static int nextCallId() {
     return callIdCounter.getAndIncrement() & 0x7FFFFFFF;
   }
+
+  @Override
+  @Unstable
+  public void close() throws Exception {
+    stop();
+  }
 }

http://git-wip-us.apache.org/repos/asf/hadoop/blob/a62637a4/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestAsyncIPC.java
----------------------------------------------------------------------
diff --git 
a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestAsyncIPC.java
 
b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestAsyncIPC.java
new file mode 100644
index 0000000..de4395e
--- /dev/null
+++ 
b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestAsyncIPC.java
@@ -0,0 +1,346 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.ipc;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+
+import java.io.IOException;
+import java.net.InetSocketAddress;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Future;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.ipc.RPC.RpcKind;
+import org.apache.hadoop.ipc.TestIPC.CallInfo;
+import org.apache.hadoop.ipc.TestIPC.TestServer;
+import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcResponseHeaderProto;
+import org.apache.hadoop.net.NetUtils;
+import org.apache.hadoop.util.StringUtils;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+public class TestAsyncIPC {
+
+  private static Configuration conf;
+  private static final Log LOG = LogFactory.getLog(TestAsyncIPC.class);
+
+  @Before
+  public void setupConf() {
+    conf = new Configuration();
+    Client.setPingInterval(conf, TestIPC.PING_INTERVAL);
+    // set asynchronous mode for main thread
+    Client.setAsynchronousMode(true);
+  }
+
+  protected static class SerialCaller extends Thread {
+    private Client client;
+    private InetSocketAddress server;
+    private int count;
+    private boolean failed;
+    Map<Integer, Future<LongWritable>> returnFutures =
+        new HashMap<Integer, Future<LongWritable>>();
+    Map<Integer, Long> expectedValues = new HashMap<Integer, Long>();
+
+    public SerialCaller(Client client, InetSocketAddress server, int count) {
+      this.client = client;
+      this.server = server;
+      this.count = count;
+      // set asynchronous mode, since SerialCaller extends Thread
+      Client.setAsynchronousMode(true);
+    }
+
+    @Override
+    public void run() {
+      // in case Thread#Start is called, which will spawn new thread
+      Client.setAsynchronousMode(true);
+      for (int i = 0; i < count; i++) {
+        try {
+          final long param = TestIPC.RANDOM.nextLong();
+          TestIPC.call(client, param, server, conf);
+          Future<LongWritable> returnFuture = Client.getReturnValue();
+          returnFutures.put(i, returnFuture);
+          expectedValues.put(i, param);
+        } catch (Exception e) {
+          LOG.fatal("Caught: " + StringUtils.stringifyException(e));
+          failed = true;
+        }
+      }
+    }
+
+    public void waitForReturnValues() throws InterruptedException,
+        ExecutionException {
+      for (int i = 0; i < count; i++) {
+        LongWritable value = returnFutures.get(i).get();
+        if (expectedValues.get(i) != value.get()) {
+          LOG.fatal(String.format("Call-%d failed!", i));
+          failed = true;
+          break;
+        }
+      }
+    }
+  }
+
+  @Test
+  public void testSerial() throws IOException, InterruptedException,
+      ExecutionException {
+    internalTestSerial(3, false, 2, 5, 100);
+    internalTestSerial(3, true, 2, 5, 10);
+  }
+
+  public void internalTestSerial(int handlerCount, boolean handlerSleep,
+      int clientCount, int callerCount, int callCount) throws IOException,
+      InterruptedException, ExecutionException {
+    Server server = new TestIPC.TestServer(handlerCount, handlerSleep, conf);
+    InetSocketAddress addr = NetUtils.getConnectAddress(server);
+    server.start();
+
+    Client[] clients = new Client[clientCount];
+    for (int i = 0; i < clientCount; i++) {
+      clients[i] = new Client(LongWritable.class, conf);
+    }
+
+    SerialCaller[] callers = new SerialCaller[callerCount];
+    for (int i = 0; i < callerCount; i++) {
+      callers[i] = new SerialCaller(clients[i % clientCount], addr, callCount);
+      callers[i].start();
+    }
+    for (int i = 0; i < callerCount; i++) {
+      callers[i].join();
+      callers[i].waitForReturnValues();
+      String msg = String.format("Expected not failed for caller-%d: %s.", i,
+          callers[i]);
+      assertFalse(msg, callers[i].failed);
+    }
+    for (int i = 0; i < clientCount; i++) {
+      clients[i].stop();
+    }
+    server.stop();
+  }
+
+  /**
+   * Test if (1) the rpc server uses the call id/retry provided by the rpc
+   * client, and (2) the rpc client receives the same call id/retry from the 
rpc
+   * server.
+   *
+   * @throws ExecutionException
+   * @throws InterruptedException
+   */
+  @Test(timeout = 60000)
+  public void testCallIdAndRetry() throws IOException, InterruptedException,
+      ExecutionException {
+    final Map<Integer, CallInfo> infoMap = new HashMap<Integer, CallInfo>();
+
+    // Override client to store the call info and check response
+    final Client client = new Client(LongWritable.class, conf) {
+      @Override
+      Call createCall(RpcKind rpcKind, Writable rpcRequest) {
+        // Set different call id and retry count for the next call
+        Client.setCallIdAndRetryCount(Client.nextCallId(),
+            TestIPC.RANDOM.nextInt(255));
+
+        final Call call = super.createCall(rpcKind, rpcRequest);
+
+        CallInfo info = new CallInfo();
+        info.id = call.id;
+        info.retry = call.retry;
+        infoMap.put(call.id, info);
+
+        return call;
+      }
+
+      @Override
+      void checkResponse(RpcResponseHeaderProto header) throws IOException {
+        super.checkResponse(header);
+        Assert.assertEquals(infoMap.get(header.getCallId()).retry,
+            header.getRetryCount());
+      }
+    };
+
+    // Attach a listener that tracks every call received by the server.
+    final TestServer server = new TestIPC.TestServer(1, false, conf);
+    server.callListener = new Runnable() {
+      @Override
+      public void run() {
+        Assert.assertEquals(infoMap.get(Server.getCallId()).retry,
+            Server.getCallRetryCount());
+      }
+    };
+
+    try {
+      InetSocketAddress addr = NetUtils.getConnectAddress(server);
+      server.start();
+      final SerialCaller caller = new SerialCaller(client, addr, 4);
+      caller.run();
+      caller.waitForReturnValues();
+      String msg = String.format("Expected not failed for caller: %s.", 
caller);
+      assertFalse(msg, caller.failed);
+    } finally {
+      client.stop();
+      server.stop();
+    }
+  }
+
+  /**
+   * Test if the rpc server gets the retry count from client.
+   *
+   * @throws ExecutionException
+   * @throws InterruptedException
+   */
+  @Test(timeout = 60000)
+  public void testCallRetryCount() throws IOException, InterruptedException,
+      ExecutionException {
+    final int retryCount = 255;
+    // Override client to store the call id
+    final Client client = new Client(LongWritable.class, conf);
+    Client.setCallIdAndRetryCount(Client.nextCallId(), retryCount);
+
+    // Attach a listener that tracks every call ID received by the server.
+    final TestServer server = new TestIPC.TestServer(1, false, conf);
+    server.callListener = new Runnable() {
+      @Override
+      public void run() {
+        // we have not set the retry count for the client, thus on the server
+        // side we should see retry count as 0
+        Assert.assertEquals(retryCount, Server.getCallRetryCount());
+      }
+    };
+
+    try {
+      InetSocketAddress addr = NetUtils.getConnectAddress(server);
+      server.start();
+      final SerialCaller caller = new SerialCaller(client, addr, 10);
+      caller.run();
+      caller.waitForReturnValues();
+      String msg = String.format("Expected not failed for caller: %s.", 
caller);
+      assertFalse(msg, caller.failed);
+    } finally {
+      client.stop();
+      server.stop();
+    }
+  }
+
+  /**
+   * Test if the rpc server gets the default retry count (0) from client.
+   *
+   * @throws ExecutionException
+   * @throws InterruptedException
+   */
+  @Test(timeout = 60000)
+  public void testInitialCallRetryCount() throws IOException,
+      InterruptedException, ExecutionException {
+    // Override client to store the call id
+    final Client client = new Client(LongWritable.class, conf);
+
+    // Attach a listener that tracks every call ID received by the server.
+    final TestServer server = new TestIPC.TestServer(1, false, conf);
+    server.callListener = new Runnable() {
+      @Override
+      public void run() {
+        // we have not set the retry count for the client, thus on the server
+        // side we should see retry count as 0
+        Assert.assertEquals(0, Server.getCallRetryCount());
+      }
+    };
+
+    try {
+      InetSocketAddress addr = NetUtils.getConnectAddress(server);
+      server.start();
+      final SerialCaller caller = new SerialCaller(client, addr, 10);
+      caller.run();
+      caller.waitForReturnValues();
+      String msg = String.format("Expected not failed for caller: %s.", 
caller);
+      assertFalse(msg, caller.failed);
+    } finally {
+      client.stop();
+      server.stop();
+    }
+  }
+
+  /**
+   * Tests that client generates a unique sequential call ID for each RPC call,
+   * even if multiple threads are using the same client.
+   *
+   * @throws InterruptedException
+   * @throws ExecutionException
+   */
+  @Test(timeout = 60000)
+  public void testUniqueSequentialCallIds() throws IOException,
+      InterruptedException, ExecutionException {
+    int serverThreads = 10, callerCount = 100, perCallerCallCount = 100;
+    TestServer server = new TestIPC.TestServer(serverThreads, false, conf);
+
+    // Attach a listener that tracks every call ID received by the server. This
+    // list must be synchronized, because multiple server threads will add to
+    // it.
+    final List<Integer> callIds = Collections
+        .synchronizedList(new ArrayList<Integer>());
+    server.callListener = new Runnable() {
+      @Override
+      public void run() {
+        callIds.add(Server.getCallId());
+      }
+    };
+
+    Client client = new Client(LongWritable.class, conf);
+
+    try {
+      InetSocketAddress addr = NetUtils.getConnectAddress(server);
+      server.start();
+      SerialCaller[] callers = new SerialCaller[callerCount];
+      for (int i = 0; i < callerCount; ++i) {
+        callers[i] = new SerialCaller(client, addr, perCallerCallCount);
+        callers[i].start();
+      }
+      for (int i = 0; i < callerCount; ++i) {
+        callers[i].join();
+        callers[i].waitForReturnValues();
+        String msg = String.format("Expected not failed for caller-%d: %s.", i,
+            callers[i]);
+        assertFalse(msg, callers[i].failed);
+      }
+    } finally {
+      client.stop();
+      server.stop();
+    }
+
+    int expectedCallCount = callerCount * perCallerCallCount;
+    assertEquals(expectedCallCount, callIds.size());
+
+    // It is not guaranteed that the server executes requests in sequential
+    // order
+    // of client call ID, so we must sort the call IDs before checking that it
+    // contains every expected value.
+    Collections.sort(callIds);
+    final int startID = callIds.get(0).intValue();
+    for (int i = 0; i < expectedCallCount; ++i) {
+      assertEquals(startID + i, callIds.get(i).intValue());
+    }
+  }
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/hadoop/blob/a62637a4/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestIPC.java
----------------------------------------------------------------------
diff --git 
a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestIPC.java
 
b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestIPC.java
index d658182..6bfcc53 100644
--- 
a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestIPC.java
+++ 
b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestIPC.java
@@ -99,7 +99,7 @@ public class TestIPC {
     LogFactory.getLog(TestIPC.class);
   
   private static Configuration conf;
-  final static private int PING_INTERVAL = 1000;
+  final static int PING_INTERVAL = 1000;
   final static private int MIN_SLEEP_TIME = 1000;
   /**
    * Flag used to turn off the fault injection behavior
@@ -114,7 +114,7 @@ public class TestIPC {
     Client.setPingInterval(conf, PING_INTERVAL);
   }
 
-  private static final Random RANDOM = new Random();
+  static final Random RANDOM = new Random();
 
   private static final String ADDRESS = "0.0.0.0";
 
@@ -148,22 +148,33 @@ public class TestIPC {
         RPC.RPC_SERVICE_CLASS_DEFAULT, null);
   }
 
-  private static class TestServer extends Server {
+  static class TestServer extends Server {
     // Tests can set callListener to run a piece of code each time the server
     // receives a call.  This code executes on the server thread, so it has
     // visibility of that thread's thread-local storage.
-    private Runnable callListener;
+    Runnable callListener;
     private boolean sleep;
     private Class<? extends Writable> responseClass;
 
     public TestServer(int handlerCount, boolean sleep) throws IOException {
       this(handlerCount, sleep, LongWritable.class, null);
     }
-    
+
+    public TestServer(int handlerCount, boolean sleep, Configuration conf)
+        throws IOException {
+      this(handlerCount, sleep, LongWritable.class, null, conf);
+    }
+
+    public TestServer(int handlerCount, boolean sleep,
+        Class<? extends Writable> paramClass,
+        Class<? extends Writable> responseClass) throws IOException {
+      this(handlerCount, sleep, paramClass, responseClass, conf);
+    }
+
     public TestServer(int handlerCount, boolean sleep,
         Class<? extends Writable> paramClass,
-        Class<? extends Writable> responseClass) 
-      throws IOException {
+        Class<? extends Writable> responseClass, Configuration conf)
+        throws IOException {
       super(ADDRESS, 0, paramClass, handlerCount, conf);
       this.sleep = sleep;
       this.responseClass = responseClass;
@@ -1070,7 +1081,7 @@ public class TestIPC {
     assertRetriesOnSocketTimeouts(conf, 4);
   }
 
-  private static class CallInfo {
+  static class CallInfo {
     int id = RpcConstants.INVALID_CALL_ID;
     int retry = RpcConstants.INVALID_RETRY_COUNT;
   }
@@ -1125,7 +1136,7 @@ public class TestIPC {
   }
   
   /** A dummy protocol */
-  private interface DummyProtocol {
+  interface DummyProtocol {
     @Idempotent
     public void dummyRun() throws IOException;
   }

Reply via email to