This is an automated email from the ASF dual-hosted git repository.

ableegoldman pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 370e5ea1f81 KAFKA-15045: (KIP-924 pt. 15) Implement 
#defaultStandbyTaskAssignment and finish rack-aware standby optimization 
(#16129)
370e5ea1f81 is described below

commit 370e5ea1f81b0cd64760f740981a388b76b5e29d
Author: Antoine Pourchet <anto...@responsive.dev>
AuthorDate: Thu May 30 16:11:33 2024 -0600

    KAFKA-15045: (KIP-924 pt. 15) Implement #defaultStandbyTaskAssignment and 
finish rack-aware standby optimization (#16129)
    
    This fills in the implementation details of the standby task assignment 
utility functions within TaskAssignmentUtils.
    
    Reviewers: Anna Sophie Blee-Goldman <ableegold...@apache.org>
---
 .../assignment/KafkaStreamsAssignment.java         |  51 ++-
 .../processor/assignment/TaskAssignmentUtils.java  | 481 ++++++++++++++++++++-
 .../assignment/assignors/StickyTaskAssignor.java   |   2 +-
 .../internals/StreamsPartitionAssignor.java        |  53 +--
 .../internals/assignment/ClientState.java          |  16 +-
 .../internals/assignment/ClientStateTask.java      |   2 +-
 .../assignment/ConstrainedPrioritySet.java         |  14 +-
 .../assignment/DefaultApplicationState.java        |   6 +-
 .../assignment/RackAwareTaskAssignor.java          |   2 +-
 9 files changed, 543 insertions(+), 84 deletions(-)

diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/assignment/KafkaStreamsAssignment.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/assignment/KafkaStreamsAssignment.java
index a4947c36467..f5205c8422b 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/assignment/KafkaStreamsAssignment.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/assignment/KafkaStreamsAssignment.java
@@ -16,11 +16,13 @@
  */
 package org.apache.kafka.streams.processor.assignment;
 
+import static java.util.Collections.unmodifiableMap;
+
 import java.time.Instant;
-import java.util.HashSet;
 import java.util.Map;
 import java.util.Optional;
 import java.util.Set;
+import java.util.function.Function;
 import java.util.stream.Collectors;
 import org.apache.kafka.streams.processor.TaskId;
 
@@ -31,7 +33,7 @@ import org.apache.kafka.streams.processor.TaskId;
 public class KafkaStreamsAssignment {
 
     private final ProcessId processId;
-    private final Map<TaskId, AssignedTask> assignment;
+    private final Map<TaskId, AssignedTask> tasks;
     private final Optional<Instant> followupRebalanceDeadline;
 
     /**
@@ -45,7 +47,8 @@ public class KafkaStreamsAssignment {
      * @return a new KafkaStreamsAssignment object with the given processId 
and assignment
      */
     public static KafkaStreamsAssignment of(final ProcessId processId, final 
Set<AssignedTask> assignment) {
-        return new KafkaStreamsAssignment(processId, assignment, 
Optional.empty());
+        final Map<TaskId, AssignedTask> tasks = 
assignment.stream().collect(Collectors.toMap(AssignedTask::id, 
Function.identity()));
+        return new KafkaStreamsAssignment(processId, tasks, Optional.empty());
     }
 
     /**
@@ -62,14 +65,14 @@ public class KafkaStreamsAssignment {
      * @return a new KafkaStreamsAssignment object with the same processId and 
assignment but with the given rebalanceDeadline
      */
     public KafkaStreamsAssignment withFollowupRebalance(final Instant 
rebalanceDeadline) {
-        return new KafkaStreamsAssignment(this.processId(), this.assignment(), 
Optional.of(rebalanceDeadline));
+        return new KafkaStreamsAssignment(this.processId(), this.tasks(), 
Optional.of(rebalanceDeadline));
     }
 
     private KafkaStreamsAssignment(final ProcessId processId,
-                                   final Set<AssignedTask> assignment,
+                                   final Map<TaskId, AssignedTask> tasks,
                                    final Optional<Instant> 
followupRebalanceDeadline) {
         this.processId = processId;
-        this.assignment = 
assignment.stream().collect(Collectors.toMap(AssignedTask::id, t -> t));
+        this.tasks = tasks;
         this.followupRebalanceDeadline = followupRebalanceDeadline;
     }
 
@@ -83,24 +86,18 @@ public class KafkaStreamsAssignment {
 
     /**
      *
-     * @return a set of assigned tasks that are part of this {@code 
KafkaStreamsAssignment}
+     * @return a read-only set of assigned tasks that are part of this {@code 
KafkaStreamsAssignment}
      */
-    public Set<AssignedTask> assignment() {
-        // TODO change assignment to return a map so we aren't forced to copy 
this into a Set
-        return new HashSet<>(assignment.values());
-    }
-
-    // TODO: merge this with #assignment by having it return a Map<TaskId, 
AssignedTask>
-    public Set<TaskId> assignedTaskIds() {
-        return assignment.keySet();
+    public Map<TaskId, AssignedTask> tasks() {
+        return unmodifiableMap(tasks);
     }
 
     public void assignTask(final AssignedTask newTask) {
-        assignment.put(newTask.id(), newTask);
+        tasks.put(newTask.id(), newTask);
     }
 
     public void removeTask(final AssignedTask removedTask) {
-        assignment.remove(removedTask.id());
+        tasks.remove(removedTask.id());
     }
 
     /**
@@ -140,5 +137,25 @@ public class KafkaStreamsAssignment {
         public Type type() {
             return taskType;
         }
+
+        @Override
+        public int hashCode() {
+            final int prime = 31;
+            int result = prime + this.id.hashCode();
+            result = prime * result + this.type().hashCode();
+            return result;
+        }
+
+        @Override
+        public boolean equals(final Object obj) {
+            if (this == obj)
+                return true;
+            if (obj == null)
+                return false;
+            if (getClass() != obj.getClass())
+                return false;
+            final AssignedTask other = (AssignedTask) obj;
+            return this.id.equals(other.id()) && this.taskType == 
other.taskType;
+        }
     }
 }
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/assignment/TaskAssignmentUtils.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/assignment/TaskAssignmentUtils.java
index d7179226a9b..d2e30d53805 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/assignment/TaskAssignmentUtils.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/assignment/TaskAssignmentUtils.java
@@ -16,6 +16,8 @@
  */
 package org.apache.kafka.streams.processor.assignment;
 
+import static 
org.apache.kafka.streams.processor.internals.assignment.RackAwareTaskAssignor.STANDBY_OPTIMIZER_MAX_ITERATION;
+
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.HashMap;
@@ -26,9 +28,15 @@ import java.util.Optional;
 import java.util.Set;
 import java.util.SortedSet;
 import java.util.UUID;
+import java.util.function.BiConsumer;
+import java.util.function.BiFunction;
+import java.util.function.Function;
 import java.util.stream.Collectors;
+import java.util.stream.Stream;
+import org.apache.kafka.streams.KeyValue;
 import org.apache.kafka.streams.processor.TaskId;
 import 
org.apache.kafka.streams.processor.assignment.KafkaStreamsAssignment.AssignedTask;
+import 
org.apache.kafka.streams.processor.internals.assignment.ConstrainedPrioritySet;
 import org.apache.kafka.streams.processor.internals.assignment.Graph;
 import 
org.apache.kafka.streams.processor.internals.assignment.MinTrafficGraphConstructor;
 import 
org.apache.kafka.streams.processor.internals.assignment.RackAwareGraphConstructor;
@@ -72,6 +80,27 @@ public final class TaskAssignmentUtils {
         return assignments;
     }
 
+    /**
+     * Assign standby tasks to KafkaStreams clients according to the default 
logic.
+     * <p>
+     * If rack-aware client tags are configured, the rack-aware standby task 
assignor will be used
+     *
+     * @param applicationState        the metadata and other info describing 
the current application state
+     * @param kafkaStreamsAssignments the current assignment of tasks to 
KafkaStreams clients
+     *
+     * @return a new map containing the mappings from KafkaStreamsAssignments 
updated with the default standby assignment
+     */
+    public static Map<ProcessId, KafkaStreamsAssignment> 
defaultStandbyTaskAssignment(final ApplicationState applicationState,
+                                                                               
       final Map<ProcessId, KafkaStreamsAssignment> kafkaStreamsAssignments) {
+        if 
(!applicationState.assignmentConfigs().rackAwareAssignmentTags().isEmpty()) {
+            return tagBasedStandbyTaskAssignment(applicationState, 
kafkaStreamsAssignments);
+        } else if (canPerformRackAwareOptimization(applicationState, 
AssignedTask.Type.STANDBY)) {
+            return tagBasedStandbyTaskAssignment(applicationState, 
kafkaStreamsAssignments);
+        } else {
+            return loadBasedStandbyTaskAssignment(applicationState, 
kafkaStreamsAssignments);
+        }
+    }
+
     /**
      * Optimize active task assignment for rack awareness. This optimization 
is based on the
      * {@link StreamsConfig#RACK_AWARE_ASSIGNMENT_TRAFFIC_COST_CONFIG 
trafficCost}
@@ -98,11 +127,9 @@ public final class TaskAssignmentUtils {
      *
      * @return a map with the KafkaStreamsAssignments updated to minimize 
cross-rack traffic for active tasks
      */
-    public static Map<ProcessId, KafkaStreamsAssignment> 
optimizeRackAwareActiveTasks(
-        final ApplicationState applicationState,
-        final Map<ProcessId, KafkaStreamsAssignment> kafkaStreamsAssignments,
-        final SortedSet<TaskId> tasks
-    ) {
+    public static Map<ProcessId, KafkaStreamsAssignment> 
optimizeRackAwareActiveTasks(final ApplicationState applicationState,
+                                                                               
       final Map<ProcessId, KafkaStreamsAssignment> kafkaStreamsAssignments,
+                                                                               
       final SortedSet<TaskId> tasks) {
         if (tasks.isEmpty()) {
             return kafkaStreamsAssignments;
         }
@@ -131,7 +158,7 @@ public final class TaskAssignmentUtils {
             clientRacks.put(uuid, 
kafkaStreamsStates.get(entry.getKey()).rackId());
         }
 
-        final long initialCost = computeInitialCost(
+        final long initialCost = computeTotalAssignmentCost(
             topicPartitionsByTaskId,
             taskIds,
             clientIds,
@@ -172,7 +199,7 @@ public final class TaskAssignmentUtils {
             assignmentGraph.clientByTask,
             (assignment, taskId) -> assignment.assignTask(new 
AssignedTask(taskId, AssignedTask.Type.ACTIVE)),
             (assignment, taskId) -> assignment.removeTask(new 
AssignedTask(taskId, AssignedTask.Type.ACTIVE)),
-            (assignment, taskId) -> 
assignment.assignedTaskIds().contains(taskId)
+            (assignment, taskId) -> assignment.tasks().containsKey(taskId)
         );
 
         return kafkaStreamsAssignments;
@@ -231,7 +258,7 @@ public final class TaskAssignmentUtils {
             clientRacks.put(uuid, 
kafkaStreamsStates.get(entry.getKey()).rackId());
         }
 
-        final long initialCost = computeInitialCost(
+        final long initialCost = computeTotalAssignmentCost(
             topicPartitionsByTaskId,
             taskIds,
             clientIds,
@@ -244,18 +271,112 @@ public final class TaskAssignmentUtils {
         );
         LOG.info("Assignment before standby task optimization has cost {}", 
initialCost);
 
-        throw new UnsupportedOperationException("Not yet Implemented.");
+        final MoveStandbyTaskPredicate moveablePredicate = 
getStandbyTaskMovePredicate(applicationState);
+        final BiFunction<KafkaStreamsAssignment, KafkaStreamsAssignment, 
List<TaskId>> getMovableTasks = (source, destination) -> {
+            return source.tasks().values().stream()
+                .filter(task -> task.type() == AssignedTask.Type.STANDBY)
+                .filter(task -> !destination.tasks().containsKey(task.id()))
+                .filter(task -> {
+                    final KafkaStreamsState sourceState = 
kafkaStreamsStates.get(source.processId());
+                    final KafkaStreamsState destinationState = 
kafkaStreamsStates.get(source.processId());
+                    return moveablePredicate.canMoveStandbyTask(sourceState, 
destinationState, task.id(), kafkaStreamsAssignments);
+                })
+                .map(AssignedTask::id)
+                .sorted()
+                .collect(Collectors.toList());
+        };
+
+        final long startTime = System.currentTimeMillis();
+        boolean taskMoved = true;
+        int round = 0;
+        final RackAwareGraphConstructor<KafkaStreamsAssignment> 
graphConstructor = RackAwareGraphConstructorFactory.create(
+            
applicationState.assignmentConfigs().rackAwareAssignmentStrategy(), taskIds);
+        while (taskMoved && round < STANDBY_OPTIMIZER_MAX_ITERATION) {
+            taskMoved = false;
+            round++;
+            for (int i = 0; i < kafkaStreamsAssignments.size(); i++) {
+                final UUID clientId1 = clientIds.get(i);
+                final KafkaStreamsAssignment clientState1 = 
kafkaStreamsAssignments.get(new ProcessId(clientId1));
+                for (int j = i + 1; j < kafkaStreamsAssignments.size(); j++) {
+                    final UUID clientId2 = clientIds.get(i);
+                    final KafkaStreamsAssignment clientState2 = 
kafkaStreamsAssignments.get(new ProcessId(clientId2));
+
+                    final String rack1 = 
clientRacks.get(clientState1.processId().id()).get();
+                    final String rack2 = 
clientRacks.get(clientState2.processId().id()).get();
+                    // Cross rack traffic can not be reduced if racks are the 
same
+                    if (rack1.equals(rack2)) {
+                        continue;
+                    }
+
+                    final List<TaskId> movable1 = 
getMovableTasks.apply(clientState1, clientState2);
+                    final List<TaskId> movable2 = 
getMovableTasks.apply(clientState2, clientState1);
+
+                    // There's no needed to optimize if one is empty because 
the optimization
+                    // can only swap tasks to keep the client's load balanced
+                    if (movable1.isEmpty() || movable2.isEmpty()) {
+                        continue;
+                    }
+
+                    final List<TaskId> taskIdList = 
Stream.concat(movable1.stream(), movable2.stream())
+                        .sorted()
+                        .collect(Collectors.toList());
+                    final List<UUID> clients = Stream.of(clientId1, 
clientId2).sorted().collect(Collectors.toList());
+
+                    final AssignmentGraph assignmentGraph = buildTaskGraph(
+                        assignmentsByUuid,
+                        clientRacks,
+                        taskIdList,
+                        clients,
+                        topicPartitionsByTaskId,
+                        crossRackTrafficCost,
+                        nonOverlapCost,
+                        false,
+                        false,
+                        graphConstructor
+                    );
+                    assignmentGraph.graph.solveMinCostFlow();
+
+                    taskMoved |= graphConstructor.assignTaskFromMinCostFlow(
+                        assignmentGraph.graph,
+                        clientIds,
+                        taskIds,
+                        assignmentsByUuid,
+                        assignmentGraph.taskCountByClient,
+                        assignmentGraph.clientByTask,
+                        (assignment, taskId) -> assignment.assignTask(new 
AssignedTask(taskId, AssignedTask.Type.STANDBY)),
+                        (assignment, taskId) -> assignment.removeTask(new 
AssignedTask(taskId, AssignedTask.Type.STANDBY)),
+                        (assignment, taskId) -> 
assignment.tasks().containsKey(taskId) && assignment.tasks().get(taskId).type() 
== AssignedTask.Type.STANDBY
+                    );
+                }
+            }
+        }
+        final long finalCost = computeTotalAssignmentCost(
+            topicPartitionsByTaskId,
+            taskIds,
+            clientIds,
+            assignmentsByUuid,
+            clientRacks,
+            crossRackTrafficCost,
+            nonOverlapCost,
+            true,
+            true
+        );
+
+        final long duration = System.currentTimeMillis() - startTime;
+        LOG.info("Assignment after {} rounds and {} milliseconds for standby 
task optimization is {}\n with cost {}",
+            round, duration, kafkaStreamsAssignments, finalCost);
+        return kafkaStreamsAssignments;
     }
 
-    private static long computeInitialCost(final Map<TaskId, 
Set<TaskTopicPartition>> topicPartitionsByTaskId,
-                                           final List<TaskId> taskIds,
-                                           final List<UUID> clientIds,
-                                           final Map<UUID, 
KafkaStreamsAssignment> assignmentsByUuid,
-                                           final Map<UUID, Optional<String>> 
clientRacks,
-                                           final int crossRackTrafficCost,
-                                           final int nonOverlapCost,
-                                           final boolean hasReplica,
-                                           final boolean isStandby) {
+    private static long computeTotalAssignmentCost(final Map<TaskId, 
Set<TaskTopicPartition>> topicPartitionsByTaskId,
+                                                   final List<TaskId> taskIds,
+                                                   final List<UUID> clientIds,
+                                                   final Map<UUID, 
KafkaStreamsAssignment> assignmentsByUuid,
+                                                   final Map<UUID, 
Optional<String>> clientRacks,
+                                                   final int 
crossRackTrafficCost,
+                                                   final int nonOverlapCost,
+                                                   final boolean hasReplica,
+                                                   final boolean isStandby) {
         if (taskIds.isEmpty()) {
             return 0;
         }
@@ -296,7 +417,7 @@ public final class TaskAssignmentUtils {
             assignmentsByUuid,
             clientByTask,
             taskCountByClient,
-            (assignment, taskId) -> 
assignment.assignedTaskIds().contains(taskId),
+            (assignment, taskId) -> assignment.tasks().containsKey(taskId),
             (taskId, processId, inCurrentAssignment, unused0, unused1, 
unused2) -> {
                 final String clientRack = clientRacks.get(processId).get();
                 final int assignmentChangeCost = !inCurrentAssignment ? 
nonOverlapCost : 0;
@@ -328,6 +449,14 @@ public final class TaskAssignmentUtils {
         }
     }
 
+    @FunctionalInterface
+    public interface MoveStandbyTaskPredicate {
+        boolean canMoveStandbyTask(final KafkaStreamsState source,
+                                   final KafkaStreamsState destination,
+                                   final TaskId taskId,
+                                   final Map<ProcessId, 
KafkaStreamsAssignment> kafkaStreamsAssignment);
+    }
+
     /**
      *
      * @return the traffic cost of assigning this {@param task} to the client 
{@param streamsState}.
@@ -354,7 +483,6 @@ public final class TaskAssignmentUtils {
                                                            final 
AssignedTask.Type taskType) {
         final String rackAwareAssignmentStrategy = 
applicationState.assignmentConfigs().rackAwareAssignmentStrategy();
         if 
(StreamsConfig.RACK_AWARE_ASSIGNMENT_STRATEGY_NONE.equals(rackAwareAssignmentStrategy))
 {
-            LOG.warn("KafkaStreams rack aware task assignment optimization was 
disabled in the StreamsConfig.");
             return false;
         }
         return hasValidRackInformation(applicationState, taskType);
@@ -407,4 +535,317 @@ public final class TaskAssignmentUtils {
         }
         return true;
     }
+
+    private static Map<ProcessId, KafkaStreamsAssignment> 
tagBasedStandbyTaskAssignment(final ApplicationState applicationState,
+                                                                               
         final Map<ProcessId, KafkaStreamsAssignment> kafkaStreamsAssignments) {
+        final int numStandbyReplicas = 
applicationState.assignmentConfigs().numStandbyReplicas();
+        final Map<ProcessId, KafkaStreamsState> streamStates = 
applicationState.kafkaStreamsStates(false);
+
+        final Set<String> rackAwareAssignmentTags = new 
HashSet<>(applicationState.assignmentConfigs().rackAwareAssignmentTags());
+        final TagStatistics tagStatistics = new 
TagStatistics(applicationState);
+
+        final ConstrainedPrioritySet standbyTaskClientsByTaskLoad = 
standbyTaskPriorityListByLoad(streamStates, kafkaStreamsAssignments);
+
+        final Set<TaskId> statefulTaskIds = 
applicationState.allTasks().values().stream()
+            .filter(TaskInfo::isStateful)
+            .map(TaskInfo::id)
+            .collect(Collectors.toSet());
+        final Map<TaskId, Integer> tasksToRemainingStandbys = 
statefulTaskIds.stream()
+            .collect(Collectors.toMap(Function.identity(), t -> 
numStandbyReplicas));
+        final Map<UUID, KafkaStreamsAssignment> clientsByUuid = 
kafkaStreamsAssignments.entrySet().stream().collect(Collectors.toMap(
+            entry -> entry.getKey().id(),
+            Map.Entry::getValue
+        ));
+
+        final Map<TaskId, ProcessId> pendingStandbyTasksToClientId = new 
HashMap<>();
+        for (final TaskId statefulTaskId : statefulTaskIds) {
+            for (final KafkaStreamsAssignment assignment : 
clientsByUuid.values()) {
+                if (assignment.tasks().containsKey(statefulTaskId)) {
+                    assignStandbyTasksToClientsWithDifferentTags(
+                        numStandbyReplicas,
+                        standbyTaskClientsByTaskLoad,
+                        statefulTaskId,
+                        assignment.processId(),
+                        rackAwareAssignmentTags,
+                        streamStates,
+                        kafkaStreamsAssignments,
+                        tasksToRemainingStandbys,
+                        tagStatistics.tagKeyToValues,
+                        tagStatistics.tagEntryToClients,
+                        pendingStandbyTasksToClientId
+                    );
+                }
+            }
+        }
+
+        if (!tasksToRemainingStandbys.isEmpty()) {
+            assignPendingStandbyTasksToLeastLoadedClients(clientsByUuid,
+                numStandbyReplicas,
+                standbyTaskClientsByTaskLoad,
+                tasksToRemainingStandbys);
+        }
+
+        return kafkaStreamsAssignments;
+    }
+
+    private static Map<ProcessId, KafkaStreamsAssignment> 
loadBasedStandbyTaskAssignment(final ApplicationState applicationState,
+                                                                               
          final Map<ProcessId, KafkaStreamsAssignment> kafkaStreamsAssignments) 
{
+        final int numStandbyReplicas = 
applicationState.assignmentConfigs().numStandbyReplicas();
+        final Map<ProcessId, KafkaStreamsState> streamStates = 
applicationState.kafkaStreamsStates(false);
+
+        final Set<TaskId> statefulTaskIds = 
applicationState.allTasks().values().stream()
+            .filter(TaskInfo::isStateful)
+            .map(TaskInfo::id)
+            .collect(Collectors.toSet());
+        final Map<TaskId, Integer> tasksToRemainingStandbys = 
statefulTaskIds.stream()
+            .collect(Collectors.toMap(Function.identity(), t -> 
numStandbyReplicas));
+        final Map<UUID, KafkaStreamsAssignment> clients = 
kafkaStreamsAssignments.entrySet().stream().collect(Collectors.toMap(
+            entry -> entry.getKey().id(),
+            Map.Entry::getValue
+        ));
+
+        final ConstrainedPrioritySet standbyTaskClientsByTaskLoad = 
standbyTaskPriorityListByLoad(streamStates, kafkaStreamsAssignments);
+        
standbyTaskClientsByTaskLoad.offerAll(streamStates.keySet().stream().map(ProcessId::id).collect(Collectors.toSet()));
+        for (final TaskId task : statefulTaskIds) {
+            assignStandbyTasksForActiveTask(
+                numStandbyReplicas,
+                clients,
+                tasksToRemainingStandbys,
+                standbyTaskClientsByTaskLoad,
+                task
+            );
+        }
+        return kafkaStreamsAssignments;
+    }
+
+    private static void assignStandbyTasksForActiveTask(final int 
numStandbyReplicas,
+                                                        final Map<UUID, 
KafkaStreamsAssignment> clients,
+                                                        final Map<TaskId, 
Integer> tasksToRemainingStandbys,
+                                                        final 
ConstrainedPrioritySet standbyTaskClientsByTaskLoad,
+                                                        final TaskId 
activeTaskId) {
+        int numRemainingStandbys = tasksToRemainingStandbys.get(activeTaskId);
+        while (numRemainingStandbys > 0) {
+            final UUID client = 
standbyTaskClientsByTaskLoad.poll(activeTaskId);
+            if (client == null) {
+                break;
+            }
+            clients.get(client).assignTask(new AssignedTask(activeTaskId, 
AssignedTask.Type.STANDBY));
+            numRemainingStandbys--;
+            standbyTaskClientsByTaskLoad.offer(client);
+        }
+
+        tasksToRemainingStandbys.put(activeTaskId, numRemainingStandbys);
+        if (numRemainingStandbys > 0) {
+            LOG.warn("Unable to assign {} of {} standby tasks for task [{}]. " 
+
+                     "There is not enough available capacity. You should " +
+                     "increase the number of application instances " +
+                     "to maintain the requested number of standby replicas.",
+                numRemainingStandbys, numStandbyReplicas, activeTaskId);
+        }
+    }
+
+    private static void assignStandbyTasksToClientsWithDifferentTags(final int 
numberOfStandbyClients,
+                                                                     final 
ConstrainedPrioritySet standbyTaskClientsByTaskLoad,
+                                                                     final 
TaskId activeTaskId,
+                                                                     final 
ProcessId activeTaskClient,
+                                                                     final 
Set<String> rackAwareAssignmentTags,
+                                                                     final 
Map<ProcessId, KafkaStreamsState> clientStates,
+                                                                     final 
Map<ProcessId, KafkaStreamsAssignment> kafkaStreamsAssignments,
+                                                                     final 
Map<TaskId, Integer> tasksToRemainingStandbys,
+                                                                     final 
Map<String, Set<String>> tagKeyToValues,
+                                                                     final 
Map<KeyValue<String, String>, Set<ProcessId>> tagEntryToClients,
+                                                                     final 
Map<TaskId, ProcessId> pendingStandbyTasksToClientId) {
+        standbyTaskClientsByTaskLoad.offerAll(clientStates.keySet().stream()
+            .map(ProcessId::id).collect(Collectors.toSet()));
+
+        // We set countOfUsedClients as 1 because client where active task is 
located has to be considered as used.
+        int countOfUsedClients = 1;
+        int numRemainingStandbys = tasksToRemainingStandbys.get(activeTaskId);
+
+        final Map<KeyValue<String, String>, Set<ProcessId>> 
tagEntryToUsedClients = new HashMap<>();
+
+        ProcessId lastUsedClient = activeTaskClient;
+        do {
+            updateClientsOnAlreadyUsedTagEntries(
+                clientStates.get(lastUsedClient),
+                countOfUsedClients,
+                rackAwareAssignmentTags,
+                tagEntryToClients,
+                tagKeyToValues,
+                tagEntryToUsedClients
+            );
+
+            final UUID clientOnUnusedTagDimensions = 
standbyTaskClientsByTaskLoad.poll(
+                activeTaskId, uuid -> !isClientUsedOnAnyOfTheTagEntries(new 
ProcessId(uuid), tagEntryToUsedClients)
+            );
+
+            if (clientOnUnusedTagDimensions == null) {
+                break;
+            }
+
+            final KafkaStreamsState clientStateOnUsedTagDimensions = 
clientStates.get(new ProcessId(clientOnUnusedTagDimensions));
+            countOfUsedClients++;
+            numRemainingStandbys--;
+
+            LOG.debug("Assigning {} out of {} standby tasks for an active task 
[{}] with client tags {}. " +
+                      "Standby task client tags are {}.",
+                numberOfStandbyClients - numRemainingStandbys, 
numberOfStandbyClients, activeTaskId,
+                clientStates.get(activeTaskClient).clientTags(),
+                clientStateOnUsedTagDimensions.clientTags());
+
+            
kafkaStreamsAssignments.get(clientStateOnUsedTagDimensions.processId()).assignTask(
+                new AssignedTask(activeTaskId, AssignedTask.Type.STANDBY)
+            );
+            lastUsedClient = new ProcessId(clientOnUnusedTagDimensions);
+        } while (numRemainingStandbys > 0);
+
+        if (numRemainingStandbys > 0) {
+            pendingStandbyTasksToClientId.put(activeTaskId, activeTaskClient);
+            tasksToRemainingStandbys.put(activeTaskId, numRemainingStandbys);
+            LOG.warn("Rack aware standby task assignment was not able to 
assign {} of {} standby tasks for the " +
+                     "active task [{}] with the rack aware assignment tags {}. 
" +
+                     "This may happen when there aren't enough application 
instances on different tag " +
+                     "dimensions compared to an active and corresponding 
standby task. " +
+                     "Consider launching application instances on different 
tag dimensions than [{}]. " +
+                     "Standby task assignment will fall back to assigning 
standby tasks to the least loaded clients.",
+                numRemainingStandbys, numberOfStandbyClients,
+                activeTaskId, rackAwareAssignmentTags,
+                clientStates.get(activeTaskClient).clientTags());
+
+        } else {
+            tasksToRemainingStandbys.remove(activeTaskId);
+        }
+    }
+
+    private static boolean isClientUsedOnAnyOfTheTagEntries(final ProcessId 
client,
+                                                            final 
Map<KeyValue<String, String>, Set<ProcessId>> tagEntryToUsedClients) {
+        return tagEntryToUsedClients.values().stream().anyMatch(usedClients -> 
usedClients.contains(client));
+    }
+
+    private static void updateClientsOnAlreadyUsedTagEntries(final 
KafkaStreamsState usedClient,
+                                                             final int 
countOfUsedClients,
+                                                             final Set<String> 
rackAwareAssignmentTags,
+                                                             final 
Map<KeyValue<String, String>, Set<ProcessId>> tagEntryToClients,
+                                                             final Map<String, 
Set<String>> tagKeyToValues,
+                                                             final 
Map<KeyValue<String, String>, Set<ProcessId>> tagEntryToUsedClients) {
+        final Map<String, String> usedClientTags = usedClient.clientTags();
+
+        for (final Map.Entry<String, String> usedClientTagEntry : 
usedClientTags.entrySet()) {
+            final String tagKey = usedClientTagEntry.getKey();
+
+            if (!rackAwareAssignmentTags.contains(tagKey)) {
+                LOG.warn("Client tag with key [{}] will be ignored when 
computing rack aware standby " +
+                         "task assignment because it is not part of the 
configured rack awareness [{}].",
+                    tagKey, rackAwareAssignmentTags);
+                continue;
+            }
+
+            final Set<String> allTagValues = tagKeyToValues.get(tagKey);
+
+            if (allTagValues.size() <= countOfUsedClients) {
+                allTagValues.forEach(tagValue -> 
tagEntryToUsedClients.remove(new KeyValue<>(tagKey, tagValue)));
+            } else {
+                final String tagValue = usedClientTagEntry.getValue();
+                final KeyValue<String, String> tagEntry = new 
KeyValue<>(tagKey, tagValue);
+                final Set<ProcessId> clientsOnUsedTagValue = 
tagEntryToClients.get(tagEntry);
+                tagEntryToUsedClients.put(tagEntry, clientsOnUsedTagValue);
+            }
+        }
+    }
+
+    private static MoveStandbyTaskPredicate getStandbyTaskMovePredicate(final 
ApplicationState applicationState) {
+        final boolean hasRackAwareAssignmentTags = 
!applicationState.assignmentConfigs().rackAwareAssignmentTags().isEmpty();
+        if (hasRackAwareAssignmentTags) {
+            final BiConsumer<KafkaStreamsState, Set<KeyValue<String, String>>> 
addTags = (cs, tagSet) -> {
+                final Map<String, String> tags = cs.clientTags();
+                if (tags != null) {
+                    tagSet.addAll(tags.entrySet().stream()
+                        .map(entry -> KeyValue.pair(entry.getKey(), 
entry.getValue()))
+                        .collect(Collectors.toList())
+                    );
+                }
+            };
+
+            final Map<ProcessId, KafkaStreamsState> clients = 
applicationState.kafkaStreamsStates(false);
+
+            return (source, destination, sourceTask, kafkaStreamsAssignments) 
-> {
+                final Set<KeyValue<String, String>> tagsWithSource = new 
HashSet<>();
+                final Set<KeyValue<String, String>> tagsWithDestination = new 
HashSet<>();
+                for (final KafkaStreamsAssignment assignment: 
kafkaStreamsAssignments.values()) {
+                    final boolean hasAssignedTask = 
assignment.tasks().containsKey(sourceTask);
+                    final boolean isSourceProcess = 
assignment.processId().equals(source.processId());
+                    final boolean isDestinationProcess = 
assignment.processId().equals(destination.processId());
+                    if (hasAssignedTask && !isSourceProcess && 
!isDestinationProcess) {
+                        final KafkaStreamsState clientState = 
clients.get(assignment.processId());
+                        addTags.accept(clientState, tagsWithSource);
+                        addTags.accept(clientState, tagsWithDestination);
+                    }
+                }
+                addTags.accept(source, tagsWithSource);
+                addTags.accept(destination, tagsWithDestination);
+                return tagsWithDestination.size() >= tagsWithSource.size();
+            };
+        } else {
+            return (a, b, c, d) -> true;
+        }
+    }
+
+    private static ConstrainedPrioritySet standbyTaskPriorityListByLoad(final 
Map<ProcessId, KafkaStreamsState> streamStates,
+                                                                        final 
Map<ProcessId, KafkaStreamsAssignment> kafkaStreamsAssignments) {
+        return new ConstrainedPrioritySet(
+            (processId, taskId) -> kafkaStreamsAssignments.get(new 
ProcessId(processId)).tasks().containsKey(taskId),
+            processId -> {
+                final double capacity = streamStates.get(new 
ProcessId(processId)).numProcessingThreads();
+                final double numTasks = kafkaStreamsAssignments.get(new 
ProcessId(processId)).tasks().size();
+                return numTasks / capacity;
+            }
+        );
+    }
+
+    private static void assignPendingStandbyTasksToLeastLoadedClients(final 
Map<UUID, KafkaStreamsAssignment> clients,
+                                                                      final 
int numStandbyReplicas,
+                                                                      final 
ConstrainedPrioritySet standbyTaskClientsByTaskLoad,
+                                                                      final 
Map<TaskId, Integer> pendingStandbyTaskToNumberRemainingStandbys) {
+        // We need to re offer all the clients to find the least loaded ones
+        standbyTaskClientsByTaskLoad.offerAll(clients.keySet());
+
+        for (final Map.Entry<TaskId, Integer> 
pendingStandbyTaskAssignmentEntry : 
pendingStandbyTaskToNumberRemainingStandbys.entrySet()) {
+            final TaskId activeTaskId = 
pendingStandbyTaskAssignmentEntry.getKey();
+
+            assignStandbyTasksForActiveTask(
+                numStandbyReplicas,
+                clients,
+                pendingStandbyTaskToNumberRemainingStandbys,
+                standbyTaskClientsByTaskLoad,
+                activeTaskId
+            );
+        }
+    }
+
+    private static class TagStatistics {
+        private final Map<String, Set<String>> tagKeyToValues;
+        private final Map<KeyValue<String, String>, Set<ProcessId>> 
tagEntryToClients;
+
+        private TagStatistics(final Map<String, Set<String>> tagKeyToValues,
+                              final Map<KeyValue<String, String>, 
Set<ProcessId>> tagEntryToClients) {
+            this.tagKeyToValues = tagKeyToValues;
+            this.tagEntryToClients = tagEntryToClients;
+        }
+
+        public TagStatistics(final ApplicationState applicationState) {
+            final Map<ProcessId, KafkaStreamsState> clientStates = 
applicationState.kafkaStreamsStates(false);
+
+            final Map<String, Set<String>> tagKeyToValues = new HashMap<>();
+            final Map<KeyValue<String, String>, Set<ProcessId>> 
tagEntryToClients = new HashMap<>();
+            for (final KafkaStreamsState state : clientStates.values()) {
+                state.clientTags().forEach((tagKey, tagValue) -> {
+                    tagKeyToValues.computeIfAbsent(tagKey, ignored -> new 
HashSet<>()).add(tagValue);
+                    tagEntryToClients.computeIfAbsent(new KeyValue<>(tagKey, 
tagValue), ignored -> new HashSet<>()).add(state.processId());
+                });
+            }
+
+            this.tagKeyToValues = tagKeyToValues;
+            this.tagEntryToClients = tagEntryToClients;
+        }
+    }
 }
\ No newline at end of file
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/assignment/assignors/StickyTaskAssignor.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/assignment/assignors/StickyTaskAssignor.java
index 74d272ef1ec..2174b6d823b 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/assignment/assignors/StickyTaskAssignor.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/assignment/assignors/StickyTaskAssignor.java
@@ -278,7 +278,7 @@ public class StickyTaskAssignor implements TaskAssignor {
 
             for (final Map.Entry<ProcessId, KafkaStreamsAssignment> entry : 
optimizedAssignments.entrySet()) {
                 final ProcessId processId = entry.getKey();
-                final Set<AssignedTask> assignedTasks = 
optimizedAssignments.get(processId).assignment();
+                final Set<AssignedTask> assignedTasks = new 
HashSet<>(optimizedAssignments.get(processId).tasks().values());
                 newAssignments.put(processId, assignedTasks);
 
                 for (final AssignedTask task : assignedTasks) {
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java
index 3fd54d70ab8..658ba1540da 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java
@@ -18,6 +18,7 @@ package org.apache.kafka.streams.processor.internals;
 
 import java.time.Instant;
 import java.util.Optional;
+import java.util.function.Function;
 import org.apache.kafka.clients.admin.Admin;
 import org.apache.kafka.clients.admin.ListOffsetsResult;
 import org.apache.kafka.clients.admin.ListOffsetsResult.ListOffsetsResultInfo;
@@ -555,18 +556,21 @@ public class StreamsPartitionAssignor implements 
ConsumerPartitionAssignor, Conf
 
         RackUtils.annotateTopicPartitionsWithRackInfo(cluster, 
internalTopicManager, allTopicPartitions);
 
-        final Set<TaskInfo> logicalTasks = logicalTaskIds.stream().map(taskId 
-> {
-            final Set<String> stateStoreNames = topologyMetadata
-                .stateStoreNameToSourceTopicsForTopology(taskId.topologyName())
-                .keySet();
-            final Set<TaskTopicPartition> topicPartitions = 
topicPartitionsForTask.get(taskId);
-            return new DefaultTaskInfo(
-                taskId,
-                !stateStoreNames.isEmpty(),
-                stateStoreNames,
-                topicPartitions
-            );
-        }).collect(Collectors.toSet());
+        final Map<TaskId, TaskInfo> logicalTasks = 
logicalTaskIds.stream().collect(Collectors.toMap(
+            Function.identity(),
+            taskId -> {
+                final Set<String> stateStoreNames = topologyMetadata
+                    
.stateStoreNameToSourceTopicsForTopology(taskId.topologyName())
+                    .keySet();
+                final Set<TaskTopicPartition> topicPartitions = 
topicPartitionsForTask.get(taskId);
+                return new DefaultTaskInfo(
+                    taskId,
+                    !stateStoreNames.isEmpty(),
+                    stateStoreNames,
+                    topicPartitions
+                );
+            }
+        ));
 
         return new DefaultApplicationState(
             assignmentConfigs.toPublicAssignmentConfigs(),
@@ -728,12 +732,12 @@ public class StreamsPartitionAssignor implements 
ConsumerPartitionAssignor, Conf
      * populate the stateful tasks that have been assigned to the clients
      */
     private UserTaskAssignmentListener assignTasksToClients(final Cluster 
fullMetadata,
-                                      final Set<String> allSourceTopics,
-                                      final Map<Subtopology, TopicsInfo> 
topicGroups,
-                                      final Map<UUID, ClientMetadata> 
clientMetadataMap,
-                                      final Map<TaskId, Set<TopicPartition>> 
partitionsForTask,
-                                      final Map<UUID, Map<String, 
Optional<String>>> racksForProcessConsumer,
-                                      final Set<TaskId> statefulTasks) {
+                                                            final Set<String> 
allSourceTopics,
+                                                            final 
Map<Subtopology, TopicsInfo> topicGroups,
+                                                            final Map<UUID, 
ClientMetadata> clientMetadataMap,
+                                                            final Map<TaskId, 
Set<TopicPartition>> partitionsForTask,
+                                                            final Map<UUID, 
Map<String, Optional<String>>> racksForProcessConsumer,
+                                                            final Set<TaskId> 
statefulTasks) {
         if (!statefulTasks.isEmpty()) {
             throw new TaskAssignmentException("The stateful tasks should not 
be populated before assigning tasks to clients");
         }
@@ -775,7 +779,7 @@ public class StreamsPartitionAssignor implements 
ConsumerPartitionAssignor, Conf
 
         final 
Optional<org.apache.kafka.streams.processor.assignment.TaskAssignor> 
userTaskAssignor =
             userTaskAssignorSupplier.get();
-        UserTaskAssignmentListener userTaskAssignmentListener = 
(GroupAssignment assignment, GroupSubscription subscription) -> { };
+        final UserTaskAssignmentListener userTaskAssignmentListener;
         if (userTaskAssignor.isPresent()) {
             final ApplicationState applicationState = buildApplicationState(
                 taskManager.topologyMetadata(),
@@ -785,12 +789,11 @@ public class StreamsPartitionAssignor implements 
ConsumerPartitionAssignor, Conf
             );
             final org.apache.kafka.streams.processor.assignment.TaskAssignor 
assignor = userTaskAssignor.get();
             final TaskAssignment taskAssignment = 
assignor.assign(applicationState);
-            processStreamsPartitionAssignment(clientMetadataMap, 
taskAssignment);
             final AssignmentError assignmentError = 
validateTaskAssignment(applicationState, taskAssignment);
-            userTaskAssignmentListener = (GroupAssignment assignment, 
GroupSubscription subscription) -> {
-                assignor.onAssignmentComputed(assignment, subscription, 
assignmentError);
-            };
+            processStreamsPartitionAssignment(clientMetadataMap, 
taskAssignment);
+            userTaskAssignmentListener = (assignment, subscription) -> 
assignor.onAssignmentComputed(assignment, subscription, assignmentError);
         } else {
+            userTaskAssignmentListener = (assignment, subscription) -> { };
             final TaskAssignor taskAssignor = 
createTaskAssignor(lagComputationSuccessful);
             final RackAwareTaskAssignor rackAwareTaskAssignor = new 
RackAwareTaskAssignor(
                 fullMetadata,
@@ -1564,7 +1567,7 @@ public class StreamsPartitionAssignor implements 
ConsumerPartitionAssignor, Conf
         final Map<TaskId, ProcessId> standbyTasksInOutput = new HashMap<>();
         for (final KafkaStreamsAssignment assignment : assignments) {
             final Set<TaskId> tasksForAssignment = new HashSet<>();
-            for (final KafkaStreamsAssignment.AssignedTask task : 
assignment.assignment()) {
+            for (final KafkaStreamsAssignment.AssignedTask task : 
assignment.tasks().values()) {
                 if (activeTasksInOutput.containsKey(task.id()) && task.type() 
== KafkaStreamsAssignment.AssignedTask.Type.ACTIVE) {
                     log.error("Assignment is invalid: active task {} was 
assigned to multiple KafkaStreams clients: {} and {}",
                         task.id(), assignment.processId().id(), 
activeTasksInOutput.get(task.id()).id());
@@ -1614,7 +1617,7 @@ public class StreamsPartitionAssignor implements 
ConsumerPartitionAssignor, Conf
 
         final Set<TaskId> taskIdsInInput = 
applicationState.allTasks().keySet();
         for (final KafkaStreamsAssignment assignment : assignments) {
-            for (final KafkaStreamsAssignment.AssignedTask task : 
assignment.assignment()) {
+            for (final KafkaStreamsAssignment.AssignedTask task : 
assignment.tasks().values()) {
                 if (!taskIdsInInput.contains(task.id())) {
                     log.error("Assignment is invalid: task {} assigned to 
KafkaStreams client {} was unknown",
                         task.id(), assignment.processId().id());
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java
index d24e9c19167..704c1d50885 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java
@@ -85,8 +85,8 @@ public class ClientState {
     }
 
     ClientState(final UUID processId, final int capacity, final Map<String, 
String> clientTags) {
-        previousStandbyTasks.taskIds(new TreeSet<>());
-        previousActiveTasks.taskIds(new TreeSet<>());
+        previousStandbyTasks.setTaskIds(new TreeSet<>());
+        previousActiveTasks.setTaskIds(new TreeSet<>());
         taskOffsetSums = new TreeMap<>();
         taskLagTotals = new TreeMap<>();
         this.capacity = capacity;
@@ -110,8 +110,8 @@ public class ClientState {
                        final Map<String, String> clientTags,
                        final int capacity,
                        final UUID processId) {
-        this.previousStandbyTasks.taskIds(unmodifiableSet(new 
TreeSet<>(previousStandbyTasks)));
-        this.previousActiveTasks.taskIds(unmodifiableSet(new 
TreeSet<>(previousActiveTasks)));
+        this.previousStandbyTasks.setTaskIds(unmodifiableSet(new 
TreeSet<>(previousStandbyTasks)));
+        this.previousActiveTasks.setTaskIds(unmodifiableSet(new 
TreeSet<>(previousActiveTasks)));
         taskOffsetSums = emptyMap();
         this.taskLagTotals = unmodifiableMap(taskLagTotals);
         this.capacity = capacity;
@@ -489,14 +489,14 @@ public class ClientState {
     }
 
     public void setAssignedTasks(final KafkaStreamsAssignment assignment) {
-        final Set<TaskId> activeTasks = assignment.assignment().stream()
+        final Set<TaskId> activeTasks = assignment.tasks().values().stream()
             .filter(task -> task.type() == 
ACTIVE).map(KafkaStreamsAssignment.AssignedTask::id)
             .collect(Collectors.toSet());
-        final Set<TaskId> standbyTasks = assignment.assignment().stream()
+        final Set<TaskId> standbyTasks = assignment.tasks().values().stream()
             .filter(task -> task.type() == 
STANDBY).map(KafkaStreamsAssignment.AssignedTask::id)
             .collect(Collectors.toSet());
-        assignedActiveTasks.taskIds(activeTasks);
-        assignedStandbyTasks.taskIds(standbyTasks);
+        assignedActiveTasks.setTaskIds(activeTasks);
+        assignedStandbyTasks.setTaskIds(standbyTasks);
     }
 
     public String currentAssignment() {
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientStateTask.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientStateTask.java
index 92769699ccc..1f098dc9aea 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientStateTask.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientStateTask.java
@@ -31,7 +31,7 @@ class ClientStateTask {
         this.consumerToTaskIds = consumerToTaskIds;
     }
 
-    void taskIds(final Set<TaskId> clientToTaskIds) {
+    void setTaskIds(final Set<TaskId> clientToTaskIds) {
         taskIds = clientToTaskIds;
     }
 
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ConstrainedPrioritySet.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ConstrainedPrioritySet.java
index 1de9dfc3c67..3705ee39af3 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ConstrainedPrioritySet.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ConstrainedPrioritySet.java
@@ -30,14 +30,14 @@ import java.util.function.Function;
 /**
  * Wraps a priority queue of clients and returns the next valid candidate(s) 
based on the current task assignment
  */
-class ConstrainedPrioritySet {
+public class ConstrainedPrioritySet {
 
     private final PriorityQueue<UUID> clientsByTaskLoad;
     private final BiFunction<UUID, TaskId, Boolean> constraint;
     private final Set<UUID> uniqueClients = new HashSet<>();
 
-    ConstrainedPrioritySet(final BiFunction<UUID, TaskId, Boolean> constraint,
-                           final Function<UUID, Double> weight) {
+    public ConstrainedPrioritySet(final BiFunction<UUID, TaskId, Boolean> 
constraint,
+                                  final Function<UUID, Double> weight) {
         this.constraint = constraint;
         clientsByTaskLoad = new 
PriorityQueue<>(Comparator.comparing(weight).thenComparing(clientId -> 
clientId));
     }
@@ -45,7 +45,7 @@ class ConstrainedPrioritySet {
     /**
      * @return the next least loaded client that satisfies the given criteria, 
or null if none do
      */
-    UUID poll(final TaskId task, final Function<UUID, Boolean> 
extraConstraint) {
+    public UUID poll(final TaskId task, final Function<UUID, Boolean> 
extraConstraint) {
         final Set<UUID> invalidPolledClients = new HashSet<>();
         while (!clientsByTaskLoad.isEmpty()) {
             final UUID candidateClient = pollNextClient();
@@ -66,17 +66,17 @@ class ConstrainedPrioritySet {
     /**
      * @return the next least loaded client that satisfies the given criteria, 
or null if none do
      */
-    UUID poll(final TaskId task) {
+    public UUID poll(final TaskId task) {
         return poll(task, client -> true);
     }
 
-    void offerAll(final Collection<UUID> clients) {
+    public void offerAll(final Collection<UUID> clients) {
         for (final UUID client : clients) {
             offer(client);
         }
     }
 
-    void offer(final UUID client) {
+    public void offer(final UUID client) {
         if (uniqueClients.contains(client)) {
             clientsByTaskLoad.remove(client);
         } else {
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/DefaultApplicationState.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/DefaultApplicationState.java
index f92898ccf34..b3b3084bc17 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/DefaultApplicationState.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/DefaultApplicationState.java
@@ -21,9 +21,7 @@ import static java.util.Collections.unmodifiableMap;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.Optional;
-import java.util.Set;
 import java.util.UUID;
-import java.util.stream.Collectors;
 import org.apache.kafka.streams.processor.assignment.TaskInfo;
 import 
org.apache.kafka.streams.processor.internals.StreamsPartitionAssignor.ClientMetadata;
 import org.apache.kafka.streams.processor.TaskId;
@@ -42,10 +40,10 @@ public class DefaultApplicationState implements 
ApplicationState {
     private final Map<Boolean, Map<ProcessId, KafkaStreamsState>> 
cachedKafkaStreamStates;
 
     public DefaultApplicationState(final AssignmentConfigs assignmentConfigs,
-                                   final Set<TaskInfo> tasks,
+                                   final Map<TaskId, TaskInfo> tasks,
                                    final Map<UUID, ClientMetadata> 
clientStates) {
         this.assignmentConfigs = assignmentConfigs;
-        this.tasks = 
unmodifiableMap(tasks.stream().collect(Collectors.toMap(TaskInfo::id, task -> 
task)));
+        this.tasks = unmodifiableMap(tasks);
         this.clientStates = clientStates;
         this.cachedKafkaStreamStates = new HashMap<>();
     }
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignor.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignor.java
index 1d671603e4c..4b430fbb2c1 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignor.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignor.java
@@ -77,7 +77,7 @@ public class RackAwareTaskAssignor {
 
     // This is number is picked based on testing. Usually the optimization for 
standby assignment
     // stops after 3 rounds
-    private static final int STANDBY_OPTIMIZER_MAX_ITERATION = 4;
+    public static final int STANDBY_OPTIMIZER_MAX_ITERATION = 4;
 
     private final Cluster fullMetadata;
     private final Map<TaskId, Set<TopicPartition>> partitionsForTask;

Reply via email to