lucasbru commented on code in PR #20889:
URL: https://github.com/apache/kafka/pull/20889#discussion_r2541447923


##########
streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java:
##########
@@ -1985,53 +1984,31 @@ public void 
shouldComputeOffsetSumForUnassignedTaskWeCanLock() throws Exception
         assertThat(taskManager.taskOffsetSums(), is(expectedOffsetSums));
     }
 
-    @Test
-    public void shouldComputeOffsetSumFromCheckpointFileForUninitializedTask() 
throws Exception {
+    @ParameterizedTest
+    @EnumSource(value = State.class, names = {"CREATED", "CLOSED"})
+    public void 
shouldComputeOffsetSumFromCheckpointFileForCreatedAndClosedTasks(final State 
state) throws Exception {
         final Map<TopicPartition, Long> changelogOffsets = mkMap(
             mkEntry(new TopicPartition("changelog", 0), 5L),
             mkEntry(new TopicPartition("changelog", 1), 10L)
         );
         final Map<TaskId, Long> expectedOffsetSums = mkMap(mkEntry(taskId00, 
15L));
 
-        expectLockObtainedFor(taskId00);
-        makeTaskFolders(taskId00.toString());
-        writeCheckpointFile(taskId00, changelogOffsets);
-
-        taskManager.handleRebalanceStart(singleton("topic"));
-        final StateMachineTask uninitializedTask = new 
StateMachineTask(taskId00, taskId00Partitions, true, stateManager);
-        when(activeTaskCreator.createTasks(any(), 
eq(taskId00Assignment))).thenReturn(singleton(uninitializedTask));
-
-        taskManager.handleAssignment(taskId00Assignment, emptyMap());
-
-        assertThat(uninitializedTask.state(), is(State.CREATED));
+        final StreamTask task = statefulTask(taskId00, 
taskId00ChangelogPartitions)
+            .inState(state)
+            .withInputPartitions(taskId00Partitions)
+            .build();
 
-        assertThat(taskManager.taskOffsetSums(), is(expectedOffsetSums));
-    }
+        final TasksRegistry tasks = mock(TasksRegistry.class);
+        final TaskManager taskManager = 
setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, tasks);
 
-    @Test
-    public void shouldComputeOffsetSumFromCheckpointFileForClosedTask() throws 
Exception {
-        final Map<TopicPartition, Long> changelogOffsets = mkMap(
-            mkEntry(new TopicPartition("changelog", 0), 5L),
-            mkEntry(new TopicPartition("changelog", 1), 10L)
-        );
-        final Map<TaskId, Long> expectedOffsetSums = mkMap(mkEntry(taskId00, 
15L));
+        when(tasks.allTasksPerId()).thenReturn(mkMap(mkEntry(taskId00, task)));

Review Comment:
   It's a bit weird that you add `tasks` to the stateUpdater first, and then 
mock it. I would reverse the order



##########
streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java:
##########
@@ -4203,107 +4175,85 @@ public void shouldNotFailOnTimeoutException() {
         final AtomicReference<TimeoutException> timeoutException = new 
AtomicReference<>();
         timeoutException.set(new TimeoutException("Skip me!"));
 
-        final StateMachineTask task00 = new StateMachineTask(taskId00, 
taskId00Partitions, true, stateManager);
-        task00.transitionTo(State.RESTORING);
-        task00.transitionTo(State.RUNNING);
-        final StateMachineTask task01 = new StateMachineTask(taskId01, 
taskId01Partitions, true, stateManager) {
-            @Override
-            public boolean process(final long wallClockTime) {
-                final TimeoutException exception = timeoutException.get();
-                if (exception != null) {
-                    throw exception;
-                }
-                return true;
-            }
-        };
-        task01.transitionTo(State.RESTORING);
-        task01.transitionTo(State.RUNNING);
-        final StateMachineTask task02 = new StateMachineTask(taskId02, 
taskId02Partitions, true, stateManager);
-        task02.transitionTo(State.RESTORING);
-        task02.transitionTo(State.RUNNING);
+        final StreamTask task00 = statefulTask(taskId00, 
taskId00ChangelogPartitions)
+            .inState(State.RUNNING)
+            .withInputPartitions(taskId00Partitions)
+            .build();
+        // throws TimeoutException on first call, then processes 2 records
+        final StreamTask task01 = statefulTask(taskId01, 
taskId01ChangelogPartitions)
+            .inState(State.RUNNING)
+            .withInputPartitions(taskId01Partitions)
+            .build();
+        final StreamTask task02 = statefulTask(taskId02, 
taskId02ChangelogPartitions)
+            .inState(State.RUNNING)
+            .withInputPartitions(taskId02Partitions)
+            .build();
 
-        taskManager.addTask(task00);
-        taskManager.addTask(task01);
-        taskManager.addTask(task02);
+        when(task00.process(anyLong()))
+            .thenReturn(true)
+            .thenReturn(true)
+            .thenReturn(false);
 
-        task00.addRecords(
-            t1p0,
-            Arrays.asList(
-                getConsumerRecord(t1p0, 0L),
-                getConsumerRecord(t1p0, 1L)
-            )
-        );
-        task01.addRecords(
-            t1p1,
-            Arrays.asList(
-                getConsumerRecord(t1p1, 0L),
-                getConsumerRecord(t1p1, 1L)
-            )
-        );
-        task02.addRecords(
-            t1p2,
-            Arrays.asList(
-                getConsumerRecord(t1p2, 0L),
-                getConsumerRecord(t1p2, 1L)
-            )
-        );
+        when(task01.process(anyLong()))
+            .thenThrow(timeoutException.get())  // throws TimeoutException
+            .thenReturn(true)
+            .thenReturn(true)
+            .thenReturn(false);
+
+        when(task02.process(anyLong()))
+            .thenReturn(true)
+            .thenReturn(true)
+            .thenReturn(false);
+
+        final TasksRegistry tasks = mock(TasksRegistry.class);
+        when(tasks.activeTasks()).thenReturn(Set.of(task00, task01, task02));
+
+        final TaskManager taskManager = 
setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, tasks);
 
         // should only process 2 records, because task01 throws 
TimeoutException
         assertThat(taskManager.process(1, time), is(2));
-        assertThat(task01.timeout, equalTo(time.milliseconds()));
+        verify(task01).maybeInitTaskTimeoutOrThrow(anyLong(), 
any(TimeoutException.class));
 
-        //  retry without error
+        //  retry without error - clear the timeout and update the mock
         timeoutException.set(null);
         assertThat(taskManager.process(1, time), is(3));

Review Comment:
   Yes - this is not working as you intended it



##########
streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java:
##########
@@ -4203,107 +4175,85 @@ public void shouldNotFailOnTimeoutException() {
         final AtomicReference<TimeoutException> timeoutException = new 
AtomicReference<>();
         timeoutException.set(new TimeoutException("Skip me!"));
 
-        final StateMachineTask task00 = new StateMachineTask(taskId00, 
taskId00Partitions, true, stateManager);
-        task00.transitionTo(State.RESTORING);
-        task00.transitionTo(State.RUNNING);
-        final StateMachineTask task01 = new StateMachineTask(taskId01, 
taskId01Partitions, true, stateManager) {
-            @Override
-            public boolean process(final long wallClockTime) {
-                final TimeoutException exception = timeoutException.get();
-                if (exception != null) {
-                    throw exception;
-                }
-                return true;
-            }
-        };
-        task01.transitionTo(State.RESTORING);
-        task01.transitionTo(State.RUNNING);
-        final StateMachineTask task02 = new StateMachineTask(taskId02, 
taskId02Partitions, true, stateManager);
-        task02.transitionTo(State.RESTORING);
-        task02.transitionTo(State.RUNNING);
+        final StreamTask task00 = statefulTask(taskId00, 
taskId00ChangelogPartitions)
+            .inState(State.RUNNING)
+            .withInputPartitions(taskId00Partitions)
+            .build();
+        // throws TimeoutException on first call, then processes 2 records
+        final StreamTask task01 = statefulTask(taskId01, 
taskId01ChangelogPartitions)
+            .inState(State.RUNNING)
+            .withInputPartitions(taskId01Partitions)
+            .build();
+        final StreamTask task02 = statefulTask(taskId02, 
taskId02ChangelogPartitions)
+            .inState(State.RUNNING)
+            .withInputPartitions(taskId02Partitions)
+            .build();
 
-        taskManager.addTask(task00);
-        taskManager.addTask(task01);
-        taskManager.addTask(task02);
+        when(task00.process(anyLong()))
+            .thenReturn(true)
+            .thenReturn(true)
+            .thenReturn(false);
 
-        task00.addRecords(
-            t1p0,
-            Arrays.asList(
-                getConsumerRecord(t1p0, 0L),
-                getConsumerRecord(t1p0, 1L)
-            )
-        );
-        task01.addRecords(
-            t1p1,
-            Arrays.asList(
-                getConsumerRecord(t1p1, 0L),
-                getConsumerRecord(t1p1, 1L)
-            )
-        );
-        task02.addRecords(
-            t1p2,
-            Arrays.asList(
-                getConsumerRecord(t1p2, 0L),
-                getConsumerRecord(t1p2, 1L)
-            )
-        );
+        when(task01.process(anyLong()))
+            .thenThrow(timeoutException.get())  // throws TimeoutException
+            .thenReturn(true)
+            .thenReturn(true)
+            .thenReturn(false);

Review Comment:
   Yes. I don't think the atomic reference makes sense now



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to