This is an automated email from the ASF dual-hosted git repository.

mjsax pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 60a51170014 KAFKA-15022: [7/N] use RackAwareTaskAssignor in HAAssignor 
(#14139)
60a51170014 is described below

commit 60a51170014622f65a22675ac54ec78058299fda
Author: Hao Li <1127478+lihao...@users.noreply.github.com>
AuthorDate: Tue Aug 8 08:01:05 2023 -0700

    KAFKA-15022: [7/N] use RackAwareTaskAssignor in HAAssignor (#14139)
    
    Part of KIP-915.
    
    - Change TaskAssignor interface to take RackAwareTaskAssignor
    - Integrate RackAwareTaskAssignor to StreamsPartitionAssignor and 
HighAvailabilityTaskAssignor
    - Update HAAssignor tests
    
    Reviewers: Anna Sophie Blee-Goldman <ableegold...@apache.org>, Matthias J. 
Sax <matth...@confluent.io>
---
 .../processor/internals/ChangelogTopics.java       |   4 +
 .../internals/StreamsPartitionAssignor.java        |  12 +-
 .../assignment/FallbackPriorTaskAssignor.java      |   4 +-
 .../assignment/HighAvailabilityTaskAssignor.java   |  42 +-
 .../assignment/RackAwareTaskAssignor.java          |   3 +
 .../internals/assignment/StandbyTaskAssignor.java  |  16 +
 .../internals/assignment/StickyTaskAssignor.java   |   2 +
 .../internals/assignment/TaskAssignor.java         |  11 +-
 .../assignment/FallbackPriorTaskAssignorTest.java  |   1 +
 .../HighAvailabilityTaskAssignorTest.java          | 756 +++++++++++++++++----
 .../assignment/StickyTaskAssignorTest.java         |   2 +
 .../assignment/TaskAssignorConvergenceTest.java    |   1 +
 12 files changed, 714 insertions(+), 140 deletions(-)

diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogTopics.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogTopics.java
index c5f7067be2c..aaf8ba16a51 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogTopics.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogTopics.java
@@ -131,4 +131,8 @@ public class ChangelogTopics {
     public Set<TaskId> statefulTaskIds() {
         return 
Collections.unmodifiableSet(changelogPartitionsForStatefulTask.keySet());
     }
+
+    public Map<TaskId, Set<TopicPartition>> changelogPartionsForTask() {
+        return Collections.unmodifiableMap(changelogPartitionsForStatefulTask);
+    }
 }
\ No newline at end of file
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java
index 83f705a2561..018f4237474 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java
@@ -16,6 +16,7 @@
  */
 package org.apache.kafka.streams.processor.internals;
 
+import java.util.Optional;
 import org.apache.kafka.clients.admin.Admin;
 import org.apache.kafka.clients.admin.ListOffsetsResult;
 import org.apache.kafka.clients.admin.ListOffsetsResult.ListOffsetsResultInfo;
@@ -46,6 +47,7 @@ import 
org.apache.kafka.streams.processor.internals.assignment.AssignorError;
 import org.apache.kafka.streams.processor.internals.assignment.ClientState;
 import 
org.apache.kafka.streams.processor.internals.assignment.CopartitionedTopicsEnforcer;
 import 
org.apache.kafka.streams.processor.internals.assignment.FallbackPriorTaskAssignor;
+import 
org.apache.kafka.streams.processor.internals.assignment.RackAwareTaskAssignor;
 import 
org.apache.kafka.streams.processor.internals.assignment.ReferenceContainer;
 import 
org.apache.kafka.streams.processor.internals.assignment.StickyTaskAssignor;
 import 
org.apache.kafka.streams.processor.internals.assignment.SubscriptionInfo;
@@ -316,6 +318,7 @@ public class StreamsPartitionAssignor implements 
ConsumerPartitionAssignor, Conf
 
         final Map<UUID, ClientMetadata> clientMetadataMap = new HashMap<>();
         final Set<TopicPartition> allOwnedPartitions = new HashSet<>();
+        final Map<UUID, Map<String, Optional<String>>> racksForProcessConsumer 
= new HashMap<>();
 
         int minReceivedMetadataVersion = LATEST_SUPPORTED_VERSION;
         int minSupportedMetadataVersion = LATEST_SUPPORTED_VERSION;
@@ -346,6 +349,8 @@ public class StreamsPartitionAssignor implements 
ConsumerPartitionAssignor, Conf
                 processId = info.processId();
             }
 
+            racksForProcessConsumer.computeIfAbsent(processId, kv -> new 
HashMap<>()).put(consumerId, subscription.rackId());
+
             ClientMetadata clientMetadata = clientMetadataMap.get(processId);
 
             // create the new client metadata if necessary
@@ -410,7 +415,8 @@ public class StreamsPartitionAssignor implements 
ConsumerPartitionAssignor, Conf
 
             final Set<TaskId> statefulTasks = new HashSet<>();
 
-            final boolean probingRebalanceNeeded = 
assignTasksToClients(fullMetadata, allSourceTopics, topicGroups, 
clientMetadataMap, partitionsForTask, statefulTasks);
+            final boolean probingRebalanceNeeded = 
assignTasksToClients(fullMetadata, allSourceTopics, topicGroups,
+                clientMetadataMap, partitionsForTask, racksForProcessConsumer, 
statefulTasks);
 
             // ---------------- Step Three ---------------- //
 
@@ -597,6 +603,7 @@ public class StreamsPartitionAssignor implements 
ConsumerPartitionAssignor, Conf
                                          final Map<Subtopology, TopicsInfo> 
topicGroups,
                                          final Map<UUID, ClientMetadata> 
clientMetadataMap,
                                          final Map<TaskId, 
Set<TopicPartition>> partitionsForTask,
+                                         final Map<UUID, Map<String, 
Optional<String>>> racksForProcessConsumer,
                                          final Set<TaskId> statefulTasks) {
         if (!statefulTasks.isEmpty()) {
             throw new TaskAssignmentException("The stateful tasks should not 
be populated before assigning tasks to clients");
@@ -633,9 +640,12 @@ public class StreamsPartitionAssignor implements 
ConsumerPartitionAssignor, Conf
 
         final TaskAssignor taskAssignor = 
createTaskAssignor(lagComputationSuccessful);
 
+        final RackAwareTaskAssignor rackAwareTaskAssignor = new 
RackAwareTaskAssignor(fullMetadata, partitionsForTask,
+            changelogTopics.changelogPartionsForTask(), tasksForTopicGroup, 
racksForProcessConsumer, internalTopicManager, assignmentConfigs);
         final boolean probingRebalanceNeeded = 
taskAssignor.assign(clientStates,
                                                                    allTasks,
                                                                    
statefulTasks,
+                                                                   
Optional.of(rackAwareTaskAssignor),
                                                                    
assignmentConfigs);
 
         log.info("{} assigned tasks {} including stateful {} to {} clients as: 
\n{}.",
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/FallbackPriorTaskAssignor.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/FallbackPriorTaskAssignor.java
index 58456ac7ac3..562a3d0a2f9 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/FallbackPriorTaskAssignor.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/FallbackPriorTaskAssignor.java
@@ -16,6 +16,7 @@
  */
 package org.apache.kafka.streams.processor.internals.assignment;
 
+import java.util.Optional;
 import org.apache.kafka.streams.processor.TaskId;
 import 
org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration.AssignmentConfigs;
 
@@ -42,8 +43,9 @@ public class FallbackPriorTaskAssignor implements 
TaskAssignor {
     public boolean assign(final Map<UUID, ClientState> clients,
                           final Set<TaskId> allTaskIds,
                           final Set<TaskId> statefulTaskIds,
+                          final Optional<RackAwareTaskAssignor> 
rackAwareTaskAssignor,
                           final AssignmentConfigs configs) {
-        delegate.assign(clients, allTaskIds, statefulTaskIds, configs);
+        delegate.assign(clients, allTaskIds, statefulTaskIds, 
rackAwareTaskAssignor, configs);
         return true;
     }
 }
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignor.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignor.java
index f402f8c279d..ac47085f5cc 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignor.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignor.java
@@ -16,6 +16,7 @@
  */
 package org.apache.kafka.streams.processor.internals.assignment;
 
+import java.util.Optional;
 import org.apache.kafka.streams.processor.TaskId;
 import org.apache.kafka.streams.processor.internals.Task;
 import 
org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration.AssignmentConfigs;
@@ -43,21 +44,27 @@ import static 
org.apache.kafka.streams.processor.internals.assignment.TaskMoveme
 
 public class HighAvailabilityTaskAssignor implements TaskAssignor {
     private static final Logger log = 
LoggerFactory.getLogger(HighAvailabilityTaskAssignor.class);
+    private static final int DEFAULT_STATEFUL_TRAFFIC_COST = 10;
+    private static final int DEFAULT_STATEFUL_NON_OVERLAP_COST = 1;
+    private static final int STATELESS_TRAFFIC_COST = 1;
+    private static final int STATELESS_NON_OVERLAP_COST = 1;
 
     @Override
     public boolean assign(final Map<UUID, ClientState> clients,
                           final Set<TaskId> allTaskIds,
                           final Set<TaskId> statefulTaskIds,
+                          final Optional<RackAwareTaskAssignor> 
rackAwareTaskAssignor,
                           final AssignmentConfigs configs) {
         final SortedSet<TaskId> statefulTasks = new TreeSet<>(statefulTaskIds);
         final TreeMap<UUID, ClientState> clientStates = new TreeMap<>(clients);
 
-        assignActiveStatefulTasks(clientStates, statefulTasks);
+        assignActiveStatefulTasks(clientStates, statefulTasks, 
rackAwareTaskAssignor, configs);
 
         assignStandbyReplicaTasks(
             clientStates,
             allTaskIds,
             statefulTasks,
+            rackAwareTaskAssignor,
             configs
         );
 
@@ -94,7 +101,7 @@ public class HighAvailabilityTaskAssignor implements 
TaskAssignor {
             warmups
         );
 
-        assignStatelessActiveTasks(clientStates, diff(TreeSet::new, 
allTaskIds, statefulTasks));
+        assignStatelessActiveTasks(clientStates, diff(TreeSet::new, 
allTaskIds, statefulTasks), rackAwareTaskAssignor);
 
         final boolean probingRebalanceNeeded = neededActiveTaskMovements + 
neededStandbyTaskMovements > 0;
 
@@ -108,7 +115,9 @@ public class HighAvailabilityTaskAssignor implements 
TaskAssignor {
     }
 
     private static void assignActiveStatefulTasks(final SortedMap<UUID, 
ClientState> clientStates,
-                                                  final SortedSet<TaskId> 
statefulTasks) {
+                                                  final SortedSet<TaskId> 
statefulTasks,
+                                                  final 
Optional<RackAwareTaskAssignor> rackAwareTaskAssignor,
+                                                  final AssignmentConfigs 
configs) {
         Iterator<ClientState> clientStateIterator = null;
         for (final TaskId task : statefulTasks) {
             if (clientStateIterator == null || !clientStateIterator.hasNext()) 
{
@@ -124,11 +133,20 @@ public class HighAvailabilityTaskAssignor implements 
TaskAssignor {
             ClientState::assignActive,
             (source, destination) -> true
         );
+
+        if (rackAwareTaskAssignor != null && rackAwareTaskAssignor.isPresent() 
&& rackAwareTaskAssignor.get().canEnableRackAwareAssignor()) {
+            final int trafficCost = configs.rackAwareAssignmentTrafficCost == 
null ?
+                DEFAULT_STATEFUL_TRAFFIC_COST : 
configs.rackAwareAssignmentTrafficCost;
+            final int nonOverlapCost = 
configs.rackAwareAssignmentNonOverlapCost == null ?
+                DEFAULT_STATEFUL_NON_OVERLAP_COST : 
configs.rackAwareAssignmentNonOverlapCost;
+            rackAwareTaskAssignor.get().optimizeActiveTasks(statefulTasks, 
clientStates, trafficCost, nonOverlapCost);
+        }
     }
 
     private void assignStandbyReplicaTasks(final TreeMap<UUID, ClientState> 
clientStates,
                                            final Set<TaskId> allTaskIds,
                                            final Set<TaskId> statefulTasks,
+                                           final 
Optional<RackAwareTaskAssignor> rackAwareTaskAssignor,
                                            final AssignmentConfigs configs) {
         if (configs.numStandbyReplicas == 0) {
             return;
@@ -145,6 +163,14 @@ public class HighAvailabilityTaskAssignor implements 
TaskAssignor {
             ClientState::assignStandby,
             standbyTaskAssignor::isAllowedTaskMovement
         );
+
+        if (rackAwareTaskAssignor != null && rackAwareTaskAssignor.isPresent() 
&& rackAwareTaskAssignor.get().canEnableRackAwareAssignor()) {
+            final int trafficCost = configs.rackAwareAssignmentTrafficCost == 
null ?
+                DEFAULT_STATEFUL_TRAFFIC_COST : 
configs.rackAwareAssignmentTrafficCost;
+            final int nonOverlapCost = 
configs.rackAwareAssignmentNonOverlapCost == null ?
+                DEFAULT_STATEFUL_NON_OVERLAP_COST : 
configs.rackAwareAssignmentNonOverlapCost;
+            rackAwareTaskAssignor.get().optimizeStandbyTasks(clientStates, 
trafficCost, nonOverlapCost, standbyTaskAssignor::isAllowedTaskMovement);
+        }
     }
 
     private static void balanceTasksOverThreads(final SortedMap<UUID, 
ClientState> clientStates,
@@ -208,19 +234,27 @@ public class HighAvailabilityTaskAssignor implements 
TaskAssignor {
     }
 
     private static void assignStatelessActiveTasks(final TreeMap<UUID, 
ClientState> clientStates,
-                                                   final Iterable<TaskId> 
statelessTasks) {
+                                                   final Iterable<TaskId> 
statelessTasks,
+                                                   final 
Optional<RackAwareTaskAssignor> rackAwareTaskAssignor) {
         final ConstrainedPrioritySet statelessActiveTaskClientsByTaskLoad = 
new ConstrainedPrioritySet(
             (client, task) -> true,
             client -> clientStates.get(client).activeTaskLoad()
         );
         statelessActiveTaskClientsByTaskLoad.offerAll(clientStates.keySet());
 
+        final SortedSet<TaskId> sortedTasks = new TreeSet<>();
         for (final TaskId task : statelessTasks) {
+            sortedTasks.add(task);
             final UUID client = 
statelessActiveTaskClientsByTaskLoad.poll(task);
             final ClientState state = clientStates.get(client);
             state.assignActive(task);
             statelessActiveTaskClientsByTaskLoad.offer(client);
         }
+
+        if (rackAwareTaskAssignor != null && rackAwareTaskAssignor.isPresent() 
&& rackAwareTaskAssignor.get().canEnableRackAwareAssignor()) {
+            rackAwareTaskAssignor.get().optimizeActiveTasks(sortedTasks, 
clientStates,
+                STATELESS_TRAFFIC_COST, STATELESS_NON_OVERLAP_COST);
+        }
     }
 
     private static Map<TaskId, SortedSet<UUID>> tasksToCaughtUpClients(final 
Set<TaskId> statefulTasks,
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignor.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignor.java
index 00b5306ae4b..c3921fc9b75 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignor.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/RackAwareTaskAssignor.java
@@ -192,6 +192,9 @@ public class RackAwareTaskAssignor {
     }
 
     private boolean validateClientRack(final Map<UUID, Map<String, 
Optional<String>>> racksForProcessConsumer) {
+        if (racksForProcessConsumer == null) {
+            return false;
+        }
         /*
          * Check rack information is populated correctly in clients
          * 1. RackId exist for all clients
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/StandbyTaskAssignor.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/StandbyTaskAssignor.java
index a5c1ca2ddb5..367cc8cba52 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/StandbyTaskAssignor.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/StandbyTaskAssignor.java
@@ -17,8 +17,11 @@
 package org.apache.kafka.streams.processor.internals.assignment;
 
 import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
 import java.util.UUID;
 import org.apache.kafka.streams.processor.TaskId;
+import 
org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration.AssignmentConfigs;
 
 interface StandbyTaskAssignor extends TaskAssignor {
     default boolean isAllowedTaskMovement(final ClientState source, final 
ClientState destination) {
@@ -39,4 +42,17 @@ interface StandbyTaskAssignor extends TaskAssignor {
                                           final Map<UUID, ClientState> 
clientStateMap) {
         return true;
     }
+
+    default boolean assign(final Map<UUID, ClientState> clients,
+                           final Set<TaskId> allTaskIds,
+                           final Set<TaskId> statefulTaskIds,
+                           final Optional<RackAwareTaskAssignor> 
rackAwareTaskAssignor,
+                           final AssignmentConfigs configs) {
+        return assign(clients, allTaskIds, statefulTaskIds, configs);
+    }
+
+    boolean assign(final Map<UUID, ClientState> clients,
+                   final Set<TaskId> allTaskIds,
+                   final Set<TaskId> statefulTaskIds,
+                   final AssignorConfiguration.AssignmentConfigs configs);
 }
\ No newline at end of file
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignor.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignor.java
index 18abbc14c4a..9a7ad46f2ca 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignor.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignor.java
@@ -16,6 +16,7 @@
  */
 package org.apache.kafka.streams.processor.internals.assignment;
 
+import java.util.Optional;
 import org.apache.kafka.streams.processor.TaskId;
 import 
org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration.AssignmentConfigs;
 import org.slf4j.Logger;
@@ -57,6 +58,7 @@ public class StickyTaskAssignor implements TaskAssignor {
     public boolean assign(final Map<UUID, ClientState> clients,
                           final Set<TaskId> allTaskIds,
                           final Set<TaskId> statefulTaskIds,
+                          final Optional<RackAwareTaskAssignor> 
rackAwareTaskAssignor,
                           final AssignmentConfigs configs) {
         this.clients = clients;
         this.allTaskIds = allTaskIds;
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignor.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignor.java
index aeb2192c63e..faa32a73a34 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignor.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignor.java
@@ -16,18 +16,21 @@
  */
 package org.apache.kafka.streams.processor.internals.assignment;
 
+import java.util.Optional;
 import org.apache.kafka.streams.processor.TaskId;
 
 import java.util.Map;
 import java.util.Set;
 import java.util.UUID;
+import 
org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration.AssignmentConfigs;
 
 public interface TaskAssignor {
     /**
      * @return whether the generated assignment requires a followup probing 
rebalance to satisfy all conditions
      */
-    boolean assign(Map<UUID, ClientState> clients,
-                   Set<TaskId> allTaskIds,
-                   Set<TaskId> statefulTaskIds,
-                   AssignorConfiguration.AssignmentConfigs configs);
+    boolean assign(final Map<UUID, ClientState> clients,
+                   final Set<TaskId> allTaskIds,
+                   final Set<TaskId> statefulTaskIds,
+                   final Optional<RackAwareTaskAssignor> rackAwareTaskAssignor,
+                   final AssignmentConfigs configs);
 }
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/FallbackPriorTaskAssignorTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/FallbackPriorTaskAssignorTest.java
index 0473d9bee45..26ff4685284 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/FallbackPriorTaskAssignorTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/FallbackPriorTaskAssignorTest.java
@@ -54,6 +54,7 @@ public class FallbackPriorTaskAssignorTest {
             clients,
             new HashSet<>(taskIds),
             new HashSet<>(taskIds),
+            null,
             new AssignorConfiguration.AssignmentConfigs(0L, 1, 0, 60_000L, 
EMPTY_RACK_AWARE_ASSIGNMENT_TAGS)
         );
         assertThat(probingRebalanceNeeded, is(true));
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignorTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignorTest.java
index 90e0fed51f3..e8b9eb06203 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignorTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignorTest.java
@@ -16,8 +16,15 @@
  */
 package org.apache.kafka.streams.processor.internals.assignment;
 
+import java.util.Collection;
+import java.util.Optional;
+import java.util.SortedMap;
+import java.util.SortedSet;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.processor.TaskId;
 import 
org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration.AssignmentConfigs;
+import org.junit.Before;
 import org.junit.Test;
 
 import java.util.HashMap;
@@ -26,7 +33,11 @@ import java.util.Map;
 import java.util.Set;
 import java.util.UUID;
 import java.util.stream.Collectors;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
 
+import static java.util.Arrays.asList;
 import static java.util.Collections.emptySet;
 import static java.util.Collections.singleton;
 import static java.util.Collections.singletonMap;
@@ -56,9 +67,21 @@ import static 
org.apache.kafka.streams.processor.internals.assignment.Assignment
 import static 
org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.assertBalancedTasks;
 import static 
org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.assertValidAssignment;
 import static 
org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.getClientStatesMap;
+import static 
org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.getClusterForAllTopics;
+import static 
org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.getProcessRacksForAllProcess;
+import static 
org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.getRandomClientState;
+import static 
org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.getRandomCluster;
+import static 
org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.getRandomProcessRacks;
+import static 
org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.getTaskChangelogMapForAllTasks;
+import static 
org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.getTaskTopicPartitionMap;
+import static 
org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.getTaskTopicPartitionMapForAllTasks;
+import static 
org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.getTopologyGroupTaskMap;
 import static 
org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.hasActiveTasks;
 import static 
org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.hasAssignedTasks;
 import static 
org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.hasStandbyTasks;
+import static 
org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.mockInternalTopicManagerForChangelog;
+import static 
org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.mockInternalTopicManagerForRandomChangelog;
+import static 
org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.verifyStandbySatisfyRackReplica;
 import static org.hamcrest.CoreMatchers.equalTo;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.empty;
@@ -66,30 +89,71 @@ import static org.hamcrest.Matchers.greaterThanOrEqualTo;
 import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.not;
 import static org.junit.Assert.fail;
-
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+
+@RunWith(Parameterized.class)
 public class HighAvailabilityTaskAssignorTest {
-    private final AssignmentConfigs configWithoutStandbys = new 
AssignmentConfigs(
-        /*acceptableRecoveryLag*/ 100L,
-        /*maxWarmupReplicas*/ 2,
-        /*numStandbyReplicas*/ 0,
-        /*probingRebalanceIntervalMs*/ 60 * 1000L,
-        /*rackAwareAssignmentTags*/ EMPTY_RACK_AWARE_ASSIGNMENT_TAGS
-    );
-
-    private final AssignmentConfigs configWithStandbys = new AssignmentConfigs(
-        /*acceptableRecoveryLag*/ 100L,
-        /*maxWarmupReplicas*/ 2,
-        /*numStandbyReplicas*/ 1,
-        /*probingRebalanceIntervalMs*/ 60 * 1000L,
-        /*rackAwareAssignmentTags*/ EMPTY_RACK_AWARE_ASSIGNMENT_TAGS
-    );
+    private AssignmentConfigs getConfigWithoutStandbys() {
+        return new AssignmentConfigs(
+            /*acceptableRecoveryLag*/ 100L,
+            /*maxWarmupReplicas*/ 2,
+            /*numStandbyReplicas*/ 0,
+            /*probingRebalanceIntervalMs*/ 60 * 1000L,
+            /*rackAwareAssignmentTags*/ EMPTY_RACK_AWARE_ASSIGNMENT_TAGS,
+            null,
+            null,
+            rackAwareStrategy
+        );
+    }
+
+    private AssignmentConfigs getConfigWithStandbys() {
+        return getConfigWithStandbys(1);
+    }
+
+    private AssignmentConfigs getConfigWithStandbys(final int replicaNum) {
+        return new AssignmentConfigs(
+            /*acceptableRecoveryLag*/ 100L,
+            /*maxWarmupReplicas*/ 2,
+            /*numStandbyReplicas*/ replicaNum,
+            /*probingRebalanceIntervalMs*/ 60 * 1000L,
+            /*rackAwareAssignmentTags*/ EMPTY_RACK_AWARE_ASSIGNMENT_TAGS,
+            null,
+            null,
+            rackAwareStrategy
+        );
+    }
+
+    @Parameter
+    public boolean enableRackAwareTaskAssignor;
+
+    private String rackAwareStrategy = 
StreamsConfig.RACK_AWARE_ASSIGNMENT_STRATEGY_NONE;
+
+    @Before
+    public void setUp() {
+        if (enableRackAwareTaskAssignor) {
+            rackAwareStrategy = 
StreamsConfig.RACK_AWARE_ASSIGNMENT_STRATEGY_MIN_TRAFFIC;
+        }
+    }
+
+    @Parameterized.Parameters(name = "enableRackAwareTaskAssignor={0}")
+    public static Collection<Object[]> getParamStoreType() {
+        return asList(new Object[][] {
+            {true},
+            {false}
+        });
+    }
 
     @Test
     public void shouldBeStickyForActiveAndStandbyTasksWhileWarmingUp() {
         final Set<TaskId> allTaskIds = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, 
TASK_1_0, TASK_1_1, TASK_1_2, TASK_2_0, TASK_2_1, TASK_2_2);
-        final ClientState clientState1 = new ClientState(allTaskIds, 
emptySet(), allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 0L)), 
EMPTY_CLIENT_TAGS, 1);
-        final ClientState clientState2 = new ClientState(emptySet(), 
allTaskIds, allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 10L)), 
EMPTY_CLIENT_TAGS, 1);
-        final ClientState clientState3 = new ClientState(emptySet(), 
emptySet(), allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 
Long.MAX_VALUE)), EMPTY_CLIENT_TAGS, 1);
+        final ClientState clientState1 = new ClientState(allTaskIds, 
emptySet(), allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 0L)), 
EMPTY_CLIENT_TAGS, 1, UUID_1);
+        final ClientState clientState2 = new ClientState(emptySet(), 
allTaskIds, allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 10L)), 
EMPTY_CLIENT_TAGS, 1, UUID_2);
+        final ClientState clientState3 = new ClientState(emptySet(), 
emptySet(), allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 
Long.MAX_VALUE)), EMPTY_CLIENT_TAGS, 1, UUID_3);
 
         final Map<UUID, ClientState> clientStates = mkMap(
             mkEntry(UUID_1, clientState1),
@@ -97,11 +161,24 @@ public class HighAvailabilityTaskAssignorTest {
             mkEntry(UUID_3, clientState3)
         );
 
+        final AssignmentConfigs configs = new AssignmentConfigs(
+            11L,
+            2,
+            1,
+            60_000L,
+            EMPTY_RACK_AWARE_ASSIGNMENT_TAGS,
+            null,
+            null,
+            rackAwareStrategy
+        );
+        final RackAwareTaskAssignor rackAwareTaskAssignor = 
getRackAwareTaskAssignor(configs);
+
         final boolean unstable = new HighAvailabilityTaskAssignor().assign(
             clientStates,
             allTaskIds,
             allTaskIds,
-            new AssignmentConfigs(11L, 2, 1, 60_000L, 
EMPTY_RACK_AWARE_ASSIGNMENT_TAGS)
+            Optional.of(rackAwareTaskAssignor),
+            configs
         );
 
         assertThat(clientState1, hasAssignedTasks(allTaskIds.size()));
@@ -111,14 +188,16 @@ public class HighAvailabilityTaskAssignorTest {
         assertThat(clientState3, hasAssignedTasks(2));
 
         assertThat(unstable, is(true));
+
+        verifyTaskPlacementWithRackAwareAssignor(rackAwareTaskAssignor, 
allTaskIds, clientStates, true);
     }
 
     @Test
     public void shouldSkipWarmupsWhenAcceptableLagIsMax() {
         final Set<TaskId> allTaskIds = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, 
TASK_1_0, TASK_1_1, TASK_1_2, TASK_2_0, TASK_2_1, TASK_2_2);
-        final ClientState clientState1 = new ClientState(allTaskIds, 
emptySet(), allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 0L)), 
EMPTY_CLIENT_TAGS, 1);
-        final ClientState clientState2 = new ClientState(emptySet(), 
emptySet(), allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 
Long.MAX_VALUE)), EMPTY_CLIENT_TAGS, 1);
-        final ClientState clientState3 = new ClientState(emptySet(), 
emptySet(), allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 
Long.MAX_VALUE)), EMPTY_CLIENT_TAGS, 1);
+        final ClientState clientState1 = new ClientState(allTaskIds, 
emptySet(), allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 0L)), 
EMPTY_CLIENT_TAGS, 1, UUID_1);
+        final ClientState clientState2 = new ClientState(emptySet(), 
emptySet(), allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 
Long.MAX_VALUE)), EMPTY_CLIENT_TAGS, 1, UUID_2);
+        final ClientState clientState3 = new ClientState(emptySet(), 
emptySet(), allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 
Long.MAX_VALUE)), EMPTY_CLIENT_TAGS, 1, UUID_3);
 
         final Map<UUID, ClientState> clientStates = mkMap(
             mkEntry(UUID_1, clientState1),
@@ -126,73 +205,143 @@ public class HighAvailabilityTaskAssignorTest {
             mkEntry(UUID_3, clientState3)
         );
 
+        final AssignmentConfigs configs = new AssignmentConfigs(
+            Long.MAX_VALUE,
+            1,
+            1,
+            60_000L,
+            EMPTY_RACK_AWARE_ASSIGNMENT_TAGS,
+            null,
+            null,
+            rackAwareStrategy
+        );
+        final RackAwareTaskAssignor rackAwareTaskAssignor = 
getRackAwareTaskAssignor(configs);
+
         final boolean unstable = new HighAvailabilityTaskAssignor().assign(
             clientStates,
             allTaskIds,
             allTaskIds,
-            new AssignmentConfigs(Long.MAX_VALUE, 1, 1, 60_000L, 
EMPTY_RACK_AWARE_ASSIGNMENT_TAGS)
+            Optional.of(rackAwareTaskAssignor),
+            configs
         );
 
         assertThat(clientState1, hasAssignedTasks(6));
         assertThat(clientState2, hasAssignedTasks(6));
         assertThat(clientState3, hasAssignedTasks(6));
         assertThat(unstable, is(false));
+
+        verifyTaskPlacementWithRackAwareAssignor(rackAwareTaskAssignor, 
allTaskIds, clientStates, true);
     }
 
     @Test
     public void 
shouldAssignActiveStatefulTasksEvenlyOverClientsWhereNumberOfClientsIntegralDivisorOfNumberOfTasks()
 {
         final Set<TaskId> allTaskIds = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, 
TASK_1_0, TASK_1_1, TASK_1_2, TASK_2_0, TASK_2_1, TASK_2_2);
         final Map<TaskId, Long> lags = 
allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 10L));
-        final ClientState clientState1 = new ClientState(emptySet(), 
emptySet(), lags, EMPTY_CLIENT_TAGS, 1);
-        final ClientState clientState2 = new ClientState(emptySet(), 
emptySet(), lags, EMPTY_CLIENT_TAGS, 1);
-        final ClientState clientState3 = new ClientState(emptySet(), 
emptySet(), lags, EMPTY_CLIENT_TAGS, 1);
+        final ClientState clientState1 = new ClientState(emptySet(), 
emptySet(), lags, EMPTY_CLIENT_TAGS, 1, UUID_1);
+        final ClientState clientState2 = new ClientState(emptySet(), 
emptySet(), lags, EMPTY_CLIENT_TAGS, 1, UUID_2);
+        final ClientState clientState3 = new ClientState(emptySet(), 
emptySet(), lags, EMPTY_CLIENT_TAGS, 1, UUID_3);
         final Map<UUID, ClientState> clientStates = 
getClientStatesMap(clientState1, clientState2, clientState3);
+
+        final AssignmentConfigs configs = new AssignmentConfigs(
+            0L,
+            1,
+            0,
+            60_000L,
+            EMPTY_RACK_AWARE_ASSIGNMENT_TAGS,
+            null,
+            null,
+            rackAwareStrategy
+        );
+        final RackAwareTaskAssignor rackAwareTaskAssignor = 
getRackAwareTaskAssignor(configs);
+
         final boolean unstable = new HighAvailabilityTaskAssignor().assign(
             clientStates,
             allTaskIds,
             allTaskIds,
-            new AssignmentConfigs(0L, 1, 0, 60_000L, 
EMPTY_RACK_AWARE_ASSIGNMENT_TAGS)
+            Optional.of(rackAwareTaskAssignor),
+            configs
         );
         assertThat(unstable, is(false));
         assertValidAssignment(0, allTaskIds, emptySet(), clientStates, new 
StringBuilder());
         assertBalancedActiveAssignment(clientStates, new StringBuilder());
         assertBalancedStatefulAssignment(allTaskIds, clientStates, new 
StringBuilder());
-        assertBalancedTasks(clientStates);
+
+        if (!enableRackAwareTaskAssignor) {
+            // Subtopology is not balanced with min_traffic rack aware 
assignment
+            assertBalancedTasks(clientStates);
+        }
+
+        verifyTaskPlacementWithRackAwareAssignor(rackAwareTaskAssignor, 
allTaskIds, clientStates, false);
     }
 
     @Test
     public void 
shouldAssignActiveStatefulTasksEvenlyOverClientsWhereNumberOfThreadsIntegralDivisorOfNumberOfTasks()
 {
         final Set<TaskId> allTaskIds = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, 
TASK_1_0, TASK_1_1, TASK_1_2, TASK_2_0, TASK_2_1, TASK_2_2);
         final Map<TaskId, Long> lags = 
allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 10L));
-        final ClientState clientState1 = new ClientState(emptySet(), 
emptySet(), lags, EMPTY_CLIENT_TAGS, 3);
-        final ClientState clientState2 = new ClientState(emptySet(), 
emptySet(), lags, EMPTY_CLIENT_TAGS, 3);
-        final ClientState clientState3 = new ClientState(emptySet(), 
emptySet(), lags, EMPTY_CLIENT_TAGS, 3);
+        final ClientState clientState1 = new ClientState(emptySet(), 
emptySet(), lags, EMPTY_CLIENT_TAGS, 3, UUID_1);
+        final ClientState clientState2 = new ClientState(emptySet(), 
emptySet(), lags, EMPTY_CLIENT_TAGS, 3, UUID_2);
+        final ClientState clientState3 = new ClientState(emptySet(), 
emptySet(), lags, EMPTY_CLIENT_TAGS, 3, UUID_3);
         final Map<UUID, ClientState> clientStates = 
getClientStatesMap(clientState1, clientState2, clientState3);
+
+        final AssignmentConfigs configs = new AssignmentConfigs(
+            0L,
+            1,
+            0,
+            60_000L,
+            EMPTY_RACK_AWARE_ASSIGNMENT_TAGS,
+            null,
+            null,
+            rackAwareStrategy
+        );
+        final RackAwareTaskAssignor rackAwareTaskAssignor = 
getRackAwareTaskAssignor(configs);
+
         final boolean unstable = new HighAvailabilityTaskAssignor().assign(
             clientStates,
             allTaskIds,
             allTaskIds,
-            new AssignmentConfigs(0L, 1, 0, 60_000L, 
EMPTY_RACK_AWARE_ASSIGNMENT_TAGS)
+            Optional.of(rackAwareTaskAssignor),
+            configs
         );
+
         assertThat(unstable, is(false));
         assertValidAssignment(0, allTaskIds, emptySet(), clientStates, new 
StringBuilder());
         assertBalancedActiveAssignment(clientStates, new StringBuilder());
         assertBalancedStatefulAssignment(allTaskIds, clientStates, new 
StringBuilder());
-        assertBalancedTasks(clientStates);
+
+        if (!enableRackAwareTaskAssignor) {
+            // Subtopology is not balanced with min_traffic rack aware 
assignment
+            assertBalancedTasks(clientStates);
+        }
+
+        verifyTaskPlacementWithRackAwareAssignor(rackAwareTaskAssignor, 
allTaskIds, clientStates, false);
     }
 
     @Test
     public void 
shouldAssignActiveStatefulTasksEvenlyOverClientsWhereNumberOfClientsNotIntegralDivisorOfNumberOfTasks()
 {
         final Set<TaskId> allTaskIds = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, 
TASK_1_0, TASK_1_1, TASK_1_2, TASK_2_0, TASK_2_1, TASK_2_2);
         final Map<TaskId, Long> lags = 
allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 10L));
-        final ClientState clientState1 = new ClientState(emptySet(), 
emptySet(), lags, EMPTY_CLIENT_TAGS, 1);
-        final ClientState clientState2 = new ClientState(emptySet(), 
emptySet(), lags, EMPTY_CLIENT_TAGS, 1);
+        final ClientState clientState1 = new ClientState(emptySet(), 
emptySet(), lags, EMPTY_CLIENT_TAGS, 1, UUID_1);
+        final ClientState clientState2 = new ClientState(emptySet(), 
emptySet(), lags, EMPTY_CLIENT_TAGS, 1, UUID_2);
         final Map<UUID, ClientState> clientStates = 
getClientStatesMap(clientState1, clientState2);
+
+        final AssignmentConfigs configs = new AssignmentConfigs(
+            0L,
+            1,
+            0,
+            60_000L,
+            EMPTY_RACK_AWARE_ASSIGNMENT_TAGS,
+            null,
+            null,
+            rackAwareStrategy
+        );
+        final RackAwareTaskAssignor rackAwareTaskAssignor = 
getRackAwareTaskAssignor(configs);
+
         final boolean unstable = new HighAvailabilityTaskAssignor().assign(
             clientStates,
             allTaskIds,
             allTaskIds,
-            new AssignmentConfigs(0L, 1, 0, 60_000L, 
EMPTY_RACK_AWARE_ASSIGNMENT_TAGS)
+            Optional.of(rackAwareTaskAssignor),
+            configs
         );
 
         assertThat(unstable, is(false));
@@ -200,21 +349,37 @@ public class HighAvailabilityTaskAssignorTest {
         assertBalancedActiveAssignment(clientStates, new StringBuilder());
         assertBalancedStatefulAssignment(allTaskIds, clientStates, new 
StringBuilder());
         assertBalancedTasks(clientStates);
+
+        verifyTaskPlacementWithRackAwareAssignor(rackAwareTaskAssignor, 
allTaskIds, clientStates, false);
     }
 
     @Test
     public void 
shouldAssignActiveStatefulTasksEvenlyOverUnevenlyDistributedStreamThreads() {
         final Set<TaskId> allTaskIds = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, 
TASK_1_0, TASK_1_1, TASK_1_2);
         final Map<TaskId, Long> lags = 
allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 10L));
-        final ClientState clientState1 = new ClientState(emptySet(), 
emptySet(), lags, EMPTY_CLIENT_TAGS, 1);
-        final ClientState clientState2 = new ClientState(emptySet(), 
emptySet(), lags, EMPTY_CLIENT_TAGS, 2);
-        final ClientState clientState3 = new ClientState(emptySet(), 
emptySet(), lags, EMPTY_CLIENT_TAGS, 3);
+        final ClientState clientState1 = new ClientState(emptySet(), 
emptySet(), lags, EMPTY_CLIENT_TAGS, 1, UUID_1);
+        final ClientState clientState2 = new ClientState(emptySet(), 
emptySet(), lags, EMPTY_CLIENT_TAGS, 2, UUID_2);
+        final ClientState clientState3 = new ClientState(emptySet(), 
emptySet(), lags, EMPTY_CLIENT_TAGS, 3, UUID_3);
         final Map<UUID, ClientState> clientStates = 
getClientStatesMap(clientState1, clientState2, clientState3);
+
+        final AssignmentConfigs configs = new AssignmentConfigs(
+            0L,
+            1,
+            0,
+            60_000L,
+            EMPTY_RACK_AWARE_ASSIGNMENT_TAGS,
+            null,
+            null,
+            rackAwareStrategy
+        );
+        final RackAwareTaskAssignor rackAwareTaskAssignor = 
getRackAwareTaskAssignor(configs);
+
         final boolean unstable = new HighAvailabilityTaskAssignor().assign(
             clientStates,
             allTaskIds,
             allTaskIds,
-            new AssignmentConfigs(0L, 1, 0, 60_000L, 
EMPTY_RACK_AWARE_ASSIGNMENT_TAGS)
+            Optional.of(rackAwareTaskAssignor),
+            configs
         );
 
         assertThat(unstable, is(false));
@@ -229,21 +394,37 @@ public class HighAvailabilityTaskAssignorTest {
         if (taskSkewReport.totalSkewedTasks() == 0) {
             fail("Expected a skewed task assignment, but was: " + 
taskSkewReport);
         }
+
+        verifyTaskPlacementWithRackAwareAssignor(rackAwareTaskAssignor, 
allTaskIds, clientStates, false);
     }
 
     @Test
     public void 
shouldAssignActiveStatefulTasksEvenlyOverClientsWithMoreClientsThanTasks() {
         final Set<TaskId> allTaskIds = mkSet(TASK_0_0, TASK_0_1);
         final Map<TaskId, Long> lags = 
allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 10L));
-        final ClientState clientState1 = new ClientState(emptySet(), 
emptySet(), lags, EMPTY_CLIENT_TAGS, 1);
-        final ClientState clientState2 = new ClientState(emptySet(), 
emptySet(), lags, EMPTY_CLIENT_TAGS, 1);
-        final ClientState clientState3 = new ClientState(emptySet(), 
emptySet(), lags, EMPTY_CLIENT_TAGS, 1);
+        final ClientState clientState1 = new ClientState(emptySet(), 
emptySet(), lags, EMPTY_CLIENT_TAGS, 1, UUID_1);
+        final ClientState clientState2 = new ClientState(emptySet(), 
emptySet(), lags, EMPTY_CLIENT_TAGS, 1, UUID_2);
+        final ClientState clientState3 = new ClientState(emptySet(), 
emptySet(), lags, EMPTY_CLIENT_TAGS, 1, UUID_3);
         final Map<UUID, ClientState> clientStates = 
getClientStatesMap(clientState1, clientState2, clientState3);
+
+        final AssignmentConfigs configs = new AssignmentConfigs(
+            0L,
+            1,
+            0,
+            60_000L,
+            EMPTY_RACK_AWARE_ASSIGNMENT_TAGS,
+            null,
+            null,
+            rackAwareStrategy
+        );
+        final RackAwareTaskAssignor rackAwareTaskAssignor = 
getRackAwareTaskAssignor(configs);
+
         final boolean unstable = new HighAvailabilityTaskAssignor().assign(
             clientStates,
             allTaskIds,
             allTaskIds,
-            new AssignmentConfigs(0L, 1, 0, 60_000L, 
EMPTY_RACK_AWARE_ASSIGNMENT_TAGS)
+            Optional.of(rackAwareTaskAssignor),
+            configs
         );
 
         assertThat(unstable, is(false));
@@ -251,28 +432,50 @@ public class HighAvailabilityTaskAssignorTest {
         assertBalancedActiveAssignment(clientStates, new StringBuilder());
         assertBalancedStatefulAssignment(allTaskIds, clientStates, new 
StringBuilder());
         assertBalancedTasks(clientStates);
+
+        verifyTaskPlacementWithRackAwareAssignor(rackAwareTaskAssignor, 
allTaskIds, clientStates, false);
     }
 
     @Test
     public void 
shouldAssignActiveStatefulTasksEvenlyOverClientsAndStreamThreadsWithEqualStreamThreadsPerClientAsTasks()
 {
         final Set<TaskId> allTaskIds = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, 
TASK_1_0, TASK_1_1, TASK_1_2, TASK_2_0, TASK_2_1, TASK_2_2);
         final Map<TaskId, Long> lags = 
allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 10L));
-        final ClientState clientState1 = new ClientState(emptySet(), 
emptySet(), lags, EMPTY_CLIENT_TAGS, 9);
-        final ClientState clientState2 = new ClientState(emptySet(), 
emptySet(), lags, EMPTY_CLIENT_TAGS, 9);
-        final ClientState clientState3 = new ClientState(emptySet(), 
emptySet(), lags, EMPTY_CLIENT_TAGS, 9);
+        final ClientState clientState1 = new ClientState(emptySet(), 
emptySet(), lags, EMPTY_CLIENT_TAGS, 9, UUID_1);
+        final ClientState clientState2 = new ClientState(emptySet(), 
emptySet(), lags, EMPTY_CLIENT_TAGS, 9, UUID_2);
+        final ClientState clientState3 = new ClientState(emptySet(), 
emptySet(), lags, EMPTY_CLIENT_TAGS, 9, UUID_3);
         final Map<UUID, ClientState> clientStates = 
getClientStatesMap(clientState1, clientState2, clientState3);
+
+        final AssignmentConfigs configs = new AssignmentConfigs(
+            0L,
+            1,
+            0,
+            60_000L,
+            EMPTY_RACK_AWARE_ASSIGNMENT_TAGS,
+            null,
+            null,
+            rackAwareStrategy
+        );
+        final RackAwareTaskAssignor rackAwareTaskAssignor = 
getRackAwareTaskAssignor(configs);
+
         final boolean unstable = new HighAvailabilityTaskAssignor().assign(
             clientStates,
             allTaskIds,
             allTaskIds,
-            new AssignmentConfigs(0L, 1, 0, 60_000L, 
EMPTY_RACK_AWARE_ASSIGNMENT_TAGS)
+            Optional.of(rackAwareTaskAssignor),
+            configs
         );
 
         assertThat(unstable, is(false));
         assertValidAssignment(0, allTaskIds, emptySet(), clientStates, new 
StringBuilder());
         assertBalancedActiveAssignment(clientStates, new StringBuilder());
         assertBalancedStatefulAssignment(allTaskIds, clientStates, new 
StringBuilder());
-        assertBalancedTasks(clientStates);
+
+        if (!enableRackAwareTaskAssignor) {
+            // Subtopology is not balanced with min_traffic rack aware 
assignment
+            assertBalancedTasks(clientStates);
+        }
+
+        verifyTaskPlacementWithRackAwareAssignor(rackAwareTaskAssignor, 
allTaskIds, clientStates, false);
     }
 
     @Test
@@ -281,45 +484,81 @@ public class HighAvailabilityTaskAssignorTest {
         final Map<TaskId, Long> lagsForCaughtUpClient = 
allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 0L));
         final Map<TaskId, Long> lagsForNotCaughtUpClient =
             allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 
Long.MAX_VALUE));
-        final ClientState caughtUpClientState = new ClientState(allTaskIds, 
emptySet(), lagsForCaughtUpClient, EMPTY_CLIENT_TAGS, 5);
-        final ClientState notCaughtUpClientState1 = new 
ClientState(emptySet(), emptySet(), lagsForNotCaughtUpClient, 
EMPTY_CLIENT_TAGS, 5);
-        final ClientState notCaughtUpClientState2 = new 
ClientState(emptySet(), emptySet(), lagsForNotCaughtUpClient, 
EMPTY_CLIENT_TAGS, 5);
+        final ClientState caughtUpClientState = new ClientState(allTaskIds, 
emptySet(), lagsForCaughtUpClient, EMPTY_CLIENT_TAGS, 5, UUID_1);
+        final ClientState notCaughtUpClientState1 = new 
ClientState(emptySet(), emptySet(), lagsForNotCaughtUpClient, 
EMPTY_CLIENT_TAGS, 5, UUID_2);
+        final ClientState notCaughtUpClientState2 = new 
ClientState(emptySet(), emptySet(), lagsForNotCaughtUpClient, 
EMPTY_CLIENT_TAGS, 5, UUID_3);
         final Map<UUID, ClientState> clientStates =
             getClientStatesMap(caughtUpClientState, notCaughtUpClientState1, 
notCaughtUpClientState2);
+
+        final AssignmentConfigs configs = new AssignmentConfigs(
+            0L,
+            allTaskIds.size() / 3 + 1,
+            0,
+            60_000L,
+            EMPTY_RACK_AWARE_ASSIGNMENT_TAGS,
+            null,
+            null,
+            rackAwareStrategy
+        );
+        final RackAwareTaskAssignor rackAwareTaskAssignor = 
getRackAwareTaskAssignor(configs);
+
         final boolean unstable = new HighAvailabilityTaskAssignor().assign(
             clientStates,
             allTaskIds,
             allTaskIds,
-            new AssignmentConfigs(0L, allTaskIds.size() / 3 + 1, 0, 60_000L, 
EMPTY_RACK_AWARE_ASSIGNMENT_TAGS)
+            Optional.of(rackAwareTaskAssignor),
+            configs
         );
 
         assertThat(unstable, is(true));
         assertThat(notCaughtUpClientState1.standbyTaskCount(), 
greaterThanOrEqualTo(allTaskIds.size() / 3));
         assertThat(notCaughtUpClientState2.standbyTaskCount(), 
greaterThanOrEqualTo(allTaskIds.size() / 3));
         assertValidAssignment(0, allTaskIds.size() / 3 + 1, allTaskIds, 
emptySet(), clientStates, new StringBuilder());
+
+        verifyTaskPlacementWithRackAwareAssignor(rackAwareTaskAssignor, 
allTaskIds, clientStates, false);
     }
 
     @Test
     public void 
shouldEvenlyAssignActiveStatefulTasksIfClientsAreWarmedUpToBalanceTaskOverClients()
 {
         final Set<TaskId> allTaskIds = mkSet(TASK_0_0, TASK_0_1, TASK_1_0, 
TASK_1_1);
-        final Set<TaskId> warmedUpTaskIds1 = mkSet(TASK_0_1);
-        final Set<TaskId> warmedUpTaskIds2 = mkSet(TASK_1_0);
+
+        // If RackAwareTaskAssignor is enabled, TASK_1_1 is assigned UUID_2
+        final TaskId warmupTaskId1 = enableRackAwareTaskAssignor ? TASK_1_1 : 
TASK_0_1;
+        // If RackAwareTaskAssignor is enabled, TASK_0_1 is assigned UUID_3
+        final TaskId warmupTaskId2 = enableRackAwareTaskAssignor ? TASK_0_1 : 
TASK_1_0;
+        final Set<TaskId> warmedUpTaskIds1 = mkSet(warmupTaskId1);
+        final Set<TaskId> warmedUpTaskIds2 = mkSet(warmupTaskId2);
         final Map<TaskId, Long> lagsForCaughtUpClient = 
allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 0L));
         final Map<TaskId, Long> lagsForWarmedUpClient1 =
             allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 
Long.MAX_VALUE));
-        lagsForWarmedUpClient1.put(TASK_0_1, 0L);
+        lagsForWarmedUpClient1.put(warmupTaskId1, 0L);
         final Map<TaskId, Long> lagsForWarmedUpClient2 =
             allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 
Long.MAX_VALUE));
-        lagsForWarmedUpClient2.put(TASK_1_0, 0L);
-        final ClientState caughtUpClientState = new ClientState(allTaskIds, 
emptySet(), lagsForCaughtUpClient, EMPTY_CLIENT_TAGS, 5);
-        final ClientState warmedUpClientState1 = new ClientState(emptySet(), 
warmedUpTaskIds1, lagsForWarmedUpClient1, EMPTY_CLIENT_TAGS, 5);
-        final ClientState warmedUpClientState2 = new ClientState(emptySet(), 
warmedUpTaskIds2, lagsForWarmedUpClient2, EMPTY_CLIENT_TAGS, 5);
+        lagsForWarmedUpClient2.put(warmupTaskId2, 0L);
+
+        final ClientState caughtUpClientState = new ClientState(allTaskIds, 
emptySet(), lagsForCaughtUpClient, EMPTY_CLIENT_TAGS, 5, UUID_1);
+        final ClientState warmedUpClientState1 = new ClientState(emptySet(), 
warmedUpTaskIds1, lagsForWarmedUpClient1, EMPTY_CLIENT_TAGS, 5, UUID_2);
+        final ClientState warmedUpClientState2 = new ClientState(emptySet(), 
warmedUpTaskIds2, lagsForWarmedUpClient2, EMPTY_CLIENT_TAGS, 5, UUID_3);
         final Map<UUID, ClientState> clientStates =
             getClientStatesMap(caughtUpClientState, warmedUpClientState1, 
warmedUpClientState2);
+
+        final AssignmentConfigs configs = new AssignmentConfigs(
+            0L,
+            allTaskIds.size() / 3 + 1,
+            0,
+            60_000L,
+            EMPTY_RACK_AWARE_ASSIGNMENT_TAGS,
+            null,
+            null,
+            rackAwareStrategy
+        );
+        final RackAwareTaskAssignor rackAwareTaskAssignor = 
getRackAwareTaskAssignor(configs);
+
         final boolean unstable = new HighAvailabilityTaskAssignor().assign(
             clientStates,
             allTaskIds,
             allTaskIds,
+            Optional.of(rackAwareTaskAssignor),
             new AssignmentConfigs(0L, allTaskIds.size() / 3 + 1, 0, 60_000L, 
EMPTY_RACK_AWARE_ASSIGNMENT_TAGS)
         );
 
@@ -328,20 +567,36 @@ public class HighAvailabilityTaskAssignorTest {
         assertBalancedActiveAssignment(clientStates, new StringBuilder());
         assertBalancedStatefulAssignment(allTaskIds, clientStates, new 
StringBuilder());
         assertBalancedTasks(clientStates);
+
+        verifyTaskPlacementWithRackAwareAssignor(rackAwareTaskAssignor, 
allTaskIds, clientStates, false);
     }
 
     @Test
     public void 
shouldAssignActiveStatefulTasksEvenlyOverStreamThreadsButBestEffortOverClients()
 {
         final Set<TaskId> allTaskIds = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, 
TASK_1_0, TASK_1_1, TASK_1_2, TASK_2_0, TASK_2_1, TASK_2_2);
         final Map<TaskId, Long> lags = 
allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 10L));
-        final ClientState clientState1 = new ClientState(emptySet(), 
emptySet(), lags, EMPTY_CLIENT_TAGS, 6);
-        final ClientState clientState2 = new ClientState(emptySet(), 
emptySet(), lags, EMPTY_CLIENT_TAGS, 3);
+        final ClientState clientState1 = new ClientState(emptySet(), 
emptySet(), lags, EMPTY_CLIENT_TAGS, 6, UUID_1);
+        final ClientState clientState2 = new ClientState(emptySet(), 
emptySet(), lags, EMPTY_CLIENT_TAGS, 3, UUID_2);
         final Map<UUID, ClientState> clientStates = 
getClientStatesMap(clientState1, clientState2);
+
+        final AssignmentConfigs configs = new AssignmentConfigs(
+            0L,
+            1,
+            0,
+            60_000L,
+            EMPTY_RACK_AWARE_ASSIGNMENT_TAGS,
+            null,
+            null,
+            rackAwareStrategy
+        );
+        final RackAwareTaskAssignor rackAwareTaskAssignor = 
getRackAwareTaskAssignor(configs);
+
         final boolean unstable = new HighAvailabilityTaskAssignor().assign(
             clientStates,
             allTaskIds,
             allTaskIds,
-            new AssignmentConfigs(0L, 1, 0, 60_000L, 
EMPTY_RACK_AWARE_ASSIGNMENT_TAGS)
+            Optional.of(rackAwareTaskAssignor),
+            configs
         );
 
         assertThat(unstable, is(false));
@@ -350,18 +605,24 @@ public class HighAvailabilityTaskAssignorTest {
         assertBalancedStatefulAssignment(allTaskIds, clientStates, new 
StringBuilder());
         assertThat(clientState1, hasActiveTasks(6));
         assertThat(clientState2, hasActiveTasks(3));
+
+        verifyTaskPlacementWithRackAwareAssignor(rackAwareTaskAssignor, 
allTaskIds, clientStates, false);
     }
 
     @Test
     public void shouldComputeNewAssignmentIfThereAreUnassignedActiveTasks() {
         final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1);
-        final ClientState client1 = new ClientState(singleton(TASK_0_0), 
emptySet(), singletonMap(TASK_0_0, 0L), EMPTY_CLIENT_TAGS, 1);
+        final ClientState client1 = new ClientState(singleton(TASK_0_0), 
emptySet(), singletonMap(TASK_0_0, 0L), EMPTY_CLIENT_TAGS, 1, UUID_1);
         final Map<UUID, ClientState> clientStates = singletonMap(UUID_1, 
client1);
 
+        final AssignmentConfigs configs = getConfigWithoutStandbys();
+        final RackAwareTaskAssignor rackAwareTaskAssignor = 
getRackAwareTaskAssignor(configs);
+
         final boolean probingRebalanceNeeded = new 
HighAvailabilityTaskAssignor().assign(clientStates,
                                                                                
          allTasks,
                                                                                
          singleton(TASK_0_0),
-                                                                               
          configWithoutStandbys);
+                                                                               
          Optional.of(rackAwareTaskAssignor),
+                                                                               
          configs);
 
         assertThat(probingRebalanceNeeded, is(false));
         assertThat(client1, hasActiveTasks(2));
@@ -371,20 +632,26 @@ public class HighAvailabilityTaskAssignorTest {
         assertBalancedActiveAssignment(clientStates, new StringBuilder());
         assertBalancedStatefulAssignment(allTasks, clientStates, new 
StringBuilder());
         assertBalancedTasks(clientStates);
+
+        verifyTaskPlacementWithRackAwareAssignor(rackAwareTaskAssignor, 
allTasks, clientStates, false);
     }
 
     @Test
     public void shouldComputeNewAssignmentIfThereAreUnassignedStandbyTasks() {
         final Set<TaskId> allTasks = mkSet(TASK_0_0);
         final Set<TaskId> statefulTasks = mkSet(TASK_0_0);
-        final ClientState client1 = new ClientState(singleton(TASK_0_0), 
emptySet(), singletonMap(TASK_0_0, 0L), EMPTY_CLIENT_TAGS, 1);
-        final ClientState client2 = new ClientState(emptySet(), emptySet(), 
singletonMap(TASK_0_0, 0L), EMPTY_CLIENT_TAGS, 1);
+        final ClientState client1 = new ClientState(singleton(TASK_0_0), 
emptySet(), singletonMap(TASK_0_0, 0L), EMPTY_CLIENT_TAGS, 1, UUID_1);
+        final ClientState client2 = new ClientState(emptySet(), emptySet(), 
singletonMap(TASK_0_0, 0L), EMPTY_CLIENT_TAGS, 1, UUID_2);
         final Map<UUID, ClientState> clientStates = mkMap(mkEntry(UUID_1, 
client1), mkEntry(UUID_2, client2));
 
+        final AssignmentConfigs configs = getConfigWithStandbys();
+        final RackAwareTaskAssignor rackAwareTaskAssignor = 
getRackAwareTaskAssignor(configs);
+
         final boolean probingRebalanceNeeded = new 
HighAvailabilityTaskAssignor().assign(clientStates,
                                                                                
          allTasks,
                                                                                
          statefulTasks,
-                                                                               
          configWithStandbys);
+                                                                               
          Optional.of(rackAwareTaskAssignor),
+                                                                               
          configs);
 
         assertThat(clientStates.get(UUID_2).standbyTasks(), not(empty()));
         assertThat(probingRebalanceNeeded, is(false));
@@ -392,21 +659,26 @@ public class HighAvailabilityTaskAssignorTest {
         assertBalancedActiveAssignment(clientStates, new StringBuilder());
         assertBalancedStatefulAssignment(allTasks, clientStates, new 
StringBuilder());
         assertBalancedTasks(clientStates);
+
+        verifyTaskPlacementWithRackAwareAssignor(rackAwareTaskAssignor, 
allTasks, clientStates, true);
     }
 
     @Test
     public void 
shouldComputeNewAssignmentIfActiveTasksWasNotOnCaughtUpClient() {
         final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1);
         final Set<TaskId> statefulTasks = mkSet(TASK_0_0);
-        final ClientState client1 = new ClientState(singleton(TASK_0_0), 
emptySet(), singletonMap(TASK_0_0, 500L), EMPTY_CLIENT_TAGS, 1);
-        final ClientState client2 = new ClientState(singleton(TASK_0_1), 
emptySet(), singletonMap(TASK_0_0, 0L), EMPTY_CLIENT_TAGS, 1);
+        final ClientState client1 = new ClientState(singleton(TASK_0_0), 
emptySet(), singletonMap(TASK_0_0, 500L), EMPTY_CLIENT_TAGS, 1, UUID_1);
+        final ClientState client2 = new ClientState(singleton(TASK_0_1), 
emptySet(), singletonMap(TASK_0_0, 0L), EMPTY_CLIENT_TAGS, 1, UUID_2);
         final Map<UUID, ClientState> clientStates = mkMap(
             mkEntry(UUID_1, client1),
             mkEntry(UUID_2, client2)
         );
 
+        final AssignmentConfigs configs = getConfigWithoutStandbys();
+        final RackAwareTaskAssignor rackAwareTaskAssignor = 
getRackAwareTaskAssignor(configs);
+
         final boolean probingRebalanceNeeded =
-            new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, 
statefulTasks, configWithoutStandbys);
+            new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, 
statefulTasks, Optional.of(rackAwareTaskAssignor), configs);
 
         assertThat(clientStates.get(UUID_1).activeTasks(), 
is(singleton(TASK_0_1)));
         assertThat(clientStates.get(UUID_2).activeTasks(), 
is(singleton(TASK_0_0)));
@@ -417,23 +689,28 @@ public class HighAvailabilityTaskAssignorTest {
         assertBalancedActiveAssignment(clientStates, new StringBuilder());
         assertBalancedStatefulAssignment(allTasks, clientStates, new 
StringBuilder());
         assertBalancedTasks(clientStates);
+
+        verifyTaskPlacementWithRackAwareAssignor(rackAwareTaskAssignor, 
allTasks, clientStates, false);
     }
 
     @Test
     public void 
shouldAssignToMostCaughtUpIfActiveTasksWasNotOnCaughtUpClient() {
         final Set<TaskId> allTasks = mkSet(TASK_0_0);
         final Set<TaskId> statefulTasks = mkSet(TASK_0_0);
-        final ClientState client1 = new ClientState(emptySet(), emptySet(), 
singletonMap(TASK_0_0, Long.MAX_VALUE), EMPTY_CLIENT_TAGS, 1);
-        final ClientState client2 = new ClientState(emptySet(), emptySet(), 
singletonMap(TASK_0_0, 1000L), EMPTY_CLIENT_TAGS, 1);
-        final ClientState client3 = new ClientState(emptySet(), emptySet(), 
singletonMap(TASK_0_0, 500L), EMPTY_CLIENT_TAGS, 1);
+        final ClientState client1 = new ClientState(emptySet(), emptySet(), 
singletonMap(TASK_0_0, Long.MAX_VALUE), EMPTY_CLIENT_TAGS, 1, UUID_1);
+        final ClientState client2 = new ClientState(emptySet(), emptySet(), 
singletonMap(TASK_0_0, 1000L), EMPTY_CLIENT_TAGS, 1, UUID_2);
+        final ClientState client3 = new ClientState(emptySet(), emptySet(), 
singletonMap(TASK_0_0, 500L), EMPTY_CLIENT_TAGS, 1, UUID_3);
         final Map<UUID, ClientState> clientStates = mkMap(
                 mkEntry(UUID_1, client1),
                 mkEntry(UUID_2, client2),
                 mkEntry(UUID_3, client3)
         );
 
+        final AssignmentConfigs configs = getConfigWithStandbys();
+        final RackAwareTaskAssignor rackAwareTaskAssignor = 
getRackAwareTaskAssignor(configs);
+
         final boolean probingRebalanceNeeded =
-                new HighAvailabilityTaskAssignor().assign(clientStates, 
allTasks, statefulTasks, configWithStandbys);
+                new HighAvailabilityTaskAssignor().assign(clientStates, 
allTasks, statefulTasks, Optional.of(rackAwareTaskAssignor), configs);
 
         assertThat(clientStates.get(UUID_1).activeTasks(), is(emptySet()));
         assertThat(clientStates.get(UUID_2).activeTasks(), is(emptySet()));
@@ -448,6 +725,8 @@ public class HighAvailabilityTaskAssignorTest {
         assertBalancedActiveAssignment(clientStates, new StringBuilder());
         assertBalancedStatefulAssignment(allTasks, clientStates, new 
StringBuilder());
         assertBalancedTasks(clientStates);
+
+        verifyTaskPlacementWithRackAwareAssignor(rackAwareTaskAssignor, 
allTasks, clientStates, true);
     }
 
     @Test
@@ -455,12 +734,15 @@ public class HighAvailabilityTaskAssignorTest {
         final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1);
         final Set<TaskId> statefulTasks = mkSet(TASK_0_0, TASK_0_1);
 
-        final ClientState client1 = 
getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_0), statefulTasks);
-        final ClientState client2 = 
getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_1), statefulTasks);
+        final ClientState client1 = 
getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_0), statefulTasks, UUID_1);
+        final ClientState client2 = 
getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_1), statefulTasks, UUID_2);
+
+        final AssignmentConfigs configs = getConfigWithStandbys();
+        final RackAwareTaskAssignor rackAwareTaskAssignor = 
getRackAwareTaskAssignor(configs);
 
         final Map<UUID, ClientState> clientStates = 
getClientStatesMap(client1, client2);
         final boolean probingRebalanceNeeded =
-            new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, 
statefulTasks, configWithStandbys);
+            new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, 
statefulTasks, Optional.of(rackAwareTaskAssignor), configs);
 
 
         assertThat(client1.activeTasks(), equalTo(mkSet(TASK_0_0)));
@@ -468,6 +750,8 @@ public class HighAvailabilityTaskAssignorTest {
         assertThat(client1.standbyTasks(), equalTo(mkSet(TASK_0_1)));
         assertThat(client2.standbyTasks(), equalTo(mkSet(TASK_0_0)));
         assertThat(probingRebalanceNeeded, is(false));
+
+        verifyTaskPlacementWithRackAwareAssignor(rackAwareTaskAssignor, 
allTasks, clientStates, true);
     }
 
     @Test
@@ -475,30 +759,40 @@ public class HighAvailabilityTaskAssignorTest {
         final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1);
         final Set<TaskId> statefulTasks = EMPTY_TASKS;
 
-        final ClientState client1 = 
getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks);
-        final ClientState client2 = 
getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks);
+        final ClientState client1 = 
getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks, UUID_1);
+        final ClientState client2 = 
getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks, UUID_2);
 
         final Map<UUID, ClientState> clientStates = 
getClientStatesMap(client1, client2);
+
+        final AssignmentConfigs configs = getConfigWithStandbys();
+        final RackAwareTaskAssignor rackAwareTaskAssignor = 
getRackAwareTaskAssignor(configs);
+
         final boolean probingRebalanceNeeded =
-            new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, 
statefulTasks, configWithStandbys);
+            new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, 
statefulTasks, Optional.of(rackAwareTaskAssignor), configs);
 
 
         assertThat(client1.activeTaskCount(), equalTo(1));
         assertThat(client2.activeTaskCount(), equalTo(1));
         assertHasNoStandbyTasks(client1, client2);
         assertThat(probingRebalanceNeeded, is(false));
+
+        verifyTaskPlacementWithRackAwareAssignor(rackAwareTaskAssignor, 
allTasks, clientStates, true);
     }
 
     @Test
     public void shouldAssignWarmupReplicasEvenIfNoStandbyReplicasConfigured() {
         final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1);
         final Set<TaskId> statefulTasks = mkSet(TASK_0_0, TASK_0_1);
-        final ClientState client1 = 
getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_0, TASK_0_1), 
statefulTasks);
-        final ClientState client2 = 
getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks);
+        final ClientState client1 = 
getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_0, TASK_0_1), 
statefulTasks, UUID_1);
+        final ClientState client2 = 
getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks, UUID_2);
 
         final Map<UUID, ClientState> clientStates = 
getClientStatesMap(client1, client2);
+
+        final AssignmentConfigs configs = getConfigWithoutStandbys();
+        final RackAwareTaskAssignor rackAwareTaskAssignor = 
getRackAwareTaskAssignor(configs);
+
         final boolean probingRebalanceNeeded =
-            new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, 
statefulTasks, configWithoutStandbys);
+            new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, 
statefulTasks, Optional.of(rackAwareTaskAssignor), configs);
 
 
         assertThat(client1.activeTasks(), equalTo(mkSet(TASK_0_0, TASK_0_1)));
@@ -506,6 +800,8 @@ public class HighAvailabilityTaskAssignorTest {
         assertHasNoStandbyTasks(client1);
         assertHasNoActiveTasks(client2);
         assertThat(probingRebalanceNeeded, is(true));
+
+        verifyTaskPlacementWithRackAwareAssignor(rackAwareTaskAssignor, 
allTasks, clientStates, false);
     }
 
 
@@ -513,21 +809,29 @@ public class HighAvailabilityTaskAssignorTest {
     public void shouldNotAssignMoreThanMaxWarmupReplicas() {
         final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, 
TASK_0_3);
         final Set<TaskId> statefulTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, 
TASK_0_3);
-        final ClientState client1 = 
getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_0, TASK_0_1, TASK_0_2, 
TASK_0_3), statefulTasks);
-        final ClientState client2 = 
getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks);
+        final ClientState client1 = 
getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_0, TASK_0_1, TASK_0_2, 
TASK_0_3), statefulTasks, UUID_1);
+        final ClientState client2 = 
getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks, UUID_2);
 
         final Map<UUID, ClientState> clientStates = 
getClientStatesMap(client1, client2);
+
+        final AssignmentConfigs configs = new AssignmentConfigs(
+            /*acceptableRecoveryLag*/ 100L,
+            /*maxWarmupReplicas*/ 1,
+            /*numStandbyReplicas*/ 0,
+            /*probingRebalanceIntervalMs*/ 60 * 1000L,
+            /*rackAwareAssignmentTags*/ EMPTY_RACK_AWARE_ASSIGNMENT_TAGS,
+            null,
+            null,
+            rackAwareStrategy
+        );
+        final RackAwareTaskAssignor rackAwareTaskAssignor = 
getRackAwareTaskAssignor(configs);
+
         final boolean probingRebalanceNeeded = new 
HighAvailabilityTaskAssignor().assign(
             clientStates,
             allTasks,
             statefulTasks,
-            new AssignmentConfigs(
-                /*acceptableRecoveryLag*/ 100L,
-                /*maxWarmupReplicas*/ 1,
-                /*numStandbyReplicas*/ 0,
-                /*probingRebalanceIntervalMs*/ 60 * 1000L,
-                /*rackAwareAssignmentTags*/ EMPTY_RACK_AWARE_ASSIGNMENT_TAGS
-            )
+            Optional.of(rackAwareTaskAssignor),
+            configs
         );
 
 
@@ -536,27 +840,37 @@ public class HighAvailabilityTaskAssignorTest {
         assertHasNoStandbyTasks(client1);
         assertHasNoActiveTasks(client2);
         assertThat(probingRebalanceNeeded, is(true));
+
+        verifyTaskPlacementWithRackAwareAssignor(rackAwareTaskAssignor, 
allTasks, clientStates, false);
     }
 
     @Test
     public void shouldNotAssignWarmupAndStandbyToTheSameClient() {
         final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, 
TASK_0_3);
         final Set<TaskId> statefulTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, 
TASK_0_3);
-        final ClientState client1 = 
getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_0, TASK_0_1, TASK_0_2, 
TASK_0_3), statefulTasks);
-        final ClientState client2 = 
getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks);
+        final ClientState client1 = 
getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_0, TASK_0_1, TASK_0_2, 
TASK_0_3), statefulTasks, UUID_1);
+        final ClientState client2 = 
getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks, UUID_2);
 
         final Map<UUID, ClientState> clientStates = 
getClientStatesMap(client1, client2);
+
+        final AssignmentConfigs configs = new AssignmentConfigs(
+            /*acceptableRecoveryLag*/ 100L,
+            /*maxWarmupReplicas*/ 1,
+            /*numStandbyReplicas*/ 1,
+            /*probingRebalanceIntervalMs*/ 60 * 1000L,
+            /*rackAwareAssignmentTags*/ EMPTY_RACK_AWARE_ASSIGNMENT_TAGS,
+            null,
+            null,
+            rackAwareStrategy
+        );
+        final RackAwareTaskAssignor rackAwareTaskAssignor = 
getRackAwareTaskAssignor(configs);
+
         final boolean probingRebalanceNeeded = new 
HighAvailabilityTaskAssignor().assign(
             clientStates,
             allTasks,
             statefulTasks,
-            new AssignmentConfigs(
-                /*acceptableRecoveryLag*/ 100L,
-                /*maxWarmupReplicas*/ 1,
-                /*numStandbyReplicas*/ 1,
-                /*probingRebalanceIntervalMs*/ 60 * 1000L,
-                /*rackAwareAssignmentTags*/ EMPTY_RACK_AWARE_ASSIGNMENT_TAGS
-            )
+            Optional.of(rackAwareTaskAssignor),
+            configs
         );
 
         assertThat(client1.activeTasks(), equalTo(mkSet(TASK_0_0, TASK_0_1, 
TASK_0_2, TASK_0_3)));
@@ -564,50 +878,66 @@ public class HighAvailabilityTaskAssignorTest {
         assertHasNoStandbyTasks(client1);
         assertHasNoActiveTasks(client2);
         assertThat(probingRebalanceNeeded, is(true));
+
+        verifyTaskPlacementWithRackAwareAssignor(rackAwareTaskAssignor, 
allTasks, clientStates, true);
     }
 
     @Test
     public void shouldNotAssignAnyStandbysWithInsufficientCapacity() {
         final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1);
         final Set<TaskId> statefulTasks = mkSet(TASK_0_0, TASK_0_1);
-        final ClientState client1 = 
getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_0, TASK_0_1), 
statefulTasks);
+        final ClientState client1 = 
getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_0, TASK_0_1), 
statefulTasks, UUID_1);
 
         final Map<UUID, ClientState> clientStates = 
getClientStatesMap(client1);
+
+        final AssignmentConfigs configs = getConfigWithStandbys();
+        final RackAwareTaskAssignor rackAwareTaskAssignor = 
getRackAwareTaskAssignor(configs);
+
         final boolean probingRebalanceNeeded =
-            new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, 
statefulTasks, configWithStandbys);
+            new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, 
statefulTasks, Optional.of(rackAwareTaskAssignor), configs);
 
         assertThat(client1.activeTasks(), equalTo(mkSet(TASK_0_0, TASK_0_1)));
         assertHasNoStandbyTasks(client1);
         assertThat(probingRebalanceNeeded, is(false));
+
+        verifyTaskPlacementWithRackAwareAssignor(rackAwareTaskAssignor, 
allTasks, clientStates, true);
     }
 
     @Test
     public void shouldAssignActiveTasksToNotCaughtUpClientIfNoneExist() {
         final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1);
         final Set<TaskId> statefulTasks = mkSet(TASK_0_0, TASK_0_1);
-        final ClientState client1 = 
getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks);
+        final ClientState client1 = 
getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks, UUID_1);
 
         final Map<UUID, ClientState> clientStates = 
getClientStatesMap(client1);
 
+        final AssignmentConfigs configs = getConfigWithStandbys();
+        final RackAwareTaskAssignor rackAwareTaskAssignor = 
getRackAwareTaskAssignor(configs);
+
         final boolean probingRebalanceNeeded =
-            new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, 
statefulTasks, configWithStandbys);
+            new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, 
statefulTasks, Optional.of(rackAwareTaskAssignor), configs);
         assertThat(client1.activeTasks(), equalTo(mkSet(TASK_0_0, TASK_0_1)));
         assertHasNoStandbyTasks(client1);
         assertThat(probingRebalanceNeeded, is(false));
+
+        verifyTaskPlacementWithRackAwareAssignor(rackAwareTaskAssignor, 
allTasks, clientStates, true);
     }
 
     @Test
     public void shouldNotAssignMoreThanMaxWarmupReplicasWithStandbys() {
         final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, 
TASK_0_3);
         final Set<TaskId> statefulTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, 
TASK_0_3);
-        final ClientState client1 = 
getMockClientWithPreviousCaughtUpTasks(statefulTasks, statefulTasks);
-        final ClientState client2 = 
getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks);
-        final ClientState client3 = 
getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks);
+        final ClientState client1 = 
getMockClientWithPreviousCaughtUpTasks(statefulTasks, statefulTasks, UUID_1);
+        final ClientState client2 = 
getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks, UUID_2);
+        final ClientState client3 = 
getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks, UUID_3);
 
         final Map<UUID, ClientState> clientStates = 
getClientStatesMap(client1, client2, client3);
 
+        final AssignmentConfigs configs = getConfigWithStandbys();
+        final RackAwareTaskAssignor rackAwareTaskAssignor = 
getRackAwareTaskAssignor(configs);
+
         final boolean probingRebalanceNeeded =
-            new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, 
statefulTasks, configWithStandbys);
+            new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, 
statefulTasks, Optional.of(rackAwareTaskAssignor), configs);
 
         assertValidAssignment(
             1,
@@ -618,6 +948,8 @@ public class HighAvailabilityTaskAssignorTest {
             new StringBuilder()
         );
         assertThat(probingRebalanceNeeded, is(true));
+
+        verifyTaskPlacementWithRackAwareAssignor(rackAwareTaskAssignor, 
allTasks, clientStates, true);
     }
 
     @Test
@@ -626,13 +958,16 @@ public class HighAvailabilityTaskAssignorTest {
         final Set<TaskId> statefulTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, 
TASK_0_3);
         final Set<TaskId> statelessTasks = mkSet(TASK_1_0, TASK_1_1, TASK_1_2);
 
-        final ClientState client1 = 
getMockClientWithPreviousCaughtUpTasks(statefulTasks, statefulTasks);
-        final ClientState client2 = 
getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks);
+        final ClientState client1 = 
getMockClientWithPreviousCaughtUpTasks(statefulTasks, statefulTasks, UUID_1);
+        final ClientState client2 = 
getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks, UUID_2);
 
         final Map<UUID, ClientState> clientStates = 
getClientStatesMap(client1, client2);
 
+        final AssignmentConfigs configs = getConfigWithStandbys();
+        final RackAwareTaskAssignor rackAwareTaskAssignor = 
getRackAwareTaskAssignor(configs);
+
         final boolean probingRebalanceNeeded =
-            new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, 
statefulTasks, configWithStandbys);
+            new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, 
statefulTasks, Optional.of(rackAwareTaskAssignor), configs);
         assertValidAssignment(
             1,
             2,
@@ -650,6 +985,8 @@ public class HighAvailabilityTaskAssignorTest {
         assertThat(taskSkewReport.toString(), 
taskSkewReport.skewedSubtopologies(), not(empty()));
 
         assertThat(probingRebalanceNeeded, is(true));
+
+        verifyTaskPlacementWithRackAwareAssignor(rackAwareTaskAssignor, 
allTasks, clientStates, true);
     }
 
     @Test
@@ -664,57 +1001,81 @@ public class HighAvailabilityTaskAssignorTest {
 
         final Map<UUID, ClientState> clientStates = 
getClientStatesMap(client1, client2, client3);
 
+        final AssignmentConfigs configs = getConfigWithoutStandbys();
+        final RackAwareTaskAssignor rackAwareTaskAssignor = 
getRackAwareTaskAssignor(configs);
+
         final boolean probingRebalanceNeeded =
-            new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, 
statefulTasks, configWithoutStandbys);
+            new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, 
statefulTasks, Optional.of(rackAwareTaskAssignor), configs);
 
         assertThat(client1.activeTasks(), not(empty()));
         assertThat(client2.activeTasks(), not(empty()));
         assertThat(client3.activeTasks(), not(empty()));
         assertThat(probingRebalanceNeeded, is(false));
+
+        verifyTaskPlacementWithRackAwareAssignor(rackAwareTaskAssignor, 
allTasks, clientStates, false);
     }
 
     @Test
     public void shouldReturnFalseIfPreviousAssignmentIsReused() {
         final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, 
TASK_0_3);
         final Set<TaskId> statefulTasks = new HashSet<>(allTasks);
-        final ClientState client1 = 
getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_0, TASK_0_2), 
statefulTasks);
-        final ClientState client2 = 
getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_1, TASK_0_3), 
statefulTasks);
+        final Set<TaskId> caughtUpTasks1 = enableRackAwareTaskAssignor ? 
mkSet(TASK_0_0, TASK_0_3) : mkSet(TASK_0_0, TASK_0_2);
+        final Set<TaskId> caughtUpTasks2 = enableRackAwareTaskAssignor ? 
mkSet(TASK_0_1, TASK_0_2) : mkSet(TASK_0_1, TASK_0_3);
+        final ClientState client1 = 
getMockClientWithPreviousCaughtUpTasks(caughtUpTasks1, statefulTasks, UUID_1);
+        final ClientState client2 = 
getMockClientWithPreviousCaughtUpTasks(caughtUpTasks2, statefulTasks, UUID_2);
 
         final Map<UUID, ClientState> clientStates = 
getClientStatesMap(client1, client2);
+
+        final AssignmentConfigs configs = getConfigWithoutStandbys();
+        final RackAwareTaskAssignor rackAwareTaskAssignor = 
getRackAwareTaskAssignor(configs);
+
         final boolean probingRebalanceNeeded =
-            new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, 
statefulTasks, configWithoutStandbys);
+            new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, 
statefulTasks, Optional.of(rackAwareTaskAssignor), configs);
 
         assertThat(probingRebalanceNeeded, is(false));
         assertThat(client1.activeTasks(), equalTo(client1.prevActiveTasks()));
         assertThat(client2.activeTasks(), equalTo(client2.prevActiveTasks()));
+
+        verifyTaskPlacementWithRackAwareAssignor(rackAwareTaskAssignor, 
allTasks, clientStates, false);
     }
 
     @Test
     public void shouldReturnFalseIfNoWarmupTasksAreAssigned() {
         final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, 
TASK_0_3);
         final Set<TaskId> statefulTasks = EMPTY_TASKS;
-        final ClientState client1 = 
getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks);
-        final ClientState client2 = 
getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks);
+        final ClientState client1 = 
getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks, UUID_1);
+        final ClientState client2 = 
getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks, UUID_2);
 
         final Map<UUID, ClientState> clientStates = 
getClientStatesMap(client1, client2);
+
+        final AssignmentConfigs configs = getConfigWithoutStandbys();
+        final RackAwareTaskAssignor rackAwareTaskAssignor = 
getRackAwareTaskAssignor(configs);
+
         final boolean probingRebalanceNeeded =
-            new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, 
statefulTasks, configWithoutStandbys);
+            new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, 
statefulTasks, Optional.of(rackAwareTaskAssignor), configs);
         assertThat(probingRebalanceNeeded, is(false));
         assertHasNoStandbyTasks(client1, client2);
+
+        verifyTaskPlacementWithRackAwareAssignor(rackAwareTaskAssignor, 
allTasks, clientStates, false);
     }
 
     @Test
     public void shouldReturnTrueIfWarmupTasksAreAssigned() {
         final Set<TaskId> allTasks = mkSet(TASK_0_0, TASK_0_1);
         final Set<TaskId> statefulTasks = mkSet(TASK_0_0, TASK_0_1);
-        final ClientState client1 = 
getMockClientWithPreviousCaughtUpTasks(allTasks, statefulTasks);
-        final ClientState client2 = 
getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks);
+        final ClientState client1 = 
getMockClientWithPreviousCaughtUpTasks(allTasks, statefulTasks, UUID_1);
+        final ClientState client2 = 
getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks, UUID_2);
+
+        final AssignmentConfigs configs = getConfigWithoutStandbys();
+        final RackAwareTaskAssignor rackAwareTaskAssignor = 
getRackAwareTaskAssignor(configs);
 
         final Map<UUID, ClientState> clientStates = 
getClientStatesMap(client1, client2);
         final boolean probingRebalanceNeeded =
-            new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, 
statefulTasks, configWithoutStandbys);
+            new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, 
statefulTasks, Optional.of(rackAwareTaskAssignor), configs);
         assertThat(probingRebalanceNeeded, is(true));
         assertThat(client2.standbyTaskCount(), equalTo(1));
+
+        verifyTaskPlacementWithRackAwareAssignor(rackAwareTaskAssignor, 
allTasks, clientStates, false);
     }
 
     @Test
@@ -730,11 +1091,24 @@ public class HighAvailabilityTaskAssignorTest {
 
         final Map<UUID, ClientState> clientStates = 
getClientStatesMap(client1, client2, client3);
 
+        final AssignmentConfigs configs = new AssignmentConfigs(
+            0L,
+            1,
+            0,
+            60_000L,
+            EMPTY_RACK_AWARE_ASSIGNMENT_TAGS,
+            null,
+            null,
+            rackAwareStrategy
+        );
+        final RackAwareTaskAssignor rackAwareTaskAssignor = 
getRackAwareTaskAssignor(configs);
+
         final boolean probingRebalanceNeeded = new 
HighAvailabilityTaskAssignor().assign(
             clientStates,
             allTasks,
             statefulTasks,
-            new AssignmentConfigs(0L, 1, 0, 60_000L, 
EMPTY_RACK_AWARE_ASSIGNMENT_TAGS)
+            Optional.of(rackAwareTaskAssignor),
+            configs
         );
 
         assertValidAssignment(
@@ -746,6 +1120,8 @@ public class HighAvailabilityTaskAssignorTest {
         );
         assertBalancedActiveAssignment(clientStates, new StringBuilder());
         assertThat(probingRebalanceNeeded, is(false));
+
+        verifyTaskPlacementWithRackAwareAssignor(rackAwareTaskAssignor, 
allTasks, clientStates, false);
     }
 
     @Test
@@ -761,11 +1137,24 @@ public class HighAvailabilityTaskAssignorTest {
 
         final Map<UUID, ClientState> clientStates = 
getClientStatesMap(client1, client2, client3);
 
+        final AssignmentConfigs configs = new AssignmentConfigs(
+            0L,
+            1,
+            0,
+            60_000L,
+            EMPTY_RACK_AWARE_ASSIGNMENT_TAGS,
+            null,
+            null,
+            rackAwareStrategy
+        );
+        final RackAwareTaskAssignor rackAwareTaskAssignor = 
getRackAwareTaskAssignor(configs);
+
         final boolean probingRebalanceNeeded = new 
HighAvailabilityTaskAssignor().assign(
             clientStates,
             allTasks,
             statefulTasks,
-            new AssignmentConfigs(0L, 1, 0, 60_000L, 
EMPTY_RACK_AWARE_ASSIGNMENT_TAGS)
+            Optional.of(rackAwareTaskAssignor),
+            configs
         );
 
         assertValidAssignment(
@@ -777,6 +1166,8 @@ public class HighAvailabilityTaskAssignorTest {
         );
         assertBalancedActiveAssignment(clientStates, new StringBuilder());
         assertThat(probingRebalanceNeeded, is(false));
+
+        verifyTaskPlacementWithRackAwareAssignor(rackAwareTaskAssignor, 
allTasks, clientStates, false);
     }
 
     @Test
@@ -792,11 +1183,24 @@ public class HighAvailabilityTaskAssignorTest {
 
         final Map<UUID, ClientState> clientStates = 
getClientStatesMap(client1, client2, client3);
 
+        final AssignmentConfigs configs = new AssignmentConfigs(
+            0L,
+            1,
+            0,
+            60_000L,
+            EMPTY_RACK_AWARE_ASSIGNMENT_TAGS,
+            null,
+            null,
+            rackAwareStrategy
+        );
+        final RackAwareTaskAssignor rackAwareTaskAssignor = 
getRackAwareTaskAssignor(configs);
+
         final boolean probingRebalanceNeeded = new 
HighAvailabilityTaskAssignor().assign(
             clientStates,
             allTasks,
             statefulTasks,
-            new AssignmentConfigs(0L, 1, 0, 60_000L, 
EMPTY_RACK_AWARE_ASSIGNMENT_TAGS)
+            Optional.of(rackAwareTaskAssignor),
+            configs
         );
 
         assertValidAssignment(
@@ -808,6 +1212,8 @@ public class HighAvailabilityTaskAssignorTest {
         );
         assertBalancedActiveAssignment(clientStates, new StringBuilder());
         assertThat(probingRebalanceNeeded, is(false));
+
+        verifyTaskPlacementWithRackAwareAssignor(rackAwareTaskAssignor, 
allTasks, clientStates, false);
     }
 
     @Test
@@ -823,11 +1229,24 @@ public class HighAvailabilityTaskAssignorTest {
 
         final Map<UUID, ClientState> clientStates = 
getClientStatesMap(client1, client2, client3);
 
+        final AssignmentConfigs configs = new AssignmentConfigs(
+            0L,
+            1,
+            0,
+            60_000L,
+            EMPTY_RACK_AWARE_ASSIGNMENT_TAGS,
+            null,
+            null,
+            rackAwareStrategy
+        );
+        final RackAwareTaskAssignor rackAwareTaskAssignor = 
getRackAwareTaskAssignor(configs);
+
         final boolean probingRebalanceNeeded = new 
HighAvailabilityTaskAssignor().assign(
             clientStates,
             allTasks,
             statefulTasks,
-            new AssignmentConfigs(0L, 1, 0, 60_000L, 
EMPTY_RACK_AWARE_ASSIGNMENT_TAGS)
+            Optional.of(rackAwareTaskAssignor),
+            configs
         );
 
         assertValidAssignment(
@@ -839,6 +1258,52 @@ public class HighAvailabilityTaskAssignorTest {
         );
         assertBalancedActiveAssignment(clientStates, new StringBuilder());
         assertThat(probingRebalanceNeeded, is(false));
+
+        verifyTaskPlacementWithRackAwareAssignor(rackAwareTaskAssignor, 
allTasks, clientStates, false);
+    }
+
+    @Test
+    public void shouldAssignRandomInput() {
+        final int nodeSize = 50;
+        final int tpSize = 60;
+        final int clientSize = 50;
+        final int replicaCount = 3;
+        final int maxCapacity = 3;
+        final SortedMap<TaskId, Set<TopicPartition>> taskTopicPartitionMap = 
getTaskTopicPartitionMap(
+            tpSize, false);
+        final AssignmentConfigs assignorConfiguration = 
getConfigWithStandbys(replicaCount);
+
+        final RackAwareTaskAssignor rackAwareTaskAssignor = spy(new 
RackAwareTaskAssignor(
+            getRandomCluster(nodeSize, tpSize),
+            taskTopicPartitionMap,
+            getTaskTopicPartitionMap(tpSize, true),
+            getTopologyGroupTaskMap(),
+            getRandomProcessRacks(clientSize, nodeSize),
+            mockInternalTopicManagerForRandomChangelog(nodeSize, tpSize),
+            assignorConfiguration
+        ));
+
+        final SortedMap<UUID, ClientState> clientStateMap = 
getRandomClientState(clientSize,
+            tpSize, maxCapacity, false);
+        final SortedSet<TaskId> taskIds = (SortedSet<TaskId>) 
taskTopicPartitionMap.keySet();
+
+        new HighAvailabilityTaskAssignor().assign(
+            clientStateMap,
+            taskIds,
+            taskIds,
+            Optional.of(rackAwareTaskAssignor),
+            assignorConfiguration
+        );
+
+        assertValidAssignment(
+            replicaCount,
+            taskIds,
+            mkSet(),
+            clientStateMap,
+            new StringBuilder()
+        );
+        assertBalancedActiveAssignment(clientStateMap, new StringBuilder());
+        verifyTaskPlacementWithRackAwareAssignor(rackAwareTaskAssignor, 
taskIds, clientStateMap, true);
     }
 
     private static void assertHasNoActiveTasks(final ClientState... clients) {
@@ -854,7 +1319,8 @@ public class HighAvailabilityTaskAssignorTest {
     }
 
     private static ClientState getMockClientWithPreviousCaughtUpTasks(final 
Set<TaskId> statefulActiveTasks,
-                                                                      final 
Set<TaskId> statefulTasks) {
+                                                                      final 
Set<TaskId> statefulTasks,
+                                                                      final 
UUID processId) {
         if (!statefulTasks.containsAll(statefulActiveTasks)) {
             throw new IllegalArgumentException("Need to initialize stateful 
tasks set before creating mock clients");
         }
@@ -866,6 +1332,36 @@ public class HighAvailabilityTaskAssignorTest {
                 taskLags.put(task, Long.MAX_VALUE);
             }
         }
-        return new ClientState(statefulActiveTasks, emptySet(), taskLags, 
EMPTY_CLIENT_TAGS, 1);
+        return new ClientState(statefulActiveTasks, emptySet(), taskLags, 
EMPTY_CLIENT_TAGS, 1, processId);
+    }
+
+    private static RackAwareTaskAssignor getRackAwareTaskAssignor(final 
AssignmentConfigs configs) {
+        return spy(
+            new RackAwareTaskAssignor(
+                getClusterForAllTopics(),
+                getTaskTopicPartitionMapForAllTasks(),
+                getTaskChangelogMapForAllTasks(),
+                new HashMap<>(),
+                getProcessRacksForAllProcess(),
+                mockInternalTopicManagerForChangelog(),
+                configs
+            )
+        );
+    }
+
+    private void verifyTaskPlacementWithRackAwareAssignor(final 
RackAwareTaskAssignor rackAwareTaskAssignor,
+                                                          final Set<TaskId> 
allTaskIds,
+                                                          final Map<UUID, 
ClientState> clientStates,
+                                                          final boolean 
hasStandby) {
+        // Verifies active and standby are in different clients
+        verifyStandbySatisfyRackReplica(allTaskIds, 
rackAwareTaskAssignor.racksForProcess(), clientStates, null, true, null);
+
+        if (enableRackAwareTaskAssignor) {
+            verify(rackAwareTaskAssignor, times(2)).optimizeActiveTasks(any(), 
any(), anyInt(), anyInt());
+            verify(rackAwareTaskAssignor, hasStandby ? times(1) : 
never()).optimizeStandbyTasks(any(), anyInt(), anyInt(), any());
+        } else {
+            verify(rackAwareTaskAssignor, never()).optimizeActiveTasks(any(), 
any(), anyInt(), anyInt());
+            verify(rackAwareTaskAssignor, never()).optimizeStandbyTasks(any(), 
anyInt(), anyInt(), any());
+        }
     }
 }
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignorTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignorTest.java
index 8c1347f22d9..e94f3aee8d7 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignorTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignorTest.java
@@ -677,6 +677,7 @@ public class StickyTaskAssignorTest {
             clients,
             new HashSet<>(taskIds),
             new HashSet<>(taskIds),
+            null,
             new AssignorConfiguration.AssignmentConfigs(0L, 1, 0, 60_000L, 
EMPTY_RACK_AWARE_ASSIGNMENT_TAGS)
         );
         assertThat(probingRebalanceNeeded, is(false));
@@ -696,6 +697,7 @@ public class StickyTaskAssignorTest {
             clients,
             new HashSet<>(taskIds),
             new HashSet<>(taskIds),
+            null,
             new AssignorConfiguration.AssignmentConfigs(0L, 1, numStandbys, 
60_000L, EMPTY_RACK_AWARE_ASSIGNMENT_TAGS)
         );
     }
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignorConvergenceTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignorConvergenceTest.java
index c1be5f33fa2..2aff41b48a3 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignorConvergenceTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignorConvergenceTest.java
@@ -413,6 +413,7 @@ public class TaskAssignorConvergenceTest {
                 harness.clientStates,
                 allTasks,
                 harness.statefulTaskEndOffsetSums.keySet(),
+                null,
                 configs
             );
             harness.recordAfter(iteration, rebalancePending);


Reply via email to