lihaosky commented on code in PR #14030:
URL: https://github.com/apache/kafka/pull/14030#discussion_r1272692608


##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignorConfiguration.java:
##########
@@ -268,24 +268,44 @@ public static class AssignmentConfigs {
         public final long probingRebalanceIntervalMs;
         public final List<String> rackAwareAssignmentTags;
 
+        // TODO: get from streamsConfig after we add the config

Review Comment:
   I refactored this a bit and will add comments in `RackAwareTaskAssignor` 
method



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignor.java:
##########
@@ -185,4 +191,224 @@ public boolean validateClientRack() {
         }
         return true;
     }
+
+    private int getCost(final TaskId taskId, final UUID clientId, final 
boolean inCurrentAssignment, final boolean isStateful) {
+        final Map<String, Optional<String>> clientRacks = 
racksForProcess.get(clientId);
+        if (clientRacks == null) {
+            throw new IllegalStateException("Client " + clientId + " doesn't 
exist in processRacks");
+        }
+        final Optional<Optional<String>> clientRackOpt = 
clientRacks.values().stream().filter(Optional::isPresent).findFirst();
+        if (!clientRackOpt.isPresent() || !clientRackOpt.get().isPresent()) {
+            throw new IllegalStateException("Client " + clientId + " doesn't 
have rack configured. Maybe forgot to call canEnableRackAwareAssignor first");
+        }
+
+        final String clientRack = clientRackOpt.get().get();
+        final Set<TopicPartition> topicPartitions = 
partitionsForTask.get(taskId);
+        if (topicPartitions == null) {
+            throw new IllegalStateException("Task " + taskId + " has no 
TopicPartitions");
+        }
+
+        final int trafficCost = assignmentConfigs.trafficCost == null ? 
(isStateful ? DEFAULT_STATEFUL_TRAFFIC_COST : DEFAULT_STATELESS_TRAFFIC_COST)
+            : assignmentConfigs.trafficCost;
+        final int nonOverlapCost = assignmentConfigs.nonOverlapCost == null ? 
(isStateful ? DEFAULT_STATEFUL_NON_OVERLAP_COST : 
DEFAULT_STATELESS_NON_OVERLAP_COST)
+            : assignmentConfigs.nonOverlapCost;
+
+        int cost = 0;
+        for (final TopicPartition tp : topicPartitions) {
+            final Set<String> tpRacks = racksForPartition.get(tp);
+            if (tpRacks == null || tpRacks.isEmpty()) {
+                throw new IllegalStateException("TopicPartition " + tp + " has 
no rack information. Maybe forgot to call canEnableRackAwareAssignor first");
+            }
+            if (!tpRacks.contains(clientRack)) {
+                cost += trafficCost;
+            }
+        }
+
+        if (!inCurrentAssignment) {
+            cost += nonOverlapCost;

Review Comment:
   Or `moveTaskCost`? `moveAssignmentCost`? Have no strong preference here...



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignor.java:
##########
@@ -185,4 +193,212 @@ public boolean validateClientRack() {
         }
         return true;
     }
+
+    private int getCost(final TaskId taskId, final UUID clientId, final 
boolean inCurrentAssignment, final boolean isStateful) {
+        final Map<String, Optional<String>> clientRacks = 
racksForProcess.get(clientId);
+        if (clientRacks == null) {
+            throw new IllegalStateException("Client " + clientId + " doesn't 
exist in processRacks");
+        }
+        final Optional<Optional<String>> clientRackOpt = 
clientRacks.values().stream().filter(Optional::isPresent).findFirst();
+        if (!clientRackOpt.isPresent() || !clientRackOpt.get().isPresent()) {
+            throw new IllegalStateException("Client " + clientId + " doesn't 
have rack configured. Maybe forgot to call canEnableRackAwareAssignor first");
+        }
+
+        final String clientRack = clientRackOpt.get().get();
+        final Set<TopicPartition> topicPartitions = 
partitionsForTask.get(taskId);
+        if (topicPartitions == null || topicPartitions.isEmpty()) {
+            throw new IllegalStateException("Task " + taskId + " has no 
TopicPartitions");
+        }
+
+        final int trafficCost = assignmentConfigs.trafficCost == null ? 
(isStateful ? DEFAULT_STATEFUL_TRAFFIC_COST : DEFAULT_STATELESS_TRAFFIC_COST)
+            : assignmentConfigs.trafficCost;
+        final int nonOverlapCost = assignmentConfigs.nonOverlapCost == null ? 
(isStateful ? DEFAULT_STATEFUL_NON_OVERLAP_COST : 
DEFAULT_STATELESS_NON_OVERLAP_COST)
+            : assignmentConfigs.nonOverlapCost;
+
+        int cost = 0;
+        for (final TopicPartition tp : topicPartitions) {
+            final Set<String> tpRacks = racksForPartition.get(tp);
+            if (tpRacks == null || tpRacks.isEmpty()) {
+                throw new IllegalStateException("TopicPartition " + tp + " has 
no rack information. Maybe forgot to call canEnableRackAwareAssignor first");
+            }
+            if (!tpRacks.contains(clientRack)) {
+                cost += trafficCost;
+            }
+        }
+
+        if (!inCurrentAssignment) {
+            cost += nonOverlapCost;
+        }
+
+        return cost;
+    }
+
+    private static int getSinkID(final List<UUID> clientList, final 
List<TaskId> taskIdList) {
+        return clientList.size() + taskIdList.size();
+    }
+
+    // For testing. canEnableRackAwareAssignor must be called first
+    long activeTasksCost(final SortedMap<UUID, ClientState> clientStates, 
final SortedSet<TaskId> activeTasks, final boolean isStateful) {
+        final List<UUID> clientList = new ArrayList<>(clientStates.keySet());

Review Comment:
   Yes. List is trying to make it deterministic. `keySet` of `SortedMap` to 
list should maintain order. That's why line 241 requires sorted map.



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignor.java:
##########
@@ -185,4 +193,212 @@ public boolean validateClientRack() {
         }
         return true;
     }
+
+    private int getCost(final TaskId taskId, final UUID clientId, final 
boolean inCurrentAssignment, final boolean isStateful) {
+        final Map<String, Optional<String>> clientRacks = 
racksForProcess.get(clientId);
+        if (clientRacks == null) {
+            throw new IllegalStateException("Client " + clientId + " doesn't 
exist in processRacks");
+        }
+        final Optional<Optional<String>> clientRackOpt = 
clientRacks.values().stream().filter(Optional::isPresent).findFirst();
+        if (!clientRackOpt.isPresent() || !clientRackOpt.get().isPresent()) {
+            throw new IllegalStateException("Client " + clientId + " doesn't 
have rack configured. Maybe forgot to call canEnableRackAwareAssignor first");
+        }
+
+        final String clientRack = clientRackOpt.get().get();
+        final Set<TopicPartition> topicPartitions = 
partitionsForTask.get(taskId);
+        if (topicPartitions == null || topicPartitions.isEmpty()) {
+            throw new IllegalStateException("Task " + taskId + " has no 
TopicPartitions");
+        }
+
+        final int trafficCost = assignmentConfigs.trafficCost == null ? 
(isStateful ? DEFAULT_STATEFUL_TRAFFIC_COST : DEFAULT_STATELESS_TRAFFIC_COST)
+            : assignmentConfigs.trafficCost;
+        final int nonOverlapCost = assignmentConfigs.nonOverlapCost == null ? 
(isStateful ? DEFAULT_STATEFUL_NON_OVERLAP_COST : 
DEFAULT_STATELESS_NON_OVERLAP_COST)
+            : assignmentConfigs.nonOverlapCost;
+
+        int cost = 0;
+        for (final TopicPartition tp : topicPartitions) {
+            final Set<String> tpRacks = racksForPartition.get(tp);
+            if (tpRacks == null || tpRacks.isEmpty()) {
+                throw new IllegalStateException("TopicPartition " + tp + " has 
no rack information. Maybe forgot to call canEnableRackAwareAssignor first");
+            }
+            if (!tpRacks.contains(clientRack)) {
+                cost += trafficCost;
+            }
+        }
+
+        if (!inCurrentAssignment) {
+            cost += nonOverlapCost;
+        }
+
+        return cost;
+    }
+
+    private static int getSinkID(final List<UUID> clientList, final 
List<TaskId> taskIdList) {
+        return clientList.size() + taskIdList.size();
+    }
+
+    // For testing. canEnableRackAwareAssignor must be called first
+    long activeTasksCost(final SortedMap<UUID, ClientState> clientStates, 
final SortedSet<TaskId> activeTasks, final boolean isStateful) {
+        final List<UUID> clientList = new ArrayList<>(clientStates.keySet());
+        final List<TaskId> taskIdList = new ArrayList<>(activeTasks);
+        final Map<TaskId, UUID> taskClientMap = new HashMap<>();
+        final Map<UUID, Integer> originalAssignedTaskNumber = new HashMap<>();
+        final Graph<Integer> graph = constructActiveTaskGraph(activeTasks, 
clientList, taskIdList,
+            clientStates, taskClientMap, originalAssignedTaskNumber, 
isStateful);
+        return graph.totalCost();
+    }
+
+    /**
+     * Optimize active task assignment for rack awareness. 
canEnableRackAwareAssignor must be called first
+     * @param clientStates Client states
+     * @param activeTasks Tasks to reassign if needed. They must be assigned 
already in clientStates
+     * @param isStateful Whether the tasks are stateful
+     * @return Total cost after optimization
+     */
+    public long optimizeActiveTasks(final SortedMap<UUID, ClientState> 
clientStates,
+                                    final SortedSet<TaskId> activeTasks,
+                                    final boolean isStateful) {
+        if (activeTasks.isEmpty()) {
+            return 0;
+        }
+
+        final List<UUID> clientList = new ArrayList<>(clientStates.keySet());
+        final List<TaskId> taskIdList = new ArrayList<>(activeTasks);
+        final Map<TaskId, UUID> taskClientMap = new HashMap<>();
+        final Map<UUID, Integer> originalAssignedTaskNumber = new HashMap<>();
+        final Graph<Integer> graph = constructActiveTaskGraph(activeTasks, 
clientList, taskIdList,
+            clientStates, taskClientMap, originalAssignedTaskNumber, 
isStateful);
+
+        graph.solveMinCostFlow();
+        final long cost = graph.totalCost();
+
+        assignActiveTaskFromMinCostFlow(graph, activeTasks, clientList, 
taskIdList,
+            clientStates, originalAssignedTaskNumber, taskClientMap);
+
+        return cost;
+    }
+
+    private Graph<Integer> constructActiveTaskGraph(final SortedSet<TaskId> 
activeTasks,
+                                                    final List<UUID> 
clientList,
+                                                    final List<TaskId> 
taskIdList,
+                                                    final Map<UUID, 
ClientState> clientStates,
+                                                    final Map<TaskId, UUID> 
taskClientMap,
+                                                    final Map<UUID, Integer> 
originalAssignedTaskNumber,
+                                                    final boolean isStateful) {
+        final Graph<Integer> graph = new Graph<>();
+
+        for (final TaskId taskId : activeTasks) {
+            for (final Entry<UUID, ClientState> clientState : 
clientStates.entrySet()) {
+                if (clientState.getValue().hasAssignedTask(taskId)) {
+                    originalAssignedTaskNumber.merge(clientState.getKey(), 1, 
Integer::sum);
+                }
+            }
+        }
+
+        // Make task and client Node id in graph deterministic
+        for (int taskNodeId  = 0; taskNodeId < taskIdList.size(); 
taskNodeId++) {
+            final TaskId taskId = taskIdList.get(taskNodeId);
+            for (int j = 0; j < clientList.size(); j++) {
+                final int clientNodeId = taskIdList.size() + j;
+                final UUID clientId = clientList.get(j);
+
+                final int flow = 
clientStates.get(clientId).hasAssignedTask(taskId) ? 1 : 0;
+                final int cost = getCost(taskId, clientId, flow == 1, 
isStateful);
+                if (flow == 1) {
+                    if (taskClientMap.containsKey(taskId)) {
+                        throw new IllegalArgumentException("Task " + taskId + 
" assigned to multiple clients "
+                            + clientId + ", " + taskClientMap.get(taskId));
+                    }
+                    taskClientMap.put(taskId, clientId);
+                }
+
+                graph.addEdge(taskNodeId, clientNodeId, 1, cost, flow);
+            }
+            if (!taskClientMap.containsKey(taskId)) {
+                throw new IllegalArgumentException("Task " + taskId + " not 
assigned to any client");
+            }
+        }
+
+        final int sinkId = getSinkID(clientList, taskIdList);
+        for (int taskNodeId = 0; taskNodeId < taskIdList.size(); taskNodeId++) 
{
+            graph.addEdge(SOURCE_ID, taskNodeId, 1, 0, 1);
+        }
+
+        // It's possible that some clients have 0 task assign. These clients 
will have 0 tasks assigned
+        // even though it may have lower traffic cost. This is to maintain the 
balance requirement.
+        for (int i = 0; i < clientList.size(); i++) {
+            final int clientNodeId = taskIdList.size() + i;
+            final int capacity = 
originalAssignedTaskNumber.getOrDefault(clientList.get(i), 0);
+            // Flow equals to capacity for edges to sink
+            graph.addEdge(clientNodeId, sinkId, capacity, 0, capacity);
+        }
+
+        graph.setSourceNode(SOURCE_ID);
+        graph.setSinkNode(sinkId);
+
+        return graph;
+    }
+
+    private void assignActiveTaskFromMinCostFlow(final Graph<Integer> graph,
+                                                 final SortedSet<TaskId> 
activeTasks,
+                                                 final List<UUID> clientList,
+                                                 final List<TaskId> taskIdList,
+                                                 final Map<UUID, ClientState> 
clientStates,
+                                                 final Map<UUID, Integer> 
originalClientCapacity,
+                                                 final Map<TaskId, UUID> 
taskClientMap) {
+        int taskAssigned = 0;
+        for (int taskNodeId = 0; taskNodeId < taskIdList.size(); taskNodeId++) 
{
+            final TaskId taskId = taskIdList.get(taskNodeId);
+            final Map<Integer, Graph<Integer>.Edge> edges = 
graph.edges(taskNodeId);
+            for (final Graph<Integer>.Edge edge : edges.values()) {
+                if (edge.flow > 0) {
+                    taskAssigned++;
+                    final int clientIndex = edge.destination - 
taskIdList.size();
+                    final UUID clientId = clientList.get(clientIndex);
+                    final UUID originalClientId = taskClientMap.get(taskId);
+
+                    // Don't need to assign this task to other client
+                    if (clientId.equals(originalClientId)) {
+                        break;
+                    }
+
+                    final ClientState clientState = clientStates.get(clientId);
+                    final ClientState originalClientState = 
clientStates.get(originalClientId);
+                    originalClientState.unassignActive(taskId);
+                    clientState.assignActive(taskId);
+                }

Review Comment:
   Yeah. There should be only one edge. I didn't break here for the validations 
below to catch anything wrong



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignor.java:
##########
@@ -185,4 +193,212 @@ public boolean validateClientRack() {
         }
         return true;
     }
+
+    private int getCost(final TaskId taskId, final UUID clientId, final 
boolean inCurrentAssignment, final boolean isStateful) {
+        final Map<String, Optional<String>> clientRacks = 
racksForProcess.get(clientId);
+        if (clientRacks == null) {
+            throw new IllegalStateException("Client " + clientId + " doesn't 
exist in processRacks");
+        }
+        final Optional<Optional<String>> clientRackOpt = 
clientRacks.values().stream().filter(Optional::isPresent).findFirst();
+        if (!clientRackOpt.isPresent() || !clientRackOpt.get().isPresent()) {
+            throw new IllegalStateException("Client " + clientId + " doesn't 
have rack configured. Maybe forgot to call canEnableRackAwareAssignor first");
+        }
+
+        final String clientRack = clientRackOpt.get().get();
+        final Set<TopicPartition> topicPartitions = 
partitionsForTask.get(taskId);
+        if (topicPartitions == null || topicPartitions.isEmpty()) {
+            throw new IllegalStateException("Task " + taskId + " has no 
TopicPartitions");
+        }
+
+        final int trafficCost = assignmentConfigs.trafficCost == null ? 
(isStateful ? DEFAULT_STATEFUL_TRAFFIC_COST : DEFAULT_STATELESS_TRAFFIC_COST)
+            : assignmentConfigs.trafficCost;
+        final int nonOverlapCost = assignmentConfigs.nonOverlapCost == null ? 
(isStateful ? DEFAULT_STATEFUL_NON_OVERLAP_COST : 
DEFAULT_STATELESS_NON_OVERLAP_COST)
+            : assignmentConfigs.nonOverlapCost;
+
+        int cost = 0;
+        for (final TopicPartition tp : topicPartitions) {
+            final Set<String> tpRacks = racksForPartition.get(tp);
+            if (tpRacks == null || tpRacks.isEmpty()) {
+                throw new IllegalStateException("TopicPartition " + tp + " has 
no rack information. Maybe forgot to call canEnableRackAwareAssignor first");
+            }
+            if (!tpRacks.contains(clientRack)) {
+                cost += trafficCost;
+            }
+        }
+
+        if (!inCurrentAssignment) {
+            cost += nonOverlapCost;
+        }
+
+        return cost;
+    }
+
+    private static int getSinkID(final List<UUID> clientList, final 
List<TaskId> taskIdList) {
+        return clientList.size() + taskIdList.size();
+    }
+
+    // For testing. canEnableRackAwareAssignor must be called first
+    long activeTasksCost(final SortedMap<UUID, ClientState> clientStates, 
final SortedSet<TaskId> activeTasks, final boolean isStateful) {
+        final List<UUID> clientList = new ArrayList<>(clientStates.keySet());
+        final List<TaskId> taskIdList = new ArrayList<>(activeTasks);
+        final Map<TaskId, UUID> taskClientMap = new HashMap<>();
+        final Map<UUID, Integer> originalAssignedTaskNumber = new HashMap<>();
+        final Graph<Integer> graph = constructActiveTaskGraph(activeTasks, 
clientList, taskIdList,
+            clientStates, taskClientMap, originalAssignedTaskNumber, 
isStateful);
+        return graph.totalCost();
+    }
+
+    /**
+     * Optimize active task assignment for rack awareness. 
canEnableRackAwareAssignor must be called first
+     * @param clientStates Client states
+     * @param activeTasks Tasks to reassign if needed. They must be assigned 
already in clientStates
+     * @param isStateful Whether the tasks are stateful

Review Comment:
   I'll refactor `isStateful` alway and pass in `trafficCost` and 
`nonOverlapCost` directly here.



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignor.java:
##########
@@ -185,4 +193,212 @@ public boolean validateClientRack() {
         }
         return true;
     }
+
+    private int getCost(final TaskId taskId, final UUID clientId, final 
boolean inCurrentAssignment, final boolean isStateful) {
+        final Map<String, Optional<String>> clientRacks = 
racksForProcess.get(clientId);
+        if (clientRacks == null) {
+            throw new IllegalStateException("Client " + clientId + " doesn't 
exist in processRacks");
+        }
+        final Optional<Optional<String>> clientRackOpt = 
clientRacks.values().stream().filter(Optional::isPresent).findFirst();
+        if (!clientRackOpt.isPresent() || !clientRackOpt.get().isPresent()) {
+            throw new IllegalStateException("Client " + clientId + " doesn't 
have rack configured. Maybe forgot to call canEnableRackAwareAssignor first");
+        }
+
+        final String clientRack = clientRackOpt.get().get();
+        final Set<TopicPartition> topicPartitions = 
partitionsForTask.get(taskId);
+        if (topicPartitions == null || topicPartitions.isEmpty()) {
+            throw new IllegalStateException("Task " + taskId + " has no 
TopicPartitions");
+        }
+
+        final int trafficCost = assignmentConfigs.trafficCost == null ? 
(isStateful ? DEFAULT_STATEFUL_TRAFFIC_COST : DEFAULT_STATELESS_TRAFFIC_COST)
+            : assignmentConfigs.trafficCost;
+        final int nonOverlapCost = assignmentConfigs.nonOverlapCost == null ? 
(isStateful ? DEFAULT_STATEFUL_NON_OVERLAP_COST : 
DEFAULT_STATELESS_NON_OVERLAP_COST)
+            : assignmentConfigs.nonOverlapCost;
+
+        int cost = 0;
+        for (final TopicPartition tp : topicPartitions) {
+            final Set<String> tpRacks = racksForPartition.get(tp);
+            if (tpRacks == null || tpRacks.isEmpty()) {
+                throw new IllegalStateException("TopicPartition " + tp + " has 
no rack information. Maybe forgot to call canEnableRackAwareAssignor first");
+            }
+            if (!tpRacks.contains(clientRack)) {
+                cost += trafficCost;
+            }
+        }
+
+        if (!inCurrentAssignment) {
+            cost += nonOverlapCost;
+        }
+
+        return cost;
+    }
+
+    private static int getSinkID(final List<UUID> clientList, final 
List<TaskId> taskIdList) {
+        return clientList.size() + taskIdList.size();
+    }
+
+    // For testing. canEnableRackAwareAssignor must be called first
+    long activeTasksCost(final SortedMap<UUID, ClientState> clientStates, 
final SortedSet<TaskId> activeTasks, final boolean isStateful) {

Review Comment:
   The `keySet` or List we get should be deterministic which can ensure out 
min-cost assignment is deterministic. If it's not sorted when passed in, we 
need to sort it ourselves. 



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignor.java:
##########
@@ -185,4 +193,212 @@ public boolean validateClientRack() {
         }
         return true;
     }
+
+    private int getCost(final TaskId taskId, final UUID clientId, final 
boolean inCurrentAssignment, final boolean isStateful) {
+        final Map<String, Optional<String>> clientRacks = 
racksForProcess.get(clientId);
+        if (clientRacks == null) {
+            throw new IllegalStateException("Client " + clientId + " doesn't 
exist in processRacks");
+        }
+        final Optional<Optional<String>> clientRackOpt = 
clientRacks.values().stream().filter(Optional::isPresent).findFirst();
+        if (!clientRackOpt.isPresent() || !clientRackOpt.get().isPresent()) {
+            throw new IllegalStateException("Client " + clientId + " doesn't 
have rack configured. Maybe forgot to call canEnableRackAwareAssignor first");
+        }
+
+        final String clientRack = clientRackOpt.get().get();
+        final Set<TopicPartition> topicPartitions = 
partitionsForTask.get(taskId);
+        if (topicPartitions == null || topicPartitions.isEmpty()) {
+            throw new IllegalStateException("Task " + taskId + " has no 
TopicPartitions");
+        }
+
+        final int trafficCost = assignmentConfigs.trafficCost == null ? 
(isStateful ? DEFAULT_STATEFUL_TRAFFIC_COST : DEFAULT_STATELESS_TRAFFIC_COST)
+            : assignmentConfigs.trafficCost;
+        final int nonOverlapCost = assignmentConfigs.nonOverlapCost == null ? 
(isStateful ? DEFAULT_STATEFUL_NON_OVERLAP_COST : 
DEFAULT_STATELESS_NON_OVERLAP_COST)
+            : assignmentConfigs.nonOverlapCost;
+
+        int cost = 0;
+        for (final TopicPartition tp : topicPartitions) {
+            final Set<String> tpRacks = racksForPartition.get(tp);
+            if (tpRacks == null || tpRacks.isEmpty()) {
+                throw new IllegalStateException("TopicPartition " + tp + " has 
no rack information. Maybe forgot to call canEnableRackAwareAssignor first");
+            }
+            if (!tpRacks.contains(clientRack)) {
+                cost += trafficCost;
+            }
+        }
+
+        if (!inCurrentAssignment) {
+            cost += nonOverlapCost;
+        }
+
+        return cost;
+    }
+
+    private static int getSinkID(final List<UUID> clientList, final 
List<TaskId> taskIdList) {
+        return clientList.size() + taskIdList.size();
+    }
+
+    // For testing. canEnableRackAwareAssignor must be called first
+    long activeTasksCost(final SortedMap<UUID, ClientState> clientStates, 
final SortedSet<TaskId> activeTasks, final boolean isStateful) {
+        final List<UUID> clientList = new ArrayList<>(clientStates.keySet());
+        final List<TaskId> taskIdList = new ArrayList<>(activeTasks);

Review Comment:
   This is also to make it deterministic. `activeTasks` is a `SortedSet`



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignor.java:
##########
@@ -185,4 +193,212 @@ public boolean validateClientRack() {
         }
         return true;
     }
+
+    private int getCost(final TaskId taskId, final UUID clientId, final 
boolean inCurrentAssignment, final boolean isStateful) {
+        final Map<String, Optional<String>> clientRacks = 
racksForProcess.get(clientId);
+        if (clientRacks == null) {
+            throw new IllegalStateException("Client " + clientId + " doesn't 
exist in processRacks");
+        }
+        final Optional<Optional<String>> clientRackOpt = 
clientRacks.values().stream().filter(Optional::isPresent).findFirst();
+        if (!clientRackOpt.isPresent() || !clientRackOpt.get().isPresent()) {

Review Comment:
   Yeah. This is to mute some warning in Intellij, Checkstyle or spotBugs



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignor.java:
##########
@@ -185,4 +193,212 @@ public boolean validateClientRack() {
         }
         return true;
     }
+
+    private int getCost(final TaskId taskId, final UUID clientId, final 
boolean inCurrentAssignment, final boolean isStateful) {
+        final Map<String, Optional<String>> clientRacks = 
racksForProcess.get(clientId);
+        if (clientRacks == null) {
+            throw new IllegalStateException("Client " + clientId + " doesn't 
exist in processRacks");
+        }
+        final Optional<Optional<String>> clientRackOpt = 
clientRacks.values().stream().filter(Optional::isPresent).findFirst();
+        if (!clientRackOpt.isPresent() || !clientRackOpt.get().isPresent()) {
+            throw new IllegalStateException("Client " + clientId + " doesn't 
have rack configured. Maybe forgot to call canEnableRackAwareAssignor first");
+        }
+
+        final String clientRack = clientRackOpt.get().get();
+        final Set<TopicPartition> topicPartitions = 
partitionsForTask.get(taskId);
+        if (topicPartitions == null || topicPartitions.isEmpty()) {
+            throw new IllegalStateException("Task " + taskId + " has no 
TopicPartitions");
+        }
+
+        final int trafficCost = assignmentConfigs.trafficCost == null ? 
(isStateful ? DEFAULT_STATEFUL_TRAFFIC_COST : DEFAULT_STATELESS_TRAFFIC_COST)
+            : assignmentConfigs.trafficCost;
+        final int nonOverlapCost = assignmentConfigs.nonOverlapCost == null ? 
(isStateful ? DEFAULT_STATEFUL_NON_OVERLAP_COST : 
DEFAULT_STATELESS_NON_OVERLAP_COST)
+            : assignmentConfigs.nonOverlapCost;
+
+        int cost = 0;
+        for (final TopicPartition tp : topicPartitions) {
+            final Set<String> tpRacks = racksForPartition.get(tp);
+            if (tpRacks == null || tpRacks.isEmpty()) {
+                throw new IllegalStateException("TopicPartition " + tp + " has 
no rack information. Maybe forgot to call canEnableRackAwareAssignor first");
+            }
+            if (!tpRacks.contains(clientRack)) {
+                cost += trafficCost;
+            }
+        }
+
+        if (!inCurrentAssignment) {
+            cost += nonOverlapCost;
+        }
+
+        return cost;
+    }
+
+    private static int getSinkID(final List<UUID> clientList, final 
List<TaskId> taskIdList) {
+        return clientList.size() + taskIdList.size();
+    }
+
+    // For testing. canEnableRackAwareAssignor must be called first
+    long activeTasksCost(final SortedMap<UUID, ClientState> clientStates, 
final SortedSet<TaskId> activeTasks, final boolean isStateful) {
+        final List<UUID> clientList = new ArrayList<>(clientStates.keySet());
+        final List<TaskId> taskIdList = new ArrayList<>(activeTasks);
+        final Map<TaskId, UUID> taskClientMap = new HashMap<>();
+        final Map<UUID, Integer> originalAssignedTaskNumber = new HashMap<>();
+        final Graph<Integer> graph = constructActiveTaskGraph(activeTasks, 
clientList, taskIdList,
+            clientStates, taskClientMap, originalAssignedTaskNumber, 
isStateful);
+        return graph.totalCost();
+    }
+
+    /**
+     * Optimize active task assignment for rack awareness. 
canEnableRackAwareAssignor must be called first
+     * @param clientStates Client states
+     * @param activeTasks Tasks to reassign if needed. They must be assigned 
already in clientStates
+     * @param isStateful Whether the tasks are stateful
+     * @return Total cost after optimization
+     */
+    public long optimizeActiveTasks(final SortedMap<UUID, ClientState> 
clientStates,
+                                    final SortedSet<TaskId> activeTasks,
+                                    final boolean isStateful) {
+        if (activeTasks.isEmpty()) {
+            return 0;
+        }
+
+        final List<UUID> clientList = new ArrayList<>(clientStates.keySet());
+        final List<TaskId> taskIdList = new ArrayList<>(activeTasks);
+        final Map<TaskId, UUID> taskClientMap = new HashMap<>();
+        final Map<UUID, Integer> originalAssignedTaskNumber = new HashMap<>();
+        final Graph<Integer> graph = constructActiveTaskGraph(activeTasks, 
clientList, taskIdList,
+            clientStates, taskClientMap, originalAssignedTaskNumber, 
isStateful);
+
+        graph.solveMinCostFlow();
+        final long cost = graph.totalCost();
+
+        assignActiveTaskFromMinCostFlow(graph, activeTasks, clientList, 
taskIdList,
+            clientStates, originalAssignedTaskNumber, taskClientMap);
+
+        return cost;
+    }
+
+    private Graph<Integer> constructActiveTaskGraph(final SortedSet<TaskId> 
activeTasks,
+                                                    final List<UUID> 
clientList,
+                                                    final List<TaskId> 
taskIdList,
+                                                    final Map<UUID, 
ClientState> clientStates,
+                                                    final Map<TaskId, UUID> 
taskClientMap,
+                                                    final Map<UUID, Integer> 
originalAssignedTaskNumber,
+                                                    final boolean isStateful) {
+        final Graph<Integer> graph = new Graph<>();
+
+        for (final TaskId taskId : activeTasks) {
+            for (final Entry<UUID, ClientState> clientState : 
clientStates.entrySet()) {
+                if (clientState.getValue().hasAssignedTask(taskId)) {
+                    originalAssignedTaskNumber.merge(clientState.getKey(), 1, 
Integer::sum);
+                }
+            }
+        }
+
+        // Make task and client Node id in graph deterministic
+        for (int taskNodeId  = 0; taskNodeId < taskIdList.size(); 
taskNodeId++) {
+            final TaskId taskId = taskIdList.get(taskNodeId);
+            for (int j = 0; j < clientList.size(); j++) {
+                final int clientNodeId = taskIdList.size() + j;
+                final UUID clientId = clientList.get(j);
+
+                final int flow = 
clientStates.get(clientId).hasAssignedTask(taskId) ? 1 : 0;
+                final int cost = getCost(taskId, clientId, flow == 1, 
isStateful);
+                if (flow == 1) {
+                    if (taskClientMap.containsKey(taskId)) {
+                        throw new IllegalArgumentException("Task " + taskId + 
" assigned to multiple clients "
+                            + clientId + ", " + taskClientMap.get(taskId));
+                    }
+                    taskClientMap.put(taskId, clientId);
+                }
+
+                graph.addEdge(taskNodeId, clientNodeId, 1, cost, flow);
+            }
+            if (!taskClientMap.containsKey(taskId)) {
+                throw new IllegalArgumentException("Task " + taskId + " not 
assigned to any client");
+            }
+        }
+
+        final int sinkId = getSinkID(clientList, taskIdList);
+        for (int taskNodeId = 0; taskNodeId < taskIdList.size(); taskNodeId++) 
{
+            graph.addEdge(SOURCE_ID, taskNodeId, 1, 0, 1);
+        }
+
+        // It's possible that some clients have 0 task assign. These clients 
will have 0 tasks assigned
+        // even though it may have lower traffic cost. This is to maintain the 
balance requirement.
+        for (int i = 0; i < clientList.size(); i++) {
+            final int clientNodeId = taskIdList.size() + i;
+            final int capacity = 
originalAssignedTaskNumber.getOrDefault(clientList.get(i), 0);
+            // Flow equals to capacity for edges to sink
+            graph.addEdge(clientNodeId, sinkId, capacity, 0, capacity);
+        }
+
+        graph.setSourceNode(SOURCE_ID);
+        graph.setSinkNode(sinkId);
+
+        return graph;
+    }
+
+    private void assignActiveTaskFromMinCostFlow(final Graph<Integer> graph,
+                                                 final SortedSet<TaskId> 
activeTasks,
+                                                 final List<UUID> clientList,
+                                                 final List<TaskId> taskIdList,
+                                                 final Map<UUID, ClientState> 
clientStates,
+                                                 final Map<UUID, Integer> 
originalClientCapacity,
+                                                 final Map<TaskId, UUID> 
taskClientMap) {
+        int taskAssigned = 0;
+        for (int taskNodeId = 0; taskNodeId < taskIdList.size(); taskNodeId++) 
{
+            final TaskId taskId = taskIdList.get(taskNodeId);
+            final Map<Integer, Graph<Integer>.Edge> edges = 
graph.edges(taskNodeId);
+            for (final Graph<Integer>.Edge edge : edges.values()) {
+                if (edge.flow > 0) {
+                    taskAssigned++;
+                    final int clientIndex = edge.destination - 
taskIdList.size();
+                    final UUID clientId = clientList.get(clientIndex);
+                    final UUID originalClientId = taskClientMap.get(taskId);
+
+                    // Don't need to assign this task to other client
+                    if (clientId.equals(originalClientId)) {
+                        break;
+                    }
+
+                    final ClientState clientState = clientStates.get(clientId);
+                    final ClientState originalClientState = 
clientStates.get(originalClientId);
+                    originalClientState.unassignActive(taskId);
+                    clientState.assignActive(taskId);
+                }
+            }
+        }
+
+        // Validate task assigned
+        if (taskAssigned != activeTasks.size()) {
+            throw new IllegalStateException("Computed active task assignment 
number "
+                + taskAssigned + " is different size " + activeTasks.size());
+        }
+
+        // Validate capacity constraint

Review Comment:
   This is the `originalAssignedTaskNum`. "load" is more accurate than 
`capacity`



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignor.java:
##########
@@ -185,4 +191,224 @@ public boolean validateClientRack() {
         }
         return true;
     }
+
+    private int getCost(final TaskId taskId, final UUID clientId, final 
boolean inCurrentAssignment, final boolean isStateful) {

Review Comment:
   Sure. I can add a comment to say it's `processId`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: jira-unsubscr...@kafka.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to