mjsax commented on code in PR #14714:
URL: https://github.com/apache/kafka/pull/14714#discussion_r1408740752


##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareGraphConstructor.java:
##########
@@ -0,0 +1,47 @@
+package org.apache.kafka.streams.processor.internals.assignment;
+
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.SortedMap;
+import java.util.UUID;
+import java.util.function.BiConsumer;
+import java.util.function.BiPredicate;
+import org.apache.kafka.streams.processor.TaskId;
+import 
org.apache.kafka.streams.processor.internals.TopologyMetadata.Subtopology;
+import 
org.apache.kafka.streams.processor.internals.assignment.RackAwareTaskAssignor.GetCostFunction;
+
+/**
+ * Construct graph for rack aware task assignor
+ */
+public interface RackAwareGraphConstructor {
+    int SOURCE_ID = -1;
+
+    int getSinkNodeID(final List<TaskId> taskIdList, final List<UUID> 
clientList, final Map<Subtopology, Set<TaskId>> tasksForTopicGroup);
+
+    int getClientNodeId(final int clientIndex, final List<TaskId> taskIdList, 
final List<UUID> clientList, final int topicGroupIndex);
+
+    int getClientIndex(final int clientNodeId, final List<TaskId> taskIdList, 
final List<UUID> clientList, final int topicGroupIndex);
+
+    Graph<Integer> constructTaskGraph(final List<UUID> clientList,
+        final List<TaskId> taskIdList,

Review Comment:
   nit: align indention (same below)



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/MinTrafficGraphConstructor.java:
##########
@@ -0,0 +1,162 @@
+package org.apache.kafka.streams.processor.internals.assignment;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Objects;
+import java.util.Set;
+import java.util.SortedMap;
+import java.util.UUID;
+import java.util.function.BiConsumer;
+import java.util.function.BiPredicate;
+import org.apache.kafka.streams.processor.TaskId;
+import 
org.apache.kafka.streams.processor.internals.TopologyMetadata.Subtopology;
+import 
org.apache.kafka.streams.processor.internals.assignment.RackAwareTaskAssignor.GetCostFunction;
+
+public class MinTrafficGraphConstructor implements RackAwareGraphConstructor {
+
+    @Override
+    public int getSinkNodeID(final List<TaskId> taskIdList, final List<UUID> 
clientList,
+        final Map<Subtopology, Set<TaskId>> tasksForTopicGroup) {
+        return clientList.size() + taskIdList.size();
+    }
+
+    @Override
+    public int getClientNodeId(final int clientIndex, final List<TaskId> 
taskIdList, final List<UUID> clientList, final int topicGroupIndex) {
+        return clientIndex + taskIdList.size();
+    }
+
+    @Override
+    public int getClientIndex(final int clientNodeId, final List<TaskId> 
taskIdList, final List<UUID> clientList, final int topicGroupIndex) {
+        return clientNodeId - taskIdList.size();
+    }
+
+    @Override
+    public Graph<Integer> constructTaskGraph(final List<UUID> clientList,
+        final List<TaskId> taskIdList, final SortedMap<UUID, ClientState> 
clientStates,
+        final Map<TaskId, UUID> taskClientMap, final Map<UUID, Integer> 
originalAssignedTaskNumber,
+        final BiPredicate<ClientState, TaskId> hasAssignedTask, final 
GetCostFunction getCostFunction, final int trafficCost,
+        final int nonOverlapCost, final boolean hasReplica, final boolean 
isStandby) {
+
+        final Graph<Integer> graph = new Graph<>();
+
+        for (final TaskId taskId : taskIdList) {
+            for (final Entry<UUID, ClientState> clientState : 
clientStates.entrySet()) {
+                if (hasAssignedTask.test(clientState.getValue(), taskId)) {
+                    originalAssignedTaskNumber.merge(clientState.getKey(), 1, 
Integer::sum);
+                }
+            }
+        }
+
+        // Make task and client Node id in graph deterministic
+        for (int taskNodeId = 0; taskNodeId < taskIdList.size(); taskNodeId++) 
{
+            final TaskId taskId = taskIdList.get(taskNodeId);
+            for (int j = 0; j < clientList.size(); j++) {
+                final int clientNodeId = getClientNodeId(j, taskIdList, 
clientList, 0);

Review Comment:
   Given that this is min-cost case, so we need to pass `clientList` or could 
we pass `null` (similar to calling `getSinkNodeID(...)` with `null` below) ?
   
   Should we pass `0` or `-1` to highly "unused" better?
   
   Might be worth to add a comment?



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/MinTrafficGraphConstructor.java:
##########
@@ -0,0 +1,162 @@
+package org.apache.kafka.streams.processor.internals.assignment;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Objects;
+import java.util.Set;
+import java.util.SortedMap;
+import java.util.UUID;
+import java.util.function.BiConsumer;
+import java.util.function.BiPredicate;
+import org.apache.kafka.streams.processor.TaskId;
+import 
org.apache.kafka.streams.processor.internals.TopologyMetadata.Subtopology;
+import 
org.apache.kafka.streams.processor.internals.assignment.RackAwareTaskAssignor.GetCostFunction;
+
+public class MinTrafficGraphConstructor implements RackAwareGraphConstructor {
+
+    @Override
+    public int getSinkNodeID(final List<TaskId> taskIdList, final List<UUID> 
clientList,
+        final Map<Subtopology, Set<TaskId>> tasksForTopicGroup) {

Review Comment:
   nit: formatting
   ```
   // either
   public int getSinkNodeID(final List<TaskId> taskIdList, final List<UUID> 
clientList, final Map<Subtopology, Set<TaskId>> tasksForTopicGroup) {
   
   // or
   public int getSinkNodeID(
       final List<TaskId> taskIdList,
       final List<UUID> clientList,
       final Map<Subtopology, Set<TaskId>> tasksForTopicGroup
   ) {
   ```



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareGraphConstructor.java:
##########
@@ -0,0 +1,47 @@
+package org.apache.kafka.streams.processor.internals.assignment;
+
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.SortedMap;
+import java.util.UUID;
+import java.util.function.BiConsumer;
+import java.util.function.BiPredicate;
+import org.apache.kafka.streams.processor.TaskId;
+import 
org.apache.kafka.streams.processor.internals.TopologyMetadata.Subtopology;
+import 
org.apache.kafka.streams.processor.internals.assignment.RackAwareTaskAssignor.GetCostFunction;
+
+/**
+ * Construct graph for rack aware task assignor
+ */
+public interface RackAwareGraphConstructor {
+    int SOURCE_ID = -1;
+
+    int getSinkNodeID(final List<TaskId> taskIdList, final List<UUID> 
clientList, final Map<Subtopology, Set<TaskId>> tasksForTopicGroup);
+
+    int getClientNodeId(final int clientIndex, final List<TaskId> taskIdList, 
final List<UUID> clientList, final int topicGroupIndex);
+
+    int getClientIndex(final int clientNodeId, final List<TaskId> taskIdList, 
final List<UUID> clientList, final int topicGroupIndex);
+
+    Graph<Integer> constructTaskGraph(final List<UUID> clientList,
+        final List<TaskId> taskIdList,
+        final SortedMap<UUID, ClientState> clientStates,

Review Comment:
   Thia was a `Map` in the old code -- why the change to `SortedMap` ?



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignor.java:
##########
@@ -61,19 +59,30 @@ boolean canMove(final ClientState source,
                         final Map<UUID, ClientState> clientStateMap);
     }
 
+    @FunctionalInterface
+    public interface GetCostFunction {

Review Comment:
   `GetCostFunction` -> `CostFunction`
   
   (This interface does not _return_ a cost function, but _is_ a cost function, 
right?)



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/MinTrafficGraphConstructor.java:
##########
@@ -0,0 +1,162 @@
+package org.apache.kafka.streams.processor.internals.assignment;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Objects;
+import java.util.Set;
+import java.util.SortedMap;
+import java.util.UUID;
+import java.util.function.BiConsumer;
+import java.util.function.BiPredicate;
+import org.apache.kafka.streams.processor.TaskId;
+import 
org.apache.kafka.streams.processor.internals.TopologyMetadata.Subtopology;
+import 
org.apache.kafka.streams.processor.internals.assignment.RackAwareTaskAssignor.GetCostFunction;
+
+public class MinTrafficGraphConstructor implements RackAwareGraphConstructor {
+
+    @Override
+    public int getSinkNodeID(final List<TaskId> taskIdList, final List<UUID> 
clientList,
+        final Map<Subtopology, Set<TaskId>> tasksForTopicGroup) {
+        return clientList.size() + taskIdList.size();
+    }
+
+    @Override
+    public int getClientNodeId(final int clientIndex, final List<TaskId> 
taskIdList, final List<UUID> clientList, final int topicGroupIndex) {
+        return clientIndex + taskIdList.size();
+    }
+
+    @Override
+    public int getClientIndex(final int clientNodeId, final List<TaskId> 
taskIdList, final List<UUID> clientList, final int topicGroupIndex) {
+        return clientNodeId - taskIdList.size();
+    }
+
+    @Override
+    public Graph<Integer> constructTaskGraph(final List<UUID> clientList,
+        final List<TaskId> taskIdList, final SortedMap<UUID, ClientState> 
clientStates,
+        final Map<TaskId, UUID> taskClientMap, final Map<UUID, Integer> 
originalAssignedTaskNumber,
+        final BiPredicate<ClientState, TaskId> hasAssignedTask, final 
GetCostFunction getCostFunction, final int trafficCost,
+        final int nonOverlapCost, final boolean hasReplica, final boolean 
isStandby) {

Review Comment:
   nit: formatting (one parameter per line -- cf comment above)



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/MinTrafficGraphConstructor.java:
##########
@@ -0,0 +1,162 @@
+package org.apache.kafka.streams.processor.internals.assignment;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Objects;
+import java.util.Set;
+import java.util.SortedMap;
+import java.util.UUID;
+import java.util.function.BiConsumer;
+import java.util.function.BiPredicate;
+import org.apache.kafka.streams.processor.TaskId;
+import 
org.apache.kafka.streams.processor.internals.TopologyMetadata.Subtopology;
+import 
org.apache.kafka.streams.processor.internals.assignment.RackAwareTaskAssignor.GetCostFunction;
+
+public class MinTrafficGraphConstructor implements RackAwareGraphConstructor {
+
+    @Override
+    public int getSinkNodeID(final List<TaskId> taskIdList, final List<UUID> 
clientList,
+        final Map<Subtopology, Set<TaskId>> tasksForTopicGroup) {
+        return clientList.size() + taskIdList.size();
+    }
+
+    @Override
+    public int getClientNodeId(final int clientIndex, final List<TaskId> 
taskIdList, final List<UUID> clientList, final int topicGroupIndex) {
+        return clientIndex + taskIdList.size();
+    }
+
+    @Override
+    public int getClientIndex(final int clientNodeId, final List<TaskId> 
taskIdList, final List<UUID> clientList, final int topicGroupIndex) {
+        return clientNodeId - taskIdList.size();
+    }
+
+    @Override
+    public Graph<Integer> constructTaskGraph(final List<UUID> clientList,
+        final List<TaskId> taskIdList, final SortedMap<UUID, ClientState> 
clientStates,
+        final Map<TaskId, UUID> taskClientMap, final Map<UUID, Integer> 
originalAssignedTaskNumber,
+        final BiPredicate<ClientState, TaskId> hasAssignedTask, final 
GetCostFunction getCostFunction, final int trafficCost,
+        final int nonOverlapCost, final boolean hasReplica, final boolean 
isStandby) {
+
+        final Graph<Integer> graph = new Graph<>();
+
+        for (final TaskId taskId : taskIdList) {
+            for (final Entry<UUID, ClientState> clientState : 
clientStates.entrySet()) {
+                if (hasAssignedTask.test(clientState.getValue(), taskId)) {
+                    originalAssignedTaskNumber.merge(clientState.getKey(), 1, 
Integer::sum);
+                }
+            }
+        }
+
+        // Make task and client Node id in graph deterministic
+        for (int taskNodeId = 0; taskNodeId < taskIdList.size(); taskNodeId++) 
{
+            final TaskId taskId = taskIdList.get(taskNodeId);
+            for (int j = 0; j < clientList.size(); j++) {
+                final int clientNodeId = getClientNodeId(j, taskIdList, 
clientList, 0);
+                final UUID processId = clientList.get(j);
+
+                final int flow = 
hasAssignedTask.test(clientStates.get(processId), taskId) ? 1 : 0;
+                final int cost = getCostFunction.getCost(taskId, processId, 
flow == 1, trafficCost,
+                    nonOverlapCost, isStandby);
+                if (flow == 1) {
+                    if (!hasReplica && taskClientMap.containsKey(taskId)) {
+                        throw new IllegalArgumentException("Task " + taskId + 
" assigned to multiple clients "
+                            + processId + ", " + taskClientMap.get(taskId));
+                    }
+                    taskClientMap.put(taskId, processId);
+                }
+
+                graph.addEdge(taskNodeId, clientNodeId, 1, cost, flow);
+            }
+            if (!taskClientMap.containsKey(taskId)) {
+                throw new IllegalArgumentException("Task " + taskId + " not 
assigned to any client");
+            }
+
+            // Add edge from source to task
+            graph.addEdge(SOURCE_ID, taskNodeId, 1, 0, 1);
+        }
+
+        final int sinkId = getSinkNodeID(taskIdList, clientList, null);
+        // It's possible that some clients have 0 task assign. These clients 
will have 0 tasks assigned
+        // even though it may have higher traffic cost. This is to maintain 
the original assigned task count
+        for (int i = 0; i < clientList.size(); i++) {
+            final int clientNodeId = getClientNodeId(i, taskIdList, 
clientList, 0);
+            final int capacity = 
originalAssignedTaskNumber.getOrDefault(clientList.get(i), 0);
+            // Flow equals to capacity for edges to sink
+            graph.addEdge(clientNodeId, sinkId, capacity, 0, capacity);
+        }
+
+        graph.setSourceNode(SOURCE_ID);
+        graph.setSinkNode(sinkId);
+
+        return graph;
+    }
+
+    @Override
+    public boolean assignTaskFromMinCostFlow(final Graph<Integer> graph,
+        final List<UUID> clientList, final List<TaskId> taskIdList,
+        final Map<UUID, ClientState> clientStates,
+        final Map<UUID, Integer> originalAssignedTaskNumber, final Map<TaskId, 
UUID> taskClientMap,
+        final BiConsumer<ClientState, TaskId> assignTask,
+        final BiConsumer<ClientState, TaskId> unAssignTask,
+        final BiPredicate<ClientState, TaskId> hasAssignedTask) {
+
+        int tasksAssigned = 0;
+        boolean taskMoved = false;
+        for (int taskNodeId = 0; taskNodeId < taskIdList.size(); taskNodeId++) 
{
+            final TaskId taskId = taskIdList.get(taskNodeId);
+            final Map<Integer, Graph<Integer>.Edge> edges = 
graph.edges(taskNodeId);
+            for (final Graph<Integer>.Edge edge : edges.values()) {
+                if (edge.flow > 0) {
+                    tasksAssigned++;
+                    final int clientIndex = getClientIndex(edge.destination, 
taskIdList, clientList, 0);

Review Comment:
   as above?



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/BalanceSubtopologyGraphConstructor.java:
##########
@@ -0,0 +1,186 @@
+package org.apache.kafka.streams.processor.internals.assignment;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Objects;
+import java.util.Set;
+import java.util.SortedMap;
+import java.util.SortedSet;
+import java.util.TreeMap;
+import java.util.TreeSet;
+import java.util.UUID;
+import java.util.function.BiConsumer;
+import java.util.function.BiPredicate;
+import org.apache.kafka.streams.processor.TaskId;
+import 
org.apache.kafka.streams.processor.internals.TopologyMetadata.Subtopology;
+import 
org.apache.kafka.streams.processor.internals.assignment.RackAwareTaskAssignor.GetCostFunction;
+
+public class BalanceSubtopologyGraphConstructor implements 
RackAwareGraphConstructor {
+
+    private final Map<Subtopology, Set<TaskId>> tasksForTopicGroup;
+
+    public BalanceSubtopologyGraphConstructor(final Map<Subtopology, 
Set<TaskId>> tasksForTopicGroup) {
+        this.tasksForTopicGroup = tasksForTopicGroup;
+    }
+
+    @Override
+    public int getSinkNodeID(final List<TaskId> taskIdList, final List<UUID> 
clientList,
+        final Map<Subtopology, Set<TaskId>> tasksForTopicGroup) {

Review Comment:
   formatting (cf other comments)



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignor.java:
##########
@@ -327,8 +325,8 @@ private long tasksCost(final SortedSet<TaskId> tasks,
         }
         final List<UUID> clientList = new ArrayList<>(clientStates.keySet());
         final List<TaskId> taskIdList = new ArrayList<>(tasks);
-        final Graph<Integer> graph = constructTaskGraph(clientList, taskIdList,
-            clientStates, new HashMap<>(), new HashMap<>(), hasAssignedTask, 
trafficCost, nonOverlapCost, hasReplica, isStandby);
+        final Graph<Integer> graph = 
RackAwareGraphConstructorFactory.create(assignmentConfigs, 
tasksForTopicGroup).constructTaskGraph(clientList, taskIdList,
+            clientStates, new HashMap<>(), new HashMap<>(), hasAssignedTask, 
this::getCost, trafficCost, nonOverlapCost, hasReplica, isStandby);

Review Comment:
   nit: formatting (pass one parameter per line):
   ```
   final Graph<Integer> graph = RackAwareGraphConstructorFactory
       .create(assignmentConfigs, tasksForTopicGroup)
       .constructTaskGraph(
           clientList,
           taskIdList,
           ...,
           hasReplica,
           isStandby
       );
   ````



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/BalanceSubtopologyGraphConstructor.java:
##########
@@ -0,0 +1,186 @@
+package org.apache.kafka.streams.processor.internals.assignment;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Objects;
+import java.util.Set;
+import java.util.SortedMap;
+import java.util.SortedSet;
+import java.util.TreeMap;
+import java.util.TreeSet;
+import java.util.UUID;
+import java.util.function.BiConsumer;
+import java.util.function.BiPredicate;
+import org.apache.kafka.streams.processor.TaskId;
+import 
org.apache.kafka.streams.processor.internals.TopologyMetadata.Subtopology;
+import 
org.apache.kafka.streams.processor.internals.assignment.RackAwareTaskAssignor.GetCostFunction;
+
+public class BalanceSubtopologyGraphConstructor implements 
RackAwareGraphConstructor {
+
+    private final Map<Subtopology, Set<TaskId>> tasksForTopicGroup;
+
+    public BalanceSubtopologyGraphConstructor(final Map<Subtopology, 
Set<TaskId>> tasksForTopicGroup) {
+        this.tasksForTopicGroup = tasksForTopicGroup;
+    }
+
+    @Override
+    public int getSinkNodeID(final List<TaskId> taskIdList, final List<UUID> 
clientList,
+        final Map<Subtopology, Set<TaskId>> tasksForTopicGroup) {
+        return clientList.size() + taskIdList.size() + clientList.size() * 
tasksForTopicGroup.size();
+    }
+
+
+    @Override
+    public int getClientNodeId(final int clientIndex, final List<TaskId> 
taskIdList, final List<UUID> clientList, final int topicGroupIndex) {
+        return taskIdList.size() + clientList.size() * topicGroupIndex + 
clientIndex;
+    }
+
+    @Override
+    public int getClientIndex(final int clientNodeId, final List<TaskId> 
taskIdList, final List<UUID> clientList, final int topicGroupIndex) {
+        return clientNodeId - taskIdList.size() - clientList.size() * 
topicGroupIndex;
+    }
+
+    private static int getSecondStageClientNodeId(final List<TaskId> 
taskIdList, final List<UUID> clientList, final Map<Subtopology, Set<TaskId>> 
tasksForTopicGroup, final int clientIndex) {
+        return taskIdList.size() + clientList.size() * 
tasksForTopicGroup.size() + clientIndex;
+    }
+
+    @Override
+    public Graph<Integer> constructTaskGraph(final List<UUID> clientList,
+        final List<TaskId> taskIdList, final SortedMap<UUID, ClientState> 
clientStates,
+        final Map<TaskId, UUID> taskClientMap, final Map<UUID, Integer> 
originalAssignedTaskNumber,
+        final BiPredicate<ClientState, TaskId> hasAssignedTask, final 
GetCostFunction getCostFunction, final int trafficCost,
+        final int nonOverlapCost, final boolean hasReplica, final boolean 
isStandby) {
+        final Graph<Integer> graph = new Graph<>();
+
+        for (final TaskId taskId : taskIdList) {
+            for (final Entry<UUID, ClientState> clientState : 
clientStates.entrySet()) {
+                if (hasAssignedTask.test(clientState.getValue(), taskId)) {
+                    originalAssignedTaskNumber.merge(clientState.getKey(), 1, 
Integer::sum);
+                }
+            }
+        }
+
+        // TODO: validate tasks in tasksForTopicGroup and taskIdList
+        final SortedMap<Subtopology, Set<TaskId>> sortedTasksForTopicGroup = 
new TreeMap<>(tasksForTopicGroup);
+        final int sinkId = getSinkNodeID(taskIdList, clientList, 
tasksForTopicGroup);
+
+        int taskNodeId = 0;
+        int topicGroupIndex = 0;
+        for (final Entry<Subtopology, Set<TaskId>> kv : 
sortedTasksForTopicGroup.entrySet()) {
+            final SortedSet<TaskId> taskIds = new TreeSet<>(kv.getValue());
+            for (int clientIndex = 0; clientIndex < clientList.size(); 
clientIndex++) {
+                final UUID processId = clientList.get(clientIndex);
+                final int clientNodeId = getClientNodeId(clientIndex, 
taskIdList, clientList, topicGroupIndex);
+                int startingTaskNodeId = taskNodeId;
+                for (final TaskId taskId : taskIds) {
+                    final int flow = 
hasAssignedTask.test(clientStates.get(processId), taskId) ? 1 : 0;
+                    graph.addEdge(startingTaskNodeId, clientNodeId, 1, 
getCostFunction.getCost(taskId, processId, false, trafficCost, nonOverlapCost, 
isStandby), flow);
+                    graph.addEdge(SOURCE_ID, startingTaskNodeId, 1, 0, 0);
+                    startingTaskNodeId++;
+                }
+
+                final int secondStageClientNodeId = 
getSecondStageClientNodeId(taskIdList, clientList, tasksForTopicGroup, 
clientIndex);
+                final int capacity = (int) 
Math.ceil(originalAssignedTaskNumber.get(processId) * 1.0 / taskIdList.size() * 
taskIds.size());
+                graph.addEdge(clientNodeId, secondStageClientNodeId, capacity, 
0, 0);
+            }
+
+            taskNodeId += taskIds.size();
+            topicGroupIndex++;
+        }
+
+        for (int clientIndex = 0; clientIndex < clientList.size(); 
clientIndex++) {
+            final UUID processId = clientList.get(clientIndex);
+            final int capacity = originalAssignedTaskNumber.get(processId);
+            final int secondStageClientNodeId = 
getSecondStageClientNodeId(taskIdList, clientList, tasksForTopicGroup, 
clientIndex);
+            graph.addEdge(secondStageClientNodeId, sinkId, capacity, 0, 0);
+        }
+
+        graph.setSourceNode(SOURCE_ID);
+        graph.setSinkNode(sinkId);
+
+        // Run max flow algorithm to get a solution first
+        final long maxFlow = graph.calculateMaxFlow();
+        if (maxFlow != taskIdList.size()) {
+            throw new IllegalStateException("max flow calculated: " + maxFlow 
+ " doesn't match taskSize: " + taskIdList.size());
+        }
+
+        return graph;
+    }
+
+    @Override
+    public boolean assignTaskFromMinCostFlow(final Graph<Integer> graph,
+        final List<UUID> clientList, final List<TaskId> taskIdList,
+        final Map<UUID, ClientState> clientStates,
+        final Map<UUID, Integer> originalAssignedTaskNumber, final Map<TaskId, 
UUID> taskClientMap,
+        final BiConsumer<ClientState, TaskId> assignTask,
+        final BiConsumer<ClientState, TaskId> unAssignTask,
+        final BiPredicate<ClientState, TaskId> hasAssignedTask) {
+
+        final SortedMap<Subtopology, Set<TaskId>> sortedTasksForTopicGroup = 
new TreeMap<>(tasksForTopicGroup);
+
+        int taskNodeId = 0;
+        int topicGroupIndex = 0;
+        int tasksAssigned = 0;
+        boolean taskMoved = false;
+        for (final Entry<Subtopology, Set<TaskId>> kv : 
sortedTasksForTopicGroup.entrySet()) {
+            final SortedSet<TaskId> taskIds = new TreeSet<>(kv.getValue());
+            for (final TaskId taskId : taskIds) {
+                final Map<Integer, Graph<Integer>.Edge> edges = 
graph.edges(taskNodeId);
+                for (final Graph<Integer>.Edge edge : edges.values()) {
+                    if (edge.flow > 0) {
+                        tasksAssigned++;
+                        final int clientIndex = 
getClientIndex(edge.destination, taskIdList, clientList, topicGroupIndex);
+                        final UUID processId = clientList.get(clientIndex);
+                        final UUID originalProcessId = 
taskClientMap.get(taskId);
+
+                        // Don't need to assign this task to other client
+                        if (processId.equals(originalProcessId)) {
+                            break;
+                        }
+
+                        
unAssignTask.accept(clientStates.get(originalProcessId), taskId);
+                        assignTask.accept(clientStates.get(processId), taskId);
+                        taskMoved = true;
+                    }
+                }
+                taskNodeId++;
+            }
+            topicGroupIndex++;
+        }
+
+        // Validate task assigned

Review Comment:
   Seems this validation code (starting here to the end of the method) is the 
same as for min-traffic? Should we extract a helper method and share code?



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/Graph.java:
##########
@@ -220,6 +222,78 @@ public void solveMinCostFlow() {
         }
     }
 
+    public long calculateMaxFlow() {
+        final Graph<V> residualGraph = residualGraph();
+        residualGraph.fordFulkson();
+
+        long maxFlow = 0;
+        for (final Entry<V, SortedMap<V, Edge>> nodeEdges : 
adjList.entrySet()) {
+            final V node = nodeEdges.getKey();
+            for (final Entry<V, Edge> nodeEdge : 
nodeEdges.getValue().entrySet()) {
+                final V destination = nodeEdge.getKey();
+                final Edge edge = nodeEdge.getValue();
+                final Edge residualEdge = 
residualGraph.adjList.get(node).get(destination);
+                edge.flow = residualEdge.flow;
+                edge.residualFlow = residualEdge.residualFlow;
+
+                if (node == sourceNode) {
+                    maxFlow += edge.flow;
+                }
+            }
+        }
+
+        return maxFlow;
+    }
+
+    private void fordFulkson() {
+        if (!isResidualGraph) {
+            throw new IllegalStateException("Should be residual graph to 
cancel negative cycles");
+        }
+
+        Map<V, V> parents = new HashMap<>();
+        while (bfs(sourceNode, sinkNode, parents)) {
+            int flow = Integer.MAX_VALUE;
+            for (V node = sinkNode; node != sourceNode; node = 
parents.get(node)) {
+                final V parent = parents.get(node);
+                flow = Math.min(flow, 
adjList.get(parent).get(node).residualFlow);
+            }
+
+            for (V node = sinkNode; node != sourceNode; node = 
parents.get(node)) {
+                final V parent = parents.get(node);
+                adjList.get(parent).get(node).residualFlow -= flow;
+                adjList.get(node).get(parent).residualFlow += flow;
+            }
+
+            parents = new HashMap<>();
+        }
+    }
+
+    private boolean bfs(final V source, final V target, final Map<V, V> 
parents) {

Review Comment:
   `bfs` -> avoid abbreviations



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/BalanceSubtopologyGraphConstructor.java:
##########
@@ -0,0 +1,186 @@
+package org.apache.kafka.streams.processor.internals.assignment;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Objects;
+import java.util.Set;
+import java.util.SortedMap;
+import java.util.SortedSet;
+import java.util.TreeMap;
+import java.util.TreeSet;
+import java.util.UUID;
+import java.util.function.BiConsumer;
+import java.util.function.BiPredicate;
+import org.apache.kafka.streams.processor.TaskId;
+import 
org.apache.kafka.streams.processor.internals.TopologyMetadata.Subtopology;
+import 
org.apache.kafka.streams.processor.internals.assignment.RackAwareTaskAssignor.GetCostFunction;
+
+public class BalanceSubtopologyGraphConstructor implements 
RackAwareGraphConstructor {
+
+    private final Map<Subtopology, Set<TaskId>> tasksForTopicGroup;
+
+    public BalanceSubtopologyGraphConstructor(final Map<Subtopology, 
Set<TaskId>> tasksForTopicGroup) {
+        this.tasksForTopicGroup = tasksForTopicGroup;
+    }
+
+    @Override
+    public int getSinkNodeID(final List<TaskId> taskIdList, final List<UUID> 
clientList,
+        final Map<Subtopology, Set<TaskId>> tasksForTopicGroup) {
+        return clientList.size() + taskIdList.size() + clientList.size() * 
tasksForTopicGroup.size();
+    }
+
+
+    @Override
+    public int getClientNodeId(final int clientIndex, final List<TaskId> 
taskIdList, final List<UUID> clientList, final int topicGroupIndex) {
+        return taskIdList.size() + clientList.size() * topicGroupIndex + 
clientIndex;
+    }
+
+    @Override
+    public int getClientIndex(final int clientNodeId, final List<TaskId> 
taskIdList, final List<UUID> clientList, final int topicGroupIndex) {
+        return clientNodeId - taskIdList.size() - clientList.size() * 
topicGroupIndex;
+    }
+
+    private static int getSecondStageClientNodeId(final List<TaskId> 
taskIdList, final List<UUID> clientList, final Map<Subtopology, Set<TaskId>> 
tasksForTopicGroup, final int clientIndex) {
+        return taskIdList.size() + clientList.size() * 
tasksForTopicGroup.size() + clientIndex;
+    }
+
+    @Override
+    public Graph<Integer> constructTaskGraph(final List<UUID> clientList,
+        final List<TaskId> taskIdList, final SortedMap<UUID, ClientState> 
clientStates,
+        final Map<TaskId, UUID> taskClientMap, final Map<UUID, Integer> 
originalAssignedTaskNumber,
+        final BiPredicate<ClientState, TaskId> hasAssignedTask, final 
GetCostFunction getCostFunction, final int trafficCost,
+        final int nonOverlapCost, final boolean hasReplica, final boolean 
isStandby) {
+        final Graph<Integer> graph = new Graph<>();
+
+        for (final TaskId taskId : taskIdList) {
+            for (final Entry<UUID, ClientState> clientState : 
clientStates.entrySet()) {
+                if (hasAssignedTask.test(clientState.getValue(), taskId)) {
+                    originalAssignedTaskNumber.merge(clientState.getKey(), 1, 
Integer::sum);
+                }
+            }
+        }
+
+        // TODO: validate tasks in tasksForTopicGroup and taskIdList
+        final SortedMap<Subtopology, Set<TaskId>> sortedTasksForTopicGroup = 
new TreeMap<>(tasksForTopicGroup);
+        final int sinkId = getSinkNodeID(taskIdList, clientList, 
tasksForTopicGroup);
+
+        int taskNodeId = 0;
+        int topicGroupIndex = 0;
+        for (final Entry<Subtopology, Set<TaskId>> kv : 
sortedTasksForTopicGroup.entrySet()) {
+            final SortedSet<TaskId> taskIds = new TreeSet<>(kv.getValue());
+            for (int clientIndex = 0; clientIndex < clientList.size(); 
clientIndex++) {
+                final UUID processId = clientList.get(clientIndex);
+                final int clientNodeId = getClientNodeId(clientIndex, 
taskIdList, clientList, topicGroupIndex);
+                int startingTaskNodeId = taskNodeId;
+                for (final TaskId taskId : taskIds) {
+                    final int flow = 
hasAssignedTask.test(clientStates.get(processId), taskId) ? 1 : 0;
+                    graph.addEdge(startingTaskNodeId, clientNodeId, 1, 
getCostFunction.getCost(taskId, processId, false, trafficCost, nonOverlapCost, 
isStandby), flow);
+                    graph.addEdge(SOURCE_ID, startingTaskNodeId, 1, 0, 0);
+                    startingTaskNodeId++;
+                }
+
+                final int secondStageClientNodeId = 
getSecondStageClientNodeId(taskIdList, clientList, tasksForTopicGroup, 
clientIndex);
+                final int capacity = (int) 
Math.ceil(originalAssignedTaskNumber.get(processId) * 1.0 / taskIdList.size() * 
taskIds.size());
+                graph.addEdge(clientNodeId, secondStageClientNodeId, capacity, 
0, 0);
+            }
+
+            taskNodeId += taskIds.size();
+            topicGroupIndex++;
+        }
+
+        for (int clientIndex = 0; clientIndex < clientList.size(); 
clientIndex++) {
+            final UUID processId = clientList.get(clientIndex);
+            final int capacity = originalAssignedTaskNumber.get(processId);
+            final int secondStageClientNodeId = 
getSecondStageClientNodeId(taskIdList, clientList, tasksForTopicGroup, 
clientIndex);
+            graph.addEdge(secondStageClientNodeId, sinkId, capacity, 0, 0);
+        }
+
+        graph.setSourceNode(SOURCE_ID);
+        graph.setSinkNode(sinkId);
+
+        // Run max flow algorithm to get a solution first
+        final long maxFlow = graph.calculateMaxFlow();
+        if (maxFlow != taskIdList.size()) {
+            throw new IllegalStateException("max flow calculated: " + maxFlow 
+ " doesn't match taskSize: " + taskIdList.size());
+        }
+
+        return graph;
+    }
+
+    @Override
+    public boolean assignTaskFromMinCostFlow(final Graph<Integer> graph,
+        final List<UUID> clientList, final List<TaskId> taskIdList,
+        final Map<UUID, ClientState> clientStates,
+        final Map<UUID, Integer> originalAssignedTaskNumber, final Map<TaskId, 
UUID> taskClientMap,
+        final BiConsumer<ClientState, TaskId> assignTask,
+        final BiConsumer<ClientState, TaskId> unAssignTask,
+        final BiPredicate<ClientState, TaskId> hasAssignedTask) {

Review Comment:
   nit: formatting



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/MinTrafficGraphConstructor.java:
##########
@@ -0,0 +1,162 @@
+package org.apache.kafka.streams.processor.internals.assignment;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Objects;
+import java.util.Set;
+import java.util.SortedMap;
+import java.util.UUID;
+import java.util.function.BiConsumer;
+import java.util.function.BiPredicate;
+import org.apache.kafka.streams.processor.TaskId;
+import 
org.apache.kafka.streams.processor.internals.TopologyMetadata.Subtopology;
+import 
org.apache.kafka.streams.processor.internals.assignment.RackAwareTaskAssignor.GetCostFunction;
+
+public class MinTrafficGraphConstructor implements RackAwareGraphConstructor {
+
+    @Override
+    public int getSinkNodeID(final List<TaskId> taskIdList, final List<UUID> 
clientList,
+        final Map<Subtopology, Set<TaskId>> tasksForTopicGroup) {
+        return clientList.size() + taskIdList.size();
+    }
+
+    @Override
+    public int getClientNodeId(final int clientIndex, final List<TaskId> 
taskIdList, final List<UUID> clientList, final int topicGroupIndex) {
+        return clientIndex + taskIdList.size();
+    }
+
+    @Override
+    public int getClientIndex(final int clientNodeId, final List<TaskId> 
taskIdList, final List<UUID> clientList, final int topicGroupIndex) {
+        return clientNodeId - taskIdList.size();
+    }
+
+    @Override
+    public Graph<Integer> constructTaskGraph(final List<UUID> clientList,
+        final List<TaskId> taskIdList, final SortedMap<UUID, ClientState> 
clientStates,
+        final Map<TaskId, UUID> taskClientMap, final Map<UUID, Integer> 
originalAssignedTaskNumber,
+        final BiPredicate<ClientState, TaskId> hasAssignedTask, final 
GetCostFunction getCostFunction, final int trafficCost,
+        final int nonOverlapCost, final boolean hasReplica, final boolean 
isStandby) {
+
+        final Graph<Integer> graph = new Graph<>();
+
+        for (final TaskId taskId : taskIdList) {
+            for (final Entry<UUID, ClientState> clientState : 
clientStates.entrySet()) {
+                if (hasAssignedTask.test(clientState.getValue(), taskId)) {
+                    originalAssignedTaskNumber.merge(clientState.getKey(), 1, 
Integer::sum);
+                }
+            }
+        }
+
+        // Make task and client Node id in graph deterministic
+        for (int taskNodeId = 0; taskNodeId < taskIdList.size(); taskNodeId++) 
{
+            final TaskId taskId = taskIdList.get(taskNodeId);
+            for (int j = 0; j < clientList.size(); j++) {
+                final int clientNodeId = getClientNodeId(j, taskIdList, 
clientList, 0);
+                final UUID processId = clientList.get(j);
+
+                final int flow = 
hasAssignedTask.test(clientStates.get(processId), taskId) ? 1 : 0;
+                final int cost = getCostFunction.getCost(taskId, processId, 
flow == 1, trafficCost,
+                    nonOverlapCost, isStandby);
+                if (flow == 1) {
+                    if (!hasReplica && taskClientMap.containsKey(taskId)) {
+                        throw new IllegalArgumentException("Task " + taskId + 
" assigned to multiple clients "
+                            + processId + ", " + taskClientMap.get(taskId));
+                    }
+                    taskClientMap.put(taskId, processId);
+                }
+
+                graph.addEdge(taskNodeId, clientNodeId, 1, cost, flow);
+            }
+            if (!taskClientMap.containsKey(taskId)) {
+                throw new IllegalArgumentException("Task " + taskId + " not 
assigned to any client");
+            }
+
+            // Add edge from source to task
+            graph.addEdge(SOURCE_ID, taskNodeId, 1, 0, 1);
+        }
+
+        final int sinkId = getSinkNodeID(taskIdList, clientList, null);
+        // It's possible that some clients have 0 task assign. These clients 
will have 0 tasks assigned
+        // even though it may have higher traffic cost. This is to maintain 
the original assigned task count
+        for (int i = 0; i < clientList.size(); i++) {
+            final int clientNodeId = getClientNodeId(i, taskIdList, 
clientList, 0);
+            final int capacity = 
originalAssignedTaskNumber.getOrDefault(clientList.get(i), 0);
+            // Flow equals to capacity for edges to sink
+            graph.addEdge(clientNodeId, sinkId, capacity, 0, capacity);
+        }
+
+        graph.setSourceNode(SOURCE_ID);
+        graph.setSinkNode(sinkId);
+
+        return graph;
+    }
+
+    @Override
+    public boolean assignTaskFromMinCostFlow(final Graph<Integer> graph,
+        final List<UUID> clientList, final List<TaskId> taskIdList,
+        final Map<UUID, ClientState> clientStates,
+        final Map<UUID, Integer> originalAssignedTaskNumber, final Map<TaskId, 
UUID> taskClientMap,
+        final BiConsumer<ClientState, TaskId> assignTask,
+        final BiConsumer<ClientState, TaskId> unAssignTask,
+        final BiPredicate<ClientState, TaskId> hasAssignedTask) {

Review Comment:
   nit: fix indention (cf other comments).



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/MinTrafficGraphConstructor.java:
##########
@@ -0,0 +1,162 @@
+package org.apache.kafka.streams.processor.internals.assignment;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Objects;
+import java.util.Set;
+import java.util.SortedMap;
+import java.util.UUID;
+import java.util.function.BiConsumer;
+import java.util.function.BiPredicate;
+import org.apache.kafka.streams.processor.TaskId;
+import 
org.apache.kafka.streams.processor.internals.TopologyMetadata.Subtopology;
+import 
org.apache.kafka.streams.processor.internals.assignment.RackAwareTaskAssignor.GetCostFunction;
+
+public class MinTrafficGraphConstructor implements RackAwareGraphConstructor {
+
+    @Override
+    public int getSinkNodeID(final List<TaskId> taskIdList, final List<UUID> 
clientList,
+        final Map<Subtopology, Set<TaskId>> tasksForTopicGroup) {
+        return clientList.size() + taskIdList.size();
+    }
+
+    @Override
+    public int getClientNodeId(final int clientIndex, final List<TaskId> 
taskIdList, final List<UUID> clientList, final int topicGroupIndex) {
+        return clientIndex + taskIdList.size();
+    }
+
+    @Override
+    public int getClientIndex(final int clientNodeId, final List<TaskId> 
taskIdList, final List<UUID> clientList, final int topicGroupIndex) {
+        return clientNodeId - taskIdList.size();
+    }
+
+    @Override
+    public Graph<Integer> constructTaskGraph(final List<UUID> clientList,
+        final List<TaskId> taskIdList, final SortedMap<UUID, ClientState> 
clientStates,
+        final Map<TaskId, UUID> taskClientMap, final Map<UUID, Integer> 
originalAssignedTaskNumber,
+        final BiPredicate<ClientState, TaskId> hasAssignedTask, final 
GetCostFunction getCostFunction, final int trafficCost,
+        final int nonOverlapCost, final boolean hasReplica, final boolean 
isStandby) {
+
+        final Graph<Integer> graph = new Graph<>();
+
+        for (final TaskId taskId : taskIdList) {
+            for (final Entry<UUID, ClientState> clientState : 
clientStates.entrySet()) {
+                if (hasAssignedTask.test(clientState.getValue(), taskId)) {
+                    originalAssignedTaskNumber.merge(clientState.getKey(), 1, 
Integer::sum);
+                }
+            }
+        }
+
+        // Make task and client Node id in graph deterministic
+        for (int taskNodeId = 0; taskNodeId < taskIdList.size(); taskNodeId++) 
{
+            final TaskId taskId = taskIdList.get(taskNodeId);
+            for (int j = 0; j < clientList.size(); j++) {
+                final int clientNodeId = getClientNodeId(j, taskIdList, 
clientList, 0);
+                final UUID processId = clientList.get(j);
+
+                final int flow = 
hasAssignedTask.test(clientStates.get(processId), taskId) ? 1 : 0;
+                final int cost = getCostFunction.getCost(taskId, processId, 
flow == 1, trafficCost,
+                    nonOverlapCost, isStandby);
+                if (flow == 1) {
+                    if (!hasReplica && taskClientMap.containsKey(taskId)) {
+                        throw new IllegalArgumentException("Task " + taskId + 
" assigned to multiple clients "
+                            + processId + ", " + taskClientMap.get(taskId));
+                    }
+                    taskClientMap.put(taskId, processId);
+                }
+
+                graph.addEdge(taskNodeId, clientNodeId, 1, cost, flow);
+            }
+            if (!taskClientMap.containsKey(taskId)) {
+                throw new IllegalArgumentException("Task " + taskId + " not 
assigned to any client");
+            }
+
+            // Add edge from source to task
+            graph.addEdge(SOURCE_ID, taskNodeId, 1, 0, 1);
+        }
+
+        final int sinkId = getSinkNodeID(taskIdList, clientList, null);
+        // It's possible that some clients have 0 task assign. These clients 
will have 0 tasks assigned
+        // even though it may have higher traffic cost. This is to maintain 
the original assigned task count
+        for (int i = 0; i < clientList.size(); i++) {
+            final int clientNodeId = getClientNodeId(i, taskIdList, 
clientList, 0);

Review Comment:
   as above?



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignor.java:
##########
@@ -363,13 +361,14 @@ public long optimizeActiveTasks(final SortedSet<TaskId> 
activeTasks,
         final List<TaskId> taskIdList = new ArrayList<>(activeTasks);
         final Map<TaskId, UUID> taskClientMap = new HashMap<>();
         final Map<UUID, Integer> originalAssignedTaskNumber = new HashMap<>();
-        final Graph<Integer> graph = constructTaskGraph(clientList, taskIdList,
-            clientStates, taskClientMap, originalAssignedTaskNumber, 
ClientState::hasActiveTask, trafficCost, nonOverlapCost, false, false);
+        final RackAwareGraphConstructor graphConstructor = 
RackAwareGraphConstructorFactory.create(assignmentConfigs, tasksForTopicGroup);
+        final Graph<Integer> graph = 
graphConstructor.constructTaskGraph(clientList, taskIdList,
+            clientStates, taskClientMap, originalAssignedTaskNumber, 
ClientState::hasActiveTask, this::getCost, trafficCost, nonOverlapCost, false, 
false);

Review Comment:
   nit: formatting (pass one parameter per line):
   ```
   final Graph<Integer> graph = graphConstructor.constructTaskGraph(
       clientList,
       taskIdList,
       ...,
       false,
       false
   );
   ```



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/BalanceSubtopologyGraphConstructor.java:
##########
@@ -0,0 +1,186 @@
+package org.apache.kafka.streams.processor.internals.assignment;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Objects;
+import java.util.Set;
+import java.util.SortedMap;
+import java.util.SortedSet;
+import java.util.TreeMap;
+import java.util.TreeSet;
+import java.util.UUID;
+import java.util.function.BiConsumer;
+import java.util.function.BiPredicate;
+import org.apache.kafka.streams.processor.TaskId;
+import 
org.apache.kafka.streams.processor.internals.TopologyMetadata.Subtopology;
+import 
org.apache.kafka.streams.processor.internals.assignment.RackAwareTaskAssignor.GetCostFunction;
+
+public class BalanceSubtopologyGraphConstructor implements 
RackAwareGraphConstructor {
+
+    private final Map<Subtopology, Set<TaskId>> tasksForTopicGroup;
+
+    public BalanceSubtopologyGraphConstructor(final Map<Subtopology, 
Set<TaskId>> tasksForTopicGroup) {
+        this.tasksForTopicGroup = tasksForTopicGroup;
+    }
+
+    @Override
+    public int getSinkNodeID(final List<TaskId> taskIdList, final List<UUID> 
clientList,
+        final Map<Subtopology, Set<TaskId>> tasksForTopicGroup) {
+        return clientList.size() + taskIdList.size() + clientList.size() * 
tasksForTopicGroup.size();
+    }
+
+
+    @Override
+    public int getClientNodeId(final int clientIndex, final List<TaskId> 
taskIdList, final List<UUID> clientList, final int topicGroupIndex) {
+        return taskIdList.size() + clientList.size() * topicGroupIndex + 
clientIndex;
+    }
+
+    @Override
+    public int getClientIndex(final int clientNodeId, final List<TaskId> 
taskIdList, final List<UUID> clientList, final int topicGroupIndex) {
+        return clientNodeId - taskIdList.size() - clientList.size() * 
topicGroupIndex;
+    }
+
+    private static int getSecondStageClientNodeId(final List<TaskId> 
taskIdList, final List<UUID> clientList, final Map<Subtopology, Set<TaskId>> 
tasksForTopicGroup, final int clientIndex) {
+        return taskIdList.size() + clientList.size() * 
tasksForTopicGroup.size() + clientIndex;
+    }
+
+    @Override
+    public Graph<Integer> constructTaskGraph(final List<UUID> clientList,
+        final List<TaskId> taskIdList, final SortedMap<UUID, ClientState> 
clientStates,
+        final Map<TaskId, UUID> taskClientMap, final Map<UUID, Integer> 
originalAssignedTaskNumber,
+        final BiPredicate<ClientState, TaskId> hasAssignedTask, final 
GetCostFunction getCostFunction, final int trafficCost,
+        final int nonOverlapCost, final boolean hasReplica, final boolean 
isStandby) {
+        final Graph<Integer> graph = new Graph<>();
+
+        for (final TaskId taskId : taskIdList) {
+            for (final Entry<UUID, ClientState> clientState : 
clientStates.entrySet()) {
+                if (hasAssignedTask.test(clientState.getValue(), taskId)) {
+                    originalAssignedTaskNumber.merge(clientState.getKey(), 1, 
Integer::sum);
+                }
+            }
+        }
+
+        // TODO: validate tasks in tasksForTopicGroup and taskIdList
+        final SortedMap<Subtopology, Set<TaskId>> sortedTasksForTopicGroup = 
new TreeMap<>(tasksForTopicGroup);
+        final int sinkId = getSinkNodeID(taskIdList, clientList, 
tasksForTopicGroup);
+
+        int taskNodeId = 0;
+        int topicGroupIndex = 0;
+        for (final Entry<Subtopology, Set<TaskId>> kv : 
sortedTasksForTopicGroup.entrySet()) {
+            final SortedSet<TaskId> taskIds = new TreeSet<>(kv.getValue());
+            for (int clientIndex = 0; clientIndex < clientList.size(); 
clientIndex++) {
+                final UUID processId = clientList.get(clientIndex);
+                final int clientNodeId = getClientNodeId(clientIndex, 
taskIdList, clientList, topicGroupIndex);
+                int startingTaskNodeId = taskNodeId;
+                for (final TaskId taskId : taskIds) {
+                    final int flow = 
hasAssignedTask.test(clientStates.get(processId), taskId) ? 1 : 0;
+                    graph.addEdge(startingTaskNodeId, clientNodeId, 1, 
getCostFunction.getCost(taskId, processId, false, trafficCost, nonOverlapCost, 
isStandby), flow);
+                    graph.addEdge(SOURCE_ID, startingTaskNodeId, 1, 0, 0);

Review Comment:
   It seems we add this edge multiple times?
   
   Would it be better to flip the nesting of `for (final TaskId taskId : 
taskIds)` and `for (int clientIndex = 0; clientIndex < clientList.size(); 
clientIndex++) {` to allow us to assign this edge only once?



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/BalanceSubtopologyGraphConstructor.java:
##########
@@ -0,0 +1,186 @@
+package org.apache.kafka.streams.processor.internals.assignment;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Objects;
+import java.util.Set;
+import java.util.SortedMap;
+import java.util.SortedSet;
+import java.util.TreeMap;
+import java.util.TreeSet;
+import java.util.UUID;
+import java.util.function.BiConsumer;
+import java.util.function.BiPredicate;
+import org.apache.kafka.streams.processor.TaskId;
+import 
org.apache.kafka.streams.processor.internals.TopologyMetadata.Subtopology;
+import 
org.apache.kafka.streams.processor.internals.assignment.RackAwareTaskAssignor.GetCostFunction;
+
+public class BalanceSubtopologyGraphConstructor implements 
RackAwareGraphConstructor {
+
+    private final Map<Subtopology, Set<TaskId>> tasksForTopicGroup;
+
+    public BalanceSubtopologyGraphConstructor(final Map<Subtopology, 
Set<TaskId>> tasksForTopicGroup) {
+        this.tasksForTopicGroup = tasksForTopicGroup;
+    }
+
+    @Override
+    public int getSinkNodeID(final List<TaskId> taskIdList, final List<UUID> 
clientList,
+        final Map<Subtopology, Set<TaskId>> tasksForTopicGroup) {
+        return clientList.size() + taskIdList.size() + clientList.size() * 
tasksForTopicGroup.size();
+    }
+
+
+    @Override
+    public int getClientNodeId(final int clientIndex, final List<TaskId> 
taskIdList, final List<UUID> clientList, final int topicGroupIndex) {
+        return taskIdList.size() + clientList.size() * topicGroupIndex + 
clientIndex;
+    }
+
+    @Override
+    public int getClientIndex(final int clientNodeId, final List<TaskId> 
taskIdList, final List<UUID> clientList, final int topicGroupIndex) {
+        return clientNodeId - taskIdList.size() - clientList.size() * 
topicGroupIndex;
+    }
+
+    private static int getSecondStageClientNodeId(final List<TaskId> 
taskIdList, final List<UUID> clientList, final Map<Subtopology, Set<TaskId>> 
tasksForTopicGroup, final int clientIndex) {
+        return taskIdList.size() + clientList.size() * 
tasksForTopicGroup.size() + clientIndex;
+    }
+
+    @Override
+    public Graph<Integer> constructTaskGraph(final List<UUID> clientList,
+        final List<TaskId> taskIdList, final SortedMap<UUID, ClientState> 
clientStates,
+        final Map<TaskId, UUID> taskClientMap, final Map<UUID, Integer> 
originalAssignedTaskNumber,
+        final BiPredicate<ClientState, TaskId> hasAssignedTask, final 
GetCostFunction getCostFunction, final int trafficCost,
+        final int nonOverlapCost, final boolean hasReplica, final boolean 
isStandby) {
+        final Graph<Integer> graph = new Graph<>();
+
+        for (final TaskId taskId : taskIdList) {
+            for (final Entry<UUID, ClientState> clientState : 
clientStates.entrySet()) {
+                if (hasAssignedTask.test(clientState.getValue(), taskId)) {
+                    originalAssignedTaskNumber.merge(clientState.getKey(), 1, 
Integer::sum);
+                }
+            }
+        }
+
+        // TODO: validate tasks in tasksForTopicGroup and taskIdList
+        final SortedMap<Subtopology, Set<TaskId>> sortedTasksForTopicGroup = 
new TreeMap<>(tasksForTopicGroup);
+        final int sinkId = getSinkNodeID(taskIdList, clientList, 
tasksForTopicGroup);
+
+        int taskNodeId = 0;
+        int topicGroupIndex = 0;
+        for (final Entry<Subtopology, Set<TaskId>> kv : 
sortedTasksForTopicGroup.entrySet()) {
+            final SortedSet<TaskId> taskIds = new TreeSet<>(kv.getValue());
+            for (int clientIndex = 0; clientIndex < clientList.size(); 
clientIndex++) {
+                final UUID processId = clientList.get(clientIndex);
+                final int clientNodeId = getClientNodeId(clientIndex, 
taskIdList, clientList, topicGroupIndex);
+                int startingTaskNodeId = taskNodeId;
+                for (final TaskId taskId : taskIds) {
+                    final int flow = 
hasAssignedTask.test(clientStates.get(processId), taskId) ? 1 : 0;
+                    graph.addEdge(startingTaskNodeId, clientNodeId, 1, 
getCostFunction.getCost(taskId, processId, false, trafficCost, nonOverlapCost, 
isStandby), flow);
+                    graph.addEdge(SOURCE_ID, startingTaskNodeId, 1, 0, 0);
+                    startingTaskNodeId++;
+                }
+
+                final int secondStageClientNodeId = 
getSecondStageClientNodeId(taskIdList, clientList, tasksForTopicGroup, 
clientIndex);

Review Comment:
   Seems that if we flip the nesting, we move the issue to here? -- Mabye you 
try to do to many things at one in the same loops? Would it be better to break 
it a little bit better, and loop over all sub-topology two times (instead of 
ones?)



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/BalanceSubtopologyGraphConstructor.java:
##########
@@ -0,0 +1,186 @@
+package org.apache.kafka.streams.processor.internals.assignment;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Objects;
+import java.util.Set;
+import java.util.SortedMap;
+import java.util.SortedSet;
+import java.util.TreeMap;
+import java.util.TreeSet;
+import java.util.UUID;
+import java.util.function.BiConsumer;
+import java.util.function.BiPredicate;
+import org.apache.kafka.streams.processor.TaskId;
+import 
org.apache.kafka.streams.processor.internals.TopologyMetadata.Subtopology;
+import 
org.apache.kafka.streams.processor.internals.assignment.RackAwareTaskAssignor.GetCostFunction;
+
+public class BalanceSubtopologyGraphConstructor implements 
RackAwareGraphConstructor {
+
+    private final Map<Subtopology, Set<TaskId>> tasksForTopicGroup;
+
+    public BalanceSubtopologyGraphConstructor(final Map<Subtopology, 
Set<TaskId>> tasksForTopicGroup) {
+        this.tasksForTopicGroup = tasksForTopicGroup;
+    }
+
+    @Override
+    public int getSinkNodeID(final List<TaskId> taskIdList, final List<UUID> 
clientList,
+        final Map<Subtopology, Set<TaskId>> tasksForTopicGroup) {
+        return clientList.size() + taskIdList.size() + clientList.size() * 
tasksForTopicGroup.size();
+    }
+
+
+    @Override
+    public int getClientNodeId(final int clientIndex, final List<TaskId> 
taskIdList, final List<UUID> clientList, final int topicGroupIndex) {
+        return taskIdList.size() + clientList.size() * topicGroupIndex + 
clientIndex;
+    }
+
+    @Override
+    public int getClientIndex(final int clientNodeId, final List<TaskId> 
taskIdList, final List<UUID> clientList, final int topicGroupIndex) {
+        return clientNodeId - taskIdList.size() - clientList.size() * 
topicGroupIndex;
+    }
+
+    private static int getSecondStageClientNodeId(final List<TaskId> 
taskIdList, final List<UUID> clientList, final Map<Subtopology, Set<TaskId>> 
tasksForTopicGroup, final int clientIndex) {
+        return taskIdList.size() + clientList.size() * 
tasksForTopicGroup.size() + clientIndex;
+    }
+
+    @Override
+    public Graph<Integer> constructTaskGraph(final List<UUID> clientList,
+        final List<TaskId> taskIdList, final SortedMap<UUID, ClientState> 
clientStates,
+        final Map<TaskId, UUID> taskClientMap, final Map<UUID, Integer> 
originalAssignedTaskNumber,
+        final BiPredicate<ClientState, TaskId> hasAssignedTask, final 
GetCostFunction getCostFunction, final int trafficCost,
+        final int nonOverlapCost, final boolean hasReplica, final boolean 
isStandby) {

Review Comment:
   nit: formatting



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/BalanceSubtopologyGraphConstructor.java:
##########
@@ -0,0 +1,186 @@
+package org.apache.kafka.streams.processor.internals.assignment;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Objects;
+import java.util.Set;
+import java.util.SortedMap;
+import java.util.SortedSet;
+import java.util.TreeMap;
+import java.util.TreeSet;
+import java.util.UUID;
+import java.util.function.BiConsumer;
+import java.util.function.BiPredicate;
+import org.apache.kafka.streams.processor.TaskId;
+import 
org.apache.kafka.streams.processor.internals.TopologyMetadata.Subtopology;
+import 
org.apache.kafka.streams.processor.internals.assignment.RackAwareTaskAssignor.GetCostFunction;
+
+public class BalanceSubtopologyGraphConstructor implements 
RackAwareGraphConstructor {
+
+    private final Map<Subtopology, Set<TaskId>> tasksForTopicGroup;
+
+    public BalanceSubtopologyGraphConstructor(final Map<Subtopology, 
Set<TaskId>> tasksForTopicGroup) {
+        this.tasksForTopicGroup = tasksForTopicGroup;
+    }
+
+    @Override
+    public int getSinkNodeID(final List<TaskId> taskIdList, final List<UUID> 
clientList,
+        final Map<Subtopology, Set<TaskId>> tasksForTopicGroup) {
+        return clientList.size() + taskIdList.size() + clientList.size() * 
tasksForTopicGroup.size();
+    }
+
+
+    @Override
+    public int getClientNodeId(final int clientIndex, final List<TaskId> 
taskIdList, final List<UUID> clientList, final int topicGroupIndex) {
+        return taskIdList.size() + clientList.size() * topicGroupIndex + 
clientIndex;
+    }
+
+    @Override
+    public int getClientIndex(final int clientNodeId, final List<TaskId> 
taskIdList, final List<UUID> clientList, final int topicGroupIndex) {
+        return clientNodeId - taskIdList.size() - clientList.size() * 
topicGroupIndex;
+    }
+
+    private static int getSecondStageClientNodeId(final List<TaskId> 
taskIdList, final List<UUID> clientList, final Map<Subtopology, Set<TaskId>> 
tasksForTopicGroup, final int clientIndex) {
+        return taskIdList.size() + clientList.size() * 
tasksForTopicGroup.size() + clientIndex;
+    }
+
+    @Override
+    public Graph<Integer> constructTaskGraph(final List<UUID> clientList,
+        final List<TaskId> taskIdList, final SortedMap<UUID, ClientState> 
clientStates,
+        final Map<TaskId, UUID> taskClientMap, final Map<UUID, Integer> 
originalAssignedTaskNumber,
+        final BiPredicate<ClientState, TaskId> hasAssignedTask, final 
GetCostFunction getCostFunction, final int trafficCost,
+        final int nonOverlapCost, final boolean hasReplica, final boolean 
isStandby) {
+        final Graph<Integer> graph = new Graph<>();
+
+        for (final TaskId taskId : taskIdList) {
+            for (final Entry<UUID, ClientState> clientState : 
clientStates.entrySet()) {
+                if (hasAssignedTask.test(clientState.getValue(), taskId)) {
+                    originalAssignedTaskNumber.merge(clientState.getKey(), 1, 
Integer::sum);
+                }
+            }
+        }
+
+        // TODO: validate tasks in tasksForTopicGroup and taskIdList
+        final SortedMap<Subtopology, Set<TaskId>> sortedTasksForTopicGroup = 
new TreeMap<>(tasksForTopicGroup);
+        final int sinkId = getSinkNodeID(taskIdList, clientList, 
tasksForTopicGroup);
+
+        int taskNodeId = 0;
+        int topicGroupIndex = 0;
+        for (final Entry<Subtopology, Set<TaskId>> kv : 
sortedTasksForTopicGroup.entrySet()) {
+            final SortedSet<TaskId> taskIds = new TreeSet<>(kv.getValue());
+            for (int clientIndex = 0; clientIndex < clientList.size(); 
clientIndex++) {
+                final UUID processId = clientList.get(clientIndex);
+                final int clientNodeId = getClientNodeId(clientIndex, 
taskIdList, clientList, topicGroupIndex);
+                int startingTaskNodeId = taskNodeId;
+                for (final TaskId taskId : taskIds) {
+                    final int flow = 
hasAssignedTask.test(clientStates.get(processId), taskId) ? 1 : 0;
+                    graph.addEdge(startingTaskNodeId, clientNodeId, 1, 
getCostFunction.getCost(taskId, processId, false, trafficCost, nonOverlapCost, 
isStandby), flow);
+                    graph.addEdge(SOURCE_ID, startingTaskNodeId, 1, 0, 0);
+                    startingTaskNodeId++;
+                }
+
+                final int secondStageClientNodeId = 
getSecondStageClientNodeId(taskIdList, clientList, tasksForTopicGroup, 
clientIndex);
+                final int capacity = (int) 
Math.ceil(originalAssignedTaskNumber.get(processId) * 1.0 / taskIdList.size() * 
taskIds.size());
+                graph.addEdge(clientNodeId, secondStageClientNodeId, capacity, 
0, 0);
+            }
+
+            taskNodeId += taskIds.size();
+            topicGroupIndex++;
+        }
+
+        for (int clientIndex = 0; clientIndex < clientList.size(); 
clientIndex++) {
+            final UUID processId = clientList.get(clientIndex);
+            final int capacity = originalAssignedTaskNumber.get(processId);
+            final int secondStageClientNodeId = 
getSecondStageClientNodeId(taskIdList, clientList, tasksForTopicGroup, 
clientIndex);
+            graph.addEdge(secondStageClientNodeId, sinkId, capacity, 0, 0);
+        }
+
+        graph.setSourceNode(SOURCE_ID);
+        graph.setSinkNode(sinkId);
+
+        // Run max flow algorithm to get a solution first
+        final long maxFlow = graph.calculateMaxFlow();
+        if (maxFlow != taskIdList.size()) {
+            throw new IllegalStateException("max flow calculated: " + maxFlow 
+ " doesn't match taskSize: " + taskIdList.size());
+        }
+
+        return graph;
+    }
+
+    @Override
+    public boolean assignTaskFromMinCostFlow(final Graph<Integer> graph,
+        final List<UUID> clientList, final List<TaskId> taskIdList,
+        final Map<UUID, ClientState> clientStates,
+        final Map<UUID, Integer> originalAssignedTaskNumber, final Map<TaskId, 
UUID> taskClientMap,
+        final BiConsumer<ClientState, TaskId> assignTask,
+        final BiConsumer<ClientState, TaskId> unAssignTask,
+        final BiPredicate<ClientState, TaskId> hasAssignedTask) {
+
+        final SortedMap<Subtopology, Set<TaskId>> sortedTasksForTopicGroup = 
new TreeMap<>(tasksForTopicGroup);
+
+        int taskNodeId = 0;
+        int topicGroupIndex = 0;
+        int tasksAssigned = 0;
+        boolean taskMoved = false;
+        for (final Entry<Subtopology, Set<TaskId>> kv : 
sortedTasksForTopicGroup.entrySet()) {
+            final SortedSet<TaskId> taskIds = new TreeSet<>(kv.getValue());
+            for (final TaskId taskId : taskIds) {
+                final Map<Integer, Graph<Integer>.Edge> edges = 
graph.edges(taskNodeId);
+                for (final Graph<Integer>.Edge edge : edges.values()) {

Review Comment:
   Seems this `for` loop is the same as for min-cost? Should we extract the 
code into a helper method?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: jira-unsubscr...@kafka.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to