Repository: hive
Updated Branches:
  refs/heads/llap e35f5c92b -> ce2af3d52


HIVE-10425. LLAP: Limit number of threads used to communicate with a
single LLAP instance to 1. (Siddharth Seth)


Project: http://git-wip-us.apache.org/repos/asf/hive/repo
Commit: http://git-wip-us.apache.org/repos/asf/hive/commit/ce2af3d5
Tree: http://git-wip-us.apache.org/repos/asf/hive/tree/ce2af3d5
Diff: http://git-wip-us.apache.org/repos/asf/hive/diff/ce2af3d5

Branch: refs/heads/llap
Commit: ce2af3d52f7330854c8dc9863286ef0049425dc0
Parents: e35f5c9
Author: Siddharth Seth <ss...@apache.org>
Authored: Mon Jun 1 16:06:59 2015 -0700
Committer: Siddharth Seth <ss...@apache.org>
Committed: Mon Jun 1 16:06:59 2015 -0700

----------------------------------------------------------------------
 .../hive/llap/tezplugins/TaskCommunicator.java  | 344 ++++++++++++++++---
 .../llap/tezplugins/TestTaskCommunicator.java   | 143 ++++++++
 2 files changed, 438 insertions(+), 49 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/hive/blob/ce2af3d5/llap-server/src/java/org/apache/hadoop/hive/llap/tezplugins/TaskCommunicator.java
----------------------------------------------------------------------
diff --git 
a/llap-server/src/java/org/apache/hadoop/hive/llap/tezplugins/TaskCommunicator.java
 
b/llap-server/src/java/org/apache/hadoop/hive/llap/tezplugins/TaskCommunicator.java
index d357d61..33e998c 100644
--- 
a/llap-server/src/java/org/apache/hadoop/hive/llap/tezplugins/TaskCommunicator.java
+++ 
b/llap-server/src/java/org/apache/hadoop/hive/llap/tezplugins/TaskCommunicator.java
@@ -15,13 +15,24 @@
 package org.apache.hadoop.hive.llap.tezplugins;
 
 import javax.net.SocketFactory;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Set;
 import java.util.concurrent.Callable;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ConcurrentMap;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.locks.Condition;
+import java.util.concurrent.locks.Lock;
+import java.util.concurrent.locks.ReentrantLock;
 
+import com.google.common.annotations.VisibleForTesting;
 import com.google.common.util.concurrent.FutureCallback;
 import com.google.common.util.concurrent.Futures;
 import com.google.common.util.concurrent.ListenableFuture;
@@ -29,9 +40,8 @@ import 
com.google.common.util.concurrent.ListeningExecutorService;
 import com.google.common.util.concurrent.MoreExecutors;
 import com.google.common.util.concurrent.ThreadFactoryBuilder;
 import com.google.protobuf.Message;
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
 import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.hive.llap.LlapNodeId;
 import org.apache.hadoop.hive.llap.configuration.LlapConfiguration;
 import org.apache.hadoop.hive.llap.daemon.LlapDaemonProtocolBlockingPB;
 import org.apache.hadoop.hive.llap.daemon.impl.LlapDaemonProtocolClientImpl;
@@ -47,24 +57,26 @@ import org.apache.hadoop.io.retry.RetryPolicies;
 import org.apache.hadoop.io.retry.RetryPolicy;
 import org.apache.hadoop.net.NetUtils;
 import org.apache.hadoop.service.AbstractService;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 public class TaskCommunicator extends AbstractService {
 
-  private static final Log LOG = LogFactory.getLog(TaskCommunicator.class);
+  private static final Logger LOG = 
LoggerFactory.getLogger(TaskCommunicator.class);
 
   private final ConcurrentMap<String, LlapDaemonProtocolBlockingPB> 
hostProxies;
-  private ListeningExecutorService executor;
 
+  private final RequestManager requestManager;
   private final RetryPolicy retryPolicy;
   private final SocketFactory socketFactory;
 
+  private final ListeningExecutorService requestManagerExecutor;
+  private volatile ListenableFuture<Void> requestManagerFuture;
+
+
   public TaskCommunicator(int numThreads, Configuration conf) {
     super(TaskCommunicator.class.getSimpleName());
-    ExecutorService localExecutor = Executors.newFixedThreadPool(numThreads,
-        new ThreadFactoryBuilder().setNameFormat("TaskCommunicator 
#%2d").build());
     this.hostProxies = new ConcurrentHashMap<>();
-    executor = MoreExecutors.listeningDecorator(localExecutor);
-
     this.socketFactory = NetUtils.getDefaultSocketFactory(conf);
 
     long connectionTimeout =
@@ -75,6 +87,12 @@ public class TaskCommunicator extends AbstractService {
         
LlapConfiguration.LLAP_TASK_COMMUNICATOR_CONNECTION_SLEEP_BETWEEN_RETRIES_MILLIS_DEFAULT);
     this.retryPolicy = 
RetryPolicies.retryUpToMaximumTimeWithFixedSleep(connectionTimeout, retrySleep,
         TimeUnit.MILLISECONDS);
+
+    this.requestManager = new RequestManager(numThreads);
+    ExecutorService localExecutor = Executors.newFixedThreadPool(1,
+        new 
ThreadFactoryBuilder().setNameFormat("RequestManagerExecutor").build());
+    this.requestManagerExecutor = 
MoreExecutors.listeningDecorator(localExecutor);
+
     LOG.info("Setting up taskCommunicator with" +
         "numThreads=" + numThreads +
         "retryTime(millis)=" + connectionTimeout +
@@ -82,71 +100,294 @@ public class TaskCommunicator extends AbstractService {
   }
 
   @Override
+  public void serviceStart() {
+    requestManagerFuture = requestManagerExecutor.submit(requestManager);
+    Futures.addCallback(requestManagerFuture, new FutureCallback<Void>() {
+      @Override
+      public void onSuccess(Void result) {
+        LOG.info("RequestManager shutdown");
+      }
+
+      @Override
+      public void onFailure(Throwable t) {
+        LOG.warn("RequestManager shutdown with error", t);
+      }
+    });
+  }
+
+  @Override
   public void serviceStop() {
-    executor.shutdownNow();
+    if (requestManagerFuture != null) {
+      requestManager.shutdown();
+      requestManagerFuture.cancel(true);
+    }
+    requestManagerExecutor.shutdown();
   }
 
   public void sendSubmitWork(SubmitWorkRequestProto request, String host, int 
port,
                          final ExecuteRequestCallback<SubmitWorkResponseProto> 
callback) {
-    ListenableFuture<SubmitWorkResponseProto> future = executor.submit(new 
SubmitWorkCallable(host, port, request));
-    Futures.addCallback(future, new ResponseCallback(callback));
+    LlapNodeId nodeId = LlapNodeId.getInstance(host, port);
+    requestManager.queueRequest(new SubmitWorkCallable(nodeId, request, 
callback));
   }
 
   public void sendSourceStateUpdate(final SourceStateUpdatedRequestProto 
request, final String host,
                                     final int port,
                                     final 
ExecuteRequestCallback<SourceStateUpdatedResponseProto> callback) {
-    ListenableFuture<SourceStateUpdatedResponseProto> future =
-        executor.submit(new SendSourceStateUpdateCallable(host, port, 
request));
-    Futures.addCallback(future, new ResponseCallback(callback));
+    LlapNodeId nodeId = LlapNodeId.getInstance(host, port);
+    requestManager.queueRequest(
+        new SendSourceStateUpdateCallable(nodeId, request, callback));
   }
 
   public void sendQueryComplete(final QueryCompleteRequestProto request, final 
String host,
                                 final int port,
                                 final 
ExecuteRequestCallback<QueryCompleteResponseProto> callback) {
-    ListenableFuture<QueryCompleteResponseProto> future =
-        executor.submit(new SendQueryCompleteCallable(host, port, request));
-    Futures.addCallback(future, new ResponseCallback(callback));
+    LlapNodeId nodeId = LlapNodeId.getInstance(host, port);
+    requestManager.queueRequest(new SendQueryCompleteCallable(nodeId, request, 
callback));
   }
 
   public void sendTerminateFragment(final TerminateFragmentRequestProto 
request, final String host,
                                     final int port,
                                     final 
ExecuteRequestCallback<TerminateFragmentResponseProto> callback) {
-    ListenableFuture<TerminateFragmentResponseProto> future =
-        executor.submit(new SendTerminateFragmentCallable(host, port, 
request));
-    Futures.addCallback(future, new ResponseCallback(callback));
+    LlapNodeId nodeId = LlapNodeId.getInstance(host, port);
+    requestManager.queueRequest(new SendTerminateFragmentCallable(nodeId, 
request, callback));
   }
 
-  private static class ResponseCallback<TYPE extends Message> implements 
FutureCallback<TYPE> {
+  @VisibleForTesting
+  static class RequestManager implements Callable<Void> {
+
+    private final Lock lock = new ReentrantLock();
+    private final AtomicBoolean isShutdown = new AtomicBoolean(false);
+    private final Condition queueCondition = lock.newCondition();
+    private final AtomicBoolean shouldRun = new AtomicBoolean(false);
+
+    private final int maxConcurrentRequestsPerNode = 1;
+    private final ListeningExecutorService executor;
+
+
+    // Tracks new additions via add, while the loop is processing existing 
ones.
+    private final LinkedList<CallableRequest> newRequestList = new 
LinkedList<>();
+
+    // Tracks existing requests which are cycled through.
+    private final LinkedList<CallableRequest> pendingRequests = new 
LinkedList<>();
+
+    // Tracks requests executing per node
+    private final ConcurrentMap<LlapNodeId, AtomicInteger> runningRequests = 
new ConcurrentHashMap<>();
+
+    // Tracks completed requests pre node
+    private final LinkedList<LlapNodeId> completedNodes = new LinkedList<>();
+
+    public RequestManager(int numThreads) {
+      ExecutorService localExecutor = Executors.newFixedThreadPool(numThreads,
+          new ThreadFactoryBuilder().setNameFormat("TaskCommunicator 
#%2d").build());
+      executor = MoreExecutors.listeningDecorator(localExecutor);
+    }
+
+
+    @VisibleForTesting
+    Set<LlapNodeId> currentLoopDisabledNodes = new HashSet<>();
+    @VisibleForTesting
+    List<CallableRequest> currentLoopSkippedRequests = new LinkedList<>();
+    @Override
+    public Void call() {
+      // Caches disabled nodes for quicker lookups and ensures a request on a 
node which was skipped
+      // does not go out of order.
+      while (!isShutdown.get()) {
+        lock.lock();
+        try {
+          while (!shouldRun.get()) {
+            queueCondition.await();
+            break; // Break out and try executing.
+          }
+          boolean shouldBreak = process();
+          if (shouldBreak) {
+            break;
+          }
+        } catch (InterruptedException e) {
+          if (isShutdown.get()) {
+            break;
+          } else {
+            LOG.warn("RunLoop interrupted without being shutdown first");
+            throw new RuntimeException(e);
+          }
+        } finally {
+          lock.unlock();
+        }
+      }
+      LOG.info("CallScheduler loop exiting");
+      return null;
+    }
+
+    /* Add a new request to be executed */
+    public void queueRequest(CallableRequest request) {
+      synchronized (newRequestList) {
+        newRequestList.add(request);
+        shouldRun.set(true);
+      }
+      notifyRunLoop();
+    }
+
+    /* Indicates a request has completed on a node */
+    public void requestFinished(LlapNodeId nodeId) {
+      synchronized (completedNodes) {
+        completedNodes.add(nodeId);
+        shouldRun.set(true);
+      }
+      notifyRunLoop();
+    }
+
+    public void shutdown() {
+      if (!isShutdown.getAndSet(true)) {
+        executor.shutdownNow();
+        notifyRunLoop();
+      }
+    }
+
+    @VisibleForTesting
+    void submitToExecutor(CallableRequest request, LlapNodeId nodeId) {
+      ListenableFuture<SourceStateUpdatedResponseProto> future =
+          executor.submit(request);
+      Futures.addCallback(future, new ResponseCallback(request.getCallback(), 
nodeId, this));
+    }
+
+    @VisibleForTesting
+    boolean process() {
+      if (isShutdown.get()) {
+        return true;
+      }
+      currentLoopDisabledNodes.clear();
+      currentLoopSkippedRequests.clear();
+
+      // Set to false to block the next loop. This must be called before 
draining the lists,
+      // otherwise an add/completion after draining the lists but before 
setting it to false,
+      // will not trigger a run. May cause one unnecessary run if an add comes 
in before drain.
+      // drain list. add request (setTrue). setFalse needs to be avoided.
+      shouldRun.compareAndSet(true, false);
+      // Drain any calls which may have come in during the last execution of 
the loop.
+      drainNewRequestList();  // Locks newRequestList
+      drainCompletedNodes();  // Locks completedNodes
+
+
+      Iterator<CallableRequest> iterator = pendingRequests.iterator();
+      while (iterator.hasNext()) {
+        CallableRequest request = iterator.next();
+        iterator.remove();
+        LlapNodeId nodeId = request.getNodeId();
+        if (canRunForNode(nodeId, currentLoopDisabledNodes)) {
+          submitToExecutor(request, nodeId);
+        } else {
+          currentLoopDisabledNodes.add(nodeId);
+          currentLoopSkippedRequests.add(request);
+        }
+      }
+      // Tried scheduling everything that could be scheduled in this loop.
+      pendingRequests.addAll(0, currentLoopSkippedRequests);
+      return false;
+    }
+
+    private void drainNewRequestList() {
+      synchronized (newRequestList) {
+        if (!newRequestList.isEmpty()) {
+          pendingRequests.addAll(newRequestList);
+          newRequestList.clear();
+        }
+      }
+    }
+
+    private void drainCompletedNodes() {
+      synchronized (completedNodes) {
+        if (!completedNodes.isEmpty()) {
+          for (LlapNodeId nodeId : completedNodes) {
+            runningRequests.get(nodeId).decrementAndGet();
+          }
+        }
+        completedNodes.clear();
+      }
+    }
+
+    private boolean canRunForNode(LlapNodeId nodeId, Set<LlapNodeId> 
currentRunDisabledNodes) {
+      if (currentRunDisabledNodes.contains(nodeId)) {
+        return false;
+      } else {
+        AtomicInteger count = runningRequests.get(nodeId);
+        if (count == null) {
+          count = new AtomicInteger(0);
+          AtomicInteger old = runningRequests.putIfAbsent(nodeId, count);
+          count = old != null ? old : count;
+        }
+        if (count.incrementAndGet() <= maxConcurrentRequestsPerNode) {
+          return true;
+        } else {
+          count.decrementAndGet();
+          return false;
+        }
+      }
+    }
+
+    private void notifyRunLoop() {
+      lock.lock();
+      try {
+        queueCondition.signal();
+      } finally {
+        lock.unlock();
+      }
+    }
+  }
+
+
+  private static final class ResponseCallback<TYPE extends Message>
+      implements FutureCallback<TYPE> {
 
     private final ExecuteRequestCallback<TYPE> callback;
+    private final LlapNodeId nodeId;
+    private final RequestManager requestManager;
 
-    public ResponseCallback(ExecuteRequestCallback<TYPE> callback) {
+    public ResponseCallback(ExecuteRequestCallback<TYPE> callback, LlapNodeId 
nodeId,
+                            RequestManager requestManager) {
       this.callback = callback;
+      this.nodeId = nodeId;
+      this.requestManager = requestManager;
     }
 
     @Override
     public void onSuccess(TYPE result) {
-      callback.setResponse(result);
+      try {
+        callback.setResponse(result);
+      } finally {
+        requestManager.requestFinished(nodeId);
+      }
     }
 
     @Override
     public void onFailure(Throwable t) {
-      callback.indicateError(t);
+      try {
+        callback.indicateError(t);
+      } finally {
+        requestManager.requestFinished(nodeId);
+      }
     }
   }
 
-  private static abstract class CallableRequest<REQUEST extends Message, 
RESPONSE extends Message>
+  @VisibleForTesting
+  static abstract class CallableRequest<REQUEST extends Message, RESPONSE 
extends Message>
       implements Callable {
 
-    final String hostname;
-    final int port;
+    final LlapNodeId nodeId;
+    final ExecuteRequestCallback<RESPONSE> callback;
     final REQUEST request;
 
 
-    protected CallableRequest(String hostname, int port, REQUEST request) {
-      this.hostname = hostname;
-      this.port = port;
+    protected CallableRequest(LlapNodeId nodeId, REQUEST request, 
ExecuteRequestCallback<RESPONSE> callback) {
+      this.nodeId = nodeId;
       this.request = request;
+      this.callback = callback;
+    }
+
+    public LlapNodeId getNodeId() {
+      return nodeId;
+    }
+
+    public ExecuteRequestCallback<RESPONSE> getCallback() {
+      return callback;
     }
 
     public abstract RESPONSE call() throws Exception;
@@ -154,56 +395,60 @@ public class TaskCommunicator extends AbstractService {
 
   private class SubmitWorkCallable extends 
CallableRequest<SubmitWorkRequestProto, SubmitWorkResponseProto> {
 
-    protected SubmitWorkCallable(String hostname, int port,
-                          SubmitWorkRequestProto submitWorkRequestProto) {
-      super(hostname, port, submitWorkRequestProto);
+    protected SubmitWorkCallable(LlapNodeId nodeId,
+                          SubmitWorkRequestProto submitWorkRequestProto,
+                                 
ExecuteRequestCallback<SubmitWorkResponseProto> callback) {
+      super(nodeId, submitWorkRequestProto, callback);
     }
 
     @Override
     public SubmitWorkResponseProto call() throws Exception {
-      return getProxy(hostname, port).submitWork(null, request);
+      return getProxy(nodeId).submitWork(null, request);
     }
   }
 
   private class SendSourceStateUpdateCallable
       extends CallableRequest<SourceStateUpdatedRequestProto, 
SourceStateUpdatedResponseProto> {
 
-    public SendSourceStateUpdateCallable(String hostname, int port,
-                                         SourceStateUpdatedRequestProto 
request) {
-      super(hostname, port, request);
+    public SendSourceStateUpdateCallable(LlapNodeId nodeId,
+                                         SourceStateUpdatedRequestProto 
request,
+                                         
ExecuteRequestCallback<SourceStateUpdatedResponseProto> callback) {
+      super(nodeId, request, callback);
     }
 
     @Override
     public SourceStateUpdatedResponseProto call() throws Exception {
-      return getProxy(hostname, port).sourceStateUpdated(null, request);
+      return getProxy(nodeId).sourceStateUpdated(null, request);
     }
   }
 
   private class SendQueryCompleteCallable
       extends CallableRequest<QueryCompleteRequestProto, 
QueryCompleteResponseProto> {
 
-    protected SendQueryCompleteCallable(String hostname, int port,
-                                        QueryCompleteRequestProto 
queryCompleteRequestProto) {
-      super(hostname, port, queryCompleteRequestProto);
+    protected SendQueryCompleteCallable(LlapNodeId nodeId,
+                                        QueryCompleteRequestProto 
queryCompleteRequestProto,
+                                        
ExecuteRequestCallback<QueryCompleteResponseProto> callback) {
+      super(nodeId, queryCompleteRequestProto, callback);
     }
 
     @Override
     public QueryCompleteResponseProto call() throws Exception {
-      return getProxy(hostname, port).queryComplete(null, request);
+      return getProxy(nodeId).queryComplete(null, request);
     }
   }
 
   private class SendTerminateFragmentCallable
       extends CallableRequest<TerminateFragmentRequestProto, 
TerminateFragmentResponseProto> {
 
-    protected SendTerminateFragmentCallable(String hostname, int port,
-                                            TerminateFragmentRequestProto 
terminateFragmentRequestProto) {
-      super(hostname, port, terminateFragmentRequestProto);
+    protected SendTerminateFragmentCallable(LlapNodeId nodeId,
+                                            TerminateFragmentRequestProto 
terminateFragmentRequestProto,
+                                            
ExecuteRequestCallback<TerminateFragmentResponseProto> callback) {
+      super(nodeId, terminateFragmentRequestProto, callback);
     }
 
     @Override
     public TerminateFragmentResponseProto call() throws Exception {
-      return getProxy(hostname, port).terminateFragment(null, request);
+      return getProxy(nodeId).terminateFragment(null, request);
     }
   }
 
@@ -212,12 +457,13 @@ public class TaskCommunicator extends AbstractService {
     void indicateError(Throwable t);
   }
 
-  private LlapDaemonProtocolBlockingPB getProxy(String hostname, int port) {
-    String hostId = getHostIdentifier(hostname, port);
+  private LlapDaemonProtocolBlockingPB getProxy(LlapNodeId nodeId) {
+    String hostId = getHostIdentifier(nodeId.getHostname(), nodeId.getPort());
 
     LlapDaemonProtocolBlockingPB proxy = hostProxies.get(hostId);
     if (proxy == null) {
-      proxy = new LlapDaemonProtocolClientImpl(getConfig(), hostname, port, 
retryPolicy, socketFactory);
+      proxy = new LlapDaemonProtocolClientImpl(getConfig(), 
nodeId.getHostname(), nodeId.getPort(),
+          retryPolicy, socketFactory);
       LlapDaemonProtocolBlockingPB proxyOld = hostProxies.putIfAbsent(hostId, 
proxy);
       if (proxyOld != null) {
         // TODO Shutdown the new proxy.

http://git-wip-us.apache.org/repos/asf/hive/blob/ce2af3d5/llap-server/src/test/org/apache/hadoop/hive/llap/tezplugins/TestTaskCommunicator.java
----------------------------------------------------------------------
diff --git 
a/llap-server/src/test/org/apache/hadoop/hive/llap/tezplugins/TestTaskCommunicator.java
 
b/llap-server/src/test/org/apache/hadoop/hive/llap/tezplugins/TestTaskCommunicator.java
new file mode 100644
index 0000000..2aef4ed
--- /dev/null
+++ 
b/llap-server/src/test/org/apache/hadoop/hive/llap/tezplugins/TestTaskCommunicator.java
@@ -0,0 +1,143 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.hive.llap.tezplugins;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.mock;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import com.google.protobuf.Message;
+import org.apache.commons.lang3.mutable.MutableInt;
+import org.apache.hadoop.hive.llap.LlapNodeId;
+import org.junit.Test;
+
+public class TestTaskCommunicator {
+
+  @Test (timeout = 5000)
+  public void testMultipleNodes() {
+    RequestManagerForTest requestManager = new RequestManagerForTest(1);
+
+    LlapNodeId nodeId1 = LlapNodeId.getInstance("host1", 1025);
+    LlapNodeId nodeId2 = LlapNodeId.getInstance("host2", 1025);
+
+    Message mockMessage = mock(Message.class);
+    TaskCommunicator.ExecuteRequestCallback mockExecuteRequestCallback = mock(
+        TaskCommunicator.ExecuteRequestCallback.class);
+
+    // Request two messages
+    requestManager.queueRequest(
+        new CallableRequestForTest(nodeId1, mockMessage, 
mockExecuteRequestCallback));
+    requestManager.queueRequest(
+        new CallableRequestForTest(nodeId2, mockMessage, 
mockExecuteRequestCallback));
+
+    // Should go through in a single process call
+    requestManager.process();
+    assertEquals(2, requestManager.numSubmissionsCounters);
+    assertNotNull(requestManager.numInvocationsPerNode.get(nodeId1));
+    assertNotNull(requestManager.numInvocationsPerNode.get(nodeId2));
+    assertEquals(1, 
requestManager.numInvocationsPerNode.get(nodeId1).getValue().intValue());
+    assertEquals(1, 
requestManager.numInvocationsPerNode.get(nodeId2).getValue().intValue());
+    assertEquals(0, requestManager.currentLoopSkippedRequests.size());
+    assertEquals(0, requestManager.currentLoopSkippedRequests.size());
+    assertEquals(0, requestManager.currentLoopDisabledNodes.size());
+  }
+
+  @Test(timeout = 5000)
+  public void testSingleInvocationPerNode() {
+    RequestManagerForTest requestManager = new RequestManagerForTest(1);
+
+    LlapNodeId nodeId1 = LlapNodeId.getInstance("host1", 1025);
+
+    Message mockMessage = mock(Message.class);
+    TaskCommunicator.ExecuteRequestCallback mockExecuteRequestCallback = mock(
+        TaskCommunicator.ExecuteRequestCallback.class);
+
+    // First request for host.
+    requestManager.queueRequest(
+        new CallableRequestForTest(nodeId1, mockMessage, 
mockExecuteRequestCallback));
+    requestManager.process();
+    assertEquals(1, requestManager.numSubmissionsCounters);
+    assertNotNull(requestManager.numInvocationsPerNode.get(nodeId1));
+    assertEquals(1, 
requestManager.numInvocationsPerNode.get(nodeId1).getValue().intValue());
+    assertEquals(0, requestManager.currentLoopSkippedRequests.size());
+
+    // Second request for host. Single invocation since the last has not 
completed.
+    requestManager.queueRequest(
+        new CallableRequestForTest(nodeId1, mockMessage, 
mockExecuteRequestCallback));
+    requestManager.process();
+    assertEquals(1, requestManager.numSubmissionsCounters);
+    assertNotNull(requestManager.numInvocationsPerNode.get(nodeId1));
+    assertEquals(1, 
requestManager.numInvocationsPerNode.get(nodeId1).getValue().intValue());
+    assertEquals(1, requestManager.currentLoopSkippedRequests.size());
+    assertEquals(1, requestManager.currentLoopDisabledNodes.size());
+    assertTrue(requestManager.currentLoopDisabledNodes.contains(nodeId1));
+
+    // Complete first request. Second pending request should go through.
+    requestManager.requestFinished(nodeId1);
+    requestManager.process();
+    assertEquals(2, requestManager.numSubmissionsCounters);
+    assertNotNull(requestManager.numInvocationsPerNode.get(nodeId1));
+    assertEquals(2, 
requestManager.numInvocationsPerNode.get(nodeId1).getValue().intValue());
+    assertEquals(0, requestManager.currentLoopSkippedRequests.size());
+    assertEquals(0, requestManager.currentLoopDisabledNodes.size());
+    assertFalse(requestManager.currentLoopDisabledNodes.contains(nodeId1));
+  }
+
+
+  static class RequestManagerForTest extends TaskCommunicator.RequestManager {
+
+    int numSubmissionsCounters = 0;
+    private Map<LlapNodeId, MutableInt> numInvocationsPerNode = new 
HashMap<>();
+
+    public RequestManagerForTest(int numThreads) {
+      super(numThreads);
+    }
+
+    protected void submitToExecutor(TaskCommunicator.CallableRequest request, 
LlapNodeId nodeId) {
+      numSubmissionsCounters++;
+      MutableInt nodeCount = numInvocationsPerNode.get(nodeId);
+      if (nodeCount == null) {
+        nodeCount = new MutableInt(0);
+        numInvocationsPerNode.put(nodeId, nodeCount);
+      }
+      nodeCount.increment();
+    }
+
+    void reset() {
+      numSubmissionsCounters = 0;
+      numInvocationsPerNode.clear();
+    }
+
+  }
+
+  static class CallableRequestForTest extends 
TaskCommunicator.CallableRequest<Message, Message> {
+
+    protected CallableRequestForTest(LlapNodeId nodeId, Message message,
+                                     
TaskCommunicator.ExecuteRequestCallback<Message> callback) {
+      super(nodeId, message, callback);
+    }
+
+    @Override
+    public Message call() throws Exception {
+      return null;
+    }
+  }
+
+}

Reply via email to