hgeraldino commented on code in PR #15316:
URL: https://github.com/apache/kafka/pull/15316#discussion_r1494816048


##########
connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerSinkTaskMockitoTest.java:
##########
@@ -601,6 +690,567 @@ public void testPartialRevocationAndAssignment() {
         verify(sinkTask, times(4)).put(Collections.emptyList());
     }
 
+    @Test
+    @SuppressWarnings("unchecked")
+    public void testPreCommitFailureAfterPartialRevocationAndAssignment() {
+        createTask(initialState);
+        expectTaskGetTopic();
+
+        workerTask.initialize(TASK_CONFIG);
+        workerTask.initializeAndStart();
+        verifyInitializeTask();
+
+        when(consumer.assignment())
+                .thenReturn(INITIAL_ASSIGNMENT, INITIAL_ASSIGNMENT)
+                .thenReturn(new HashSet<>(Arrays.asList(TOPIC_PARTITION2)))
+                .thenReturn(new HashSet<>(Arrays.asList(TOPIC_PARTITION2)))
+                .thenReturn(new HashSet<>(Arrays.asList(TOPIC_PARTITION2)))
+                .thenReturn(new HashSet<>(Arrays.asList(TOPIC_PARTITION2, 
TOPIC_PARTITION3)))
+                .thenReturn(new HashSet<>(Arrays.asList(TOPIC_PARTITION2, 
TOPIC_PARTITION3)))
+                .thenReturn(new HashSet<>(Arrays.asList(TOPIC_PARTITION2, 
TOPIC_PARTITION3)));
+
+        INITIAL_ASSIGNMENT.forEach(tp -> 
when(consumer.position(tp)).thenReturn(FIRST_OFFSET));
+        when(consumer.position(TOPIC_PARTITION3)).thenReturn(FIRST_OFFSET);
+
+        // First poll; assignment is [TP1, TP2]
+        when(consumer.poll(any(Duration.class)))
+                .thenAnswer((Answer<ConsumerRecords<byte[], byte[]>>) 
invocation -> {
+                    
rebalanceListener.getValue().onPartitionsAssigned(INITIAL_ASSIGNMENT);
+                    return ConsumerRecords.empty();
+                })
+                // Second poll; a single record is delivered from TP1
+                .thenAnswer(expectConsumerPoll(1))
+                // Third poll; assignment changes to [TP2]
+                .thenAnswer(invocation -> {
+                    
rebalanceListener.getValue().onPartitionsRevoked(Collections.singleton(TOPIC_PARTITION));
+                    
rebalanceListener.getValue().onPartitionsAssigned(Collections.emptySet());
+                    return ConsumerRecords.empty();
+                })
+                // Fourth poll; assignment changes to [TP2, TP3]
+                .thenAnswer(invocation -> {
+                    
rebalanceListener.getValue().onPartitionsRevoked(Collections.emptySet());
+                    
rebalanceListener.getValue().onPartitionsAssigned(Collections.singleton(TOPIC_PARTITION3));
+                    return ConsumerRecords.empty();
+                })
+                // Fifth poll; an offset commit takes place
+                .thenAnswer(expectConsumerPoll(0));
+
+        expectConversionAndTransformation(null, new RecordHeaders());
+
+        // First iteration--first call to poll, first consumer assignment
+        workerTask.iteration();
+        // Second iteration--second call to poll, delivery of one record
+        workerTask.iteration();
+        // Third iteration--third call to poll, partial consumer revocation
+        final Map<TopicPartition, OffsetAndMetadata> offsets = new HashMap<>();
+        offsets.put(TOPIC_PARTITION, new OffsetAndMetadata(FIRST_OFFSET + 1));
+        when(sinkTask.preCommit(offsets)).thenReturn(offsets);
+        doAnswer(invocation -> null).when(consumer).commitSync(offsets);
+
+        workerTask.iteration();
+        verify(sinkTask).close(Collections.singleton(TOPIC_PARTITION));
+        verify(sinkTask, times(2)).put(Collections.emptyList());
+
+        // Fourth iteration--fourth call to poll, partial consumer assignment
+        workerTask.iteration();
+
+        verify(sinkTask).open(Collections.singleton(TOPIC_PARTITION3));
+
+        final Map<TopicPartition, OffsetAndMetadata> workerCurrentOffsets = 
new HashMap<>();
+        workerCurrentOffsets.put(TOPIC_PARTITION2, new 
OffsetAndMetadata(FIRST_OFFSET));
+        workerCurrentOffsets.put(TOPIC_PARTITION3, new 
OffsetAndMetadata(FIRST_OFFSET));
+        when(sinkTask.preCommit(workerCurrentOffsets)).thenThrow(new 
ConnectException("Failed to flush"));
+
+        // Fifth iteration--task-requested offset commit with failure in 
SinkTask::preCommit
+        sinkTaskContext.getValue().requestCommit();
+        workerTask.iteration();
+
+        verify(consumer).seek(TOPIC_PARTITION2, FIRST_OFFSET);
+        verify(consumer).seek(TOPIC_PARTITION3, FIRST_OFFSET);
+    }
+
+    @Test
+    public void testWakeupInCommitSyncCausesRetry() {
+        createTask(initialState);
+
+        workerTask.initialize(TASK_CONFIG);
+        time.sleep(30000L);
+        workerTask.initializeAndStart();
+        time.sleep(30000L);
+        verifyInitializeTask();
+
+        expectTaskGetTopic();
+        expectPollInitialAssignment()
+                .thenAnswer(expectConsumerPoll(1))
+                .thenAnswer(invocation -> {
+                    
rebalanceListener.getValue().onPartitionsRevoked(INITIAL_ASSIGNMENT);
+                    
rebalanceListener.getValue().onPartitionsAssigned(INITIAL_ASSIGNMENT);
+                    return ConsumerRecords.empty();
+                });
+        expectConversionAndTransformation(null, new RecordHeaders());
+
+        workerTask.iteration(); // poll for initial assignment
+        time.sleep(30000L);
+
+        final Map<TopicPartition, OffsetAndMetadata> offsets = new HashMap<>();
+        offsets.put(TOPIC_PARTITION, new OffsetAndMetadata(FIRST_OFFSET + 1));
+        offsets.put(TOPIC_PARTITION2, new OffsetAndMetadata(FIRST_OFFSET));
+        when(sinkTask.preCommit(offsets)).thenReturn(offsets);
+
+        // first one raises wakeup
+        doThrow(new WakeupException())
+                // and succeed the second time
+                .doAnswer(invocation -> null)
+                .when(consumer).commitSync(eq(offsets));
+
+        workerTask.iteration(); // first record delivered
+
+        workerTask.iteration(); // now rebalance with the wakeup triggered
+        time.sleep(30000L);
+
+        verify(sinkTask).close(INITIAL_ASSIGNMENT);
+        verify(sinkTask, times(2)).open(INITIAL_ASSIGNMENT);
+
+        INITIAL_ASSIGNMENT.forEach(tp -> {
+            verify(consumer).resume(Collections.singleton(tp));
+        });
+
+        verify(statusListener).onResume(taskId);
+
+        assertSinkMetricValue("partition-count", 2);
+        assertSinkMetricValue("sink-record-read-total", 1.0);
+        assertSinkMetricValue("sink-record-send-total", 1.0);
+        assertSinkMetricValue("sink-record-active-count", 0.0);
+        assertSinkMetricValue("sink-record-active-count-max", 1.0);
+        assertSinkMetricValue("sink-record-active-count-avg", 0.33333);
+        assertSinkMetricValue("offset-commit-seq-no", 1.0);
+        assertSinkMetricValue("offset-commit-completion-total", 1.0);
+        assertSinkMetricValue("offset-commit-skip-total", 0.0);
+        assertTaskMetricValue("status", "running");
+        assertTaskMetricValue("running-ratio", 1.0);
+        assertTaskMetricValue("pause-ratio", 0.0);
+        assertTaskMetricValue("batch-size-max", 1.0);
+        assertTaskMetricValue("batch-size-avg", 1.0);
+        assertTaskMetricValue("offset-commit-max-time-ms", 0.0);
+        assertTaskMetricValue("offset-commit-avg-time-ms", 0.0);
+        assertTaskMetricValue("offset-commit-failure-percentage", 0.0);
+        assertTaskMetricValue("offset-commit-success-percentage", 1.0);
+    }
+
+    @Test
+    @SuppressWarnings("unchecked")
+    public void testWakeupNotThrownDuringShutdown() {
+        createTask(initialState);
+
+        workerTask.initialize(TASK_CONFIG);
+        workerTask.initializeAndStart();
+        verifyInitializeTask();
+
+        expectTaskGetTopic();
+        expectPollInitialAssignment()
+                .thenAnswer(expectConsumerPoll(1))
+                .thenAnswer(invocation -> {
+                    // stop the task during its second iteration
+                    workerTask.stop();
+                    return new ConsumerRecords<>(Collections.emptyMap());
+                });
+        expectConversionAndTransformation(null, new RecordHeaders());
+
+        final Map<TopicPartition, OffsetAndMetadata> offsets = new HashMap<>();
+        offsets.put(TOPIC_PARTITION, new OffsetAndMetadata(FIRST_OFFSET + 1));
+        offsets.put(TOPIC_PARTITION2, new OffsetAndMetadata(FIRST_OFFSET));
+        when(sinkTask.preCommit(offsets)).thenReturn(offsets);
+
+        // fail the first time
+        doThrow(new WakeupException())
+                // and succeed the second time
+                .doAnswer(invocation -> null)
+                .when(consumer).commitSync(eq(offsets));
+
+        workerTask.execute();
+
+        assertEquals(0, workerTask.commitFailures());
+        verify(consumer).wakeup();
+        verify(sinkTask).close(any(Collection.class));
+    }
+
+    @Test
+    public void testRequestCommit() {
+        createTask(initialState);
+
+        workerTask.initialize(TASK_CONFIG);
+        workerTask.initializeAndStart();
+        verifyInitializeTask();
+
+        expectTaskGetTopic();
+        expectPollInitialAssignment()
+                .thenAnswer(expectConsumerPoll(1))
+                .thenAnswer(expectConsumerPoll(0));
+        expectConversionAndTransformation(null, new RecordHeaders());
+
+        // Initial assignment
+        time.sleep(30000L);
+        workerTask.iteration();
+        assertSinkMetricValue("partition-count", 2);
+
+        final Map<TopicPartition, OffsetAndMetadata> offsets = new HashMap<>();
+        offsets.put(TOPIC_PARTITION, new OffsetAndMetadata(FIRST_OFFSET + 1));
+        offsets.put(TOPIC_PARTITION2, new OffsetAndMetadata(FIRST_OFFSET));
+        when(sinkTask.preCommit(offsets)).thenReturn(offsets);
+
+        // First record delivered
+        workerTask.iteration();
+        assertSinkMetricValue("partition-count", 2);
+        assertSinkMetricValue("sink-record-read-total", 1.0);
+        assertSinkMetricValue("sink-record-send-total", 1.0);
+        assertSinkMetricValue("sink-record-active-count", 1.0);
+        assertSinkMetricValue("sink-record-active-count-max", 1.0);
+        assertSinkMetricValue("sink-record-active-count-avg", 0.333333);
+        assertSinkMetricValue("offset-commit-seq-no", 0.0);
+        assertSinkMetricValue("offset-commit-completion-total", 0.0);
+        assertSinkMetricValue("offset-commit-skip-total", 0.0);
+        assertTaskMetricValue("status", "running");
+        assertTaskMetricValue("running-ratio", 1.0);
+        assertTaskMetricValue("pause-ratio", 0.0);
+        assertTaskMetricValue("batch-size-max", 1.0);
+        assertTaskMetricValue("batch-size-avg", 0.5);
+        assertTaskMetricValue("offset-commit-failure-percentage", 0.0);
+        assertTaskMetricValue("offset-commit-success-percentage", 0.0);
+
+        // Grab the commit time prior to requesting a commit.
+        // This time should advance slightly after committing.
+        // KAFKA-8229
+        final long previousCommitValue = workerTask.getNextCommit();
+        sinkTaskContext.getValue().requestCommit();
+        assertTrue(sinkTaskContext.getValue().isCommitRequested());
+        assertNotEquals(offsets, workerTask.lastCommittedOffsets());
+
+        ArgumentCaptor<OffsetCommitCallback> callback = 
ArgumentCaptor.forClass(OffsetCommitCallback.class);
+        time.sleep(10000L);
+        workerTask.iteration(); // triggers the commit
+        verify(consumer).commitAsync(eq(offsets), callback.capture());
+        callback.getValue().onComplete(offsets, null);
+        time.sleep(10000L);
+
+        assertFalse(sinkTaskContext.getValue().isCommitRequested()); // should 
have been cleared
+        assertEquals(offsets, workerTask.lastCommittedOffsets());
+        assertEquals(0, workerTask.commitFailures());
+
+        // Assert the next commit time advances slightly, the amount it 
advances
+        // is the normal commit time less the two sleeps since it started each
+        // of those sleeps were 10 seconds.
+        // KAFKA-8229
+        assertEquals("Should have only advanced by 40 seconds",
+                previousCommitValue  +
+                        (WorkerConfig.OFFSET_COMMIT_INTERVAL_MS_DEFAULT - 
10000L * 2),
+                workerTask.getNextCommit());
+
+        assertSinkMetricValue("partition-count", 2);
+        assertSinkMetricValue("sink-record-read-total", 1.0);
+        assertSinkMetricValue("sink-record-send-total", 1.0);
+        assertSinkMetricValue("sink-record-active-count", 0.0);
+        assertSinkMetricValue("sink-record-active-count-max", 1.0);
+        assertSinkMetricValue("sink-record-active-count-avg", 0.2);
+        assertSinkMetricValue("offset-commit-seq-no", 1.0);
+        assertSinkMetricValue("offset-commit-completion-total", 1.0);
+        assertSinkMetricValue("offset-commit-skip-total", 0.0);
+        assertTaskMetricValue("status", "running");
+        assertTaskMetricValue("running-ratio", 1.0);
+        assertTaskMetricValue("pause-ratio", 0.0);
+        assertTaskMetricValue("batch-size-max", 1.0);
+        assertTaskMetricValue("batch-size-avg", 0.33333);
+        assertTaskMetricValue("offset-commit-max-time-ms", 0.0);
+        assertTaskMetricValue("offset-commit-avg-time-ms", 0.0);
+        assertTaskMetricValue("offset-commit-failure-percentage", 0.0);
+        assertTaskMetricValue("offset-commit-success-percentage", 1.0);
+    }
+
+    @Test
+    public void testPreCommit() {
+        createTask(initialState);
+
+        workerTask.initialize(TASK_CONFIG);
+        workerTask.initializeAndStart();
+        verifyInitializeTask();
+
+        expectTaskGetTopic();
+        expectPollInitialAssignment()
+                .thenAnswer(expectConsumerPoll(2))
+                .thenAnswer(expectConsumerPoll(0));
+        expectConversionAndTransformation(null, new RecordHeaders());
+
+        workerTask.iteration(); // iter 1 -- initial assignment
+
+        final Map<TopicPartition, OffsetAndMetadata> workerStartingOffsets = 
new HashMap<>();
+        workerStartingOffsets.put(TOPIC_PARTITION, new 
OffsetAndMetadata(FIRST_OFFSET));
+        workerStartingOffsets.put(TOPIC_PARTITION2, new 
OffsetAndMetadata(FIRST_OFFSET));
+
+        assertEquals(workerStartingOffsets, workerTask.currentOffsets());
+
+        final Map<TopicPartition, OffsetAndMetadata> workerCurrentOffsets = 
new HashMap<>();
+        workerCurrentOffsets.put(TOPIC_PARTITION, new 
OffsetAndMetadata(FIRST_OFFSET + 2));
+        workerCurrentOffsets.put(TOPIC_PARTITION2, new 
OffsetAndMetadata(FIRST_OFFSET));
+
+        final Map<TopicPartition, OffsetAndMetadata> taskOffsets = new 
HashMap<>();
+        taskOffsets.put(TOPIC_PARTITION, new OffsetAndMetadata(FIRST_OFFSET + 
1)); // act like FIRST_OFFSET+2 has not yet been flushed by the task
+        taskOffsets.put(TOPIC_PARTITION2, new OffsetAndMetadata(FIRST_OFFSET + 
1)); // should be ignored because > current offset
+        taskOffsets.put(new TopicPartition(TOPIC, 3), new 
OffsetAndMetadata(FIRST_OFFSET)); // should be ignored because this partition 
is not assigned
+
+        when(sinkTask.preCommit(workerCurrentOffsets)).thenReturn(taskOffsets);
+
+        workerTask.iteration(); // iter 2 -- deliver 2 records
+
+        final Map<TopicPartition, OffsetAndMetadata> committableOffsets = new 
HashMap<>();
+        committableOffsets.put(TOPIC_PARTITION, new 
OffsetAndMetadata(FIRST_OFFSET + 1));
+        committableOffsets.put(TOPIC_PARTITION2, new 
OffsetAndMetadata(FIRST_OFFSET));
+
+        assertEquals(workerCurrentOffsets, workerTask.currentOffsets());
+        assertEquals(workerStartingOffsets, workerTask.lastCommittedOffsets());
+
+        sinkTaskContext.getValue().requestCommit();
+        workerTask.iteration(); // iter 3 -- commit
+
+        // Expect extra invalid topic partition to be filtered, which causes 
the consumer assignment to be logged
+        ArgumentCaptor<OffsetCommitCallback> callback = 
ArgumentCaptor.forClass(OffsetCommitCallback.class);
+        verify(consumer).commitAsync(eq(committableOffsets), 
callback.capture());
+        callback.getValue().onComplete(committableOffsets, null);
+
+        assertEquals(committableOffsets, workerTask.lastCommittedOffsets());
+    }
+
+    @Test
+    public void testPreCommitFailure() {
+        createTask(initialState);
+
+        workerTask.initialize(TASK_CONFIG);
+        workerTask.initializeAndStart();
+        verifyInitializeTask();
+
+        expectTaskGetTopic();
+        expectPollInitialAssignment()
+                // Put one message through the task to get some offsets to 
commit
+                .thenAnswer(expectConsumerPoll(2))
+                .thenAnswer(expectConsumerPoll(0));
+
+        expectConversionAndTransformation(null, new RecordHeaders());
+
+        workerTask.iteration(); // iter 1 -- initial assignment
+
+        workerTask.iteration(); // iter 2 -- deliver 2 records
+
+        // iter 3
+        final Map<TopicPartition, OffsetAndMetadata> workerCurrentOffsets = 
new HashMap<>();
+        workerCurrentOffsets.put(TOPIC_PARTITION, new 
OffsetAndMetadata(FIRST_OFFSET + 2));
+        workerCurrentOffsets.put(TOPIC_PARTITION2, new 
OffsetAndMetadata(FIRST_OFFSET));
+        when(sinkTask.preCommit(workerCurrentOffsets)).thenThrow(new 
ConnectException("Failed to flush"));
+
+        sinkTaskContext.getValue().requestCommit();
+        workerTask.iteration(); // iter 3 -- commit
+
+        verify(consumer).seek(TOPIC_PARTITION, FIRST_OFFSET);
+        verify(consumer).seek(TOPIC_PARTITION2, FIRST_OFFSET);
+    }
+
+    @Test
+    public void testIgnoredCommit() {
+        createTask(initialState);
+
+        workerTask.initialize(TASK_CONFIG);
+        workerTask.initializeAndStart();
+        verifyInitializeTask();
+
+        expectTaskGetTopic();
+        // iter 1
+        expectPollInitialAssignment()
+                // iter 2
+                .thenAnswer(expectConsumerPoll(1))
+                // no actual consumer.commit() triggered
+                .thenAnswer(expectConsumerPoll(0));
+
+        expectConversionAndTransformation(null, new RecordHeaders());
+
+        workerTask.iteration(); // iter 1 -- initial assignment
+
+        final Map<TopicPartition, OffsetAndMetadata> workerStartingOffsets = 
new HashMap<>();
+        workerStartingOffsets.put(TOPIC_PARTITION, new 
OffsetAndMetadata(FIRST_OFFSET));
+        workerStartingOffsets.put(TOPIC_PARTITION2, new 
OffsetAndMetadata(FIRST_OFFSET));
+
+        assertEquals(workerStartingOffsets, workerTask.currentOffsets());
+        assertEquals(workerStartingOffsets, workerTask.lastCommittedOffsets());
+
+        workerTask.iteration(); // iter 2 -- deliver 2 records
+
+        final Map<TopicPartition, OffsetAndMetadata> workerCurrentOffsets = 
new HashMap<>();
+        workerCurrentOffsets.put(TOPIC_PARTITION, new 
OffsetAndMetadata(FIRST_OFFSET + 1));
+        workerCurrentOffsets.put(TOPIC_PARTITION2, new 
OffsetAndMetadata(FIRST_OFFSET));
+
+        
when(sinkTask.preCommit(workerCurrentOffsets)).thenReturn(workerStartingOffsets);
+
+        sinkTaskContext.getValue().requestCommit();
+        workerTask.iteration(); // iter 3 -- commit
+    }
+
+    // Test that the commitTimeoutMs timestamp is correctly computed and 
checked in WorkerSinkTask.iteration()
+    // when there is a long running commit in process. See KAFKA-4942 for more 
information.
+    @Test
+    public void testLongRunningCommitWithoutTimeout() throws 
InterruptedException {
+        createTask(initialState);
+
+        workerTask.initialize(TASK_CONFIG);
+        workerTask.initializeAndStart();
+        verifyInitializeTask();
+
+        expectTaskGetTopic();
+        expectPollInitialAssignment()
+                .thenAnswer(expectConsumerPoll(1))
+                // no actual consumer.commit() triggered
+                .thenAnswer(expectConsumerPoll(0));
+        expectConversionAndTransformation(null, new RecordHeaders());
+
+        final Map<TopicPartition, OffsetAndMetadata> workerStartingOffsets = 
new HashMap<>();
+        workerStartingOffsets.put(TOPIC_PARTITION, new 
OffsetAndMetadata(FIRST_OFFSET));
+        workerStartingOffsets.put(TOPIC_PARTITION2, new 
OffsetAndMetadata(FIRST_OFFSET));
+
+        workerTask.iteration(); // iter 1 -- initial assignment
+        assertEquals(workerStartingOffsets, workerTask.currentOffsets());
+        assertEquals(workerStartingOffsets, workerTask.lastCommittedOffsets());
+
+        time.sleep(WorkerConfig.OFFSET_COMMIT_TIMEOUT_MS_DEFAULT);
+        workerTask.iteration(); // iter 2 -- deliver 2 records
+
+        final Map<TopicPartition, OffsetAndMetadata> workerCurrentOffsets = 
new HashMap<>();
+        workerCurrentOffsets.put(TOPIC_PARTITION, new 
OffsetAndMetadata(FIRST_OFFSET + 1));
+        workerCurrentOffsets.put(TOPIC_PARTITION2, new 
OffsetAndMetadata(FIRST_OFFSET));
+
+        // iter 3 - note that we return the current offset to indicate they 
should be committed
+        
when(sinkTask.preCommit(workerCurrentOffsets)).thenReturn(workerCurrentOffsets);
+
+        // We need to delay the result of trying to commit offsets to Kafka 
via the consumer.commitAsync
+        // method. We do this so that we can test that we do not erroneously 
mark a commit as timed out
+        // while it is still running and under time. To fake this for tests we 
have the commit run in a
+        // separate thread and wait for a latch which we control back in the 
main thread.
+        final ExecutorService executor = Executors.newSingleThreadExecutor();
+        final CountDownLatch latch = new CountDownLatch(1);
+
+        doAnswer(invocation -> {
+            @SuppressWarnings("unchecked")
+            final Map<TopicPartition, OffsetAndMetadata> offsets = 
invocation.getArgument(0);
+            @SuppressWarnings("unchecked")
+            final OffsetCommitCallback callback = invocation.getArgument(1);
+
+            executor.execute(() -> {
+                try {
+                    latch.await();
+                } catch (InterruptedException e) {
+                    Thread.currentThread().interrupt();
+                }
+
+                callback.onComplete(offsets, null);
+            });
+
+            return null;
+        }).when(consumer).commitAsync(eq(workerCurrentOffsets), 
any(OffsetCommitCallback.class));
+
+        sinkTaskContext.getValue().requestCommit();
+        workerTask.iteration(); // iter 3 -- commit in progress
+
+        // Make sure the "committing" flag didn't immediately get flipped back 
to false due to an incorrect timeout
+        assertTrue("Expected worker to be in the process of committing 
offsets", workerTask.isCommitting());
+
+        // Let the async commit finish and wait for it to end
+        latch.countDown();
+        executor.shutdown();
+        executor.awaitTermination(30, TimeUnit.SECONDS);
+
+        assertEquals(workerCurrentOffsets, workerTask.currentOffsets());
+        assertEquals(workerCurrentOffsets, workerTask.lastCommittedOffsets());
+    }
+
+    @SuppressWarnings("unchecked")
+    @Test
+    public void testSinkTasksHandleCloseErrors() {
+        createTask(initialState);
+
+        workerTask.initialize(TASK_CONFIG);
+        workerTask.initializeAndStart();
+        verifyInitializeTask();
+
+        expectTaskGetTopic();
+        expectPollInitialAssignment()
+                // Put one message through the task to get some offsets to 
commit
+                .thenAnswer(expectConsumerPoll(1))
+                .thenAnswer(expectConsumerPoll(1));
+
+        expectConversionAndTransformation(null, new RecordHeaders());
+
+        doAnswer(invocation -> null)
+                // Throw an exception on the next put to trigger shutdown 
behavior
+                // This exception is the true "cause" of the failure
+                .doAnswer(invocation -> {
+                    workerTask.stop();
+                    return null;
+                })
+                .when(sinkTask).put(anyList());
+
+        Throwable closeException = new RuntimeException();
+        when(sinkTask.preCommit(anyMap())).thenReturn(Collections.emptyMap());
+
+        // Throw another exception while closing the task's assignment
+        doThrow(closeException).when(sinkTask).close(any(Collection.class));
+
+        try {
+            workerTask.execute();
+            fail("workerTask.execute should have thrown an exception");
+        } catch (RuntimeException e) {
+            assertSame("Exception from close should propagate as-is", 
closeException, e);
+        }

Review Comment:
   I think the goal is not just to ensure that an exception is thrown, but that 
the exception object instance is exactly the one being set by the mock (and not 
some other exception)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: jira-unsubscr...@kafka.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to