Repository: storm
Updated Branches:
  refs/heads/master 37403d17d -> c8947c2fe


[STORM-2686] Add locality awareness to LoadAwareShuffleGrouping


Project: http://git-wip-us.apache.org/repos/asf/storm/repo
Commit: http://git-wip-us.apache.org/repos/asf/storm/commit/01bd4f82
Tree: http://git-wip-us.apache.org/repos/asf/storm/tree/01bd4f82
Diff: http://git-wip-us.apache.org/repos/asf/storm/diff/01bd4f82

Branch: refs/heads/master
Commit: 01bd4f821c940e979c360a4667c27c0477fde9a7
Parents: 352cd46
Author: Ethan Li <ethanopensou...@gmail.com>
Authored: Mon Oct 9 18:41:39 2017 -0500
Committer: Ethan Li <ethanopensou...@gmail.com>
Committed: Tue Oct 10 17:06:31 2017 -0500

----------------------------------------------------------------------
 conf/defaults.yaml                              |   2 +
 .../src/jvm/org/apache/storm/Config.java        |  18 ++
 .../apache/storm/daemon/worker/WorkerState.java |   2 +-
 .../grouping/LoadAwareShuffleGrouping.java      | 178 ++++++++++++++++++-
 .../storm/task/GeneralTopologyContext.java      |   4 +
 .../storm/task/WorkerTopologyContext.java       |  40 ++++-
 .../grouping/LoadAwareShuffleGroupingTest.java  |  37 +++-
 7 files changed, 267 insertions(+), 14 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/storm/blob/01bd4f82/conf/defaults.yaml
----------------------------------------------------------------------
diff --git a/conf/defaults.yaml b/conf/defaults.yaml
index 103b04f..ad34054 100644
--- a/conf/defaults.yaml
+++ b/conf/defaults.yaml
@@ -261,6 +261,8 @@ topology.disruptor.batch.size: 100
 topology.disruptor.batch.timeout.millis: 1
 topology.disable.loadaware.messaging: false
 topology.state.checkpoint.interval.ms: 1000
+topology.localityaware.higher.bound.percent: 0.8
+topology.localityaware.lower.bound.percent: 0.2
 
 # Configs for Resource Aware Scheduler
 # topology priority describing the importance of the topology in decreasing 
importance starting from 0 (i.e. 0 is the highest priority and the priority 
importance decreases as the priority number increases).

http://git-wip-us.apache.org/repos/asf/storm/blob/01bd4f82/storm-client/src/jvm/org/apache/storm/Config.java
----------------------------------------------------------------------
diff --git a/storm-client/src/jvm/org/apache/storm/Config.java 
b/storm-client/src/jvm/org/apache/storm/Config.java
index e296e8f..6be0c21 100644
--- a/storm-client/src/jvm/org/apache/storm/Config.java
+++ b/storm-client/src/jvm/org/apache/storm/Config.java
@@ -65,6 +65,24 @@ public class Config extends HashMap<String, Object> {
     public static final String TOPOLOGY_DISABLE_LOADAWARE_MESSAGING = 
"topology.disable.loadaware.messaging";
 
     /**
+     * This signifies the load congestion among target tasks in scope. 
Currently it's only used in LoadAwareShuffleGrouping.
+     * When the average load is higher than the higher bound, the executor 
should choose target tasks in a higher scope,
+     * The scopes and their orders are: EVERYTHING > RACK_LOCAL > HOST_LOCAL > 
WORKER_LOCAL
+     */
+    @isPositiveNumber
+    @NotNull
+    public static final String TOPOLOGY_LOCALITYAWARE_HIGHER_BOUND_PERCENT = 
"topology.localityaware.higher.bound.percent";
+
+    /**
+     * This signifies the load congestion among target tasks in scope. 
Currently it's only used in LoadAwareShuffleGrouping.
+     * When the average load is lower than the lower bound, the executor 
should choose target tasks in a lower scope.
+     * The scopes and their orders are: EVERYTHING > RACK_LOCAL > HOST_LOCAL > 
WORKER_LOCAL
+     */
+    @isPositiveNumber
+    @NotNull
+    public static final String TOPOLOGY_LOCALITYAWARE_LOWER_BOUND_PERCENT = 
"topology.localityaware.lower.bound.percent";
+
+    /**
      * Try to serialize all tuples, even for local transfers.  This should 
only be used
      * for testing, as a sanity check that all of your tuples are setup 
properly.
      */

http://git-wip-us.apache.org/repos/asf/storm/blob/01bd4f82/storm-client/src/jvm/org/apache/storm/daemon/worker/WorkerState.java
----------------------------------------------------------------------
diff --git 
a/storm-client/src/jvm/org/apache/storm/daemon/worker/WorkerState.java 
b/storm-client/src/jvm/org/apache/storm/daemon/worker/WorkerState.java
index 825de4b..ec2ff59 100644
--- a/storm-client/src/jvm/org/apache/storm/daemon/worker/WorkerState.java
+++ b/storm-client/src/jvm/org/apache/storm/daemon/worker/WorkerState.java
@@ -585,7 +585,7 @@ public class WorkerState {
             return new WorkerTopologyContext(systemTopology, topologyConf, 
taskToComponent, componentToSortedTasks,
                 componentToStreamToFields, topologyId, codeDir, pidDir, port, 
taskIds,
                 defaultSharedResources,
-                userSharedResources);
+                userSharedResources, cachedTaskToNodePort, assignmentId);
         } catch (IOException e) {
             throw Utils.wrapInRuntime(e);
         }

http://git-wip-us.apache.org/repos/asf/storm/blob/01bd4f82/storm-client/src/jvm/org/apache/storm/grouping/LoadAwareShuffleGrouping.java
----------------------------------------------------------------------
diff --git 
a/storm-client/src/jvm/org/apache/storm/grouping/LoadAwareShuffleGrouping.java 
b/storm-client/src/jvm/org/apache/storm/grouping/LoadAwareShuffleGrouping.java
index f5b63ec..3fd75e5 100644
--- 
a/storm-client/src/jvm/org/apache/storm/grouping/LoadAwareShuffleGrouping.java
+++ 
b/storm-client/src/jvm/org/apache/storm/grouping/LoadAwareShuffleGrouping.java
@@ -19,15 +19,29 @@
 package org.apache.storm.grouping;
 
 import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.Sets;
+
 import java.io.Serializable;
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Random;
+import java.util.Set;
 import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicReference;
+
+import org.apache.storm.Config;
 import org.apache.storm.generated.GlobalStreamId;
+import org.apache.storm.generated.NodeInfo;
+import org.apache.storm.networktopography.DNSToSwitchMapping;
 import org.apache.storm.task.WorkerTopologyContext;
+import org.apache.storm.utils.ObjectReader;
+import org.apache.storm.utils.ReflectionUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 public class LoadAwareShuffleGrouping implements 
LoadAwareCustomStreamGrouping, Serializable {
     static final int CAPACITY = 1000;
@@ -40,8 +54,13 @@ public class LoadAwareShuffleGrouping implements 
LoadAwareCustomStreamGrouping,
             this.index = index;
             weight = MAX_WEIGHT;
         }
+
+        void resetWeight() {
+            weight = MAX_WEIGHT;
+        }
     }
 
+    private static final Logger LOG = 
LoggerFactory.getLogger(LoadAwareShuffleGrouping.class);
     private final Map<Integer, IndexAndWeights> orig = new HashMap<>();
     private Random random;
     @VisibleForTesting
@@ -50,10 +69,28 @@ public class LoadAwareShuffleGrouping implements 
LoadAwareCustomStreamGrouping,
     volatile int[] choices;
     private volatile int[] prepareChoices;
     private AtomicInteger current;
+    private Scope currentScope;
+    private NodeInfo sourceNodeInfo;
+    private List<Integer> targetTasks;
+    private AtomicReference<Map<Integer, NodeInfo>> taskToNodePort;
+    private Map<String, Object> conf;
+    private DNSToSwitchMapping dnsToSwitchMapping;
+    private Map<Scope, List<Integer>> localityGroup;
+    private double higherBound;
+    private double lowerBound;
 
     @Override
     public void prepare(WorkerTopologyContext context, GlobalStreamId stream, 
List<Integer> targetTasks) {
         random = new Random();
+        sourceNodeInfo = new NodeInfo(context.getThisWorkerHost(), 
Sets.newHashSet((long) context.getThisWorkerPort()));
+        taskToNodePort = context.getTaskToNodePort();
+        this.targetTasks = targetTasks;
+        conf = context.getConf();
+        dnsToSwitchMapping = ReflectionUtils.newInstance((String) 
conf.get(Config.STORM_NETWORK_TOPOGRAPHY_PLUGIN));
+        localityGroup = new HashMap<>();
+        currentScope = Scope.WORKER_LOCAL;
+        higherBound = 
ObjectReader.getDouble(conf.get(Config.TOPOLOGY_LOCALITYAWARE_HIGHER_BOUND_PERCENT));
+        lowerBound = 
ObjectReader.getDouble(conf.get(Config.TOPOLOGY_LOCALITYAWARE_LOWER_BOUND_PERCENT));
 
         rets = (List<Integer>[]) new List<?>[targetTasks.size()];
         int i = 0;
@@ -93,12 +130,81 @@ public class LoadAwareShuffleGrouping implements 
LoadAwareCustomStreamGrouping,
         updateRing(loadMapping);
     }
 
+    private void refreshLocalityGroup() {
+        Map<Integer, NodeInfo> cachedTaskToNodePort = taskToNodePort.get();
+        Map<String, String> hostToRack = 
getHostToRackMapping(cachedTaskToNodePort);
+
+        localityGroup.values().stream().forEach(v -> v.clear());
+
+        for (int target: targetTasks) {
+            Scope scope = calculateScope(cachedTaskToNodePort, hostToRack, 
target);
+            if (!localityGroup.containsKey(scope)) {
+                localityGroup.put(scope, new ArrayList<>());
+            }
+            localityGroup.get(scope).add(target);
+        }
+    }
+
+    private List<Integer> getTargetsInScope(Scope scope) {
+        List<Integer> rets = new ArrayList<>();
+        List<Integer> targetInScope = localityGroup.get(scope);
+        if (null != targetInScope) {
+            rets.addAll(targetInScope);
+        }
+        Scope downgradeScope = Scope.downgrade(scope);
+        if (downgradeScope != scope) {
+            rets.addAll(getTargetsInScope(downgradeScope));
+        }
+        return rets;
+    }
+
+    private Scope transition(LoadMapping load) {
+        List<Integer> targetInScope = getTargetsInScope(currentScope);
+        if (targetInScope.isEmpty()) {
+            Scope upScope = Scope.upgrade(currentScope);
+            if (upScope == currentScope) {
+                throw new RuntimeException("This executor has no target 
tasks.");
+            }
+            currentScope = upScope;
+            return transition(load);
+        }
+
+        if (null == load) {
+            return currentScope;
+        }
+
+        double avg = targetInScope.stream().mapToDouble((key) -> 
load.get(key)).average().getAsDouble();
+        Scope nextScope;
+        if (avg < lowerBound) {
+            nextScope = Scope.downgrade(currentScope);
+            if (getTargetsInScope(nextScope).isEmpty()) {
+                nextScope = currentScope;
+            }
+        } else if (avg > higherBound) {
+            nextScope = Scope.upgrade(currentScope);
+        } else {
+            nextScope = currentScope;
+        }
+
+        return nextScope;
+    }
+
     private synchronized void updateRing(LoadMapping load) {
+        refreshLocalityGroup();
+        Scope prevScope = currentScope;
+        currentScope = transition(load);
+        if (currentScope != prevScope) {
+            //reset all the weights
+            orig.values().stream().forEach(o -> o.resetWeight());
+        }
+
+        List<Integer> targetsInScope = getTargetsInScope(currentScope);
+
         //We will adjust weights based off of the minimum load
-        double min = load == null ? 0 : 
orig.keySet().stream().mapToDouble((key) -> load.get(key)).min().getAsDouble();
-        for (Map.Entry<Integer, IndexAndWeights> target: orig.entrySet()) {
-            IndexAndWeights val = target.getValue();
-            double l = load == null ? 0.0 : load.get(target.getKey());
+        double min = load == null ? 0 : 
targetsInScope.stream().mapToDouble((key) -> load.get(key)).min().getAsDouble();
+        for (int target: targetsInScope) {
+            IndexAndWeights val = orig.get(target);
+            double l = load == null ? 0.0 : load.get(target);
             if (l <= min + (0.05)) {
                 //We assume that within 5% of the minimum congestion is still 
fine.
                 //Not congested we grow (but slowly)
@@ -109,12 +215,13 @@ public class LoadAwareShuffleGrouping implements 
LoadAwareCustomStreamGrouping,
             }
         }
         //Now we need to build the array
-        long weightSum = orig.values().stream().mapToLong((w) -> 
w.weight).sum();
+        long weightSum = targetsInScope.stream().mapToLong((target) -> 
orig.get(target).weight).sum();
         //Now we can calculate a percentage
 
         int currentIdx = 0;
         if (weightSum > 0) {
-            for (IndexAndWeights indexAndWeights : orig.values()) {
+            for (int target: targetsInScope) {
+                IndexAndWeights indexAndWeights = orig.get(target);
                 int count = (int) ((indexAndWeights.weight / (double) 
weightSum) * CAPACITY);
                 for (int i = 0; i < count && currentIdx < CAPACITY; i++) {
                     prepareChoices[currentIdx] = indexAndWeights.index;
@@ -156,4 +263,63 @@ public class LoadAwareShuffleGrouping implements 
LoadAwareCustomStreamGrouping,
         arr[i] = arr[j];
         arr[j] = tmp;
     }
+
+
+    private Scope calculateScope(Map<Integer, NodeInfo> taskToNodePort, 
Map<String, String> hostToRack, int target) {
+        NodeInfo targetNodeInfo = taskToNodePort.get(target);
+
+        if (targetNodeInfo == null) {
+            return Scope.EVERYTHING;
+        }
+
+        String sourceRack = hostToRack.get(sourceNodeInfo.get_node());
+        String targetRack = hostToRack.get(targetNodeInfo.get_node());
+
+        if(sourceRack != null && targetRack != null && 
sourceRack.equals(targetRack)) {
+            if(sourceNodeInfo.get_node().equals(targetNodeInfo.get_node())) {
+                
if(sourceNodeInfo.get_port().equals(targetNodeInfo.get_port())) {
+                    return Scope.WORKER_LOCAL;
+                }
+                return Scope.HOST_LOCAL;
+            }
+            return Scope.RACK_LOCAL;
+        } else {
+            return Scope.EVERYTHING;
+        }
+    }
+
+    private Map<String, String> getHostToRackMapping(Map<Integer, NodeInfo> 
taskToNodePort) {
+        Set<String> hosts = new HashSet();
+        for (int task: targetTasks) {
+            hosts.add(taskToNodePort.get(task).get_node());
+        }
+        hosts.add(sourceNodeInfo.get_node());
+        return dnsToSwitchMapping.resolve(new ArrayList<>(hosts));
+    }
+
+    enum Scope {
+        WORKER_LOCAL, HOST_LOCAL, RACK_LOCAL, EVERYTHING;
+
+        public static Scope downgrade(Scope current) {
+            switch (current) {
+                case EVERYTHING: return RACK_LOCAL;
+                case RACK_LOCAL: return HOST_LOCAL;
+                case HOST_LOCAL:
+                case WORKER_LOCAL:
+                default:
+                    return WORKER_LOCAL;
+            }
+        }
+
+        public static Scope upgrade(Scope current) {
+            switch (current) {
+                case WORKER_LOCAL: return HOST_LOCAL;
+                case HOST_LOCAL: return RACK_LOCAL;
+                case RACK_LOCAL:
+                case EVERYTHING:
+                default:
+                    return EVERYTHING;
+            }
+        }
+    }
 }
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/storm/blob/01bd4f82/storm-client/src/jvm/org/apache/storm/task/GeneralTopologyContext.java
----------------------------------------------------------------------
diff --git 
a/storm-client/src/jvm/org/apache/storm/task/GeneralTopologyContext.java 
b/storm-client/src/jvm/org/apache/storm/task/GeneralTopologyContext.java
index 6614f94..deae7cf 100644
--- a/storm-client/src/jvm/org/apache/storm/task/GeneralTopologyContext.java
+++ b/storm-client/src/jvm/org/apache/storm/task/GeneralTopologyContext.java
@@ -199,4 +199,8 @@ public class GeneralTopologyContext implements JSONAware {
         }
         return max;
     }
+
+    public Map<String, Object> getConf() {
+        return _topoConf;
+    }
 }

http://git-wip-us.apache.org/repos/asf/storm/blob/01bd4f82/storm-client/src/jvm/org/apache/storm/task/WorkerTopologyContext.java
----------------------------------------------------------------------
diff --git 
a/storm-client/src/jvm/org/apache/storm/task/WorkerTopologyContext.java 
b/storm-client/src/jvm/org/apache/storm/task/WorkerTopologyContext.java
index 2817b65..8c63eb6 100644
--- a/storm-client/src/jvm/org/apache/storm/task/WorkerTopologyContext.java
+++ b/storm-client/src/jvm/org/apache/storm/task/WorkerTopologyContext.java
@@ -17,6 +17,7 @@
  */
 package org.apache.storm.task;
 
+import org.apache.storm.generated.NodeInfo;
 import org.apache.storm.generated.StormTopology;
 import org.apache.storm.tuple.Fields;
 import java.io.File;
@@ -24,6 +25,7 @@ import java.io.IOException;
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.ExecutorService;
+import java.util.concurrent.atomic.AtomicReference;
 
 public class WorkerTopologyContext extends GeneralTopologyContext {
     public static final String SHARED_EXECUTOR = "executor";
@@ -34,6 +36,8 @@ public class WorkerTopologyContext extends 
GeneralTopologyContext {
     private String _pidDir;
     Map<String, Object> _userResources;
     Map<String, Object> _defaultResources;
+    private AtomicReference<Map<Integer, NodeInfo>> taskToNodePort;
+    private String assignmentId;
     
     public WorkerTopologyContext(
             StormTopology topology,
@@ -47,7 +51,9 @@ public class WorkerTopologyContext extends 
GeneralTopologyContext {
             Integer workerPort,
             List<Integer> workerTasks,
             Map<String, Object> defaultResources,
-            Map<String, Object> userResources
+            Map<String, Object> userResources,
+            AtomicReference<Map<Integer, NodeInfo>> taskToNodePort,
+            String assignmentId
             ) {
         super(topology, topoConf, taskToComponent, componentToSortedTasks, 
componentToStreamToFields, stormId);
         _codeDir = codeDir;
@@ -64,6 +70,26 @@ public class WorkerTopologyContext extends 
GeneralTopologyContext {
         }
         _workerPort = workerPort;
         _workerTasks = workerTasks;
+        this.taskToNodePort = taskToNodePort;
+        this.assignmentId = assignmentId;
+
+    }
+
+    public WorkerTopologyContext(
+            StormTopology topology,
+            Map<String, Object> topoConf,
+            Map<Integer, String> taskToComponent,
+            Map<String, List<Integer>> componentToSortedTasks,
+            Map<String, Map<String, Fields>> componentToStreamToFields,
+            String stormId,
+            String codeDir,
+            String pidDir,
+            Integer workerPort,
+            List<Integer> workerTasks,
+            Map<String, Object> defaultResources,
+            Map<String, Object> userResources) {
+        this(topology, topoConf, taskToComponent, componentToSortedTasks, 
componentToStreamToFields, stormId,
+                codeDir, pidDir, workerPort, workerTasks, defaultResources, 
userResources, null, null);
     }
 
     /**
@@ -78,6 +104,18 @@ public class WorkerTopologyContext extends 
GeneralTopologyContext {
         return _workerPort;
     }
 
+    public String getThisWorkerHost() {
+        return assignmentId;
+    }
+
+    /**
+     * Get a map from task Id to NodePort
+     * @return a map from task To NodePort
+     */
+    public AtomicReference<Map<Integer, NodeInfo>> getTaskToNodePort() {
+        return taskToNodePort;
+    }
+
     /**
      * Gets the location of the external resources for this worker on the
      * local filesystem. These external resources typically include bolts 
implemented

http://git-wip-us.apache.org/repos/asf/storm/blob/01bd4f82/storm-client/test/jvm/org/apache/storm/grouping/LoadAwareShuffleGroupingTest.java
----------------------------------------------------------------------
diff --git 
a/storm-client/test/jvm/org/apache/storm/grouping/LoadAwareShuffleGroupingTest.java
 
b/storm-client/test/jvm/org/apache/storm/grouping/LoadAwareShuffleGroupingTest.java
index a5f9304..d704900 100644
--- 
a/storm-client/test/jvm/org/apache/storm/grouping/LoadAwareShuffleGroupingTest.java
+++ 
b/storm-client/test/jvm/org/apache/storm/grouping/LoadAwareShuffleGroupingTest.java
@@ -18,13 +18,19 @@
 package org.apache.storm.grouping;
 
 import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
 import com.google.common.util.concurrent.MoreExecutors;
 import java.util.Arrays;
+
+import org.apache.commons.collections.map.HashedMap;
+import org.apache.storm.Config;
 import org.apache.storm.daemon.GrouperFactory;
 import org.apache.storm.generated.GlobalStreamId;
 import org.apache.storm.generated.Grouping;
+import org.apache.storm.generated.NodeInfo;
 import org.apache.storm.generated.NullStruct;
 import org.apache.storm.task.WorkerTopologyContext;
+import org.junit.Before;
 import org.junit.Ignore;
 import org.junit.Test;
 import org.slf4j.Logger;
@@ -44,20 +50,39 @@ import java.util.concurrent.Future;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.ScheduledThreadPoolExecutor;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicReference;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertTrue;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
 
 public class LoadAwareShuffleGroupingTest {
     public static final double ACCEPTABLE_MARGIN = 0.015;
     private static final Logger LOG = 
LoggerFactory.getLogger(LoadAwareShuffleGroupingTest.class);
 
+    private WorkerTopologyContext mockContext(List<Integer> availableTaskIds) {
+        Map<String, Object> conf = new HashMap<>();
+        conf.put(Config.STORM_NETWORK_TOPOGRAPHY_PLUGIN, 
"org.apache.storm.networktopography.DefaultRackDNSToSwitchMapping");
+        conf.put(Config.TOPOLOGY_LOCALITYAWARE_HIGHER_BOUND_PERCENT, 0.8);
+        conf.put(Config.TOPOLOGY_LOCALITYAWARE_LOWER_BOUND_PERCENT, 0.2);
+
+        WorkerTopologyContext context = mock(WorkerTopologyContext.class);
+        when(context.getConf()).thenReturn(conf);
+        Map<Integer, NodeInfo> taskNodeToPort = new HashMap<>();
+        NodeInfo nodeInfo = new NodeInfo("node-id", Sets.newHashSet(6700L));
+        availableTaskIds.forEach(e -> taskNodeToPort.put(e, nodeInfo));
+        when(context.getTaskToNodePort()).thenReturn(new 
AtomicReference<>(taskNodeToPort));
+        when(context.getThisWorkerHost()).thenReturn("node-id");
+        when(context.getThisWorkerPort()).thenReturn(6700);
+        return context;
+    }
+
     @Test
     public void testUnevenLoadOverTime() throws Exception {
         LoadAwareShuffleGrouping grouping = new LoadAwareShuffleGrouping();
-        WorkerTopologyContext context = mock(WorkerTopologyContext.class);
+        WorkerTopologyContext context = mockContext(Arrays.asList(1, 2));
         grouping.prepare(context, new GlobalStreamId("a", "default"), 
Arrays.asList(1, 2));
         double expectedOneWeight = 100.0;
         double expectedTwoWeight = 100.0;
@@ -122,7 +147,7 @@ public class LoadAwareShuffleGroupingTest {
         final List<Integer> availableTaskIds = getAvailableTaskIds(numTasks);
         final LoadMapping loadMapping = 
buildLocalTasksEvenLoadMapping(availableTaskIds);
 
-        WorkerTopologyContext context = mock(WorkerTopologyContext.class);
+        final WorkerTopologyContext context = mockContext(availableTaskIds);
         grouper.prepare(context, null, availableTaskIds);
 
         // Keep track of how many times we see each taskId
@@ -152,7 +177,7 @@ public class LoadAwareShuffleGroupingTest {
         final List<Integer> availableTaskIds = getAvailableTaskIds(numTasks);
         final LoadMapping loadMapping = 
buildLocalTasksEvenLoadMapping(availableTaskIds);
 
-        final WorkerTopologyContext context = 
mock(WorkerTopologyContext.class);
+        final WorkerTopologyContext context = mockContext(availableTaskIds);
         grouper.prepare(context, null, availableTaskIds);
 
         // force triggers building ring
@@ -222,7 +247,7 @@ public class LoadAwareShuffleGroupingTest {
     public void testShuffleLoadEven() {
         // port test-shuffle-load-even
         LoadAwareCustomStreamGrouping shuffler = GrouperFactory
-            .mkGrouper(null, "comp", "stream", null, Grouping.shuffle(new 
NullStruct()),
+            .mkGrouper(mockContext(Lists.newArrayList(1, 2)), "comp", 
"stream", null, Grouping.shuffle(new NullStruct()),
                 Lists.newArrayList(1, 2), Collections.emptyMap());
         int numMessages = 100000;
         int minPrCount = (int) (numMessages * (0.5 - ACCEPTABLE_MARGIN));
@@ -367,7 +392,7 @@ public class LoadAwareShuffleGroupingTest {
         // Task Id not used, so just pick a static value
         final int inputTaskId = 100;
 
-        WorkerTopologyContext context = mock(WorkerTopologyContext.class);
+        WorkerTopologyContext context = mockContext(availableTaskIds);
         grouper.prepare(context, null, availableTaskIds);
 
         // periodically calls refreshLoad in 1 sec to simulate worker load 
update timer
@@ -405,7 +430,7 @@ public class LoadAwareShuffleGroupingTest {
         // Task Id not used, so just pick a static value
         final int inputTaskId = 100;
 
-        final WorkerTopologyContext context = 
mock(WorkerTopologyContext.class);
+        final WorkerTopologyContext context = mockContext(availableTaskIds);
 
         // Call prepare with our available taskIds
         grouper.prepare(context, null, availableTaskIds);

Reply via email to