lihaosky commented on code in PR #14714:
URL: https://github.com/apache/kafka/pull/14714#discussion_r1409723863


##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/BalanceSubtopologyGraphConstructor.java:
##########
@@ -0,0 +1,186 @@
+package org.apache.kafka.streams.processor.internals.assignment;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Objects;
+import java.util.Set;
+import java.util.SortedMap;
+import java.util.SortedSet;
+import java.util.TreeMap;
+import java.util.TreeSet;
+import java.util.UUID;
+import java.util.function.BiConsumer;
+import java.util.function.BiPredicate;
+import org.apache.kafka.streams.processor.TaskId;
+import 
org.apache.kafka.streams.processor.internals.TopologyMetadata.Subtopology;
+import 
org.apache.kafka.streams.processor.internals.assignment.RackAwareTaskAssignor.GetCostFunction;
+
+public class BalanceSubtopologyGraphConstructor implements 
RackAwareGraphConstructor {
+
+    private final Map<Subtopology, Set<TaskId>> tasksForTopicGroup;
+
+    public BalanceSubtopologyGraphConstructor(final Map<Subtopology, 
Set<TaskId>> tasksForTopicGroup) {
+        this.tasksForTopicGroup = tasksForTopicGroup;
+    }
+
+    @Override
+    public int getSinkNodeID(final List<TaskId> taskIdList, final List<UUID> 
clientList,
+        final Map<Subtopology, Set<TaskId>> tasksForTopicGroup) {
+        return clientList.size() + taskIdList.size() + clientList.size() * 
tasksForTopicGroup.size();
+    }
+
+
+    @Override
+    public int getClientNodeId(final int clientIndex, final List<TaskId> 
taskIdList, final List<UUID> clientList, final int topicGroupIndex) {
+        return taskIdList.size() + clientList.size() * topicGroupIndex + 
clientIndex;
+    }
+
+    @Override
+    public int getClientIndex(final int clientNodeId, final List<TaskId> 
taskIdList, final List<UUID> clientList, final int topicGroupIndex) {
+        return clientNodeId - taskIdList.size() - clientList.size() * 
topicGroupIndex;
+    }
+
+    private static int getSecondStageClientNodeId(final List<TaskId> 
taskIdList, final List<UUID> clientList, final Map<Subtopology, Set<TaskId>> 
tasksForTopicGroup, final int clientIndex) {
+        return taskIdList.size() + clientList.size() * 
tasksForTopicGroup.size() + clientIndex;
+    }
+
+    @Override
+    public Graph<Integer> constructTaskGraph(final List<UUID> clientList,
+        final List<TaskId> taskIdList, final SortedMap<UUID, ClientState> 
clientStates,
+        final Map<TaskId, UUID> taskClientMap, final Map<UUID, Integer> 
originalAssignedTaskNumber,
+        final BiPredicate<ClientState, TaskId> hasAssignedTask, final 
GetCostFunction getCostFunction, final int trafficCost,
+        final int nonOverlapCost, final boolean hasReplica, final boolean 
isStandby) {
+        final Graph<Integer> graph = new Graph<>();
+
+        for (final TaskId taskId : taskIdList) {
+            for (final Entry<UUID, ClientState> clientState : 
clientStates.entrySet()) {
+                if (hasAssignedTask.test(clientState.getValue(), taskId)) {
+                    originalAssignedTaskNumber.merge(clientState.getKey(), 1, 
Integer::sum);
+                }
+            }
+        }
+
+        // TODO: validate tasks in tasksForTopicGroup and taskIdList
+        final SortedMap<Subtopology, Set<TaskId>> sortedTasksForTopicGroup = 
new TreeMap<>(tasksForTopicGroup);
+        final int sinkId = getSinkNodeID(taskIdList, clientList, 
tasksForTopicGroup);
+
+        int taskNodeId = 0;
+        int topicGroupIndex = 0;
+        for (final Entry<Subtopology, Set<TaskId>> kv : 
sortedTasksForTopicGroup.entrySet()) {
+            final SortedSet<TaskId> taskIds = new TreeSet<>(kv.getValue());
+            for (int clientIndex = 0; clientIndex < clientList.size(); 
clientIndex++) {
+                final UUID processId = clientList.get(clientIndex);
+                final int clientNodeId = getClientNodeId(clientIndex, 
taskIdList, clientList, topicGroupIndex);
+                int startingTaskNodeId = taskNodeId;
+                for (final TaskId taskId : taskIds) {
+                    final int flow = 
hasAssignedTask.test(clientStates.get(processId), taskId) ? 1 : 0;
+                    graph.addEdge(startingTaskNodeId, clientNodeId, 1, 
getCostFunction.getCost(taskId, processId, false, trafficCost, nonOverlapCost, 
isStandby), flow);
+                    graph.addEdge(SOURCE_ID, startingTaskNodeId, 1, 0, 0);
+                    startingTaskNodeId++;
+                }
+
+                final int secondStageClientNodeId = 
getSecondStageClientNodeId(taskIdList, clientList, tasksForTopicGroup, 
clientIndex);

Review Comment:
   Right, I initially put task in outer loop and found this issue. Will make it 
more readable.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: jira-unsubscr...@kafka.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to