lihaosky commented on code in PR #14030:
URL: https://github.com/apache/kafka/pull/14030#discussion_r1272724945


##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignor.java:
##########
@@ -185,4 +193,212 @@ public boolean validateClientRack() {
         }
         return true;
     }
+
+    private int getCost(final TaskId taskId, final UUID clientId, final 
boolean inCurrentAssignment, final boolean isStateful) {
+        final Map<String, Optional<String>> clientRacks = 
racksForProcess.get(clientId);
+        if (clientRacks == null) {
+            throw new IllegalStateException("Client " + clientId + " doesn't 
exist in processRacks");
+        }
+        final Optional<Optional<String>> clientRackOpt = 
clientRacks.values().stream().filter(Optional::isPresent).findFirst();
+        if (!clientRackOpt.isPresent() || !clientRackOpt.get().isPresent()) {
+            throw new IllegalStateException("Client " + clientId + " doesn't 
have rack configured. Maybe forgot to call canEnableRackAwareAssignor first");
+        }
+
+        final String clientRack = clientRackOpt.get().get();
+        final Set<TopicPartition> topicPartitions = 
partitionsForTask.get(taskId);
+        if (topicPartitions == null || topicPartitions.isEmpty()) {
+            throw new IllegalStateException("Task " + taskId + " has no 
TopicPartitions");
+        }
+
+        final int trafficCost = assignmentConfigs.trafficCost == null ? 
(isStateful ? DEFAULT_STATEFUL_TRAFFIC_COST : DEFAULT_STATELESS_TRAFFIC_COST)
+            : assignmentConfigs.trafficCost;
+        final int nonOverlapCost = assignmentConfigs.nonOverlapCost == null ? 
(isStateful ? DEFAULT_STATEFUL_NON_OVERLAP_COST : 
DEFAULT_STATELESS_NON_OVERLAP_COST)
+            : assignmentConfigs.nonOverlapCost;
+
+        int cost = 0;
+        for (final TopicPartition tp : topicPartitions) {
+            final Set<String> tpRacks = racksForPartition.get(tp);
+            if (tpRacks == null || tpRacks.isEmpty()) {
+                throw new IllegalStateException("TopicPartition " + tp + " has 
no rack information. Maybe forgot to call canEnableRackAwareAssignor first");
+            }
+            if (!tpRacks.contains(clientRack)) {
+                cost += trafficCost;
+            }
+        }
+
+        if (!inCurrentAssignment) {
+            cost += nonOverlapCost;
+        }
+
+        return cost;
+    }
+
+    private static int getSinkID(final List<UUID> clientList, final 
List<TaskId> taskIdList) {
+        return clientList.size() + taskIdList.size();
+    }
+
+    // For testing. canEnableRackAwareAssignor must be called first
+    long activeTasksCost(final SortedMap<UUID, ClientState> clientStates, 
final SortedSet<TaskId> activeTasks, final boolean isStateful) {
+        final List<UUID> clientList = new ArrayList<>(clientStates.keySet());
+        final List<TaskId> taskIdList = new ArrayList<>(activeTasks);
+        final Map<TaskId, UUID> taskClientMap = new HashMap<>();
+        final Map<UUID, Integer> originalAssignedTaskNumber = new HashMap<>();
+        final Graph<Integer> graph = constructActiveTaskGraph(activeTasks, 
clientList, taskIdList,
+            clientStates, taskClientMap, originalAssignedTaskNumber, 
isStateful);
+        return graph.totalCost();
+    }
+
+    /**
+     * Optimize active task assignment for rack awareness. 
canEnableRackAwareAssignor must be called first
+     * @param clientStates Client states
+     * @param activeTasks Tasks to reassign if needed. They must be assigned 
already in clientStates
+     * @param isStateful Whether the tasks are stateful
+     * @return Total cost after optimization
+     */
+    public long optimizeActiveTasks(final SortedMap<UUID, ClientState> 
clientStates,
+                                    final SortedSet<TaskId> activeTasks,
+                                    final boolean isStateful) {
+        if (activeTasks.isEmpty()) {
+            return 0;
+        }
+
+        final List<UUID> clientList = new ArrayList<>(clientStates.keySet());
+        final List<TaskId> taskIdList = new ArrayList<>(activeTasks);
+        final Map<TaskId, UUID> taskClientMap = new HashMap<>();
+        final Map<UUID, Integer> originalAssignedTaskNumber = new HashMap<>();
+        final Graph<Integer> graph = constructActiveTaskGraph(activeTasks, 
clientList, taskIdList,
+            clientStates, taskClientMap, originalAssignedTaskNumber, 
isStateful);
+
+        graph.solveMinCostFlow();
+        final long cost = graph.totalCost();
+
+        assignActiveTaskFromMinCostFlow(graph, activeTasks, clientList, 
taskIdList,
+            clientStates, originalAssignedTaskNumber, taskClientMap);
+
+        return cost;
+    }
+
+    private Graph<Integer> constructActiveTaskGraph(final SortedSet<TaskId> 
activeTasks,
+                                                    final List<UUID> 
clientList,
+                                                    final List<TaskId> 
taskIdList,
+                                                    final Map<UUID, 
ClientState> clientStates,
+                                                    final Map<TaskId, UUID> 
taskClientMap,
+                                                    final Map<UUID, Integer> 
originalAssignedTaskNumber,
+                                                    final boolean isStateful) {
+        final Graph<Integer> graph = new Graph<>();
+
+        for (final TaskId taskId : activeTasks) {
+            for (final Entry<UUID, ClientState> clientState : 
clientStates.entrySet()) {
+                if (clientState.getValue().hasAssignedTask(taskId)) {
+                    originalAssignedTaskNumber.merge(clientState.getKey(), 1, 
Integer::sum);
+                }
+            }
+        }
+
+        // Make task and client Node id in graph deterministic
+        for (int taskNodeId  = 0; taskNodeId < taskIdList.size(); 
taskNodeId++) {
+            final TaskId taskId = taskIdList.get(taskNodeId);
+            for (int j = 0; j < clientList.size(); j++) {
+                final int clientNodeId = taskIdList.size() + j;
+                final UUID clientId = clientList.get(j);
+
+                final int flow = 
clientStates.get(clientId).hasAssignedTask(taskId) ? 1 : 0;
+                final int cost = getCost(taskId, clientId, flow == 1, 
isStateful);
+                if (flow == 1) {
+                    if (taskClientMap.containsKey(taskId)) {
+                        throw new IllegalArgumentException("Task " + taskId + 
" assigned to multiple clients "
+                            + clientId + ", " + taskClientMap.get(taskId));
+                    }
+                    taskClientMap.put(taskId, clientId);
+                }
+
+                graph.addEdge(taskNodeId, clientNodeId, 1, cost, flow);
+            }
+            if (!taskClientMap.containsKey(taskId)) {
+                throw new IllegalArgumentException("Task " + taskId + " not 
assigned to any client");
+            }
+        }
+
+        final int sinkId = getSinkID(clientList, taskIdList);
+        for (int taskNodeId = 0; taskNodeId < taskIdList.size(); taskNodeId++) 
{
+            graph.addEdge(SOURCE_ID, taskNodeId, 1, 0, 1);
+        }
+
+        // It's possible that some clients have 0 task assign. These clients 
will have 0 tasks assigned
+        // even though it may have lower traffic cost. This is to maintain the 
balance requirement.

Review Comment:
   The balance requirement I meant here is actually keeping 
`originalAssignedTaskNum` even though some clients having 0 task assigned seems 
not making it balance... I'll reword it



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: jira-unsubscr...@kafka.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to