This is an automated email from the ASF dual-hosted git repository.

mjsax pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 2181ddbb039 KAFKA-18943: Kafka Streams incorrectly commits TX during 
task revokation (#19164)
2181ddbb039 is described below

commit 2181ddbb039ff688f5ff41784d943cb579f7575c
Author: Matthias J. Sax <[email protected]>
AuthorDate: Thu Mar 13 09:37:11 2025 -0700

    KAFKA-18943: Kafka Streams incorrectly commits TX during task revokation 
(#19164)
    
    Fixes two issues:
     - only commit TX if no revoked tasks need to be committed
     - commit revoked tasks after punctuation triggered
    
    Reviewers: Lucas Brutschy <[email protected]>, Anna Sophie 
Blee-Goldman <[email protected]>, Bruno Cadonna <[email protected]>, Bill 
Bejeck <[email protected]>
---
 .../streams/integration/EosIntegrationTest.java    | 171 ++++++++++++++--
 .../integration/RebalanceIntegrationTest.java      | 221 +++++++++++++++++++++
 .../streams/integration/TestTaskAssignor.java      |  42 ++++
 .../streams/processor/internals/TaskManager.java   |  24 ++-
 .../processor/internals/TaskManagerTest.java       |  70 ++++++-
 5 files changed, 501 insertions(+), 27 deletions(-)

diff --git 
a/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/EosIntegrationTest.java
 
b/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/EosIntegrationTest.java
index 9d6940a3b5e..95c19fd9cb8 100644
--- 
a/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/EosIntegrationTest.java
+++ 
b/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/EosIntegrationTest.java
@@ -22,10 +22,13 @@ import org.apache.kafka.clients.consumer.Consumer;
 import org.apache.kafka.clients.consumer.ConsumerConfig;
 import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.clients.consumer.KafkaConsumer;
+import org.apache.kafka.clients.producer.KafkaProducer;
+import org.apache.kafka.clients.producer.Producer;
 import org.apache.kafka.clients.producer.ProducerConfig;
 import org.apache.kafka.common.IsolationLevel;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.serialization.ByteArrayDeserializer;
+import org.apache.kafka.common.serialization.ByteArraySerializer;
 import org.apache.kafka.common.serialization.Deserializer;
 import org.apache.kafka.common.serialization.IntegerDeserializer;
 import org.apache.kafka.common.serialization.IntegerSerializer;
@@ -47,9 +50,11 @@ import 
org.apache.kafka.streams.integration.utils.IntegrationTestUtils;
 import org.apache.kafka.streams.kstream.KStream;
 import org.apache.kafka.streams.processor.StateRestoreListener;
 import org.apache.kafka.streams.processor.TaskId;
+import org.apache.kafka.streams.processor.api.ContextualProcessor;
 import org.apache.kafka.streams.processor.api.Processor;
 import org.apache.kafka.streams.processor.api.ProcessorContext;
 import org.apache.kafka.streams.processor.api.Record;
+import org.apache.kafka.streams.processor.internals.DefaultKafkaClientSupplier;
 import org.apache.kafka.streams.processor.internals.StreamThread;
 import org.apache.kafka.streams.query.QueryResult;
 import org.apache.kafka.streams.query.RangeQuery;
@@ -79,6 +84,7 @@ import java.io.IOException;
 import java.nio.file.Paths;
 import java.time.Duration;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
@@ -122,7 +128,10 @@ public class EosIntegrationTest {
 
     public static final EmbeddedKafkaCluster CLUSTER = new 
EmbeddedKafkaCluster(
         NUM_BROKERS,
-        
Utils.mkProperties(Collections.singletonMap("auto.create.topics.enable", 
"true"))
+        Utils.mkProperties(mkMap(
+            mkEntry("auto.create.topics.enable", "true"),
+            mkEntry("transaction.max.timeout.ms", "" + Integer.MAX_VALUE)
+        ))
     );
 
     @BeforeAll
@@ -871,6 +880,7 @@ public class EosIntegrationTest {
                                        final String storeName,
                                        final long startingOffset,
                                        final long endingOffset) {}
+
             @Override
             public void onBatchRestored(final TopicPartition topicPartition,
                                         final String storeName,
@@ -883,6 +893,7 @@ public class EosIntegrationTest {
                     }
                 }
             }
+
             @Override
             public void onRestoreEnd(final TopicPartition topicPartition,
                                      final String storeName,
@@ -892,9 +903,7 @@ public class EosIntegrationTest {
         ensureCommittedRecordsInTopicPartition(
             applicationId + "-" + stateStoreName + "-changelog",
             partitionToVerify,
-            2000,
-            IntegerDeserializer.class,
-            IntegerDeserializer.class
+            2000
         );
         throwException.set(true);
         final List<KeyValue<Integer, Integer>> recordBatch2 = 
IntStream.range(endKey - 1000, endKey).mapToObj(i -> KeyValue.pair(i, 
0)).collect(Collectors.toList());
@@ -922,6 +931,129 @@ public class EosIntegrationTest {
         );
     }
 
+
+    private final AtomicReference<String> transactionalProducerId = new 
AtomicReference<>();
+
+    private class TestClientSupplier extends DefaultKafkaClientSupplier {
+        @Override
+        public Producer<byte[], byte[]> getProducer(final Map<String, Object> 
config) {
+            transactionalProducerId.compareAndSet(null, (String) 
config.get(ProducerConfig.TRANSACTIONAL_ID_CONFIG));
+
+            return new KafkaProducer<>(config, new ByteArraySerializer(), new 
ByteArraySerializer());
+        }
+    }
+
+    static final AtomicReference<TaskId> TASK_WITH_DATA = new 
AtomicReference<>();
+    static final AtomicBoolean DID_REVOKE_IDLE_TASK = new AtomicBoolean(false);
+
+    @Test
+    public void 
shouldNotCommitActiveTasksWithPendingInputIfRevokedTaskDidNotMakeProgress() 
throws Exception {
+        final AtomicBoolean requestCommit = new AtomicBoolean(false);
+
+        final StreamsBuilder builder = new StreamsBuilder();
+        builder.<Long, Long>stream(MULTI_PARTITION_INPUT_TOPIC)
+            .process(() -> new ContextualProcessor<Long, Long, Long, Long>() {
+                @Override
+                public void process(final Record<Long, Long> record) {
+                    if (!requestCommit.get()) {
+                        if (TASK_WITH_DATA.get() != null) {
+                            throw new IllegalStateException("Should only 
process single record using single task");
+                        }
+                        TASK_WITH_DATA.set(context().taskId());
+                    }
+
+                    
context().forward(record.withValue(context().recordMetadata().get().offset()));
+
+                    if (requestCommit.get()) {
+                        context().commit();
+                    }
+                }
+            })
+            .to(SINGLE_PARTITION_OUTPUT_TOPIC);
+
+        final Properties properties = new Properties();
+        properties.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, 
StreamsConfig.EXACTLY_ONCE_V2);
+        properties.put(StreamsConfig.STATESTORE_CACHE_MAX_BYTES_CONFIG, 0);
+        properties.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 
Integer.MAX_VALUE);
+        
properties.put(StreamsConfig.consumerPrefix(ConsumerConfig.MAX_POLL_RECORDS_CONFIG),
 1);
+        
properties.put(StreamsConfig.consumerPrefix(ConsumerConfig.METADATA_MAX_AGE_CONFIG),
 "1000");
+        
properties.put(StreamsConfig.consumerPrefix(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG),
 "earliest");
+        
properties.put(StreamsConfig.consumerPrefix(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG),
 MAX_POLL_INTERVAL_MS - 1);
+        
properties.put(StreamsConfig.consumerPrefix(ConsumerConfig.MAX_POLL_INTERVAL_MS_CONFIG),
 MAX_POLL_INTERVAL_MS);
+        
properties.put(StreamsConfig.producerPrefix(ProducerConfig.TRANSACTION_TIMEOUT_CONFIG),
 Integer.MAX_VALUE);
+        properties.put(StreamsConfig.TASK_ASSIGNOR_CLASS_CONFIG, 
TestTaskAssignor.class.getName());
+
+        final Properties config = StreamsTestUtils.getStreamsConfig(
+            applicationId,
+            CLUSTER.bootstrapServers(),
+            Serdes.LongSerde.class.getName(),
+            Serdes.LongSerde.class.getName(),
+            properties
+        );
+
+
+        try (final KafkaStreams streams = new KafkaStreams(builder.build(), 
config, new TestClientSupplier())) {
+            startApplicationAndWaitUntilRunning(streams);
+
+            // PHASE 1:
+            // write single input record, and wait for it to get into output 
topic (uncommitted)
+            // StreamThread-1 now has a task with progress, and one task w/o 
progress
+            final List<KeyValue<Long, Long>> inputDataTask0 = 
Collections.singletonList(KeyValue.pair(1L, -1L));
+
+            IntegrationTestUtils.produceKeyValuesSynchronously(
+                MULTI_PARTITION_INPUT_TOPIC,
+                inputDataTask0,
+                TestUtils.producerConfig(CLUSTER.bootstrapServers(), 
LongSerializer.class, LongSerializer.class),
+                CLUSTER.time
+            );
+
+            final List<KeyValue<Long, Long>> expectedUncommittedResultTask0 = 
Collections.singletonList(KeyValue.pair(1L, 0L));
+            final List<KeyValue<Long, Long>> uncommittedRecordsBeforeRebalance 
= readResult(SINGLE_PARTITION_OUTPUT_TOPIC, 
expectedUncommittedResultTask0.size(), null);
+            checkResultPerKey(uncommittedRecordsBeforeRebalance, 
expectedUncommittedResultTask0, "The uncommitted records do not match what 
expected");
+
+            // PHASE 2:
+            // add second thread, to trigger rebalance
+            // expect idle task to get revoked -- this should not trigger a TX 
commit
+            streams.addStreamThread();
+
+            waitForCondition(DID_REVOKE_IDLE_TASK::get, "Idle Task was not 
revoked as expected.");
+
+            // best-effort sanity check (might pass and not detect issue in 
slow environments)
+            try {
+                readResult(SINGLE_PARTITION_OUTPUT_TOPIC, 1, "consumer", 
10_000L);
+                throw new Exception("Should not be able to read records, as 
they should have not been committed.");
+            } catch (final AssertionError expected) {
+                // swallow -- we expect to not be able to read uncommitted 
data, but time-out
+            }
+
+            // PHASE 3:
+            // fence producer to abort pending TX of first input record
+            // expect rebalancing and recovery until both input record are 
processed
+            requestCommit.set(true);
+
+            // produce into input topic to fence KS producer
+            final List<KeyValue<Long, Long>> inputDataTask0Fencing = 
Collections.singletonList(KeyValue.pair(4L, -3L));
+
+            final Properties producerConfigs = new Properties();
+            
producerConfigs.setProperty(ProducerConfig.TRANSACTIONAL_ID_CONFIG, 
transactionalProducerId.get());
+            IntegrationTestUtils.produceKeyValuesSynchronously(
+                MULTI_PARTITION_INPUT_TOPIC,
+                inputDataTask0Fencing,
+                TestUtils.producerConfig(CLUSTER.bootstrapServers(), 
LongSerializer.class, LongSerializer.class, producerConfigs),
+                CLUSTER.time,
+                true
+            );
+
+            final List<KeyValue<Long, Long>> 
expectedUncommittedResultAfterError = Arrays.asList(KeyValue.pair(1L, 0L), 
KeyValue.pair(1L, 0L), KeyValue.pair(4L, 1L));
+            final List<KeyValue<Long, Long>> uncommittedRecordsAfterError = 
readResult(SINGLE_PARTITION_OUTPUT_TOPIC, 
expectedUncommittedResultAfterError.size(), null);
+            checkResultPerKey(uncommittedRecordsAfterError, 
expectedUncommittedResultAfterError, "The committed records do not match what 
expected");
+        }
+
+        final List<KeyValue<Long, Long>> expectedFinalResult = 
Arrays.asList(KeyValue.pair(1L, 0L), KeyValue.pair(4L, 1L));
+        final List<KeyValue<Long, Long>> finalResult = 
readResult(SINGLE_PARTITION_OUTPUT_TOPIC, 2, "committed-only-consumer");
+        checkResultPerKey(finalResult, expectedFinalResult, "The committed 
records do not match what expected");
+    }
+
     private void verifyOffsetsAreInCheckpoint(final int partition) throws 
IOException {
         final String stateStoreDir = stateTmpDir + File.separator + "appDir" + 
File.separator + applicationId + File.separator + "0_" + partition + 
File.separator;
 
@@ -936,8 +1068,8 @@ public class EosIntegrationTest {
             KafkaConsumer<String, String> consumer = new KafkaConsumer<>(
                 consumerConfig(
                     CLUSTER.bootstrapServers(),
-                    Serdes.ByteArray().deserializer().getClass(),
-                    Serdes.ByteArray().deserializer().getClass()
+                    ByteArrayDeserializer.class,
+                    ByteArrayDeserializer.class
                 )
             )
         ) {
@@ -979,7 +1111,6 @@ public class EosIntegrationTest {
         return data;
     }
 
-    @SuppressWarnings("deprecation")
     // the threads should no longer fail one thread one at a time
     private KafkaStreams getKafkaStreams(final String dummyHostName,
                                          final boolean withState,
@@ -1129,14 +1260,22 @@ public class EosIntegrationTest {
     private List<KeyValue<Long, Long>> readResult(final String topic,
                                                   final int numberOfRecords,
                                                   final String groupId) throws 
Exception {
-        return readResult(topic, numberOfRecords, LongDeserializer.class, 
LongDeserializer.class, groupId);
+        return readResult(topic, numberOfRecords, LongDeserializer.class, 
LongDeserializer.class, groupId, DEFAULT_TIMEOUT);
+    }
+
+    private List<KeyValue<Long, Long>> readResult(final String topic,
+                                                  final int numberOfRecords,
+                                                  final String groupId,
+                                                  final long timeout) throws 
Exception {
+        return readResult(topic, numberOfRecords, LongDeserializer.class, 
LongDeserializer.class, groupId, timeout);
     }
 
     private <K, V> List<KeyValue<K, V>> readResult(final String topic,
                                                    final int numberOfRecords,
                                                    final Class<? extends 
Deserializer<K>> keyDeserializer,
                                                    final Class<? extends 
Deserializer<V>> valueDeserializer,
-                                                   final String groupId) 
throws Exception {
+                                                   final String groupId,
+                                                   final long timeout) throws 
Exception {
         if (groupId != null) {
             return IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived(
                 TestUtils.consumerConfig(
@@ -1148,7 +1287,8 @@ public class EosIntegrationTest {
                         ConsumerConfig.ISOLATION_LEVEL_CONFIG,
                         IsolationLevel.READ_COMMITTED.toString()))),
                 topic,
-                numberOfRecords
+                numberOfRecords,
+                timeout
             );
         }
 
@@ -1156,15 +1296,14 @@ public class EosIntegrationTest {
         return IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived(
             TestUtils.consumerConfig(CLUSTER.bootstrapServers(), 
keyDeserializer, valueDeserializer),
             topic,
-            numberOfRecords
+            numberOfRecords,
+            timeout
         );
     }
 
     private <K, V> void ensureCommittedRecordsInTopicPartition(final String 
topic,
                                                                final int 
partition,
-                                                               final int 
numberOfRecords,
-                                                               final Class<? 
extends Deserializer<K>> keyDeserializer,
-                                                               final Class<? 
extends Deserializer<V>> valueDeserializer) throws Exception {
+                                                               final int 
numberOfRecords) throws Exception {
         final long timeoutMs = 2 * DEFAULT_TIMEOUT;
         final int maxTries = 10;
         final long deadline = System.currentTimeMillis() + timeoutMs;
@@ -1174,8 +1313,8 @@ public class EosIntegrationTest {
                 TestUtils.consumerConfig(
                     CLUSTER.bootstrapServers(),
                     CONSUMER_GROUP_ID,
-                    keyDeserializer,
-                    valueDeserializer,
+                    IntegerDeserializer.class,
+                    IntegerDeserializer.class,
                     Utils.mkProperties(Collections.singletonMap(
                         ConsumerConfig.ISOLATION_LEVEL_CONFIG,
                         IsolationLevel.READ_COMMITTED.toString())
diff --git 
a/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/RebalanceIntegrationTest.java
 
b/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/RebalanceIntegrationTest.java
new file mode 100644
index 00000000000..84ae9c2ee40
--- /dev/null
+++ 
b/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/RebalanceIntegrationTest.java
@@ -0,0 +1,221 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.integration;
+
+import org.apache.kafka.clients.consumer.ConsumerConfig;
+import org.apache.kafka.clients.producer.ProducerConfig;
+import org.apache.kafka.common.serialization.LongDeserializer;
+import org.apache.kafka.common.serialization.Serdes;
+import org.apache.kafka.common.utils.Utils;
+import org.apache.kafka.streams.KafkaStreams;
+import org.apache.kafka.streams.KeyValue;
+import org.apache.kafka.streams.StreamsBuilder;
+import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster;
+import org.apache.kafka.streams.integration.utils.IntegrationTestUtils;
+import org.apache.kafka.streams.processor.Cancellable;
+import org.apache.kafka.streams.processor.PunctuationType;
+import org.apache.kafka.streams.processor.api.Processor;
+import org.apache.kafka.streams.processor.api.ProcessorContext;
+import org.apache.kafka.streams.processor.api.Record;
+import org.apache.kafka.test.StreamsTestUtils;
+import org.apache.kafka.test.TestUtils;
+
+import org.junit.jupiter.api.AfterAll;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Tag;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.Timeout;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.time.Duration;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Properties;
+import java.util.Set;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicReference;
+
+import static org.apache.kafka.common.utils.Utils.mkEntry;
+import static org.apache.kafka.common.utils.Utils.mkMap;
+import static 
org.apache.kafka.streams.integration.utils.IntegrationTestUtils.startApplicationAndWaitUntilRunning;
+import static org.hamcrest.CoreMatchers.equalTo;
+import static org.hamcrest.MatcherAssert.assertThat;
+
+@Tag("integration")
+@Timeout(600)
+public class RebalanceIntegrationTest {
+    private static final Logger LOG = 
LoggerFactory.getLogger(RebalanceIntegrationTest.class);
+    private static final int NUM_BROKERS = 3;
+    private static final int MAX_POLL_INTERVAL_MS = 30_000;
+
+    public static final EmbeddedKafkaCluster CLUSTER = new 
EmbeddedKafkaCluster(
+        NUM_BROKERS,
+        Utils.mkProperties(mkMap(
+            mkEntry("auto.create.topics.enable", "true"),
+            mkEntry("transaction.max.timeout.ms", "" + Integer.MAX_VALUE)
+        ))
+    );
+
+    @BeforeAll
+    public static void startCluster() throws IOException {
+        CLUSTER.start();
+    }
+
+    @AfterAll
+    public static void closeCluster() {
+        CLUSTER.stop();
+    }
+
+
+    private String applicationId;
+    private static final int NUM_TOPIC_PARTITIONS = 2;
+    private static final String MULTI_PARTITION_INPUT_TOPIC = 
"multiPartitionInputTopic";
+    private static final String SINGLE_PARTITION_OUTPUT_TOPIC = 
"singlePartitionOutputTopic";
+
+    private static final AtomicInteger TEST_NUMBER = new AtomicInteger(0);
+
+    @BeforeEach
+    public void createTopics() throws Exception {
+        applicationId = "appId-" + TEST_NUMBER.getAndIncrement();
+        CLUSTER.deleteTopics(MULTI_PARTITION_INPUT_TOPIC, 
SINGLE_PARTITION_OUTPUT_TOPIC);
+
+        CLUSTER.createTopics(SINGLE_PARTITION_OUTPUT_TOPIC);
+        CLUSTER.createTopic(MULTI_PARTITION_INPUT_TOPIC, NUM_TOPIC_PARTITIONS, 
1);
+    }
+
+    private void checkResultPerKey(final List<KeyValue<Long, Long>> result,
+                                   final List<KeyValue<Long, Long>> 
expectedResult) {
+        final Set<Long> allKeys = new HashSet<>();
+        addAllKeys(allKeys, result);
+        addAllKeys(allKeys, expectedResult);
+
+        for (final Long key : allKeys) {
+            assertThat("The records do not match what expected", 
getAllRecordPerKey(key, result), equalTo(getAllRecordPerKey(key, 
expectedResult)));
+        }
+    }
+
+    private void addAllKeys(final Set<Long> allKeys, final List<KeyValue<Long, 
Long>> records) {
+        for (final KeyValue<Long, Long> record : records) {
+            allKeys.add(record.key);
+        }
+    }
+
+    private List<KeyValue<Long, Long>> getAllRecordPerKey(final Long key, 
final List<KeyValue<Long, Long>> records) {
+        final List<KeyValue<Long, Long>> recordsPerKey = new 
ArrayList<>(records.size());
+
+        for (final KeyValue<Long, Long> record : records) {
+            if (record.key.equals(key)) {
+                recordsPerKey.add(record);
+            }
+        }
+
+        return recordsPerKey;
+    }
+
+    @Test
+    public void shouldCommitAllTasksIfRevokedTaskTriggerPunctuation() throws 
Exception {
+        final AtomicBoolean requestCommit = new AtomicBoolean(false);
+
+        final StreamsBuilder builder = new StreamsBuilder();
+        builder.<Long, Long>stream(MULTI_PARTITION_INPUT_TOPIC)
+            .process(() -> new Processor<Long, Long, Long, Long>() {
+                ProcessorContext<Long, Long> context;
+                @Override
+                public void init(final ProcessorContext<Long, Long> context) {
+                    this.context = context;
+
+                    final AtomicReference<Cancellable> cancellable = new 
AtomicReference<>();
+                    cancellable.set(context.schedule(
+                        Duration.ofSeconds(1),
+                        PunctuationType.WALL_CLOCK_TIME,
+                        time -> {
+                            context.forward(new Record<>(
+                                (context.taskId().partition() + 1) * 100L,
+                                -(context.taskId().partition() + 1L),
+                                context.currentSystemTimeMs()));
+                            cancellable.get().cancel();
+                        }
+                    ));
+                }
+
+                @Override
+
+                public void process(final Record<Long, Long> record) {
+                    
context.forward(record.withValue(context.recordMetadata().get().offset()));
+
+                    if (requestCommit.get()) {
+                        context.commit();
+                    }
+                }
+            })
+            .to(SINGLE_PARTITION_OUTPUT_TOPIC);
+
+        final Properties properties = new Properties();
+        properties.put(StreamsConfig.STATESTORE_CACHE_MAX_BYTES_CONFIG, 0);
+        properties.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 
Integer.MAX_VALUE);
+        
properties.put(StreamsConfig.consumerPrefix(ConsumerConfig.MAX_POLL_RECORDS_CONFIG),
 1);
+        
properties.put(StreamsConfig.consumerPrefix(ConsumerConfig.METADATA_MAX_AGE_CONFIG),
 "1000");
+        
properties.put(StreamsConfig.consumerPrefix(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG),
 "earliest");
+        
properties.put(StreamsConfig.consumerPrefix(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG),
 MAX_POLL_INTERVAL_MS - 1);
+        
properties.put(StreamsConfig.consumerPrefix(ConsumerConfig.MAX_POLL_INTERVAL_MS_CONFIG),
 MAX_POLL_INTERVAL_MS);
+        
properties.put(StreamsConfig.producerPrefix(ProducerConfig.TRANSACTION_TIMEOUT_CONFIG),
 Integer.MAX_VALUE);
+        properties.put(StreamsConfig.TASK_ASSIGNOR_CLASS_CONFIG, 
TestTaskAssignor.class.getName());
+
+        final Properties config = StreamsTestUtils.getStreamsConfig(
+            applicationId,
+            CLUSTER.bootstrapServers(),
+            Serdes.LongSerde.class.getName(),
+            Serdes.LongSerde.class.getName(),
+            properties
+        );
+
+        try (final KafkaStreams streams = new KafkaStreams(builder.build(), 
config)) {
+            startApplicationAndWaitUntilRunning(streams);
+
+            // PHASE 1:
+            // produce single output record via punctuation (uncommitted) 
[this happens for both tasks]
+            // StreamThread-1 now has a task with progress, and one task w/o 
progress
+            final List<KeyValue<Long, Long>> 
expectedUncommittedResultBeforeRebalance = Arrays.asList(KeyValue.pair(100L, 
-1L), KeyValue.pair(200L, -2L));
+            final List<KeyValue<Long, Long>> uncommittedRecordsBeforeRebalance 
= IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived(
+                TestUtils.consumerConfig(CLUSTER.bootstrapServers(), 
LongDeserializer.class, LongDeserializer.class),
+                SINGLE_PARTITION_OUTPUT_TOPIC,
+                expectedUncommittedResultBeforeRebalance.size()
+            );
+            checkResultPerKey(uncommittedRecordsBeforeRebalance, 
expectedUncommittedResultBeforeRebalance);
+
+            // PHASE 2:
+            // add second thread, to trigger rebalance
+            // both task should get committed
+            streams.addStreamThread();
+
+            final List<KeyValue<Long, Long>> 
expectedUncommittedResultAfterRebalance = Arrays.asList(KeyValue.pair(100L, 
-1L), KeyValue.pair(200L, -2L));
+            final List<KeyValue<Long, Long>> uncommittedRecordsAfterRebalance 
= IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived(
+                TestUtils.consumerConfig(CLUSTER.bootstrapServers(), 
LongDeserializer.class, LongDeserializer.class),
+                SINGLE_PARTITION_OUTPUT_TOPIC,
+                expectedUncommittedResultAfterRebalance.size()
+            );
+            checkResultPerKey(uncommittedRecordsAfterRebalance, 
expectedUncommittedResultAfterRebalance);
+        }
+    }
+}
\ No newline at end of file
diff --git 
a/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/TestTaskAssignor.java
 
b/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/TestTaskAssignor.java
new file mode 100644
index 00000000000..8e4d614e42a
--- /dev/null
+++ 
b/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/TestTaskAssignor.java
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.integration;
+
+import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor;
+import org.apache.kafka.streams.processor.TaskId;
+import 
org.apache.kafka.streams.processor.assignment.assignors.StickyTaskAssignor;
+
+public class TestTaskAssignor extends StickyTaskAssignor {
+
+    @Override
+    public void onAssignmentComputed(final 
ConsumerPartitionAssignor.GroupAssignment assignment,
+                                     final 
ConsumerPartitionAssignor.GroupSubscription subscription,
+                                     final AssignmentError error) {
+        if (assignment.groupAssignment().size() == 1) {
+            return;
+        }
+
+        for (final String threadName : assignment.groupAssignment().keySet()) {
+            if (threadName.contains("-StreamThread-1-")) {
+                final TaskId taskWithData =  
EosIntegrationTest.TASK_WITH_DATA.get();
+                if (taskWithData != null && taskWithData.partition() == 
assignment.groupAssignment().get(threadName).partitions().get(0).partition()) {
+                    EosIntegrationTest.DID_REVOKE_IDLE_TASK.set(true);
+                }
+            }
+        }
+    }
+}
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
index 102e1d07036..eccf0c8f33d 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java
@@ -1135,12 +1135,16 @@ public class TaskManager {
         final Set<TaskId> lockedTaskIds = 
activeRunningTaskIterable().stream().map(Task::id).collect(Collectors.toSet());
         maybeLockTasks(lockedTaskIds);
 
+        boolean revokedTasksNeedCommit = false;
         for (final Task task : activeRunningTaskIterable()) {
             if 
(remainingRevokedPartitions.containsAll(task.inputPartitions())) {
                 // when the task input partitions are included in the revoked 
list,
                 // this is an active task and should be revoked
+
                 revokedActiveTasks.add(task);
                 remainingRevokedPartitions.removeAll(task.inputPartitions());
+
+                revokedTasksNeedCommit |= task.commitNeeded();
             } else if (task.commitNeeded()) {
                 commitNeededActiveTasks.add(task);
             }
@@ -1154,11 +1158,9 @@ public class TaskManager {
                          "have been cleaned up by the handleAssignment 
callback.", remainingRevokedPartitions);
         }
 
-        prepareCommitAndAddOffsetsToMap(revokedActiveTasks, 
consumedOffsetsPerTask);
-
-        // if we need to commit any revoking task then we just commit all of 
those needed committing together
-        final boolean shouldCommitAdditionalTasks = 
!consumedOffsetsPerTask.isEmpty();
-        if (shouldCommitAdditionalTasks) {
+        if (revokedTasksNeedCommit) {
+            prepareCommitAndAddOffsetsToMap(revokedActiveTasks, 
consumedOffsetsPerTask);
+            // if we need to commit any revoking task then we just commit all 
of those needed committing together
             prepareCommitAndAddOffsetsToMap(commitNeededActiveTasks, 
consumedOffsetsPerTask);
         }
 
@@ -1167,10 +1169,12 @@ public class TaskManager {
         // as such we just need to skip those dirty tasks in the checkpoint
         final Set<Task> dirtyTasks = new HashSet<>();
         try {
-            // in handleRevocation we must call commitOffsetsOrTransaction() 
directly rather than
-            // commitAndFillInConsumedOffsetsAndMetadataPerTaskMap() to make 
sure we don't skip the
-            // offset commit because we are in a rebalance
-            taskExecutor.commitOffsetsOrTransaction(consumedOffsetsPerTask);
+            if (revokedTasksNeedCommit) {
+                // in handleRevocation we must call 
commitOffsetsOrTransaction() directly rather than
+                // commitAndFillInConsumedOffsetsAndMetadataPerTaskMap() to 
make sure we don't skip the
+                // offset commit because we are in a rebalance
+                
taskExecutor.commitOffsetsOrTransaction(consumedOffsetsPerTask);
+            }
         } catch (final TaskCorruptedException e) {
             log.warn("Some tasks were corrupted when trying to commit offsets, 
these will be cleaned and revived: {}",
                      e.corruptedTasks());
@@ -1203,7 +1207,7 @@ public class TaskManager {
             }
         }
 
-        if (shouldCommitAdditionalTasks) {
+        if (revokedTasksNeedCommit) {
             for (final Task task : commitNeededActiveTasks) {
                 if (!dirtyTasks.contains(task)) {
                     try {
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
index b31cde8f1dc..9d7df53adbe 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java
@@ -3054,7 +3054,7 @@ public class TaskManagerTest {
 
         assertThat(task00.commitNeeded, is(false));
         assertThat(task00.commitPrepared, is(true));
-        assertThat(task00.commitNeeded, is(false));
+        assertThat(task01.commitNeeded, is(false));
         assertThat(task01.commitPrepared, is(true));
         assertThat(task02.commitPrepared, is(false));
         assertThat(task10.commitPrepared, is(false));
@@ -3062,6 +3062,74 @@ public class TaskManagerTest {
         verify(consumer).commitSync(expectedCommittedOffsets);
     }
 
+    @Test
+    public void shouldNotCommitIfNoRevokedTasksNeedCommitting() {
+        final StateMachineTask task00 = new StateMachineTask(taskId00, 
taskId00Partitions, true, stateManager);
+
+        final StateMachineTask task01 = new StateMachineTask(taskId01, 
taskId01Partitions, true, stateManager);
+        task01.setCommitNeeded();
+
+        final StateMachineTask task02 = new StateMachineTask(taskId02, 
taskId02Partitions, true, stateManager);
+
+        final Map<TaskId, Set<TopicPartition>> assignmentActive = mkMap(
+            mkEntry(taskId00, taskId00Partitions),
+            mkEntry(taskId01, taskId01Partitions),
+            mkEntry(taskId02, taskId02Partitions)
+        );
+
+        when(consumer.assignment()).thenReturn(assignment);
+
+        when(activeTaskCreator.createTasks(any(), eq(assignmentActive)))
+            .thenReturn(asList(task00, task01, task02));
+
+        taskManager.handleAssignment(assignmentActive, Collections.emptyMap());
+        assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
+        assertThat(task00.state(), is(Task.State.RUNNING));
+        assertThat(task01.state(), is(Task.State.RUNNING));
+        assertThat(task02.state(), is(Task.State.RUNNING));
+
+        taskManager.handleRevocation(taskId00Partitions);
+
+        assertThat(task00.commitPrepared, is(false));
+        assertThat(task01.commitPrepared, is(false));
+        assertThat(task02.commitPrepared, is(false));
+    }
+
+    @Test
+    public void shouldNotCommitIfNoRevokedTasksNeedCommittingWithEOSv2() {
+        final TaskManager taskManager = 
setUpTaskManager(ProcessingMode.EXACTLY_ONCE_V2, false);
+
+        final StateMachineTask task00 = new StateMachineTask(taskId00, 
taskId00Partitions, true, stateManager);
+
+        final StateMachineTask task01 = new StateMachineTask(taskId01, 
taskId01Partitions, true, stateManager);
+        task01.setCommitNeeded();
+
+        final StateMachineTask task02 = new StateMachineTask(taskId02, 
taskId02Partitions, true, stateManager);
+
+        final Map<TaskId, Set<TopicPartition>> assignmentActive = mkMap(
+            mkEntry(taskId00, taskId00Partitions),
+            mkEntry(taskId01, taskId01Partitions),
+            mkEntry(taskId02, taskId02Partitions)
+        );
+
+        when(consumer.assignment()).thenReturn(assignment);
+
+        when(activeTaskCreator.createTasks(any(), eq(assignmentActive)))
+            .thenReturn(asList(task00, task01, task02));
+
+        taskManager.handleAssignment(assignmentActive, Collections.emptyMap());
+        assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), 
null), is(true));
+        assertThat(task00.state(), is(Task.State.RUNNING));
+        assertThat(task01.state(), is(Task.State.RUNNING));
+        assertThat(task02.state(), is(Task.State.RUNNING));
+
+        taskManager.handleRevocation(taskId00Partitions);
+
+        assertThat(task00.commitPrepared, is(false));
+        assertThat(task01.commitPrepared, is(false));
+        assertThat(task02.commitPrepared, is(false));
+    }
+
     @Test
     public void shouldNotCommitOnHandleAssignmentIfNoTaskClosed() {
         final StateMachineTask task00 = new StateMachineTask(taskId00, 
taskId00Partitions, true, stateManager);


Reply via email to