This is an automated email from the ASF dual-hosted git repository.

mjsax pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 4297d99528b MINOR: Replace generic Task type with StreamTask and 
StandbyTask (#21502)
4297d99528b is described below

commit 4297d99528b0865e0e95affedc77746f54a84e2b
Author: Nikita Shupletsov <[email protected]>
AuthorDate: Wed Feb 18 19:09:19 2026 -0800

    MINOR: Replace generic Task type with StreamTask and StandbyTask (#21502)
    
    This PR improves type safety by replacing the generic Task type with
    StreamTask and StandbyTask. As a side effect, we can avoid unnecessary
    casts, usage of instanceof, or checking the task type via .isActiveTask.
    
    Reviewers: Matthias J. Sax <[email protected]>
---
 .../processor/internals/ActiveTaskCreator.java     |   7 +-
 .../processor/internals/StandbyTaskCreator.java    |   5 +-
 .../streams/processor/internals/StreamThread.java  |   4 +-
 .../streams/processor/internals/TaskExecutor.java  |  16 ++-
 .../streams/processor/internals/TaskManager.java   | 119 ++++++++++-----------
 .../kafka/streams/processor/internals/Tasks.java   |  96 ++++++++++-------
 .../streams/processor/internals/TasksRegistry.java |  24 +++--
 .../internals/tasks/DefaultTaskManager.java        |  12 +--
 .../processor/internals/StreamThreadTest.java      |   9 +-
 .../processor/internals/TaskManagerTest.java       |  34 +++---
 .../streams/processor/internals/TasksTest.java     |  12 +--
 .../internals/tasks/DefaultTaskManagerTest.java    |   2 +-
 12 files changed, 175 insertions(+), 165 deletions(-)

diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java
index 24ff8fabf0c..ac8ec93a9c3 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java
@@ -131,10 +131,9 @@ class ActiveTaskCreator {
         return isClosed;
     }
 
-    // TODO: convert to StreamTask when we remove TaskManager#StateMachineTask 
with mocks
-    public Collection<Task> createTasks(final Consumer<byte[], byte[]> 
consumer,
-                                        final Map<TaskId, Set<TopicPartition>> 
tasksToBeCreated) {
-        final List<Task> createdTasks = new ArrayList<>();
+    public Collection<StreamTask> createTasks(final Consumer<byte[], byte[]> 
consumer,
+                                              final Map<TaskId, 
Set<TopicPartition>> tasksToBeCreated) {
+        final List<StreamTask> createdTasks = new ArrayList<>();
 
         for (final Map.Entry<TaskId, Set<TopicPartition>> newTaskAndPartitions 
: tasksToBeCreated.entrySet()) {
             final TaskId taskId = newTaskAndPartitions.getKey();
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTaskCreator.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTaskCreator.java
index 693cb4ed63a..f04aec38f46 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTaskCreator.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTaskCreator.java
@@ -65,9 +65,8 @@ class StandbyTaskCreator {
         );
     }
 
-    // TODO: convert to StandbyTask when we remove 
TaskManager#StateMachineTask with mocks
-    Collection<Task> createTasks(final Map<TaskId, Set<TopicPartition>> 
tasksToBeCreated) {
-        final List<Task> createdTasks = new ArrayList<>();
+    Collection<StandbyTask> createTasks(final Map<TaskId, Set<TopicPartition>> 
tasksToBeCreated) {
+        final List<StandbyTask> createdTasks = new ArrayList<>();
 
         for (final Map.Entry<TaskId, Set<TopicPartition>> newTaskAndPartitions 
: tasksToBeCreated.entrySet()) {
             final TaskId taskId = newTaskAndPartitions.getKey();
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
index 153840abc8f..a5002a17378 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java
@@ -1926,7 +1926,7 @@ public class StreamThread extends Thread implements 
ProcessingThread {
     }
 
     private void updateThreadMetadata(final Map<TaskId, Task> activeTasks,
-                                      final Map<TaskId, Task> standbyTasks) {
+                                      final Map<TaskId, StandbyTask> 
standbyTasks) {
         final Set<TaskMetadata> activeTasksMetadata = new HashSet<>();
         for (final Map.Entry<TaskId, Task> task : activeTasks.entrySet()) {
             activeTasksMetadata.add(new TaskMetadataImpl(
@@ -1938,7 +1938,7 @@ public class StreamThread extends Thread implements 
ProcessingThread {
             ));
         }
         final Set<TaskMetadata> standbyTasksMetadata = new HashSet<>();
-        for (final Map.Entry<TaskId, Task> task : standbyTasks.entrySet()) {
+        for (final Map.Entry<TaskId, StandbyTask> task : 
standbyTasks.entrySet()) {
             standbyTasksMetadata.add(new TaskMetadataImpl(
                 task.getValue().id(),
                 task.getValue().inputPartitions(),
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskExecutor.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskExecutor.java
index 8294bf407d0..641d9c72b11 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskExecutor.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskExecutor.java
@@ -68,7 +68,7 @@ public class TaskExecutor {
         int totalProcessed = 0;
         Task lastProcessed = null;
 
-        for (final Task task : tasks.activeInitializedTasks()) {
+        for (final StreamTask task : tasks.activeInitializedTasks()) {
             final long now = time.milliseconds();
             try {
                 if (executionMetadata.canProcessTask(task, now)) {
@@ -136,7 +136,7 @@ public class TaskExecutor {
      * @param consumedOffsetsAndMetadata an empty map that will be filled in 
with the prepared offsets
      * @return number of committed offsets, or -1 if we are in the middle of a 
rebalance and cannot commit
      */
-    int commitTasksAndMaybeUpdateCommittableOffsets(final Collection<Task> 
tasksToCommit,
+    int commitTasksAndMaybeUpdateCommittableOffsets(final Collection<? extends 
Task> tasksToCommit,
                                                     final Map<Task, 
Map<TopicPartition, OffsetAndMetadata>> consumedOffsetsAndMetadata) {
         int committed = 0;
         for (final Task task : tasksToCommit) {
@@ -233,12 +233,10 @@ public class TaskExecutor {
 
     private void updateTaskCommitMetadata(final Map<TopicPartition, 
OffsetAndMetadata> allOffsets) {
         if (!allOffsets.isEmpty()) {
-            for (final Task task : tasks.activeInitializedTasks()) {
-                if (task instanceof StreamTask) {
-                    for (final TopicPartition topicPartition : 
task.inputPartitions()) {
-                        if (allOffsets.containsKey(topicPartition)) {
-                            ((StreamTask) 
task).updateCommittedOffsets(topicPartition, 
allOffsets.get(topicPartition).offset());
-                        }
+            for (final StreamTask task : tasks.activeInitializedTasks()) {
+                for (final TopicPartition topicPartition : 
task.inputPartitions()) {
+                    if (allOffsets.containsKey(topicPartition)) {
+                        task.updateCommittedOffsets(topicPartition, 
allOffsets.get(topicPartition).offset());
                     }
                 }
             }
@@ -261,7 +259,7 @@ public class TaskExecutor {
     int punctuate() {
         int punctuated = 0;
 
-        for (final Task task : tasks.activeInitializedTasks()) {
+        for (final StreamTask task : tasks.activeInitializedTasks()) {
             try {
                 if (executionMetadata.canPunctuateTask(task)) {
                     if (task.maybePunctuateStreamTime()) {
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
index f91663cf0c8..9f421a9671c 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
@@ -449,8 +449,8 @@ public class TaskManager {
 
     private void createNewTasks(final Map<TaskId, Set<TopicPartition>> 
activeTasksToCreate,
                                 final Map<TaskId, Set<TopicPartition>> 
standbyTasksToCreate) {
-        final Collection<Task> newActiveTasks = 
activeTaskCreator.createTasks(mainConsumer, activeTasksToCreate);
-        final Collection<Task> newStandbyTasks = 
standbyTaskCreator.createTasks(standbyTasksToCreate);
+        final Collection<StreamTask> newActiveTasks = 
activeTaskCreator.createTasks(mainConsumer, activeTasksToCreate);
+        final Collection<StandbyTask> newStandbyTasks = 
standbyTaskCreator.createTasks(standbyTasksToCreate);
 
         tasks.addPendingTasksToInit(newActiveTasks);
         tasks.addPendingTasksToInit(newStandbyTasks);
@@ -540,7 +540,7 @@ public class TaskManager {
             }
             final TaskId taskId = task.id();
             if (activeTasksToCreate.containsKey(taskId)) {
-                handleReassignedActiveTask(task, 
activeTasksToCreate.get(taskId));
+                handleReassignedActiveTask((StreamTask) task, 
activeTasksToCreate.get(taskId));
                 activeTasksToCreate.remove(taskId);
             } else if (standbyTasksToCreate.containsKey(taskId)) {
                 tasksToRecycle.put(task, standbyTasksToCreate.get(taskId));
@@ -551,7 +551,7 @@ public class TaskManager {
         }
     }
 
-    private void handleReassignedActiveTask(final Task task,
+    private void handleReassignedActiveTask(final StreamTask task,
                                             final Set<TopicPartition> 
inputPartitions) {
         if (tasks.updateActiveTaskInputPartitions(task, inputPartitions)) {
             task.updateInputPartitions(inputPartitions, 
topologyMetadata.nodeToSourceTopics(task.id()));
@@ -906,12 +906,12 @@ public class TaskManager {
     /**
      * @throws StreamsException if fetching committed offsets timed out often 
enough to exceed task timeout
      */
-    private void transitRestoredTaskToRunning(final Task task,
+    private void transitRestoredTaskToRunning(final StreamTask task,
                                               final long now,
                                               final 
java.util.function.Consumer<Set<TopicPartition>> offsetResetter) throws 
StreamsException {
         try {
             task.completeRestoration(offsetResetter);
-            tasks.addTask(task);
+            tasks.addActiveTask(task);
             mainConsumer.resume(task.inputPartitions());
             task.clearTaskTimeout();
         } catch (final TimeoutException timeoutException) {
@@ -1000,7 +1000,7 @@ public class TaskManager {
 
         try {
             while (iterator.hasNext()) {
-                final Task task = iterator.next();
+                final StreamTask task = iterator.next();
                 transitRestoredTaskToRunning(task, now, offsetResetter);
                 iterator.remove(); // Remove successfully transitioned tasks
             }
@@ -1026,8 +1026,8 @@ public class TaskManager {
     void handleRevocation(final Collection<TopicPartition> revokedPartitions) {
         final Set<TopicPartition> remainingRevokedPartitions = new 
HashSet<>(revokedPartitions);
 
-        final Set<Task> revokedActiveTasks = new 
TreeSet<>(Comparator.comparing(Task::id));
-        final Set<Task> commitNeededActiveTasks = new 
TreeSet<>(Comparator.comparing(Task::id));
+        final Set<StreamTask> revokedActiveTasks = new 
TreeSet<>(Comparator.comparing(Task::id));
+        final Set<StreamTask> commitNeededActiveTasks = new 
TreeSet<>(Comparator.comparing(Task::id));
         final Map<Task, Map<TopicPartition, OffsetAndMetadata>> 
consumedOffsetsPerTask = new HashMap<>();
         final AtomicReference<RuntimeException> firstException = new 
AtomicReference<>(null);
 
@@ -1035,7 +1035,7 @@ public class TaskManager {
         maybeLockTasks(lockedTaskIds);
 
         boolean revokedTasksNeedCommit = false;
-        for (final Task task : activeRunningTaskIterable()) {
+        for (final StreamTask task : activeRunningTaskIterable()) {
             if 
(remainingRevokedPartitions.containsAll(task.inputPartitions())) {
                 // when the task input partitions are included in the revoked 
list,
                 // this is an active task and should be revoked
@@ -1151,15 +1151,15 @@ public class TaskManager {
         }
         getNonFailedTasks(futures, failedTasksFromStateUpdater).forEach(task 
-> {
             task.suspend();
-            tasks.addTask(task);
+            tasks.addActiveTask((StreamTask) task);
         });
 
         maybeThrowTaskExceptions(failedTasksFromStateUpdater);
     }
 
-    private void prepareCommitAndAddOffsetsToMap(final Set<Task> 
tasksToPrepare,
+    private void prepareCommitAndAddOffsetsToMap(final Set<StreamTask> 
tasksToPrepare,
                                                  final Map<Task, 
Map<TopicPartition, OffsetAndMetadata>> consumedOffsetsPerTask) {
-        for (final Task task : tasksToPrepare) {
+        for (final StreamTask task : tasksToPrepare) {
             try {
                 final Map<TopicPartition, OffsetAndMetadata> 
committableOffsets = task.prepareCommit(true);
                 if (!committableOffsets.isEmpty()) {
@@ -1411,17 +1411,15 @@ public class TaskManager {
 
         final AtomicReference<RuntimeException> firstException = new 
AtomicReference<>(null);
 
-        // TODO: change type to `StreamTask`
-        final Set<Task> activeTasks = new 
TreeSet<>(Comparator.comparing(Task::id));
+        final Set<StreamTask> activeTasks = new 
TreeSet<>(Comparator.comparing(Task::id));
         activeTasks.addAll(tasks.activeInitializedTasks());
-        // TODO: change type to `StandbyTask`
-        final Set<Task> standbyTasks = new 
TreeSet<>(Comparator.comparing(Task::id));
+        final Set<StandbyTask> standbyTasks = new 
TreeSet<>(Comparator.comparing(Task::id));
         standbyTasks.addAll(tasks.standbyInitializedTasks());
 
-        final Set<Task> pendingActiveTasks = 
tasks.drainPendingActiveTasksToInit();
+        final Set<StreamTask> pendingActiveTasks = 
tasks.drainPendingActiveTasksToInit();
         activeTasks.addAll(pendingActiveTasks);
         tasks.addPendingTasksToClose(pendingActiveTasks);
-        final Set<Task> pendingStandbyTasks = 
tasks.drainPendingStandbyTasksToInit();
+        final Set<StandbyTask> pendingStandbyTasks = 
tasks.drainPendingStandbyTasksToInit();
         standbyTasks.addAll(pendingStandbyTasks);
         tasks.addPendingTasksToClose(pendingStandbyTasks);
 
@@ -1503,7 +1501,7 @@ public class TaskManager {
     /**
      * Closes and cleans up after the provided tasks, including closing their 
corresponding task producers
      */
-    void closeAndCleanUpTasks(final Collection<Task> activeTasks, final 
Collection<Task> standbyTasks, final boolean clean) {
+    void closeAndCleanUpTasks(final Collection<StreamTask> activeTasks, final 
Collection<StandbyTask> standbyTasks, final boolean clean) {
         final AtomicReference<RuntimeException> firstException = new 
AtomicReference<>(null);
 
         final Set<TaskId> ids =
@@ -1529,20 +1527,20 @@ public class TaskManager {
     }
 
     // Returns the set of active tasks that must be closed dirty
-    private Collection<Task> tryCloseCleanActiveTasks(final Collection<Task> 
activeTasksToClose,
-                                                      final boolean clean,
-                                                      final 
AtomicReference<RuntimeException> firstException) {
+    private Collection<StreamTask> tryCloseCleanActiveTasks(final 
Collection<StreamTask> activeTasksToClose,
+                                                            final boolean 
clean,
+                                                            final 
AtomicReference<RuntimeException> firstException) {
         if (!clean) {
             return activeTasksToClose;
         }
-        final Comparator<Task> byId = Comparator.comparing(Task::id);
-        final Set<Task> tasksToCommit = new TreeSet<>(byId);
-        final Set<Task> tasksToCloseDirty = new TreeSet<>(byId);
-        final Set<Task> tasksToCloseClean = new TreeSet<>(byId);
+        final Comparator<StreamTask> byId = Comparator.comparing(Task::id);
+        final Set<StreamTask> tasksToCommit = new TreeSet<>(byId);
+        final Set<StreamTask> tasksToCloseDirty = new TreeSet<>(byId);
+        final Set<StreamTask> tasksToCloseClean = new TreeSet<>(byId);
         final Map<Task, Map<TopicPartition, OffsetAndMetadata>> 
consumedOffsetsAndMetadataPerTask = new HashMap<>();
 
         // first committing all tasks and then suspend and close them clean
-        for (final Task task : activeTasksToClose) {
+        for (final StreamTask task : activeTasksToClose) {
             try {
                 final Map<TopicPartition, OffsetAndMetadata> 
committableOffsets = task.prepareCommit(true);
                 tasksToCommit.add(task);
@@ -1578,7 +1576,7 @@ public class TaskManager {
                 if (e instanceof TaskCorruptedException) {
                     final TaskCorruptedException taskCorruptedException = 
(TaskCorruptedException) e;
                     final Set<TaskId> corruptedTaskIds = 
taskCorruptedException.corruptedTasks();
-                    final Set<Task> corruptedTasks = tasksToCommit
+                    final Set<StreamTask> corruptedTasks = tasksToCommit
                         .stream()
                         .filter(task -> corruptedTaskIds.contains(task.id()))
                         .collect(Collectors.toSet());
@@ -1591,7 +1589,7 @@ public class TaskManager {
                 }
             }
 
-            for (final Task task : activeTasksToClose) {
+            for (final StreamTask task : activeTasksToClose) {
                 try {
                     task.postCommit(true);
                 } catch (final RuntimeException e) {
@@ -1603,7 +1601,7 @@ public class TaskManager {
             }
         }
 
-        for (final Task task : tasksToCloseClean) {
+        for (final StreamTask task : tasksToCloseClean) {
             try {
                 task.suspend();
                 closeTaskClean(task);
@@ -1622,16 +1620,16 @@ public class TaskManager {
     }
 
     // Returns the set of standby tasks that must be closed dirty
-    private Collection<Task> tryCloseCleanStandbyTasks(final Collection<Task> 
standbyTasksToClose,
-                                                       final boolean clean,
-                                                       final 
AtomicReference<RuntimeException> firstException) {
+    private Collection<StandbyTask> tryCloseCleanStandbyTasks(final 
Collection<StandbyTask> standbyTasksToClose,
+                                                              final boolean 
clean,
+                                                              final 
AtomicReference<RuntimeException> firstException) {
         if (!clean) {
             return standbyTasksToClose;
         }
-        final Set<Task> tasksToCloseDirty = new 
TreeSet<>(Comparator.comparing(Task::id));
+        final Set<StandbyTask> tasksToCloseDirty = new 
TreeSet<>(Comparator.comparing(Task::id));
 
         // first committing and then suspend / close clean
-        for (final Task task : standbyTasksToClose) {
+        for (final StandbyTask task : standbyTasksToClose) {
             try {
                 task.prepareCommit(true);
                 task.postCommit(true);
@@ -1709,7 +1707,7 @@ public class TaskManager {
         return activeTaskStream().collect(Collectors.toList());
     }
 
-    List<Task> activeRunningTaskIterable() {
+    List<StreamTask> activeRunningTaskIterable() {
         return activeRunningTaskStream().collect(Collectors.toList());
     }
 
@@ -1720,20 +1718,21 @@ public class TaskManager {
         );
     }
 
-    private Stream<Task> activeRunningTaskStream() {
-        return tasks.allInitializedTasks().stream().filter(Task::isActive);
+    private Stream<StreamTask> activeRunningTaskStream() {
+        return 
tasks.allInitializedTasks().stream().filter(Task::isActive).map(StreamTask.class::cast);
     }
 
-    Map<TaskId, Task> standbyTaskMap() {
+    Map<TaskId, StandbyTask> standbyTaskMap() {
         return standbyTaskStream().collect(Collectors.toMap(Task::id, t -> t));
     }
 
-    private List<Task> standbyTaskIterable() {
+    private List<StandbyTask> standbyTaskIterable() {
         return standbyTaskStream().collect(Collectors.toList());
     }
 
-    private Stream<Task> standbyTaskStream() {
-        final Stream<Task> standbyTasksInTaskRegistry = 
tasks.allInitializedTasks().stream().filter(t -> !t.isActive());
+    private Stream<StandbyTask> standbyTaskStream() {
+        final Stream<StandbyTask> standbyTasksInTaskRegistry = 
tasks.allInitializedTasks().stream().filter(t -> !t.isActive())
+                .map(StandbyTask.class::cast);
         return Stream.concat(
             stateUpdater.standbyTasks().stream(),
             standbyTasksInTaskRegistry
@@ -1749,7 +1748,7 @@ public class TaskManager {
      * the corresponding record queues have capacity (again).
      */
     public void resumePollingForPartitionsWithAvailableSpace() {
-        for (final Task t: tasks.activeInitializedTasks()) {
+        for (final StreamTask t: tasks.activeInitializedTasks()) {
             t.resumePollingForPartitionsWithAvailableSpace();
         }
     }
@@ -1758,7 +1757,7 @@ public class TaskManager {
      * Fetches up-to-date lag information from the consumer.
      */
     public void updateLags() {
-        for (final Task t: tasks.activeInitializedTasks()) {
+        for (final StreamTask t: tasks.activeInitializedTasks()) {
             t.updateLags();
         }
     }
@@ -1861,7 +1860,7 @@ public class TaskManager {
      * @throws TaskCorruptedException if committing offsets failed due to 
TimeoutException (EOS)
      * @return number of committed offsets, or -1 if we are in the middle of a 
rebalance and cannot commit
      */
-    int commit(final Collection<Task> tasksToCommit) {
+    int commit(final Collection<? extends Task> tasksToCommit) {
         int committed = 0;
         final Set<TaskId> ids =
             tasksToCommit.stream()
@@ -1893,7 +1892,7 @@ public class TaskManager {
         if (rebalanceInProgress) {
             return -1;
         } else {
-            for (final Task task : activeRunningTaskIterable()) {
+            for (final StreamTask task : activeRunningTaskIterable()) {
                 if (task.commitRequested() && task.commitNeeded()) {
                     return commit(activeRunningTaskIterable());
                 }
@@ -1902,7 +1901,7 @@ public class TaskManager {
         }
     }
 
-    private int commitTasksAndMaybeUpdateCommittableOffsets(final 
Collection<Task> tasksToCommit,
+    private int commitTasksAndMaybeUpdateCommittableOffsets(final Collection<? 
extends Task> tasksToCommit,
                                                             final Map<Task, 
Map<TopicPartition, OffsetAndMetadata>> consumedOffsetsAndMetadata) {
         if (rebalanceInProgress) {
             return -1;
@@ -1912,11 +1911,9 @@ public class TaskManager {
     }
 
     public void updateTaskEndMetadata(final TopicPartition topicPartition, 
final Long offset) {
-        for (final Task task : tasks.activeInitializedTasks()) {
-            if (task instanceof StreamTask) {
-                if (task.inputPartitions().contains(topicPartition)) {
-                    ((StreamTask) task).updateEndOffsets(topicPartition, 
offset);
-                }
+        for (final StreamTask task : tasks.activeInitializedTasks()) {
+            if (task.inputPartitions().contains(topicPartition)) {
+                task.updateEndOffsets(topicPartition, offset);
             }
         }
     }
@@ -1941,21 +1938,21 @@ public class TaskManager {
 
     void maybeCloseTasksFromRemovedTopologies(final Set<String> 
currentNamedTopologies) {
         try {
-            final Set<Task> activeTasksToRemove = new 
TreeSet<>(Comparator.comparing(Task::id));
-            final Set<Task> standbyTasksToRemove = new 
TreeSet<>(Comparator.comparing(Task::id));
+            final Set<StreamTask> activeTasksToRemove = new 
TreeSet<>(Comparator.comparing(Task::id));
+            final Set<StandbyTask> standbyTasksToRemove = new 
TreeSet<>(Comparator.comparing(Task::id));
             for (final Task task : tasks.allInitializedTasks()) {
                 if 
(!currentNamedTopologies.contains(task.id().topologyName())) {
                     if (task.isActive()) {
-                        activeTasksToRemove.add(task);
+                        activeTasksToRemove.add((StreamTask) task);
                     } else {
-                        standbyTasksToRemove.add(task);
+                        standbyTasksToRemove.add((StandbyTask) task);
                     }
                 }
             }
 
-            final Set<Task> allTasksToRemove = union(HashSet::new, 
activeTasksToRemove, standbyTasksToRemove);
+            final Set<TaskId> allTaskIdsToRemove = 
Stream.concat(activeTasksToRemove.stream(), 
standbyTasksToRemove.stream()).map(Task::id).collect(Collectors.toSet());
             closeAndCleanUpTasks(activeTasksToRemove, standbyTasksToRemove, 
true);
-            
releaseLockedDirectoriesForTasks(allTasksToRemove.stream().map(Task::id).collect(Collectors.toSet()));
+            releaseLockedDirectoriesForTasks(allTaskIdsToRemove);
         } catch (final Exception e) {
             // TODO KAFKA-12648: for now just swallow the exception to avoid 
interfering with the other topologies
             //  that are running alongside, but eventually we should be able 
to rethrow up to the handler to inform
@@ -1980,7 +1977,7 @@ public class TaskManager {
     }
 
     void recordTaskProcessRatio(final long totalProcessLatencyMs, final long 
now) {
-        for (final Task task : activeRunningTaskIterable()) {
+        for (final StreamTask task : activeRunningTaskIterable()) {
             task.recordProcessTimeRatioAndBufferSize(totalProcessLatencyMs, 
now);
         }
     }
@@ -2004,7 +2001,7 @@ public class TaskManager {
             }
 
             final Map<TopicPartition, RecordsToDelete> recordsToDelete = new 
HashMap<>();
-            for (final Task task : activeRunningTaskIterable()) {
+            for (final StreamTask task : activeRunningTaskIterable()) {
                 for (final Map.Entry<TopicPartition, Long> entry : 
task.purgeableOffsets().entrySet()) {
                     recordsToDelete.put(entry.getKey(), 
RecordsToDelete.beforeOffset(entry.getValue()));
                 }
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java
index 2e6175446ad..dd496e7a57e 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java
@@ -46,11 +46,10 @@ import static org.apache.kafka.common.utils.Utils.union;
 class Tasks implements TasksRegistry {
     private final Logger log;
 
-    // TODO: convert to Stream/StandbyTask when we remove 
TaskManager#StateMachineTask with mocks
     // note that these two maps may be accessed by concurrent threads and hence
     // should be synchronized when accessed
-    private final Map<TaskId, Task> activeTasksPerId = new TreeMap<>();
-    private final Map<TaskId, Task> standbyTasksPerId = new TreeMap<>();
+    private final Map<TaskId, StreamTask> activeTasksPerId = new TreeMap<>();
+    private final Map<TaskId, StandbyTask> standbyTasksPerId = new TreeMap<>();
 
     // Tasks may have been assigned for a NamedTopology that is not yet known 
by this host. When that occurs we stash
     // these unknown tasks until either the corresponding NamedTopology is 
added and we can create them at last, or
@@ -61,8 +60,7 @@ class Tasks implements TasksRegistry {
     private final Set<Task> pendingTasksToClose = new HashSet<>();
     private final Set<TaskId> failedTaskIds = new HashSet<>();
 
-    // TODO: convert to Stream/StandbyTask when we remove 
TaskManager#StateMachineTask with mocks
-    private final Map<TopicPartition, Task> activeTasksPerPartition = new 
HashMap<>();
+    private final Map<TopicPartition, StreamTask> activeTasksPerPartition = 
new HashMap<>();
 
     Tasks(final LogContext logContext) {
         this.log = logContext.logger(getClass());
@@ -112,13 +110,13 @@ class Tasks implements TasksRegistry {
     }
 
     @Override
-    public Set<Task> drainPendingActiveTasksToInit() {
-        final Set<Task> result = new HashSet<>();
+    public Set<StreamTask> drainPendingActiveTasksToInit() {
+        final Set<StreamTask> result = new HashSet<>();
         final Iterator<Task> iterator = pendingTasksToInit.iterator();
         while (iterator.hasNext()) {
             final Task task = iterator.next();
             if (task.isActive()) {
-                result.add(task);
+                result.add((StreamTask) task);
                 iterator.remove();
             }
         }
@@ -126,13 +124,13 @@ class Tasks implements TasksRegistry {
     }
 
     @Override
-    public Set<Task> drainPendingStandbyTasksToInit() {
-        final Set<Task> result = new HashSet<>();
+    public Set<StandbyTask> drainPendingStandbyTasksToInit() {
+        final Set<StandbyTask> result = new HashSet<>();
         final Iterator<Task> iterator = pendingTasksToInit.iterator();
         while (iterator.hasNext()) {
             final Task task = iterator.next();
             if (!task.isActive()) {
-                result.add(task);
+                result.add((StandbyTask) task);
                 iterator.remove();
             }
         }
@@ -145,7 +143,7 @@ class Tasks implements TasksRegistry {
     }
 
     @Override
-    public void addPendingTasksToInit(final Collection<Task> tasks) {
+    public void addPendingTasksToInit(final Collection<? extends Task> tasks) {
         pendingTasksToInit.addAll(tasks);
     }
 
@@ -160,7 +158,7 @@ class Tasks implements TasksRegistry {
     }
 
     @Override
-    public void addPendingTasksToClose(final Collection<Task> tasks) {
+    public void addPendingTasksToClose(final Collection<? extends Task> tasks) 
{
         pendingTasksToClose.addAll(tasks);
     }
 
@@ -170,26 +168,36 @@ class Tasks implements TasksRegistry {
     }
 
     @Override
-    public void addActiveTasks(final Collection<Task> newTasks) {
+    public void addActiveTasks(final Collection<StreamTask> newTasks) {
         if (!newTasks.isEmpty()) {
-            for (final Task activeTask : newTasks) {
-                addTask(activeTask);
+            for (final StreamTask activeTask : newTasks) {
+                addActiveTask(activeTask);
             }
         }
     }
 
     @Override
-    public void addStandbyTasks(final Collection<Task> newTasks) {
+    public void addStandbyTasks(final Collection<StandbyTask> newTasks) {
         if (!newTasks.isEmpty()) {
-            for (final Task standbyTask : newTasks) {
-                addTask(standbyTask);
+            for (final StandbyTask standbyTask : newTasks) {
+                addStandbyTask(standbyTask);
             }
         }
     }
 
     @Override
-    public synchronized void addTask(final Task task) {
+    public void addTask(final Task task) {
+        if (task.isActive()) {
+            addActiveTask((StreamTask) task);
+        } else {
+            addStandbyTask((StandbyTask) task);
+        }
+    }
+
+    @Override
+    public synchronized void addActiveTask(final StreamTask task) {
         final TaskId taskId = task.id();
+
         if (activeTasksPerId.containsKey(taskId)) {
             throw new IllegalStateException("Attempted to create an active 
task that we already own: " + taskId);
         }
@@ -198,17 +206,28 @@ class Tasks implements TasksRegistry {
             throw new IllegalStateException("Attempted to create an active 
task while we already own its standby: " + taskId);
         }
 
-        if (task.isActive()) {
-            activeTasksPerId.put(task.id(), task);
-            pendingActiveTasksToCreate.remove(task.id());
-            for (final TopicPartition topicPartition : task.inputPartitions()) 
{
-                activeTasksPerPartition.put(topicPartition, task);
-            }
-        } else {
-            standbyTasksPerId.put(task.id(), task);
+        activeTasksPerId.put(taskId, task);
+        pendingActiveTasksToCreate.remove(taskId);
+        for (final TopicPartition topicPartition : task.inputPartitions()) {
+            activeTasksPerPartition.put(topicPartition, task);
         }
     }
 
+    @Override
+    public synchronized void addStandbyTask(final StandbyTask task) {
+        final TaskId taskId = task.id();
+
+        if (standbyTasksPerId.containsKey(taskId)) {
+            throw new IllegalStateException("Attempted to create an standby 
task that we already own: " + taskId);
+        }
+
+        if (activeTasksPerId.containsKey(taskId)) {
+            throw new IllegalStateException("Attempted to create an standby 
task while we already own its active: " + taskId);
+        }
+
+        standbyTasksPerId.put(taskId, task);
+    }
+
     @Override
     public void addFailedTask(final Task task) {
         failedTaskIds.add(task.id());
@@ -252,17 +271,15 @@ class Tasks implements TasksRegistry {
     }
 
     @Override
-    public boolean updateActiveTaskInputPartitions(final Task task, final 
Set<TopicPartition> topicPartitions) {
+    public boolean updateActiveTaskInputPartitions(final StreamTask task, 
final Set<TopicPartition> topicPartitions) {
         final boolean requiresUpdate = 
!task.inputPartitions().equals(topicPartitions);
         if (requiresUpdate) {
             log.debug("Update task {} inputPartitions: current {}, new {}", 
task, task.inputPartitions(), topicPartitions);
-            if (task.isActive()) {
-                for (final TopicPartition inputPartition : 
task.inputPartitions()) {
-                    activeTasksPerPartition.remove(inputPartition);
-                }
-                for (final TopicPartition topicPartition : topicPartitions) {
-                    activeTasksPerPartition.put(topicPartition, task);
-                }
+            for (final TopicPartition inputPartition : task.inputPartitions()) 
{
+                activeTasksPerPartition.remove(inputPartition);
+            }
+            for (final TopicPartition topicPartition : topicPartitions) {
+                activeTasksPerPartition.put(topicPartition, task);
             }
         }
 
@@ -289,9 +306,8 @@ class Tasks implements TasksRegistry {
         failedTaskIds.clear();
     }
 
-    // TODO: change return type to `StreamTask`
     @Override
-    public Task activeInitializedTasksForInputPartition(final TopicPartition 
partition) {
+    public StreamTask activeInitializedTasksForInputPartition(final 
TopicPartition partition) {
         return activeTasksPerPartition.get(partition);
     }
 
@@ -330,12 +346,12 @@ class Tasks implements TasksRegistry {
     }
 
     @Override
-    public synchronized Collection<Task> activeInitializedTasks() {
+    public synchronized Collection<StreamTask> activeInitializedTasks() {
         return Collections.unmodifiableCollection(activeTasksPerId.values());
     }
 
     @Override
-    public synchronized Collection<Task> standbyInitializedTasks() {
+    public synchronized Collection<StandbyTask> standbyInitializedTasks() {
         return Collections.unmodifiableCollection(standbyTasksPerId.values());
     }
 
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TasksRegistry.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TasksRegistry.java
index 4e037fbd6f6..31ad2843e60 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TasksRegistry.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TasksRegistry.java
@@ -37,39 +37,43 @@ public interface TasksRegistry {
 
     Set<Task> drainPendingTasksToInit();
 
-    Set<Task> drainPendingActiveTasksToInit();
+    Set<StreamTask> drainPendingActiveTasksToInit();
 
-    Set<Task> drainPendingStandbyTasksToInit();
+    Set<StandbyTask> drainPendingStandbyTasksToInit();
 
     Set<Task> pendingTasksToInit();
 
-    void addPendingTasksToInit(final Collection<Task> tasks);
+    void addPendingTasksToInit(final Collection<? extends Task> tasks);
 
     boolean hasPendingTasksToInit();
 
     Set<Task> pendingTasksToClose();
 
-    void addPendingTasksToClose(final Collection<Task> tasks);
+    void addPendingTasksToClose(final Collection<? extends Task> tasks);
 
     boolean hasPendingTasksToClose();
 
-    void addActiveTasks(final Collection<Task> tasks);
+    void addActiveTasks(final Collection<StreamTask> tasks);
 
-    void addStandbyTasks(final Collection<Task> tasks);
+    void addStandbyTasks(final Collection<StandbyTask> tasks);
 
     void addTask(final Task task);
 
+    void addActiveTask(final StreamTask task);
+
+    void addStandbyTask(final StandbyTask task);
+
     void addFailedTask(final Task task);
 
     void removeTask(final Task taskToRemove);
 
     void replaceStandbyWithActive(final StreamTask activeTask);
 
-    boolean updateActiveTaskInputPartitions(final Task task, final 
Set<TopicPartition> topicPartitions);
+    boolean updateActiveTaskInputPartitions(final StreamTask task, final 
Set<TopicPartition> topicPartitions);
 
     void clear();
 
-    Task activeInitializedTasksForInputPartition(final TopicPartition 
partition);
+    StreamTask activeInitializedTasksForInputPartition(final TopicPartition 
partition);
 
     Task initializedTask(final TaskId taskId);
 
@@ -77,9 +81,9 @@ public interface TasksRegistry {
 
     Collection<TaskId> activeInitializedTaskIds();
 
-    Collection<Task> activeInitializedTasks();
+    Collection<StreamTask> activeInitializedTasks();
 
-    Collection<Task> standbyInitializedTasks();
+    Collection<StandbyTask> standbyInitializedTasks();
 
     Set<Task> allInitializedTasks();
 
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/tasks/DefaultTaskManager.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/tasks/DefaultTaskManager.java
index 2259f7768c2..53950c10088 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/tasks/DefaultTaskManager.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/tasks/DefaultTaskManager.java
@@ -102,10 +102,10 @@ public final class DefaultTaskManager implements 
TaskManager {
             }
 
             // the most naive scheduling algorithm for now: give the next 
unlocked, unassigned, and  processable task
-            for (final Task task : tasks.activeInitializedTasks()) {
+            for (final StreamTask task : tasks.activeInitializedTasks()) {
                 if (!assignedTasks.containsKey(task.id()) &&
                     !lockedTasks.contains(task.id()) &&
-                    canProgress((StreamTask) task, time.milliseconds()) &&
+                    canProgress(task, time.milliseconds()) &&
                     !hasUncaughtException(task.id())
                 ) {
 
@@ -113,7 +113,7 @@ public final class DefaultTaskManager implements 
TaskManager {
 
                     log.debug("Assigned task {} to executor {}", task.id(), 
executor.name());
 
-                    return (StreamTask) task;
+                    return task;
                 }
             }
 
@@ -126,10 +126,10 @@ public final class DefaultTaskManager implements 
TaskManager {
     @Override
     public void awaitProcessableTasks(final Supplier<Boolean> isShuttingDown) 
throws InterruptedException {
         final boolean interrupted = returnWithTasksLocked(() -> {
-            for (final Task task : tasks.activeInitializedTasks()) {
+            for (final StreamTask task : tasks.activeInitializedTasks()) {
                 if (!assignedTasks.containsKey(task.id()) &&
                     !lockedTasks.contains(task.id()) &&
-                    canProgress((StreamTask) task, time.milliseconds()) &&
+                    canProgress(task, time.milliseconds()) &&
                     !hasUncaughtException(task.id())
                 ) {
                     log.debug("Await unblocked: returning early from await 
since a processable task {} was found", task.id());
@@ -270,7 +270,7 @@ public final class DefaultTaskManager implements 
TaskManager {
     public void add(final Set<StreamTask> tasksToAdd) {
         executeWithTasksLocked(() -> {
             for (final StreamTask task : tasksToAdd) {
-                tasks.addTask(task);
+                tasks.addActiveTask(task);
             }
             log.debug("Waking up task executors");
             tasksCondition.signalAll();
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
index 947c1eedc18..8db09dba11e 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java
@@ -1068,7 +1068,7 @@ public class StreamThreadTest {
             null
         ) {
             @Override
-            int commit(final Collection<Task> tasksToCommit) {
+            int commit(final Collection<? extends Task> tasksToCommit) {
                 committed.set(true);
                 // we advance time to make sure the commit delay is considered 
when computing the next commit timestamp
                 mockTime.sleep(commitLatency);
@@ -1137,7 +1137,7 @@ public class StreamThreadTest {
         when(consumer.groupMetadata()).thenReturn(consumerGroupMetadata);
         
when(consumerGroupMetadata.groupInstanceId()).thenReturn(Optional.empty());
         when(consumer.poll(any())).thenReturn(ConsumerRecords.empty());
-        final Task task = mock(Task.class);
+        final StreamTask task = mock(StreamTask.class);
         final ActiveTaskCreator activeTaskCreator = 
mock(ActiveTaskCreator.class);
         when(activeTaskCreator.createTasks(any(), 
any())).thenReturn(Collections.singleton(task));
         
when(activeTaskCreator.producerClientIds()).thenReturn("producerClientId");
@@ -1176,7 +1176,7 @@ public class StreamThreadTest {
             schedulingTaskManager
         ) {
             @Override
-            int commit(final Collection<Task> tasksToCommit) {
+            int commit(final Collection<? extends Task> tasksToCommit) {
                 mockTime.sleep(10L);
                 return 1;
             }
@@ -4060,8 +4060,7 @@ public class StreamThreadTest {
         internalTopologyBuilder.setStreamsConfig(config);
     }
 
-    // TODO: change return type to `StandbyTask`
-    private Collection<Task> createStandbyTask(final StreamsConfig config) {
+    private Collection<StandbyTask> createStandbyTask(final StreamsConfig 
config) {
         final LogContext logContext = new LogContext("test");
         final StreamsMetricsImpl streamsMetrics =
             new StreamsMetricsImpl(metrics, CLIENT_ID, mockTime);
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
index 92d150096a7..ce369cd1c08 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
@@ -668,7 +668,7 @@ public class TaskManagerTest {
         assertEquals("Encounter unexpected fatal error for task " + 
failedActiveTaskToRecycle.id(), exception.getMessage());
         assertEquals(taskException, exception.getCause());
         verify(tasks).addFailedTask(failedActiveTaskToRecycle);
-        verify(tasks, never()).addTask(failedActiveTaskToRecycle);
+        verify(tasks, never()).addActiveTask(failedActiveTaskToRecycle);
         verify(tasks).allNonFailedInitializedTasks();
         verify(standbyTaskCreator, 
never()).createStandbyTaskFromActive(failedActiveTaskToRecycle, 
taskId03Partitions);
     }
@@ -698,7 +698,7 @@ public class TaskManagerTest {
         assertEquals("Encounter unexpected fatal error for task " + 
failedStandbyTaskToRecycle.id(), exception.getMessage());
         assertEquals(taskException, exception.getCause());
         verify(tasks).addFailedTask(failedStandbyTaskToRecycle);
-        verify(tasks, never()).addTask(failedStandbyTaskToRecycle);
+        verify(tasks, never()).addStandbyTask(failedStandbyTaskToRecycle);
         verify(tasks).allNonFailedInitializedTasks();
         verify(activeTaskCreator, 
never()).createActiveTaskFromStandby(failedStandbyTaskToRecycle, 
taskId03Partitions, consumer);
     }
@@ -728,7 +728,7 @@ public class TaskManagerTest {
         assertEquals("Encounter unexpected fatal error for task " + 
failedActiveTaskToReassign.id(), exception.getMessage());
         assertEquals(taskException, exception.getCause());
         verify(tasks).addFailedTask(failedActiveTaskToReassign);
-        verify(tasks, never()).addTask(failedActiveTaskToReassign);
+        verify(tasks, never()).addActiveTask(failedActiveTaskToReassign);
         verify(tasks).allNonFailedInitializedTasks();
         verify(tasks, 
never()).updateActiveTaskInputPartitions(failedActiveTaskToReassign, 
taskId00Partitions);
     }
@@ -882,7 +882,7 @@ public class TaskManagerTest {
             .withInputPartitions(taskId03Partitions).build();
         final TasksRegistry tasks = mock(TasksRegistry.class);
         final TaskManager taskManager = 
setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks);
-        final Set<Task> createdTasks = Set.of(activeTaskToBeCreated);
+        final Set<StreamTask> createdTasks = Set.of(activeTaskToBeCreated);
         final Map<TaskId, Set<TopicPartition>> tasksToBeCreated = mkMap(
             mkEntry(activeTaskToBeCreated.id(), 
activeTaskToBeCreated.inputPartitions()));
         when(activeTaskCreator.createTasks(consumer, 
tasksToBeCreated)).thenReturn(createdTasks);
@@ -900,7 +900,7 @@ public class TaskManagerTest {
             .withInputPartitions(taskId02Partitions).build();
         final TasksRegistry tasks = mock(TasksRegistry.class);
         final TaskManager taskManager = 
setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks);
-        final Set<Task> createdTasks = Set.of(standbyTaskToBeCreated);
+        final Set<StandbyTask> createdTasks = Set.of(standbyTaskToBeCreated);
         when(standbyTaskCreator.createTasks(mkMap(
             mkEntry(standbyTaskToBeCreated.id(), 
standbyTaskToBeCreated.inputPartitions())))
         ).thenReturn(createdTasks);
@@ -1320,7 +1320,7 @@ public class TaskManagerTest {
         taskManager.handleRevocation(task.inputPartitions());
 
         verify(task).suspend();
-        verify(tasks).addTask(task);
+        verify(tasks).addActiveTask(task);
         verify(stateUpdater).remove(task.id());
     }
 
@@ -1344,9 +1344,9 @@ public class TaskManagerTest {
         taskManager.handleRevocation(union(HashSet::new, taskId00Partitions, 
taskId01Partitions));
 
         verify(task1).suspend();
-        verify(tasks).addTask(task1);
+        verify(tasks).addActiveTask(task1);
         verify(task2).suspend();
-        verify(tasks).addTask(task2);
+        verify(tasks).addActiveTask(task2);
     }
 
     @Test
@@ -1360,7 +1360,7 @@ public class TaskManagerTest {
         taskManager.handleRevocation(taskId01Partitions);
 
         verify(task, never()).suspend();
-        verify(tasks, never()).addTask(task);
+        verify(tasks, never()).addActiveTask(task);
         verify(stateUpdater, never()).remove(task.id());
     }
 
@@ -1375,7 +1375,7 @@ public class TaskManagerTest {
         taskManager.handleRevocation(taskId00Partitions);
 
         verify(task, never()).suspend();
-        verify(tasks, never()).addTask(task);
+        verify(tasks, never()).addStandbyTask(task);
         verify(stateUpdater, never()).remove(task.id());
     }
 
@@ -1405,7 +1405,7 @@ public class TaskManagerTest {
         assertEquals("Encounter unexpected fatal error for task " + 
task2.id(), thrownException.getMessage());
         assertEquals(thrownException.getCause(), taskException);
         verify(task1).suspend();
-        verify(tasks).addTask(task1);
+        verify(tasks).addActiveTask(task1);
         verify(task2, never()).suspend();
         verify(tasks).addFailedTask(task2);
     }
@@ -1560,7 +1560,7 @@ public class TaskManagerTest {
         for (final StreamTask restoredTask : restoredTasks) {
             verify(restoredTask).completeRestoration(noOpResetter);
             verify(restoredTask, atLeastOnce()).clearTaskTimeout();
-            verify(tasks).addTask(restoredTask);
+            verify(tasks).addActiveTask(restoredTask);
             verify(consumer).resume(restoredTask.inputPartitions());
         }
     }
@@ -1579,7 +1579,7 @@ public class TaskManagerTest {
 
         verify(task).maybeInitTaskTimeoutOrThrow(anyLong(), 
eq(timeoutException));
         verify(stateUpdater).add(task);
-        verify(tasks, never()).addTask(task);
+        verify(tasks, never()).addActiveTask(task);
         verify(task, never()).clearTaskTimeout();
         verifyNoInteractions(consumer);
     }
@@ -1614,19 +1614,19 @@ public class TaskManagerTest {
         assertThrows(StreamsException.class, () -> 
taskManager.checkStateUpdater(time.milliseconds(), noOpResetter));
 
         // task1 should be successfully transitioned
-        verify(tasks).addTask(task1);
+        verify(tasks).addActiveTask(task1);
         verify(consumer).resume(task1.inputPartitions());
         verify(task1).clearTaskTimeout();
 
         // task2 should be added back to state updater once in the finally 
block
         // (the add in the catch block doesn't execute because 
maybeInitTaskTimeoutOrThrow throws)
         verify(stateUpdater).add(task2);
-        verify(tasks, never()).addTask(task2);
+        verify(tasks, never()).addActiveTask(task2);
         verify(task2, never()).clearTaskTimeout();
 
         // task3 should also be added back to state updater in the finally 
block
         verify(stateUpdater).add(task3);
-        verify(tasks, never()).addTask(task3);
+        verify(tasks, never()).addActiveTask(task3);
         verify(task3, never()).clearTaskTimeout();
     }
 
@@ -2951,7 +2951,7 @@ public class TaskManagerTest {
         assertFalse(restorationComplete);
         verify(task00).completeRestoration(any());
         verify(stateUpdater).add(task00);
-        verify(tasks, never()).addTask(task00);
+        verify(tasks, never()).addActiveTask(task00);
         verifyNoInteractions(consumer);
     }
 
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TasksTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TasksTest.java
index bb2865d223c..30328735e9e 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TasksTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TasksTest.java
@@ -157,10 +157,9 @@ public class TasksTest {
         final StandbyTask standbyTask2 = standbyTask(TASK_1_1, 
Set.of(TOPIC_PARTITION_A_1)).build();
         tasks.addPendingTasksToInit(Set.of(activeTask1, activeTask2, 
standbyTask1, standbyTask2));
 
-        final Set<Task> activeTasksToInit = 
tasks.drainPendingActiveTasksToInit();
+        final Set<StreamTask> activeTasksToInit = 
tasks.drainPendingActiveTasksToInit();
         assertEquals(2, activeTasksToInit.size());
         assertTrue(activeTasksToInit.containsAll(Set.of(activeTask1, 
activeTask2)));
-        assertFalse(activeTasksToInit.containsAll(Set.of(standbyTask1, 
standbyTask2)));
         assertEquals(2, tasks.pendingTasksToInit().size());
         assertTrue(tasks.hasPendingTasksToInit());
         assertTrue(tasks.pendingTasksToInit().containsAll(Set.of(standbyTask1, 
standbyTask2)));
@@ -174,11 +173,10 @@ public class TasksTest {
         final StandbyTask standbyTask2 = standbyTask(TASK_1_1, 
Set.of(TOPIC_PARTITION_A_1)).build();
         tasks.addPendingTasksToInit(Set.of(activeTask1, activeTask2, 
standbyTask1, standbyTask2));
 
-        final Set<Task> standbyTasksToInit = 
tasks.drainPendingStandbyTasksToInit();
+        final Set<StandbyTask> standbyTasksToInit = 
tasks.drainPendingStandbyTasksToInit();
 
         assertEquals(2, standbyTasksToInit.size());
         assertTrue(standbyTasksToInit.containsAll(Set.of(standbyTask1, 
standbyTask2)));
-        assertFalse(standbyTasksToInit.containsAll(Set.of(activeTask1, 
activeTask2)));
         assertEquals(2, tasks.pendingTasksToInit().size());
         assertTrue(tasks.hasPendingTasksToInit());
         assertTrue(tasks.pendingTasksToInit().containsAll(Set.of(activeTask1, 
activeTask2)));
@@ -188,7 +186,7 @@ public class TasksTest {
     public void shouldAddFailedTask() {
         final StreamTask activeTask1 = statefulTask(TASK_0_0, 
Set.of(TOPIC_PARTITION_B_0)).build();
         final StreamTask activeTask2 = statefulTask(TASK_0_1, 
Set.of(TOPIC_PARTITION_B_1)).build();
-        tasks.addTask(activeTask2);
+        tasks.addActiveTask(activeTask2);
 
         tasks.addFailedTask(activeTask1);
 
@@ -210,7 +208,7 @@ public class TasksTest {
         
assertFalse(tasks.allNonFailedInitializedTasks().contains(activeTask1));
         assertFalse(tasks.allInitializedTasks().contains(activeTask1));
 
-        tasks.addTask(activeTask1);
+        tasks.addActiveTask(activeTask1);
         assertTrue(tasks.allNonFailedInitializedTasks().contains(activeTask1));
     }
 
@@ -224,7 +222,7 @@ public class TasksTest {
         
assertFalse(tasks.allNonFailedInitializedTasks().contains(activeTask1));
         assertFalse(tasks.allInitializedTasks().contains(activeTask1));
 
-        tasks.addTask(activeTask1);
+        tasks.addActiveTask(activeTask1);
         assertTrue(tasks.allNonFailedInitializedTasks().contains(activeTask1));
     }
 
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/tasks/DefaultTaskManagerTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/tasks/DefaultTaskManagerTest.java
index 879f62f0b08..2f6a2fe30d7 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/tasks/DefaultTaskManagerTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/tasks/DefaultTaskManagerTest.java
@@ -93,7 +93,7 @@ public class DefaultTaskManagerTest {
     public void shouldAddTask() {
         taskManager.add(Collections.singleton(task));
 
-        verify(tasks).addTask(task);
+        verify(tasks).addActiveTask(task);
         
when(tasks.activeInitializedTasks()).thenReturn(Collections.singleton(task));
         assertEquals(1, taskManager.getTasks().size());
     }

Reply via email to