showuon commented on a change in pull request #11493:
URL: https://github.com/apache/kafka/pull/11493#discussion_r805534017



##########
File path: 
streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java
##########
@@ -784,21 +784,33 @@ private void populatePartitionsByHostMaps(final 
Map<HostInfo, Set<TopicPartition
             final ClientMetadata clientMetadata = clientEntry.getValue();
             final ClientState state = clientMetadata.state;
             final SortedSet<String> consumers = clientMetadata.consumers;
+            final Map<String, Integer> threadTaskCounts = new HashMap<>();
 
-            final Map<String, List<TaskId>> activeTaskAssignment = 
assignTasksToThreads(
+            final Map<String, List<TaskId>> activeTaskStatefulAssignment = 
assignStatefulTasksToThreads(
                 state.statefulActiveTasks(),
-                state.statelessActiveTasks(),
                 consumers,
-                state
+                state,
+                threadTaskCounts
             );
 
-            final Map<String, List<TaskId>> standbyTaskAssignment = 
assignTasksToThreads(
+            final Map<String, List<TaskId>> standbyTaskAssignment = 
assignStatefulTasksToThreads(
                 state.standbyTasks(),
-                Collections.emptySet(),
                 consumers,
-                state
+                state,
+                threadTaskCounts
+            );
+
+            final Map<String, List<TaskId>> activeTaskStatelessAssignment = 
assignStatelessTasksToThreads(
+                state.statelessActiveTasks(),
+                consumers,
+                threadTaskCounts
             );
 
+            final Map<String, List<TaskId>> activeTaskAssignment = 
activeTaskStatefulAssignment;

Review comment:
       could we add comment here to mention this is to combine active stateful 
assignment + active stateless assignmet...?

##########
File path: 
streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java
##########
@@ -1029,104 +1041,148 @@ private boolean addClientAssignments(final 
Set<TaskId> statefulTasks,
 
     /**
      * Generate an assignment that tries to preserve thread-level stickiness 
of stateful tasks without violating
-     * balance. The stateful and total task load are both balanced across 
threads. Tasks without previous owners
-     * will be interleaved by group id to spread subtopologies across threads 
and further balance the workload.
+     * balance. The tasks are balanced across threads. Tasks without previous 
owners will be interleaved by
+     * group id to spread subtopologies across threads and further balance the 
workload.
+     * threadLoad is a map that keeps track of task load per thread across 
multiple calls so actives and standbys
+     * are evenly distributed
      */
-    static Map<String, List<TaskId>> assignTasksToThreads(final 
Collection<TaskId> statefulTasksToAssign,
-                                                          final 
Collection<TaskId> statelessTasksToAssign,
-                                                          final 
SortedSet<String> consumers,
-                                                          final ClientState 
state) {
+    static Map<String, List<TaskId>> assignStatefulTasksToThreads(final 
Collection<TaskId> tasksToAssign,
+                                                                  final 
SortedSet<String> consumers,
+                                                                  final 
ClientState state,
+                                                                  final 
Map<String, Integer> threadLoad) {
         final Map<String, List<TaskId>> assignment = new HashMap<>();
         for (final String consumer : consumers) {
             assignment.put(consumer, new ArrayList<>());
         }
 
-        final List<TaskId> unassignedStatelessTasks = new 
ArrayList<>(statelessTasksToAssign);
-        Collections.sort(unassignedStatelessTasks);
-
-        final Iterator<TaskId> unassignedStatelessTasksIter = 
unassignedStatelessTasks.iterator();
+        int totalTasks = tasksToAssign.size();
+        for (final Integer threadTaskCount : threadLoad.values()) {
+            totalTasks += threadTaskCount;
+        }
 
-        final int minStatefulTasksPerThread = (int) Math.floor(((double) 
statefulTasksToAssign.size()) / consumers.size());
-        final PriorityQueue<TaskId> unassignedStatefulTasks = new 
PriorityQueue<>(statefulTasksToAssign);
+        final int minTasksPerThread = (int) Math.floor(((double) totalTasks) / 
consumers.size());
+        final PriorityQueue<TaskId> unassignedTasks = new 
PriorityQueue<>(tasksToAssign);
 
         final Queue<String> consumersToFill = new LinkedList<>();
         // keep track of tasks that we have to skip during the first pass in 
case we can reassign them later
         // using tree-map to make sure the iteration ordering over keys are 
preserved
         final Map<TaskId, String> unassignedTaskToPreviousOwner = new 
TreeMap<>();
 
-        if (!unassignedStatefulTasks.isEmpty()) {
-            // First assign stateful tasks to previous owner, up to the min 
expected tasks/thread
+        if (!unassignedTasks.isEmpty()) {
+            // First assign tasks to previous owner, up to the min expected 
tasks/thread
             for (final String consumer : consumers) {
                 final List<TaskId> threadAssignment = assignment.get(consumer);
 
                 for (final TaskId task : state.prevTasksByLag(consumer)) {
-                    if (unassignedStatefulTasks.contains(task)) {
-                        if (threadAssignment.size() < 
minStatefulTasksPerThread) {
+                    if (unassignedTasks.contains(task)) {
+                        final int threadTaskCount = threadAssignment.size() + 
threadLoad.getOrDefault(consumer, 0);

Review comment:
       `threadLoad` is not changed before the end of the `assignTasksToThreads` 
method. Could we get the value:  `threadLoad.getOrDefault(consumer, 0);` at the 
beginning of the consumers loop?

##########
File path: 
streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java
##########
@@ -1029,104 +1041,148 @@ private boolean addClientAssignments(final 
Set<TaskId> statefulTasks,
 
     /**
      * Generate an assignment that tries to preserve thread-level stickiness 
of stateful tasks without violating
-     * balance. The stateful and total task load are both balanced across 
threads. Tasks without previous owners
-     * will be interleaved by group id to spread subtopologies across threads 
and further balance the workload.
+     * balance. The tasks are balanced across threads. Tasks without previous 
owners will be interleaved by
+     * group id to spread subtopologies across threads and further balance the 
workload.
+     * threadLoad is a map that keeps track of task load per thread across 
multiple calls so actives and standbys
+     * are evenly distributed
      */
-    static Map<String, List<TaskId>> assignTasksToThreads(final 
Collection<TaskId> statefulTasksToAssign,
-                                                          final 
Collection<TaskId> statelessTasksToAssign,
-                                                          final 
SortedSet<String> consumers,
-                                                          final ClientState 
state) {
+    static Map<String, List<TaskId>> assignStatefulTasksToThreads(final 
Collection<TaskId> tasksToAssign,
+                                                                  final 
SortedSet<String> consumers,
+                                                                  final 
ClientState state,
+                                                                  final 
Map<String, Integer> threadLoad) {
         final Map<String, List<TaskId>> assignment = new HashMap<>();
         for (final String consumer : consumers) {
             assignment.put(consumer, new ArrayList<>());
         }
 
-        final List<TaskId> unassignedStatelessTasks = new 
ArrayList<>(statelessTasksToAssign);
-        Collections.sort(unassignedStatelessTasks);
-
-        final Iterator<TaskId> unassignedStatelessTasksIter = 
unassignedStatelessTasks.iterator();
+        int totalTasks = tasksToAssign.size();
+        for (final Integer threadTaskCount : threadLoad.values()) {
+            totalTasks += threadTaskCount;
+        }
 
-        final int minStatefulTasksPerThread = (int) Math.floor(((double) 
statefulTasksToAssign.size()) / consumers.size());
-        final PriorityQueue<TaskId> unassignedStatefulTasks = new 
PriorityQueue<>(statefulTasksToAssign);
+        final int minTasksPerThread = (int) Math.floor(((double) totalTasks) / 
consumers.size());
+        final PriorityQueue<TaskId> unassignedTasks = new 
PriorityQueue<>(tasksToAssign);
 
         final Queue<String> consumersToFill = new LinkedList<>();
         // keep track of tasks that we have to skip during the first pass in 
case we can reassign them later
         // using tree-map to make sure the iteration ordering over keys are 
preserved
         final Map<TaskId, String> unassignedTaskToPreviousOwner = new 
TreeMap<>();
 
-        if (!unassignedStatefulTasks.isEmpty()) {
-            // First assign stateful tasks to previous owner, up to the min 
expected tasks/thread
+        if (!unassignedTasks.isEmpty()) {
+            // First assign tasks to previous owner, up to the min expected 
tasks/thread
             for (final String consumer : consumers) {
                 final List<TaskId> threadAssignment = assignment.get(consumer);
 
                 for (final TaskId task : state.prevTasksByLag(consumer)) {
-                    if (unassignedStatefulTasks.contains(task)) {
-                        if (threadAssignment.size() < 
minStatefulTasksPerThread) {
+                    if (unassignedTasks.contains(task)) {
+                        final int threadTaskCount = threadAssignment.size() + 
threadLoad.getOrDefault(consumer, 0);
+                        if (threadTaskCount < minTasksPerThread) {
                             threadAssignment.add(task);
-                            unassignedStatefulTasks.remove(task);
+                            unassignedTasks.remove(task);
                         } else {
                             unassignedTaskToPreviousOwner.put(task, consumer);
                         }
                     }
                 }
 
-                if (threadAssignment.size() < minStatefulTasksPerThread) {
+                final int threadTaskCount = threadAssignment.size() + 
threadLoad.getOrDefault(consumer, 0);
+                if (threadTaskCount < minTasksPerThread) {
                     consumersToFill.offer(consumer);
                 }
             }
 
             // Next interleave remaining unassigned tasks amongst unfilled 
consumers
             while (!consumersToFill.isEmpty()) {
-                final TaskId task = unassignedStatefulTasks.poll();
+                final TaskId task = unassignedTasks.poll();
                 if (task != null) {
                     final String consumer = consumersToFill.poll();
                     final List<TaskId> threadAssignment = 
assignment.get(consumer);
                     threadAssignment.add(task);
-                    if (threadAssignment.size() < minStatefulTasksPerThread) {
+                    final int threadTaskCount = threadAssignment.size() + 
threadLoad.getOrDefault(consumer, 0);
+                    if (threadTaskCount < minTasksPerThread) {
                         consumersToFill.offer(consumer);
                     }
                 } else {
                     throw new TaskAssignmentException("Ran out of unassigned 
stateful tasks but some members were not at capacity");
                 }
             }
 
-            // At this point all consumers are at the min capacity, so there 
may be up to N - 1 unassigned
-            // stateful tasks still remaining that should now be distributed 
over the consumers
-            if (!unassignedStatefulTasks.isEmpty()) {
-                consumersToFill.addAll(consumers);
+            // At this point all consumers are at the min or min + 1 capacity,
+            // the tasks still remaining that should now be distributed over 
the consumers that are still
+            // at min capacity
+            if (!unassignedTasks.isEmpty()) {
+                for (final String consumer : consumers) {
+                    final int taskCount = assignment.get(consumer).size() + 
threadLoad.getOrDefault(consumer, 0);
+                    if (taskCount == minTasksPerThread) {
+                        consumersToFill.add(consumer);

Review comment:
       I think the original algorithm assumes that when reaching this step, all 
consumers are at the min capacity. So that it could just do this to set 
`comsuersToFill`:
   ```java
   if (!unassignedStatefulTasks.isEmpty())
          consumersToFill.addAll(consumers);
   ```
   
   But I didn't see the where we change this algorithm. Please let me know 
where I missed. Thanks.

##########
File path: 
streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java
##########
@@ -1029,104 +1041,148 @@ private boolean addClientAssignments(final 
Set<TaskId> statefulTasks,
 
     /**
      * Generate an assignment that tries to preserve thread-level stickiness 
of stateful tasks without violating
-     * balance. The stateful and total task load are both balanced across 
threads. Tasks without previous owners
-     * will be interleaved by group id to spread subtopologies across threads 
and further balance the workload.
+     * balance. The tasks are balanced across threads. Tasks without previous 
owners will be interleaved by
+     * group id to spread subtopologies across threads and further balance the 
workload.
+     * threadLoad is a map that keeps track of task load per thread across 
multiple calls so actives and standbys
+     * are evenly distributed
      */
-    static Map<String, List<TaskId>> assignTasksToThreads(final 
Collection<TaskId> statefulTasksToAssign,
-                                                          final 
Collection<TaskId> statelessTasksToAssign,
-                                                          final 
SortedSet<String> consumers,
-                                                          final ClientState 
state) {
+    static Map<String, List<TaskId>> assignStatefulTasksToThreads(final 
Collection<TaskId> tasksToAssign,
+                                                                  final 
SortedSet<String> consumers,
+                                                                  final 
ClientState state,
+                                                                  final 
Map<String, Integer> threadLoad) {
         final Map<String, List<TaskId>> assignment = new HashMap<>();
         for (final String consumer : consumers) {
             assignment.put(consumer, new ArrayList<>());
         }
 
-        final List<TaskId> unassignedStatelessTasks = new 
ArrayList<>(statelessTasksToAssign);
-        Collections.sort(unassignedStatelessTasks);
-
-        final Iterator<TaskId> unassignedStatelessTasksIter = 
unassignedStatelessTasks.iterator();
+        int totalTasks = tasksToAssign.size();
+        for (final Integer threadTaskCount : threadLoad.values()) {
+            totalTasks += threadTaskCount;
+        }
 
-        final int minStatefulTasksPerThread = (int) Math.floor(((double) 
statefulTasksToAssign.size()) / consumers.size());
-        final PriorityQueue<TaskId> unassignedStatefulTasks = new 
PriorityQueue<>(statefulTasksToAssign);
+        final int minTasksPerThread = (int) Math.floor(((double) totalTasks) / 
consumers.size());
+        final PriorityQueue<TaskId> unassignedTasks = new 
PriorityQueue<>(tasksToAssign);
 
         final Queue<String> consumersToFill = new LinkedList<>();
         // keep track of tasks that we have to skip during the first pass in 
case we can reassign them later
         // using tree-map to make sure the iteration ordering over keys are 
preserved
         final Map<TaskId, String> unassignedTaskToPreviousOwner = new 
TreeMap<>();
 
-        if (!unassignedStatefulTasks.isEmpty()) {
-            // First assign stateful tasks to previous owner, up to the min 
expected tasks/thread
+        if (!unassignedTasks.isEmpty()) {
+            // First assign tasks to previous owner, up to the min expected 
tasks/thread
             for (final String consumer : consumers) {
                 final List<TaskId> threadAssignment = assignment.get(consumer);
 
                 for (final TaskId task : state.prevTasksByLag(consumer)) {
-                    if (unassignedStatefulTasks.contains(task)) {
-                        if (threadAssignment.size() < 
minStatefulTasksPerThread) {
+                    if (unassignedTasks.contains(task)) {
+                        final int threadTaskCount = threadAssignment.size() + 
threadLoad.getOrDefault(consumer, 0);

Review comment:
       Also, to compute the `threadTaskCount`, we always need to add 
`threadLoad.getOrDefault(consumer, 0)` each time. Could we just make the 
`minTasksPerThread = minTasksPerThread - threadLoad.getOrDefault(consumer, 0)`, 
and add comments to it to make it clear, so that we don't need to do this 
adding `threadLoad.getOrDefault(consumer, 0)` each time? WDYT?

##########
File path: 
streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java
##########
@@ -1140,6 +1196,13 @@ private boolean addClientAssignments(final Set<TaskId> 
statefulTasks,
             consumersToFill.offer(consumer);
         }
 
+        // Update threadLoad
+        for (final Map.Entry<String, List<TaskId>> taskEntry : 
assignment.entrySet()) {
+            final String consumer = taskEntry.getKey();
+            final int totalCount = threadLoad.getOrDefault(consumer, 0) + 
taskEntry.getValue().size();
+            threadLoad.put(consumer, totalCount);
+        }
+

Review comment:
       should we keep the `threadLoad` at the end of 
`assignStatelessTasksToThreads`. We don't need it after 
`assignStatelessTasksToThreads`, right?

##########
File path: 
streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignorTest.java
##########
@@ -997,6 +997,77 @@ public void testAssignWithStandbyReplicas() {
         assertEquals(standbyPartitionsByHost, info20.standbyPartitionByHost());
     }
 
+    @Test
+    public void testAssignWithStandbyReplicasBalance() {
+        builder.addSource(null, "source1", null, null, null, "topic1");
+        builder.addProcessor("processor", new MockApiProcessorSupplier<>(), 
"source1");
+        builder.addStateStore(new MockKeyValueStoreBuilder("store1", false), 
"processor");
+
+        final List<String> topics = asList("topic1");
+
+        final Set<TaskId> prevTasks00 = mkSet(TASK_0_0);
+        final Set<TaskId> standbyTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2);
+
+        createMockTaskManager(prevTasks00, standbyTasks);

Review comment:
       The `prevTasks00` and `standbyTasks` are meaningless in this test. If we 
just want to create a mock taskManager, could we use 
`createMockTaskManager(EMPTY_TASKS, EMPTY_TASKS);` instead?

##########
File path: 
streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java
##########
@@ -1029,104 +1041,148 @@ private boolean addClientAssignments(final 
Set<TaskId> statefulTasks,
 
     /**
      * Generate an assignment that tries to preserve thread-level stickiness 
of stateful tasks without violating
-     * balance. The stateful and total task load are both balanced across 
threads. Tasks without previous owners
-     * will be interleaved by group id to spread subtopologies across threads 
and further balance the workload.
+     * balance. The tasks are balanced across threads. Tasks without previous 
owners will be interleaved by
+     * group id to spread subtopologies across threads and further balance the 
workload.
+     * threadLoad is a map that keeps track of task load per thread across 
multiple calls so actives and standbys
+     * are evenly distributed
      */
-    static Map<String, List<TaskId>> assignTasksToThreads(final 
Collection<TaskId> statefulTasksToAssign,
-                                                          final 
Collection<TaskId> statelessTasksToAssign,
-                                                          final 
SortedSet<String> consumers,
-                                                          final ClientState 
state) {
+    static Map<String, List<TaskId>> assignStatefulTasksToThreads(final 
Collection<TaskId> tasksToAssign,
+                                                                  final 
SortedSet<String> consumers,
+                                                                  final 
ClientState state,
+                                                                  final 
Map<String, Integer> threadLoad) {
         final Map<String, List<TaskId>> assignment = new HashMap<>();
         for (final String consumer : consumers) {
             assignment.put(consumer, new ArrayList<>());
         }
 
-        final List<TaskId> unassignedStatelessTasks = new 
ArrayList<>(statelessTasksToAssign);
-        Collections.sort(unassignedStatelessTasks);
-
-        final Iterator<TaskId> unassignedStatelessTasksIter = 
unassignedStatelessTasks.iterator();
+        int totalTasks = tasksToAssign.size();
+        for (final Integer threadTaskCount : threadLoad.values()) {
+            totalTasks += threadTaskCount;
+        }
 
-        final int minStatefulTasksPerThread = (int) Math.floor(((double) 
statefulTasksToAssign.size()) / consumers.size());
-        final PriorityQueue<TaskId> unassignedStatefulTasks = new 
PriorityQueue<>(statefulTasksToAssign);
+        final int minTasksPerThread = (int) Math.floor(((double) totalTasks) / 
consumers.size());
+        final PriorityQueue<TaskId> unassignedTasks = new 
PriorityQueue<>(tasksToAssign);
 
         final Queue<String> consumersToFill = new LinkedList<>();
         // keep track of tasks that we have to skip during the first pass in 
case we can reassign them later
         // using tree-map to make sure the iteration ordering over keys are 
preserved
         final Map<TaskId, String> unassignedTaskToPreviousOwner = new 
TreeMap<>();
 
-        if (!unassignedStatefulTasks.isEmpty()) {
-            // First assign stateful tasks to previous owner, up to the min 
expected tasks/thread
+        if (!unassignedTasks.isEmpty()) {
+            // First assign tasks to previous owner, up to the min expected 
tasks/thread
             for (final String consumer : consumers) {
                 final List<TaskId> threadAssignment = assignment.get(consumer);
 
                 for (final TaskId task : state.prevTasksByLag(consumer)) {
-                    if (unassignedStatefulTasks.contains(task)) {
-                        if (threadAssignment.size() < 
minStatefulTasksPerThread) {
+                    if (unassignedTasks.contains(task)) {
+                        final int threadTaskCount = threadAssignment.size() + 
threadLoad.getOrDefault(consumer, 0);
+                        if (threadTaskCount < minTasksPerThread) {
                             threadAssignment.add(task);
-                            unassignedStatefulTasks.remove(task);
+                            unassignedTasks.remove(task);
                         } else {
                             unassignedTaskToPreviousOwner.put(task, consumer);
                         }
                     }
                 }
 
-                if (threadAssignment.size() < minStatefulTasksPerThread) {
+                final int threadTaskCount = threadAssignment.size() + 
threadLoad.getOrDefault(consumer, 0);
+                if (threadTaskCount < minTasksPerThread) {
                     consumersToFill.offer(consumer);
                 }
             }
 
             // Next interleave remaining unassigned tasks amongst unfilled 
consumers
             while (!consumersToFill.isEmpty()) {
-                final TaskId task = unassignedStatefulTasks.poll();
+                final TaskId task = unassignedTasks.poll();
                 if (task != null) {
                     final String consumer = consumersToFill.poll();
                     final List<TaskId> threadAssignment = 
assignment.get(consumer);
                     threadAssignment.add(task);
-                    if (threadAssignment.size() < minStatefulTasksPerThread) {
+                    final int threadTaskCount = threadAssignment.size() + 
threadLoad.getOrDefault(consumer, 0);
+                    if (threadTaskCount < minTasksPerThread) {
                         consumersToFill.offer(consumer);
                     }
                 } else {
                     throw new TaskAssignmentException("Ran out of unassigned 
stateful tasks but some members were not at capacity");
                 }
             }
 
-            // At this point all consumers are at the min capacity, so there 
may be up to N - 1 unassigned
-            // stateful tasks still remaining that should now be distributed 
over the consumers
-            if (!unassignedStatefulTasks.isEmpty()) {
-                consumersToFill.addAll(consumers);
+            // At this point all consumers are at the min or min + 1 capacity,
+            // the tasks still remaining that should now be distributed over 
the consumers that are still
+            // at min capacity
+            if (!unassignedTasks.isEmpty()) {
+                for (final String consumer : consumers) {
+                    final int taskCount = assignment.get(consumer).size() + 
threadLoad.getOrDefault(consumer, 0);
+                    if (taskCount == minTasksPerThread) {
+                        consumersToFill.add(consumer);
+                    }
+                }
 
                 // Go over the tasks we skipped earlier and assign them to 
their previous owner when possible
                 for (final Map.Entry<TaskId, String> taskEntry : 
unassignedTaskToPreviousOwner.entrySet()) {
                     final TaskId task = taskEntry.getKey();
                     final String consumer = taskEntry.getValue();
-                    if (consumersToFill.contains(consumer) && 
unassignedStatefulTasks.contains(task)) {
+                    if (consumersToFill.contains(consumer) && 
unassignedTasks.contains(task)) {
                         assignment.get(consumer).add(task);
-                        unassignedStatefulTasks.remove(task);
+                        unassignedTasks.remove(task);
                         // Remove this consumer since we know it is now at 
minCapacity + 1
                         consumersToFill.remove(consumer);
                     }
                 }
 
                 // Now just distribute the remaining unassigned stateful tasks 
over the consumers still at min capacity
-                for (final TaskId task : unassignedStatefulTasks) {
+                for (final TaskId task : unassignedTasks) {
                     final String consumer = consumersToFill.poll();
                     final List<TaskId> threadAssignment = 
assignment.get(consumer);
                     threadAssignment.add(task);
                 }
+            }
+        }
+        // Update threadLoad
+        for (final Map.Entry<String, List<TaskId>> taskEntry : 
assignment.entrySet()) {
+            final String consumer = taskEntry.getKey();
+            final int totalCount = threadLoad.getOrDefault(consumer, 0) + 
taskEntry.getValue().size();
+            threadLoad.put(consumer, totalCount);
+        }
 
+        return assignment;
+    }
 
-                // There must be at least one consumer still at min capacity 
while all the others are at min
-                // capacity + 1, so start distributing stateless tasks to get 
all consumers back to the same count
-                while (unassignedStatelessTasksIter.hasNext()) {
-                    final String consumer = consumersToFill.poll();
-                    if (consumer != null) {
-                        final TaskId task = 
unassignedStatelessTasksIter.next();
-                        unassignedStatelessTasksIter.remove();
-                        assignment.get(consumer).add(task);
-                    } else {
-                        break;
-                    }
-                }
+    static Map<String, List<TaskId>> assignStatelessTasksToThreads(final 
Collection<TaskId> statelessTasksToAssign,
+                                                                  final 
SortedSet<String> consumers,
+                                                                  final 
Map<String, Integer> threadLoad) {
+        final List<TaskId> tasksToAssign = new 
ArrayList<>(statelessTasksToAssign);
+        Collections.sort(tasksToAssign);
+        final Map<String, List<TaskId>> assignment = new HashMap<>();
+        for (final String consumer : consumers) {
+            assignment.put(consumer, new ArrayList<>());
+        }
+
+        int maxThreadLoad = 0;
+        for (final int load : threadLoad.values()) {
+            maxThreadLoad = Integer.max(maxThreadLoad, load);

Review comment:
       I was thinking we don't need these loops to get the maxThreadLoad and 
`consumersToFill`. Do you think if we can put `assignStatelessTasksToThreads` 
at the end of `assignStatefulTasksToThreads`. So that we will have 
`maxThreadLoad` and `consumersToFill` directly. So in the 
`assignStatefulTasksToThreads` method signature, we will have one more 
parameter like: `boolean shouldAssignStatelessTasks`, or 
`isActiveTasksAssignment`... something like that. WDYT?

##########
File path: 
streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignorTest.java
##########
@@ -997,6 +997,77 @@ public void testAssignWithStandbyReplicas() {
         assertEquals(standbyPartitionsByHost, info20.standbyPartitionByHost());
     }
 
+    @Test
+    public void testAssignWithStandbyReplicasBalance() {
+        builder.addSource(null, "source1", null, null, null, "topic1");
+        builder.addProcessor("processor", new MockApiProcessorSupplier<>(), 
"source1");
+        builder.addStateStore(new MockKeyValueStoreBuilder("store1", false), 
"processor");
+
+        final List<String> topics = asList("topic1");
+
+        final Set<TaskId> prevTasks00 = mkSet(TASK_0_0);
+        final Set<TaskId> standbyTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2);
+
+        createMockTaskManager(prevTasks00, standbyTasks);
+        adminClient = 
createMockAdminClientForAssignor(getTopicPartitionOffsetsMap(
+                singletonList(APPLICATION_ID + "-store1-changelog"),
+                singletonList(3))
+        );
+        
configurePartitionAssignorWith(Collections.singletonMap(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG,
 1));
+
+        subscriptions.put("consumer10",
+                new Subscription(
+                        topics,
+                        getInfo(UUID_1, EMPTY_TASKS, EMPTY_TASKS, 
USER_END_POINT).encode()));
+        subscriptions.put("consumer11",
+                new Subscription(
+                        emptyList(),
+                        getInfo(UUID_1, EMPTY_TASKS, EMPTY_TASKS, 
USER_END_POINT).encode()));
+        subscriptions.put("consumer12",
+                new Subscription(
+                        emptyList(),
+                        getInfo(UUID_1, EMPTY_TASKS, EMPTY_TASKS, 
USER_END_POINT).encode()));
+        subscriptions.put("consumer13",
+                new Subscription(
+                        emptyList(),
+                        getInfo(UUID_1, EMPTY_TASKS, EMPTY_TASKS, 
USER_END_POINT).encode()));
+        subscriptions.put("consumer20",
+                new Subscription(
+                        topics,
+                        getInfo(UUID_2, EMPTY_TASKS, EMPTY_TASKS, 
OTHER_END_POINT).encode()));
+        subscriptions.put("consumer21",
+                new Subscription(
+                        topics,
+                        getInfo(UUID_2, EMPTY_TASKS, EMPTY_TASKS, 
OTHER_END_POINT).encode()));
+        subscriptions.put("consumer22",
+                new Subscription(
+                        topics,
+                        getInfo(UUID_2, EMPTY_TASKS, EMPTY_TASKS, 
OTHER_END_POINT).encode()));
+
+
+        final Map<String, Assignment> assignments =
+                partitionAssignor.assign(metadata, new 
GroupSubscription(subscriptions)).groupAssignment();
+
+        // Consumers
+        final AssignmentInfo info10 = 
AssignmentInfo.decode(assignments.get("consumer10").userData());
+        final AssignmentInfo info11 = 
AssignmentInfo.decode(assignments.get("consumer11").userData());
+        final AssignmentInfo info12 = 
AssignmentInfo.decode(assignments.get("consumer12").userData());
+        final AssignmentInfo info13 = 
AssignmentInfo.decode(assignments.get("consumer13").userData());
+        final AssignmentInfo info20 = 
AssignmentInfo.decode(assignments.get("consumer20").userData());
+        final AssignmentInfo info21 = 
AssignmentInfo.decode(assignments.get("consumer21").userData());
+        final AssignmentInfo info22 = 
AssignmentInfo.decode(assignments.get("consumer22").userData());
+
+        // Check each consumer has no more than 1 task
+        // (client 1 has more consumers than needed so consumer13 won't get a 
task)
+        assertEquals(1, info10.activeTasks().size() + 
info10.standbyTasks().size());
+        assertEquals(1, info11.activeTasks().size() + 
info11.standbyTasks().size());
+        assertEquals(1, info12.activeTasks().size() + 
info12.standbyTasks().size());
+        assertEquals(0, info13.activeTasks().size() + 
info13.standbyTasks().size());
+        assertEquals(1, info20.activeTasks().size() + 
info20.standbyTasks().size());
+        assertEquals(1, info21.activeTasks().size() + 
info21.standbyTasks().size());
+        assertEquals(1, info22.activeTasks().size() + 
info22.standbyTasks().size());

Review comment:
       I'm thinking we can just verify the assignment size is `< 2`. We don't 
care about the exact number, right?

##########
File path: 
streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignorTest.java
##########
@@ -997,6 +997,77 @@ public void testAssignWithStandbyReplicas() {
         assertEquals(standbyPartitionsByHost, info20.standbyPartitionByHost());
     }
 
+    @Test
+    public void testAssignWithStandbyReplicasBalance() {

Review comment:
       Could we add some tests to cover the are less than total tasks (active + 
standby), but we can still distribute them evenly, ex:
   tasks: TASK_0_0, TASK_0_1, TASK_0_2
   consumers: 2 in thread1, 3 in thread2
   We expected each thread should have 3 tasks, and distribute evenly. 

##########
File path: 
streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java
##########
@@ -1029,104 +1041,148 @@ private boolean addClientAssignments(final 
Set<TaskId> statefulTasks,
 
     /**
      * Generate an assignment that tries to preserve thread-level stickiness 
of stateful tasks without violating
-     * balance. The stateful and total task load are both balanced across 
threads. Tasks without previous owners
-     * will be interleaved by group id to spread subtopologies across threads 
and further balance the workload.
+     * balance. The tasks are balanced across threads. Tasks without previous 
owners will be interleaved by
+     * group id to spread subtopologies across threads and further balance the 
workload.
+     * threadLoad is a map that keeps track of task load per thread across 
multiple calls so actives and standbys
+     * are evenly distributed
      */
-    static Map<String, List<TaskId>> assignTasksToThreads(final 
Collection<TaskId> statefulTasksToAssign,
-                                                          final 
Collection<TaskId> statelessTasksToAssign,
-                                                          final 
SortedSet<String> consumers,
-                                                          final ClientState 
state) {
+    static Map<String, List<TaskId>> assignStatefulTasksToThreads(final 
Collection<TaskId> tasksToAssign,
+                                                                  final 
SortedSet<String> consumers,
+                                                                  final 
ClientState state,
+                                                                  final 
Map<String, Integer> threadLoad) {
         final Map<String, List<TaskId>> assignment = new HashMap<>();
         for (final String consumer : consumers) {
             assignment.put(consumer, new ArrayList<>());
         }
 
-        final List<TaskId> unassignedStatelessTasks = new 
ArrayList<>(statelessTasksToAssign);
-        Collections.sort(unassignedStatelessTasks);
-
-        final Iterator<TaskId> unassignedStatelessTasksIter = 
unassignedStatelessTasks.iterator();
+        int totalTasks = tasksToAssign.size();
+        for (final Integer threadTaskCount : threadLoad.values()) {
+            totalTasks += threadTaskCount;
+        }
 
-        final int minStatefulTasksPerThread = (int) Math.floor(((double) 
statefulTasksToAssign.size()) / consumers.size());
-        final PriorityQueue<TaskId> unassignedStatefulTasks = new 
PriorityQueue<>(statefulTasksToAssign);
+        final int minTasksPerThread = (int) Math.floor(((double) totalTasks) / 
consumers.size());
+        final PriorityQueue<TaskId> unassignedTasks = new 
PriorityQueue<>(tasksToAssign);
 
         final Queue<String> consumersToFill = new LinkedList<>();
         // keep track of tasks that we have to skip during the first pass in 
case we can reassign them later
         // using tree-map to make sure the iteration ordering over keys are 
preserved
         final Map<TaskId, String> unassignedTaskToPreviousOwner = new 
TreeMap<>();
 
-        if (!unassignedStatefulTasks.isEmpty()) {
-            // First assign stateful tasks to previous owner, up to the min 
expected tasks/thread
+        if (!unassignedTasks.isEmpty()) {
+            // First assign tasks to previous owner, up to the min expected 
tasks/thread
             for (final String consumer : consumers) {
                 final List<TaskId> threadAssignment = assignment.get(consumer);
 
                 for (final TaskId task : state.prevTasksByLag(consumer)) {
-                    if (unassignedStatefulTasks.contains(task)) {
-                        if (threadAssignment.size() < 
minStatefulTasksPerThread) {
+                    if (unassignedTasks.contains(task)) {
+                        final int threadTaskCount = threadAssignment.size() + 
threadLoad.getOrDefault(consumer, 0);
+                        if (threadTaskCount < minTasksPerThread) {
                             threadAssignment.add(task);
-                            unassignedStatefulTasks.remove(task);
+                            unassignedTasks.remove(task);
                         } else {
                             unassignedTaskToPreviousOwner.put(task, consumer);
                         }
                     }
                 }
 
-                if (threadAssignment.size() < minStatefulTasksPerThread) {
+                final int threadTaskCount = threadAssignment.size() + 
threadLoad.getOrDefault(consumer, 0);
+                if (threadTaskCount < minTasksPerThread) {
                     consumersToFill.offer(consumer);
                 }
             }
 
             // Next interleave remaining unassigned tasks amongst unfilled 
consumers
             while (!consumersToFill.isEmpty()) {
-                final TaskId task = unassignedStatefulTasks.poll();
+                final TaskId task = unassignedTasks.poll();
                 if (task != null) {
                     final String consumer = consumersToFill.poll();
                     final List<TaskId> threadAssignment = 
assignment.get(consumer);
                     threadAssignment.add(task);
-                    if (threadAssignment.size() < minStatefulTasksPerThread) {
+                    final int threadTaskCount = threadAssignment.size() + 
threadLoad.getOrDefault(consumer, 0);
+                    if (threadTaskCount < minTasksPerThread) {
                         consumersToFill.offer(consumer);
                     }
                 } else {
                     throw new TaskAssignmentException("Ran out of unassigned 
stateful tasks but some members were not at capacity");
                 }
             }
 
-            // At this point all consumers are at the min capacity, so there 
may be up to N - 1 unassigned
-            // stateful tasks still remaining that should now be distributed 
over the consumers
-            if (!unassignedStatefulTasks.isEmpty()) {
-                consumersToFill.addAll(consumers);
+            // At this point all consumers are at the min or min + 1 capacity,

Review comment:
       Where does the `min + 1` case come from? Could you elaborate more? 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: jira-unsubscr...@kafka.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to