ableegoldman commented on code in PR #16033:
URL: https://github.com/apache/kafka/pull/16033#discussion_r1618288306


##########
streams/src/main/java/org/apache/kafka/streams/processor/assignment/TaskAssignmentUtils.java:
##########
@@ -16,78 +16,408 @@
  */
 package org.apache.kafka.streams.processor.assignment;
 
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
 import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
 import java.util.SortedSet;
+import java.util.UUID;
+import java.util.stream.Collectors;
 import org.apache.kafka.streams.processor.TaskId;
+import 
org.apache.kafka.streams.processor.assignment.KafkaStreamsAssignment.AssignedTask;
+import org.apache.kafka.streams.processor.internals.assignment.Graph;
+import 
org.apache.kafka.streams.processor.internals.assignment.MinTrafficGraphConstructor;
+import 
org.apache.kafka.streams.processor.internals.assignment.RackAwareGraphConstructor;
+import 
org.apache.kafka.streams.processor.internals.assignment.RackAwareGraphConstructorFactory;
+import org.apache.kafka.streams.StreamsConfig;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 /**
  * A set of utilities to help implement task assignment via the {@link 
TaskAssignor}
  */
 public final class TaskAssignmentUtils {
+    private static final Logger LOG = 
LoggerFactory.getLogger(TaskAssignmentUtils.class);
+
+    private TaskAssignmentUtils() {}
+
     /**
-     * Assign standby tasks to KafkaStreams clients according to the default 
logic.
-     * <p>
-     * If rack-aware client tags are configured, the rack-aware standby task 
assignor will be used
+     * Return a "no-op" assignment that just copies the previous assignment of 
tasks to KafkaStreams clients
      *
-     * @param applicationState        the metadata and other info describing 
the current application state
-     * @param kafkaStreamsAssignments the current assignment of tasks to 
KafkaStreams clients
+     * @param applicationState the metadata and other info describing the 
current application state
      *
-     * @return a new map containing the mappings from KafkaStreamsAssignments 
updated with the default
-     *         standby assignment
+     * @return a new map containing an assignment that replicates exactly the 
previous assignment reported
+     *         in the applicationState
      */
-    public static Map<ProcessId, KafkaStreamsAssignment> 
defaultStandbyTaskAssignment(
-        final ApplicationState applicationState,
-        final Map<ProcessId, KafkaStreamsAssignment> kafkaStreamsAssignments
-    ) {
-        throw new UnsupportedOperationException("Not Implemented.");
+    public static Map<ProcessId, KafkaStreamsAssignment> 
identityAssignment(final ApplicationState applicationState) {
+        final Map<ProcessId, KafkaStreamsAssignment> assignments = new 
HashMap<>();
+        applicationState.kafkaStreamsStates(false).forEach((processId, state) 
-> {
+            final Set<AssignedTask> tasks = new HashSet<>();
+            state.previousActiveTasks().forEach(taskId -> {
+                tasks.add(new AssignedTask(taskId,
+                    AssignedTask.Type.ACTIVE));
+            });
+            state.previousStandbyTasks().forEach(taskId -> {
+                tasks.add(new AssignedTask(taskId,
+                    AssignedTask.Type.STANDBY));
+            });
+
+            final KafkaStreamsAssignment newAssignment = 
KafkaStreamsAssignment.of(processId, tasks);
+            assignments.put(processId, newAssignment);
+        });
+        return assignments;
     }
 
     /**
-     * Optimize the active task assignment for rack-awareness
+     * Optimize active task assignment for rack awareness. This optimization 
is based on the
+     * {@link StreamsConfig#RACK_AWARE_ASSIGNMENT_TRAFFIC_COST_CONFIG 
trafficCost}
+     * and {@link StreamsConfig#RACK_AWARE_ASSIGNMENT_NON_OVERLAP_COST_CONFIG 
nonOverlapCost}
+     * configs which balance cross rack traffic minimization and task movement.
+     * Setting {@code trafficCost} to a larger number reduces the overall 
cross rack traffic of the resulting
+     * assignment, but can increase the number of tasks shuffled around 
between clients.
+     * Setting {@code nonOverlapCost} to a larger number increases the 
affinity of tasks to their intended client
+     * and reduces the amount by which the rack-aware optimization can shuffle 
tasks around, at the cost of higher
+     * cross-rack traffic.
+     * In an extreme case, if we set {@code nonOverlapCost} to 0 and @{code 
trafficCost} to a positive value,
+     * the resulting assignment will have an absolute minimum of cross rack 
traffic. If we set {@code trafficCost} to 0,
+     * and {@code nonOverlapCost} to a positive value, the resulting 
assignment will be identical to the input assignment.
+     * <p>
+     * This method optimizes cross-rack traffic for active tasks only. For 
standby task optimization,
+     * use {@link #optimizeRackAwareStandbyTasks}.
      *
      * @param applicationState        the metadata and other info describing 
the current application state
      * @param kafkaStreamsAssignments the current assignment of tasks to 
KafkaStreams clients
-     * @param tasks                   the set of tasks to reassign if 
possible. Must already be assigned
-     *                                to a KafkaStreams client
+     * @param tasks                   the set of tasks to reassign if 
possible. Must already be assigned to a KafkaStreams client
      *
-     * @return a new map containing the mappings from KafkaStreamsAssignments 
updated with the default
-     *         rack-aware assignment for active tasks
+     * @return a new map containing the mappings from KafkaStreamsAssignments 
updated with the default rack-aware assignment for active tasks
      */
     public static Map<ProcessId, KafkaStreamsAssignment> 
optimizeRackAwareActiveTasks(
         final ApplicationState applicationState,
         final Map<ProcessId, KafkaStreamsAssignment> kafkaStreamsAssignments,
         final SortedSet<TaskId> tasks
     ) {
-        throw new UnsupportedOperationException("Not Implemented.");
+        if (tasks.isEmpty()) {
+            return kafkaStreamsAssignments;
+        }
+
+        if (!hasValidRackInformation(applicationState)) {
+            LOG.warn("Cannot optimize active tasks with invalid rack 
information.");
+            return kafkaStreamsAssignments;
+        }
+
+        final int crossRackTrafficCost = 
applicationState.assignmentConfigs().rackAwareTrafficCost();
+        final int nonOverlapCost = 
applicationState.assignmentConfigs().rackAwareNonOverlapCost();
+        final long currentCost = computeTaskCost(
+            applicationState.allTasks().stream().filter(taskInfo -> 
tasks.contains(taskInfo.id())).collect(
+                Collectors.toSet()),
+            applicationState.kafkaStreamsStates(false),
+            crossRackTrafficCost,
+            nonOverlapCost,
+            false,
+            false
+        );
+        LOG.info("Assignment before active task optimization has cost {}", 
currentCost);
+
+        final List<UUID> clientIds = 
kafkaStreamsAssignments.keySet().stream().map(ProcessId::id).collect(
+            Collectors.toList());
+        final Map<ProcessId, KafkaStreamsState> kafkaStreamsStates = 
applicationState.kafkaStreamsStates(false);
+        final Map<UUID, Optional<String>> clientRacks = 
kafkaStreamsStates.values().stream().collect(
+                Collectors.toMap(state -> state.processId().id(), 
KafkaStreamsState::rackId));
+        final Map<UUID, Set<TaskId>> previousTaskIdsByProcess = 
kafkaStreamsStates.values().stream().collect(Collectors.toMap(
+            state -> state.processId().id(),
+            KafkaStreamsState::previousActiveTasks
+        ));
+        final Map<TaskId, Set<TaskTopicPartition>> topicPartitionsByTaskId = 
applicationState.allTasks().stream()
+            .filter(taskInfo -> tasks.contains(taskInfo.id()))
+            .collect(Collectors.toMap(TaskInfo::id, 
TaskInfo::topicPartitions));
+
+        final List<TaskId> taskIds = new ArrayList<>(tasks);
+        final RackAwareGraphConstructor<UUID> graphConstructor = 
RackAwareGraphConstructorFactory.create(
+            
applicationState.assignmentConfigs().rackAwareAssignmentStrategy(), taskIds);
+        final AssignmentGraph assignmentGraph = buildTaskGraph(
+            clientIds,
+            clientRacks,
+            taskIds,
+            previousTaskIdsByProcess,
+            topicPartitionsByTaskId,
+            crossRackTrafficCost,
+            nonOverlapCost,
+            false,
+            false,
+            graphConstructor
+        );
+
+        assignmentGraph.graph.solveMinCostFlow();
+
+        final Map<UUID, Set<AssignedTask>> reassigned = new HashMap<>();
+        final Map<UUID, Set<TaskId>> unassigned = new HashMap<>();
+        graphConstructor.assignTaskFromMinCostFlow(
+            assignmentGraph.graph,
+            clientIds,
+            taskIds,
+            clientIds.stream().collect(Collectors.toMap(id -> id, id -> id)),
+            assignmentGraph.taskCountByClient,
+            assignmentGraph.clientByTask,
+            (processId, taskId) -> {
+                reassigned.computeIfAbsent(processId, k -> new HashSet<>());
+                reassigned.get(processId).add(new AssignedTask(taskId, 
AssignedTask.Type.ACTIVE));
+            },
+            (processId, taskId) -> {
+                unassigned.computeIfAbsent(processId, k -> new HashSet<>());
+                unassigned.get(processId).add(taskId);
+            },
+            (processId, taskId) -> {
+                return 
previousTaskIdsByProcess.get(processId).contains(taskId);
+            }
+        );
+
+        return processTaskMoves(kafkaStreamsAssignments.values(), reassigned, 
unassigned);
     }
 
     /**
-     * Optimize the standby task assignment for rack-awareness
+     * Optimize standby task assignment for rack awareness. This optimization 
is based on the
+     * {@link StreamsConfig#RACK_AWARE_ASSIGNMENT_TRAFFIC_COST_CONFIG 
trafficCost}
+     * and {@link StreamsConfig#RACK_AWARE_ASSIGNMENT_NON_OVERLAP_COST_CONFIG 
nonOverlapCost}
+     * configs which balance cross rack traffic minimization and task movement.
+     * Setting {@code trafficCost} to a larger number reduces the overall 
cross rack traffic of the resulting
+     * assignment, but can increase the number of tasks shuffled around 
between clients.
+     * Setting {@code nonOverlapCost} to a larger number increases the 
affinity of tasks to their intended client
+     * and reduces the amount by which the rack-aware optimization can shuffle 
tasks around, at the cost of higher
+     * cross-rack traffic.
+     * In an extreme case, if we set {@code nonOverlapCost} to 0 and @{code 
trafficCost} to a positive value,
+     * the resulting assignment will have an absolute minimum of cross rack 
traffic. If we set {@code trafficCost} to 0,
+     * and {@code nonOverlapCost} to a positive value, the resulting 
assignment will be identical to the input assignment.
+     * <p>
+     * This method optimizes cross-rack traffic for standby tasks only. For 
active task optimization,
+     * use {@link #optimizeRackAwareActiveTasks}.
      *
      * @param kafkaStreamsAssignments the current assignment of tasks to 
KafkaStreams clients
      * @param applicationState        the metadata and other info describing 
the current application state
      *
-     * @return a new map containing the mappings from KafkaStreamsAssignments 
updated with the default
-     *         rack-aware assignment for standby tasks
+     * @return a new map containing the mappings from KafkaStreamsAssignments 
updated with the default rack-aware assignment for standy tasks
      */
     public static Map<ProcessId, KafkaStreamsAssignment> 
optimizeRackAwareStandbyTasks(
         final ApplicationState applicationState,
         final Map<ProcessId, KafkaStreamsAssignment> kafkaStreamsAssignments
     ) {
+        if (!hasValidRackInformation(applicationState)) {
+            LOG.warn("Cannot optimize standby tasks with invalid rack 
information.");
+            return kafkaStreamsAssignments;
+        }
+
+        final int crossRackTrafficCost = 
applicationState.assignmentConfigs().rackAwareTrafficCost();
+        final int nonOverlapCost = 
applicationState.assignmentConfigs().rackAwareNonOverlapCost();
+        final long currentCost = computeTaskCost(
+            applicationState.allTasks(),
+            applicationState.kafkaStreamsStates(false),
+            crossRackTrafficCost,
+            nonOverlapCost,
+            true,
+            true
+        );
+        LOG.info("Assignment before standby task optimization has cost {}", 
currentCost);
         throw new UnsupportedOperationException("Not Implemented.");
     }
 
+    private static long computeTaskCost(final Set<TaskInfo> tasks,
+                                        final Map<ProcessId, 
KafkaStreamsState> clients,
+                                        final int crossRackTrafficCost,
+                                        final int nonOverlapCost,
+                                        final boolean hasReplica,
+                                        final boolean isStandby) {
+        if (tasks.isEmpty()) {
+            return 0;
+        }
+
+        final List<UUID> clientIds = 
clients.keySet().stream().map(ProcessId::id).collect(
+            Collectors.toList());
+
+        final List<TaskId> taskIds = 
tasks.stream().map(TaskInfo::id).collect(Collectors.toList());
+        final Map<TaskId, Set<TaskTopicPartition>> topicPartitionsByTaskId = 
tasks.stream().collect(
+            Collectors.toMap(TaskInfo::id, TaskInfo::topicPartitions));
+
+        final Map<UUID, Optional<String>> clientRacks = 
clients.values().stream().collect(
+            Collectors.toMap(state -> state.processId().id(), 
KafkaStreamsState::rackId));
+
+        final Map<UUID, Set<TaskId>> taskIdsByProcess = 
clients.values().stream().collect(
+            Collectors.toMap(state -> state.processId().id(), state -> {
+                if (isStandby) {
+                    return state.previousStandbyTasks();
+                }
+                return state.previousActiveTasks();
+            })
+        );
+
+        final RackAwareGraphConstructor<UUID> graphConstructor = new 
MinTrafficGraphConstructor<>();
+        final AssignmentGraph assignmentGraph = buildTaskGraph(clientIds, 
clientRacks, taskIds, taskIdsByProcess, topicPartitionsByTaskId,
+            crossRackTrafficCost, nonOverlapCost, hasReplica, isStandby, 
graphConstructor);
+        return assignmentGraph.graph.totalCost();
+    }
+
+    private static AssignmentGraph buildTaskGraph(final List<UUID> clientIds,
+                                                  final Map<UUID, 
Optional<String>> clientRacks,
+                                                  final List<TaskId> taskIds,
+                                                  final Map<UUID, Set<TaskId>> 
previousTaskIdsByProcess,
+                                                  final Map<TaskId, 
Set<TaskTopicPartition>> topicPartitionsByTaskId,
+                                                  final int 
crossRackTrafficCost,
+                                                  final int nonOverlapCost,
+                                                  final boolean hasReplica,
+                                                  final boolean isStandby,
+                                                  final 
RackAwareGraphConstructor<UUID> graphConstructor) {
+        final Map<UUID, UUID> clientsUuidByUuid = 
clientIds.stream().collect(Collectors.toMap(id -> id, id -> id));
+        final Map<TaskId, UUID> clientByTask = new HashMap<>();
+        final Map<UUID, Integer> taskCountByClient = new HashMap<>();
+        final Graph<Integer> graph = graphConstructor.constructTaskGraph(
+            clientIds,
+            taskIds,
+            clientsUuidByUuid,
+            clientByTask,
+            taskCountByClient,
+            (processId, taskId) -> {
+                return 
previousTaskIdsByProcess.get(processId).contains(taskId);
+            },
+            (taskId, processId, inCurrentAssignment, unused0, unused1, 
unused2) -> {
+                final int assignmentChangeCost = !inCurrentAssignment ? 
nonOverlapCost : 0;
+                final Optional<String> clientRack = clientRacks.get(processId);
+                final Set<TaskTopicPartition> topicPartitions = 
topicPartitionsByTaskId.get(taskId).stream().filter(tp -> {
+                    return isStandby ? tp.isChangelog() : true;
+                }).collect(Collectors.toSet());
+                return assignmentChangeCost + 
getCrossRackTrafficCost(topicPartitions, clientRack, crossRackTrafficCost);
+            },
+            crossRackTrafficCost,
+            nonOverlapCost,
+            hasReplica,
+            isStandby
+        );
+        return new AssignmentGraph(graph, clientByTask, taskCountByClient);
+    }
+
+    /**
+     * This internal structure is used to keep track of the graph solving 
outputs alongside the graph
+     * structure itself.
+     */
+    private static final class AssignmentGraph {
+        public final Graph<Integer> graph;
+        public final Map<TaskId, UUID> clientByTask;
+        public final Map<UUID, Integer> taskCountByClient;
+
+        public AssignmentGraph(final Graph<Integer> graph,
+                               final Map<TaskId, UUID> clientByTask,
+                               final Map<UUID, Integer> taskCountByClient) {
+            this.graph = graph;
+            this.clientByTask = clientByTask;
+            this.taskCountByClient = taskCountByClient;
+        }
+    }
+
     /**
-     * Return a "no-op" assignment that just copies the previous assignment of 
tasks to KafkaStreams clients
      *
-     * @param applicationState the metadata and other info describing the 
current application state
+     * @return the traffic cost of assigning this {@param task} to the client 
{@param streamsState}.
+     */
+    private static int getCrossRackTrafficCost(final Set<TaskTopicPartition> 
topicPartitions,
+                                               final Optional<String> 
clientRack,
+                                               final int crossRackTrafficCost) 
{
+        if (!clientRack.isPresent()) {
+            throw new IllegalStateException("Client doesn't have rack 
configured.");
+        }
+
+        int cost = 0;
+        for (final TaskTopicPartition topicPartition : topicPartitions) {
+            final Optional<Set<String>> topicPartitionRacks = 
topicPartition.rackIds();
+            if (topicPartitionRacks == null || 
!topicPartitionRacks.isPresent()) {
+                throw new IllegalStateException("TopicPartition " + 
topicPartition + " has no rack information.");
+            }
+
+            if (topicPartitionRacks.get().contains(clientRack.get())) {
+                continue;
+            }
+
+            cost += crossRackTrafficCost;
+        }
+        return cost;
+    }
+
+    /**
+     * This function returns whether the current application state has the 
required rack information
+     * to make assignment decisions with.
+     * This includes ensuring that every client has a known rack id, and that 
every topic partition for
+     * every logical task that needs to be assigned also has known rack ids.
+     * If a logical task has no source topic partitions, it will return false.
+     * If standby tasks are configured, but a logical task has no changelog 
topic partitions, it will return false.
      *
-     * @return a new map containing an assignment that replicates exactly the 
previous assignment reported
-     *         in the applicationState
+     * @return whether rack-aware assignment decisions can be made for this 
application.
      */
-    public static Map<ProcessId, KafkaStreamsAssignment> identityAssignment(
-        final ApplicationState applicationState
-    ) {
-        throw new UnsupportedOperationException("Not Implemented.");
+    private static boolean hasValidRackInformation(final ApplicationState 
applicationState) {
+        for (final KafkaStreamsState state : 
applicationState.kafkaStreamsStates(false).values()) {
+            if (!hasValidRackInformation(state)) {
+                return false;
+            }
+        }
+
+        for (final TaskInfo task : applicationState.allTasks()) {
+            if (!hasValidRackInformation(task)) {
+                return false;
+            }
+        }
+        return true;
+    }
+
+    private static boolean hasValidRackInformation(final KafkaStreamsState 
state) {
+        if (!state.rackId().isPresent()) {
+            LOG.error("Client " + state.processId() + " doesn't have rack 
configured.");
+            return false;
+        }
+        return true;
+    }
+
+    private static boolean hasValidRackInformation(final TaskInfo task) {
+        for (final TaskTopicPartition topicPartition : task.topicPartitions()) 
{
+            final Optional<Set<String>> racks = topicPartition.rackIds();
+            if (!racks.isPresent() || racks.get().isEmpty()) {
+                LOG.error("Topic partition {} for task {} does not have racks 
configured.", topicPartition, task.id());
+                return false;
+            }
+        }
+        return true;
+    }
+
+    /**
+     * This function returns a copy of the old collection of {@code 
KafkaStreamsAssignment} with tasks
+     * moved according to the {@param reassigned} tasks and {@param 
unassigned} tasks.
+     *
+     * @param kafkaStreamsAssignments the collection to start from when moving 
tasks from process to process
+     * @param reassigned              the map from process id to tasks that 
this client has newly been assigned
+     * @param unassigned              the map from process id to tasks that 
this client has newly been unassigned
+     *
+     * @return the new mapping from processId to {@code 
KafkaStreamsAssignment}.
+     */
+    private static Map<ProcessId, KafkaStreamsAssignment> 
processTaskMoves(final Collection<KafkaStreamsAssignment> 
kafkaStreamsAssignments,

Review Comment:
   Alright I went ahead and tried this out for myself, I think it helps a lot 
to simplify things in this PR so it makes sense to incorporate here rather than 
doing it as an entirely separate PR. Lmk your thoughts as always: 
https://github.com/apourchet/kafka/pull/1



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: jira-unsubscr...@kafka.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to