This is an automated email from the ASF dual-hosted git repository.

lucasbru pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 225660a43dc KAFKA-19683: Cleanup and rewrite more tests in 
TaskManagerTest [3/N] (#20692)
225660a43dc is described below

commit 225660a43dc6e34e2dc10e8b7549467a7b75cb96
Author: Shashank <[email protected]>
AuthorDate: Tue Oct 14 05:52:25 2025 -0700

    KAFKA-19683: Cleanup and rewrite more tests in TaskManagerTest [3/N] 
(#20692)
    
    Cleanup and rewrote more tests in `TaskManagerTest.java`
    
    Reviewers: Lucas Brutschy <[email protected]>
---
 .../processor/internals/TaskManagerTest.java       | 545 +++++++++------------
 1 file changed, 245 insertions(+), 300 deletions(-)

diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
index f89faffd971..66078b44abc 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
@@ -120,6 +120,7 @@ import static org.mockito.ArgumentMatchers.anyLong;
 import static org.mockito.ArgumentMatchers.anyString;
 import static org.mockito.ArgumentMatchers.argThat;
 import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.atLeastOnce;
 import static org.mockito.Mockito.doNothing;
 import static org.mockito.Mockito.doThrow;
 import static org.mockito.Mockito.inOrder;
@@ -1716,17 +1717,17 @@ public class TaskManagerTest {
     }
 
     @Test
-    public void shouldPauseAllTopicsWithoutStateUpdaterOnRebalanceComplete() {
+    public void shouldPauseAllTopicsOnRebalanceComplete() {
         final Set<TopicPartition> assigned = Set.of(t1p0, t1p1);
         when(consumer.assignment()).thenReturn(assigned);
-
+        final TaskManager taskManager = 
setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, null);
         taskManager.handleRebalanceComplete();
 
         verify(consumer).pause(assigned);
     }
 
     @Test
-    public void shouldNotPauseReadyTasksWithStateUpdaterOnRebalanceComplete() {
+    public void shouldNotPauseReadyTasksOnRebalanceComplete() {
         final StreamTask statefulTask0 = statefulTask(taskId00, 
taskId00ChangelogPartitions)
             .inState(State.RUNNING)
             .withInputPartitions(taskId00Partitions).build();
@@ -1743,29 +1744,6 @@ public class TaskManagerTest {
 
     @Test
     public void shouldReleaseLockForUnassignedTasksAfterRebalance() throws 
Exception {
-        expectLockObtainedFor(taskId00, taskId01, taskId02);
-        expectDirectoryNotEmpty(taskId00, taskId01, taskId02);
-
-        makeTaskFolders(
-            taskId00.toString(),  // active task
-            taskId01.toString(),  // standby task
-            taskId02.toString()   // unassigned but able to lock
-        );
-        taskManager.handleRebalanceStart(singleton("topic"));
-
-        assertThat(taskManager.lockedTaskDirectories(), is(Set.of(taskId00, 
taskId01, taskId02)));
-
-        handleAssignment(taskId00Assignment, taskId01Assignment, emptyMap());
-
-        taskManager.handleRebalanceComplete();
-        assertThat(taskManager.lockedTaskDirectories(), is(Set.of(taskId00, 
taskId01)));
-
-        verify(stateDirectory).unlock(taskId02);
-        verify(consumer).pause(assignment);
-    }
-
-    @Test
-    public void 
shouldReleaseLockForUnassignedTasksAfterRebalanceWithStateUpdater() throws 
Exception {
         final StreamTask runningStatefulTask = statefulTask(taskId00, 
taskId00ChangelogPartitions)
             .inState(State.RUNNING)
             .withInputPartitions(taskId00Partitions).build();
@@ -1805,10 +1783,12 @@ public class TaskManagerTest {
         final StreamTask runningStatefulTask = statefulTask(taskId00, 
taskId00ChangelogPartitions)
             .inState(State.RUNNING).build();
         final long changelogOffsetOfRunningTask = Task.LATEST_OFFSET;
-        when(runningStatefulTask.changelogOffsets())
-            .thenReturn(mkMap(mkEntry(t1p0changelog, 
changelogOffsetOfRunningTask)));
+        final Map<TopicPartition, Long> changelogOffsets = mkMap(
+            mkEntry(t1p0changelog, changelogOffsetOfRunningTask)
+        );
+        
when(runningStatefulTask.changelogOffsets()).thenReturn(changelogOffsets);
         final TasksRegistry tasks = mock(TasksRegistry.class);
-        final TaskManager taskManager = 
setUpTaskManagerWithoutStateUpdater(ProcessingMode.AT_LEAST_ONCE, tasks, false);
+        final TaskManager taskManager = 
setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, tasks);
         when(tasks.allTasksPerId()).thenReturn(mkMap(mkEntry(taskId00, 
runningStatefulTask)));
 
         assertThat(
@@ -1819,20 +1799,32 @@ public class TaskManagerTest {
 
     @Test
     public void shouldComputeOffsetSumForNonRunningActiveTask() throws 
Exception {
+        final StreamTask restoringStatefulTask = statefulTask(taskId00, 
taskId00ChangelogPartitions)
+            .inState(State.RESTORING).build();
         final Map<TopicPartition, Long> changelogOffsets = mkMap(
             mkEntry(new TopicPartition("changelog", 0), 5L),
             mkEntry(new TopicPartition("changelog", 1), 10L)
         );
-        final Map<TaskId, Long> expectedOffsetSums = mkMap(mkEntry(taskId00, 
15L));
+        final Map<TaskId, Long> expectedOffsetSums = mkMap(
+            mkEntry(taskId00, 15L)
+        );
+        when(restoringStatefulTask.changelogOffsets())
+            .thenReturn(changelogOffsets);
+        final TasksRegistry tasks = mock(TasksRegistry.class);
+        final TaskManager taskManager = 
setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, tasks);
+        when(stateUpdater.tasks()).thenReturn(Set.of(restoringStatefulTask));
 
-        computeOffsetSumAndVerify(changelogOffsets, expectedOffsetSums);
+        assertThat(taskManager.taskOffsetSums(), is(expectedOffsetSums));
     }
 
     @Test
-    public void shouldComputeOffsetSumForRestoringActiveTaskWithStateUpdater() 
throws Exception {
+    public void shouldComputeOffsetSumForRestoringActiveTask() throws 
Exception {
         final StreamTask restoringStatefulTask = statefulTask(taskId00, 
taskId00ChangelogPartitions)
             .inState(State.RESTORING).build();
         final long changelogOffset = 42L;
+        final Map<TaskId, Long> expectedOffsetSums = mkMap(
+            mkEntry(taskId00, changelogOffset)
+        );
         
when(restoringStatefulTask.changelogOffsets()).thenReturn(mkMap(mkEntry(t1p0changelog,
 changelogOffset)));
         expectLockObtainedFor(taskId00);
         makeTaskFolders(taskId00.toString());
@@ -1843,11 +1835,11 @@ public class TaskManagerTest {
         when(stateUpdater.tasks()).thenReturn(Set.of(restoringStatefulTask));
         taskManager.handleRebalanceStart(singleton("topic"));
 
-        assertThat(taskManager.taskOffsetSums(), is(mkMap(mkEntry(taskId00, 
changelogOffset))));
+        assertThat(taskManager.taskOffsetSums(), is(expectedOffsetSums));
     }
 
     @Test
-    public void 
shouldComputeOffsetSumForRestoringStandbyTaskWithStateUpdater() throws 
Exception {
+    public void shouldComputeOffsetSumForRestoringStandbyTask() throws 
Exception {
         final StandbyTask restoringStandbyTask = standbyTask(taskId00, 
taskId00ChangelogPartitions)
             .inState(State.RUNNING).build();
         final long changelogOffset = 42L;
@@ -1919,23 +1911,6 @@ public class TaskManagerTest {
         );
     }
 
-    private void computeOffsetSumAndVerify(final Map<TopicPartition, Long> 
changelogOffsets,
-                                           final Map<TaskId, Long> 
expectedOffsetSums) throws Exception {
-        expectLockObtainedFor(taskId00);
-        expectDirectoryNotEmpty(taskId00);
-        makeTaskFolders(taskId00.toString());
-
-        taskManager.handleRebalanceStart(singleton("topic"));
-        final StateMachineTask restoringTask = handleAssignment(
-            emptyMap(),
-            emptyMap(),
-            taskId00Assignment
-        ).get(taskId00);
-        restoringTask.setChangelogOffsets(changelogOffsets);
-
-        assertThat(taskManager.taskOffsetSums(), is(expectedOffsetSums));
-    }
-
     @Test
     public void shouldComputeOffsetSumForStandbyTask() throws Exception {
         final Map<TopicPartition, Long> changelogOffsets = mkMap(
@@ -2073,48 +2048,45 @@ public class TaskManagerTest {
 
     @Test
     public void 
shouldCloseActiveUnassignedSuspendedTasksWhenClosingRevokedTasks() {
-        final StateMachineTask task00 = new StateMachineTask(taskId00, 
taskId00Partitions, true, stateManager);
-        final Map<TopicPartition, OffsetAndMetadata> offsets = 
singletonMap(t1p0, new OffsetAndMetadata(0L, null));
-        task00.setCommittableOffsetsAndMetadata(offsets);
-
-        // first `handleAssignment`
-        when(consumer.assignment()).thenReturn(assignment);
-
-        when(activeTaskCreator.createTasks(any(), 
eq(taskId00Assignment))).thenReturn(singletonList(task00));
+        final StreamTask task00 = statefulTask(taskId00, 
taskId00ChangelogPartitions)
+            .withInputPartitions(taskId00Partitions)
+            .inState(State.SUSPENDED)
+            .build();
 
-        taskManager.handleAssignment(taskId00Assignment, emptyMap());
-        assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
-        assertThat(task00.state(), is(Task.State.RUNNING));
+        final TasksRegistry tasks = mock(TasksRegistry.class);
+        when(tasks.allNonFailedTasks()).thenReturn(Set.of(task00));
 
-        taskManager.handleRevocation(taskId00Partitions);
-        assertThat(task00.state(), is(Task.State.SUSPENDED));
+        final TaskManager taskManager = 
setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, tasks);
 
         taskManager.handleAssignment(emptyMap(), emptyMap());
-        assertThat(task00.state(), is(Task.State.CLOSED));
-        assertThat(taskManager.activeTaskMap(), Matchers.anEmptyMap());
-        assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap());
+
+        verify(task00).prepareCommit(true);
+        verify(task00).closeClean();
+        verify(tasks).removeTask(task00);
     }
 
     @Test
     public void 
shouldCloseDirtyActiveUnassignedTasksWhenErrorCleanClosingTask() {
-        final StateMachineTask task00 = new StateMachineTask(taskId00, 
taskId00Partitions, true, stateManager) {
-            @Override
-            public void closeClean() {
-                throw new RuntimeException("KABOOM!");
-            }
-        };
+        final StreamTask task00 = statefulTask(taskId00, 
taskId00ChangelogPartitions)
+            .withInputPartitions(taskId00Partitions)
+            .inState(State.SUSPENDED)
+            .build();
 
-        when(activeTaskCreator.createTasks(any(), 
eq(taskId00Assignment))).thenReturn(singletonList(task00));
+        doThrow(new RuntimeException("KABOOM!")).when(task00).closeClean();
 
-        taskManager.handleAssignment(taskId00Assignment, emptyMap());
-        taskManager.handleRevocation(taskId00Partitions);
+        final TasksRegistry tasks = mock(TasksRegistry.class);
+        when(tasks.allNonFailedTasks()).thenReturn(Set.of(task00));
+
+        final TaskManager taskManager = 
setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, tasks);
 
         final RuntimeException thrown = assertThrows(
             RuntimeException.class,
             () -> taskManager.handleAssignment(emptyMap(), emptyMap())
         );
 
-        assertThat(task00.state(), is(Task.State.CLOSED));
+        verify(task00).closeClean();
+        verify(task00).closeDirty();
+        verify(tasks).removeTask(task00);
         assertThat(
             thrown.getMessage(),
             is("Encounter unexpected fatal error for task 0_0")
@@ -2124,13 +2096,18 @@ public class TaskManagerTest {
 
     @Test
     public void shouldCloseActiveTasksWhenHandlingLostTasks() {
-        final StateMachineTask task00 = new StateMachineTask(taskId00, 
taskId00Partitions, true, stateManager);
-        final StateMachineTask task01 = new StateMachineTask(taskId01, 
taskId01Partitions, false, stateManager);
+        final StreamTask task00 = statefulTask(taskId00, 
taskId00ChangelogPartitions)
+            .withInputPartitions(taskId00Partitions)
+            .inState(State.RUNNING)
+            .build();
 
-        // `handleAssignment`
-        when(consumer.assignment()).thenReturn(assignment);
-        when(activeTaskCreator.createTasks(any(), 
eq(taskId00Assignment))).thenReturn(singletonList(task00));
-        
when(standbyTaskCreator.createTasks(taskId01Assignment)).thenReturn(singletonList(task01));
+        final StandbyTask task01 = standbyTask(taskId01, 
taskId01ChangelogPartitions)
+            .inState(State.RUNNING)
+            .build();
+
+        final TasksRegistry tasks = mock(TasksRegistry.class);
+        when(tasks.allTasks()).thenReturn(Set.of(task00, task01));
+        when(tasks.allTaskIds()).thenReturn(Set.of(taskId00, taskId01));
 
         final ArrayList<TaskDirectory> taskFolders = new ArrayList<>(2);
         taskFolders.add(new 
TaskDirectory(testFolder.resolve(taskId00.toString()).toFile(), null));
@@ -2143,21 +2120,26 @@ public class TaskManagerTest {
         expectLockObtainedFor(taskId00, taskId01);
         expectDirectoryNotEmpty(taskId00, taskId01);
 
-        taskManager.handleRebalanceStart(emptySet());
-        assertThat(taskManager.lockedTaskDirectories(), 
Matchers.is(Set.of(taskId00, taskId01)));
+        final TaskManager taskManager = 
setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, tasks);
 
-        taskManager.handleAssignment(taskId00Assignment, taskId01Assignment);
-        assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
-        assertThat(task00.state(), is(Task.State.RUNNING));
-        assertThat(task01.state(), is(Task.State.RUNNING));
+        taskManager.handleRebalanceStart(emptySet());
+        assertThat(taskManager.lockedTaskDirectories(), is(Set.of(taskId00, 
taskId01)));
 
-        // `handleLostAll`
+        // this should close only active tasks as zombies
         taskManager.handleLostAll();
-        assertThat(task00.commitPrepared, is(true));
-        assertThat(task00.state(), is(Task.State.CLOSED));
-        assertThat(task01.state(), is(Task.State.RUNNING));
-        assertThat(taskManager.activeTaskMap(), Matchers.anEmptyMap());
-        assertThat(taskManager.standbyTaskMap(), is(singletonMap(taskId01, 
task01)));
+
+        // close of active task
+        verify(task00).prepareCommit(false);
+        verify(task00).suspend();
+        verify(task00).closeDirty();
+        verify(tasks).removeTask(task00);
+
+        // standby task not closed
+        verify(task01, never()).prepareCommit(anyBoolean());
+        verify(task01, never()).suspend();
+        verify(task01, never()).closeDirty();
+        verify(task01, never()).closeClean();
+        verify(tasks, never()).removeTask(task01);
 
         // The locked task map will not be cleared.
         assertThat(taskManager.lockedTaskDirectories(), is(Set.of(taskId00, 
taskId01)));
@@ -2316,37 +2298,7 @@ public class TaskManagerTest {
     }
 
     @Test
-    public void shouldNotCommitNonRunningNonCorruptedTasks() {
-        final ProcessorStateManager stateManager = 
mock(ProcessorStateManager.class);
-
-        final StateMachineTask corruptedTask = new StateMachineTask(taskId00, 
taskId00Partitions, true, stateManager);
-        final StateMachineTask nonRunningNonCorruptedTask = new 
StateMachineTask(taskId01, taskId01Partitions, true, stateManager);
-
-        nonRunningNonCorruptedTask.setCommitNeeded();
-
-        final Map<TaskId, Set<TopicPartition>> assignment = new 
HashMap<>(taskId00Assignment);
-        assignment.putAll(taskId01Assignment);
-
-        // `handleAssignment`
-        when(activeTaskCreator.createTasks(any(), eq(assignment)))
-            .thenReturn(asList(corruptedTask, nonRunningNonCorruptedTask));
-        when(consumer.assignment()).thenReturn(taskId00Partitions);
-
-        taskManager.handleAssignment(assignment, emptyMap());
-
-        corruptedTask.setChangelogOffsets(singletonMap(t1p0, 0L));
-        taskManager.handleCorruption(singleton(taskId00));
-
-        assertThat(nonRunningNonCorruptedTask.state(), is(Task.State.CREATED));
-        assertThat(nonRunningNonCorruptedTask.partitionsForOffsetReset, 
equalTo(Collections.emptySet()));
-        assertThat(corruptedTask.partitionsForOffsetReset, 
equalTo(taskId00Partitions));
-
-        assertFalse(nonRunningNonCorruptedTask.commitPrepared);
-        verify(stateManager).markChangelogAsCorrupted(taskId00Partitions);
-    }
-
-    @Test
-    public void 
shouldNotCommitNonCorruptedRestoringActiveTasksAndNotCommitRunningStandbyTasksWithStateUpdaterEnabled()
 {
+    public void 
shouldNotCommitNonCorruptedRestoringActiveTasksAndNotCommitRunningStandbyTasks()
 {
         final StreamTask activeRestoringTask = statefulTask(taskId00, 
taskId00ChangelogPartitions)
             .withInputPartitions(taskId00Partitions)
             .inState(State.RESTORING).build();
@@ -3643,119 +3595,150 @@ public class TaskManagerTest {
 
     @Test
     public void shouldCommitActiveAndStandbyTasks() {
-        final StateMachineTask task00 = new StateMachineTask(taskId00, 
taskId00Partitions, true, stateManager);
+        final StreamTask task00 = statefulTask(taskId00, 
taskId00ChangelogPartitions)
+            .withInputPartitions(taskId00Partitions)
+            .inState(State.RUNNING)
+            .build();
         final Map<TopicPartition, OffsetAndMetadata> offsets = 
singletonMap(t1p0, new OffsetAndMetadata(0L, null));
-        task00.setCommittableOffsetsAndMetadata(offsets);
-        final StateMachineTask task01 = new StateMachineTask(taskId01, 
taskId01Partitions, false, stateManager);
 
-        when(consumer.assignment()).thenReturn(assignment);
-        when(activeTaskCreator.createTasks(any(), eq(taskId00Assignment)))
-            .thenReturn(singletonList(task00));
-        when(standbyTaskCreator.createTasks(taskId01Assignment))
-            .thenReturn(singletonList(task01));
+        final StandbyTask task01 = standbyTask(taskId01, 
taskId01ChangelogPartitions)
+            .withInputPartitions(taskId01Partitions)
+            .inState(State.RUNNING)
+            .build();
 
-        taskManager.handleAssignment(taskId00Assignment, taskId01Assignment);
-        assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
+        when(task00.commitNeeded()).thenReturn(true);
+        when(task00.prepareCommit(true)).thenReturn(offsets);
+        when(task01.commitNeeded()).thenReturn(true);
+        when(task01.prepareCommit(true)).thenReturn(emptyMap());
 
-        assertThat(task00.state(), is(Task.State.RUNNING));
-        assertThat(task01.state(), is(Task.State.RUNNING));
+        final TasksRegistry tasks = mock(TasksRegistry.class);
+        when(tasks.allTasks()).thenReturn(Set.of(task00, task01));
 
-        task00.setCommitNeeded();
-        task01.setCommitNeeded();
+        final TaskManager taskManager = 
setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, tasks);
 
         assertThat(taskManager.commitAll(), equalTo(2));
-        assertThat(task00.commitNeeded, is(false));
-        assertThat(task01.commitNeeded, is(false));
 
+        verify(task00, times(2)).commitNeeded();
+        verify(task00).prepareCommit(true);
+        verify(task00).postCommit(false);
+        verify(task01, times(2)).commitNeeded();
+        verify(task01).prepareCommit(true);
+        verify(task01).postCommit(false);
         verify(consumer).commitSync(offsets);
     }
 
     @Test
     public void shouldCommitProvidedTasksIfNeeded() {
-        final StateMachineTask task00 = new StateMachineTask(taskId00, 
taskId00Partitions, true, stateManager);
-        final StateMachineTask task01 = new StateMachineTask(taskId01, 
taskId01Partitions, true, stateManager);
-        final StateMachineTask task02 = new StateMachineTask(taskId02, 
taskId02Partitions, true, stateManager);
-        final StateMachineTask task03 = new StateMachineTask(taskId03, 
taskId03Partitions, false, stateManager);
-        final StateMachineTask task04 = new StateMachineTask(taskId04, 
taskId04Partitions, false, stateManager);
-        final StateMachineTask task05 = new StateMachineTask(taskId05, 
taskId05Partitions, false, stateManager);
+        final StreamTask task00 = statefulTask(taskId00, 
taskId00ChangelogPartitions)
+            .withInputPartitions(taskId00Partitions)
+            .inState(State.RUNNING)
+            .build();
+        final Map<TopicPartition, OffsetAndMetadata> offsetsTask00 = 
singletonMap(t1p0, new OffsetAndMetadata(0L, null));
 
-        final Map<TaskId, Set<TopicPartition>> assignmentActive = mkMap(
-            mkEntry(taskId00, taskId00Partitions),
-            mkEntry(taskId01, taskId01Partitions),
-            mkEntry(taskId02, taskId02Partitions)
-        );
-        final Map<TaskId, Set<TopicPartition>> assignmentStandby = mkMap(
-            mkEntry(taskId03, taskId03Partitions),
-            mkEntry(taskId04, taskId04Partitions),
-            mkEntry(taskId05, taskId05Partitions)
-        );
+        final StreamTask task01 = statefulTask(taskId01, 
taskId01ChangelogPartitions)
+            .withInputPartitions(taskId01Partitions)
+            .inState(State.RUNNING)
+            .build();
+        final Map<TopicPartition, OffsetAndMetadata> offsetsTask01 = 
singletonMap(t1p1, new OffsetAndMetadata(1L, null));
 
-        when(consumer.assignment()).thenReturn(assignment);
-        when(activeTaskCreator.createTasks(any(), eq(assignmentActive)))
-            .thenReturn(Arrays.asList(task00, task01, task02));
-        when(standbyTaskCreator.createTasks(assignmentStandby))
-            .thenReturn(Arrays.asList(task03, task04, task05));
+        final StreamTask task02 = statefulTask(taskId02, 
taskId02ChangelogPartitions)
+            .withInputPartitions(taskId02Partitions)
+            .inState(State.RUNNING)
+            .build();
 
-        taskManager.handleAssignment(assignmentActive, assignmentStandby);
-        assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
+        final StandbyTask task03 = standbyTask(taskId03, 
taskId03ChangelogPartitions)
+            .withInputPartitions(taskId03Partitions)
+            .inState(State.RUNNING)
+            .build();
 
-        assertThat(task00.state(), is(Task.State.RUNNING));
-        assertThat(task01.state(), is(Task.State.RUNNING));
+        final StandbyTask task04 = standbyTask(taskId04, 
taskId04ChangelogPartitions)
+            .withInputPartitions(taskId04Partitions)
+            .inState(State.RUNNING)
+            .build();
 
-        task00.setCommitNeeded();
-        task01.setCommitNeeded();
-        task03.setCommitNeeded();
-        task04.setCommitNeeded();
+        final StandbyTask task05 = standbyTask(taskId05, 
taskId05ChangelogPartitions)
+            .withInputPartitions(taskId05Partitions)
+            .inState(State.RUNNING)
+            .build();
+
+        when(task00.commitNeeded()).thenReturn(true);
+        when(task00.prepareCommit(true)).thenReturn(offsetsTask00);
+        when(task01.commitNeeded()).thenReturn(true);
+        when(task01.prepareCommit(true)).thenReturn(offsetsTask01);
+        when(task02.commitNeeded()).thenReturn(false);
+        when(task03.commitNeeded()).thenReturn(true);
+        when(task03.prepareCommit(true)).thenReturn(emptyMap());
+        when(task04.commitNeeded()).thenReturn(true);
+        when(task04.prepareCommit(true)).thenReturn(emptyMap());
+        when(task05.commitNeeded()).thenReturn(false);
+
+        final TasksRegistry tasks = mock(TasksRegistry.class);
+
+        final TaskManager taskManager = 
setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, tasks);
 
         assertThat(taskManager.commit(Set.of(task00, task02, task03, task05)), 
equalTo(2));
-        assertThat(task00.commitNeeded, is(false));
-        assertThat(task01.commitNeeded, is(true));
-        assertThat(task02.commitNeeded, is(false));
-        assertThat(task03.commitNeeded, is(false));
-        assertThat(task04.commitNeeded, is(true));
-        assertThat(task05.commitNeeded, is(false));
+
+        verify(task00, times(2)).commitNeeded();
+        verify(task00).prepareCommit(true);
+        verify(task00).postCommit(false);
+        verify(task01, never()).prepareCommit(anyBoolean());
+        verify(task01, never()).postCommit(anyBoolean());
+        verify(task02, atLeastOnce()).commitNeeded();
+        verify(task02, never()).prepareCommit(anyBoolean());
+        verify(task02, never()).postCommit(anyBoolean());
+        verify(task03, times(2)).commitNeeded();
+        verify(task03).prepareCommit(true);
+        verify(task03).postCommit(false);
+        verify(task04, never()).prepareCommit(anyBoolean());
+        verify(task04, never()).postCommit(anyBoolean());
+        verify(task05, atLeastOnce()).commitNeeded();
+        verify(task05, never()).prepareCommit(anyBoolean());
+        verify(task05, never()).postCommit(anyBoolean());
+        verify(consumer).commitSync(offsetsTask00);
     }
 
     @Test
     public void shouldNotCommitOffsetsIfOnlyStandbyTasksAssigned() {
-        final StateMachineTask task00 = new StateMachineTask(taskId00, 
taskId00Partitions, false, stateManager);
-
-        when(consumer.assignment()).thenReturn(assignment);
-        
when(standbyTaskCreator.createTasks(taskId00Assignment)).thenReturn(singletonList(task00));
+        final StandbyTask task00 = standbyTask(taskId00, 
taskId00ChangelogPartitions)
+            .withInputPartitions(taskId00Partitions)
+            .inState(State.RUNNING)
+            .build();
 
-        taskManager.handleAssignment(Collections.emptyMap(), 
taskId00Assignment);
-        assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
+        when(task00.commitNeeded()).thenReturn(true);
+        when(task00.prepareCommit(true)).thenReturn(emptyMap());
 
-        assertThat(task00.state(), is(Task.State.RUNNING));
+        final TasksRegistry tasks = mock(TasksRegistry.class);
+        when(tasks.allTasks()).thenReturn(Set.of(task00));
 
-        task00.setCommitNeeded();
+        final TaskManager taskManager = 
setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, tasks);
 
         assertThat(taskManager.commitAll(), equalTo(1));
-        assertThat(task00.commitNeeded, is(false));
+
+        verify(task00, times(2)).commitNeeded();
+        verify(task00).prepareCommit(true);
+        verify(task00).postCommit(false);
+        verify(consumer, never()).commitSync(any(Map.class));
     }
 
     @Test
-    public void shouldNotCommitActiveAndStandbyTasksWhileRebalanceInProgress() 
throws Exception {
-        final StateMachineTask task00 = new StateMachineTask(taskId00, 
taskId00Partitions, true, stateManager);
-        final StateMachineTask task01 = new StateMachineTask(taskId01, 
taskId01Partitions, false, stateManager);
+    public void shouldNotCommitActiveAndStandbyTasksWhileRebalanceInProgress() 
{
+        final StreamTask task00 = statefulTask(taskId00, 
taskId00ChangelogPartitions)
+            .withInputPartitions(taskId00Partitions)
+            .inState(State.RUNNING)
+            .build();
 
-        makeTaskFolders(taskId00.toString(), taskId01.toString());
-        expectDirectoryNotEmpty(taskId00, taskId01);
-        expectLockObtainedFor(taskId00, taskId01);
-        when(consumer.assignment()).thenReturn(assignment);
-        when(activeTaskCreator.createTasks(any(), eq(taskId00Assignment)))
-            .thenReturn(singletonList(task00));
-        when(standbyTaskCreator.createTasks(taskId01Assignment))
-            .thenReturn(singletonList(task01));
+        final StandbyTask task01 = standbyTask(taskId01, 
taskId01ChangelogPartitions)
+            .withInputPartitions(taskId01Partitions)
+            .inState(State.RUNNING)
+            .build();
 
-        taskManager.handleAssignment(taskId00Assignment, taskId01Assignment);
-        assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
+        when(task00.commitNeeded()).thenReturn(true);
+        when(task01.commitNeeded()).thenReturn(true);
 
-        assertThat(task00.state(), is(Task.State.RUNNING));
-        assertThat(task01.state(), is(Task.State.RUNNING));
+        final TasksRegistry tasks = mock(TasksRegistry.class);
+        when(tasks.allTasks()).thenReturn(Set.of(task00, task01));
 
-        task00.setCommitNeeded();
-        task01.setCommitNeeded();
+        final TaskManager taskManager = 
setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, tasks);
 
         taskManager.handleRebalanceStart(emptySet());
 
@@ -3953,63 +3936,79 @@ public class TaskManagerTest {
 
     @Test
     public void shouldMaybeCommitAllActiveTasksThatNeedCommit() {
-        final StateMachineTask task00 = new StateMachineTask(taskId00, 
taskId00Partitions, true, stateManager);
+        final StreamTask task00 = statefulTask(taskId00, 
taskId00ChangelogPartitions)
+            .withInputPartitions(taskId00Partitions)
+            .inState(State.RUNNING)
+            .build();
         final Map<TopicPartition, OffsetAndMetadata> offsets0 = 
singletonMap(t1p0, new OffsetAndMetadata(0L, null));
-        task00.setCommittableOffsetsAndMetadata(offsets0);
-        final StateMachineTask task01 = new StateMachineTask(taskId01, 
taskId01Partitions, true, stateManager);
+
+        final StreamTask task01 = statefulTask(taskId01, 
taskId01ChangelogPartitions)
+            .withInputPartitions(taskId01Partitions)
+            .inState(State.RUNNING)
+            .build();
         final Map<TopicPartition, OffsetAndMetadata> offsets1 = 
singletonMap(t1p1, new OffsetAndMetadata(1L, null));
-        task01.setCommittableOffsetsAndMetadata(offsets1);
-        final StateMachineTask task02 = new StateMachineTask(taskId02, 
taskId02Partitions, true, stateManager);
-        final Map<TopicPartition, OffsetAndMetadata> offsets2 = 
singletonMap(t1p2, new OffsetAndMetadata(2L, null));
-        task02.setCommittableOffsetsAndMetadata(offsets2);
-        final StateMachineTask task03 = new StateMachineTask(taskId03, 
taskId03Partitions, true, stateManager);
-        final StateMachineTask task04 = new StateMachineTask(taskId10, 
taskId10Partitions, false, stateManager);
 
-        final Map<TopicPartition, OffsetAndMetadata> expectedCommittedOffsets 
= new HashMap<>();
-        expectedCommittedOffsets.putAll(offsets0);
-        expectedCommittedOffsets.putAll(offsets1);
+        final StreamTask task02 = statefulTask(taskId02, 
taskId02ChangelogPartitions)
+            .withInputPartitions(taskId02Partitions)
+            .inState(State.RUNNING)
+            .build();
 
-        final Map<TaskId, Set<TopicPartition>> assignmentActive = mkMap(
-            mkEntry(taskId00, taskId00Partitions),
-            mkEntry(taskId01, taskId01Partitions),
-            mkEntry(taskId02, taskId02Partitions),
-            mkEntry(taskId03, taskId03Partitions)
-        );
+        final StreamTask task03 = statefulTask(taskId03, 
taskId03ChangelogPartitions)
+            .withInputPartitions(taskId03Partitions)
+            .inState(State.RUNNING)
+            .build();
 
-        final Map<TaskId, Set<TopicPartition>> assignmentStandby = mkMap(
-            mkEntry(taskId10, taskId10Partitions)
-        );
+        // for task00 both commitRequested AND commitNeeded - so it should 
trigger commit
+        when(task00.commitRequested()).thenReturn(true);
+        when(task00.commitNeeded()).thenReturn(true);
+        when(task00.prepareCommit(true)).thenReturn(offsets0);
 
-        when(consumer.assignment()).thenReturn(assignment);
-        when(activeTaskCreator.createTasks(any(), eq(assignmentActive)))
-            .thenReturn(asList(task00, task01, task02, task03));
-        when(standbyTaskCreator.createTasks(assignmentStandby))
-            .thenReturn(singletonList(task04));
+        // for task01 only commitNeeded (no commitRequested) so it gets 
committed when triggered
+        when(task01.commitRequested()).thenReturn(false);
+        when(task01.commitNeeded()).thenReturn(true);
+        when(task01.prepareCommit(true)).thenReturn(offsets1);
 
-        taskManager.handleAssignment(assignmentActive, assignmentStandby);
-        assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
+        // for task02 only commitRequested (no commitNeeded), so does not get 
committed
+        when(task02.commitRequested()).thenReturn(true);
+        when(task02.commitNeeded()).thenReturn(false);
 
-        assertThat(task00.state(), is(Task.State.RUNNING));
-        assertThat(task01.state(), is(Task.State.RUNNING));
-        assertThat(task02.state(), is(Task.State.RUNNING));
-        assertThat(task03.state(), is(Task.State.RUNNING));
-        assertThat(task04.state(), is(Task.State.RUNNING));
+        // for task03 both commitRequested AND commitNeeded, so should trigger 
commit
+        when(task03.commitRequested()).thenReturn(true);
+        when(task03.commitNeeded()).thenReturn(true);
+        when(task03.prepareCommit(true)).thenReturn(emptyMap());
 
-        task00.setCommitNeeded();
-        task00.setCommitRequested();
+        // expected committed offsets only for task00 and task01 (task03 has 
empty offsets)
+        final Map<TopicPartition, OffsetAndMetadata> expectedCommittedOffsets 
= new HashMap<>();
+        expectedCommittedOffsets.putAll(offsets0);
+        expectedCommittedOffsets.putAll(offsets1);
 
-        task01.setCommitNeeded();
+        final TasksRegistry tasks = mock(TasksRegistry.class);
+        when(tasks.allTasks()).thenReturn(Set.of(task00, task01, task02, 
task03));
 
-        task02.setCommitRequested();
+        final TaskManager taskManager = 
setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, tasks);
 
-        task03.setCommitNeeded();
-        task03.setCommitRequested();
+        // maybeCommitActiveTasksPerUserRequested checks if any task has both 
commitRequested AND commitNeeded
+        // If found, commits all active running tasks that have commitNeeded
+        // Returns count of committed tasks: task00, task01, and task03 (3 
tasks)
+        assertThat(taskManager.maybeCommitActiveTasksPerUserRequested(), 
equalTo(3));
 
-        task04.setCommitNeeded();
-        task04.setCommitRequested();
+        // Verify commit flow for tasks that needed commit
+        verify(task00, atLeastOnce()).commitNeeded();
+        verify(task00).prepareCommit(true);
+        verify(task00).postCommit(false);
 
-        assertThat(taskManager.maybeCommitActiveTasksPerUserRequested(), 
equalTo(3));
+        verify(task01, atLeastOnce()).commitNeeded();
+        verify(task01).prepareCommit(true);
+        verify(task01).postCommit(false);
+
+        verify(task03, atLeastOnce()).commitNeeded();
+        verify(task03).prepareCommit(true);
+        verify(task03).postCommit(false);
 
+        // task02 should not be committed (no commitNeeded)
+        verify(task02, never()).prepareCommit(anyBoolean());
+
+        // Consumer should commit combined offsets from task00 and task01
         verify(consumer).commitSync(expectedCommittedOffsets);
     }
 
@@ -4229,11 +4228,6 @@ public class TaskManagerTest {
         final TasksRegistry tasks = mock(TasksRegistry.class);
         when(tasks.activeTasks()).thenReturn(Set.of(task00));
 
-        when(stateUpdater.restoresActiveTasks()).thenReturn(false);
-        when(stateUpdater.hasExceptionsAndFailedTasks()).thenReturn(false);
-        
when(stateUpdater.drainRestoredActiveTasks(any(Duration.class))).thenReturn(Set.of());
-        
when(stateUpdater.drainExceptionsAndFailedTasks()).thenReturn(List.of());
-
         final TaskManager taskManager = 
setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, tasks);
 
         // one for stream and one for system time
@@ -4669,55 +4663,6 @@ public class TaskManagerTest {
 
     @Test
     public void shouldRecycleStartupTasksFromStateDirectoryAsActive() {
-        final StandbyTask startupTask = standbyTask(taskId00, 
taskId00ChangelogPartitions).build();
-        final StreamTask activeTask = statefulTask(taskId00, 
taskId00ChangelogPartitions).build();
-        when(activeTaskCreator.createActiveTaskFromStandby(eq(startupTask), 
eq(taskId00Partitions), any()))
-            .thenReturn(activeTask);
-
-        when(stateDirectory.hasStartupTasks()).thenReturn(true, false);
-        
when(stateDirectory.removeStartupTask(taskId00)).thenReturn(startupTask, (Task) 
null);
-
-        taskManager.handleAssignment(taskId00Assignment, 
Collections.emptyMap());
-
-        // ensure we recycled our existing startup Standby into an Active task
-        verify(activeTaskCreator).createActiveTaskFromStandby(eq(startupTask), 
eq(taskId00Partitions), any());
-
-        // ensure we didn't construct any new Tasks
-        verify(activeTaskCreator).createTasks(any(), 
eq(Collections.emptyMap()));
-        verify(standbyTaskCreator).createTasks(Collections.emptyMap());
-        verifyNoMoreInteractions(activeTaskCreator);
-        verifyNoMoreInteractions(standbyTaskCreator);
-
-        // verify the recycled task is now being used as an assigned Active
-        assertEquals(Collections.singletonMap(taskId00, activeTask), 
taskManager.activeTaskMap());
-        assertEquals(Collections.emptyMap(), taskManager.standbyTaskMap());
-    }
-
-    @Test
-    public void shouldUseStartupTasksFromStateDirectoryAsStandby() {
-        final StandbyTask startupTask = standbyTask(taskId00, 
taskId00ChangelogPartitions).build();
-
-        when(stateDirectory.hasStartupTasks()).thenReturn(true, true, false);
-        
when(stateDirectory.removeStartupTask(taskId00)).thenReturn(startupTask, (Task) 
null);
-
-        taskManager.handleAssignment(Collections.emptyMap(), 
taskId00Assignment);
-
-        // ensure we used our existing startup Task directly as a Standby
-        verify(startupTask).resume();
-
-        // ensure we didn't construct any new Tasks, or recycle an existing 
Task; we only used the one we already have
-        verify(activeTaskCreator).createTasks(any(), 
eq(Collections.emptyMap()));
-        verify(standbyTaskCreator).createTasks(Collections.emptyMap());
-        verifyNoMoreInteractions(activeTaskCreator);
-        verifyNoMoreInteractions(standbyTaskCreator);
-
-        // verify the startup Standby is now being used as an assigned Standby
-        assertEquals(Collections.emptyMap(), taskManager.activeTaskMap());
-        assertEquals(Collections.singletonMap(taskId00, startupTask), 
taskManager.standbyTaskMap());
-    }
-
-    @Test
-    public void 
shouldRecycleStartupTasksFromStateDirectoryAsActiveWithStateUpdater() {
         final Tasks taskRegistry = new Tasks(new LogContext());
         final TaskManager taskManager = 
setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, taskRegistry);
         final StandbyTask startupTask = standbyTask(taskId00, 
taskId00ChangelogPartitions).build();
@@ -4755,7 +4700,7 @@ public class TaskManagerTest {
     }
 
     @Test
-    public void 
shouldUseStartupTasksFromStateDirectoryAsStandbyWithStateUpdater() {
+    public void shouldUseStartupTasksFromStateDirectoryAsStandby() {
         final Tasks taskRegistry = new Tasks(new LogContext());
         final TaskManager taskManager = 
setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, taskRegistry);
         final StandbyTask startupTask = standbyTask(taskId00, 
taskId00ChangelogPartitions).build();


Reply via email to