mjsax commented on code in PR #14030:
URL: https://github.com/apache/kafka/pull/14030#discussion_r1272606193


##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignor.java:
##########
@@ -185,4 +193,212 @@ public boolean validateClientRack() {
         }
         return true;
     }
+
+    private int getCost(final TaskId taskId, final UUID clientId, final 
boolean inCurrentAssignment, final boolean isStateful) {
+        final Map<String, Optional<String>> clientRacks = 
racksForProcess.get(clientId);
+        if (clientRacks == null) {
+            throw new IllegalStateException("Client " + clientId + " doesn't 
exist in processRacks");
+        }
+        final Optional<Optional<String>> clientRackOpt = 
clientRacks.values().stream().filter(Optional::isPresent).findFirst();
+        if (!clientRackOpt.isPresent() || !clientRackOpt.get().isPresent()) {

Review Comment:
   We `filter()` already for `isPresent` -- seems we only seen the first check?



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignor.java:
##########
@@ -185,4 +193,212 @@ public boolean validateClientRack() {
         }
         return true;
     }
+
+    private int getCost(final TaskId taskId, final UUID clientId, final 
boolean inCurrentAssignment, final boolean isStateful) {
+        final Map<String, Optional<String>> clientRacks = 
racksForProcess.get(clientId);
+        if (clientRacks == null) {
+            throw new IllegalStateException("Client " + clientId + " doesn't 
exist in processRacks");
+        }
+        final Optional<Optional<String>> clientRackOpt = 
clientRacks.values().stream().filter(Optional::isPresent).findFirst();
+        if (!clientRackOpt.isPresent() || !clientRackOpt.get().isPresent()) {
+            throw new IllegalStateException("Client " + clientId + " doesn't 
have rack configured. Maybe forgot to call canEnableRackAwareAssignor first");
+        }
+
+        final String clientRack = clientRackOpt.get().get();
+        final Set<TopicPartition> topicPartitions = 
partitionsForTask.get(taskId);
+        if (topicPartitions == null || topicPartitions.isEmpty()) {
+            throw new IllegalStateException("Task " + taskId + " has no 
TopicPartitions");
+        }
+
+        final int trafficCost = assignmentConfigs.trafficCost == null ? 
(isStateful ? DEFAULT_STATEFUL_TRAFFIC_COST : DEFAULT_STATELESS_TRAFFIC_COST)
+            : assignmentConfigs.trafficCost;
+        final int nonOverlapCost = assignmentConfigs.nonOverlapCost == null ? 
(isStateful ? DEFAULT_STATEFUL_NON_OVERLAP_COST : 
DEFAULT_STATELESS_NON_OVERLAP_COST)
+            : assignmentConfigs.nonOverlapCost;
+
+        int cost = 0;
+        for (final TopicPartition tp : topicPartitions) {
+            final Set<String> tpRacks = racksForPartition.get(tp);
+            if (tpRacks == null || tpRacks.isEmpty()) {
+                throw new IllegalStateException("TopicPartition " + tp + " has 
no rack information. Maybe forgot to call canEnableRackAwareAssignor first");
+            }
+            if (!tpRacks.contains(clientRack)) {
+                cost += trafficCost;
+            }
+        }
+
+        if (!inCurrentAssignment) {
+            cost += nonOverlapCost;
+        }
+
+        return cost;
+    }
+
+    private static int getSinkID(final List<UUID> clientList, final 
List<TaskId> taskIdList) {
+        return clientList.size() + taskIdList.size();
+    }
+
+    // For testing. canEnableRackAwareAssignor must be called first
+    long activeTasksCost(final SortedMap<UUID, ClientState> clientStates, 
final SortedSet<TaskId> activeTasks, final boolean isStateful) {
+        final List<UUID> clientList = new ArrayList<>(clientStates.keySet());
+        final List<TaskId> taskIdList = new ArrayList<>(activeTasks);
+        final Map<TaskId, UUID> taskClientMap = new HashMap<>();
+        final Map<UUID, Integer> originalAssignedTaskNumber = new HashMap<>();
+        final Graph<Integer> graph = constructActiveTaskGraph(activeTasks, 
clientList, taskIdList,
+            clientStates, taskClientMap, originalAssignedTaskNumber, 
isStateful);

Review Comment:
   Given that we don't read `taskClientMap` nor `originalAssignedTaskNumber`, 
should we just pass `new HashMap` directly?



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignor.java:
##########
@@ -185,4 +193,212 @@ public boolean validateClientRack() {
         }
         return true;
     }
+
+    private int getCost(final TaskId taskId, final UUID clientId, final 
boolean inCurrentAssignment, final boolean isStateful) {
+        final Map<String, Optional<String>> clientRacks = 
racksForProcess.get(clientId);
+        if (clientRacks == null) {
+            throw new IllegalStateException("Client " + clientId + " doesn't 
exist in processRacks");
+        }
+        final Optional<Optional<String>> clientRackOpt = 
clientRacks.values().stream().filter(Optional::isPresent).findFirst();
+        if (!clientRackOpt.isPresent() || !clientRackOpt.get().isPresent()) {
+            throw new IllegalStateException("Client " + clientId + " doesn't 
have rack configured. Maybe forgot to call canEnableRackAwareAssignor first");
+        }
+
+        final String clientRack = clientRackOpt.get().get();
+        final Set<TopicPartition> topicPartitions = 
partitionsForTask.get(taskId);
+        if (topicPartitions == null || topicPartitions.isEmpty()) {
+            throw new IllegalStateException("Task " + taskId + " has no 
TopicPartitions");
+        }
+
+        final int trafficCost = assignmentConfigs.trafficCost == null ? 
(isStateful ? DEFAULT_STATEFUL_TRAFFIC_COST : DEFAULT_STATELESS_TRAFFIC_COST)
+            : assignmentConfigs.trafficCost;
+        final int nonOverlapCost = assignmentConfigs.nonOverlapCost == null ? 
(isStateful ? DEFAULT_STATEFUL_NON_OVERLAP_COST : 
DEFAULT_STATELESS_NON_OVERLAP_COST)
+            : assignmentConfigs.nonOverlapCost;
+
+        int cost = 0;
+        for (final TopicPartition tp : topicPartitions) {
+            final Set<String> tpRacks = racksForPartition.get(tp);
+            if (tpRacks == null || tpRacks.isEmpty()) {
+                throw new IllegalStateException("TopicPartition " + tp + " has 
no rack information. Maybe forgot to call canEnableRackAwareAssignor first");
+            }
+            if (!tpRacks.contains(clientRack)) {
+                cost += trafficCost;
+            }
+        }
+
+        if (!inCurrentAssignment) {
+            cost += nonOverlapCost;
+        }
+
+        return cost;
+    }
+
+    private static int getSinkID(final List<UUID> clientList, final 
List<TaskId> taskIdList) {
+        return clientList.size() + taskIdList.size();
+    }
+
+    // For testing. canEnableRackAwareAssignor must be called first
+    long activeTasksCost(final SortedMap<UUID, ClientState> clientStates, 
final SortedSet<TaskId> activeTasks, final boolean isStateful) {
+        final List<UUID> clientList = new ArrayList<>(clientStates.keySet());
+        final List<TaskId> taskIdList = new ArrayList<>(activeTasks);
+        final Map<TaskId, UUID> taskClientMap = new HashMap<>();
+        final Map<UUID, Integer> originalAssignedTaskNumber = new HashMap<>();
+        final Graph<Integer> graph = constructActiveTaskGraph(activeTasks, 
clientList, taskIdList,
+            clientStates, taskClientMap, originalAssignedTaskNumber, 
isStateful);
+        return graph.totalCost();
+    }
+
+    /**
+     * Optimize active task assignment for rack awareness. 
canEnableRackAwareAssignor must be called first
+     * @param clientStates Client states
+     * @param activeTasks Tasks to reassign if needed. They must be assigned 
already in clientStates
+     * @param isStateful Whether the tasks are stateful
+     * @return Total cost after optimization
+     */
+    public long optimizeActiveTasks(final SortedMap<UUID, ClientState> 
clientStates,
+                                    final SortedSet<TaskId> activeTasks,
+                                    final boolean isStateful) {
+        if (activeTasks.isEmpty()) {
+            return 0;
+        }
+
+        final List<UUID> clientList = new ArrayList<>(clientStates.keySet());
+        final List<TaskId> taskIdList = new ArrayList<>(activeTasks);
+        final Map<TaskId, UUID> taskClientMap = new HashMap<>();
+        final Map<UUID, Integer> originalAssignedTaskNumber = new HashMap<>();
+        final Graph<Integer> graph = constructActiveTaskGraph(activeTasks, 
clientList, taskIdList,
+            clientStates, taskClientMap, originalAssignedTaskNumber, 
isStateful);
+
+        graph.solveMinCostFlow();
+        final long cost = graph.totalCost();
+
+        assignActiveTaskFromMinCostFlow(graph, activeTasks, clientList, 
taskIdList,
+            clientStates, originalAssignedTaskNumber, taskClientMap);
+
+        return cost;
+    }
+
+    private Graph<Integer> constructActiveTaskGraph(final SortedSet<TaskId> 
activeTasks,
+                                                    final List<UUID> 
clientList,
+                                                    final List<TaskId> 
taskIdList,
+                                                    final Map<UUID, 
ClientState> clientStates,
+                                                    final Map<TaskId, UUID> 
taskClientMap,
+                                                    final Map<UUID, Integer> 
originalAssignedTaskNumber,
+                                                    final boolean isStateful) {
+        final Graph<Integer> graph = new Graph<>();
+
+        for (final TaskId taskId : activeTasks) {
+            for (final Entry<UUID, ClientState> clientState : 
clientStates.entrySet()) {
+                if (clientState.getValue().hasAssignedTask(taskId)) {
+                    originalAssignedTaskNumber.merge(clientState.getKey(), 1, 
Integer::sum);
+                }
+            }
+        }
+
+        // Make task and client Node id in graph deterministic
+        for (int taskNodeId  = 0; taskNodeId < taskIdList.size(); 
taskNodeId++) {

Review Comment:
   nit: double whitespace `taskNodeId  = 0`



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignor.java:
##########
@@ -185,4 +193,212 @@ public boolean validateClientRack() {
         }
         return true;
     }
+
+    private int getCost(final TaskId taskId, final UUID clientId, final 
boolean inCurrentAssignment, final boolean isStateful) {
+        final Map<String, Optional<String>> clientRacks = 
racksForProcess.get(clientId);
+        if (clientRacks == null) {
+            throw new IllegalStateException("Client " + clientId + " doesn't 
exist in processRacks");
+        }
+        final Optional<Optional<String>> clientRackOpt = 
clientRacks.values().stream().filter(Optional::isPresent).findFirst();
+        if (!clientRackOpt.isPresent() || !clientRackOpt.get().isPresent()) {
+            throw new IllegalStateException("Client " + clientId + " doesn't 
have rack configured. Maybe forgot to call canEnableRackAwareAssignor first");
+        }
+
+        final String clientRack = clientRackOpt.get().get();
+        final Set<TopicPartition> topicPartitions = 
partitionsForTask.get(taskId);
+        if (topicPartitions == null || topicPartitions.isEmpty()) {
+            throw new IllegalStateException("Task " + taskId + " has no 
TopicPartitions");
+        }
+
+        final int trafficCost = assignmentConfigs.trafficCost == null ? 
(isStateful ? DEFAULT_STATEFUL_TRAFFIC_COST : DEFAULT_STATELESS_TRAFFIC_COST)
+            : assignmentConfigs.trafficCost;
+        final int nonOverlapCost = assignmentConfigs.nonOverlapCost == null ? 
(isStateful ? DEFAULT_STATEFUL_NON_OVERLAP_COST : 
DEFAULT_STATELESS_NON_OVERLAP_COST)
+            : assignmentConfigs.nonOverlapCost;
+
+        int cost = 0;
+        for (final TopicPartition tp : topicPartitions) {
+            final Set<String> tpRacks = racksForPartition.get(tp);
+            if (tpRacks == null || tpRacks.isEmpty()) {
+                throw new IllegalStateException("TopicPartition " + tp + " has 
no rack information. Maybe forgot to call canEnableRackAwareAssignor first");
+            }
+            if (!tpRacks.contains(clientRack)) {
+                cost += trafficCost;
+            }
+        }
+
+        if (!inCurrentAssignment) {
+            cost += nonOverlapCost;
+        }
+
+        return cost;
+    }
+
+    private static int getSinkID(final List<UUID> clientList, final 
List<TaskId> taskIdList) {
+        return clientList.size() + taskIdList.size();
+    }
+
+    // For testing. canEnableRackAwareAssignor must be called first
+    long activeTasksCost(final SortedMap<UUID, ClientState> clientStates, 
final SortedSet<TaskId> activeTasks, final boolean isStateful) {
+        final List<UUID> clientList = new ArrayList<>(clientStates.keySet());
+        final List<TaskId> taskIdList = new ArrayList<>(activeTasks);

Review Comment:
   Same question as for `clientList`.



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignor.java:
##########
@@ -185,4 +193,212 @@ public boolean validateClientRack() {
         }
         return true;
     }
+
+    private int getCost(final TaskId taskId, final UUID clientId, final 
boolean inCurrentAssignment, final boolean isStateful) {
+        final Map<String, Optional<String>> clientRacks = 
racksForProcess.get(clientId);
+        if (clientRacks == null) {
+            throw new IllegalStateException("Client " + clientId + " doesn't 
exist in processRacks");
+        }
+        final Optional<Optional<String>> clientRackOpt = 
clientRacks.values().stream().filter(Optional::isPresent).findFirst();
+        if (!clientRackOpt.isPresent() || !clientRackOpt.get().isPresent()) {
+            throw new IllegalStateException("Client " + clientId + " doesn't 
have rack configured. Maybe forgot to call canEnableRackAwareAssignor first");
+        }
+
+        final String clientRack = clientRackOpt.get().get();
+        final Set<TopicPartition> topicPartitions = 
partitionsForTask.get(taskId);
+        if (topicPartitions == null || topicPartitions.isEmpty()) {
+            throw new IllegalStateException("Task " + taskId + " has no 
TopicPartitions");
+        }
+
+        final int trafficCost = assignmentConfigs.trafficCost == null ? 
(isStateful ? DEFAULT_STATEFUL_TRAFFIC_COST : DEFAULT_STATELESS_TRAFFIC_COST)
+            : assignmentConfigs.trafficCost;
+        final int nonOverlapCost = assignmentConfigs.nonOverlapCost == null ? 
(isStateful ? DEFAULT_STATEFUL_NON_OVERLAP_COST : 
DEFAULT_STATELESS_NON_OVERLAP_COST)
+            : assignmentConfigs.nonOverlapCost;
+
+        int cost = 0;
+        for (final TopicPartition tp : topicPartitions) {
+            final Set<String> tpRacks = racksForPartition.get(tp);
+            if (tpRacks == null || tpRacks.isEmpty()) {
+                throw new IllegalStateException("TopicPartition " + tp + " has 
no rack information. Maybe forgot to call canEnableRackAwareAssignor first");
+            }
+            if (!tpRacks.contains(clientRack)) {
+                cost += trafficCost;
+            }
+        }
+
+        if (!inCurrentAssignment) {
+            cost += nonOverlapCost;
+        }
+
+        return cost;
+    }
+
+    private static int getSinkID(final List<UUID> clientList, final 
List<TaskId> taskIdList) {
+        return clientList.size() + taskIdList.size();
+    }
+
+    // For testing. canEnableRackAwareAssignor must be called first
+    long activeTasksCost(final SortedMap<UUID, ClientState> clientStates, 
final SortedSet<TaskId> activeTasks, final boolean isStateful) {
+        final List<UUID> clientList = new ArrayList<>(clientStates.keySet());
+        final List<TaskId> taskIdList = new ArrayList<>(activeTasks);
+        final Map<TaskId, UUID> taskClientMap = new HashMap<>();
+        final Map<UUID, Integer> originalAssignedTaskNumber = new HashMap<>();
+        final Graph<Integer> graph = constructActiveTaskGraph(activeTasks, 
clientList, taskIdList,
+            clientStates, taskClientMap, originalAssignedTaskNumber, 
isStateful);
+        return graph.totalCost();
+    }
+
+    /**
+     * Optimize active task assignment for rack awareness. 
canEnableRackAwareAssignor must be called first
+     * @param clientStates Client states
+     * @param activeTasks Tasks to reassign if needed. They must be assigned 
already in clientStates
+     * @param isStateful Whether the tasks are stateful
+     * @return Total cost after optimization
+     */
+    public long optimizeActiveTasks(final SortedMap<UUID, ClientState> 
clientStates,
+                                    final SortedSet<TaskId> activeTasks,
+                                    final boolean isStateful) {
+        if (activeTasks.isEmpty()) {
+            return 0;
+        }
+
+        final List<UUID> clientList = new ArrayList<>(clientStates.keySet());
+        final List<TaskId> taskIdList = new ArrayList<>(activeTasks);
+        final Map<TaskId, UUID> taskClientMap = new HashMap<>();
+        final Map<UUID, Integer> originalAssignedTaskNumber = new HashMap<>();
+        final Graph<Integer> graph = constructActiveTaskGraph(activeTasks, 
clientList, taskIdList,
+            clientStates, taskClientMap, originalAssignedTaskNumber, 
isStateful);
+
+        graph.solveMinCostFlow();
+        final long cost = graph.totalCost();
+
+        assignActiveTaskFromMinCostFlow(graph, activeTasks, clientList, 
taskIdList,
+            clientStates, originalAssignedTaskNumber, taskClientMap);
+
+        return cost;
+    }
+
+    private Graph<Integer> constructActiveTaskGraph(final SortedSet<TaskId> 
activeTasks,
+                                                    final List<UUID> 
clientList,
+                                                    final List<TaskId> 
taskIdList,
+                                                    final Map<UUID, 
ClientState> clientStates,
+                                                    final Map<TaskId, UUID> 
taskClientMap,
+                                                    final Map<UUID, Integer> 
originalAssignedTaskNumber,
+                                                    final boolean isStateful) {
+        final Graph<Integer> graph = new Graph<>();
+
+        for (final TaskId taskId : activeTasks) {
+            for (final Entry<UUID, ClientState> clientState : 
clientStates.entrySet()) {
+                if (clientState.getValue().hasAssignedTask(taskId)) {
+                    originalAssignedTaskNumber.merge(clientState.getKey(), 1, 
Integer::sum);
+                }
+            }
+        }
+
+        // Make task and client Node id in graph deterministic
+        for (int taskNodeId  = 0; taskNodeId < taskIdList.size(); 
taskNodeId++) {
+            final TaskId taskId = taskIdList.get(taskNodeId);
+            for (int j = 0; j < clientList.size(); j++) {
+                final int clientNodeId = taskIdList.size() + j;
+                final UUID clientId = clientList.get(j);

Review Comment:
   `clientId` is `processId` here, right? (renaming or add comment?)



##########
streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignorTest.java:
##########
@@ -145,147 +188,517 @@ public void disableActiveSinceRackMissingInClient() {
 
         // False since process1 doesn't have rackId
         assertFalse(assignor.validateClientRack());
-        assertFalse(assignor.canEnableRackAwareAssignorForActiveTasks());
+        assertFalse(assignor.canEnableRackAwareAssignor());
     }
 
     @Test
-    public void disableActiveSinceRackDiffersInSameProcess() {
+    public void shouldDisableActiveWhenRackDiffersInSameProcess() {
         final Map<UUID, Map<String, Optional<String>>> processRacks = new 
HashMap<>();
 
         // Different consumers in same process have different rack ID. This 
shouldn't happen.
         // If happens, there's a bug somewhere
-        processRacks.computeIfAbsent(process0UUID, k -> new 
HashMap<>()).put("consumer1", Optional.of("rack1"));
-        processRacks.computeIfAbsent(process0UUID, k -> new 
HashMap<>()).put("consumer2", Optional.of("rack2"));
+        processRacks.computeIfAbsent(UUID_1, k -> new 
HashMap<>()).put("consumer1", Optional.of("rack1"));
+        processRacks.computeIfAbsent(UUID_1, k -> new 
HashMap<>()).put("consumer2", Optional.of("rack2"));
 
         final RackAwareTaskAssignor assignor = new RackAwareTaskAssignor(
             getClusterForTopic0(),
-            getTaskTopicPartitionMapForTask1(),
+            getTaskTopicPartitionMapForTask0(),
             getTopologyGroupTaskMap(),
             processRacks,
             mockInternalTopicManager,
             new 
AssignorConfiguration(streamsConfig.originals()).assignmentConfigs()
         );
 
         assertFalse(assignor.validateClientRack());
-        assertFalse(assignor.canEnableRackAwareAssignorForActiveTasks());
+        assertFalse(assignor.canEnableRackAwareAssignor());
     }
 
     @Test
-    public void enableRackAwareAssignorForActiveWithoutDescribingTopics() {
+    public void 
shouldEnableRackAwareAssignorForActiveWithoutDescribingTopics() {
         final RackAwareTaskAssignor assignor = new RackAwareTaskAssignor(
             getClusterForTopic0(),
-            getTaskTopicPartitionMapForTask1(),
+            getTaskTopicPartitionMapForTask0(),
             getTopologyGroupTaskMap(),
             getProcessRacksForProcess0(),
             mockInternalTopicManager,
             new 
AssignorConfiguration(streamsConfig.originals()).assignmentConfigs()
         );
 
         // partitionWithoutInfo00 has rackInfo in cluster metadata
-        assertTrue(assignor.canEnableRackAwareAssignorForActiveTasks());
+        assertTrue(assignor.canEnableRackAwareAssignor());
     }
 
     @Test
-    public void enableRackAwareAssignorForActiveWithDescribingTopics() {
+    public void shouldEnableRackAwareAssignorForActiveWithDescribingTopics() {
         final MockInternalTopicManager spyTopicManager = 
spy(mockInternalTopicManager);
         doReturn(
             Collections.singletonMap(
-                TOPIC0,
+                TP_0_NAME,
                 Collections.singletonList(
-                    new TopicPartitionInfo(0, node0, Arrays.asList(replicas1), 
Collections.emptyList())
+                    new TopicPartitionInfo(0, NODE_0, 
Arrays.asList(REPLICA_1), Collections.emptyList())
                 )
             )
-        
).when(spyTopicManager).getTopicPartitionInfo(Collections.singleton(TOPIC0));
+        
).when(spyTopicManager).getTopicPartitionInfo(Collections.singleton(TP_0_NAME));
 
         final RackAwareTaskAssignor assignor = new RackAwareTaskAssignor(
             getClusterWithNoNode(),
-            getTaskTopicPartitionMapForTask1(),
+            getTaskTopicPartitionMapForTask0(),
             getTopologyGroupTaskMap(),
             getProcessRacksForProcess0(),
             spyTopicManager,
             new 
AssignorConfiguration(streamsConfig.originals()).assignmentConfigs()
         );
 
-        assertTrue(assignor.canEnableRackAwareAssignorForActiveTasks());
+        assertTrue(assignor.canEnableRackAwareAssignor());
     }
 
     @Test
-    public void disableRackAwareAssignorForActiveWithDescribingTopicsFailure() 
{
+    public void 
shouldDisableRackAwareAssignorForActiveWithDescribingTopicsFailure() {
         final MockInternalTopicManager spyTopicManager = 
spy(mockInternalTopicManager);
-        doThrow(new TimeoutException("Timeout describing 
topic")).when(spyTopicManager).getTopicPartitionInfo(Collections.singleton(TOPIC0));
+        doThrow(new TimeoutException("Timeout describing 
topic")).when(spyTopicManager).getTopicPartitionInfo(Collections.singleton(
+            TP_0_NAME));
 
         final RackAwareTaskAssignor assignor = new RackAwareTaskAssignor(
             getClusterWithNoNode(),
-            getTaskTopicPartitionMapForTask1(),
+            getTaskTopicPartitionMapForTask0(),
             getTopologyGroupTaskMap(),
             getProcessRacksForProcess0(),
             spyTopicManager,
             new 
AssignorConfiguration(streamsConfig.originals()).assignmentConfigs()
         );
 
-        assertFalse(assignor.canEnableRackAwareAssignorForActiveTasks());
+        assertFalse(assignor.canEnableRackAwareAssignor());
         assertTrue(assignor.populateTopicsToDiscribe(new HashSet<>()));
     }
 
+    @Test
+    public void shouldOptimizeEmptyActiveTasks() {
+        final RackAwareTaskAssignor assignor = new RackAwareTaskAssignor(
+            getClusterForTopic0And1(),
+            getTaskTopicPartitionMapForAllTasks(),
+            getTopologyGroupTaskMap(),
+            getProcessRacksForAllProcess(),
+            mockInternalTopicManager,
+            new 
AssignorConfiguration(streamsConfig.originals()).assignmentConfigs()
+        );
+
+        final ClientState clientState0 = new ClientState(emptySet(), 
emptySet(), emptyMap(), EMPTY_CLIENT_TAGS, 1);
+
+        clientState0.assignActiveTasks(mkSet(TASK_0_1, TASK_1_1));
+
+        final SortedMap<UUID, ClientState> clientStateMap = new 
TreeMap<>(mkMap(
+            mkEntry(UUID_1, clientState0)
+        ));
+        final SortedSet<TaskId> taskIds = mkSortedSet();
+
+        if (assignor.canEnableRackAwareAssignor()) {
+            final long originalCost = assignor.activeTasksCost(clientStateMap, 
taskIds, stateful);
+            assertEquals(0, originalCost);
+
+            final long cost = assignor.optimizeActiveTasks(clientStateMap, 
taskIds, stateful);
+            assertEquals(0, cost);
+        }
+
+        assertEquals(mkSet(TASK_0_1, TASK_1_1), clientState0.activeTasks());
+    }
+
+    @Test
+    public void shouldOptimizeActiveTasks() {
+        final RackAwareTaskAssignor assignor = new RackAwareTaskAssignor(
+            getClusterForTopic0And1(),
+            getTaskTopicPartitionMapForAllTasks(),
+            getTopologyGroupTaskMap(),
+            getProcessRacksForAllProcess(),
+            mockInternalTopicManager,
+            new 
AssignorConfiguration(streamsConfig.originals()).assignmentConfigs()
+        );
+
+        final ClientState clientState0 = new ClientState(emptySet(), 
emptySet(), emptyMap(), EMPTY_CLIENT_TAGS, 1);
+        final ClientState clientState1 = new ClientState(emptySet(), 
emptySet(), emptyMap(), EMPTY_CLIENT_TAGS, 1);
+        final ClientState clientState2 = new ClientState(emptySet(), 
emptySet(), emptyMap(), EMPTY_CLIENT_TAGS, 1);
+
+        clientState0.assignActiveTasks(mkSet(TASK_0_1, TASK_1_1));
+        clientState1.assignActive(TASK_1_0);
+        clientState2.assignActive(TASK_0_0);
+
+        // task_0_0 has same rack as UUID_1
+        // task_0_1 has same rack as UUID_2 and UUID_3
+        // task_1_0 has same rack as UUID_1 and UUID_3
+        // task_1_1 has same rack as UUID_2
+        // Optimal assignment is UUID_1: {0_0, 1_0}, UUID_2: {1_1}, UUID_3: 
{0_1} which result in no cross rack traffic
+        final SortedMap<UUID, ClientState> clientStateMap = new 
TreeMap<>(mkMap(
+            mkEntry(UUID_1, clientState0),
+            mkEntry(UUID_2, clientState1),
+            mkEntry(UUID_3, clientState2)
+        ));
+        final SortedSet<TaskId> taskIds = mkSortedSet(TASK_0_0, TASK_0_1, 
TASK_1_0, TASK_1_1);
+
+        if (assignor.canEnableRackAwareAssignor()) {
+            int expected = stateful ? 40 : 4;
+            final long originalCost = assignor.activeTasksCost(clientStateMap, 
taskIds, stateful);
+            assertEquals(expected, originalCost);
+
+            expected = stateful ? 4 : 0;
+            final long cost = assignor.optimizeActiveTasks(clientStateMap, 
taskIds, stateful);
+            assertEquals(expected, cost);
+        }
+
+        assertEquals(mkSet(TASK_0_0, TASK_1_0), clientState0.activeTasks());
+        assertEquals(mkSet(TASK_1_1), clientState1.activeTasks());
+        assertEquals(mkSet(TASK_0_1), clientState2.activeTasks());
+    }
+
+    @Test
+    public void shouldOptimizeActiveTasksWithWeightOverride() {
+        final AssignmentConfigs assignmentConfigs = new AssignmentConfigs(1L, 
2, 2, 60000L, Collections.emptyList(), 1, 10);
+        final RackAwareTaskAssignor assignor = new RackAwareTaskAssignor(
+            getClusterForTopic0And1(),
+            getTaskTopicPartitionMapForAllTasks(),
+            getTopologyGroupTaskMap(),
+            getProcessRacksForAllProcess(),
+            mockInternalTopicManager,
+            assignmentConfigs
+        );
+
+        final ClientState clientState0 = new ClientState(emptySet(), 
emptySet(), emptyMap(), EMPTY_CLIENT_TAGS, 1);
+        final ClientState clientState1 = new ClientState(emptySet(), 
emptySet(), emptyMap(), EMPTY_CLIENT_TAGS, 1);
+        final ClientState clientState2 = new ClientState(emptySet(), 
emptySet(), emptyMap(), EMPTY_CLIENT_TAGS, 1);
+
+        clientState0.assignActiveTasks(mkSet(TASK_0_1, TASK_1_1));
+        clientState1.assignActive(TASK_1_0);
+        clientState2.assignActive(TASK_0_0);
+
+        final SortedMap<UUID, ClientState> clientStateMap = new 
TreeMap<>(mkMap(
+            mkEntry(UUID_1, clientState0),
+            mkEntry(UUID_2, clientState1),
+            mkEntry(UUID_3, clientState2)
+        ));
+        final SortedSet<TaskId> taskIds = mkSortedSet(TASK_0_0, TASK_0_1, 
TASK_1_0, TASK_1_1);
+
+        // Because non_overlap_cost is very high, this basically will stick to 
original assignment
+        if (assignor.canEnableRackAwareAssignor()) {
+            final long originalCost = assignor.activeTasksCost(clientStateMap, 
taskIds, stateful);
+            assertEquals(4, originalCost);

Review Comment:
   Why is expected cost not 40 for stateful as in the test above?



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignor.java:
##########
@@ -185,4 +193,212 @@ public boolean validateClientRack() {
         }
         return true;
     }
+
+    private int getCost(final TaskId taskId, final UUID clientId, final 
boolean inCurrentAssignment, final boolean isStateful) {
+        final Map<String, Optional<String>> clientRacks = 
racksForProcess.get(clientId);
+        if (clientRacks == null) {
+            throw new IllegalStateException("Client " + clientId + " doesn't 
exist in processRacks");
+        }
+        final Optional<Optional<String>> clientRackOpt = 
clientRacks.values().stream().filter(Optional::isPresent).findFirst();
+        if (!clientRackOpt.isPresent() || !clientRackOpt.get().isPresent()) {
+            throw new IllegalStateException("Client " + clientId + " doesn't 
have rack configured. Maybe forgot to call canEnableRackAwareAssignor first");
+        }
+
+        final String clientRack = clientRackOpt.get().get();
+        final Set<TopicPartition> topicPartitions = 
partitionsForTask.get(taskId);
+        if (topicPartitions == null || topicPartitions.isEmpty()) {
+            throw new IllegalStateException("Task " + taskId + " has no 
TopicPartitions");
+        }
+
+        final int trafficCost = assignmentConfigs.trafficCost == null ? 
(isStateful ? DEFAULT_STATEFUL_TRAFFIC_COST : DEFAULT_STATELESS_TRAFFIC_COST)
+            : assignmentConfigs.trafficCost;
+        final int nonOverlapCost = assignmentConfigs.nonOverlapCost == null ? 
(isStateful ? DEFAULT_STATEFUL_NON_OVERLAP_COST : 
DEFAULT_STATELESS_NON_OVERLAP_COST)
+            : assignmentConfigs.nonOverlapCost;
+
+        int cost = 0;
+        for (final TopicPartition tp : topicPartitions) {
+            final Set<String> tpRacks = racksForPartition.get(tp);
+            if (tpRacks == null || tpRacks.isEmpty()) {
+                throw new IllegalStateException("TopicPartition " + tp + " has 
no rack information. Maybe forgot to call canEnableRackAwareAssignor first");
+            }
+            if (!tpRacks.contains(clientRack)) {
+                cost += trafficCost;
+            }
+        }
+
+        if (!inCurrentAssignment) {
+            cost += nonOverlapCost;
+        }
+
+        return cost;
+    }
+
+    private static int getSinkID(final List<UUID> clientList, final 
List<TaskId> taskIdList) {
+        return clientList.size() + taskIdList.size();
+    }
+
+    // For testing. canEnableRackAwareAssignor must be called first
+    long activeTasksCost(final SortedMap<UUID, ClientState> clientStates, 
final SortedSet<TaskId> activeTasks, final boolean isStateful) {
+        final List<UUID> clientList = new ArrayList<>(clientStates.keySet());
+        final List<TaskId> taskIdList = new ArrayList<>(activeTasks);
+        final Map<TaskId, UUID> taskClientMap = new HashMap<>();
+        final Map<UUID, Integer> originalAssignedTaskNumber = new HashMap<>();
+        final Graph<Integer> graph = constructActiveTaskGraph(activeTasks, 
clientList, taskIdList,
+            clientStates, taskClientMap, originalAssignedTaskNumber, 
isStateful);
+        return graph.totalCost();
+    }
+
+    /**
+     * Optimize active task assignment for rack awareness. 
canEnableRackAwareAssignor must be called first
+     * @param clientStates Client states
+     * @param activeTasks Tasks to reassign if needed. They must be assigned 
already in clientStates
+     * @param isStateful Whether the tasks are stateful
+     * @return Total cost after optimization
+     */
+    public long optimizeActiveTasks(final SortedMap<UUID, ClientState> 
clientStates,
+                                    final SortedSet<TaskId> activeTasks,
+                                    final boolean isStateful) {
+        if (activeTasks.isEmpty()) {
+            return 0;
+        }
+
+        final List<UUID> clientList = new ArrayList<>(clientStates.keySet());
+        final List<TaskId> taskIdList = new ArrayList<>(activeTasks);
+        final Map<TaskId, UUID> taskClientMap = new HashMap<>();
+        final Map<UUID, Integer> originalAssignedTaskNumber = new HashMap<>();
+        final Graph<Integer> graph = constructActiveTaskGraph(activeTasks, 
clientList, taskIdList,
+            clientStates, taskClientMap, originalAssignedTaskNumber, 
isStateful);
+
+        graph.solveMinCostFlow();
+        final long cost = graph.totalCost();
+
+        assignActiveTaskFromMinCostFlow(graph, activeTasks, clientList, 
taskIdList,
+            clientStates, originalAssignedTaskNumber, taskClientMap);
+
+        return cost;
+    }
+
+    private Graph<Integer> constructActiveTaskGraph(final SortedSet<TaskId> 
activeTasks,
+                                                    final List<UUID> 
clientList,
+                                                    final List<TaskId> 
taskIdList,
+                                                    final Map<UUID, 
ClientState> clientStates,
+                                                    final Map<TaskId, UUID> 
taskClientMap,
+                                                    final Map<UUID, Integer> 
originalAssignedTaskNumber,
+                                                    final boolean isStateful) {
+        final Graph<Integer> graph = new Graph<>();
+
+        for (final TaskId taskId : activeTasks) {
+            for (final Entry<UUID, ClientState> clientState : 
clientStates.entrySet()) {
+                if (clientState.getValue().hasAssignedTask(taskId)) {
+                    originalAssignedTaskNumber.merge(clientState.getKey(), 1, 
Integer::sum);
+                }
+            }
+        }
+
+        // Make task and client Node id in graph deterministic
+        for (int taskNodeId  = 0; taskNodeId < taskIdList.size(); 
taskNodeId++) {
+            final TaskId taskId = taskIdList.get(taskNodeId);
+            for (int j = 0; j < clientList.size(); j++) {
+                final int clientNodeId = taskIdList.size() + j;
+                final UUID clientId = clientList.get(j);
+
+                final int flow = 
clientStates.get(clientId).hasAssignedTask(taskId) ? 1 : 0;
+                final int cost = getCost(taskId, clientId, flow == 1, 
isStateful);
+                if (flow == 1) {
+                    if (taskClientMap.containsKey(taskId)) {
+                        throw new IllegalArgumentException("Task " + taskId + 
" assigned to multiple clients "
+                            + clientId + ", " + taskClientMap.get(taskId));
+                    }
+                    taskClientMap.put(taskId, clientId);
+                }
+
+                graph.addEdge(taskNodeId, clientNodeId, 1, cost, flow);
+            }
+            if (!taskClientMap.containsKey(taskId)) {
+                throw new IllegalArgumentException("Task " + taskId + " not 
assigned to any client");
+            }
+        }
+
+        final int sinkId = getSinkID(clientList, taskIdList);
+        for (int taskNodeId = 0; taskNodeId < taskIdList.size(); taskNodeId++) 
{
+            graph.addEdge(SOURCE_ID, taskNodeId, 1, 0, 1);
+        }
+
+        // It's possible that some clients have 0 task assign. These clients 
will have 0 tasks assigned
+        // even though it may have lower traffic cost. This is to maintain the 
balance requirement.
+        for (int i = 0; i < clientList.size(); i++) {
+            final int clientNodeId = taskIdList.size() + i;
+            final int capacity = 
originalAssignedTaskNumber.getOrDefault(clientList.get(i), 0);
+            // Flow equals to capacity for edges to sink
+            graph.addEdge(clientNodeId, sinkId, capacity, 0, capacity);
+        }
+
+        graph.setSourceNode(SOURCE_ID);
+        graph.setSinkNode(sinkId);
+
+        return graph;
+    }
+
+    private void assignActiveTaskFromMinCostFlow(final Graph<Integer> graph,
+                                                 final SortedSet<TaskId> 
activeTasks,
+                                                 final List<UUID> clientList,
+                                                 final List<TaskId> taskIdList,
+                                                 final Map<UUID, ClientState> 
clientStates,
+                                                 final Map<UUID, Integer> 
originalClientCapacity,
+                                                 final Map<TaskId, UUID> 
taskClientMap) {
+        int taskAssigned = 0;
+        for (int taskNodeId = 0; taskNodeId < taskIdList.size(); taskNodeId++) 
{
+            final TaskId taskId = taskIdList.get(taskNodeId);
+            final Map<Integer, Graph<Integer>.Edge> edges = 
graph.edges(taskNodeId);
+            for (final Graph<Integer>.Edge edge : edges.values()) {
+                if (edge.flow > 0) {
+                    taskAssigned++;
+                    final int clientIndex = edge.destination - 
taskIdList.size();
+                    final UUID clientId = clientList.get(clientIndex);
+                    final UUID originalClientId = taskClientMap.get(taskId);
+
+                    // Don't need to assign this task to other client
+                    if (clientId.equals(originalClientId)) {
+                        break;
+                    }
+
+                    final ClientState clientState = clientStates.get(clientId);
+                    final ClientState originalClientState = 
clientStates.get(originalClientId);
+                    originalClientState.unassignActive(taskId);
+                    clientState.assignActive(taskId);
+                }
+            }
+        }
+
+        // Validate task assigned
+        if (taskAssigned != activeTasks.size()) {
+            throw new IllegalStateException("Computed active task assignment 
number "
+                + taskAssigned + " is different size " + activeTasks.size());
+        }
+
+        // Validate capacity constraint

Review Comment:
   Is this really a "capacity" or rather the "load" of a client?



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignor.java:
##########
@@ -185,4 +191,224 @@ public boolean validateClientRack() {
         }
         return true;
     }
+
+    private int getCost(final TaskId taskId, final UUID clientId, final 
boolean inCurrentAssignment, final boolean isStateful) {
+        final Map<String, Optional<String>> clientRacks = 
racksForProcess.get(clientId);
+        if (clientRacks == null) {
+            throw new IllegalStateException("Client " + clientId + " doesn't 
exist in processRacks");
+        }
+        final Optional<Optional<String>> clientRackOpt = 
clientRacks.values().stream().filter(Optional::isPresent).findFirst();
+        if (!clientRackOpt.isPresent() || !clientRackOpt.get().isPresent()) {
+            throw new IllegalStateException("Client " + clientId + " doesn't 
have rack configured. Maybe forgot to call canEnableRackAwareAssignor first");
+        }
+
+        final String clientRack = clientRackOpt.get().get();
+        final Set<TopicPartition> topicPartitions = 
partitionsForTask.get(taskId);
+        if (topicPartitions == null) {
+            throw new IllegalStateException("Task " + taskId + " has no 
TopicPartitions");
+        }
+
+        final int trafficCost = assignmentConfigs.trafficCost == null ? 
(isStateful ? DEFAULT_STATEFUL_TRAFFIC_COST : DEFAULT_STATELESS_TRAFFIC_COST)
+            : assignmentConfigs.trafficCost;
+        final int nonOverlapCost = assignmentConfigs.nonOverlapCost == null ? 
(isStateful ? DEFAULT_STATEFUL_NON_OVERLAP_COST : 
DEFAULT_STATELESS_NON_OVERLAP_COST)
+            : assignmentConfigs.nonOverlapCost;
+
+        int cost = 0;
+        for (final TopicPartition tp : topicPartitions) {
+            final Set<String> tpRacks = racksForPartition.get(tp);
+            if (tpRacks == null || tpRacks.isEmpty()) {
+                throw new IllegalStateException("TopicPartition " + tp + " has 
no rack information. Maybe forgot to call canEnableRackAwareAssignor first");
+            }
+            if (!tpRacks.contains(clientRack)) {
+                cost += trafficCost;
+            }
+        }
+
+        if (!inCurrentAssignment) {
+            cost += nonOverlapCost;
+        }
+
+        return cost;
+    }
+
+    // For testing. canEnableRackAwareAssignor must be called first
+    long activeTasksCost(final SortedMap<UUID, ClientState> clientStates, 
final SortedSet<TaskId> statefulTasks, final boolean isStateful) {
+        final List<UUID> clientList = new ArrayList<>(clientStates.keySet());
+        final List<TaskId> taskIdList = new ArrayList<>(statefulTasks);
+        final Map<TaskId, UUID> taskClientMap = new HashMap<>();
+        final Map<UUID, Integer> clientCapacity = new HashMap<>();
+        final Graph<Integer> graph = new Graph<>();
+
+        constructStatefulActiveTaskGraph(graph, statefulTasks, clientList, 
taskIdList,
+            clientStates, taskClientMap, clientCapacity, isStateful);
+
+        final int sourceId = taskIdList.size() + clientList.size();
+        final int sinkId = sourceId + 1;
+        for (int taskNodeId = 0; taskNodeId < taskIdList.size(); taskNodeId++) 
{
+            graph.addEdge(sourceId, taskNodeId, 1, 0, 1);
+        }
+        for (int i = 0; i < clientList.size(); i++) {
+            final int capacity = 
clientCapacity.getOrDefault(clientList.get(i), 0);
+            final int clientNodeId = taskIdList.size() + i;
+            graph.addEdge(clientNodeId, sinkId, capacity, 0, capacity);
+        }
+        graph.setSourceNode(sourceId);
+        graph.setSinkNode(sinkId);
+        return graph.totalCost();
+    }
+
+    /**
+     * Optimize active stateful task assignment for rack awareness. 
canEnableRackAwareAssignor must be called first
+     * @param clientStates Client states
+     * @param taskIds Tasks to reassign if needed. They must be assigned 
already in clientStates
+     * @param isStateful Whether the tasks are stateful
+     * @return Total cost after optimization
+     */
+    public long optimizeActiveTasks(final SortedMap<UUID, ClientState> 
clientStates,
+                                    final SortedSet<TaskId> taskIds,
+                                    final boolean isStateful) {
+        if (taskIds.isEmpty()) {
+            return 0;
+        }
+
+        final List<UUID> clientList = new ArrayList<>(clientStates.keySet());

Review Comment:
   Question as above: why do we need to make a deep copy into a list? Can't we 
just pass `clientStates.keySet()`  instead? Seems both methods only read but 
don't modify.



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignor.java:
##########
@@ -185,4 +193,212 @@ public boolean validateClientRack() {
         }
         return true;
     }
+
+    private int getCost(final TaskId taskId, final UUID clientId, final 
boolean inCurrentAssignment, final boolean isStateful) {
+        final Map<String, Optional<String>> clientRacks = 
racksForProcess.get(clientId);
+        if (clientRacks == null) {
+            throw new IllegalStateException("Client " + clientId + " doesn't 
exist in processRacks");
+        }
+        final Optional<Optional<String>> clientRackOpt = 
clientRacks.values().stream().filter(Optional::isPresent).findFirst();
+        if (!clientRackOpt.isPresent() || !clientRackOpt.get().isPresent()) {
+            throw new IllegalStateException("Client " + clientId + " doesn't 
have rack configured. Maybe forgot to call canEnableRackAwareAssignor first");
+        }
+
+        final String clientRack = clientRackOpt.get().get();
+        final Set<TopicPartition> topicPartitions = 
partitionsForTask.get(taskId);
+        if (topicPartitions == null || topicPartitions.isEmpty()) {
+            throw new IllegalStateException("Task " + taskId + " has no 
TopicPartitions");
+        }
+
+        final int trafficCost = assignmentConfigs.trafficCost == null ? 
(isStateful ? DEFAULT_STATEFUL_TRAFFIC_COST : DEFAULT_STATELESS_TRAFFIC_COST)
+            : assignmentConfigs.trafficCost;
+        final int nonOverlapCost = assignmentConfigs.nonOverlapCost == null ? 
(isStateful ? DEFAULT_STATEFUL_NON_OVERLAP_COST : 
DEFAULT_STATELESS_NON_OVERLAP_COST)
+            : assignmentConfigs.nonOverlapCost;
+
+        int cost = 0;
+        for (final TopicPartition tp : topicPartitions) {
+            final Set<String> tpRacks = racksForPartition.get(tp);
+            if (tpRacks == null || tpRacks.isEmpty()) {
+                throw new IllegalStateException("TopicPartition " + tp + " has 
no rack information. Maybe forgot to call canEnableRackAwareAssignor first");
+            }
+            if (!tpRacks.contains(clientRack)) {
+                cost += trafficCost;
+            }
+        }
+
+        if (!inCurrentAssignment) {
+            cost += nonOverlapCost;
+        }
+
+        return cost;
+    }
+
+    private static int getSinkID(final List<UUID> clientList, final 
List<TaskId> taskIdList) {
+        return clientList.size() + taskIdList.size();
+    }
+
+    // For testing. canEnableRackAwareAssignor must be called first
+    long activeTasksCost(final SortedMap<UUID, ClientState> clientStates, 
final SortedSet<TaskId> activeTasks, final boolean isStateful) {
+        final List<UUID> clientList = new ArrayList<>(clientStates.keySet());
+        final List<TaskId> taskIdList = new ArrayList<>(activeTasks);
+        final Map<TaskId, UUID> taskClientMap = new HashMap<>();
+        final Map<UUID, Integer> originalAssignedTaskNumber = new HashMap<>();
+        final Graph<Integer> graph = constructActiveTaskGraph(activeTasks, 
clientList, taskIdList,
+            clientStates, taskClientMap, originalAssignedTaskNumber, 
isStateful);
+        return graph.totalCost();
+    }
+
+    /**
+     * Optimize active task assignment for rack awareness. 
canEnableRackAwareAssignor must be called first
+     * @param clientStates Client states
+     * @param activeTasks Tasks to reassign if needed. They must be assigned 
already in clientStates
+     * @param isStateful Whether the tasks are stateful

Review Comment:
   So all tasks passed in are stateful? If yes, maybe add "all" to the comments 
(also, if it's `false` does it mean all tasks are stateless?)



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignor.java:
##########
@@ -185,4 +193,212 @@ public boolean validateClientRack() {
         }
         return true;
     }
+
+    private int getCost(final TaskId taskId, final UUID clientId, final 
boolean inCurrentAssignment, final boolean isStateful) {
+        final Map<String, Optional<String>> clientRacks = 
racksForProcess.get(clientId);
+        if (clientRacks == null) {
+            throw new IllegalStateException("Client " + clientId + " doesn't 
exist in processRacks");
+        }
+        final Optional<Optional<String>> clientRackOpt = 
clientRacks.values().stream().filter(Optional::isPresent).findFirst();
+        if (!clientRackOpt.isPresent() || !clientRackOpt.get().isPresent()) {
+            throw new IllegalStateException("Client " + clientId + " doesn't 
have rack configured. Maybe forgot to call canEnableRackAwareAssignor first");
+        }
+
+        final String clientRack = clientRackOpt.get().get();
+        final Set<TopicPartition> topicPartitions = 
partitionsForTask.get(taskId);
+        if (topicPartitions == null || topicPartitions.isEmpty()) {
+            throw new IllegalStateException("Task " + taskId + " has no 
TopicPartitions");
+        }
+
+        final int trafficCost = assignmentConfigs.trafficCost == null ? 
(isStateful ? DEFAULT_STATEFUL_TRAFFIC_COST : DEFAULT_STATELESS_TRAFFIC_COST)
+            : assignmentConfigs.trafficCost;
+        final int nonOverlapCost = assignmentConfigs.nonOverlapCost == null ? 
(isStateful ? DEFAULT_STATEFUL_NON_OVERLAP_COST : 
DEFAULT_STATELESS_NON_OVERLAP_COST)
+            : assignmentConfigs.nonOverlapCost;
+
+        int cost = 0;
+        for (final TopicPartition tp : topicPartitions) {
+            final Set<String> tpRacks = racksForPartition.get(tp);
+            if (tpRacks == null || tpRacks.isEmpty()) {
+                throw new IllegalStateException("TopicPartition " + tp + " has 
no rack information. Maybe forgot to call canEnableRackAwareAssignor first");
+            }
+            if (!tpRacks.contains(clientRack)) {
+                cost += trafficCost;
+            }
+        }
+
+        if (!inCurrentAssignment) {
+            cost += nonOverlapCost;
+        }
+
+        return cost;
+    }
+
+    private static int getSinkID(final List<UUID> clientList, final 
List<TaskId> taskIdList) {
+        return clientList.size() + taskIdList.size();
+    }
+
+    // For testing. canEnableRackAwareAssignor must be called first
+    long activeTasksCost(final SortedMap<UUID, ClientState> clientStates, 
final SortedSet<TaskId> activeTasks, final boolean isStateful) {

Review Comment:
   Is it important for the method that the maps are sorted? Should we use `Map` 
(least restrictive) interface instead)?



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignor.java:
##########
@@ -185,4 +193,212 @@ public boolean validateClientRack() {
         }
         return true;
     }
+
+    private int getCost(final TaskId taskId, final UUID clientId, final 
boolean inCurrentAssignment, final boolean isStateful) {
+        final Map<String, Optional<String>> clientRacks = 
racksForProcess.get(clientId);
+        if (clientRacks == null) {
+            throw new IllegalStateException("Client " + clientId + " doesn't 
exist in processRacks");
+        }
+        final Optional<Optional<String>> clientRackOpt = 
clientRacks.values().stream().filter(Optional::isPresent).findFirst();
+        if (!clientRackOpt.isPresent() || !clientRackOpt.get().isPresent()) {
+            throw new IllegalStateException("Client " + clientId + " doesn't 
have rack configured. Maybe forgot to call canEnableRackAwareAssignor first");
+        }
+
+        final String clientRack = clientRackOpt.get().get();
+        final Set<TopicPartition> topicPartitions = 
partitionsForTask.get(taskId);
+        if (topicPartitions == null || topicPartitions.isEmpty()) {
+            throw new IllegalStateException("Task " + taskId + " has no 
TopicPartitions");
+        }
+
+        final int trafficCost = assignmentConfigs.trafficCost == null ? 
(isStateful ? DEFAULT_STATEFUL_TRAFFIC_COST : DEFAULT_STATELESS_TRAFFIC_COST)
+            : assignmentConfigs.trafficCost;
+        final int nonOverlapCost = assignmentConfigs.nonOverlapCost == null ? 
(isStateful ? DEFAULT_STATEFUL_NON_OVERLAP_COST : 
DEFAULT_STATELESS_NON_OVERLAP_COST)
+            : assignmentConfigs.nonOverlapCost;
+
+        int cost = 0;
+        for (final TopicPartition tp : topicPartitions) {
+            final Set<String> tpRacks = racksForPartition.get(tp);
+            if (tpRacks == null || tpRacks.isEmpty()) {
+                throw new IllegalStateException("TopicPartition " + tp + " has 
no rack information. Maybe forgot to call canEnableRackAwareAssignor first");
+            }
+            if (!tpRacks.contains(clientRack)) {
+                cost += trafficCost;
+            }
+        }
+
+        if (!inCurrentAssignment) {
+            cost += nonOverlapCost;
+        }
+
+        return cost;
+    }
+
+    private static int getSinkID(final List<UUID> clientList, final 
List<TaskId> taskIdList) {
+        return clientList.size() + taskIdList.size();
+    }
+
+    // For testing. canEnableRackAwareAssignor must be called first
+    long activeTasksCost(final SortedMap<UUID, ClientState> clientStates, 
final SortedSet<TaskId> activeTasks, final boolean isStateful) {
+        final List<UUID> clientList = new ArrayList<>(clientStates.keySet());

Review Comment:
   Seems `constructActiveTaskGraph` is only reading but not modifying 
`clientList` -- why do we need to pass a deep-copy? Could we pass a `Set` 
instead of a `List`?
   
   Looking into `constructActiveTaskGraph`, it seems we access by index and try 
to make thing deterministic. Is this the reason why we need a list here? If 
yes, given that we go from `keySet` to list, is this translation actually 
deterministic itself (could be given that it's a sorted-map).



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignor.java:
##########
@@ -185,4 +193,212 @@ public boolean validateClientRack() {
         }
         return true;
     }
+
+    private int getCost(final TaskId taskId, final UUID clientId, final 
boolean inCurrentAssignment, final boolean isStateful) {
+        final Map<String, Optional<String>> clientRacks = 
racksForProcess.get(clientId);
+        if (clientRacks == null) {
+            throw new IllegalStateException("Client " + clientId + " doesn't 
exist in processRacks");
+        }
+        final Optional<Optional<String>> clientRackOpt = 
clientRacks.values().stream().filter(Optional::isPresent).findFirst();
+        if (!clientRackOpt.isPresent() || !clientRackOpt.get().isPresent()) {
+            throw new IllegalStateException("Client " + clientId + " doesn't 
have rack configured. Maybe forgot to call canEnableRackAwareAssignor first");
+        }
+
+        final String clientRack = clientRackOpt.get().get();
+        final Set<TopicPartition> topicPartitions = 
partitionsForTask.get(taskId);
+        if (topicPartitions == null || topicPartitions.isEmpty()) {
+            throw new IllegalStateException("Task " + taskId + " has no 
TopicPartitions");
+        }
+
+        final int trafficCost = assignmentConfigs.trafficCost == null ? 
(isStateful ? DEFAULT_STATEFUL_TRAFFIC_COST : DEFAULT_STATELESS_TRAFFIC_COST)
+            : assignmentConfigs.trafficCost;
+        final int nonOverlapCost = assignmentConfigs.nonOverlapCost == null ? 
(isStateful ? DEFAULT_STATEFUL_NON_OVERLAP_COST : 
DEFAULT_STATELESS_NON_OVERLAP_COST)
+            : assignmentConfigs.nonOverlapCost;
+
+        int cost = 0;
+        for (final TopicPartition tp : topicPartitions) {
+            final Set<String> tpRacks = racksForPartition.get(tp);
+            if (tpRacks == null || tpRacks.isEmpty()) {
+                throw new IllegalStateException("TopicPartition " + tp + " has 
no rack information. Maybe forgot to call canEnableRackAwareAssignor first");
+            }
+            if (!tpRacks.contains(clientRack)) {
+                cost += trafficCost;
+            }
+        }
+
+        if (!inCurrentAssignment) {
+            cost += nonOverlapCost;
+        }
+
+        return cost;
+    }
+
+    private static int getSinkID(final List<UUID> clientList, final 
List<TaskId> taskIdList) {
+        return clientList.size() + taskIdList.size();
+    }
+
+    // For testing. canEnableRackAwareAssignor must be called first
+    long activeTasksCost(final SortedMap<UUID, ClientState> clientStates, 
final SortedSet<TaskId> activeTasks, final boolean isStateful) {
+        final List<UUID> clientList = new ArrayList<>(clientStates.keySet());
+        final List<TaskId> taskIdList = new ArrayList<>(activeTasks);
+        final Map<TaskId, UUID> taskClientMap = new HashMap<>();
+        final Map<UUID, Integer> originalAssignedTaskNumber = new HashMap<>();
+        final Graph<Integer> graph = constructActiveTaskGraph(activeTasks, 
clientList, taskIdList,
+            clientStates, taskClientMap, originalAssignedTaskNumber, 
isStateful);
+        return graph.totalCost();
+    }
+
+    /**
+     * Optimize active task assignment for rack awareness. 
canEnableRackAwareAssignor must be called first
+     * @param clientStates Client states
+     * @param activeTasks Tasks to reassign if needed. They must be assigned 
already in clientStates
+     * @param isStateful Whether the tasks are stateful
+     * @return Total cost after optimization
+     */
+    public long optimizeActiveTasks(final SortedMap<UUID, ClientState> 
clientStates,
+                                    final SortedSet<TaskId> activeTasks,
+                                    final boolean isStateful) {
+        if (activeTasks.isEmpty()) {
+            return 0;
+        }
+
+        final List<UUID> clientList = new ArrayList<>(clientStates.keySet());
+        final List<TaskId> taskIdList = new ArrayList<>(activeTasks);
+        final Map<TaskId, UUID> taskClientMap = new HashMap<>();
+        final Map<UUID, Integer> originalAssignedTaskNumber = new HashMap<>();
+        final Graph<Integer> graph = constructActiveTaskGraph(activeTasks, 
clientList, taskIdList,
+            clientStates, taskClientMap, originalAssignedTaskNumber, 
isStateful);
+
+        graph.solveMinCostFlow();
+        final long cost = graph.totalCost();
+
+        assignActiveTaskFromMinCostFlow(graph, activeTasks, clientList, 
taskIdList,
+            clientStates, originalAssignedTaskNumber, taskClientMap);
+
+        return cost;
+    }
+
+    private Graph<Integer> constructActiveTaskGraph(final SortedSet<TaskId> 
activeTasks,
+                                                    final List<UUID> 
clientList,
+                                                    final List<TaskId> 
taskIdList,
+                                                    final Map<UUID, 
ClientState> clientStates,
+                                                    final Map<TaskId, UUID> 
taskClientMap,
+                                                    final Map<UUID, Integer> 
originalAssignedTaskNumber,
+                                                    final boolean isStateful) {
+        final Graph<Integer> graph = new Graph<>();
+
+        for (final TaskId taskId : activeTasks) {
+            for (final Entry<UUID, ClientState> clientState : 
clientStates.entrySet()) {
+                if (clientState.getValue().hasAssignedTask(taskId)) {
+                    originalAssignedTaskNumber.merge(clientState.getKey(), 1, 
Integer::sum);
+                }
+            }
+        }
+
+        // Make task and client Node id in graph deterministic
+        for (int taskNodeId  = 0; taskNodeId < taskIdList.size(); 
taskNodeId++) {
+            final TaskId taskId = taskIdList.get(taskNodeId);
+            for (int j = 0; j < clientList.size(); j++) {
+                final int clientNodeId = taskIdList.size() + j;
+                final UUID clientId = clientList.get(j);
+
+                final int flow = 
clientStates.get(clientId).hasAssignedTask(taskId) ? 1 : 0;
+                final int cost = getCost(taskId, clientId, flow == 1, 
isStateful);
+                if (flow == 1) {
+                    if (taskClientMap.containsKey(taskId)) {
+                        throw new IllegalArgumentException("Task " + taskId + 
" assigned to multiple clients "
+                            + clientId + ", " + taskClientMap.get(taskId));
+                    }
+                    taskClientMap.put(taskId, clientId);
+                }
+
+                graph.addEdge(taskNodeId, clientNodeId, 1, cost, flow);
+            }
+            if (!taskClientMap.containsKey(taskId)) {
+                throw new IllegalArgumentException("Task " + taskId + " not 
assigned to any client");
+            }
+        }
+
+        final int sinkId = getSinkID(clientList, taskIdList);
+        for (int taskNodeId = 0; taskNodeId < taskIdList.size(); taskNodeId++) 
{
+            graph.addEdge(SOURCE_ID, taskNodeId, 1, 0, 1);
+        }
+
+        // It's possible that some clients have 0 task assign. These clients 
will have 0 tasks assigned
+        // even though it may have lower traffic cost. This is to maintain the 
balance requirement.
+        for (int i = 0; i < clientList.size(); i++) {
+            final int clientNodeId = taskIdList.size() + i;
+            final int capacity = 
originalAssignedTaskNumber.getOrDefault(clientList.get(i), 0);
+            // Flow equals to capacity for edges to sink
+            graph.addEdge(clientNodeId, sinkId, capacity, 0, capacity);
+        }
+
+        graph.setSourceNode(SOURCE_ID);
+        graph.setSinkNode(sinkId);
+
+        return graph;
+    }
+
+    private void assignActiveTaskFromMinCostFlow(final Graph<Integer> graph,
+                                                 final SortedSet<TaskId> 
activeTasks,
+                                                 final List<UUID> clientList,
+                                                 final List<TaskId> taskIdList,
+                                                 final Map<UUID, ClientState> 
clientStates,
+                                                 final Map<UUID, Integer> 
originalClientCapacity,
+                                                 final Map<TaskId, UUID> 
taskClientMap) {
+        int taskAssigned = 0;

Review Comment:
   `task[s]Assigned` ?



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignor.java:
##########
@@ -185,4 +193,212 @@ public boolean validateClientRack() {
         }
         return true;
     }
+
+    private int getCost(final TaskId taskId, final UUID clientId, final 
boolean inCurrentAssignment, final boolean isStateful) {
+        final Map<String, Optional<String>> clientRacks = 
racksForProcess.get(clientId);
+        if (clientRacks == null) {
+            throw new IllegalStateException("Client " + clientId + " doesn't 
exist in processRacks");
+        }
+        final Optional<Optional<String>> clientRackOpt = 
clientRacks.values().stream().filter(Optional::isPresent).findFirst();
+        if (!clientRackOpt.isPresent() || !clientRackOpt.get().isPresent()) {
+            throw new IllegalStateException("Client " + clientId + " doesn't 
have rack configured. Maybe forgot to call canEnableRackAwareAssignor first");
+        }
+
+        final String clientRack = clientRackOpt.get().get();
+        final Set<TopicPartition> topicPartitions = 
partitionsForTask.get(taskId);
+        if (topicPartitions == null || topicPartitions.isEmpty()) {
+            throw new IllegalStateException("Task " + taskId + " has no 
TopicPartitions");
+        }
+
+        final int trafficCost = assignmentConfigs.trafficCost == null ? 
(isStateful ? DEFAULT_STATEFUL_TRAFFIC_COST : DEFAULT_STATELESS_TRAFFIC_COST)
+            : assignmentConfigs.trafficCost;
+        final int nonOverlapCost = assignmentConfigs.nonOverlapCost == null ? 
(isStateful ? DEFAULT_STATEFUL_NON_OVERLAP_COST : 
DEFAULT_STATELESS_NON_OVERLAP_COST)
+            : assignmentConfigs.nonOverlapCost;
+
+        int cost = 0;
+        for (final TopicPartition tp : topicPartitions) {
+            final Set<String> tpRacks = racksForPartition.get(tp);
+            if (tpRacks == null || tpRacks.isEmpty()) {
+                throw new IllegalStateException("TopicPartition " + tp + " has 
no rack information. Maybe forgot to call canEnableRackAwareAssignor first");
+            }
+            if (!tpRacks.contains(clientRack)) {
+                cost += trafficCost;
+            }
+        }
+
+        if (!inCurrentAssignment) {
+            cost += nonOverlapCost;
+        }
+
+        return cost;
+    }
+
+    private static int getSinkID(final List<UUID> clientList, final 
List<TaskId> taskIdList) {
+        return clientList.size() + taskIdList.size();
+    }
+
+    // For testing. canEnableRackAwareAssignor must be called first
+    long activeTasksCost(final SortedMap<UUID, ClientState> clientStates, 
final SortedSet<TaskId> activeTasks, final boolean isStateful) {
+        final List<UUID> clientList = new ArrayList<>(clientStates.keySet());
+        final List<TaskId> taskIdList = new ArrayList<>(activeTasks);
+        final Map<TaskId, UUID> taskClientMap = new HashMap<>();
+        final Map<UUID, Integer> originalAssignedTaskNumber = new HashMap<>();
+        final Graph<Integer> graph = constructActiveTaskGraph(activeTasks, 
clientList, taskIdList,
+            clientStates, taskClientMap, originalAssignedTaskNumber, 
isStateful);
+        return graph.totalCost();
+    }
+
+    /**
+     * Optimize active task assignment for rack awareness. 
canEnableRackAwareAssignor must be called first
+     * @param clientStates Client states
+     * @param activeTasks Tasks to reassign if needed. They must be assigned 
already in clientStates
+     * @param isStateful Whether the tasks are stateful
+     * @return Total cost after optimization
+     */
+    public long optimizeActiveTasks(final SortedMap<UUID, ClientState> 
clientStates,
+                                    final SortedSet<TaskId> activeTasks,
+                                    final boolean isStateful) {
+        if (activeTasks.isEmpty()) {
+            return 0;
+        }
+
+        final List<UUID> clientList = new ArrayList<>(clientStates.keySet());
+        final List<TaskId> taskIdList = new ArrayList<>(activeTasks);
+        final Map<TaskId, UUID> taskClientMap = new HashMap<>();
+        final Map<UUID, Integer> originalAssignedTaskNumber = new HashMap<>();
+        final Graph<Integer> graph = constructActiveTaskGraph(activeTasks, 
clientList, taskIdList,
+            clientStates, taskClientMap, originalAssignedTaskNumber, 
isStateful);
+
+        graph.solveMinCostFlow();
+        final long cost = graph.totalCost();
+
+        assignActiveTaskFromMinCostFlow(graph, activeTasks, clientList, 
taskIdList,
+            clientStates, originalAssignedTaskNumber, taskClientMap);
+
+        return cost;
+    }
+
+    private Graph<Integer> constructActiveTaskGraph(final SortedSet<TaskId> 
activeTasks,
+                                                    final List<UUID> 
clientList,
+                                                    final List<TaskId> 
taskIdList,
+                                                    final Map<UUID, 
ClientState> clientStates,
+                                                    final Map<TaskId, UUID> 
taskClientMap,
+                                                    final Map<UUID, Integer> 
originalAssignedTaskNumber,
+                                                    final boolean isStateful) {
+        final Graph<Integer> graph = new Graph<>();
+
+        for (final TaskId taskId : activeTasks) {
+            for (final Entry<UUID, ClientState> clientState : 
clientStates.entrySet()) {
+                if (clientState.getValue().hasAssignedTask(taskId)) {
+                    originalAssignedTaskNumber.merge(clientState.getKey(), 1, 
Integer::sum);
+                }
+            }
+        }
+
+        // Make task and client Node id in graph deterministic
+        for (int taskNodeId  = 0; taskNodeId < taskIdList.size(); 
taskNodeId++) {
+            final TaskId taskId = taskIdList.get(taskNodeId);
+            for (int j = 0; j < clientList.size(); j++) {
+                final int clientNodeId = taskIdList.size() + j;
+                final UUID clientId = clientList.get(j);
+
+                final int flow = 
clientStates.get(clientId).hasAssignedTask(taskId) ? 1 : 0;
+                final int cost = getCost(taskId, clientId, flow == 1, 
isStateful);
+                if (flow == 1) {
+                    if (taskClientMap.containsKey(taskId)) {
+                        throw new IllegalArgumentException("Task " + taskId + 
" assigned to multiple clients "
+                            + clientId + ", " + taskClientMap.get(taskId));
+                    }
+                    taskClientMap.put(taskId, clientId);
+                }
+
+                graph.addEdge(taskNodeId, clientNodeId, 1, cost, flow);
+            }
+            if (!taskClientMap.containsKey(taskId)) {
+                throw new IllegalArgumentException("Task " + taskId + " not 
assigned to any client");
+            }
+        }
+
+        final int sinkId = getSinkID(clientList, taskIdList);
+        for (int taskNodeId = 0; taskNodeId < taskIdList.size(); taskNodeId++) 
{
+            graph.addEdge(SOURCE_ID, taskNodeId, 1, 0, 1);
+        }
+
+        // It's possible that some clients have 0 task assign. These clients 
will have 0 tasks assigned
+        // even though it may have lower traffic cost. This is to maintain the 
balance requirement.

Review Comment:
   Not sure if I understand? How would a client with zero tasks improve balance?



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignor.java:
##########
@@ -185,4 +193,212 @@ public boolean validateClientRack() {
         }
         return true;
     }
+
+    private int getCost(final TaskId taskId, final UUID clientId, final 
boolean inCurrentAssignment, final boolean isStateful) {
+        final Map<String, Optional<String>> clientRacks = 
racksForProcess.get(clientId);
+        if (clientRacks == null) {
+            throw new IllegalStateException("Client " + clientId + " doesn't 
exist in processRacks");
+        }
+        final Optional<Optional<String>> clientRackOpt = 
clientRacks.values().stream().filter(Optional::isPresent).findFirst();
+        if (!clientRackOpt.isPresent() || !clientRackOpt.get().isPresent()) {
+            throw new IllegalStateException("Client " + clientId + " doesn't 
have rack configured. Maybe forgot to call canEnableRackAwareAssignor first");
+        }
+
+        final String clientRack = clientRackOpt.get().get();
+        final Set<TopicPartition> topicPartitions = 
partitionsForTask.get(taskId);
+        if (topicPartitions == null || topicPartitions.isEmpty()) {
+            throw new IllegalStateException("Task " + taskId + " has no 
TopicPartitions");
+        }
+
+        final int trafficCost = assignmentConfigs.trafficCost == null ? 
(isStateful ? DEFAULT_STATEFUL_TRAFFIC_COST : DEFAULT_STATELESS_TRAFFIC_COST)
+            : assignmentConfigs.trafficCost;
+        final int nonOverlapCost = assignmentConfigs.nonOverlapCost == null ? 
(isStateful ? DEFAULT_STATEFUL_NON_OVERLAP_COST : 
DEFAULT_STATELESS_NON_OVERLAP_COST)
+            : assignmentConfigs.nonOverlapCost;
+
+        int cost = 0;
+        for (final TopicPartition tp : topicPartitions) {
+            final Set<String> tpRacks = racksForPartition.get(tp);
+            if (tpRacks == null || tpRacks.isEmpty()) {
+                throw new IllegalStateException("TopicPartition " + tp + " has 
no rack information. Maybe forgot to call canEnableRackAwareAssignor first");
+            }
+            if (!tpRacks.contains(clientRack)) {
+                cost += trafficCost;
+            }
+        }
+
+        if (!inCurrentAssignment) {
+            cost += nonOverlapCost;
+        }
+
+        return cost;
+    }
+
+    private static int getSinkID(final List<UUID> clientList, final 
List<TaskId> taskIdList) {
+        return clientList.size() + taskIdList.size();
+    }
+
+    // For testing. canEnableRackAwareAssignor must be called first
+    long activeTasksCost(final SortedMap<UUID, ClientState> clientStates, 
final SortedSet<TaskId> activeTasks, final boolean isStateful) {
+        final List<UUID> clientList = new ArrayList<>(clientStates.keySet());
+        final List<TaskId> taskIdList = new ArrayList<>(activeTasks);
+        final Map<TaskId, UUID> taskClientMap = new HashMap<>();
+        final Map<UUID, Integer> originalAssignedTaskNumber = new HashMap<>();
+        final Graph<Integer> graph = constructActiveTaskGraph(activeTasks, 
clientList, taskIdList,
+            clientStates, taskClientMap, originalAssignedTaskNumber, 
isStateful);
+        return graph.totalCost();
+    }
+
+    /**
+     * Optimize active task assignment for rack awareness. 
canEnableRackAwareAssignor must be called first
+     * @param clientStates Client states
+     * @param activeTasks Tasks to reassign if needed. They must be assigned 
already in clientStates
+     * @param isStateful Whether the tasks are stateful
+     * @return Total cost after optimization
+     */
+    public long optimizeActiveTasks(final SortedMap<UUID, ClientState> 
clientStates,
+                                    final SortedSet<TaskId> activeTasks,
+                                    final boolean isStateful) {
+        if (activeTasks.isEmpty()) {
+            return 0;
+        }
+
+        final List<UUID> clientList = new ArrayList<>(clientStates.keySet());
+        final List<TaskId> taskIdList = new ArrayList<>(activeTasks);
+        final Map<TaskId, UUID> taskClientMap = new HashMap<>();
+        final Map<UUID, Integer> originalAssignedTaskNumber = new HashMap<>();
+        final Graph<Integer> graph = constructActiveTaskGraph(activeTasks, 
clientList, taskIdList,
+            clientStates, taskClientMap, originalAssignedTaskNumber, 
isStateful);
+
+        graph.solveMinCostFlow();
+        final long cost = graph.totalCost();
+
+        assignActiveTaskFromMinCostFlow(graph, activeTasks, clientList, 
taskIdList,
+            clientStates, originalAssignedTaskNumber, taskClientMap);
+
+        return cost;
+    }
+
+    private Graph<Integer> constructActiveTaskGraph(final SortedSet<TaskId> 
activeTasks,
+                                                    final List<UUID> 
clientList,
+                                                    final List<TaskId> 
taskIdList,
+                                                    final Map<UUID, 
ClientState> clientStates,
+                                                    final Map<TaskId, UUID> 
taskClientMap,
+                                                    final Map<UUID, Integer> 
originalAssignedTaskNumber,
+                                                    final boolean isStateful) {
+        final Graph<Integer> graph = new Graph<>();
+
+        for (final TaskId taskId : activeTasks) {
+            for (final Entry<UUID, ClientState> clientState : 
clientStates.entrySet()) {
+                if (clientState.getValue().hasAssignedTask(taskId)) {
+                    originalAssignedTaskNumber.merge(clientState.getKey(), 1, 
Integer::sum);
+                }
+            }
+        }
+
+        // Make task and client Node id in graph deterministic
+        for (int taskNodeId  = 0; taskNodeId < taskIdList.size(); 
taskNodeId++) {
+            final TaskId taskId = taskIdList.get(taskNodeId);
+            for (int j = 0; j < clientList.size(); j++) {
+                final int clientNodeId = taskIdList.size() + j;
+                final UUID clientId = clientList.get(j);
+
+                final int flow = 
clientStates.get(clientId).hasAssignedTask(taskId) ? 1 : 0;
+                final int cost = getCost(taskId, clientId, flow == 1, 
isStateful);
+                if (flow == 1) {
+                    if (taskClientMap.containsKey(taskId)) {
+                        throw new IllegalArgumentException("Task " + taskId + 
" assigned to multiple clients "
+                            + clientId + ", " + taskClientMap.get(taskId));
+                    }
+                    taskClientMap.put(taskId, clientId);
+                }
+
+                graph.addEdge(taskNodeId, clientNodeId, 1, cost, flow);
+            }
+            if (!taskClientMap.containsKey(taskId)) {
+                throw new IllegalArgumentException("Task " + taskId + " not 
assigned to any client");
+            }
+        }
+
+        final int sinkId = getSinkID(clientList, taskIdList);
+        for (int taskNodeId = 0; taskNodeId < taskIdList.size(); taskNodeId++) 
{
+            graph.addEdge(SOURCE_ID, taskNodeId, 1, 0, 1);
+        }
+
+        // It's possible that some clients have 0 task assign. These clients 
will have 0 tasks assigned
+        // even though it may have lower traffic cost. This is to maintain the 
balance requirement.
+        for (int i = 0; i < clientList.size(); i++) {
+            final int clientNodeId = taskIdList.size() + i;
+            final int capacity = 
originalAssignedTaskNumber.getOrDefault(clientList.get(i), 0);
+            // Flow equals to capacity for edges to sink
+            graph.addEdge(clientNodeId, sinkId, capacity, 0, capacity);
+        }
+
+        graph.setSourceNode(SOURCE_ID);
+        graph.setSinkNode(sinkId);
+
+        return graph;
+    }
+
+    private void assignActiveTaskFromMinCostFlow(final Graph<Integer> graph,
+                                                 final SortedSet<TaskId> 
activeTasks,
+                                                 final List<UUID> clientList,
+                                                 final List<TaskId> taskIdList,
+                                                 final Map<UUID, ClientState> 
clientStates,
+                                                 final Map<UUID, Integer> 
originalClientCapacity,
+                                                 final Map<TaskId, UUID> 
taskClientMap) {
+        int taskAssigned = 0;
+        for (int taskNodeId = 0; taskNodeId < taskIdList.size(); taskNodeId++) 
{
+            final TaskId taskId = taskIdList.get(taskNodeId);
+            final Map<Integer, Graph<Integer>.Edge> edges = 
graph.edges(taskNodeId);
+            for (final Graph<Integer>.Edge edge : edges.values()) {
+                if (edge.flow > 0) {
+                    taskAssigned++;
+                    final int clientIndex = edge.destination - 
taskIdList.size();
+                    final UUID clientId = clientList.get(clientIndex);
+                    final UUID originalClientId = taskClientMap.get(taskId);
+
+                    // Don't need to assign this task to other client
+                    if (clientId.equals(originalClientId)) {
+                        break;
+                    }
+
+                    final ClientState clientState = clientStates.get(clientId);
+                    final ClientState originalClientState = 
clientStates.get(originalClientId);
+                    originalClientState.unassignActive(taskId);
+                    clientState.assignActive(taskId);

Review Comment:
   Could we simplify this 4 lines to:
   ```
   clientStates.get(originalClientId).unassignActive(taskId);
   clientStates.get(clientId).assignActive(taskId);
   ```



##########
streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignorTest.java:
##########
@@ -145,147 +188,517 @@ public void disableActiveSinceRackMissingInClient() {
 
         // False since process1 doesn't have rackId
         assertFalse(assignor.validateClientRack());
-        assertFalse(assignor.canEnableRackAwareAssignorForActiveTasks());
+        assertFalse(assignor.canEnableRackAwareAssignor());
     }
 
     @Test
-    public void disableActiveSinceRackDiffersInSameProcess() {
+    public void shouldDisableActiveWhenRackDiffersInSameProcess() {
         final Map<UUID, Map<String, Optional<String>>> processRacks = new 
HashMap<>();
 
         // Different consumers in same process have different rack ID. This 
shouldn't happen.
         // If happens, there's a bug somewhere
-        processRacks.computeIfAbsent(process0UUID, k -> new 
HashMap<>()).put("consumer1", Optional.of("rack1"));
-        processRacks.computeIfAbsent(process0UUID, k -> new 
HashMap<>()).put("consumer2", Optional.of("rack2"));
+        processRacks.computeIfAbsent(UUID_1, k -> new 
HashMap<>()).put("consumer1", Optional.of("rack1"));
+        processRacks.computeIfAbsent(UUID_1, k -> new 
HashMap<>()).put("consumer2", Optional.of("rack2"));
 
         final RackAwareTaskAssignor assignor = new RackAwareTaskAssignor(
             getClusterForTopic0(),
-            getTaskTopicPartitionMapForTask1(),
+            getTaskTopicPartitionMapForTask0(),
             getTopologyGroupTaskMap(),
             processRacks,
             mockInternalTopicManager,
             new 
AssignorConfiguration(streamsConfig.originals()).assignmentConfigs()
         );
 
         assertFalse(assignor.validateClientRack());
-        assertFalse(assignor.canEnableRackAwareAssignorForActiveTasks());
+        assertFalse(assignor.canEnableRackAwareAssignor());
     }
 
     @Test
-    public void enableRackAwareAssignorForActiveWithoutDescribingTopics() {
+    public void 
shouldEnableRackAwareAssignorForActiveWithoutDescribingTopics() {
         final RackAwareTaskAssignor assignor = new RackAwareTaskAssignor(
             getClusterForTopic0(),
-            getTaskTopicPartitionMapForTask1(),
+            getTaskTopicPartitionMapForTask0(),
             getTopologyGroupTaskMap(),
             getProcessRacksForProcess0(),
             mockInternalTopicManager,
             new 
AssignorConfiguration(streamsConfig.originals()).assignmentConfigs()
         );
 
         // partitionWithoutInfo00 has rackInfo in cluster metadata
-        assertTrue(assignor.canEnableRackAwareAssignorForActiveTasks());
+        assertTrue(assignor.canEnableRackAwareAssignor());
     }
 
     @Test
-    public void enableRackAwareAssignorForActiveWithDescribingTopics() {
+    public void shouldEnableRackAwareAssignorForActiveWithDescribingTopics() {
         final MockInternalTopicManager spyTopicManager = 
spy(mockInternalTopicManager);
         doReturn(
             Collections.singletonMap(
-                TOPIC0,
+                TP_0_NAME,
                 Collections.singletonList(
-                    new TopicPartitionInfo(0, node0, Arrays.asList(replicas1), 
Collections.emptyList())
+                    new TopicPartitionInfo(0, NODE_0, 
Arrays.asList(REPLICA_1), Collections.emptyList())
                 )
             )
-        
).when(spyTopicManager).getTopicPartitionInfo(Collections.singleton(TOPIC0));
+        
).when(spyTopicManager).getTopicPartitionInfo(Collections.singleton(TP_0_NAME));
 
         final RackAwareTaskAssignor assignor = new RackAwareTaskAssignor(
             getClusterWithNoNode(),
-            getTaskTopicPartitionMapForTask1(),
+            getTaskTopicPartitionMapForTask0(),
             getTopologyGroupTaskMap(),
             getProcessRacksForProcess0(),
             spyTopicManager,
             new 
AssignorConfiguration(streamsConfig.originals()).assignmentConfigs()
         );
 
-        assertTrue(assignor.canEnableRackAwareAssignorForActiveTasks());
+        assertTrue(assignor.canEnableRackAwareAssignor());
     }
 
     @Test
-    public void disableRackAwareAssignorForActiveWithDescribingTopicsFailure() 
{
+    public void 
shouldDisableRackAwareAssignorForActiveWithDescribingTopicsFailure() {
         final MockInternalTopicManager spyTopicManager = 
spy(mockInternalTopicManager);
-        doThrow(new TimeoutException("Timeout describing 
topic")).when(spyTopicManager).getTopicPartitionInfo(Collections.singleton(TOPIC0));
+        doThrow(new TimeoutException("Timeout describing 
topic")).when(spyTopicManager).getTopicPartitionInfo(Collections.singleton(
+            TP_0_NAME));
 
         final RackAwareTaskAssignor assignor = new RackAwareTaskAssignor(
             getClusterWithNoNode(),
-            getTaskTopicPartitionMapForTask1(),
+            getTaskTopicPartitionMapForTask0(),
             getTopologyGroupTaskMap(),
             getProcessRacksForProcess0(),
             spyTopicManager,
             new 
AssignorConfiguration(streamsConfig.originals()).assignmentConfigs()
         );
 
-        assertFalse(assignor.canEnableRackAwareAssignorForActiveTasks());
+        assertFalse(assignor.canEnableRackAwareAssignor());
         assertTrue(assignor.populateTopicsToDiscribe(new HashSet<>()));
     }
 
+    @Test
+    public void shouldOptimizeEmptyActiveTasks() {
+        final RackAwareTaskAssignor assignor = new RackAwareTaskAssignor(
+            getClusterForTopic0And1(),
+            getTaskTopicPartitionMapForAllTasks(),
+            getTopologyGroupTaskMap(),
+            getProcessRacksForAllProcess(),
+            mockInternalTopicManager,
+            new 
AssignorConfiguration(streamsConfig.originals()).assignmentConfigs()
+        );
+
+        final ClientState clientState0 = new ClientState(emptySet(), 
emptySet(), emptyMap(), EMPTY_CLIENT_TAGS, 1);
+
+        clientState0.assignActiveTasks(mkSet(TASK_0_1, TASK_1_1));
+
+        final SortedMap<UUID, ClientState> clientStateMap = new 
TreeMap<>(mkMap(
+            mkEntry(UUID_1, clientState0)
+        ));
+        final SortedSet<TaskId> taskIds = mkSortedSet();

Review Comment:
   Why is this set empty?



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignorConfiguration.java:
##########
@@ -268,24 +268,44 @@ public static class AssignmentConfigs {
         public final long probingRebalanceIntervalMs;
         public final List<String> rackAwareAssignmentTags;
 
+        // TODO: get from streamsConfig after we add the config

Review Comment:
   Might be good to add an extensive comment to explain what both do.



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignor.java:
##########
@@ -185,4 +193,212 @@ public boolean validateClientRack() {
         }
         return true;
     }
+
+    private int getCost(final TaskId taskId, final UUID clientId, final 
boolean inCurrentAssignment, final boolean isStateful) {
+        final Map<String, Optional<String>> clientRacks = 
racksForProcess.get(clientId);
+        if (clientRacks == null) {
+            throw new IllegalStateException("Client " + clientId + " doesn't 
exist in processRacks");
+        }
+        final Optional<Optional<String>> clientRackOpt = 
clientRacks.values().stream().filter(Optional::isPresent).findFirst();
+        if (!clientRackOpt.isPresent() || !clientRackOpt.get().isPresent()) {
+            throw new IllegalStateException("Client " + clientId + " doesn't 
have rack configured. Maybe forgot to call canEnableRackAwareAssignor first");
+        }
+
+        final String clientRack = clientRackOpt.get().get();
+        final Set<TopicPartition> topicPartitions = 
partitionsForTask.get(taskId);
+        if (topicPartitions == null || topicPartitions.isEmpty()) {
+            throw new IllegalStateException("Task " + taskId + " has no 
TopicPartitions");
+        }
+
+        final int trafficCost = assignmentConfigs.trafficCost == null ? 
(isStateful ? DEFAULT_STATEFUL_TRAFFIC_COST : DEFAULT_STATELESS_TRAFFIC_COST)
+            : assignmentConfigs.trafficCost;
+        final int nonOverlapCost = assignmentConfigs.nonOverlapCost == null ? 
(isStateful ? DEFAULT_STATEFUL_NON_OVERLAP_COST : 
DEFAULT_STATELESS_NON_OVERLAP_COST)
+            : assignmentConfigs.nonOverlapCost;
+
+        int cost = 0;
+        for (final TopicPartition tp : topicPartitions) {
+            final Set<String> tpRacks = racksForPartition.get(tp);
+            if (tpRacks == null || tpRacks.isEmpty()) {
+                throw new IllegalStateException("TopicPartition " + tp + " has 
no rack information. Maybe forgot to call canEnableRackAwareAssignor first");
+            }
+            if (!tpRacks.contains(clientRack)) {
+                cost += trafficCost;
+            }
+        }
+
+        if (!inCurrentAssignment) {
+            cost += nonOverlapCost;
+        }
+
+        return cost;
+    }
+
+    private static int getSinkID(final List<UUID> clientList, final 
List<TaskId> taskIdList) {
+        return clientList.size() + taskIdList.size();
+    }
+
+    // For testing. canEnableRackAwareAssignor must be called first
+    long activeTasksCost(final SortedMap<UUID, ClientState> clientStates, 
final SortedSet<TaskId> activeTasks, final boolean isStateful) {
+        final List<UUID> clientList = new ArrayList<>(clientStates.keySet());
+        final List<TaskId> taskIdList = new ArrayList<>(activeTasks);
+        final Map<TaskId, UUID> taskClientMap = new HashMap<>();
+        final Map<UUID, Integer> originalAssignedTaskNumber = new HashMap<>();
+        final Graph<Integer> graph = constructActiveTaskGraph(activeTasks, 
clientList, taskIdList,
+            clientStates, taskClientMap, originalAssignedTaskNumber, 
isStateful);
+        return graph.totalCost();
+    }
+
+    /**
+     * Optimize active task assignment for rack awareness. 
canEnableRackAwareAssignor must be called first
+     * @param clientStates Client states
+     * @param activeTasks Tasks to reassign if needed. They must be assigned 
already in clientStates
+     * @param isStateful Whether the tasks are stateful
+     * @return Total cost after optimization
+     */
+    public long optimizeActiveTasks(final SortedMap<UUID, ClientState> 
clientStates,
+                                    final SortedSet<TaskId> activeTasks,
+                                    final boolean isStateful) {
+        if (activeTasks.isEmpty()) {
+            return 0;
+        }
+
+        final List<UUID> clientList = new ArrayList<>(clientStates.keySet());
+        final List<TaskId> taskIdList = new ArrayList<>(activeTasks);
+        final Map<TaskId, UUID> taskClientMap = new HashMap<>();
+        final Map<UUID, Integer> originalAssignedTaskNumber = new HashMap<>();
+        final Graph<Integer> graph = constructActiveTaskGraph(activeTasks, 
clientList, taskIdList,
+            clientStates, taskClientMap, originalAssignedTaskNumber, 
isStateful);
+
+        graph.solveMinCostFlow();
+        final long cost = graph.totalCost();
+
+        assignActiveTaskFromMinCostFlow(graph, activeTasks, clientList, 
taskIdList,
+            clientStates, originalAssignedTaskNumber, taskClientMap);
+
+        return cost;
+    }
+
+    private Graph<Integer> constructActiveTaskGraph(final SortedSet<TaskId> 
activeTasks,
+                                                    final List<UUID> 
clientList,
+                                                    final List<TaskId> 
taskIdList,
+                                                    final Map<UUID, 
ClientState> clientStates,
+                                                    final Map<TaskId, UUID> 
taskClientMap,
+                                                    final Map<UUID, Integer> 
originalAssignedTaskNumber,
+                                                    final boolean isStateful) {
+        final Graph<Integer> graph = new Graph<>();
+
+        for (final TaskId taskId : activeTasks) {
+            for (final Entry<UUID, ClientState> clientState : 
clientStates.entrySet()) {
+                if (clientState.getValue().hasAssignedTask(taskId)) {
+                    originalAssignedTaskNumber.merge(clientState.getKey(), 1, 
Integer::sum);
+                }
+            }
+        }
+
+        // Make task and client Node id in graph deterministic
+        for (int taskNodeId  = 0; taskNodeId < taskIdList.size(); 
taskNodeId++) {
+            final TaskId taskId = taskIdList.get(taskNodeId);
+            for (int j = 0; j < clientList.size(); j++) {
+                final int clientNodeId = taskIdList.size() + j;
+                final UUID clientId = clientList.get(j);
+
+                final int flow = 
clientStates.get(clientId).hasAssignedTask(taskId) ? 1 : 0;
+                final int cost = getCost(taskId, clientId, flow == 1, 
isStateful);
+                if (flow == 1) {
+                    if (taskClientMap.containsKey(taskId)) {
+                        throw new IllegalArgumentException("Task " + taskId + 
" assigned to multiple clients "
+                            + clientId + ", " + taskClientMap.get(taskId));
+                    }
+                    taskClientMap.put(taskId, clientId);
+                }
+
+                graph.addEdge(taskNodeId, clientNodeId, 1, cost, flow);
+            }
+            if (!taskClientMap.containsKey(taskId)) {
+                throw new IllegalArgumentException("Task " + taskId + " not 
assigned to any client");
+            }
+        }
+
+        final int sinkId = getSinkID(clientList, taskIdList);
+        for (int taskNodeId = 0; taskNodeId < taskIdList.size(); taskNodeId++) 
{
+            graph.addEdge(SOURCE_ID, taskNodeId, 1, 0, 1);
+        }
+
+        // It's possible that some clients have 0 task assign. These clients 
will have 0 tasks assigned
+        // even though it may have lower traffic cost. This is to maintain the 
balance requirement.
+        for (int i = 0; i < clientList.size(); i++) {
+            final int clientNodeId = taskIdList.size() + i;
+            final int capacity = 
originalAssignedTaskNumber.getOrDefault(clientList.get(i), 0);
+            // Flow equals to capacity for edges to sink
+            graph.addEdge(clientNodeId, sinkId, capacity, 0, capacity);
+        }
+
+        graph.setSourceNode(SOURCE_ID);
+        graph.setSinkNode(sinkId);
+
+        return graph;
+    }
+
+    private void assignActiveTaskFromMinCostFlow(final Graph<Integer> graph,
+                                                 final SortedSet<TaskId> 
activeTasks,
+                                                 final List<UUID> clientList,
+                                                 final List<TaskId> taskIdList,
+                                                 final Map<UUID, ClientState> 
clientStates,
+                                                 final Map<UUID, Integer> 
originalClientCapacity,
+                                                 final Map<TaskId, UUID> 
taskClientMap) {
+        int taskAssigned = 0;
+        for (int taskNodeId = 0; taskNodeId < taskIdList.size(); taskNodeId++) 
{
+            final TaskId taskId = taskIdList.get(taskNodeId);
+            final Map<Integer, Graph<Integer>.Edge> edges = 
graph.edges(taskNodeId);
+            for (final Graph<Integer>.Edge edge : edges.values()) {
+                if (edge.flow > 0) {
+                    taskAssigned++;
+                    final int clientIndex = edge.destination - 
taskIdList.size();
+                    final UUID clientId = clientList.get(clientIndex);
+                    final UUID originalClientId = taskClientMap.get(taskId);
+
+                    // Don't need to assign this task to other client
+                    if (clientId.equals(originalClientId)) {
+                        break;
+                    }
+
+                    final ClientState clientState = clientStates.get(clientId);
+                    final ClientState originalClientState = 
clientStates.get(originalClientId);
+                    originalClientState.unassignActive(taskId);
+                    clientState.assignActive(taskId);
+                }

Review Comment:
   Should we `break` here (ie, inside the `if` block)? There should be only one 
edge with flow 1?
   
   Or even replace the `for` loop over the edges with a 
`graph.edges(taskNodeId).values().stream().filter(e.flow == 1).findFirst` (and 
throw if we don't find any edge with flow == 1)?



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignor.java:
##########
@@ -185,4 +191,224 @@ public boolean validateClientRack() {
         }
         return true;
     }
+
+    private int getCost(final TaskId taskId, final UUID clientId, final 
boolean inCurrentAssignment, final boolean isStateful) {

Review Comment:
   > Here I meant processId using the UUID.
   
   Should we rename it? It's unfortunate that it's overloaded, but `clientId` 
sounds like `client.id` config, but it not. (Or add a comment?)



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignor.java:
##########
@@ -38,29 +43,34 @@
 
 public class RackAwareTaskAssignor {
     private static final Logger log = 
LoggerFactory.getLogger(RackAwareTaskAssignor.class);
+    private static final int DEFAULT_STATEFUL_TRAFFIC_COST = 10;

Review Comment:
   Thanks for the in-person sync. To summarize what we discussed:
   
   For RackAware assignment, we want to avoid to move a task from client A to 
client B if both clients are in the same rack (because the cross-rack-traffic 
cost is still the same anyway). -- However, if we can move a task from client A 
to client B and it will reduce cross-rack-traffic cost, we should move the task 
and neglect non-overlap-cost.
   
   It make sense to me, to pass in the values as propose, and use different 
ones for sticky vs HA.
   
   About where to construct it: I would say whatever is simplest :) -- I guess 
we could pass it via `assign()` directly, or just create inside `assign()` -- 
whatever is less complex to do.



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignor.java:
##########
@@ -185,4 +191,224 @@ public boolean validateClientRack() {
         }
         return true;
     }
+
+    private int getCost(final TaskId taskId, final UUID clientId, final 
boolean inCurrentAssignment, final boolean isStateful) {
+        final Map<String, Optional<String>> clientRacks = 
racksForProcess.get(clientId);
+        if (clientRacks == null) {
+            throw new IllegalStateException("Client " + clientId + " doesn't 
exist in processRacks");
+        }
+        final Optional<Optional<String>> clientRackOpt = 
clientRacks.values().stream().filter(Optional::isPresent).findFirst();
+        if (!clientRackOpt.isPresent() || !clientRackOpt.get().isPresent()) {
+            throw new IllegalStateException("Client " + clientId + " doesn't 
have rack configured. Maybe forgot to call canEnableRackAwareAssignor first");
+        }
+
+        final String clientRack = clientRackOpt.get().get();
+        final Set<TopicPartition> topicPartitions = 
partitionsForTask.get(taskId);
+        if (topicPartitions == null) {
+            throw new IllegalStateException("Task " + taskId + " has no 
TopicPartitions");
+        }
+
+        final int trafficCost = assignmentConfigs.trafficCost == null ? 
(isStateful ? DEFAULT_STATEFUL_TRAFFIC_COST : DEFAULT_STATELESS_TRAFFIC_COST)
+            : assignmentConfigs.trafficCost;
+        final int nonOverlapCost = assignmentConfigs.nonOverlapCost == null ? 
(isStateful ? DEFAULT_STATEFUL_NON_OVERLAP_COST : 
DEFAULT_STATELESS_NON_OVERLAP_COST)
+            : assignmentConfigs.nonOverlapCost;
+
+        int cost = 0;
+        for (final TopicPartition tp : topicPartitions) {
+            final Set<String> tpRacks = racksForPartition.get(tp);
+            if (tpRacks == null || tpRacks.isEmpty()) {
+                throw new IllegalStateException("TopicPartition " + tp + " has 
no rack information. Maybe forgot to call canEnableRackAwareAssignor first");
+            }
+            if (!tpRacks.contains(clientRack)) {
+                cost += trafficCost;
+            }
+        }
+
+        if (!inCurrentAssignment) {
+            cost += nonOverlapCost;

Review Comment:
   Left already a comment above about adding a better explanation. Also 
wondering, if we could find a more descriptive name? Unfortunately, I don't 
have a good idea either.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: jira-unsubscr...@kafka.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to