This is an automated email from the ASF dual-hosted git repository.

lucasbru pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new e31ae3ebfc1 KAFKA-16836: Fix incorrect suspend reason in standby 
update listener (#21500)
e31ae3ebfc1 is described below

commit e31ae3ebfc1ff7c25b9059923560c6ce39cf39ae
Author: Lucas Brutschy <[email protected]>
AuthorDate: Thu Feb 19 14:42:58 2026 +0100

    KAFKA-16836: Fix incorrect suspend reason in standby update listener 
(#21500)
    
    `StoreChangelogReader.unregister()` incorrectly inferred the
    `SuspendReason` by checking whether the task state was `RUNNING`. Since
    `unregister()` is called from within the state updater thread where
    standby tasks are always in `RUNNING` state, `onUpdateSuspended()`
    always reported `PROMOTED` even when the task was actually being
    migrated.
    
    The fix passes the `SuspendReason` explicitly through the call chain
    from `TaskManager` → `StateUpdater.remove()` → `DefaultStateUpdater` →
    `ChangelogReader.unregister()`, removing the incorrect inference logic.
    `TaskManager` knows the actual reason: `PROMOTED` when recycling a
    standby task to active, and `MIGRATED` for all other cases (task
    closing, active recycling, input partition changes, revocation,
    shutdown).
    
    Reviewers: TengYao Chi <[email protected]>
---
 .../processor/internals/ChangelogRegister.java     |  11 +-
 .../processor/internals/DefaultStateUpdater.java   |  38 +++---
 .../streams/processor/internals/StateUpdater.java  |   4 +-
 .../processor/internals/StoreChangelogReader.java  |  13 ++-
 .../streams/processor/internals/TaskAndAction.java |  21 +++-
 .../streams/processor/internals/TaskManager.java   |  19 +--
 .../internals/DefaultStateUpdaterTest.java         |  63 ++++++----
 .../processor/internals/MockChangelogReader.java   |   7 ++
 .../internals/StoreChangelogReaderTest.java        |  33 ++++++
 .../processor/internals/TaskAndActionTest.java     |   8 +-
 .../processor/internals/TaskManagerTest.java       | 127 +++++++++++----------
 .../apache/kafka/streams/TopologyTestDriver.java   |   5 +
 12 files changed, 227 insertions(+), 122 deletions(-)

diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogRegister.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogRegister.java
index 74ef370c2f6..8bcf6928056 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogRegister.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogRegister.java
@@ -17,6 +17,7 @@
 package org.apache.kafka.streams.processor.internals;
 
 import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.streams.processor.StandbyUpdateListener;
 
 import java.util.Collection;
 import java.util.Set;
@@ -36,8 +37,16 @@ public interface ChangelogRegister {
     void register(final Set<TopicPartition> partitions, final 
ProcessorStateManager stateManager);
 
     /**
-     * Unregisters and removes the passed in partitions from the set of 
changelogs
+     * Unregisters and removes the passed in partitions from the set of 
changelogs.
+     * Defaults to {@link StandbyUpdateListener.SuspendReason#MIGRATED} for 
the standby suspend reason.
      * @param removedPartitions the set of partitions to remove
      */
     void unregister(final Collection<TopicPartition> removedPartitions);
+
+    /**
+     * Unregisters and removes the passed in partitions from the set of 
changelogs.
+     * @param removedPartitions the set of partitions to remove
+     * @param reason the reason for suspending standby update, passed to the 
standby update listener
+     */
+    void unregister(final Collection<TopicPartition> removedPartitions, final 
StandbyUpdateListener.SuspendReason reason);
 }
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java
index 7792a64f88f..cc06a3bf7d9 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java
@@ -35,6 +35,7 @@ import org.apache.kafka.common.utils.Time;
 import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.errors.TaskCorruptedException;
+import org.apache.kafka.streams.processor.StandbyUpdateListener;
 import org.apache.kafka.streams.processor.TaskId;
 import org.apache.kafka.streams.processor.internals.Task.State;
 import org.apache.kafka.streams.processor.internals.TaskAndAction.Action;
@@ -222,7 +223,7 @@ public class DefaultStateUpdater implements StateUpdater {
                             addTask(taskAndAction.task());
                             break;
                         case REMOVE:
-                            removeTask(taskAndAction.taskId(), 
taskAndAction.futureForRemove());
+                            removeTask(taskAndAction.taskId(), 
taskAndAction.futureForRemove(), taskAndAction.suspendReason());
                             break;
                         default:
                             throw new IllegalStateException("Unknown action 
type " + action);
@@ -519,10 +520,12 @@ public class DefaultStateUpdater implements StateUpdater {
             }
         }
 
-        private void removeTask(final TaskId taskId, final 
CompletableFuture<RemovedTaskResult> future) {
+        private void removeTask(final TaskId taskId,
+                                final CompletableFuture<RemovedTaskResult> 
future,
+                                final StandbyUpdateListener.SuspendReason 
suspendReason) {
             try {
-                if (!removeUpdatingTask(taskId, future)
-                    && !removePausedTask(taskId, future)
+                if (!removeUpdatingTask(taskId, future, suspendReason)
+                    && !removePausedTask(taskId, future, suspendReason)
                     && !removeRestoredTask(taskId, future)
                     && !removeFailedTask(taskId, future)) {
 
@@ -539,12 +542,14 @@ public class DefaultStateUpdater implements StateUpdater {
             }
         }
 
-        private boolean removeUpdatingTask(final TaskId taskId, final 
CompletableFuture<RemovedTaskResult> future) {
+        private boolean removeUpdatingTask(final TaskId taskId,
+                                           final 
CompletableFuture<RemovedTaskResult> future,
+                                           final 
StandbyUpdateListener.SuspendReason suspendReason) {
             if (!updatingTasks.containsKey(taskId)) {
                 return false;
             }
             final Task task = updatingTasks.get(taskId);
-            prepareUpdatingTaskForRemoval(task);
+            prepareUpdatingTaskForRemoval(task, suspendReason);
             updatingTasks.remove(taskId);
             if (task.isActive()) {
                 transitToUpdateStandbysIfOnlyStandbysLeft();
@@ -555,18 +560,21 @@ public class DefaultStateUpdater implements StateUpdater {
             return true;
         }
 
-        private void prepareUpdatingTaskForRemoval(final Task task) {
+        private void prepareUpdatingTaskForRemoval(final Task task,
+                                                   final 
StandbyUpdateListener.SuspendReason suspendReason) {
             measureCheckpointLatency(() -> task.maybeCheckpoint(true));
             final Collection<TopicPartition> changelogPartitions = 
task.changelogPartitions();
-            changelogReader.unregister(changelogPartitions);
+            changelogReader.unregister(changelogPartitions, suspendReason);
         }
 
-        private boolean removePausedTask(final TaskId taskId, final 
CompletableFuture<RemovedTaskResult> future) {
+        private boolean removePausedTask(final TaskId taskId,
+                                         final 
CompletableFuture<RemovedTaskResult> future,
+                                         final 
StandbyUpdateListener.SuspendReason suspendReason) {
             if (!pausedTasks.containsKey(taskId)) {
                 return false;
             }
             final Task task = pausedTasks.get(taskId);
-            preparePausedTaskForRemoval(task);
+            preparePausedTaskForRemoval(task, suspendReason);
             pausedTasks.remove(taskId);
             log.info((task.isActive() ? "Active" : "Standby")
                 + " task " + task.id() + " was removed from the paused 
tasks.");
@@ -574,9 +582,10 @@ public class DefaultStateUpdater implements StateUpdater {
             return true;
         }
 
-        private void preparePausedTaskForRemoval(final Task task) {
+        private void preparePausedTaskForRemoval(final Task task,
+                                                 final 
StandbyUpdateListener.SuspendReason suspendReason) {
             final Collection<TopicPartition> changelogPartitions = 
task.changelogPartitions();
-            changelogReader.unregister(changelogPartitions);
+            changelogReader.unregister(changelogPartitions, suspendReason);
         }
 
         private boolean removeRestoredTask(final TaskId taskId, final 
CompletableFuture<RemovedTaskResult> future) {
@@ -896,11 +905,12 @@ public class DefaultStateUpdater implements StateUpdater {
     }
 
     @Override
-    public CompletableFuture<RemovedTaskResult> remove(final TaskId taskId) {
+    public CompletableFuture<RemovedTaskResult> remove(final TaskId taskId,
+                                                       final 
StandbyUpdateListener.SuspendReason suspendReason) {
         final CompletableFuture<RemovedTaskResult> future = new 
CompletableFuture<>();
         tasksAndActionsLock.lock();
         try {
-            tasksAndActions.add(TaskAndAction.createRemoveTask(taskId, 
future));
+            tasksAndActions.add(TaskAndAction.createRemoveTask(taskId, future, 
suspendReason));
             tasksAndActionsCondition.signalAll();
         } finally {
             tasksAndActionsLock.unlock();
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateUpdater.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateUpdater.java
index 8dd3d41c689..03d4de98c62 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateUpdater.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateUpdater.java
@@ -18,6 +18,7 @@ package org.apache.kafka.streams.processor.internals;
 
 import org.apache.kafka.common.Uuid;
 import org.apache.kafka.common.internals.KafkaFutureImpl;
+import org.apache.kafka.streams.processor.StandbyUpdateListener;
 import org.apache.kafka.streams.processor.TaskId;
 
 import java.time.Duration;
@@ -144,8 +145,9 @@ public interface StateUpdater {
      * restored tasks, or failed tasks.
      *
      * @param taskId ID of the task to remove
+     * @param suspendReason the reason for suspending standby update, passed 
through to the changelog reader
      */
-    CompletableFuture<RemovedTaskResult> remove(final TaskId taskId);
+    CompletableFuture<RemovedTaskResult> remove(final TaskId taskId, final 
StandbyUpdateListener.SuspendReason suspendReason);
 
     /**
      * Wakes up the state updater if it is currently dormant, to check if a 
paused task should be resumed.
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java
index 3ec8aec6fca..eab7da800d8 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java
@@ -1041,6 +1041,12 @@ public class StoreChangelogReader implements 
ChangelogReader {
 
     @Override
     public void unregister(final Collection<TopicPartition> revokedChangelogs) 
{
+        unregister(revokedChangelogs, 
StandbyUpdateListener.SuspendReason.MIGRATED);
+    }
+
+    @Override
+    public void unregister(final Collection<TopicPartition> revokedChangelogs,
+                           final StandbyUpdateListener.SuspendReason reason) {
         // Only changelogs that are initialized have been added to the restore 
consumer's assignment
         final List<TopicPartition> revokedInitializedChangelogs = new 
ArrayList<>();
 
@@ -1068,13 +1074,8 @@ public class StoreChangelogReader implements 
ChangelogReader {
                             // endOffset and storeOffset may be unknown at 
this point
                             final long storeOffset = storeMetadata.offset() != 
null ? storeMetadata.offset() : -1;
                             final long endOffset = storeMetadata.endOffset() 
!= null ? storeMetadata.endOffset() : -1;
-                            // Unregistering running standby tasks means the 
task has been promoted to active.
-                            final StandbyUpdateListener.SuspendReason 
suspendReason = 
-                                changelogMetadata.stateManager.taskState() == 
Task.State.RUNNING 
-                                    ? 
StandbyUpdateListener.SuspendReason.PROMOTED
-                                    : 
StandbyUpdateListener.SuspendReason.MIGRATED;
                             try {
-                                
standbyUpdateListener.onUpdateSuspended(partition, storeName, storeOffset, 
endOffset, suspendReason);
+                                
standbyUpdateListener.onUpdateSuspended(partition, storeName, storeOffset, 
endOffset, reason);
                             } catch (final Exception e) {
                                 throw new StreamsException("Standby updater 
listener failed on update suspended", e);
                             }
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskAndAction.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskAndAction.java
index ec6c6830bbd..f36d80c216a 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskAndAction.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskAndAction.java
@@ -16,6 +16,7 @@
  */
 package org.apache.kafka.streams.processor.internals;
 
+import org.apache.kafka.streams.processor.StandbyUpdateListener;
 import org.apache.kafka.streams.processor.TaskId;
 
 import java.util.Objects;
@@ -32,27 +33,32 @@ public class TaskAndAction {
     private final TaskId taskId;
     private final Action action;
     private final CompletableFuture<StateUpdater.RemovedTaskResult> 
futureForRemove;
+    private final StandbyUpdateListener.SuspendReason suspendReason;
 
     private TaskAndAction(final Task task,
                           final TaskId taskId,
                           final Action action,
-                          final 
CompletableFuture<StateUpdater.RemovedTaskResult> futureForRemove) {
+                          final 
CompletableFuture<StateUpdater.RemovedTaskResult> futureForRemove,
+                          final StandbyUpdateListener.SuspendReason 
suspendReason) {
         this.task = task;
         this.taskId = taskId;
         this.action = action;
         this.futureForRemove = futureForRemove;
+        this.suspendReason = suspendReason;
     }
 
     public static TaskAndAction createAddTask(final Task task) {
         Objects.requireNonNull(task, "Task to add is null!");
-        return new TaskAndAction(task, null, Action.ADD, null);
+        return new TaskAndAction(task, null, Action.ADD, null, null);
     }
 
     public static TaskAndAction createRemoveTask(final TaskId taskId,
-                                                 final 
CompletableFuture<StateUpdater.RemovedTaskResult> future) {
+                                                 final 
CompletableFuture<StateUpdater.RemovedTaskResult> future,
+                                                 final 
StandbyUpdateListener.SuspendReason suspendReason) {
         Objects.requireNonNull(taskId, "Task ID of task to remove is null!");
         Objects.requireNonNull(future, "Future for task to remove is null!");
-        return new TaskAndAction(null, taskId, Action.REMOVE, future);
+        Objects.requireNonNull(suspendReason, "Suspend reason for task to 
remove is null!");
+        return new TaskAndAction(null, taskId, Action.REMOVE, future, 
suspendReason);
     }
 
     public Task task() {
@@ -76,6 +82,13 @@ public class TaskAndAction {
         return futureForRemove;
     }
 
+    public StandbyUpdateListener.SuspendReason suspendReason() {
+        if (action != Action.REMOVE) {
+            throw new IllegalStateException("Action type " + action + " cannot 
have a suspend reason!");
+        }
+        return suspendReason;
+    }
+
     public Action action() {
         return action;
     }
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
index 9f421a9671c..6e7ebf46d90 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
@@ -37,6 +37,7 @@ import org.apache.kafka.streams.errors.TaskCorruptedException;
 import org.apache.kafka.streams.errors.TaskIdFormatException;
 import org.apache.kafka.streams.errors.TaskMigratedException;
 import org.apache.kafka.streams.internals.StreamsConfigUtils.ProcessingMode;
+import org.apache.kafka.streams.processor.StandbyUpdateListener;
 import org.apache.kafka.streams.processor.TaskId;
 import org.apache.kafka.streams.processor.assignment.ProcessId;
 import 
org.apache.kafka.streams.processor.internals.StateDirectory.TaskDirectory;
@@ -611,25 +612,29 @@ public class TaskManager {
             if (activeTasksToCreate.containsKey(taskId)) {
                 if (task.isActive()) {
                     if 
(!task.inputPartitions().equals(activeTasksToCreate.get(taskId))) {
-                        final 
CompletableFuture<StateUpdater.RemovedTaskResult> future = 
stateUpdater.remove(taskId);
+                        final 
CompletableFuture<StateUpdater.RemovedTaskResult> future =
+                            stateUpdater.remove(taskId, 
StandbyUpdateListener.SuspendReason.MIGRATED);
                         futuresForUpdatingInputPartitions.put(taskId, future);
                         newInputPartitions.put(taskId, 
activeTasksToCreate.get(taskId));
                     }
                 } else {
-                    final CompletableFuture<StateUpdater.RemovedTaskResult> 
future = stateUpdater.remove(taskId);
+                    final CompletableFuture<StateUpdater.RemovedTaskResult> 
future =
+                        stateUpdater.remove(taskId, 
StandbyUpdateListener.SuspendReason.PROMOTED);
                     futuresForStandbyTasksToRecycle.put(taskId, future);
                     activeInputPartitions.put(taskId, 
activeTasksToCreate.get(taskId));
                 }
                 activeTasksToCreate.remove(taskId);
             } else if (standbyTasksToCreate.containsKey(taskId)) {
                 if (task.isActive()) {
-                    final CompletableFuture<StateUpdater.RemovedTaskResult> 
future = stateUpdater.remove(taskId);
+                    final CompletableFuture<StateUpdater.RemovedTaskResult> 
future =
+                        stateUpdater.remove(taskId, 
StandbyUpdateListener.SuspendReason.MIGRATED);
                     futuresForActiveTasksToRecycle.put(taskId, future);
                     standbyInputPartitions.put(taskId, 
standbyTasksToCreate.get(taskId));
                 }
                 standbyTasksToCreate.remove(taskId);
             } else {
-                final CompletableFuture<StateUpdater.RemovedTaskResult> future 
= stateUpdater.remove(taskId);
+                final CompletableFuture<StateUpdater.RemovedTaskResult> future 
=
+                    stateUpdater.remove(taskId, 
StandbyUpdateListener.SuspendReason.MIGRATED);
                 futuresForTasksToClose.put(taskId, future);
             }
         }
@@ -1144,7 +1149,7 @@ public class TaskManager {
         for (final Task restoringTask : stateUpdater.tasks()) {
             if (restoringTask.isActive()) {
                 if 
(remainingRevokedPartitions.containsAll(restoringTask.inputPartitions())) {
-                    futures.put(restoringTask.id(), 
stateUpdater.remove(restoringTask.id()));
+                    futures.put(restoringTask.id(), 
stateUpdater.remove(restoringTask.id(), 
StandbyUpdateListener.SuspendReason.MIGRATED));
                     
remainingRevokedPartitions.removeAll(restoringTask.inputPartitions());
                 }
             }
@@ -1213,7 +1218,7 @@ public class TaskManager {
         final Set<Task> tasksToCloseDirty = new 
TreeSet<>(Comparator.comparing(Task::id));
         for (final Task restoringTask : stateUpdater.tasks()) {
             if (restoringTask.isActive()) {
-                futures.put(restoringTask.id(), 
stateUpdater.remove(restoringTask.id()));
+                futures.put(restoringTask.id(), 
stateUpdater.remove(restoringTask.id(), 
StandbyUpdateListener.SuspendReason.MIGRATED));
             }
         }
 
@@ -1465,7 +1470,7 @@ public class TaskManager {
 
         final Map<TaskId, CompletableFuture<StateUpdater.RemovedTaskResult>> 
futures = new LinkedHashMap<>();
         for (final Task task : stateUpdater.tasks()) {
-            final CompletableFuture<StateUpdater.RemovedTaskResult> future = 
stateUpdater.remove(task.id());
+            final CompletableFuture<StateUpdater.RemovedTaskResult> future = 
stateUpdater.remove(task.id(), StandbyUpdateListener.SuspendReason.MIGRATED);
             futures.put(task.id(), future);
         }
         final Set<Task> tasksToCloseClean = new 
TreeSet<>(Comparator.comparing(Task::id));
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java
index 8087e28c8b3..4dacd079dd0 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java
@@ -26,6 +26,7 @@ import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.errors.ProcessorStateException;
 import org.apache.kafka.streams.errors.StreamsException;
 import org.apache.kafka.streams.errors.TaskCorruptedException;
+import org.apache.kafka.streams.processor.StandbyUpdateListener;
 import org.apache.kafka.streams.processor.TaskId;
 import 
org.apache.kafka.streams.processor.internals.StateUpdater.ExceptionAndTask;
 import org.apache.kafka.streams.processor.internals.Task.State;
@@ -448,7 +449,7 @@ class DefaultStateUpdaterTest {
             .thenReturn(false);
         stateUpdater.start();
         stateUpdater.add(task);
-        stateUpdater.remove(task.id()).get();
+        stateUpdater.remove(task.id(), 
StandbyUpdateListener.SuspendReason.MIGRATED).get();
         verifyRestoredActiveTasks();
         verifyUpdatingTasks();
         verifyExceptionsAndFailedTasks();
@@ -715,8 +716,8 @@ class DefaultStateUpdaterTest {
         stateUpdater.add(standbyTask);
         verifyUpdatingTasks(activeTask1, activeTask2, standbyTask);
 
-        final CompletableFuture<StateUpdater.RemovedTaskResult> future1 = 
stateUpdater.remove(activeTask1.id());
-        final CompletableFuture<StateUpdater.RemovedTaskResult> future2 = 
stateUpdater.remove(activeTask2.id());
+        final CompletableFuture<StateUpdater.RemovedTaskResult> future1 = 
stateUpdater.remove(activeTask1.id(), 
StandbyUpdateListener.SuspendReason.MIGRATED);
+        final CompletableFuture<StateUpdater.RemovedTaskResult> future2 = 
stateUpdater.remove(activeTask2.id(), 
StandbyUpdateListener.SuspendReason.MIGRATED);
         CompletableFuture.allOf(future1, future2).get();
 
         final InOrder orderVerifier = inOrder(changelogReader);
@@ -735,7 +736,7 @@ class DefaultStateUpdaterTest {
         stateUpdater.add(standbyTask2);
         verifyUpdatingTasks(standbyTask1, standbyTask2);
 
-        stateUpdater.remove(standbyTask2.id()).get();
+        stateUpdater.remove(standbyTask2.id(), 
StandbyUpdateListener.SuspendReason.MIGRATED).get();
 
         verify(changelogReader).transitToUpdateStandby();
     }
@@ -752,6 +753,22 @@ class DefaultStateUpdaterTest {
         shouldRemoveUpdatingStatefulTask(task);
     }
 
+    @Test
+    public void shouldPassSuspendReasonToChangelogReaderOnRemove() throws 
Exception {
+        final StandbyTask task = standbyTask(TASK_0_0, 
Set.of(TOPIC_PARTITION_A_0)).inState(State.RUNNING).build();
+        
when(changelogReader.completedChangelogs()).thenReturn(Collections.emptySet());
+        when(changelogReader.allChangelogsCompleted()).thenReturn(false);
+        stateUpdater.start();
+        stateUpdater.add(task);
+        verifyUpdatingTasks(task);
+
+        final CompletableFuture<StateUpdater.RemovedTaskResult> future =
+            stateUpdater.remove(task.id(), 
StandbyUpdateListener.SuspendReason.PROMOTED);
+
+        assertEquals(new StateUpdater.RemovedTaskResult(task), future.get());
+        verify(changelogReader).unregister(task.changelogPartitions(), 
StandbyUpdateListener.SuspendReason.PROMOTED);
+    }
+
     private void shouldRemoveUpdatingStatefulTask(final Task task) throws 
Exception {
         
when(changelogReader.completedChangelogs()).thenReturn(Collections.emptySet());
         when(changelogReader.allChangelogsCompleted()).thenReturn(false);
@@ -759,7 +776,7 @@ class DefaultStateUpdaterTest {
         stateUpdater.add(task);
         verifyUpdatingTasks(task);
 
-        final CompletableFuture<StateUpdater.RemovedTaskResult> future = 
stateUpdater.remove(task.id());
+        final CompletableFuture<StateUpdater.RemovedTaskResult> future = 
stateUpdater.remove(task.id(), StandbyUpdateListener.SuspendReason.MIGRATED);
 
         assertEquals(new StateUpdater.RemovedTaskResult(task), future.get());
         verifyCheckpointTasks(true, task);
@@ -767,7 +784,7 @@ class DefaultStateUpdaterTest {
         verifyUpdatingTasks();
         verifyPausedTasks();
         verifyExceptionsAndFailedTasks();
-        verify(changelogReader).unregister(task.changelogPartitions());
+        verify(changelogReader).unregister(task.changelogPartitions(), 
StandbyUpdateListener.SuspendReason.MIGRATED);
     }
 
     @Test
@@ -776,7 +793,7 @@ class DefaultStateUpdaterTest {
         final StreamsException streamsException = new 
StreamsException("Something happened", task.id());
         setupShouldThrowIfRemovingUpdatingStatefulTaskFailsWithException(task, 
streamsException);
 
-        final CompletableFuture<StateUpdater.RemovedTaskResult> future = 
stateUpdater.remove(task.id());
+        final CompletableFuture<StateUpdater.RemovedTaskResult> future = 
stateUpdater.remove(task.id(), StandbyUpdateListener.SuspendReason.MIGRATED);
 
         verifyRemovingUpdatingStatefulTaskFails(future, task, 
streamsException, true);
 
@@ -788,7 +805,7 @@ class DefaultStateUpdaterTest {
         final RuntimeException runtimeException = new 
RuntimeException("Something happened");
         setupShouldThrowIfRemovingUpdatingStatefulTaskFailsWithException(task, 
runtimeException);
 
-        final CompletableFuture<StateUpdater.RemovedTaskResult> future = 
stateUpdater.remove(task.id());
+        final CompletableFuture<StateUpdater.RemovedTaskResult> future = 
stateUpdater.remove(task.id(), StandbyUpdateListener.SuspendReason.MIGRATED);
 
         verifyRemovingUpdatingStatefulTaskFails(future, task, 
runtimeException, false);
     }
@@ -799,7 +816,7 @@ class DefaultStateUpdaterTest {
         final StreamsException streamsException = new 
StreamsException("Something happened", task.id());
         setupShouldThrowIfRemovingUpdatingStatefulTaskFailsWithException(task, 
streamsException);
 
-        final CompletableFuture<StateUpdater.RemovedTaskResult> future = 
stateUpdater.remove(task.id());
+        final CompletableFuture<StateUpdater.RemovedTaskResult> future = 
stateUpdater.remove(task.id(), StandbyUpdateListener.SuspendReason.MIGRATED);
 
         verifyRemovingUpdatingStatefulTaskFails(future, task, 
streamsException, true);
     }
@@ -810,7 +827,7 @@ class DefaultStateUpdaterTest {
         final RuntimeException runtimeException = new 
RuntimeException("Something happened");
         setupShouldThrowIfRemovingUpdatingStatefulTaskFailsWithException(task, 
runtimeException);
 
-        final CompletableFuture<StateUpdater.RemovedTaskResult> future = 
stateUpdater.remove(task.id());
+        final CompletableFuture<StateUpdater.RemovedTaskResult> future = 
stateUpdater.remove(task.id(), StandbyUpdateListener.SuspendReason.MIGRATED);
 
         verifyRemovingUpdatingStatefulTaskFails(future, task, 
runtimeException, false);
     }
@@ -820,7 +837,7 @@ class DefaultStateUpdaterTest {
         
when(changelogReader.completedChangelogs()).thenReturn(Collections.emptySet());
         when(changelogReader.allChangelogsCompleted()).thenReturn(false);
         final Collection<TopicPartition> changelogPartitions = 
task.changelogPartitions();
-        
doThrow(exception).when(changelogReader).unregister(changelogPartitions);
+        
doThrow(exception).when(changelogReader).unregister(changelogPartitions, 
StandbyUpdateListener.SuspendReason.MIGRATED);
         stateUpdater.start();
         stateUpdater.add(task);
         verifyUpdatingTasks(task);
@@ -851,8 +868,8 @@ class DefaultStateUpdaterTest {
         verifyPausedTasks(statefulTask, standbyTask);
         verifyUpdatingTasks();
 
-        final CompletableFuture<StateUpdater.RemovedTaskResult> 
futureOfStatefulTask = stateUpdater.remove(statefulTask.id());
-        final CompletableFuture<StateUpdater.RemovedTaskResult> 
futureOfStandbyTask = stateUpdater.remove(standbyTask.id());
+        final CompletableFuture<StateUpdater.RemovedTaskResult> 
futureOfStatefulTask = stateUpdater.remove(statefulTask.id(), 
StandbyUpdateListener.SuspendReason.MIGRATED);
+        final CompletableFuture<StateUpdater.RemovedTaskResult> 
futureOfStandbyTask = stateUpdater.remove(standbyTask.id(), 
StandbyUpdateListener.SuspendReason.MIGRATED);
 
         assertEquals(new StateUpdater.RemovedTaskResult(statefulTask), 
futureOfStatefulTask.get());
         assertEquals(new StateUpdater.RemovedTaskResult(standbyTask), 
futureOfStandbyTask.get());
@@ -860,8 +877,8 @@ class DefaultStateUpdaterTest {
         verifyCheckpointTasks(true, statefulTask, standbyTask);
         verifyUpdatingTasks();
         verifyExceptionsAndFailedTasks();
-        verify(changelogReader).unregister(statefulTask.changelogPartitions());
-        verify(changelogReader).unregister(standbyTask.changelogPartitions());
+        verify(changelogReader).unregister(statefulTask.changelogPartitions(), 
StandbyUpdateListener.SuspendReason.MIGRATED);
+        verify(changelogReader).unregister(standbyTask.changelogPartitions(), 
StandbyUpdateListener.SuspendReason.MIGRATED);
     }
 
     @Test
@@ -869,7 +886,7 @@ class DefaultStateUpdaterTest {
         final StreamTask statefulTask = statefulTask(TASK_0_0, 
Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build();
         final StreamsException streamsException = new 
StreamsException("Something happened", statefulTask.id());
         final Collection<TopicPartition> changelogPartitions = 
statefulTask.changelogPartitions();
-        
doThrow(streamsException).when(changelogReader).unregister(changelogPartitions);
+        
doThrow(streamsException).when(changelogReader).unregister(changelogPartitions, 
StandbyUpdateListener.SuspendReason.MIGRATED);
         stateUpdater.start();
         stateUpdater.add(statefulTask);
         verifyUpdatingTasks(statefulTask);
@@ -877,7 +894,7 @@ class DefaultStateUpdaterTest {
         verifyPausedTasks(statefulTask);
         verifyUpdatingTasks();
 
-        final CompletableFuture<StateUpdater.RemovedTaskResult> future = 
stateUpdater.remove(statefulTask.id());
+        final CompletableFuture<StateUpdater.RemovedTaskResult> future = 
stateUpdater.remove(statefulTask.id(), 
StandbyUpdateListener.SuspendReason.MIGRATED);
 
         final ExecutionException executionException = 
assertThrows(ExecutionException.class, future::get);
         assertInstanceOf(StreamsException.class, 
executionException.getCause());
@@ -907,7 +924,7 @@ class DefaultStateUpdaterTest {
         stateUpdater.add(task);
         verifyRestoredActiveTasks(task);
 
-        final CompletableFuture<StateUpdater.RemovedTaskResult> future = 
stateUpdater.remove(task.id());
+        final CompletableFuture<StateUpdater.RemovedTaskResult> future = 
stateUpdater.remove(task.id(), StandbyUpdateListener.SuspendReason.MIGRATED);
         future.get();
 
         assertEquals(new StateUpdater.RemovedTaskResult(task), future.get());
@@ -943,7 +960,7 @@ class DefaultStateUpdaterTest {
         final ExceptionAndTask expectedExceptionAndTasks = new 
ExceptionAndTask(streamsException, task);
         verifyExceptionsAndFailedTasks(expectedExceptionAndTasks);
 
-        final CompletableFuture<StateUpdater.RemovedTaskResult> future = 
stateUpdater.remove(task.id());
+        final CompletableFuture<StateUpdater.RemovedTaskResult> future = 
stateUpdater.remove(task.id(), StandbyUpdateListener.SuspendReason.MIGRATED);
 
         assertEquals(new StateUpdater.RemovedTaskResult(task, 
streamsException), future.get());
         verifyPausedTasks();
@@ -973,7 +990,7 @@ class DefaultStateUpdaterTest {
         verifyUpdatingTasks(updatingTask);
         verifyPausedTasks();
 
-        final CompletableFuture<StateUpdater.RemovedTaskResult> future = 
stateUpdater.remove(TASK_1_0);
+        final CompletableFuture<StateUpdater.RemovedTaskResult> future = 
stateUpdater.remove(TASK_1_0, StandbyUpdateListener.SuspendReason.MIGRATED);
 
         assertNull(future.get());
         verifyRestoredActiveTasks(restoredTask);
@@ -986,7 +1003,7 @@ class DefaultStateUpdaterTest {
     public void shouldCompleteWithNullIfNoTasks() throws Exception {
         stateUpdater.start();
 
-        final CompletableFuture<StateUpdater.RemovedTaskResult> future = 
stateUpdater.remove(TASK_0_1);
+        final CompletableFuture<StateUpdater.RemovedTaskResult> future = 
stateUpdater.remove(TASK_0_1, StandbyUpdateListener.SuspendReason.MIGRATED);
 
         assertNull(future.get());
         assertTrue(stateUpdater.isRunning());
@@ -1487,7 +1504,7 @@ class DefaultStateUpdaterTest {
         stateUpdater.add(activeTask1);
         stateUpdater.add(standbyTask1);
         stateUpdater.add(standbyTask2);
-        stateUpdater.remove(TASK_0_0);
+        stateUpdater.remove(TASK_0_0, 
StandbyUpdateListener.SuspendReason.MIGRATED);
         stateUpdater.add(activeTask2);
         stateUpdater.add(standbyTask3);
 
@@ -1782,7 +1799,7 @@ class DefaultStateUpdaterTest {
         verifyUpdatingTasks(failedStatefulTask, activeTask1);
 
         throwException.set(true);
-        final ExecutionException exception = 
assertThrows(ExecutionException.class, () -> 
stateUpdater.remove(TASK_0_2).get());
+        final ExecutionException exception = 
assertThrows(ExecutionException.class, () -> stateUpdater.remove(TASK_0_2, 
StandbyUpdateListener.SuspendReason.MIGRATED).get());
         assertEquals(processorStateException, exception.getCause());
 
         stateUpdater.add(activeTask2);
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/MockChangelogReader.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/MockChangelogReader.java
index 49d18d888ed..765f7d85ee1 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/MockChangelogReader.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/MockChangelogReader.java
@@ -17,6 +17,7 @@
 package org.apache.kafka.streams.processor.internals;
 
 import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.streams.processor.StandbyUpdateListener.SuspendReason;
 import org.apache.kafka.streams.processor.TaskId;
 
 import java.util.Collection;
@@ -91,6 +92,12 @@ public class MockChangelogReader implements ChangelogReader {
         }
     }
 
+    @Override
+    public void unregister(final Collection<TopicPartition> partitions,
+                           final SuspendReason reason) {
+        unregister(partitions);
+    }
+
     @Override
     public boolean isEmpty() {
         return restoredOffsets.isEmpty() && restoringPartitions.isEmpty();
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StoreChangelogReaderTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StoreChangelogReaderTest.java
index 13389239602..72954175f5b 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StoreChangelogReaderTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StoreChangelogReaderTest.java
@@ -37,6 +37,7 @@ import org.apache.kafka.common.utils.LogContext;
 import org.apache.kafka.common.utils.MockTime;
 import org.apache.kafka.streams.StreamsConfig;
 import org.apache.kafka.streams.errors.StreamsException;
+import org.apache.kafka.streams.processor.StandbyUpdateListener.SuspendReason;
 import org.apache.kafka.streams.processor.StateStore;
 import org.apache.kafka.streams.processor.TaskId;
 import 
org.apache.kafka.streams.processor.internals.ProcessorStateManager.StateStoreMetadata;
@@ -270,6 +271,8 @@ public class StoreChangelogReaderTest {
             assertNull(callback.storeNameCalledStates.get(RESTORE_SUSPENDED));
             assertEquals(storeName, 
standbyListener.capturedStore(UPDATE_START));
             assertEquals(tp, standbyListener.updatePartition);
+            assertEquals(storeName, 
standbyListener.capturedStore(UPDATE_SUSPENDED));
+            assertEquals(SuspendReason.MIGRATED, 
standbyListener.updateSuspendedReason);
         }
         assertNull(callback.storeNameCalledStates.get(RESTORE_BATCH));
     }
@@ -330,9 +333,39 @@ public class StoreChangelogReaderTest {
             assertNull(callback.storeNameCalledStates.get(UPDATE_BATCH));
             assertEquals(storeName, 
standbyListener.capturedStore(UPDATE_START));
             assertEquals(tp, standbyListener.updatePartition);
+            assertEquals(storeName, 
standbyListener.capturedStore(UPDATE_SUSPENDED));
+            assertEquals(SuspendReason.MIGRATED, 
standbyListener.updateSuspendedReason);
         }
     }
 
+    @Test
+    public void shouldPassSuspendReasonToStandbyListener() {
+        setupStateManagerMock(STANDBY);
+        setupStoreMetadata();
+        setupStore();
+        @SuppressWarnings("unchecked")
+        final Map<TaskId, Task> mockTasks = mock(Map.class);
+        when(mockTasks.get(null)).thenReturn(mock(Task.class));
+        when(mockTasks.containsKey(null)).thenReturn(true);
+        when(storeMetadata.offset()).thenReturn(9L);
+        when(storeMetadata.endOffset()).thenReturn(10L);
+        when(stateManager.changelogAsSource(tp)).thenReturn(true);
+
+        adminClient.updateEndOffsets(Collections.singletonMap(tp, 100L));
+
+        final StoreChangelogReader changelogReader =
+            new StoreChangelogReader(time, config, logContext, adminClient, 
consumer, callback, standbyListener);
+
+        changelogReader.register(tp, stateManager);
+        changelogReader.transitToUpdateStandby();
+        changelogReader.restore(mockTasks);
+
+        changelogReader.unregister(Collections.singleton(tp), 
SuspendReason.PROMOTED);
+
+        assertEquals(storeName, 
standbyListener.capturedStore(UPDATE_SUSPENDED));
+        assertEquals(SuspendReason.PROMOTED, 
standbyListener.updateSuspendedReason);
+    }
+
     @ParameterizedTest
     @EnumSource(value = Task.TaskType.class, names = {"ACTIVE", "STANDBY"})
     public void shouldInitializeChangelogAndCheckForCompletion(final 
Task.TaskType type) {
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskAndActionTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskAndActionTest.java
index 704ce9f403b..4624b446243 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskAndActionTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskAndActionTest.java
@@ -16,6 +16,7 @@
  */
 package org.apache.kafka.streams.processor.internals;
 
+import org.apache.kafka.streams.processor.StandbyUpdateListener;
 import org.apache.kafka.streams.processor.TaskId;
 
 import org.junit.jupiter.api.Test;
@@ -52,11 +53,12 @@ class TaskAndActionTest {
         final TaskId taskId = new TaskId(0, 0);
         final CompletableFuture<StateUpdater.RemovedTaskResult> future = new 
CompletableFuture<>();
 
-        final TaskAndAction removeTask = createRemoveTask(taskId, future);
+        final TaskAndAction removeTask = createRemoveTask(taskId, future, 
StandbyUpdateListener.SuspendReason.MIGRATED);
 
         assertEquals(REMOVE, removeTask.action());
         assertEquals(taskId, removeTask.taskId());
         assertEquals(future, removeTask.futureForRemove());
+        assertEquals(StandbyUpdateListener.SuspendReason.MIGRATED, 
removeTask.suspendReason());
         final Exception exceptionForTask = 
assertThrows(IllegalStateException.class, removeTask::task);
         assertEquals("Action type REMOVE cannot have a task!", 
exceptionForTask.getMessage());
     }
@@ -71,7 +73,7 @@ class TaskAndActionTest {
     public void shouldThrowIfRemoveTaskActionIsCreatedWithNullTaskId() {
         final Exception exception = assertThrows(
             NullPointerException.class,
-            () -> createRemoveTask(null, new CompletableFuture<>())
+            () -> createRemoveTask(null, new CompletableFuture<>(), 
StandbyUpdateListener.SuspendReason.MIGRATED)
         );
         assertTrue(exception.getMessage().contains("Task ID of task to remove 
is null!"));
     }
@@ -80,7 +82,7 @@ class TaskAndActionTest {
     public void shouldThrowIfRemoveTaskActionIsCreatedWithNullFuture() {
         final Exception exception = assertThrows(
             NullPointerException.class,
-            () -> createRemoveTask(new TaskId(0, 0), null)
+            () -> createRemoveTask(new TaskId(0, 0), null, 
StandbyUpdateListener.SuspendReason.MIGRATED)
         );
         assertTrue(exception.getMessage().contains("Future for task to remove 
is null!"));
     }
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
index ce369cd1c08..586af7cc1ae 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
@@ -43,6 +43,7 @@ import org.apache.kafka.streams.errors.TaskCorruptedException;
 import org.apache.kafka.streams.errors.TaskMigratedException;
 import org.apache.kafka.streams.internals.StreamsConfigUtils;
 import org.apache.kafka.streams.internals.StreamsConfigUtils.ProcessingMode;
+import org.apache.kafka.streams.processor.StandbyUpdateListener.SuspendReason;
 import org.apache.kafka.streams.processor.TaskId;
 import org.apache.kafka.streams.processor.assignment.ProcessId;
 import 
org.apache.kafka.streams.processor.internals.StateDirectory.TaskDirectory;
@@ -368,7 +369,7 @@ public class TaskManagerTest {
         final TaskManager taskManager = 
setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks);
         when(stateUpdater.tasks()).thenReturn(Set.of(activeTaskToClose));
         final CompletableFuture<StateUpdater.RemovedTaskResult> future = new 
CompletableFuture<>();
-        when(stateUpdater.remove(activeTaskToClose.id())).thenReturn(future);
+        when(stateUpdater.remove(eq(activeTaskToClose.id()), 
eq(SuspendReason.MIGRATED))).thenReturn(future);
         future.complete(new StateUpdater.RemovedTaskResult(activeTaskToClose));
 
         taskManager.handleAssignment(Collections.emptyMap(), 
Collections.emptyMap());
@@ -388,7 +389,7 @@ public class TaskManagerTest {
         final TaskManager taskManager = 
setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks);
         when(stateUpdater.tasks()).thenReturn(Set.of(activeTaskToClose));
         final CompletableFuture<StateUpdater.RemovedTaskResult> future = new 
CompletableFuture<>();
-        when(stateUpdater.remove(activeTaskToClose.id())).thenReturn(future);
+        when(stateUpdater.remove(eq(activeTaskToClose.id()), 
eq(SuspendReason.MIGRATED))).thenReturn(future);
         future.complete(new StateUpdater.RemovedTaskResult(activeTaskToClose, 
new RuntimeException("KABOOM!")));
 
         taskManager.handleAssignment(Collections.emptyMap(), 
Collections.emptyMap());
@@ -409,7 +410,7 @@ public class TaskManagerTest {
         final TaskManager taskManager = 
setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks);
         when(stateUpdater.tasks()).thenReturn(Set.of(standbyTaskToClose));
         final CompletableFuture<StateUpdater.RemovedTaskResult> future = new 
CompletableFuture<>();
-        when(stateUpdater.remove(standbyTaskToClose.id())).thenReturn(future);
+        when(stateUpdater.remove(eq(standbyTaskToClose.id()), 
eq(SuspendReason.MIGRATED))).thenReturn(future);
         future.complete(new 
StateUpdater.RemovedTaskResult(standbyTaskToClose));
 
         taskManager.handleAssignment(Collections.emptyMap(), 
Collections.emptyMap());
@@ -429,7 +430,7 @@ public class TaskManagerTest {
         final TaskManager taskManager = 
setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks);
         when(stateUpdater.tasks()).thenReturn(Set.of(standbyTaskToClose));
         final CompletableFuture<StateUpdater.RemovedTaskResult> future = new 
CompletableFuture<>();
-        when(stateUpdater.remove(standbyTaskToClose.id())).thenReturn(future);
+        when(stateUpdater.remove(eq(standbyTaskToClose.id()), 
eq(SuspendReason.MIGRATED))).thenReturn(future);
         future.complete(new StateUpdater.RemovedTaskResult(standbyTaskToClose, 
new RuntimeException("KABOOM!")));
 
         taskManager.handleAssignment(Collections.emptyMap(), 
Collections.emptyMap());
@@ -450,7 +451,7 @@ public class TaskManagerTest {
         final TaskManager taskManager = 
setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks);
         when(stateUpdater.tasks()).thenReturn(Set.of(failedStandbyTask));
         final CompletableFuture<StateUpdater.RemovedTaskResult> future = new 
CompletableFuture<>();
-        when(stateUpdater.remove(failedStandbyTask.id())).thenReturn(future);
+        when(stateUpdater.remove(eq(failedStandbyTask.id()), 
eq(SuspendReason.MIGRATED))).thenReturn(future);
         final RuntimeException kaboom = new RuntimeException("KABOOM!");
         future.completeExceptionally(kaboom);
         when(stateUpdater.drainExceptionsAndFailedTasks())
@@ -477,7 +478,7 @@ public class TaskManagerTest {
         final TaskManager taskManager = 
setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks);
         
when(stateUpdater.tasks()).thenReturn(Set.of(activeTaskToUpdateInputPartitions));
         final CompletableFuture<StateUpdater.RemovedTaskResult> future = new 
CompletableFuture<>();
-        
when(stateUpdater.remove(activeTaskToUpdateInputPartitions.id())).thenReturn(future);
+        when(stateUpdater.remove(eq(activeTaskToUpdateInputPartitions.id()), 
eq(SuspendReason.MIGRATED))).thenReturn(future);
         future.complete(new 
StateUpdater.RemovedTaskResult(activeTaskToUpdateInputPartitions));
 
         taskManager.handleAssignment(
@@ -507,7 +508,7 @@ public class TaskManagerTest {
         
when(standbyTaskCreator.createStandbyTaskFromActive(activeTaskToRecycle, 
taskId03Partitions))
             .thenReturn(recycledStandbyTask);
         final CompletableFuture<StateUpdater.RemovedTaskResult> future = new 
CompletableFuture<>();
-        when(stateUpdater.remove(taskId03)).thenReturn(future);
+        when(stateUpdater.remove(eq(taskId03), 
eq(SuspendReason.MIGRATED))).thenReturn(future);
         future.complete(new 
StateUpdater.RemovedTaskResult(activeTaskToRecycle));
 
         taskManager.handleAssignment(
@@ -531,7 +532,7 @@ public class TaskManagerTest {
         
when(standbyTaskCreator.createStandbyTaskFromActive(activeTaskToRecycle, 
activeTaskToRecycle.inputPartitions()))
             .thenThrow(new RuntimeException());
         final CompletableFuture<StateUpdater.RemovedTaskResult> future = new 
CompletableFuture<>();
-        when(stateUpdater.remove(activeTaskToRecycle.id())).thenReturn(future);
+        when(stateUpdater.remove(eq(activeTaskToRecycle.id()), 
eq(SuspendReason.MIGRATED))).thenReturn(future);
         future.complete(new 
StateUpdater.RemovedTaskResult(activeTaskToRecycle));
 
         assertThrows(
@@ -561,7 +562,7 @@ public class TaskManagerTest {
         
when(activeTaskCreator.createActiveTaskFromStandby(standbyTaskToRecycle, 
taskId03Partitions, consumer))
             .thenReturn(recycledActiveTask);
         final CompletableFuture<StateUpdater.RemovedTaskResult> future = new 
CompletableFuture<>();
-        
when(stateUpdater.remove(standbyTaskToRecycle.id())).thenReturn(future);
+        when(stateUpdater.remove(eq(standbyTaskToRecycle.id()), 
eq(SuspendReason.PROMOTED))).thenReturn(future);
         future.complete(new 
StateUpdater.RemovedTaskResult(standbyTaskToRecycle));
 
         taskManager.handleAssignment(
@@ -588,7 +589,7 @@ public class TaskManagerTest {
             consumer))
             .thenThrow(new RuntimeException());
         final CompletableFuture<StateUpdater.RemovedTaskResult> future = new 
CompletableFuture<>();
-        
when(stateUpdater.remove(standbyTaskToRecycle.id())).thenReturn(future);
+        when(stateUpdater.remove(eq(standbyTaskToRecycle.id()), 
eq(SuspendReason.PROMOTED))).thenReturn(future);
         future.complete(new 
StateUpdater.RemovedTaskResult(standbyTaskToRecycle));
 
         assertThrows(
@@ -618,7 +619,7 @@ public class TaskManagerTest {
             Collections.emptyMap()
         );
 
-        verify(stateUpdater, never()).remove(reassignedActiveTask.id());
+        verify(stateUpdater, never()).remove(eq(reassignedActiveTask.id()), 
any());
         verify(activeTaskCreator).createTasks(consumer, 
Collections.emptyMap());
         verify(standbyTaskCreator).createTasks(Collections.emptyMap());
     }
@@ -652,7 +653,7 @@ public class TaskManagerTest {
         final TaskManager taskManager = 
setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks);
         
when(stateUpdater.tasks()).thenReturn(Set.of(failedActiveTaskToRecycle));
         final RuntimeException taskException = new RuntimeException("Nobody 
expects the Spanish inquisition!");
-        when(stateUpdater.remove(failedActiveTaskToRecycle.id()))
+        when(stateUpdater.remove(eq(failedActiveTaskToRecycle.id()), 
eq(SuspendReason.MIGRATED)))
             .thenReturn(CompletableFuture.completedFuture(
                 new StateUpdater.RemovedTaskResult(failedActiveTaskToRecycle, 
taskException)
             ));
@@ -682,7 +683,7 @@ public class TaskManagerTest {
         final TaskManager taskManager = 
setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks);
         
when(stateUpdater.tasks()).thenReturn(Set.of(failedStandbyTaskToRecycle));
         final RuntimeException taskException = new RuntimeException("Nobody 
expects the Spanish inquisition!");
-        when(stateUpdater.remove(failedStandbyTaskToRecycle.id()))
+        when(stateUpdater.remove(eq(failedStandbyTaskToRecycle.id()), 
eq(SuspendReason.PROMOTED)))
             .thenReturn(CompletableFuture.completedFuture(
                 new StateUpdater.RemovedTaskResult(failedStandbyTaskToRecycle, 
taskException)
             ));
@@ -712,7 +713,7 @@ public class TaskManagerTest {
         final TaskManager taskManager = 
setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks);
         
when(stateUpdater.tasks()).thenReturn(Set.of(failedActiveTaskToReassign));
         final RuntimeException taskException = new RuntimeException("Nobody 
expects the Spanish inquisition!");
-        when(stateUpdater.remove(failedActiveTaskToReassign.id()))
+        when(stateUpdater.remove(eq(failedActiveTaskToReassign.id()), 
eq(SuspendReason.MIGRATED)))
             .thenReturn(CompletableFuture.completedFuture(
                 new StateUpdater.RemovedTaskResult(failedActiveTaskToReassign, 
taskException)
             ));
@@ -745,7 +746,7 @@ public class TaskManagerTest {
         final TaskManager taskManager = 
setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks);
         
when(tasks.allNonFailedInitializedTasks()).thenReturn(Set.of(reassignedActiveTask1));
         when(stateUpdater.tasks()).thenReturn(Set.of(reassignedActiveTask2));
-        when(stateUpdater.remove(reassignedActiveTask2.id()))
+        when(stateUpdater.remove(eq(reassignedActiveTask2.id()), 
eq(SuspendReason.MIGRATED)))
             .thenReturn(CompletableFuture.completedFuture(new 
StateUpdater.RemovedTaskResult(reassignedActiveTask2)));
 
         taskManager.handleAssignment(
@@ -757,7 +758,7 @@ public class TaskManagerTest {
         );
 
         final InOrder inOrder = inOrder(stateUpdater, tasks);
-        inOrder.verify(stateUpdater).remove(reassignedActiveTask2.id());
+        inOrder.verify(stateUpdater).remove(eq(reassignedActiveTask2.id()), 
eq(SuspendReason.MIGRATED));
         inOrder.verify(tasks).removeTask(reassignedActiveTask1);
         inOrder.verify(stateUpdater).add(reassignedActiveTask1);
     }
@@ -775,7 +776,7 @@ public class TaskManagerTest {
             Collections.emptyMap(),
             mkMap(mkEntry(standbyTaskToUpdateInputPartitions.id(), 
taskId03Partitions))
         );
-        verify(stateUpdater, 
never()).remove(standbyTaskToUpdateInputPartitions.id());
+        verify(stateUpdater, 
never()).remove(eq(standbyTaskToUpdateInputPartitions.id()), any());
         verify(activeTaskCreator).createTasks(consumer, 
Collections.emptyMap());
         verify(standbyTaskCreator).createTasks(Collections.emptyMap());
     }
@@ -813,12 +814,12 @@ public class TaskManagerTest {
         final TaskManager taskManager = 
setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks);
         when(stateUpdater.tasks()).thenReturn(Set.of(activeTaskToClose, 
standbyTaskToRecycle));
         final CompletableFuture<StateUpdater.RemovedTaskResult> 
futureForActiveTaskToClose = new CompletableFuture<>();
-        
when(stateUpdater.remove(activeTaskToClose.id())).thenReturn(futureForActiveTaskToClose);
+        when(stateUpdater.remove(eq(activeTaskToClose.id()), 
eq(SuspendReason.MIGRATED))).thenReturn(futureForActiveTaskToClose);
         futureForActiveTaskToClose.complete(new 
StateUpdater.RemovedTaskResult(activeTaskToClose));
         
when(activeTaskCreator.createActiveTaskFromStandby(standbyTaskToRecycle, 
taskId02Partitions, consumer))
             .thenReturn(recycledActiveTask);
         final CompletableFuture<StateUpdater.RemovedTaskResult> 
futureForStandbyTaskToRecycle = new CompletableFuture<>();
-        
when(stateUpdater.remove(standbyTaskToRecycle.id())).thenReturn(futureForStandbyTaskToRecycle);
+        when(stateUpdater.remove(eq(standbyTaskToRecycle.id()), 
eq(SuspendReason.PROMOTED))).thenReturn(futureForStandbyTaskToRecycle);
         futureForStandbyTaskToRecycle.complete(new 
StateUpdater.RemovedTaskResult(standbyTaskToRecycle));
 
         taskManager.handleAssignment(
@@ -1314,14 +1315,14 @@ public class TaskManagerTest {
         final TaskManager taskManager = 
setupForRevocationAndLost(Set.of(task), tasks);
         when(stateUpdater.tasks()).thenReturn(Set.of(task));
         final CompletableFuture<StateUpdater.RemovedTaskResult> future = new 
CompletableFuture<>();
-        when(stateUpdater.remove(task.id())).thenReturn(future);
+        when(stateUpdater.remove(eq(task.id()), 
eq(SuspendReason.MIGRATED))).thenReturn(future);
         future.complete(new StateUpdater.RemovedTaskResult(task));
 
         taskManager.handleRevocation(task.inputPartitions());
 
         verify(task).suspend();
         verify(tasks).addActiveTask(task);
-        verify(stateUpdater).remove(task.id());
+        verify(stateUpdater).remove(eq(task.id()), eq(SuspendReason.MIGRATED));
     }
 
     @Test
@@ -1335,10 +1336,10 @@ public class TaskManagerTest {
         final TasksRegistry tasks = mock(TasksRegistry.class);
         final TaskManager taskManager = 
setupForRevocationAndLost(Set.of(task1, task2), tasks);
         final CompletableFuture<StateUpdater.RemovedTaskResult> future1 = new 
CompletableFuture<>();
-        when(stateUpdater.remove(task1.id())).thenReturn(future1);
+        when(stateUpdater.remove(eq(task1.id()), 
eq(SuspendReason.MIGRATED))).thenReturn(future1);
         future1.complete(new StateUpdater.RemovedTaskResult(task1));
         final CompletableFuture<StateUpdater.RemovedTaskResult> future2 = new 
CompletableFuture<>();
-        when(stateUpdater.remove(task2.id())).thenReturn(future2);
+        when(stateUpdater.remove(eq(task2.id()), 
eq(SuspendReason.MIGRATED))).thenReturn(future2);
         future2.complete(new StateUpdater.RemovedTaskResult(task2));
 
         taskManager.handleRevocation(union(HashSet::new, taskId00Partitions, 
taskId01Partitions));
@@ -1361,7 +1362,7 @@ public class TaskManagerTest {
 
         verify(task, never()).suspend();
         verify(tasks, never()).addActiveTask(task);
-        verify(stateUpdater, never()).remove(task.id());
+        verify(stateUpdater, never()).remove(eq(task.id()), any());
     }
 
     @Test
@@ -1376,7 +1377,7 @@ public class TaskManagerTest {
 
         verify(task, never()).suspend();
         verify(tasks, never()).addStandbyTask(task);
-        verify(stateUpdater, never()).remove(task.id());
+        verify(stateUpdater, never()).remove(eq(task.id()), any());
     }
 
     @Test
@@ -1390,10 +1391,10 @@ public class TaskManagerTest {
         final TasksRegistry tasks = mock(TasksRegistry.class);
         final TaskManager taskManager = 
setupForRevocationAndLost(Set.of(task1, task2), tasks);
         final CompletableFuture<StateUpdater.RemovedTaskResult> future1 = new 
CompletableFuture<>();
-        when(stateUpdater.remove(task1.id())).thenReturn(future1);
+        when(stateUpdater.remove(eq(task1.id()), 
eq(SuspendReason.MIGRATED))).thenReturn(future1);
         future1.complete(new StateUpdater.RemovedTaskResult(task1));
         final CompletableFuture<StateUpdater.RemovedTaskResult> future2 = new 
CompletableFuture<>();
-        when(stateUpdater.remove(task2.id())).thenReturn(future2);
+        when(stateUpdater.remove(eq(task2.id()), 
eq(SuspendReason.MIGRATED))).thenReturn(future2);
         final RuntimeException taskException = new RuntimeException("Nobody 
expects the Spanish inquisition!");
         future2.complete(new StateUpdater.RemovedTaskResult(task2, 
taskException));
 
@@ -1424,10 +1425,10 @@ public class TaskManagerTest {
         final TasksRegistry tasks = mock(TasksRegistry.class);
         final TaskManager taskManager = 
setupForRevocationAndLost(Set.of(task1, task2, task3), tasks);
         final CompletableFuture<StateUpdater.RemovedTaskResult> future1 = new 
CompletableFuture<>();
-        when(stateUpdater.remove(task1.id())).thenReturn(future1);
+        when(stateUpdater.remove(eq(task1.id()), 
eq(SuspendReason.MIGRATED))).thenReturn(future1);
         future1.complete(new StateUpdater.RemovedTaskResult(task1));
         final CompletableFuture<StateUpdater.RemovedTaskResult> future3 = new 
CompletableFuture<>();
-        when(stateUpdater.remove(task3.id())).thenReturn(future3);
+        when(stateUpdater.remove(eq(task3.id()), 
eq(SuspendReason.MIGRATED))).thenReturn(future3);
         future3.complete(new StateUpdater.RemovedTaskResult(task3));
 
         taskManager.handleLostAll();
@@ -1436,7 +1437,7 @@ public class TaskManagerTest {
         verify(task1).closeClean();
         verify(task3).suspend();
         verify(task3).closeClean();
-        verify(stateUpdater, never()).remove(task2.id());
+        verify(stateUpdater, never()).remove(eq(task2.id()), 
eq(SuspendReason.MIGRATED));
     }
 
     @Test
@@ -1470,10 +1471,10 @@ public class TaskManagerTest {
         final TasksRegistry tasks = mock(TasksRegistry.class);
         final TaskManager taskManager = 
setupForRevocationAndLost(Set.of(task1, task2), tasks);
         final CompletableFuture<StateUpdater.RemovedTaskResult> future1 = new 
CompletableFuture<>();
-        when(stateUpdater.remove(task1.id())).thenReturn(future1);
+        when(stateUpdater.remove(eq(task1.id()), 
eq(SuspendReason.MIGRATED))).thenReturn(future1);
         future1.complete(new StateUpdater.RemovedTaskResult(task1, new 
StreamsException("Something happened")));
         final CompletableFuture<StateUpdater.RemovedTaskResult> future3 = new 
CompletableFuture<>();
-        when(stateUpdater.remove(task2.id())).thenReturn(future3);
+        when(stateUpdater.remove(eq(task2.id()), 
eq(SuspendReason.MIGRATED))).thenReturn(future3);
         future3.complete(new StateUpdater.RemovedTaskResult(task2, new 
StreamsException("Something else happened")));
 
         taskManager.handleLostAll();
@@ -1501,10 +1502,10 @@ public class TaskManagerTest {
         when(tasks.drainPendingActiveTasksToInit()).thenReturn(Set.of(task1));
         final TaskManager taskManager = 
setupForRevocationAndLost(Set.of(task2, task3), tasks);
         final CompletableFuture<StateUpdater.RemovedTaskResult> future2 = new 
CompletableFuture<>();
-        when(stateUpdater.remove(task2.id())).thenReturn(future2);
+        when(stateUpdater.remove(eq(task2.id()), 
eq(SuspendReason.MIGRATED))).thenReturn(future2);
         future2.complete(new StateUpdater.RemovedTaskResult(task2, new 
StreamsException("Something happened")));
         final CompletableFuture<StateUpdater.RemovedTaskResult> future3 = new 
CompletableFuture<>();
-        when(stateUpdater.remove(task3.id())).thenReturn(future3);
+        when(stateUpdater.remove(eq(task3.id()), 
eq(SuspendReason.MIGRATED))).thenReturn(future3);
         future3.complete(new StateUpdater.RemovedTaskResult(task3));
 
         taskManager.handleLostAll();
@@ -2765,12 +2766,12 @@ public class TaskManagerTest {
 
         // mock future for removing task from StateUpdater
         final CompletableFuture<StateUpdater.RemovedTaskResult> future = new 
CompletableFuture<>();
-        when(stateUpdater.remove(task00.id())).thenReturn(future);
+        when(stateUpdater.remove(eq(task00.id()), 
eq(SuspendReason.MIGRATED))).thenReturn(future);
         future.complete(new StateUpdater.RemovedTaskResult(task00));
 
         taskManager.handleAssignment(emptyMap(), emptyMap());
 
-        verify(stateUpdater).remove(task00.id());
+        verify(stateUpdater).remove(eq(task00.id()), 
eq(SuspendReason.MIGRATED));
         verify(task00).suspend();
         verify(task00).closeClean();
 
@@ -3195,7 +3196,7 @@ public class TaskManagerTest {
         verify(task00, never()).postCommit(anyBoolean());
 
         // standby task not removed from state updater
-        verify(stateUpdater, never()).remove(task01.id());
+        verify(stateUpdater, never()).remove(eq(task01.id()), any());
         verify(task01, never()).prepareCommit(anyBoolean());
         verify(task01, never()).postCommit(anyBoolean());
 
@@ -3220,7 +3221,7 @@ public class TaskManagerTest {
 
         // mock to remove standby task from state updater
         final CompletableFuture<StateUpdater.RemovedTaskResult> future = new 
CompletableFuture<>();
-        when(stateUpdater.remove(task01.id())).thenReturn(future);
+        when(stateUpdater.remove(eq(task01.id()), 
eq(SuspendReason.MIGRATED))).thenReturn(future);
         future.complete(new StateUpdater.RemovedTaskResult(task01));
 
         final Map<TaskId, Set<TopicPartition>> assignmentActive = 
singletonMap(taskId00, taskId00Partitions);
@@ -3230,7 +3231,7 @@ public class TaskManagerTest {
         verify(task00, never()).prepareCommit(anyBoolean());
         verify(task00, never()).postCommit(anyBoolean());
 
-        verify(stateUpdater).remove(task01.id());
+        verify(stateUpdater).remove(eq(task01.id()), 
eq(SuspendReason.MIGRATED));
         verify(task01).suspend();
         verify(task01).closeClean();
 
@@ -3392,7 +3393,7 @@ public class TaskManagerTest {
 
         when(stateUpdater.tasks()).thenReturn(singleton(task00));
         final CompletableFuture<StateUpdater.RemovedTaskResult> future = 
mock(CompletableFuture.class);
-        when(stateUpdater.remove(eq(taskId00))).thenReturn(future);
+        when(stateUpdater.remove(eq(taskId00), 
eq(SuspendReason.MIGRATED))).thenReturn(future);
         when(future.get(anyLong(), any())).thenThrow(new 
java.util.concurrent.TimeoutException());
 
         taskManager.shutdown(true);
@@ -3420,7 +3421,7 @@ public class TaskManagerTest {
 
         // task01 is revoked, task00 stays
         final CompletableFuture<StateUpdater.RemovedTaskResult> futureTask01 = 
new CompletableFuture<>();
-        when(stateUpdater.remove(task01.id())).thenReturn(futureTask01);
+        when(stateUpdater.remove(eq(task01.id()), 
eq(SuspendReason.MIGRATED))).thenReturn(futureTask01);
         futureTask01.complete(new StateUpdater.RemovedTaskResult(task01));
 
         final RuntimeException thrown = assertThrows(RuntimeException.class,
@@ -3432,7 +3433,7 @@ public class TaskManagerTest {
 
         verify(task01, times(2)).suspend();
         verify(task01).closeDirty();
-        verify(stateUpdater, never()).remove(task00.id());
+        verify(stateUpdater, never()).remove(eq(task00.id()), 
eq(SuspendReason.MIGRATED));
         verify(task00, never()).suspend();
         verify(task00, never()).prepareCommit(anyBoolean());
         verify(task00, never()).closeClean();
@@ -3533,7 +3534,7 @@ public class TaskManagerTest {
         
when(tasks.standbyInitializedTasks()).thenReturn(Set.of(standbyTask00));
 
         final CompletableFuture<StateUpdater.RemovedTaskResult> 
futureForStandbyTask = new CompletableFuture<>();
-        when(stateUpdater.remove(taskId00)).thenReturn(futureForStandbyTask);
+        when(stateUpdater.remove(eq(taskId00), 
eq(SuspendReason.MIGRATED))).thenReturn(futureForStandbyTask);
 
         final TaskManager taskManager = 
setUpTaskManager(ProcessingMode.AT_LEAST_ONCE, tasks);
 
@@ -3619,13 +3620,13 @@ public class TaskManagerTest {
         final CompletableFuture<StateUpdater.RemovedTaskResult> 
futureForRemovedFailedStandbyTask = new CompletableFuture<>();
         final CompletableFuture<StateUpdater.RemovedTaskResult> 
futureForRemovedFailedStatefulTaskDuringRemoval = new CompletableFuture<>();
         final CompletableFuture<StateUpdater.RemovedTaskResult> 
futureForRemovedFailedStandbyTaskDuringRemoval = new CompletableFuture<>();
-        
when(stateUpdater.remove(removedStatefulTask.id())).thenReturn(futureForRemovedStatefulTask);
-        
when(stateUpdater.remove(removedStandbyTask.id())).thenReturn(futureForRemovedStandbyTask);
-        
when(stateUpdater.remove(removedFailedStatefulTask.id())).thenReturn(futureForRemovedFailedStatefulTask);
-        
when(stateUpdater.remove(removedFailedStandbyTask.id())).thenReturn(futureForRemovedFailedStandbyTask);
-        when(stateUpdater.remove(removedFailedStatefulTaskDuringRemoval.id()))
+        when(stateUpdater.remove(eq(removedStatefulTask.id()), 
eq(SuspendReason.MIGRATED))).thenReturn(futureForRemovedStatefulTask);
+        when(stateUpdater.remove(eq(removedStandbyTask.id()), 
eq(SuspendReason.MIGRATED))).thenReturn(futureForRemovedStandbyTask);
+        when(stateUpdater.remove(eq(removedFailedStatefulTask.id()), 
eq(SuspendReason.MIGRATED))).thenReturn(futureForRemovedFailedStatefulTask);
+        when(stateUpdater.remove(eq(removedFailedStandbyTask.id()), 
eq(SuspendReason.MIGRATED))).thenReturn(futureForRemovedFailedStandbyTask);
+        
when(stateUpdater.remove(eq(removedFailedStatefulTaskDuringRemoval.id()), 
eq(SuspendReason.MIGRATED)))
             .thenReturn(futureForRemovedFailedStatefulTaskDuringRemoval);
-        when(stateUpdater.remove(removedFailedStandbyTaskDuringRemoval.id()))
+        
when(stateUpdater.remove(eq(removedFailedStandbyTaskDuringRemoval.id()), 
eq(SuspendReason.MIGRATED)))
             .thenReturn(futureForRemovedFailedStandbyTaskDuringRemoval);
         when(stateUpdater.drainExceptionsAndFailedTasks())
                 .thenReturn(Arrays.asList(
@@ -4429,11 +4430,11 @@ public class TaskManagerTest {
 
         // mock futures for removing tasks from StateUpdater
         final CompletableFuture<StateUpdater.RemovedTaskResult> future01 = new 
CompletableFuture<>();
-        when(stateUpdater.remove(taskId01)).thenReturn(future01);
+        when(stateUpdater.remove(eq(taskId01), 
eq(SuspendReason.MIGRATED))).thenReturn(future01);
         future01.complete(new StateUpdater.RemovedTaskResult(migratedTask01));
 
         final CompletableFuture<StateUpdater.RemovedTaskResult> future02 = new 
CompletableFuture<>();
-        when(stateUpdater.remove(taskId02)).thenReturn(future02);
+        when(stateUpdater.remove(eq(taskId02), 
eq(SuspendReason.MIGRATED))).thenReturn(future02);
         future02.complete(new StateUpdater.RemovedTaskResult(migratedTask02));
 
         final TaskMigratedException thrown = assertThrows(
@@ -4448,8 +4449,8 @@ public class TaskManagerTest {
         );
         verify(migratedTask01, times(2)).suspend();
         verify(migratedTask02, times(2)).suspend();
-        verify(stateUpdater).remove(taskId01);
-        verify(stateUpdater).remove(taskId02);
+        verify(stateUpdater).remove(eq(taskId01), eq(SuspendReason.MIGRATED));
+        verify(stateUpdater).remove(eq(taskId02), eq(SuspendReason.MIGRATED));
     }
 
     @Test
@@ -4475,11 +4476,11 @@ public class TaskManagerTest {
 
         // mock futures for removing tasks from StateUpdater
         final CompletableFuture<StateUpdater.RemovedTaskResult> future01 = new 
CompletableFuture<>();
-        when(stateUpdater.remove(taskId01)).thenReturn(future01);
+        when(stateUpdater.remove(eq(taskId01), 
eq(SuspendReason.MIGRATED))).thenReturn(future01);
         future01.complete(new StateUpdater.RemovedTaskResult(migratedTask01));
 
         final CompletableFuture<StateUpdater.RemovedTaskResult> future02 = new 
CompletableFuture<>();
-        when(stateUpdater.remove(taskId02)).thenReturn(future02);
+        when(stateUpdater.remove(eq(taskId02), 
eq(SuspendReason.MIGRATED))).thenReturn(future02);
         future02.complete(new StateUpdater.RemovedTaskResult(migratedTask02));
 
         final RuntimeException thrown = assertThrows(
@@ -4493,8 +4494,8 @@ public class TaskManagerTest {
 
         verify(migratedTask01, times(2)).suspend();
         verify(migratedTask02, times(2)).suspend();
-        verify(stateUpdater).remove(taskId01);
-        verify(stateUpdater).remove(taskId02);
+        verify(stateUpdater).remove(eq(taskId01), eq(SuspendReason.MIGRATED));
+        verify(stateUpdater).remove(eq(taskId02), eq(SuspendReason.MIGRATED));
     }
 
     @Test
@@ -4520,11 +4521,11 @@ public class TaskManagerTest {
 
         // mock futures for removing tasks from StateUpdater
         final CompletableFuture<StateUpdater.RemovedTaskResult> future01 = new 
CompletableFuture<>();
-        when(stateUpdater.remove(taskId01)).thenReturn(future01);
+        when(stateUpdater.remove(eq(taskId01), 
eq(SuspendReason.MIGRATED))).thenReturn(future01);
         future01.complete(new StateUpdater.RemovedTaskResult(migratedTask01));
 
         final CompletableFuture<StateUpdater.RemovedTaskResult> future02 = new 
CompletableFuture<>();
-        when(stateUpdater.remove(taskId02)).thenReturn(future02);
+        when(stateUpdater.remove(eq(taskId02), 
eq(SuspendReason.MIGRATED))).thenReturn(future02);
         future02.complete(new StateUpdater.RemovedTaskResult(migratedTask02));
 
         final StreamsException thrown = assertThrows(
@@ -4540,8 +4541,8 @@ public class TaskManagerTest {
 
         verify(migratedTask01, times(2)).suspend();
         verify(migratedTask02, times(2)).suspend();
-        verify(stateUpdater).remove(taskId01);
-        verify(stateUpdater).remove(taskId02);
+        verify(stateUpdater).remove(eq(taskId01), eq(SuspendReason.MIGRATED));
+        verify(stateUpdater).remove(eq(taskId02), eq(SuspendReason.MIGRATED));
     }
 
     @Test
@@ -4807,7 +4808,7 @@ public class TaskManagerTest {
         // convert active to standby
         when(stateUpdater.tasks()).thenReturn(Set.of(activeTaskToRecycle));
         final CompletableFuture<StateUpdater.RemovedTaskResult> future = new 
CompletableFuture<>();
-        when(stateUpdater.remove(activeTaskToRecycle.id())).thenReturn(future);
+        when(stateUpdater.remove(eq(activeTaskToRecycle.id()), 
eq(SuspendReason.MIGRATED))).thenReturn(future);
         future.complete(new 
StateUpdater.RemovedTaskResult(activeTaskToRecycle));
 
         taskManager.handleAssignment(Collections.emptyMap(), 
taskId00Assignment);
@@ -4839,7 +4840,7 @@ public class TaskManagerTest {
         // convert standby to active
         when(stateUpdater.tasks()).thenReturn(Set.of(standbyTaskToRecycle));
         final CompletableFuture<StateUpdater.RemovedTaskResult> future = new 
CompletableFuture<>();
-        
when(stateUpdater.remove(standbyTaskToRecycle.id())).thenReturn(future);
+        when(stateUpdater.remove(eq(standbyTaskToRecycle.id()), 
eq(SuspendReason.PROMOTED))).thenReturn(future);
         future.complete(new 
StateUpdater.RemovedTaskResult(standbyTaskToRecycle));
 
         taskManager.handleAssignment(taskId00Assignment, 
Collections.emptyMap());
diff --git 
a/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java
 
b/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java
index 071ee507c24..65fcb58c24f 100644
--- 
a/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java
+++ 
b/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java
@@ -48,6 +48,7 @@ import org.apache.kafka.streams.internals.StreamsConfigUtils;
 import org.apache.kafka.streams.kstream.Windowed;
 import org.apache.kafka.streams.processor.PunctuationType;
 import org.apache.kafka.streams.processor.Punctuator;
+import org.apache.kafka.streams.processor.StandbyUpdateListener.SuspendReason;
 import org.apache.kafka.streams.processor.StateRestoreListener;
 import org.apache.kafka.streams.processor.StateStore;
 import org.apache.kafka.streams.processor.StateStoreContext;
@@ -1159,6 +1160,10 @@ public class TopologyTestDriver implements Closeable {
 
         @Override
         public void unregister(final Collection<TopicPartition> partitions) { }
+
+        @Override
+        public void unregister(final Collection<TopicPartition> partitions,
+                               final SuspendReason reason) { }
     }
 
     static class MockTime implements Time {

Reply via email to