ableegoldman commented on code in PR #16129: URL: https://github.com/apache/kafka/pull/16129#discussion_r1619808751
########## streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ConstrainedPrioritySet.java: ########## @@ -30,13 +30,13 @@ /** * Wraps a priority queue of clients and returns the next valid candidate(s) based on the current task assignment */ -class ConstrainedPrioritySet { +public class ConstrainedPrioritySet { private final PriorityQueue<UUID> clientsByTaskLoad; private final BiFunction<UUID, TaskId, Boolean> constraint; private final Set<UUID> uniqueClients = new HashSet<>(); - ConstrainedPrioritySet(final BiFunction<UUID, TaskId, Boolean> constraint, + public ConstrainedPrioritySet(final BiFunction<UUID, TaskId, Boolean> constraint, final Function<UUID, Double> weight) { Review Comment: nit: adjust parameter indentation ########## streams/src/main/java/org/apache/kafka/streams/processor/assignment/TaskAssignmentUtils.java: ########## @@ -407,4 +543,345 @@ private static boolean hasValidRackInformation(final TaskInfo task, } return true; } + + private static Map<ProcessId, KafkaStreamsAssignment> tagBasedStandbyTaskAssignment(final ApplicationState applicationState, + final Map<ProcessId, KafkaStreamsAssignment> kafkaStreamsAssignments) { + final int numStandbyReplicas = applicationState.assignmentConfigs().numStandbyReplicas(); + final Map<TaskId, Integer> tasksToRemainingStandbys = applicationState.allTasks().values().stream() + .collect(Collectors.toMap(TaskInfo::id, taskInfo -> numStandbyReplicas)); + final Map<ProcessId, KafkaStreamsState> streamStates = applicationState.kafkaStreamsStates(false); + + final Set<String> rackAwareAssignmentTags = new HashSet<>(getRackAwareAssignmentTags(applicationState)); + final TagStatistics tagStatistics = new TagStatistics(applicationState); + + final ConstrainedPrioritySet standbyTaskClientsByTaskLoad = standbyTaskPriorityListByLoad(streamStates, kafkaStreamsAssignments); + + final Set<TaskId> statefulTaskIds = applicationState.allTasks().values().stream() + .filter(TaskInfo::isStateful) + .map(TaskInfo::id) + .collect(Collectors.toSet()); + final Map<UUID, KafkaStreamsAssignment> clientsByUuid = kafkaStreamsAssignments.entrySet().stream().collect(Collectors.toMap( + entry -> entry.getKey().id(), + Map.Entry::getValue + )); + final Function<KafkaStreamsState, Map<String, String>> clientTagGetter = createClientTagGetter(applicationState); + + final Map<TaskId, ProcessId> pendingStandbyTasksToClientId = new HashMap<>(); + for (final TaskId statefulTaskId : statefulTaskIds) { + for (final KafkaStreamsAssignment assignment : clientsByUuid.values()) { + if (assignment.tasks().containsKey(statefulTaskId)) { + assignStandbyTasksToClientsWithDifferentTags( + numStandbyReplicas, + standbyTaskClientsByTaskLoad, + statefulTaskId, + assignment.processId(), + rackAwareAssignmentTags, + streamStates, + kafkaStreamsAssignments, + tasksToRemainingStandbys, + tagStatistics.tagKeyToValues, + tagStatistics.tagEntryToClients, + pendingStandbyTasksToClientId, + clientTagGetter + ); + } + } + } + + if (!tasksToRemainingStandbys.isEmpty()) { + assignPendingStandbyTasksToLeastLoadedClients(clientsByUuid, + numStandbyReplicas, + standbyTaskClientsByTaskLoad, + tasksToRemainingStandbys); + } + + return kafkaStreamsAssignments; + } + + private static Map<ProcessId, KafkaStreamsAssignment> loadBasedStandbyTaskAssignment(final ApplicationState applicationState, + final Map<ProcessId, KafkaStreamsAssignment> kafkaStreamsAssignments) { + final int numStandbyReplicas = applicationState.assignmentConfigs().numStandbyReplicas(); + final Map<TaskId, Integer> tasksToRemainingStandbys = applicationState.allTasks().values().stream() + .collect(Collectors.toMap(TaskInfo::id, taskInfo -> numStandbyReplicas)); + final Map<ProcessId, KafkaStreamsState> streamStates = applicationState.kafkaStreamsStates(false); + + final Set<TaskId> statefulTaskIds = applicationState.allTasks().values().stream() + .filter(TaskInfo::isStateful) + .map(TaskInfo::id) + .collect(Collectors.toSet()); + final Map<UUID, KafkaStreamsAssignment> clients = kafkaStreamsAssignments.entrySet().stream().collect(Collectors.toMap( + entry -> entry.getKey().id(), + Map.Entry::getValue + )); + + final ConstrainedPrioritySet standbyTaskClientsByTaskLoad = standbyTaskPriorityListByLoad(streamStates, kafkaStreamsAssignments); + standbyTaskClientsByTaskLoad.offerAll(streamStates.keySet().stream().map(ProcessId::id).collect(Collectors.toSet())); + for (final TaskId task : statefulTaskIds) { + assignStandbyTasksForActiveTask(numStandbyReplicas, clients, + tasksToRemainingStandbys, standbyTaskClientsByTaskLoad, task); + } + return kafkaStreamsAssignments; + } + + private static void assignStandbyTasksForActiveTask(final int numStandbyReplicas, + final Map<UUID, KafkaStreamsAssignment> clients, + final Map<TaskId, Integer> tasksToRemainingStandbys, + final ConstrainedPrioritySet standbyTaskClientsByTaskLoad, + final TaskId activeTaskId) { + int numRemainingStandbys = tasksToRemainingStandbys.get(activeTaskId); + while (numRemainingStandbys > 0) { + final UUID client = standbyTaskClientsByTaskLoad.poll(activeTaskId); + if (client == null) { + break; + } + clients.get(client).assignTask(new AssignedTask(activeTaskId, AssignedTask.Type.STANDBY)); + numRemainingStandbys--; + standbyTaskClientsByTaskLoad.offer(client); + tasksToRemainingStandbys.put(activeTaskId, numRemainingStandbys); Review Comment: nit: kind of confusing to update the map on every iteration of this loop, we should just do it once, when we have the final value after exiting the loop (I also can't tell if the `tasksToRemainingStandbys` values are ever even accessed again after calling this map, but I'm not 100% sure and it seems best to leave that map in a consistent state just in case it does get used again in the future 🤷♀️ ) ########## streams/src/main/java/org/apache/kafka/streams/processor/assignment/TaskAssignmentUtils.java: ########## @@ -407,4 +543,345 @@ private static boolean hasValidRackInformation(final TaskInfo task, } return true; } + + private static Map<ProcessId, KafkaStreamsAssignment> tagBasedStandbyTaskAssignment(final ApplicationState applicationState, + final Map<ProcessId, KafkaStreamsAssignment> kafkaStreamsAssignments) { + final int numStandbyReplicas = applicationState.assignmentConfigs().numStandbyReplicas(); + final Map<TaskId, Integer> tasksToRemainingStandbys = applicationState.allTasks().values().stream() + .collect(Collectors.toMap(TaskInfo::id, taskInfo -> numStandbyReplicas)); + final Map<ProcessId, KafkaStreamsState> streamStates = applicationState.kafkaStreamsStates(false); + + final Set<String> rackAwareAssignmentTags = new HashSet<>(getRackAwareAssignmentTags(applicationState)); + final TagStatistics tagStatistics = new TagStatistics(applicationState); + + final ConstrainedPrioritySet standbyTaskClientsByTaskLoad = standbyTaskPriorityListByLoad(streamStates, kafkaStreamsAssignments); + + final Set<TaskId> statefulTaskIds = applicationState.allTasks().values().stream() + .filter(TaskInfo::isStateful) + .map(TaskInfo::id) + .collect(Collectors.toSet()); + final Map<UUID, KafkaStreamsAssignment> clientsByUuid = kafkaStreamsAssignments.entrySet().stream().collect(Collectors.toMap( + entry -> entry.getKey().id(), + Map.Entry::getValue + )); + final Function<KafkaStreamsState, Map<String, String>> clientTagGetter = createClientTagGetter(applicationState); + + final Map<TaskId, ProcessId> pendingStandbyTasksToClientId = new HashMap<>(); + for (final TaskId statefulTaskId : statefulTaskIds) { + for (final KafkaStreamsAssignment assignment : clientsByUuid.values()) { + if (assignment.tasks().containsKey(statefulTaskId)) { + assignStandbyTasksToClientsWithDifferentTags( + numStandbyReplicas, + standbyTaskClientsByTaskLoad, + statefulTaskId, + assignment.processId(), + rackAwareAssignmentTags, + streamStates, + kafkaStreamsAssignments, + tasksToRemainingStandbys, + tagStatistics.tagKeyToValues, + tagStatistics.tagEntryToClients, + pendingStandbyTasksToClientId, + clientTagGetter + ); + } + } + } + + if (!tasksToRemainingStandbys.isEmpty()) { + assignPendingStandbyTasksToLeastLoadedClients(clientsByUuid, + numStandbyReplicas, + standbyTaskClientsByTaskLoad, + tasksToRemainingStandbys); + } + + return kafkaStreamsAssignments; + } + + private static Map<ProcessId, KafkaStreamsAssignment> loadBasedStandbyTaskAssignment(final ApplicationState applicationState, + final Map<ProcessId, KafkaStreamsAssignment> kafkaStreamsAssignments) { + final int numStandbyReplicas = applicationState.assignmentConfigs().numStandbyReplicas(); + final Map<TaskId, Integer> tasksToRemainingStandbys = applicationState.allTasks().values().stream() + .collect(Collectors.toMap(TaskInfo::id, taskInfo -> numStandbyReplicas)); + final Map<ProcessId, KafkaStreamsState> streamStates = applicationState.kafkaStreamsStates(false); + + final Set<TaskId> statefulTaskIds = applicationState.allTasks().values().stream() + .filter(TaskInfo::isStateful) + .map(TaskInfo::id) + .collect(Collectors.toSet()); + final Map<UUID, KafkaStreamsAssignment> clients = kafkaStreamsAssignments.entrySet().stream().collect(Collectors.toMap( + entry -> entry.getKey().id(), + Map.Entry::getValue + )); + + final ConstrainedPrioritySet standbyTaskClientsByTaskLoad = standbyTaskPriorityListByLoad(streamStates, kafkaStreamsAssignments); + standbyTaskClientsByTaskLoad.offerAll(streamStates.keySet().stream().map(ProcessId::id).collect(Collectors.toSet())); + for (final TaskId task : statefulTaskIds) { + assignStandbyTasksForActiveTask(numStandbyReplicas, clients, + tasksToRemainingStandbys, standbyTaskClientsByTaskLoad, task); Review Comment: nit: another weird linebreak, should be either all parameters on the same line or all on separate lines (though I know this wasn't your doing) ########## streams/src/main/java/org/apache/kafka/streams/processor/assignment/TaskAssignmentUtils.java: ########## @@ -407,4 +543,345 @@ private static boolean hasValidRackInformation(final TaskInfo task, } return true; } + + private static Map<ProcessId, KafkaStreamsAssignment> tagBasedStandbyTaskAssignment(final ApplicationState applicationState, + final Map<ProcessId, KafkaStreamsAssignment> kafkaStreamsAssignments) { + final int numStandbyReplicas = applicationState.assignmentConfigs().numStandbyReplicas(); + final Map<TaskId, Integer> tasksToRemainingStandbys = applicationState.allTasks().values().stream() + .collect(Collectors.toMap(TaskInfo::id, taskInfo -> numStandbyReplicas)); Review Comment: nit: this isn't a correctness issue, but for logging/debugging/inspection sanity we should only fill in this map with stateful tasks. Technically the "remaining standbys" for any stateless tasks is just 0 (though I'd rather just not include them at all in this map, not suggesting you set their value to 0) We can also save a small bit of processing overhead as well by just building this map based on the `statefulTaskIds` set that's initialized a bit later. We can move this down to after that gets created and then simplify it to just `statefulTaskIds.stream().collect(toMap(t -> t, t -> numStandbyReplicas))` ########## streams/src/main/java/org/apache/kafka/streams/processor/assignment/TaskAssignmentUtils.java: ########## @@ -407,4 +543,345 @@ private static boolean hasValidRackInformation(final TaskInfo task, } return true; } + + private static Map<ProcessId, KafkaStreamsAssignment> tagBasedStandbyTaskAssignment(final ApplicationState applicationState, + final Map<ProcessId, KafkaStreamsAssignment> kafkaStreamsAssignments) { + final int numStandbyReplicas = applicationState.assignmentConfigs().numStandbyReplicas(); + final Map<TaskId, Integer> tasksToRemainingStandbys = applicationState.allTasks().values().stream() + .collect(Collectors.toMap(TaskInfo::id, taskInfo -> numStandbyReplicas)); + final Map<ProcessId, KafkaStreamsState> streamStates = applicationState.kafkaStreamsStates(false); + + final Set<String> rackAwareAssignmentTags = new HashSet<>(getRackAwareAssignmentTags(applicationState)); + final TagStatistics tagStatistics = new TagStatistics(applicationState); + + final ConstrainedPrioritySet standbyTaskClientsByTaskLoad = standbyTaskPriorityListByLoad(streamStates, kafkaStreamsAssignments); + + final Set<TaskId> statefulTaskIds = applicationState.allTasks().values().stream() + .filter(TaskInfo::isStateful) + .map(TaskInfo::id) + .collect(Collectors.toSet()); + final Map<UUID, KafkaStreamsAssignment> clientsByUuid = kafkaStreamsAssignments.entrySet().stream().collect(Collectors.toMap( + entry -> entry.getKey().id(), + Map.Entry::getValue + )); + final Function<KafkaStreamsState, Map<String, String>> clientTagGetter = createClientTagGetter(applicationState); + + final Map<TaskId, ProcessId> pendingStandbyTasksToClientId = new HashMap<>(); + for (final TaskId statefulTaskId : statefulTaskIds) { + for (final KafkaStreamsAssignment assignment : clientsByUuid.values()) { + if (assignment.tasks().containsKey(statefulTaskId)) { Review Comment: one more minor bug: in the original code this is actually checking for assigned _active_ tasks specifically, so this should be ```suggestion if (assignment.tasks().containsKey(statefulTaskId) && assignment.tasks().get(statefulTaskId).type() == Type.ACTIVE) { ``` ########## streams/src/main/java/org/apache/kafka/streams/processor/assignment/TaskAssignmentUtils.java: ########## @@ -407,4 +543,345 @@ private static boolean hasValidRackInformation(final TaskInfo task, } return true; } + + private static Map<ProcessId, KafkaStreamsAssignment> tagBasedStandbyTaskAssignment(final ApplicationState applicationState, + final Map<ProcessId, KafkaStreamsAssignment> kafkaStreamsAssignments) { + final int numStandbyReplicas = applicationState.assignmentConfigs().numStandbyReplicas(); + final Map<TaskId, Integer> tasksToRemainingStandbys = applicationState.allTasks().values().stream() + .collect(Collectors.toMap(TaskInfo::id, taskInfo -> numStandbyReplicas)); + final Map<ProcessId, KafkaStreamsState> streamStates = applicationState.kafkaStreamsStates(false); + + final Set<String> rackAwareAssignmentTags = new HashSet<>(getRackAwareAssignmentTags(applicationState)); + final TagStatistics tagStatistics = new TagStatistics(applicationState); + + final ConstrainedPrioritySet standbyTaskClientsByTaskLoad = standbyTaskPriorityListByLoad(streamStates, kafkaStreamsAssignments); + + final Set<TaskId> statefulTaskIds = applicationState.allTasks().values().stream() + .filter(TaskInfo::isStateful) + .map(TaskInfo::id) + .collect(Collectors.toSet()); + final Map<UUID, KafkaStreamsAssignment> clientsByUuid = kafkaStreamsAssignments.entrySet().stream().collect(Collectors.toMap( + entry -> entry.getKey().id(), + Map.Entry::getValue + )); + final Function<KafkaStreamsState, Map<String, String>> clientTagGetter = createClientTagGetter(applicationState); + + final Map<TaskId, ProcessId> pendingStandbyTasksToClientId = new HashMap<>(); + for (final TaskId statefulTaskId : statefulTaskIds) { + for (final KafkaStreamsAssignment assignment : clientsByUuid.values()) { + if (assignment.tasks().containsKey(statefulTaskId)) { + assignStandbyTasksToClientsWithDifferentTags( + numStandbyReplicas, + standbyTaskClientsByTaskLoad, + statefulTaskId, + assignment.processId(), + rackAwareAssignmentTags, + streamStates, + kafkaStreamsAssignments, + tasksToRemainingStandbys, + tagStatistics.tagKeyToValues, + tagStatistics.tagEntryToClients, + pendingStandbyTasksToClientId, + clientTagGetter + ); + } + } + } + + if (!tasksToRemainingStandbys.isEmpty()) { + assignPendingStandbyTasksToLeastLoadedClients(clientsByUuid, + numStandbyReplicas, + standbyTaskClientsByTaskLoad, + tasksToRemainingStandbys); + } + + return kafkaStreamsAssignments; + } + + private static Map<ProcessId, KafkaStreamsAssignment> loadBasedStandbyTaskAssignment(final ApplicationState applicationState, + final Map<ProcessId, KafkaStreamsAssignment> kafkaStreamsAssignments) { + final int numStandbyReplicas = applicationState.assignmentConfigs().numStandbyReplicas(); + final Map<TaskId, Integer> tasksToRemainingStandbys = applicationState.allTasks().values().stream() + .collect(Collectors.toMap(TaskInfo::id, taskInfo -> numStandbyReplicas)); Review Comment: ditto here, move down a bit and then we can just initialize it as `statefulTaskIds.stream().collect(toMap(t -> t, t -> numStandbyReplicas))` ########## streams/src/main/java/org/apache/kafka/streams/processor/assignment/TaskAssignmentUtils.java: ########## @@ -407,4 +543,345 @@ private static boolean hasValidRackInformation(final TaskInfo task, } return true; } + + private static Map<ProcessId, KafkaStreamsAssignment> tagBasedStandbyTaskAssignment(final ApplicationState applicationState, + final Map<ProcessId, KafkaStreamsAssignment> kafkaStreamsAssignments) { + final int numStandbyReplicas = applicationState.assignmentConfigs().numStandbyReplicas(); + final Map<TaskId, Integer> tasksToRemainingStandbys = applicationState.allTasks().values().stream() + .collect(Collectors.toMap(TaskInfo::id, taskInfo -> numStandbyReplicas)); + final Map<ProcessId, KafkaStreamsState> streamStates = applicationState.kafkaStreamsStates(false); + + final Set<String> rackAwareAssignmentTags = new HashSet<>(getRackAwareAssignmentTags(applicationState)); + final TagStatistics tagStatistics = new TagStatistics(applicationState); + + final ConstrainedPrioritySet standbyTaskClientsByTaskLoad = standbyTaskPriorityListByLoad(streamStates, kafkaStreamsAssignments); + + final Set<TaskId> statefulTaskIds = applicationState.allTasks().values().stream() + .filter(TaskInfo::isStateful) + .map(TaskInfo::id) + .collect(Collectors.toSet()); + final Map<UUID, KafkaStreamsAssignment> clientsByUuid = kafkaStreamsAssignments.entrySet().stream().collect(Collectors.toMap( + entry -> entry.getKey().id(), + Map.Entry::getValue + )); + final Function<KafkaStreamsState, Map<String, String>> clientTagGetter = createClientTagGetter(applicationState); + + final Map<TaskId, ProcessId> pendingStandbyTasksToClientId = new HashMap<>(); + for (final TaskId statefulTaskId : statefulTaskIds) { + for (final KafkaStreamsAssignment assignment : clientsByUuid.values()) { + if (assignment.tasks().containsKey(statefulTaskId)) { + assignStandbyTasksToClientsWithDifferentTags( + numStandbyReplicas, + standbyTaskClientsByTaskLoad, + statefulTaskId, + assignment.processId(), + rackAwareAssignmentTags, + streamStates, + kafkaStreamsAssignments, + tasksToRemainingStandbys, + tagStatistics.tagKeyToValues, + tagStatistics.tagEntryToClients, + pendingStandbyTasksToClientId, + clientTagGetter + ); + } + } + } + + if (!tasksToRemainingStandbys.isEmpty()) { + assignPendingStandbyTasksToLeastLoadedClients(clientsByUuid, + numStandbyReplicas, + standbyTaskClientsByTaskLoad, + tasksToRemainingStandbys); + } + + return kafkaStreamsAssignments; + } + + private static Map<ProcessId, KafkaStreamsAssignment> loadBasedStandbyTaskAssignment(final ApplicationState applicationState, + final Map<ProcessId, KafkaStreamsAssignment> kafkaStreamsAssignments) { + final int numStandbyReplicas = applicationState.assignmentConfigs().numStandbyReplicas(); + final Map<TaskId, Integer> tasksToRemainingStandbys = applicationState.allTasks().values().stream() + .collect(Collectors.toMap(TaskInfo::id, taskInfo -> numStandbyReplicas)); + final Map<ProcessId, KafkaStreamsState> streamStates = applicationState.kafkaStreamsStates(false); + + final Set<TaskId> statefulTaskIds = applicationState.allTasks().values().stream() + .filter(TaskInfo::isStateful) + .map(TaskInfo::id) + .collect(Collectors.toSet()); + final Map<UUID, KafkaStreamsAssignment> clients = kafkaStreamsAssignments.entrySet().stream().collect(Collectors.toMap( + entry -> entry.getKey().id(), + Map.Entry::getValue + )); + + final ConstrainedPrioritySet standbyTaskClientsByTaskLoad = standbyTaskPriorityListByLoad(streamStates, kafkaStreamsAssignments); + standbyTaskClientsByTaskLoad.offerAll(streamStates.keySet().stream().map(ProcessId::id).collect(Collectors.toSet())); + for (final TaskId task : statefulTaskIds) { + assignStandbyTasksForActiveTask(numStandbyReplicas, clients, + tasksToRemainingStandbys, standbyTaskClientsByTaskLoad, task); + } + return kafkaStreamsAssignments; + } + + private static void assignStandbyTasksForActiveTask(final int numStandbyReplicas, + final Map<UUID, KafkaStreamsAssignment> clients, + final Map<TaskId, Integer> tasksToRemainingStandbys, + final ConstrainedPrioritySet standbyTaskClientsByTaskLoad, + final TaskId activeTaskId) { + int numRemainingStandbys = tasksToRemainingStandbys.get(activeTaskId); + while (numRemainingStandbys > 0) { + final UUID client = standbyTaskClientsByTaskLoad.poll(activeTaskId); + if (client == null) { + break; + } + clients.get(client).assignTask(new AssignedTask(activeTaskId, AssignedTask.Type.STANDBY)); + numRemainingStandbys--; + standbyTaskClientsByTaskLoad.offer(client); + tasksToRemainingStandbys.put(activeTaskId, numRemainingStandbys); + } + + if (numRemainingStandbys > 0) { + LOG.warn("Unable to assign {} of {} standby tasks for task [{}]. " + + "There is not enough available capacity. You should " + + "increase the number of application instances " + + "to maintain the requested number of standby replicas.", + numRemainingStandbys, numStandbyReplicas, activeTaskId); + } + } + + private static void assignStandbyTasksToClientsWithDifferentTags(final int numberOfStandbyClients, + final ConstrainedPrioritySet standbyTaskClientsByTaskLoad, + final TaskId activeTaskId, + final ProcessId activeTaskClient, + final Set<String> rackAwareAssignmentTags, + final Map<ProcessId, KafkaStreamsState> clientStates, + final Map<ProcessId, KafkaStreamsAssignment> kafkaStreamsAssignments, + final Map<TaskId, Integer> tasksToRemainingStandbys, + final Map<String, Set<String>> tagKeyToValues, + final Map<KeyValue<String, String>, Set<ProcessId>> tagEntryToClients, + final Map<TaskId, ProcessId> pendingStandbyTasksToClientId, + final Function<KafkaStreamsState, Map<String, String>> clientTagGetter) { + standbyTaskClientsByTaskLoad.offerAll(clientStates.keySet().stream() + .map(ProcessId::id).collect(Collectors.toSet())); + + // We set countOfUsedClients as 1 because client where active task is located has to be considered as used. + int countOfUsedClients = 1; + int numRemainingStandbys = tasksToRemainingStandbys.get(activeTaskId); + + final Map<KeyValue<String, String>, Set<ProcessId>> tagEntryToUsedClients = new HashMap<>(); + + ProcessId lastUsedClient = activeTaskClient; + do { + updateClientsOnAlreadyUsedTagEntries( + clientStates.get(lastUsedClient), + countOfUsedClients, + rackAwareAssignmentTags, + tagEntryToClients, + tagKeyToValues, + tagEntryToUsedClients, + clientTagGetter + ); + + final UUID clientOnUnusedTagDimensions = standbyTaskClientsByTaskLoad.poll( + activeTaskId, uuid -> !isClientUsedOnAnyOfTheTagEntries(new ProcessId(uuid), tagEntryToUsedClients) + ); + + if (clientOnUnusedTagDimensions == null) { + break; + } + + final KafkaStreamsState clientStateOnUsedTagDimensions = clientStates.get(new ProcessId(clientOnUnusedTagDimensions)); + countOfUsedClients++; + numRemainingStandbys--; + + LOG.debug("Assigning {} out of {} standby tasks for an active task [{}] with client tags {}. " + + "Standby task client tags are {}.", + numberOfStandbyClients - numRemainingStandbys, numberOfStandbyClients, activeTaskId, + clientTagGetter.apply(clientStates.get(activeTaskClient)), + clientTagGetter.apply(clientStateOnUsedTagDimensions)); + + kafkaStreamsAssignments.get(clientStateOnUsedTagDimensions.processId()).assignTask( + new AssignedTask(activeTaskId, AssignedTask.Type.STANDBY) + ); + lastUsedClient = new ProcessId(clientOnUnusedTagDimensions); + } while (numRemainingStandbys > 0); + + if (numRemainingStandbys > 0) { + pendingStandbyTasksToClientId.put(activeTaskId, activeTaskClient); + tasksToRemainingStandbys.put(activeTaskId, numRemainingStandbys); + LOG.warn("Rack aware standby task assignment was not able to assign {} of {} standby tasks for the " + + "active task [{}] with the rack aware assignment tags {}. " + + "This may happen when there aren't enough application instances on different tag " + + "dimensions compared to an active and corresponding standby task. " + + "Consider launching application instances on different tag dimensions than [{}]. " + + "Standby task assignment will fall back to assigning standby tasks to the least loaded clients.", + numRemainingStandbys, numberOfStandbyClients, + activeTaskId, rackAwareAssignmentTags, + clientTagGetter.apply(clientStates.get(activeTaskClient))); + + } else { + tasksToRemainingStandbys.remove(activeTaskId); + } + } + + private static boolean isClientUsedOnAnyOfTheTagEntries(final ProcessId client, + final Map<KeyValue<String, String>, Set<ProcessId>> tagEntryToUsedClients) { + return tagEntryToUsedClients.values().stream().anyMatch(usedClients -> usedClients.contains(client)); + } + + private static void updateClientsOnAlreadyUsedTagEntries(final KafkaStreamsState usedClient, + final int countOfUsedClients, + final Set<String> rackAwareAssignmentTags, + final Map<KeyValue<String, String>, Set<ProcessId>> tagEntryToClients, + final Map<String, Set<String>> tagKeyToValues, + final Map<KeyValue<String, String>, Set<ProcessId>> tagEntryToUsedClients, + final Function<KafkaStreamsState, Map<String, String>> clientTagGetter) { + final Map<String, String> usedClientTags = clientTagGetter.apply(usedClient); + + for (final Map.Entry<String, String> usedClientTagEntry : usedClientTags.entrySet()) { + final String tagKey = usedClientTagEntry.getKey(); + + if (!rackAwareAssignmentTags.contains(tagKey)) { + LOG.warn("Client tag with key [{}] will be ignored when computing rack aware standby " + + "task assignment because it is not part of the configured rack awareness [{}].", + tagKey, rackAwareAssignmentTags); + continue; + } + + final Set<String> allTagValues = tagKeyToValues.get(tagKey); + + if (allTagValues.size() <= countOfUsedClients) { + allTagValues.forEach(tagValue -> tagEntryToUsedClients.remove(new KeyValue<>(tagKey, tagValue))); + } else { + final String tagValue = usedClientTagEntry.getValue(); + final KeyValue<String, String> tagEntry = new KeyValue<>(tagKey, tagValue); + final Set<ProcessId> clientsOnUsedTagValue = tagEntryToClients.get(tagEntry); + tagEntryToUsedClients.put(tagEntry, clientsOnUsedTagValue); + } + } + } + + private static Function<KafkaStreamsState, Map<String, String>> createClientTagGetter(final ApplicationState applicationState) { + final boolean hasRackAwareAssignmentTags = !applicationState.assignmentConfigs().rackAwareAssignmentTags().isEmpty(); + final boolean canPerformRackAwareOptimization = canPerformRackAwareOptimization(applicationState, AssignedTask.Type.STANDBY); + + if (hasRackAwareAssignmentTags || !canPerformRackAwareOptimization) { + return KafkaStreamsState::clientTags; + } else { + return state -> mkMap(mkEntry("rack", state.rackId().get())); + } + } + + private static List<String> getRackAwareAssignmentTags(final ApplicationState applicationState) { + final boolean hasRackAwareAssignmentTags = !applicationState.assignmentConfigs().rackAwareAssignmentTags().isEmpty(); + + if (hasRackAwareAssignmentTags) { + return applicationState.assignmentConfigs().rackAwareAssignmentTags(); + } else if (canPerformRackAwareOptimization(applicationState, AssignedTask.Type.STANDBY)) { + return Collections.singletonList("rack"); Review Comment: Similar to [this comment](https://github.com/apache/kafka/pull/16129/files#r1619754391), I don't think the middle case exists in non-testing code, and we shouldn't have special cases for tests in production code. Also, once you make the simplification from that comment, it should only be possible to reach this if there are rack aware tags. So we should be able to just inline the `applicationState.assignmentConfigs().rackAwareAssignmentTags()` and get rid of this method altogether -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: jira-unsubscr...@kafka.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org