http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/main/scala/org/apache/samza/coordinator/JobModelManager.scala
----------------------------------------------------------------------
diff --git 
a/samza-core/src/main/scala/org/apache/samza/coordinator/JobModelManager.scala 
b/samza-core/src/main/scala/org/apache/samza/coordinator/JobModelManager.scala
index 600b7a1..5fb71f3 100644
--- 
a/samza-core/src/main/scala/org/apache/samza/coordinator/JobModelManager.scala
+++ 
b/samza-core/src/main/scala/org/apache/samza/coordinator/JobModelManager.scala
@@ -19,31 +19,33 @@
 
 package org.apache.samza.coordinator
 
-
 import java.util
 import java.util.concurrent.atomic.AtomicReference
-
 import org.apache.samza.config._
 import org.apache.samza.config.JobConfig.Config2Job
 import org.apache.samza.config.SystemConfig.Config2System
 import org.apache.samza.config.TaskConfig.Config2Task
 import org.apache.samza.config.Config
 import 
org.apache.samza.container.grouper.stream.SystemStreamPartitionGrouperFactory
-import org.apache.samza.container.grouper.task.BalancingTaskNameGrouper
-import org.apache.samza.container.grouper.task.TaskNameGrouperFactory
+import org.apache.samza.container.grouper.task._
 import org.apache.samza.container.LocalityManager
 import org.apache.samza.container.TaskName
 import org.apache.samza.coordinator.server.HttpServer
 import org.apache.samza.coordinator.server.JobServlet
+import org.apache.samza.job.model.ContainerModel
 import org.apache.samza.job.model.JobModel
 import org.apache.samza.job.model.TaskModel
+import org.apache.samza.metrics.MetricsRegistry
 import org.apache.samza.metrics.MetricsRegistryMap
 import org.apache.samza.system._
 import org.apache.samza.util.Logging
 import org.apache.samza.util.Util
 import org.apache.samza.Partition
+import org.apache.samza.coordinator.stream.messages.SetContainerHostMapping
+import org.apache.samza.runtime.LocationId
 
 import scala.collection.JavaConverters._
+import scala.collection.JavaConversions._
 
 /**
  * Helper companion object that is responsible for wiring up a JobModelManager
@@ -51,66 +53,145 @@ import scala.collection.JavaConverters._
  */
 object JobModelManager extends Logging {
 
-  val SOURCE = "JobModelManager"
   /**
    * a volatile value to store the current instantiated 
<code>JobModelManager</code>
    */
-  @volatile var currentJobModelManager: JobModelManager = null
+  @volatile var currentJobModelManager: JobModelManager = _
   val jobModelRef: AtomicReference[JobModel] = new AtomicReference[JobModel]()
 
   /**
-   * Does the following actions for a job.
+   * Currently used only in the ApplicationMaster for yarn deployment model.
+   * Does the following:
    * a) Reads the jobModel from coordinator stream using the job's 
configuration.
-   * b) Recomputes changelog partition mapping based on jobModel and job's 
configuration.
+   * b) Recomputes the changelog partition mapping based on jobModel and job's 
configuration.
    * c) Builds JobModelManager using the jobModel read from coordinator stream.
-   * @param config Config from the coordinator stream.
-   * @param changelogPartitionMapping The changelog partition-to-task mapping.
-   * @return JobModelManager
+   * @param config config from the coordinator stream.
+   * @param changelogPartitionMapping changelog partition-to-task mapping of 
the samza job.
+   * @param metricsRegistry the registry for reporting metrics.
+   * @return the instantiated {@see JobModelManager}.
    */
-  def apply(config: Config, changelogPartitionMapping: util.Map[TaskName, 
Integer]) = {
-    val localityManager = new LocalityManager(config, new MetricsRegistryMap())
-
-    // Map the name of each system to the corresponding SystemAdmin
+  def apply(config: Config, changelogPartitionMapping: util.Map[TaskName, 
Integer], metricsRegistry: MetricsRegistry = new MetricsRegistryMap()): 
JobModelManager = {
+    val localityManager = new LocalityManager(config, metricsRegistry)
+    val taskAssignmentManager = new TaskAssignmentManager(config, 
metricsRegistry)
     val systemAdmins = new SystemAdmins(config)
-    val streamMetadataCache = new StreamMetadataCache(systemAdmins, 0)
+    try {
+      systemAdmins.start()
+      val streamMetadataCache = new StreamMetadataCache(systemAdmins, 0)
+      val grouperMetadata: GrouperMetadata = getGrouperMetadata(config, 
localityManager, taskAssignmentManager)
 
-    val containerCount = new JobConfig(config).getContainerCount
-    val processorList = List.range(0, containerCount).map(c => c.toString)
+      val jobModel: JobModel = readJobModel(config, changelogPartitionMapping, 
streamMetadataCache, grouperMetadata)
+      jobModelRef.set(new JobModel(jobModel.getConfig, jobModel.getContainers, 
localityManager))
 
-    systemAdmins.start()
-    val jobModelManager = getJobModelManager(config, 
changelogPartitionMapping, localityManager, streamMetadataCache, 
processorList.asJava)
-    systemAdmins.stop()
+      updateTaskAssignments(jobModel, taskAssignmentManager, grouperMetadata)
 
-    jobModelManager
+      val server = new HttpServer
+      server.addServlet("/", new JobServlet(jobModelRef))
+
+      currentJobModelManager = new JobModelManager(jobModel, server, 
localityManager)
+      currentJobModelManager
+    } finally {
+      taskAssignmentManager.close()
+      systemAdmins.stop()
+      // Not closing localityManager, since {@code ClusterBasedJobCoordinator} 
uses it to read container locality through {@code JobModel}.
+    }
   }
 
   /**
-   * Build a JobModelManager using a Samza job's configuration.
-   */
-  private def getJobModelManager(config: Config,
-                                changeLogMapping: util.Map[TaskName, Integer],
-                                localityManager: LocalityManager,
-                                streamMetadataCache: StreamMetadataCache,
-                                containerIds: java.util.List[String]) = {
-    val jobModel: JobModel = readJobModel(config, changeLogMapping, 
localityManager, streamMetadataCache, containerIds)
-    jobModelRef.set(jobModel)
-
-    val server = new HttpServer
-    server.addServlet("/", new JobServlet(jobModelRef))
-    currentJobModelManager = new JobModelManager(jobModel, server, 
localityManager)
-    currentJobModelManager
+    * Builds the {@see GrouperMetadataImpl} for the samza job.
+    * @param config represents the configurations defined by the user.
+    * @param localityManager provides the processor to host mapping persisted 
to the metadata store.
+    * @param taskAssignmentManager provides the processor to task assignments 
persisted to the metadata store.
+    * @return the instantiated {@see GrouperMetadata}.
+    */
+  def getGrouperMetadata(config: Config, localityManager: LocalityManager, 
taskAssignmentManager: TaskAssignmentManager) = {
+    val processorLocality: util.Map[String, LocationId] = 
getProcessorLocality(config, localityManager)
+    val taskAssignment: util.Map[String, String] = 
taskAssignmentManager.readTaskAssignment()
+    val taskNameToProcessorId: util.Map[TaskName, String] = new 
util.HashMap[TaskName, String]()
+    for ((taskName, processorId) <- taskAssignment) {
+      taskNameToProcessorId.put(new TaskName(taskName), processorId)
+    }
+
+    val taskLocality:util.Map[TaskName, LocationId] = new 
util.HashMap[TaskName, LocationId]()
+    for ((taskName, processorId) <- taskAssignment) {
+      if (processorLocality.containsKey(processorId)) {
+        taskLocality.put(new TaskName(taskName), 
processorLocality.get(processorId))
+      }
+    }
+    new GrouperMetadataImpl(processorLocality, taskLocality, new 
util.HashMap[TaskName, util.List[SystemStreamPartition]](), 
taskNameToProcessorId)
   }
 
   /**
-   * For each input stream specified in config, exactly determine its
-   * partitions, returning a set of SystemStreamPartitions containing them all.
-   */
-  private def getInputStreamPartitions(config: Config, streamMetadataCache: 
StreamMetadataCache) = {
+    * Retrieves and returns the processor locality of a samza job using 
provided {@see Config} and {@see LocalityManager}.
+    * @param config provides the configurations defined by the user. Required 
to connect to the storage layer.
+    * @param localityManager provides the processor to host mapping persisted 
to the metadata store.
+    * @return the processor locality.
+    */
+  def getProcessorLocality(config: Config, localityManager: LocalityManager) = 
{
+    val containerToLocationId: util.Map[String, LocationId] = new 
util.HashMap[String, LocationId]()
+    val existingContainerLocality = localityManager.readContainerLocality()
+
+    for (containerId <- 0 to config.getContainerCount) {
+      val localityMapping = existingContainerLocality.get(containerId.toString)
+      // To handle the case when the container count is increased between two 
different runs of a samza-yarn job,
+      // set the locality of newly added containers to any_host.
+      var locationId: LocationId = new LocationId("ANY_HOST")
+      if (localityMapping != null && 
localityMapping.containsKey(SetContainerHostMapping.HOST_KEY)) {
+        locationId = new 
LocationId(localityMapping.get(SetContainerHostMapping.HOST_KEY))
+      }
+      containerToLocationId.put(containerId.toString, locationId)
+    }
+
+    containerToLocationId
+  }
+
+  /**
+    * This method does the following:
+    * 1. Deletes the existing task assignments if the partition-task grouping 
has changed from the previous run of the job.
+    * 2. Saves the newly generated task assignments to the storage layer 
through the {@param TaskAssignementManager}.
+    *
+    * @param jobModel              represents the {@see JobModel} of the samza 
job.
+    * @param taskAssignmentManager required to persist the processor to task 
assignments to the storage layer.
+    * @param grouperMetadata       provides the historical metadata of the 
application.
+    */
+  def updateTaskAssignments(jobModel: JobModel, taskAssignmentManager: 
TaskAssignmentManager, grouperMetadata: GrouperMetadata): Unit = {
+    val taskNames: util.Set[String] = new util.HashSet[String]()
+    for (container <- jobModel.getContainers.values()) {
+      for (taskModel <- container.getTasks.values()) {
+        taskNames.add(taskModel.getTaskName.getTaskName)
+      }
+    }
+    val taskToContainerId = 
grouperMetadata.getPreviousTaskToProcessorAssignment
+    if (taskNames.size() != taskToContainerId.size()) {
+      warn("Current task count {} does not match saved task count {}. Stateful 
jobs may observe misalignment of keys!",
+           taskNames.size(), taskToContainerId.size())
+      // If the tasks changed, then the partition-task grouping is also likely 
changed and we can't handle that
+      // without a much more complicated mapping. Further, the partition count 
may have changed, which means
+      // input message keys are likely reshuffled w.r.t. partitions, so the 
local state may not contain necessary
+      // data associated with the incoming keys. Warn the user and default to 
grouper
+      // In this scenario the tasks may have been reduced, so we need to 
delete all the existing messages
+      taskAssignmentManager.deleteTaskContainerMappings(taskNames)
+    }
+
+    for (container <- jobModel.getContainers.values()) {
+      for (taskName <- container.getTasks.keySet) {
+        taskAssignmentManager.writeTaskContainerMapping(taskName.getTaskName, 
container.getId)
+      }
+    }
+  }
+
+  /**
+    * Computes the input system stream partitions of a samza job using the 
provided {@param config}
+    * and {@param streamMetadataCache}.
+    * @param config the configuration of the job.
+    * @param streamMetadataCache to query the partition metadata of the input 
streams.
+    * @return the input {@see SystemStreamPartition} of the samza job.
+    */
+  private def getInputStreamPartitions(config: Config, streamMetadataCache: 
StreamMetadataCache): Set[SystemStreamPartition] = {
     val inputSystemStreams = config.getInputStreams
 
     // Get the set of partitions for each SystemStream from the stream metadata
     streamMetadataCache
-      .getStreamMetadata(inputSystemStreams, true)
+      .getStreamMetadata(inputSystemStreams, partitionsMetadataOnly = true)
       .flatMap {
         case (systemStream, metadata) =>
           metadata
@@ -121,55 +202,69 @@ object JobModelManager extends Logging {
       }.toSet
   }
 
+  /**
+    * Builds the input {@see SystemStreamPartition} based upon the {@param 
config} defined by the user.
+    * @param config configuration to fetch the metadata of the input streams.
+    * @param streamMetadataCache required to query the partition metadata of 
the input streams.
+    * @return the input SystemStreamPartitions of the job.
+    */
   private def getMatchedInputStreamPartitions(config: Config, 
streamMetadataCache: StreamMetadataCache): Set[SystemStreamPartition] = {
     val allSystemStreamPartitions = getInputStreamPartitions(config, 
streamMetadataCache)
     config.getSSPMatcherClass match {
-      case Some(s) => {
+      case Some(s) =>
         val jfr = config.getSSPMatcherConfigJobFactoryRegex.r
         config.getStreamJobFactoryClass match {
-          case Some(jfr(_*)) => {
-            info("before match: allSystemStreamPartitions.size = %s" format 
(allSystemStreamPartitions.size))
+          case Some(jfr(_*)) =>
+            info("before match: allSystemStreamPartitions.size = %s" format 
allSystemStreamPartitions.size)
             val sspMatcher = Util.getObj(s, 
classOf[SystemStreamPartitionMatcher])
             val matchedPartitions = 
sspMatcher.filter(allSystemStreamPartitions.asJava, config).asScala.toSet
             // Usually a small set hence ok to log at info level
-            info("after match: matchedPartitions = %s" format 
(matchedPartitions))
+            info("after match: matchedPartitions = %s" format 
matchedPartitions)
             matchedPartitions
-          }
           case _ => allSystemStreamPartitions
         }
-      }
       case _ => allSystemStreamPartitions
     }
   }
 
   /**
-   * Gets a SystemStreamPartitionGrouper object from the configuration.
-   */
+    * Finds the {@see SystemStreamPartitionGrouperFactory} from the {@param 
config}. Instantiates the  {@see SystemStreamPartitionGrouper}
+    * object through the factory.
+    * @param config the configuration of the samza job.
+    * @return the instantiated {@see SystemStreamPartitionGrouper}.
+    */
   private def getSystemStreamPartitionGrouper(config: Config) = {
     val factoryString = config.getSystemStreamPartitionGrouperFactory
     val factory = Util.getObj(factoryString, 
classOf[SystemStreamPartitionGrouperFactory])
     factory.getSystemStreamPartitionGrouper(config)
   }
 
+
   /**
-   * The function reads the latest checkpoint from the underlying coordinator 
stream and
-   * builds a new JobModel.
-   */
+    * Does the following:
+    * 1. Fetches metadata of the input streams defined in configuration 
through {@param streamMetadataCache}.
+    * 2. Applies the {@see SystemStreamPartitionGrouper}, {@see 
TaskNameGrouper} defined in the configuration
+    * to build the {@see JobModel}.
+    * @param config the configuration of the job.
+    * @param changeLogPartitionMapping the task to changelog partition mapping 
of the job.
+    * @param streamMetadataCache the cache that holds the partition metadata 
of the input streams.
+    * @param grouperMetadata provides the historical metadata of the 
application.
+    * @return the built {@see JobModel}.
+    */
   def readJobModel(config: Config,
                    changeLogPartitionMapping: util.Map[TaskName, Integer],
-                   localityManager: LocalityManager,
                    streamMetadataCache: StreamMetadataCache,
-                   containerIds: java.util.List[String]): JobModel = {
+                   grouperMetadata: GrouperMetadata): JobModel = {
     // Do grouping to fetch TaskName to SSP mapping
     val allSystemStreamPartitions = getMatchedInputStreamPartitions(config, 
streamMetadataCache)
 
     // processor list is required by some of the groupers. So, let's pass them 
as part of the config.
     // Copy the config and add the processor list to the config copy.
     val configMap = new util.HashMap[String, String](config)
-    configMap.put(JobConfig.PROCESSOR_LIST, String.join(",", containerIds))
+    configMap.put(JobConfig.PROCESSOR_LIST, String.join(",", 
grouperMetadata.getProcessorLocality.keySet()))
     val grouper = getSystemStreamPartitionGrouper(new MapConfig(configMap))
 
-    val groups = grouper.group(allSystemStreamPartitions.asJava)
+    val groups = grouper.group(allSystemStreamPartitions)
     info("SystemStreamPartitionGrouper %s has grouped the 
SystemStreamPartitions into %d tasks with the following taskNames: %s" 
format(grouper, groups.size(), groups.keySet()))
 
     val isHostAffinityEnabled = new 
ClusterManagerConfig(config).getHostAffinityEnabled
@@ -200,22 +295,18 @@ object JobModelManager extends Logging {
     // SSPTaskNameGrouper for locality, load-balancing, etc.
     val containerGrouperFactory = 
Util.getObj(config.getTaskNameGrouperFactory, classOf[TaskNameGrouperFactory])
     val containerGrouper = containerGrouperFactory.build(config)
-    val containerModels = {
-      containerGrouper match {
-        case grouper: BalancingTaskNameGrouper if isHostAffinityEnabled => 
grouper.balance(taskModels.asJava, localityManager)
-        case _ => containerGrouper.group(taskModels.asJava, containerIds)
-      }
-    }
-    val containerMap = containerModels.asScala.map { case (containerModel) => 
containerModel.getId -> containerModel }.toMap
-
-    if (isHostAffinityEnabled) {
-      new JobModel(config, containerMap.asJava, localityManager)
+    var containerModels: util.Set[ContainerModel] = null
+    if(isHostAffinityEnabled) {
+      containerModels = containerGrouper.group(taskModels, grouperMetadata)
     } else {
-      new JobModel(config, containerMap.asJava)
+      containerModels = containerGrouper.group(taskModels, new 
util.ArrayList[String](grouperMetadata.getProcessorLocality.keySet()))
     }
+    val containerMap = containerModels.asScala.map(containerModel => 
containerModel.getId -> containerModel).toMap
+
+    new JobModel(config, containerMap.asJava)
   }
 
-  private def getSystemNames(config: Config) = config.getSystemNames.toSet
+  private def getSystemNames(config: Config) = config.getSystemNames().toSet
 }
 
 /**
@@ -248,7 +339,7 @@ class JobModelManager(
 
   debug("Got job model: %s." format jobModel)
 
-  def start {
+  def start() {
     if (server != null) {
       debug("Starting HTTP server.")
       server.start
@@ -256,7 +347,7 @@ class JobModelManager(
     }
   }
 
-  def stop {
+  def stop() {
     if (server != null) {
       debug("Stopping HTTP server.")
       server.stop

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/main/scala/org/apache/samza/job/local/ProcessJobFactory.scala
----------------------------------------------------------------------
diff --git 
a/samza-core/src/main/scala/org/apache/samza/job/local/ProcessJobFactory.scala 
b/samza-core/src/main/scala/org/apache/samza/job/local/ProcessJobFactory.scala
index 64f516b..d16c294 100644
--- 
a/samza-core/src/main/scala/org/apache/samza/job/local/ProcessJobFactory.scala
+++ 
b/samza-core/src/main/scala/org/apache/samza/job/local/ProcessJobFactory.scala
@@ -50,7 +50,7 @@ class ProcessJobFactory extends StreamJobFactory with Logging 
{
     coordinatorStreamManager.bootstrap
     val changelogStreamManager = new 
ChangelogStreamManager(coordinatorStreamManager)
 
-    val coordinator = JobModelManager(coordinatorStreamManager.getConfig, 
changelogStreamManager.readPartitionMapping())
+    val coordinator = JobModelManager(coordinatorStreamManager.getConfig, 
changelogStreamManager.readPartitionMapping(), metricsRegistry)
     val jobModel = coordinator.jobModel
 
     val taskPartitionMappings: util.Map[TaskName, Integer] = new 
util.HashMap[TaskName, Integer]

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala
----------------------------------------------------------------------
diff --git 
a/samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala 
b/samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala
index 5a8d2f8..e4a7838 100644
--- 
a/samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala
+++ 
b/samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala
@@ -52,7 +52,7 @@ class ThreadJobFactory extends StreamJobFactory with Logging {
     coordinatorStreamManager.bootstrap
     val changelogStreamManager = new 
ChangelogStreamManager(coordinatorStreamManager)
 
-    val coordinator = JobModelManager(coordinatorStreamManager.getConfig, 
changelogStreamManager.readPartitionMapping())
+    val coordinator = JobModelManager(coordinatorStreamManager.getConfig, 
changelogStreamManager.readPartitionMapping(), metricsRegistry)
 
     val jobModel = coordinator.jobModel
 

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerCount.java
----------------------------------------------------------------------
diff --git 
a/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerCount.java
 
b/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerCount.java
index 0c2f2fb..9e6e8d0 100644
--- 
a/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerCount.java
+++ 
b/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerCount.java
@@ -24,54 +24,34 @@ import java.util.HashSet;
 import java.util.Map;
 import java.util.Set;
 import java.util.UUID;
-import org.apache.samza.SamzaException;
-import org.apache.samza.config.Config;
-import org.apache.samza.config.JobConfig;
-import org.apache.samza.config.MapConfig;
-import org.apache.samza.container.LocalityManager;
+
+import org.apache.samza.container.TaskName;
 import org.apache.samza.job.model.ContainerModel;
 import org.apache.samza.job.model.TaskModel;
-import org.junit.Before;
+import org.apache.samza.SamzaException;
 import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.mockito.Mockito;
-import org.powermock.api.mockito.PowerMockito;
-import org.powermock.core.classloader.annotations.PrepareForTest;
-import org.powermock.modules.junit4.PowerMockRunner;
 
 import static org.apache.samza.container.mock.ContainerMocks.*;
 import static org.junit.Assert.*;
-import static org.mockito.Mockito.*;
 
-@RunWith(PowerMockRunner.class)
-@PrepareForTest({TaskAssignmentManager.class, GroupByContainerCount.class})
 public class TestGroupByContainerCount {
-  private TaskAssignmentManager taskAssignmentManager;
-  private LocalityManager localityManager;
-  @Before
-  public void setup() throws Exception {
-    taskAssignmentManager = mock(TaskAssignmentManager.class);
-    localityManager = mock(LocalityManager.class);
-    
PowerMockito.whenNew(TaskAssignmentManager.class).withAnyArguments().thenReturn(taskAssignmentManager);
-    Mockito.doNothing().when(taskAssignmentManager).init();
-  }
 
   @Test(expected = IllegalArgumentException.class)
   public void testGroupEmptyTasks() {
-    new GroupByContainerCount(getConfig(1)).group(new HashSet());
+    new GroupByContainerCount(1).group(new HashSet<>());
   }
 
   @Test(expected = IllegalArgumentException.class)
   public void testGroupFewerTasksThanContainers() {
     Set<TaskModel> taskModels = new HashSet<>();
     taskModels.add(getTaskModel(1));
-    new GroupByContainerCount(getConfig(2)).group(taskModels);
+    new GroupByContainerCount(2).group(taskModels);
   }
 
   @Test(expected = UnsupportedOperationException.class)
   public void testGrouperResultImmutable() {
     Set<TaskModel> taskModels = generateTaskModels(3);
-    Set<ContainerModel> containers = new 
GroupByContainerCount(getConfig(3)).group(taskModels);
+    Set<ContainerModel> containers = new 
GroupByContainerCount(3).group(taskModels);
     containers.remove(containers.iterator().next());
   }
 
@@ -79,7 +59,7 @@ public class TestGroupByContainerCount {
   public void testGroupHappyPath() {
     Set<TaskModel> taskModels = generateTaskModels(5);
 
-    Set<ContainerModel> containers = new 
GroupByContainerCount(getConfig(2)).group(taskModels);
+    Set<ContainerModel> containers = new 
GroupByContainerCount(2).group(taskModels);
 
     Map<String, ContainerModel> containersMap = new HashMap<>();
     for (ContainerModel container : containers) {
@@ -106,7 +86,7 @@ public class TestGroupByContainerCount {
   public void testGroupManyTasks() {
     Set<TaskModel> taskModels = generateTaskModels(21);
 
-    Set<ContainerModel> containers = new 
GroupByContainerCount(getConfig(2)).group(taskModels);
+    Set<ContainerModel> containers = new 
GroupByContainerCount(2).group(taskModels);
 
     Map<String, ContainerModel> containersMap = new HashMap<>();
     for (ContainerModel container : containers) {
@@ -174,11 +154,11 @@ public class TestGroupByContainerCount {
   @Test
   public void testBalancerAfterContainerIncrease() {
     Set<TaskModel> taskModels = generateTaskModels(9);
-    Set<ContainerModel> prevContainers = new 
GroupByContainerCount(getConfig(2)).group(taskModels);
-    Map<String, String> prevTaskToContainerMapping = 
generateTaskContainerMapping(prevContainers);
-    
when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+    Set<ContainerModel> prevContainers = new 
GroupByContainerCount(2).group(taskModels);
+    Map<TaskName, String> prevTaskToContainerMapping = 
generateTaskContainerMapping(prevContainers);
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new 
HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
 
-    Set<ContainerModel> containers = new 
GroupByContainerCount(getConfig(4)).balance(taskModels, localityManager);
+    Set<ContainerModel> containers = new 
GroupByContainerCount(4).group(taskModels, grouperMetadata);
 
     Map<String, ContainerModel> containersMap = new HashMap<>();
     for (ContainerModel container : containers) {
@@ -213,22 +193,6 @@ public class TestGroupByContainerCount {
     assertTrue(container2.getTasks().containsKey(getTaskName(6)));
     assertTrue(container3.getTasks().containsKey(getTaskName(5)));
     assertTrue(container3.getTasks().containsKey(getTaskName(7)));
-
-    // Verify task mappings are saved
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(),
 "0");
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(),
 "0");
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(4).getTaskName(),
 "0");
-
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(),
 "1");
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(3).getTaskName(),
 "1");
-
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(8).getTaskName(),
 "2");
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(6).getTaskName(),
 "2");
-
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(5).getTaskName(),
 "3");
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(7).getTaskName(),
 "3");
-
-    verify(taskAssignmentManager, 
never()).deleteTaskContainerMappings(anyCollection());
   }
 
   /**
@@ -256,11 +220,11 @@ public class TestGroupByContainerCount {
   @Test
   public void testBalancerAfterContainerDecrease() {
     Set<TaskModel> taskModels = generateTaskModels(9);
-    Set<ContainerModel> prevContainers = new 
GroupByContainerCount(getConfig(4)).group(taskModels);
-    Map<String, String> prevTaskToContainerMapping = 
generateTaskContainerMapping(prevContainers);
-    
when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+    Set<ContainerModel> prevContainers = new 
GroupByContainerCount(4).group(taskModels);
+    Map<TaskName, String> prevTaskToContainerMapping = 
generateTaskContainerMapping(prevContainers);
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new 
HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
 
-    Set<ContainerModel> containers = new 
GroupByContainerCount(getConfig(2)).balance(taskModels, localityManager);
+    Set<ContainerModel> containers = new 
GroupByContainerCount(2).group(taskModels, grouperMetadata);
 
     Map<String, ContainerModel> containersMap = new HashMap<>();
     for (ContainerModel container : containers) {
@@ -290,20 +254,6 @@ public class TestGroupByContainerCount {
     assertTrue(container0.getTasks().containsKey(getTaskName(2)));
     assertTrue(container1.getTasks().containsKey(getTaskName(7)));
     assertTrue(container1.getTasks().containsKey(getTaskName(3)));
-
-    // Verify task mappings are saved
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(),
 "0");
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(4).getTaskName(),
 "0");
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(8).getTaskName(),
 "0");
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(6).getTaskName(),
 "0");
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(),
 "0");
-
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(),
 "1");
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(5).getTaskName(),
 "1");
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(7).getTaskName(),
 "1");
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(3).getTaskName(),
 "1");
-
-    verify(taskAssignmentManager, 
never()).deleteTaskContainerMappings(anyCollection());
   }
 
   /**
@@ -331,15 +281,15 @@ public class TestGroupByContainerCount {
    *  T8  T7  T3
    */
   @Test
-  public void testBalancerMultipleReblances() throws Exception {
+  public void testBalancerMultipleReblances() {
     // Before
     Set<TaskModel> taskModels = generateTaskModels(9);
-    Set<ContainerModel> prevContainers = new 
GroupByContainerCount(getConfig(4)).group(taskModels);
-    Map<String, String> prevTaskToContainerMapping = 
generateTaskContainerMapping(prevContainers);
-    
when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+    Set<ContainerModel> prevContainers = new 
GroupByContainerCount(4).group(taskModels);
+    Map<TaskName, String> prevTaskToContainerMapping = 
generateTaskContainerMapping(prevContainers);
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new 
HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
 
     // First balance
-    Set<ContainerModel> containers = new 
GroupByContainerCount(getConfig(2)).balance(taskModels, localityManager);
+    Set<ContainerModel> containers = new 
GroupByContainerCount(2).group(taskModels, grouperMetadata);
 
     Map<String, ContainerModel> containersMap = new HashMap<>();
     for (ContainerModel container : containers) {
@@ -370,30 +320,11 @@ public class TestGroupByContainerCount {
     assertTrue(container1.getTasks().containsKey(getTaskName(7)));
     assertTrue(container1.getTasks().containsKey(getTaskName(3)));
 
-    // Verify task mappings are saved
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(),
 "0");
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(4).getTaskName(),
 "0");
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(8).getTaskName(),
 "0");
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(6).getTaskName(),
 "0");
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(),
 "0");
-
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(),
 "1");
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(5).getTaskName(),
 "1");
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(7).getTaskName(),
 "1");
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(3).getTaskName(),
 "1");
-
-    verify(taskAssignmentManager, 
never()).deleteTaskContainerMappings(anyCollection());
-
-
     // Second balance
     prevTaskToContainerMapping = generateTaskContainerMapping(prevContainers);
 
-    TaskAssignmentManager taskAssignmentManager2 = 
mock(TaskAssignmentManager.class);
-    
when(taskAssignmentManager2.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
-    LocalityManager localityManager2 = mock(LocalityManager.class);
-    
PowerMockito.whenNew(TaskAssignmentManager.class).withAnyArguments().thenReturn(taskAssignmentManager2);
-
-    containers = new GroupByContainerCount(getConfig(3)).balance(taskModels, 
localityManager2);
+    GrouperMetadataImpl grouperMetadata1 = new GrouperMetadataImpl(new 
HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
+    containers = new GroupByContainerCount(3).group(taskModels, 
grouperMetadata1);
 
     containersMap = new HashMap<>();
     for (ContainerModel container : containers) {
@@ -427,21 +358,6 @@ public class TestGroupByContainerCount {
     assertTrue(container2.getTasks().containsKey(getTaskName(6)));
     assertTrue(container2.getTasks().containsKey(getTaskName(2)));
     assertTrue(container2.getTasks().containsKey(getTaskName(3)));
-
-    // Verify task mappings are saved
-    
verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(0).getTaskName(),
 "0");
-    
verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(4).getTaskName(),
 "0");
-    
verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(8).getTaskName(),
 "0");
-
-    
verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(1).getTaskName(),
 "1");
-    
verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(5).getTaskName(),
 "1");
-    
verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(7).getTaskName(),
 "1");
-
-    
verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(6).getTaskName(),
 "2");
-    
verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(2).getTaskName(),
 "2");
-    
verify(taskAssignmentManager2).writeTaskContainerMapping(getTaskName(3).getTaskName(),
 "2");
-
-    verify(taskAssignmentManager2, 
never()).deleteTaskContainerMappings(anyCollection());
   }
 
   /**
@@ -466,11 +382,11 @@ public class TestGroupByContainerCount {
   @Test
   public void testBalancerAfterContainerSame() {
     Set<TaskModel> taskModels = generateTaskModels(9);
-    Set<ContainerModel> prevContainers = new 
GroupByContainerCount(getConfig(2)).group(taskModels);
-    Map<String, String> prevTaskToContainerMapping = 
generateTaskContainerMapping(prevContainers);
-    
when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+    Set<ContainerModel> prevContainers = new 
GroupByContainerCount(2).group(taskModels);
+    Map<TaskName, String> prevTaskToContainerMapping = 
generateTaskContainerMapping(prevContainers);
 
-    Set<ContainerModel> containers = new 
GroupByContainerCount(getConfig(2)).balance(taskModels, localityManager);
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new 
HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
+    Set<ContainerModel> containers = new 
GroupByContainerCount(2).group(taskModels, grouperMetadata);
 
     Map<String, ContainerModel> containersMap = new HashMap<>();
     for (ContainerModel container : containers) {
@@ -496,9 +412,6 @@ public class TestGroupByContainerCount {
     assertTrue(container1.getTasks().containsKey(getTaskName(3)));
     assertTrue(container1.getTasks().containsKey(getTaskName(5)));
     assertTrue(container1.getTasks().containsKey(getTaskName(7)));
-
-    verify(taskAssignmentManager, 
never()).writeTaskContainerMapping(anyString(), anyString());
-    verify(taskAssignmentManager, 
never()).deleteTaskContainerMappings(anyCollection());
   }
 
   /**
@@ -528,19 +441,19 @@ public class TestGroupByContainerCount {
   public void testBalancerAfterContainerSameCustomAssignment() {
     Set<TaskModel> taskModels = generateTaskModels(9);
 
-    Map<String, String> prevTaskToContainerMapping = new HashMap<>();
-    prevTaskToContainerMapping.put(getTaskName(0).getTaskName(), "0");
-    prevTaskToContainerMapping.put(getTaskName(1).getTaskName(), "0");
-    prevTaskToContainerMapping.put(getTaskName(2).getTaskName(), "0");
-    prevTaskToContainerMapping.put(getTaskName(3).getTaskName(), "0");
-    prevTaskToContainerMapping.put(getTaskName(4).getTaskName(), "0");
-    prevTaskToContainerMapping.put(getTaskName(5).getTaskName(), "0");
-    prevTaskToContainerMapping.put(getTaskName(6).getTaskName(), "1");
-    prevTaskToContainerMapping.put(getTaskName(7).getTaskName(), "1");
-    prevTaskToContainerMapping.put(getTaskName(8).getTaskName(), "1");
-    
when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+    Map<TaskName, String> prevTaskToContainerMapping = new HashMap<>();
+    prevTaskToContainerMapping.put(getTaskName(0), "0");
+    prevTaskToContainerMapping.put(getTaskName(1), "0");
+    prevTaskToContainerMapping.put(getTaskName(2), "0");
+    prevTaskToContainerMapping.put(getTaskName(3), "0");
+    prevTaskToContainerMapping.put(getTaskName(4), "0");
+    prevTaskToContainerMapping.put(getTaskName(5), "0");
+    prevTaskToContainerMapping.put(getTaskName(6), "1");
+    prevTaskToContainerMapping.put(getTaskName(7), "1");
+    prevTaskToContainerMapping.put(getTaskName(8), "1");
 
-    Set<ContainerModel> containers = new 
GroupByContainerCount(getConfig(2)).balance(taskModels, localityManager);
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new 
HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
+    Set<ContainerModel> containers = new 
GroupByContainerCount(2).group(taskModels, grouperMetadata);
 
     Map<String, ContainerModel> containersMap = new HashMap<>();
     for (ContainerModel container : containers) {
@@ -566,9 +479,6 @@ public class TestGroupByContainerCount {
     assertTrue(container1.getTasks().containsKey(getTaskName(6)));
     assertTrue(container1.getTasks().containsKey(getTaskName(7)));
     assertTrue(container1.getTasks().containsKey(getTaskName(8)));
-
-    verify(taskAssignmentManager, 
never()).writeTaskContainerMapping(anyString(), anyString());
-    verify(taskAssignmentManager, 
never()).deleteTaskContainerMappings(anyCollection());
   }
 
   /**
@@ -597,16 +507,16 @@ public class TestGroupByContainerCount {
   public void 
testBalancerAfterContainerSameCustomAssignmentAndContainerIncrease() {
     Set<TaskModel> taskModels = generateTaskModels(6);
 
-    Map<String, String> prevTaskToContainerMapping = new HashMap<>();
-    prevTaskToContainerMapping.put(getTaskName(0).getTaskName(), "0");
-    prevTaskToContainerMapping.put(getTaskName(1).getTaskName(), "1");
-    prevTaskToContainerMapping.put(getTaskName(2).getTaskName(), "1");
-    prevTaskToContainerMapping.put(getTaskName(3).getTaskName(), "1");
-    prevTaskToContainerMapping.put(getTaskName(4).getTaskName(), "1");
-    prevTaskToContainerMapping.put(getTaskName(5).getTaskName(), "1");
-    
when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+    Map<TaskName, String> prevTaskToContainerMapping = new HashMap<>();
+    prevTaskToContainerMapping.put(getTaskName(0), "0");
+    prevTaskToContainerMapping.put(getTaskName(1), "1");
+    prevTaskToContainerMapping.put(getTaskName(2), "1");
+    prevTaskToContainerMapping.put(getTaskName(3), "1");
+    prevTaskToContainerMapping.put(getTaskName(4), "1");
+    prevTaskToContainerMapping.put(getTaskName(5), "1");
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new 
HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
 
-    Set<ContainerModel> containers = new 
GroupByContainerCount(getConfig(3)).balance(taskModels, localityManager);
+    Set<ContainerModel> containers = new 
GroupByContainerCount(3).group(taskModels, grouperMetadata);
 
     Map<String, ContainerModel> containersMap = new HashMap<>();
     for (ContainerModel container : containers) {
@@ -633,146 +543,106 @@ public class TestGroupByContainerCount {
     assertTrue(container1.getTasks().containsKey(getTaskName(2)));
     assertTrue(container2.getTasks().containsKey(getTaskName(4)));
     assertTrue(container2.getTasks().containsKey(getTaskName(3)));
-
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(),
 "0");
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(),
 "1");
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(),
 "1");
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(3).getTaskName(),
 "2");
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(4).getTaskName(),
 "2");
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(5).getTaskName(),
 "0");
-
-    verify(taskAssignmentManager, 
never()).deleteTaskContainerMappings(anyCollection());
   }
 
   @Test
   public void testBalancerOldContainerCountOne() {
     Set<TaskModel> taskModels = generateTaskModels(3);
-    Set<ContainerModel> prevContainers = new 
GroupByContainerCount(getConfig(1)).group(taskModels);
-    Map<String, String> prevTaskToContainerMapping = 
generateTaskContainerMapping(prevContainers);
-    
when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+    Set<ContainerModel> prevContainers = new 
GroupByContainerCount(1).group(taskModels);
+    Map<TaskName, String> prevTaskToContainerMapping = 
generateTaskContainerMapping(prevContainers);
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new 
HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
 
-    Set<ContainerModel> groupContainers = new 
GroupByContainerCount(getConfig(3)).group(taskModels);
-    Set<ContainerModel> balanceContainers = new 
GroupByContainerCount(getConfig(3)).balance(taskModels, localityManager);
+    Set<ContainerModel> groupContainers = new 
GroupByContainerCount(3).group(taskModels);
+    Set<ContainerModel> balanceContainers = new 
GroupByContainerCount(3).group(taskModels, grouperMetadata);
 
     // Results should be the same as calling group()
     assertEquals(groupContainers, balanceContainers);
-
-    // Verify task mappings are saved
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(),
 "0");
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(),
 "1");
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(),
 "2");
-
-    verify(taskAssignmentManager, 
never()).deleteTaskContainerMappings(anyCollection());
   }
 
   @Test
   public void testBalancerNewContainerCountOne() {
     Set<TaskModel> taskModels = generateTaskModels(3);
-    Set<ContainerModel> prevContainers = new 
GroupByContainerCount(getConfig(3)).group(taskModels);
-    Map<String, String> prevTaskToContainerMapping = 
generateTaskContainerMapping(prevContainers);
-    
when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+    Set<ContainerModel> prevContainers = new 
GroupByContainerCount(3).group(taskModels);
+    Map<TaskName, String> prevTaskToContainerMapping = 
generateTaskContainerMapping(prevContainers);
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new 
HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
 
-    Set<ContainerModel> groupContainers = new 
GroupByContainerCount(getConfig(1)).group(taskModels);
-    Set<ContainerModel> balanceContainers = new 
GroupByContainerCount(getConfig(1)).balance(taskModels, localityManager);
+    Set<ContainerModel> groupContainers = new 
GroupByContainerCount(1).group(taskModels);
+    Set<ContainerModel> balanceContainers = new 
GroupByContainerCount(1).group(taskModels, grouperMetadata);
 
     // Results should be the same as calling group()
     assertEquals(groupContainers, balanceContainers);
-
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(),
 "0");
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(),
 "0");
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(),
 "0");
-
-    verify(taskAssignmentManager, 
never()).deleteTaskContainerMappings(anyCollection());
   }
 
   @Test
   public void testBalancerEmptyTaskMapping() {
     Set<TaskModel> taskModels = generateTaskModels(3);
-    when(taskAssignmentManager.readTaskAssignment()).thenReturn(new 
HashMap<>());
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new 
HashMap<>(), new HashMap<>(), new HashMap<>(), new HashMap<>());
 
-    Set<ContainerModel> groupContainers = new 
GroupByContainerCount(getConfig(1)).group(taskModels);
-    Set<ContainerModel> balanceContainers = new 
GroupByContainerCount(getConfig(1)).balance(taskModels, localityManager);
+    Set<ContainerModel> groupContainers = new 
GroupByContainerCount(1).group(taskModels);
+    Set<ContainerModel> balanceContainers = new 
GroupByContainerCount(1).group(taskModels, grouperMetadata);
 
     // Results should be the same as calling group()
     assertEquals(groupContainers, balanceContainers);
-
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(),
 "0");
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(),
 "0");
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(),
 "0");
-
-    verify(taskAssignmentManager, 
never()).deleteTaskContainerMappings(anyCollection());
   }
 
   @Test
   public void testGroupTaskCountIncrease() {
     int taskCount = 3;
     Set<TaskModel> taskModels = generateTaskModels(taskCount);
-    Set<ContainerModel> prevContainers = new 
GroupByContainerCount(getConfig(2)).group(generateTaskModels(taskCount - 1)); 
// Here's the key step
-    Map<String, String> prevTaskToContainerMapping = 
generateTaskContainerMapping(prevContainers);
-    
when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+    Set<ContainerModel> prevContainers = new 
GroupByContainerCount(2).group(generateTaskModels(taskCount - 1)); // Here's 
the key step
+    Map<TaskName, String> prevTaskToContainerMapping = 
generateTaskContainerMapping(prevContainers);
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new 
HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
 
-    Set<ContainerModel> groupContainers = new 
GroupByContainerCount(getConfig(1)).group(taskModels);
-    Set<ContainerModel> balanceContainers = new 
GroupByContainerCount(getConfig(1)).balance(taskModels, localityManager);
+    Set<ContainerModel> groupContainers = new 
GroupByContainerCount(1).group(taskModels);
+    Set<ContainerModel> balanceContainers = new 
GroupByContainerCount(1).group(taskModels, grouperMetadata);
 
     // Results should be the same as calling group()
     assertEquals(groupContainers, balanceContainers);
-
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(),
 "0");
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(),
 "0");
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(),
 "0");
-
-    verify(taskAssignmentManager).deleteTaskContainerMappings(anyCollection());
   }
 
   @Test
   public void testGroupTaskCountDecrease() {
     int taskCount = 3;
     Set<TaskModel> taskModels = generateTaskModels(taskCount);
-    Set<ContainerModel> prevContainers = new 
GroupByContainerCount(getConfig(3)).group(generateTaskModels(taskCount + 1)); 
// Here's the key step
-    Map<String, String> prevTaskToContainerMapping = 
generateTaskContainerMapping(prevContainers);
-    
when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+    Set<ContainerModel> prevContainers = new 
GroupByContainerCount(3).group(generateTaskModels(taskCount + 1)); // Here's 
the key step
+    Map<TaskName, String> prevTaskToContainerMapping = 
generateTaskContainerMapping(prevContainers);
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new 
HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
 
-    Set<ContainerModel> groupContainers = new 
GroupByContainerCount(getConfig(1)).group(taskModels);
-    Set<ContainerModel> balanceContainers = new 
GroupByContainerCount(getConfig(1)).balance(taskModels, localityManager);
+    Set<ContainerModel> groupContainers = new 
GroupByContainerCount(1).group(taskModels);
+    Set<ContainerModel> balanceContainers = new 
GroupByContainerCount(1).group(taskModels, grouperMetadata);
 
     // Results should be the same as calling group()
     assertEquals(groupContainers, balanceContainers);
-
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(0).getTaskName(),
 "0");
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(1).getTaskName(),
 "0");
-    
verify(taskAssignmentManager).writeTaskContainerMapping(getTaskName(2).getTaskName(),
 "0");
-
-    verify(taskAssignmentManager).deleteTaskContainerMappings(anyCollection());
   }
 
   @Test(expected = IllegalArgumentException.class)
   public void testBalancerNewContainerCountGreaterThanTasks() {
     Set<TaskModel> taskModels = generateTaskModels(3);
-    Set<ContainerModel> prevContainers = new 
GroupByContainerCount(getConfig(3)).group(taskModels);
-    Map<String, String> prevTaskToContainerMapping = 
generateTaskContainerMapping(prevContainers);
-    
when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+    Set<ContainerModel> prevContainers = new 
GroupByContainerCount(3).group(taskModels);
+    Map<TaskName, String> prevTaskToContainerMapping = 
generateTaskContainerMapping(prevContainers);
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new 
HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
 
-    new GroupByContainerCount(getConfig(5)).balance(taskModels, 
localityManager);     // Should throw
+    new GroupByContainerCount(5).group(taskModels, grouperMetadata);     // 
Should throw
   }
 
   @Test(expected = IllegalArgumentException.class)
   public void testBalancerEmptyTasks() {
     Set<TaskModel> taskModels = generateTaskModels(3);
-    Set<ContainerModel> prevContainers = new 
GroupByContainerCount(getConfig(3)).group(taskModels);
-    Map<String, String> prevTaskToContainerMapping = 
generateTaskContainerMapping(prevContainers);
-    
when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+    Set<ContainerModel> prevContainers = new 
GroupByContainerCount(3).group(taskModels);
+    Map<TaskName, String> prevTaskToContainerMapping = 
generateTaskContainerMapping(prevContainers);
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new 
HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
 
-    new GroupByContainerCount(getConfig(5)).balance(new HashSet<>(), 
localityManager);     // Should throw
+    new GroupByContainerCount(5).group(new HashSet<>(), grouperMetadata);
   }
 
   @Test(expected = UnsupportedOperationException.class)
   public void testBalancerResultImmutable() {
     Set<TaskModel> taskModels = generateTaskModels(3);
-    Set<ContainerModel> prevContainers = new 
GroupByContainerCount(getConfig(3)).group(taskModels);
-    Map<String, String> prevTaskToContainerMapping = 
generateTaskContainerMapping(prevContainers);
-    
when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
+    Set<ContainerModel> prevContainers = new 
GroupByContainerCount(3).group(taskModels);
+    Map<TaskName, String> prevTaskToContainerMapping = 
generateTaskContainerMapping(prevContainers);
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new 
HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
 
-    Set<ContainerModel> containers = new 
GroupByContainerCount(getConfig(2)).balance(taskModels, localityManager);
+    Set<ContainerModel> containers = new 
GroupByContainerCount(2).group(taskModels, grouperMetadata);
     containers.remove(containers.iterator().next());
   }
 
@@ -780,32 +650,20 @@ public class TestGroupByContainerCount {
   public void testBalancerThrowsOnNonIntegerContainerIds() {
     Set<TaskModel> taskModels = generateTaskModels(3);
     Set<ContainerModel> prevContainers = new HashSet<>();
-    taskModels.forEach(model -> {
-        prevContainers.add(
-          new ContainerModel(UUID.randomUUID().toString(), 
Collections.singletonMap(model.getTaskName(), model)));
-      });
-    Map<String, String> prevTaskToContainerMapping = 
generateTaskContainerMapping(prevContainers);
-    
when(taskAssignmentManager.readTaskAssignment()).thenReturn(prevTaskToContainerMapping);
-
-    new GroupByContainerCount(getConfig(3)).balance(taskModels, 
localityManager); //Should throw
-
+    taskModels.forEach(model -> prevContainers.add(new 
ContainerModel(UUID.randomUUID().toString(), 
Collections.singletonMap(model.getTaskName(), model))));
+    Map<TaskName, String> prevTaskToContainerMapping = 
generateTaskContainerMapping(prevContainers);
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new 
HashMap<>(), new HashMap<>(), new HashMap<>(), prevTaskToContainerMapping);
+    new GroupByContainerCount(3).group(taskModels, grouperMetadata); //Should 
throw
   }
 
   @Test
   public void testBalancerWithNullLocalityManager() {
     Set<TaskModel> taskModels = generateTaskModels(3);
 
-    Set<ContainerModel> groupContainers = new 
GroupByContainerCount(getConfig(3)).group(taskModels);
-    Set<ContainerModel> balanceContainers = new 
GroupByContainerCount(getConfig(3)).balance(taskModels, null);
+    Set<ContainerModel> groupContainers = new 
GroupByContainerCount(3).group(taskModels);
+    Set<ContainerModel> balanceContainers = new 
GroupByContainerCount(3).balance(taskModels, null);
 
     // Results should be the same as calling group()
     assertEquals(groupContainers, balanceContainers);
   }
-
-
-  Config getConfig(int containerCount) {
-    Map<String, String> config = new HashMap<>();
-    config.put(JobConfig.JOB_CONTAINER_COUNT(), 
String.valueOf(containerCount));
-    return new MapConfig(config);
-  }
 }

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerIds.java
----------------------------------------------------------------------
diff --git 
a/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerIds.java
 
b/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerIds.java
index 5bb78e8..12b6b1e 100644
--- 
a/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerIds.java
+++ 
b/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestGroupByContainerIds.java
@@ -20,6 +20,7 @@
 package org.apache.samza.container.grouper.task;
 
 import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.ImmutableSet;
 import java.util.ArrayList;
 import java.util.Collections;
@@ -29,35 +30,24 @@ import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.stream.Collectors;
+
+import org.apache.samza.Partition;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.MapConfig;
-import org.apache.samza.container.LocalityManager;
 import org.apache.samza.container.TaskName;
 import org.apache.samza.job.model.ContainerModel;
 import org.apache.samza.job.model.TaskModel;
-import org.junit.Before;
+import org.apache.samza.runtime.LocationId;
 import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.powermock.api.mockito.PowerMockito;
-import org.powermock.core.classloader.annotations.PrepareForTest;
-import org.powermock.modules.junit4.PowerMockRunner;
-
-import static org.apache.samza.container.mock.ContainerMocks.*;
-import static org.junit.Assert.*;
-import static org.mockito.Mockito.*;
 
+import static 
org.apache.samza.container.mock.ContainerMocks.generateTaskModels;
+import static org.apache.samza.container.mock.ContainerMocks.getTaskName;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
 
-@RunWith(PowerMockRunner.class)
-@PrepareForTest({TaskAssignmentManager.class, GroupByContainerIds.class})
 public class TestGroupByContainerIds {
 
-  @Before
-  public void setup() throws Exception {
-    TaskAssignmentManager taskAssignmentManager = 
mock(TaskAssignmentManager.class);
-    LocalityManager localityManager = mock(LocalityManager.class);
-    
PowerMockito.whenNew(TaskAssignmentManager.class).withAnyArguments().thenReturn(taskAssignmentManager);
-  }
-
   private Config buildConfigForContainerCount(int count) {
     Map<String, String> map = new HashMap<>();
     map.put("job.container.count", String.valueOf(count));
@@ -67,6 +57,7 @@ public class TestGroupByContainerIds {
   private TaskNameGrouper buildSimpleGrouper() {
     return buildSimpleGrouper(1);
   }
+
   private TaskNameGrouper buildSimpleGrouper(int containerCount) {
     return new 
GroupByContainerIdsFactory().build(buildConfigForContainerCount(containerCount));
   }
@@ -114,7 +105,8 @@ public class TestGroupByContainerIds {
   public void testGroupWithNullContainerIds() {
     Set<TaskModel> taskModels = generateTaskModels(5);
 
-    Set<ContainerModel> containers = buildSimpleGrouper(2).group(taskModels, 
null);
+    List<String> containerIds = null;
+    Set<ContainerModel> containers = buildSimpleGrouper(2).group(taskModels, 
containerIds);
 
     Map<String, ContainerModel> containersMap = new HashMap<>();
     for (ContainerModel container : containers) {
@@ -251,4 +243,264 @@ public class TestGroupByContainerIds {
     assertEquals(1, actualContainerModels.size());
     assertEquals(ImmutableSet.of(expectedContainerModel), 
actualContainerModels);
   }
+
+  @Test
+  public void testShouldUseTaskLocalityWhenGeneratingContainerModels() {
+    TaskNameGrouper taskNameGrouper = buildSimpleGrouper(3);
+
+    String testProcessorId1 = "testProcessorId1";
+    String testProcessorId2 = "testProcessorId2";
+    String testProcessorId3 = "testProcessorId3";
+
+    LocationId testLocationId1 = new LocationId("testLocationId1");
+    LocationId testLocationId2 = new LocationId("testLocationId2");
+    LocationId testLocationId3 = new LocationId("testLocationId3");
+
+    TaskName testTaskName1 = new TaskName("testTasKId1");
+    TaskName testTaskName2 = new TaskName("testTaskId2");
+    TaskName testTaskName3 = new TaskName("testTaskId3");
+
+    TaskModel testTaskModel1 = new TaskModel(testTaskName1, new HashSet<>(), 
new Partition(0));
+    TaskModel testTaskModel2 = new TaskModel(testTaskName2, new HashSet<>(), 
new Partition(1));
+    TaskModel testTaskModel3 = new TaskModel(testTaskName3, new HashSet<>(), 
new Partition(2));
+
+    Map<String, LocationId> processorLocality = 
ImmutableMap.of(testProcessorId1, testLocationId1,
+                                                                
testProcessorId2, testLocationId2,
+                                                                
testProcessorId3, testLocationId3);
+
+    Map<TaskName, LocationId> taskLocality = ImmutableMap.of(testTaskName1, 
testLocationId1,
+                                                             testTaskName2, 
testLocationId2,
+                                                             testTaskName3, 
testLocationId3);
+
+    GrouperMetadataImpl grouperMetadata = new 
GrouperMetadataImpl(processorLocality, taskLocality, new HashMap<>(), new 
HashMap<>());
+
+    Set<TaskModel> taskModels = ImmutableSet.of(testTaskModel1, 
testTaskModel2, testTaskModel3);
+
+    Set<ContainerModel> expectedContainerModels = ImmutableSet.of(new 
ContainerModel(testProcessorId1, ImmutableMap.of(testTaskName1, 
testTaskModel1)),
+                                                                  new 
ContainerModel(testProcessorId2, ImmutableMap.of(testTaskName2, 
testTaskModel2)),
+                                                                  new 
ContainerModel(testProcessorId3, ImmutableMap.of(testTaskName3, 
testTaskModel3)));
+
+    Set<ContainerModel> actualContainerModels = 
taskNameGrouper.group(taskModels, grouperMetadata);
+
+    assertEquals(expectedContainerModels, actualContainerModels);
+  }
+
+  @Test
+  public void testGenerateContainerModelForSingleContainer() {
+    TaskNameGrouper taskNameGrouper = buildSimpleGrouper(1);
+
+    String testProcessorId1 = "testProcessorId1";
+
+    LocationId testLocationId1 = new LocationId("testLocationId1");
+    LocationId testLocationId2 = new LocationId("testLocationId2");
+    LocationId testLocationId3 = new LocationId("testLocationId3");
+
+    TaskName testTaskName1 = new TaskName("testTasKId1");
+    TaskName testTaskName2 = new TaskName("testTaskId2");
+    TaskName testTaskName3 = new TaskName("testTaskId3");
+
+    TaskModel testTaskModel1 = new TaskModel(testTaskName1, new HashSet<>(), 
new Partition(0));
+    TaskModel testTaskModel2 = new TaskModel(testTaskName2, new HashSet<>(), 
new Partition(1));
+    TaskModel testTaskModel3 = new TaskModel(testTaskName3, new HashSet<>(), 
new Partition(2));
+
+    Map<String, LocationId> processorLocality = 
ImmutableMap.of(testProcessorId1, testLocationId1);
+
+    Map<TaskName, LocationId> taskLocality = ImmutableMap.of(testTaskName1, 
testLocationId1,
+                                                             testTaskName2, 
testLocationId2,
+                                                             testTaskName3, 
testLocationId3);
+
+    GrouperMetadataImpl grouperMetadata = new 
GrouperMetadataImpl(processorLocality, taskLocality, new HashMap<>(), new 
HashMap<>());
+
+    Set<TaskModel> taskModels = ImmutableSet.of(testTaskModel1, 
testTaskModel2, testTaskModel3);
+
+    Set<ContainerModel> expectedContainerModels = ImmutableSet.of(new 
ContainerModel(testProcessorId1, ImmutableMap.of(testTaskName1, testTaskModel1,
+                                                                               
                                        testTaskName2, testTaskModel2,
+                                                                               
                                        testTaskName3, testTaskModel3)));
+
+    Set<ContainerModel> actualContainerModels = 
taskNameGrouper.group(taskModels, grouperMetadata);
+
+    assertEquals(expectedContainerModels, actualContainerModels);
+  }
+
+  @Test
+  public void testShouldGenerateCorrectContainerModelWhenTaskLocalityIsEmpty() 
{
+    TaskNameGrouper taskNameGrouper = buildSimpleGrouper(3);
+
+    String testProcessorId1 = "testProcessorId1";
+    String testProcessorId2 = "testProcessorId2";
+    String testProcessorId3 = "testProcessorId3";
+
+    LocationId testLocationId1 = new LocationId("testLocationId1");
+    LocationId testLocationId2 = new LocationId("testLocationId2");
+    LocationId testLocationId3 = new LocationId("testLocationId3");
+
+    TaskName testTaskName1 = new TaskName("testTasKId1");
+    TaskName testTaskName2 = new TaskName("testTaskId2");
+    TaskName testTaskName3 = new TaskName("testTaskId3");
+
+    TaskModel testTaskModel1 = new TaskModel(testTaskName1, new HashSet<>(), 
new Partition(0));
+    TaskModel testTaskModel2 = new TaskModel(testTaskName2, new HashSet<>(), 
new Partition(1));
+    TaskModel testTaskModel3 = new TaskModel(testTaskName3, new HashSet<>(), 
new Partition(2));
+
+    Map<String, LocationId> processorLocality = 
ImmutableMap.of(testProcessorId1, testLocationId1,
+                                                                
testProcessorId2, testLocationId2,
+                                                                
testProcessorId3, testLocationId3);
+
+    Map<TaskName, LocationId> taskLocality = ImmutableMap.of(testTaskName1, 
testLocationId1);
+
+    GrouperMetadataImpl grouperMetadata = new 
GrouperMetadataImpl(processorLocality, taskLocality, new HashMap<>(), new 
HashMap<>());
+
+    Set<TaskModel> taskModels = ImmutableSet.of(testTaskModel1, 
testTaskModel2, testTaskModel3);
+
+    Set<ContainerModel> expectedContainerModels = ImmutableSet.of(new 
ContainerModel(testProcessorId1, ImmutableMap.of(testTaskName1, 
testTaskModel1)),
+                                                                  new 
ContainerModel(testProcessorId2, ImmutableMap.of(testTaskName2, 
testTaskModel2)),
+                                                                  new 
ContainerModel(testProcessorId3, ImmutableMap.of(testTaskName3, 
testTaskModel3)));
+
+    Set<ContainerModel> actualContainerModels = 
taskNameGrouper.group(taskModels, grouperMetadata);
+
+    assertEquals(expectedContainerModels, actualContainerModels);
+  }
+
+  @Test(expected = IllegalArgumentException.class)
+  public void testShouldFailWhenProcessorLocalityIsEmpty() {
+    TaskNameGrouper taskNameGrouper = buildSimpleGrouper(3);
+
+    GrouperMetadataImpl grouperMetadata = new GrouperMetadataImpl(new 
HashMap<>(), new HashMap<>(), new HashMap<>(), new HashMap<>());
+
+    taskNameGrouper.group(new HashSet<>(), grouperMetadata);
+  }
+
+  @Test
+  public void 
testShouldGenerateIdenticalTaskDistributionWhenNoChangeInProcessorGroup() {
+    TaskNameGrouper taskNameGrouper = buildSimpleGrouper(3);
+
+    String testProcessorId1 = "testProcessorId1";
+    String testProcessorId2 = "testProcessorId2";
+    String testProcessorId3 = "testProcessorId3";
+
+    LocationId testLocationId1 = new LocationId("testLocationId1");
+    LocationId testLocationId2 = new LocationId("testLocationId2");
+    LocationId testLocationId3 = new LocationId("testLocationId3");
+
+    TaskName testTaskName1 = new TaskName("testTasKId1");
+    TaskName testTaskName2 = new TaskName("testTaskId2");
+    TaskName testTaskName3 = new TaskName("testTaskId3");
+
+    TaskModel testTaskModel1 = new TaskModel(testTaskName1, new HashSet<>(), 
new Partition(0));
+    TaskModel testTaskModel2 = new TaskModel(testTaskName2, new HashSet<>(), 
new Partition(1));
+    TaskModel testTaskModel3 = new TaskModel(testTaskName3, new HashSet<>(), 
new Partition(2));
+
+    Map<String, LocationId> processorLocality = 
ImmutableMap.of(testProcessorId1, testLocationId1,
+            testProcessorId2, testLocationId2,
+            testProcessorId3, testLocationId3);
+
+    Map<TaskName, LocationId> taskLocality = ImmutableMap.of(testTaskName1, 
testLocationId1,
+            testTaskName2, testLocationId2,
+            testTaskName3, testLocationId3);
+
+    GrouperMetadataImpl grouperMetadata = new 
GrouperMetadataImpl(processorLocality, taskLocality, new HashMap<>(), new 
HashMap<>());
+
+    Set<TaskModel> taskModels = ImmutableSet.of(testTaskModel1, 
testTaskModel2, testTaskModel3);
+
+    Set<ContainerModel> expectedContainerModels = ImmutableSet.of(new 
ContainerModel(testProcessorId1, ImmutableMap.of(testTaskName1, 
testTaskModel1)),
+            new ContainerModel(testProcessorId2, 
ImmutableMap.of(testTaskName2, testTaskModel2)),
+            new ContainerModel(testProcessorId3, 
ImmutableMap.of(testTaskName3, testTaskModel3)));
+
+    Set<ContainerModel> actualContainerModels = 
taskNameGrouper.group(taskModels, grouperMetadata);
+
+    assertEquals(expectedContainerModels, actualContainerModels);
+
+    actualContainerModels = taskNameGrouper.group(taskModels, grouperMetadata);
+
+    assertEquals(expectedContainerModels, actualContainerModels);
+  }
+
+  @Test
+  public void 
testShouldMinimizeTaskShuffleWhenAvailableProcessorInGroupChanges() {
+    TaskNameGrouper taskNameGrouper = buildSimpleGrouper(3);
+
+    String testProcessorId1 = "testProcessorId1";
+    String testProcessorId2 = "testProcessorId2";
+    String testProcessorId3 = "testProcessorId3";
+
+    LocationId testLocationId1 = new LocationId("testLocationId1");
+    LocationId testLocationId2 = new LocationId("testLocationId2");
+    LocationId testLocationId3 = new LocationId("testLocationId3");
+
+    TaskName testTaskName1 = new TaskName("testTasKId1");
+    TaskName testTaskName2 = new TaskName("testTaskId2");
+    TaskName testTaskName3 = new TaskName("testTaskId3");
+
+    TaskModel testTaskModel1 = new TaskModel(testTaskName1, new HashSet<>(), 
new Partition(0));
+    TaskModel testTaskModel2 = new TaskModel(testTaskName2, new HashSet<>(), 
new Partition(1));
+    TaskModel testTaskModel3 = new TaskModel(testTaskName3, new HashSet<>(), 
new Partition(2));
+
+    Map<String, LocationId> processorLocality = 
ImmutableMap.of(testProcessorId1, testLocationId1,
+            testProcessorId2, testLocationId2,
+            testProcessorId3, testLocationId3);
+
+    Map<TaskName, LocationId> taskLocality = ImmutableMap.of(testTaskName1, 
testLocationId1,
+            testTaskName2, testLocationId2,
+            testTaskName3, testLocationId3);
+
+    GrouperMetadataImpl grouperMetadata = new 
GrouperMetadataImpl(processorLocality, taskLocality, new HashMap<>(), new 
HashMap<>());
+
+    Set<TaskModel> taskModels = ImmutableSet.of(testTaskModel1, 
testTaskModel2, testTaskModel3);
+
+    Set<ContainerModel> expectedContainerModels = ImmutableSet.of(new 
ContainerModel(testProcessorId1, ImmutableMap.of(testTaskName1, 
testTaskModel1)),
+            new ContainerModel(testProcessorId2, 
ImmutableMap.of(testTaskName2, testTaskModel2)),
+            new ContainerModel(testProcessorId3, 
ImmutableMap.of(testTaskName3, testTaskModel3)));
+
+    Set<ContainerModel> actualContainerModels = 
taskNameGrouper.group(taskModels, grouperMetadata);
+
+    assertEquals(expectedContainerModels, actualContainerModels);
+
+    processorLocality = ImmutableMap.of(testProcessorId1, testLocationId1,
+                                        testProcessorId2, testLocationId2);
+
+    grouperMetadata = new GrouperMetadataImpl(processorLocality, taskLocality, 
new HashMap<>(), new HashMap<>());
+
+    actualContainerModels = taskNameGrouper.group(taskModels, grouperMetadata);
+
+    expectedContainerModels = ImmutableSet.of(new 
ContainerModel(testProcessorId1, ImmutableMap.of(testTaskName1, testTaskModel1, 
testTaskName3, testTaskModel3)),
+                                              new 
ContainerModel(testProcessorId2, ImmutableMap.of(testTaskName2, 
testTaskModel2)));
+
+    assertEquals(expectedContainerModels, actualContainerModels);
+  }
+
+  @Test
+  public void testMoreTasksThanProcessors() {
+    String testProcessorId1 = "testProcessorId1";
+    String testProcessorId2 = "testProcessorId2";
+
+    LocationId testLocationId1 = new LocationId("testLocationId1");
+    LocationId testLocationId2 = new LocationId("testLocationId2");
+    LocationId testLocationId3 = new LocationId("testLocationId3");
+
+    TaskName testTaskName1 = new TaskName("testTasKId1");
+    TaskName testTaskName2 = new TaskName("testTaskId2");
+    TaskName testTaskName3 = new TaskName("testTaskId3");
+
+    Map<String, LocationId> processorLocality = 
ImmutableMap.of(testProcessorId1, testLocationId1,
+        testProcessorId2, testLocationId2);
+
+    Map<TaskName, LocationId> taskLocality = ImmutableMap.of(testTaskName1, 
testLocationId1,
+        testTaskName2, testLocationId2,
+        testTaskName3, testLocationId3);
+
+    GrouperMetadataImpl grouperMetadata = new 
GrouperMetadataImpl(processorLocality, taskLocality, new HashMap<>(), new 
HashMap<>());
+
+
+    Set<TaskModel> taskModels = generateTaskModels(1);
+    List<String> containerIds = ImmutableList.of(testProcessorId1, 
testProcessorId2);
+
+    Map<TaskName, TaskModel> expectedTasks = taskModels.stream()
+        .collect(Collectors.toMap(TaskModel::getTaskName, x -> x));
+    ContainerModel expectedContainerModel = new 
ContainerModel(testProcessorId1, expectedTasks);
+
+    Set<ContainerModel> actualContainerModels = 
buildSimpleGrouper().group(taskModels, grouperMetadata);
+
+    assertEquals(1, actualContainerModels.size());
+    assertEquals(ImmutableSet.of(expectedContainerModel), 
actualContainerModels);
+  }
 }

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestTaskAssignmentManager.java
----------------------------------------------------------------------
diff --git 
a/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestTaskAssignmentManager.java
 
b/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestTaskAssignmentManager.java
index fcdbf08..60164b2 100644
--- 
a/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestTaskAssignmentManager.java
+++ 
b/samza-core/src/test/java/org/apache/samza/container/grouper/task/TestTaskAssignmentManager.java
@@ -68,7 +68,6 @@ public class TestTaskAssignmentManager {
   @Test
   public void testTaskAssignmentManager() {
     TaskAssignmentManager taskAssignmentManager = new 
TaskAssignmentManager(config, new MetricsRegistryMap());
-    taskAssignmentManager.init();
 
     Map<String, String> expectedMap = ImmutableMap.of("Task0", "0", "Task1", 
"1", "Task2", "2", "Task3", "0", "Task4", "1");
 
@@ -86,7 +85,6 @@ public class TestTaskAssignmentManager {
   @Test
   public void testDeleteMappings() {
     TaskAssignmentManager taskAssignmentManager = new 
TaskAssignmentManager(config, new MetricsRegistryMap());
-    taskAssignmentManager.init();
 
     Map<String, String> expectedMap = ImmutableMap.of("Task0", "0", "Task1", 
"1");
 
@@ -108,7 +106,6 @@ public class TestTaskAssignmentManager {
   @Test
   public void testTaskAssignmentManagerEmptyCoordinatorStream() {
     TaskAssignmentManager taskAssignmentManager = new 
TaskAssignmentManager(config, new MetricsRegistryMap());
-    taskAssignmentManager.init();
 
     Map<String, String> expectedMap = new HashMap<>();
     Map<String, String> localMap = taskAssignmentManager.readTaskAssignment();

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/test/java/org/apache/samza/container/mock/ContainerMocks.java
----------------------------------------------------------------------
diff --git 
a/samza-core/src/test/java/org/apache/samza/container/mock/ContainerMocks.java 
b/samza-core/src/test/java/org/apache/samza/container/mock/ContainerMocks.java
index ca9def2..be240b1 100644
--- 
a/samza-core/src/test/java/org/apache/samza/container/mock/ContainerMocks.java
+++ 
b/samza-core/src/test/java/org/apache/samza/container/mock/ContainerMocks.java
@@ -117,11 +117,11 @@ public class ContainerMocks {
     return values;
   }
 
-  public static Map<String, String> 
generateTaskContainerMapping(Set<ContainerModel> containers) {
-    Map<String, String> taskMapping = new HashMap<>();
+  public static Map<TaskName, String> 
generateTaskContainerMapping(Set<ContainerModel> containers) {
+    Map<TaskName, String> taskMapping = new HashMap<>();
     for (ContainerModel container : containers) {
       for (TaskName taskName : container.getTasks().keySet()) {
-        taskMapping.put(taskName.getTaskName(), container.getId());
+        taskMapping.put(taskName, container.getId());
       }
     }
     return taskMapping;

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/test/java/org/apache/samza/coordinator/JobModelManagerTestUtil.java
----------------------------------------------------------------------
diff --git 
a/samza-core/src/test/java/org/apache/samza/coordinator/JobModelManagerTestUtil.java
 
b/samza-core/src/test/java/org/apache/samza/coordinator/JobModelManagerTestUtil.java
index ea25ec1..02aaaa7 100644
--- 
a/samza-core/src/test/java/org/apache/samza/coordinator/JobModelManagerTestUtil.java
+++ 
b/samza-core/src/test/java/org/apache/samza/coordinator/JobModelManagerTestUtil.java
@@ -19,15 +19,15 @@
 
 package org.apache.samza.coordinator;
 
-import java.util.ArrayList;
 import java.util.HashMap;
-import java.util.List;
 import java.util.Map;
 import org.apache.samza.config.Config;
 import org.apache.samza.container.LocalityManager;
+import org.apache.samza.container.grouper.task.GrouperMetadataImpl;
 import org.apache.samza.coordinator.server.HttpServer;
 import org.apache.samza.job.model.ContainerModel;
 import org.apache.samza.job.model.JobModel;
+import org.apache.samza.runtime.LocationId;
 import org.apache.samza.system.StreamMetadataCache;
 
 /**
@@ -49,15 +49,8 @@ public class JobModelManagerTestUtil {
     return new JobModelManager(jobModel, server, null);
   }
 
-  public static JobModelManager getJobModelManagerUsingReadModel(Config 
config, int containerCount, StreamMetadataCache streamMetadataCache,
-    LocalityManager locManager, HttpServer server) {
-    List<String> containerIds = new ArrayList<>();
-    for (int i = 0; i < containerCount; i++) {
-      containerIds.add(String.valueOf(i));
-    }
-    JobModel jobModel = JobModelManager.readJobModel(config, new HashMap<>(), 
locManager, streamMetadataCache, containerIds);
-    return new JobModelManager(jobModel, server, null);
+  public static JobModelManager getJobModelManagerUsingReadModel(Config 
config, StreamMetadataCache streamMetadataCache, HttpServer server, 
LocalityManager localityManager, Map<String, LocationId> processorLocality) {
+    JobModel jobModel = JobModelManager.readJobModel(config, new HashMap<>(), 
streamMetadataCache, new GrouperMetadataImpl(processorLocality, new 
HashMap<>(), new HashMap<>(), new HashMap<>()));
+    return new JobModelManager(new JobModel(jobModel.getConfig(), 
jobModel.getContainers(), localityManager), server, localityManager);
   }
-
-
 }

http://git-wip-us.apache.org/repos/asf/samza/blob/5ea72584/samza-core/src/test/java/org/apache/samza/coordinator/TestJobModelManager.java
----------------------------------------------------------------------
diff --git 
a/samza-core/src/test/java/org/apache/samza/coordinator/TestJobModelManager.java
 
b/samza-core/src/test/java/org/apache/samza/coordinator/TestJobModelManager.java
index 1dbf132..6048466 100644
--- 
a/samza-core/src/test/java/org/apache/samza/coordinator/TestJobModelManager.java
+++ 
b/samza-core/src/test/java/org/apache/samza/coordinator/TestJobModelManager.java
@@ -19,20 +19,30 @@
 
 package org.apache.samza.coordinator;
 
+import com.google.common.collect.ImmutableMap;
+import java.util.HashSet;
+import java.util.Set;
 import org.apache.samza.Partition;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.MapConfig;
 import org.apache.samza.container.LocalityManager;
+import org.apache.samza.container.TaskName;
 import org.apache.samza.container.grouper.task.GroupByContainerCount;
+import org.apache.samza.container.grouper.task.GrouperMetadataImpl;
 import org.apache.samza.container.grouper.task.TaskAssignmentManager;
 import org.apache.samza.coordinator.server.HttpServer;
 import org.apache.samza.coordinator.stream.messages.SetContainerHostMapping;
+import org.apache.samza.job.model.ContainerModel;
+import org.apache.samza.job.model.JobModel;
+import org.apache.samza.job.model.TaskModel;
+import org.apache.samza.runtime.LocationId;
 import org.apache.samza.system.StreamMetadataCache;
 import org.apache.samza.system.SystemStream;
 import org.apache.samza.system.SystemStreamMetadata;
 import org.apache.samza.testUtils.MockHttpServer;
 import org.eclipse.jetty.servlet.DefaultServlet;
 import org.eclipse.jetty.servlet.ServletHolder;
+import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
 
@@ -40,14 +50,15 @@ import java.util.HashMap;
 import java.util.Map;
 import java.util.Collections;
 
+import static org.apache.samza.coordinator.JobModelManager.*;
 import static org.junit.Assert.assertEquals;
 import static org.mockito.Matchers.anyBoolean;
 import static org.mockito.Matchers.argThat;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.when;
+import static org.mockito.Mockito.*;
 
 import org.junit.runner.RunWith;
 import org.mockito.ArgumentMatcher;
+import org.mockito.Mockito;
 import org.powermock.api.mockito.PowerMockito;
 import org.powermock.core.classloader.annotations.PrepareForTest;
 import org.powermock.modules.junit4.PowerMockRunner;
@@ -60,7 +71,6 @@ import scala.collection.JavaConversions;
 @PrepareForTest({TaskAssignmentManager.class, GroupByContainerCount.class})
 public class TestJobModelManager {
   private final TaskAssignmentManager mockTaskManager = 
mock(TaskAssignmentManager.class);
-  private final LocalityManager mockLocalityManager = 
mock(LocalityManager.class);
   private final Map<String, Map<String, String>> localityMappings = new 
HashMap<>();
   private final HttpServer server = new MockHttpServer("/", 7777, null, new 
ServletHolder(DefaultServlet.class));
   private final SystemStream inputStream = new SystemStream("test-system", 
"test-stream");
@@ -75,7 +85,6 @@ public class TestJobModelManager {
 
   @Before
   public void setup() throws Exception {
-    
when(mockLocalityManager.readContainerLocality()).thenReturn(this.localityMappings);
     when(mockStreamMetadataCache.getStreamMetadata(argThat(new 
ArgumentMatcher<scala.collection.immutable.Set<SystemStream>>() {
       @Override
       public boolean matches(Object argument) {
@@ -105,11 +114,15 @@ public class TestJobModelManager {
         put("job.host-affinity.enabled", "true");
       }
     });
+    LocalityManager mockLocalityManager = mock(LocalityManager.class);
 
-    this.localityMappings.put("0", new HashMap<String, String>() { {
+    localityMappings.put("0", new HashMap<String, String>() { {
         put(SetContainerHostMapping.HOST_KEY, "abc-affinity");
       } });
-    this.jobModelManager = 
JobModelManagerTestUtil.getJobModelManagerUsingReadModel(config, 1, 
mockStreamMetadataCache, mockLocalityManager, server);
+    
when(mockLocalityManager.readContainerLocality()).thenReturn(this.localityMappings);
+
+    Map<String, LocationId> containerLocality = ImmutableMap.of("0", new 
LocationId("abc-affinity"));
+    this.jobModelManager = 
JobModelManagerTestUtil.getJobModelManagerUsingReadModel(config, 
mockStreamMetadataCache, server, mockLocalityManager, containerLocality);
 
     assertEquals(jobModelManager.jobModel().getAllContainerLocality(), new 
HashMap<String, String>() { { this.put("0", "abc-affinity"); } });
   }
@@ -132,11 +145,96 @@ public class TestJobModelManager {
       }
     });
 
-    this.localityMappings.put("0", new HashMap<String, String>() { {
+    LocalityManager mockLocalityManager = mock(LocalityManager.class);
+
+    localityMappings.put("0", new HashMap<String, String>() { {
         put(SetContainerHostMapping.HOST_KEY, "abc-affinity");
       } });
-    this.jobModelManager = 
JobModelManagerTestUtil.getJobModelManagerUsingReadModel(config, 1, 
mockStreamMetadataCache, mockLocalityManager, server);
+    when(mockLocalityManager.readContainerLocality()).thenReturn(new 
HashMap<>());
+
+    Map<String, LocationId> containerLocality = ImmutableMap.of("0", new 
LocationId("abc-affinity"));
+
+    this.jobModelManager = 
JobModelManagerTestUtil.getJobModelManagerUsingReadModel(config, 
mockStreamMetadataCache, server, mockLocalityManager, containerLocality);
 
     assertEquals(jobModelManager.jobModel().getAllContainerLocality(), new 
HashMap<String, String>() { { this.put("0", null); } });
   }
+
+  @Test
+  public void testGetGrouperMetadata() {
+    // Mocking setup.
+    LocalityManager mockLocalityManager = mock(LocalityManager.class);
+    TaskAssignmentManager mockTaskAssignmentManager = 
Mockito.mock(TaskAssignmentManager.class);
+
+    Map<String, Map<String, String>> localityMappings = new HashMap<>();
+    localityMappings.put("0", 
ImmutableMap.of(SetContainerHostMapping.HOST_KEY, "abc-affinity"));
+
+    Map<String, String> taskAssignment = ImmutableMap.of("task-0", "0");
+
+    // Mock the container locality assignment.
+    
when(mockLocalityManager.readContainerLocality()).thenReturn(localityMappings);
+
+    // Mock the container to task assignment.
+    
when(mockTaskAssignmentManager.readTaskAssignment()).thenReturn(taskAssignment);
+
+    GrouperMetadataImpl grouperMetadata = 
JobModelManager.getGrouperMetadata(new MapConfig(), mockLocalityManager, 
mockTaskAssignmentManager);
+
+    Mockito.verify(mockLocalityManager).readContainerLocality();
+    Mockito.verify(mockTaskAssignmentManager).readTaskAssignment();
+
+    Assert.assertEquals(ImmutableMap.of("0", new LocationId("abc-affinity"), 
"1", new LocationId("ANY_HOST")), grouperMetadata.getProcessorLocality());
+    Assert.assertEquals(ImmutableMap.of(new TaskName("task-0"), new 
LocationId("abc-affinity")), grouperMetadata.getTaskLocality());
+  }
+
+  @Test
+  public void testGetProcessorLocality() {
+    // Mock the dependencies.
+    LocalityManager mockLocalityManager = mock(LocalityManager.class);
+
+    Map<String, Map<String, String>> localityMappings = new HashMap<>();
+    localityMappings.put("0", 
ImmutableMap.of(SetContainerHostMapping.HOST_KEY, "abc-affinity"));
+
+    // Mock the container locality assignment.
+    
when(mockLocalityManager.readContainerLocality()).thenReturn(localityMappings);
+
+    Map<String, LocationId> processorLocality = 
JobModelManager.getProcessorLocality(new MapConfig(), mockLocalityManager);
+
+    Mockito.verify(mockLocalityManager).readContainerLocality();
+    Assert.assertEquals(ImmutableMap.of("0", new LocationId("abc-affinity"), 
"1", new LocationId("ANY_HOST")), processorLocality);
+  }
+
+  @Test
+  public void testUpdateTaskAssignments() {
+    // Mocking setup.
+    JobModel mockJobModel = Mockito.mock(JobModel.class);
+    GrouperMetadataImpl mockGrouperMetadata = 
Mockito.mock(GrouperMetadataImpl.class);
+    TaskAssignmentManager mockTaskAssignmentManager = 
Mockito.mock(TaskAssignmentManager.class);
+
+    Map<TaskName, TaskModel> taskModelMap = new HashMap<>();
+    taskModelMap.put(new TaskName("task-1"), new TaskModel(new 
TaskName("task-1"), new HashSet<>(), new Partition(0)));
+    taskModelMap.put(new TaskName("task-2"), new TaskModel(new 
TaskName("task-2"), new HashSet<>(), new Partition(1)));
+    taskModelMap.put(new TaskName("task-3"), new TaskModel(new 
TaskName("task-3"), new HashSet<>(), new Partition(2)));
+    taskModelMap.put(new TaskName("task-4"), new TaskModel(new 
TaskName("task-4"), new HashSet<>(), new Partition(3)));
+    ContainerModel containerModel = new ContainerModel("test-container-id", 
taskModelMap);
+    Map<String, ContainerModel> containerMapping = 
ImmutableMap.of("test-container-id", containerModel);
+
+    when(mockJobModel.getContainers()).thenReturn(containerMapping);
+    
when(mockGrouperMetadata.getPreviousTaskToProcessorAssignment()).thenReturn(new 
HashMap<>());
+    
Mockito.doNothing().when(mockTaskAssignmentManager).writeTaskContainerMapping(Mockito.any(),
 Mockito.any());
+
+    JobModelManager.updateTaskAssignments(mockJobModel, 
mockTaskAssignmentManager, mockGrouperMetadata);
+
+    Set<String> taskNames = new HashSet<String>();
+    taskNames.add("task-4");
+    taskNames.add("task-2");
+    taskNames.add("task-3");
+    taskNames.add("task-1");
+
+    // Verifications
+    Mockito.verify(mockJobModel, atLeast(1)).getContainers();
+    
Mockito.verify(mockTaskAssignmentManager).deleteTaskContainerMappings((Iterable<String>)
 taskNames);
+    
Mockito.verify(mockTaskAssignmentManager).writeTaskContainerMapping("task-1", 
"test-container-id");
+    
Mockito.verify(mockTaskAssignmentManager).writeTaskContainerMapping("task-2", 
"test-container-id");
+    
Mockito.verify(mockTaskAssignmentManager).writeTaskContainerMapping("task-3", 
"test-container-id");
+    
Mockito.verify(mockTaskAssignmentManager).writeTaskContainerMapping("task-4", 
"test-container-id");
+  }
 }

Reply via email to